dpo_trainer.py 22.5 KB
Newer Older
lvzhen's avatar
first  
lvzhen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

import torch
from torch import nn
from transformers import PreTrainedModel
from trl import DPOTrainer as HFDPOTrainer
from trl.trainer.utils import pad_to_length

from swift.llm.utils.template import Template
from swift.utils import get_logger
from .mixin import PushToMsHubMixin, SwiftMixin
from .utils import build_tokenized_answer, patch_trl, sort_by_max_length

logger = get_logger()


class DPOTrainer(PushToMsHubMixin, SwiftMixin, HFDPOTrainer):

    def __init__(self, *args, template: Template, sft_beta=0., test_oom_error=False, **kwargs):
        self.template = template
        template._is_training = True
        self.sft_beta = sft_beta
        self.streaming = kwargs.pop('streaming')
        is_vision = kwargs.pop('is_vision')
        patch_trl(is_vision)
        self.processed_keys = []  # keys after tokenize_row mapiing
        self.column_names = list(next(iter(kwargs.get('train_dataset'))).keys())
        self._data_keys = []  # vision related key in _data
        self.need_filter: bool = False

        super().__init__(*args, **kwargs)
        # remove origin columns to resolve conflit in streaming mode
        self.train_dataset = self.train_dataset.remove_columns(self.column_names)
        if self.eval_dataset is not None:
            self.eval_dataset = self.eval_dataset.remove_columns(self.column_names)

        if self.need_filter:
            self.train_dataset = self.train_dataset.filter(lambda x: x['prompt_input_ids'] is not None)
            if self.eval_dataset is not None:
                self.eval_dataset = self.eval_dataset.filter(lambda x: x['prompt_input_ids'] is not None)
        if not self.streaming:
            train_ds_info = self.stat_dataset(self.train_dataset, self.is_encoder_decoder)

            if self.eval_dataset is not None:
                val_ds_info = self.stat_dataset(self.eval_dataset, self.is_encoder_decoder)
                self.dataset_info = {'train_dataset': train_ds_info, 'val_dataset': val_ds_info}
            else:
                self.dataset_info = {'train_dataset': train_ds_info}
        if test_oom_error:
            self.train_dataset = sort_by_max_length(self.train_dataset, 20000, self.is_encoder_decoder)
        # performance
        self.perf: Dict[str, Any] = {
            'gen_time': 0.,
            'gen_len': 0,
            'memory': {},
            'model': self.model.get_trainable_parameters() if hasattr(self.model, 'get_trainable_parameters') else None,
        }
        # modify after init
        self.is_vision_model = is_vision
        self.model.config.model_type = self.model.config.model_type[:-1]  # remove suffix

    def train(self, *args, **kwargs) -> torch.Tensor:
        res = super().train(*args, **kwargs)
        for i in range(torch.cuda.device_count()):
            self.perf['memory'][f'cuda:{i}'] = f'{torch.cuda.max_memory_reserved(i)/1024/1024/1024:.2f}GiB'
        return res

    def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None) -> Dict:

        batch = {}
        if not self.is_encoder_decoder:
            # encode without response
            prompt = feature.copy()
            prompt['response'] = None
            prompt_tokens = self.template.encode(prompt)[0]
            prompt_tokens.pop('labels', None)
            # Skip examples that have too lengthy prompt to avoid conflict in following processing
            if 'input_ids' not in prompt_tokens:
                self.need_filter = True
                return {k: None for k in self.processed_keys}

            # for MLLM, pop vision related data to process after
            if '_data' in prompt_tokens:
                if not self._data_keys:
                    self._data_keys = prompt_tokens['_data'].keys()
                for key in prompt_tokens['_data'].keys():
                    if key not in prompt_tokens:
                        prompt_tokens[key] = prompt_tokens['_data'][key]
                prompt_tokens.pop('_data')

            # convert bfloat16 to float32 to avoid conflict in mapping
            if 'pixel_values' in prompt_tokens and prompt_tokens['pixel_values'].dtype == torch.bfloat16:
                prompt_tokens['pixel_values'] = prompt_tokens['pixel_values'].to(torch.float32)

            if 'attention_mask' not in prompt_tokens:
                prompt_tokens['attention_mask'] = [1] * len(prompt_tokens['input_ids'])

            prompt_tokens = {f'prompt_{k}': v for k, v in prompt_tokens.items()}

            # encode with response
            chosen_tokens = build_tokenized_answer(feature['response'], self.template)
            chosen_tokens.update(prompt_tokens)

            rejected_tokens = build_tokenized_answer(feature['rejected_response'], self.template)
            rejected_tokens.update(prompt_tokens)

            longer_response_length = max(len(chosen_tokens['input_ids']), len(rejected_tokens['input_ids']))

            # if combined sequence is too long, truncate the prompt
            for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
                if len(answer_tokens['prompt_input_ids']) + longer_response_length > self.max_length:
                    if self.truncation_mode == 'keep_start':
                        for k in ['prompt_input_ids', 'prompt_attention_mask']:
                            answer_tokens[k] = answer_tokens[k][:self.max_prompt_length]
                    elif self.truncation_mode == 'keep_end':
                        for k in ['prompt_input_ids', 'prompt_attention_mask']:
                            answer_tokens[k] = answer_tokens[k][-self.max_prompt_length:]
                    else:
                        raise ValueError(f'Unknown truncation mode: {self.truncation_mode}')

            # if that's still too long, truncate the response
            for answer_tokens in [chosen_tokens, rejected_tokens]:
                if len(answer_tokens['prompt_input_ids']) + longer_response_length > self.max_length:
                    for k in ['input_ids', 'attention_mask']:
                        answer_tokens[k] = answer_tokens[k][:self.max_length - self.max_prompt_length]

            # Create labels
            chosen_sequence_tokens = {
                k: chosen_tokens[f'prompt_{k}'] + chosen_tokens[k]
                for k in ['input_ids', 'attention_mask']
            }
            rejected_sequence_tokens = {
                k: rejected_tokens[f'prompt_{k}'] + rejected_tokens[k]
                for k in ['input_ids', 'attention_mask']
            }

            chosen_sequence_tokens['labels'] = chosen_sequence_tokens['input_ids'][:]
            _paddings = [self.label_pad_token_id] * len(chosen_tokens['prompt_input_ids'])
            chosen_sequence_tokens['labels'][:len(chosen_tokens['prompt_input_ids'])] = _paddings
            rejected_sequence_tokens['labels'] = rejected_sequence_tokens['input_ids'][:]
            _paddings = [self.label_pad_token_id] * len(rejected_tokens['prompt_input_ids'])
            rejected_sequence_tokens['labels'][:len(rejected_tokens['prompt_input_ids'])] = _paddings

            for k, toks in {
                    'chosen_': chosen_sequence_tokens,
                    'rejected_': rejected_sequence_tokens,
                    '': prompt_tokens,
            }.items():
                for type_key, tokens in toks.items():
                    if type_key == 'token_type_ids':
                        continue
                    batch[f'{k}{type_key}'] = tokens

        else:
            # encoder-decoder
            prompt = feature.copy()
            prompt['response'] = None
            prompt_tokens = self.template.encode(prompt)[0]
            prompt_tokens.pop('labels', None)

            if '_data' in prompt_tokens:
                if not self._data_keys:
                    self._data_keys = prompt_tokens['_data'].keys()
                for key in prompt_tokens['_data'].keys():
                    if key not in prompt_tokens:
                        prompt_tokens[key] = prompt_tokens['_data'][key]
                prompt_tokens.pop('_data')

            if 'pixel_values' in prompt_tokens and prompt_tokens['pixel_values'].dtype == torch.bfloat16:
                # datasets do not accept bfloat16; convert to float32.
                prompt_tokens['pixel_values'] = prompt_tokens['pixel_values'].to(torch.float32)
            if 'attention_mask' not in prompt_tokens:
                prompt_tokens['attention_mask'] = [1] * len(prompt_tokens['input_ids'])

            prompt_tokens = {f'prompt_{k}': v for k, v in prompt_tokens.items()}

            # encode with response
            chosen_tokens = build_tokenized_answer(feature['response'], self.template)
            rejected_tokens = build_tokenized_answer(feature['rejected_response'], self.template)

            batch['chosen_labels'] = chosen_tokens['input_ids']
            batch['rejected_labels'] = rejected_tokens['input_ids']

            if model is not None and hasattr(model, 'prepare_decoder_input_ids_from_labels'):
                batch['rejected_decoder_input_ids'] = model.prepare_decoder_input_ids_from_labels(
                    labels=torch.tensor(batch['rejected_labels']))
                batch['chosen_decoder_input_ids'] = model.prepare_decoder_input_ids_from_labels(
                    labels=torch.tensor(batch['chosen_labels']))

            batch.update(prompt_tokens)
        if not self.processed_keys:
            self.processed_keys = (list(batch.keys()))
        return batch

    def get_batch_loss_metrics(
        self,
        model,
        batch: Dict[str, Union[List, torch.LongTensor]],
        train_eval: Literal['train', 'eval'] = 'train',
    ):
        """Compute the DPO loss and other metrics for the given batch of inputs for train or test."""

        metrics = {}
        forward_output = self.concatenated_forward(model, batch)
        (
            policy_chosen_logps,
            policy_rejected_logps,
            policy_chosen_logits,
            policy_rejected_logits,
            policy_nll_loss,
        ) = forward_output[:5]
        if self.aux_loss_enabled:
            aux_loss = forward_output[5]

        # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
        if 'reference_chosen_logps' in batch and 'reference_rejected_logps' in batch:
            reference_chosen_logps = batch['reference_chosen_logps']
            reference_rejected_logps = batch['reference_rejected_logps']
        else:
            with torch.no_grad():
                if self.ref_model is None:
                    with self.null_ref_context():
                        (
                            reference_chosen_logps,
                            reference_rejected_logps,
                            _,
                            _,
                            _,
                        ) = self.concatenated_forward(self.model, batch)
                else:
                    (
                        reference_chosen_logps,
                        reference_rejected_logps,
                        _,
                        _,
                        _,
                    ) = self.concatenated_forward(self.ref_model, batch)

        losses, chosen_rewards, rejected_rewards = self.dpo_loss(
            policy_chosen_logps,
            policy_rejected_logps,
            reference_chosen_logps,
            reference_rejected_logps,
        )

        if self.sft_beta > 0.:
            losses = (1 - self.sft_beta) * losses + self.sft_beta * policy_nll_loss

        reward_accuracies = (chosen_rewards > rejected_rewards).float()

        prefix = 'eval_' if train_eval == 'eval' else ''
        metrics[f'{prefix}rewards/chosen'] = chosen_rewards.mean().cpu()
        metrics[f'{prefix}rewards/rejected'] = rejected_rewards.mean().cpu()
        metrics[f'{prefix}rewards/accuracies'] = reward_accuracies.mean().cpu()
        metrics[f'{prefix}rewards/margins'] = (chosen_rewards - rejected_rewards).mean().cpu()
        metrics[f'{prefix}logps/rejected'] = policy_rejected_logps.detach().mean().cpu()
        metrics[f'{prefix}logps/chosen'] = policy_chosen_logps.detach().mean().cpu()
        metrics[f'{prefix}logps/ref_rejected'] = reference_rejected_logps.detach(  # noqa
        ).mean().cpu()  # noqa
        metrics[f'{prefix}logps/ref_chosen'] = reference_chosen_logps.detach().mean().cpu()
        metrics[f'{prefix}logits/rejected'] = policy_rejected_logits.detach().mean().cpu()
        metrics[f'{prefix}logits/chosen'] = policy_chosen_logits.detach().mean().cpu()
        if self.sft_beta > 0:
            metrics[f'{prefix}sft_loss'] = policy_nll_loss.detach().mean().cpu()

        if self.aux_loss_enabled:
            return losses.mean() + getattr(model.config, 'router_aux_loss_coef', 0.0) * aux_loss, metrics

        return losses.mean(), metrics

    def concatenated_forward(
        self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.

        We do this to avoid doing two forward passes, because it's faster for FSDP.
        """
        concatenated_batch = self.concatenated_inputs(
            batch,
            is_encoder_decoder=self.is_encoder_decoder,
            is_vision_model=self.is_vision_model,
            label_pad_token_id=self.label_pad_token_id,
            padding_value=self.padding_value,
            device=self.accelerator.device,
        )
        len_chosen = batch['chosen_labels'].shape[0]

        model_kwargs = {}

        if self.is_encoder_decoder:
            model_kwargs['labels'] = concatenated_batch['concatenated_labels']
            model_kwargs['decoder_input_ids'] = concatenated_batch.pop('concatenated_decoder_input_ids', None)

        if self.is_vision_model:
            # Here, we restore the _data, processing image information within the forward hook of the model.
            batch_size = concatenated_batch['concatenated_input_ids'].shape[0]
            if self._data_keys:
                _data = [dict() for _ in range(batch_size)]
                for k in self._data_keys:
                    if k == 'input_ids':
                        _data = [{**d, k: concatenated_batch['concatenated_input_ids'][i]} for i, d in enumerate(_data)]
                    elif k == 'pixel_values':
                        # convert the dtype of the pixel values that may be converted to float32 in tokenize_row
                        model_dtype = self.accelerator.unwrap_model(model).dtype
                        # for vision related data, paired response share the same one
                        _data = [{**d, k: concatenated_batch[k][i // 2].to(model_dtype)} for i, d in enumerate(_data)]
                    else:
                        _data = [{**d, k: concatenated_batch[k][i // 2]} for i, d in enumerate(_data)]
                    model_kwargs['_data'] = _data

            if 'images' in concatenated_batch:
                model_kwargs['images'] = concatenated_batch['images']

        if self.aux_loss_enabled:
            model_kwargs['output_router_logits'] = True

        outputs = model(
            input_ids=concatenated_batch['concatenated_input_ids'],
            attention_mask=concatenated_batch['concatenated_attention_mask'],
            use_cache=False,
            **model_kwargs,
        )
        all_logits = outputs.logits

        if all_logits.shape[:2] != concatenated_batch['concatenated_labels'].shape[:2]:
            # for llava, the model returns logits for the entire sequence,
            # including the image tokens (placed before the text tokens)
            seq_len = concatenated_batch['concatenated_labels'].shape[1]
            all_logits = all_logits[:, -seq_len:]

        all_logps, size_completion = self.get_batch_logps(
            all_logits,
            concatenated_batch['concatenated_labels'],
            is_encoder_decoder=self.is_encoder_decoder,
            label_pad_token_id=self.label_pad_token_id,
        )

        def cross_entropy_loss(logits, labels):
            if not self.is_encoder_decoder:
                # Shift so that tokens < n predict n
                logits = logits[..., :-1, :].contiguous()
                labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = nn.CrossEntropyLoss()
            logits = logits.view(-1, logits.shape[-1])
            labels = labels.view(-1)
            # Enable model parallelism
            labels = labels.to(logits.device)
            loss = loss_fct(logits, labels)
            return loss

        labels = concatenated_batch['concatenated_labels'].clone()
        nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])

        if self.loss_type == 'ipo':
            all_logps = all_logps / size_completion

        chosen_logps = all_logps[:len_chosen]
        rejected_logps = all_logps[len_chosen:]

        chosen_logits = all_logits[:len_chosen]
        rejected_logits = all_logits[len_chosen:]

        if self.aux_loss_enabled:
            return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)

        return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)

    @staticmethod
    def concatenated_inputs(
        batch: Dict[str, Union[List, torch.LongTensor]],
        is_encoder_decoder: bool = False,
        is_vision_model: bool = False,
        label_pad_token_id: int = -100,
        padding_value: int = 0,
        device: Optional[torch.device] = None,
    ) -> Dict[str, torch.LongTensor]:
        """
        Concatenate the chosen and rejected inputs into a single tensor.

        Args:
            batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids',
            which are tensors of shape (batch_size, sequence_length).
            is_encoder_decoder: Whether the model is an encoder-decoder model.
            is_vision_model: Whether the model is an vision LLM.
            label_pad_token_id: The label pad token id.
            padding_value: The padding value to use for the concatenated inputs_ids.
            device: The device for the concatenated inputs.

        Returns:
            A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
        """
        concatenated_batch = {}

        if is_encoder_decoder:
            max_length = max(batch['chosen_labels'].shape[1], batch['rejected_labels'].shape[1])
        else:
            max_length = max(batch['chosen_input_ids'].shape[1], batch['rejected_input_ids'].shape[1])

        for k in batch:
            if k.startswith('chosen') and isinstance(batch[k], torch.Tensor):
                if 'labels' in k or is_encoder_decoder:
                    pad_value = label_pad_token_id
                elif k.endswith('_input_ids'):
                    pad_value = padding_value
                elif k.endswith('_attention_mask'):
                    pad_value = 0
                concatenated_key = k.replace('chosen', 'concatenated')
                concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
        for k in batch:
            if k.startswith('rejected') and isinstance(batch[k], torch.Tensor):
                if 'labels' in k or is_encoder_decoder:
                    pad_value = label_pad_token_id
                elif k.endswith('_input_ids'):
                    pad_value = padding_value
                elif k.endswith('_attention_mask'):
                    pad_value = 0
                concatenated_key = k.replace('rejected', 'concatenated')
                concatenated_batch[concatenated_key] = torch.cat(
                    (
                        concatenated_batch[concatenated_key],
                        pad_to_length(batch[k], max_length, pad_value=pad_value),
                    ),
                    dim=0,
                ).to(device=device)

        if is_encoder_decoder:
            concatenated_batch['concatenated_input_ids'] = batch['prompt_input_ids'].repeat(2, 1).to(device=device)
            concatenated_batch['concatenated_attention_mask'] = (
                batch['prompt_attention_mask'].repeat(2, 1).to(device=device))

        # patch here
        if is_vision_model:
            # for keys appear in _data, we leave data collector in hook
            if 'prompt_pixel_values' in batch:
                pixel_values = [values for values in batch['prompt_pixel_values']]
                concatenated_batch['pixel_values'] = pixel_values

            if 'prompt_image_flags' in batch:
                image_flags = [torch.tensor(flags) for flags in batch['prompt_image_flags']]
                concatenated_batch['image_flags'] = image_flags

            if 'prompt_pixel_attention_mask' in batch:
                pixel_attention_mask = [mask for mask in batch['pixel_attention_mask']]
                concatenated_batch['pixel_attention_mask'] = pixel_attention_mask

            if 'prompt_image_sizes' in batch:
                concatenated_batch['image_sizes'] = batch['prompt_image_sizes']

            if 'prompt_images' in batch:
                # images not in _data, we manually execute data collector here
                concatenated_batch['images'] = batch['prompt_images'].squeeze(1).repeat(2, 1, 1, 1).to(device=device)
        return concatenated_batch

    @staticmethod
    def stat_dataset(llm_dataset, is_encoder_decoder: bool = False) -> Any:
        _token_len = []
        from datasets import Dataset as HfDataset
        from swift.utils.np_utils import stat_array
        if isinstance(llm_dataset, HfDataset):
            if is_encoder_decoder:
                prompt = llm_dataset['prompt_input_ids']
                chosen = llm_dataset['chosen_labels']
                rejected = llm_dataset['chosen_labels']
                for p, cc, rr in zip(prompt, chosen, rejected):
                    _token_len.append(max(len(cc), len(rr)) + len(p))
            else:
                chosen = llm_dataset['chosen_input_ids']
                rejected = llm_dataset['rejected_input_ids']
                for cc, rr in zip(chosen, rejected):
                    _token_len.append(max(len(cc), len(rr)))
        else:
            for d in llm_dataset:
                if is_encoder_decoder:
                    _token_len.append(
                        max(len(d['chosen_labels']), len(d['chosen_labels'])) + len(d['prompt_input_ids']))
                else:
                    _token_len.append(max(len(d['chosen_input_ids']), len(d['rejected_input_ids'])))
        _, stat_str = stat_array(_token_len)
        logger.info(f'Dataset Token Length: {stat_str}')
        return stat_str