from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import torch from torch import nn from transformers import PreTrainedModel, trainer from trl import DPOTrainer as HFDPOTrainer from swift.llm.utils.template import Context, Template from swift.llm.utils.utils import sort_by_max_length from swift.utils import get_logger from .callback import DefaultFlowCallbackNew, PrinterCallbackNew, ProgressCallbackNew from .mixin import PushToMsHubMixin, SwiftMixin logger = get_logger() class DPOTrainer(PushToMsHubMixin, SwiftMixin, HFDPOTrainer): def __init__(self, *args, template: Template, sft_beta=0., test_oom_error=False, **kwargs): self.template = template self.sft_beta = sft_beta super().__init__(*args, **kwargs) train_ds_info = self.stat_dataset(self.train_dataset) val_ds_info = self.stat_dataset(self.eval_dataset) self.dataset_info = {'train_dataset': train_ds_info, 'val_dataset': val_ds_info} if test_oom_error: self.train_dataset = sort_by_max_length(self.train_dataset, 20000) # performance self.perf: Dict[str, Any] = { 'gen_time': 0., 'gen_len': 0, 'memory': {}, 'model': self.model.get_trainable_parameters() if hasattr(self.model, 'get_trainable_parameters') else None, } def train(self, *args, **kwargs) -> torch.Tensor: res = super().train(*args, **kwargs) for i in range(torch.cuda.device_count()): self.perf['memory'][f'cuda:{i}'] = f'{torch.cuda.max_memory_reserved(i)/1024/1024/1024:.2f}GiB' return res def concat_template(self, feature): query: Optional[str] = feature.get('query', None) system: Optional[str] = feature.get('system', None) history: List = feature.get('history', []) if system is None: if self.template.use_default_system: system = self.template.default_system else: assert self.template.prefix_has_system is not None, 'not support `system`' res_context_list: List[Context] = [] compute_loss_idx: List[float] = [] if system is None: assert self.template.prefix != self.template.prefix_has_system, f'template.prefix: {self.template.prefix}' prefix = self.template.prefix else: prefix = self.template.prefix_has_system self.template._concat_context_list(prefix, res_context_list, compute_loss_idx, system=system) for i, (q, r) in enumerate(history): self.template._concat_context_list( [ *self.template.prompt, '{{RESPONSE}}', *self.template.chat_sep # noqa ], res_context_list, compute_loss_idx, query=q, response=r, round0=i) # noqa self.template._concat_context_list( self.template.prompt, res_context_list, compute_loss_idx, query=query, round0=len(history)) res_context_list, compute_loss_idx = self.template._simplify_context_list(res_context_list, compute_loss_idx) return res_context_list, feature['response'], feature['rejected_response'], compute_loss_idx def build_tokenized_answer(self, prompt, answer, prompt_loss_scale): input_ids, labels, loss_scale, kwargs = self.template._encode_context_list(prompt, prompt_loss_scale) tgt_input_ids = self.template._encode_context_list([answer], [1.0])[0] tgt_input_ids += self.template._encode_context_list(self.template.suffix, [1.0])[0] return dict( prompt_input_ids=input_ids, prompt_attention_mask=[1] * len(input_ids), input_ids=tgt_input_ids, attention_mask=[1] * len(tgt_input_ids), ) def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None) -> Dict: batch = {} if not self.is_encoder_decoder: prompt, chosen, rejected, loss_scale = self.concat_template(feature) prompt_tokens, _, _, _ = self.template._encode_context_list(prompt, loss_scale) prompt_tokens = { 'input_ids': prompt_tokens, 'attention_mask': [1] * len(prompt_tokens), } prompt_tokens = {f'prompt_{k}': v for k, v in prompt_tokens.items()} if not isinstance(chosen, str): raise ValueError(f'chosen should be an str but got {type(chosen)}') chosen_tokens = self.build_tokenized_answer(prompt, chosen, loss_scale) if not isinstance(rejected, str): raise ValueError(f'rejected should be an str but got {type(rejected)}') rejected_tokens = self.build_tokenized_answer(prompt, rejected, loss_scale) longer_response_length = max(len(chosen_tokens['input_ids']), len(rejected_tokens['input_ids'])) # if combined sequence is too long, truncate the prompt for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]: if len(answer_tokens['prompt_input_ids']) + longer_response_length > self.max_length: if self.truncation_mode == 'keep_start': for k in ['prompt_input_ids', 'prompt_attention_mask']: answer_tokens[k] = answer_tokens[k][:self.max_prompt_length] elif self.truncation_mode == 'keep_end': for k in ['prompt_input_ids', 'prompt_attention_mask']: answer_tokens[k] = answer_tokens[k][-self.max_prompt_length:] else: raise ValueError(f'Unknown truncation mode: {self.truncation_mode}') # if that's still too long, truncate the response for answer_tokens in [chosen_tokens, rejected_tokens]: if len(answer_tokens['prompt_input_ids']) + longer_response_length > self.max_length: for k in ['input_ids', 'attention_mask']: answer_tokens[k] = answer_tokens[k][:self.max_length - self.max_prompt_length] # Create labels chosen_sequence_tokens = { k: chosen_tokens[f'prompt_{k}'] + chosen_tokens[k] for k in ['input_ids', 'attention_mask'] } rejected_sequence_tokens = { k: rejected_tokens[f'prompt_{k}'] + rejected_tokens[k] for k in ['input_ids', 'attention_mask'] } chosen_sequence_tokens['labels'] = chosen_sequence_tokens['input_ids'][:] _paddings = [self.label_pad_token_id] * len(chosen_tokens['prompt_input_ids']) chosen_sequence_tokens['labels'][:len(chosen_tokens['prompt_input_ids'])] = _paddings rejected_sequence_tokens['labels'] = rejected_sequence_tokens['input_ids'][:] _paddings = [self.label_pad_token_id] * len(rejected_tokens['prompt_input_ids']) rejected_sequence_tokens['labels'][:len(rejected_tokens['prompt_input_ids'])] = _paddings for k, toks in { 'chosen_': chosen_sequence_tokens, 'rejected_': rejected_sequence_tokens, '': prompt_tokens, }.items(): for type_key, tokens in toks.items(): if type_key == 'token_type_ids': continue batch[f'{k}{type_key}'] = tokens else: # encoder-decoder batch = super().tokenize_row(feature, model) return batch @staticmethod def stat_dataset(llm_dataset) -> Any: _token_len = [] from datasets import Dataset as HfDataset from swift.utils.np_utils import stat_array if isinstance(llm_dataset, HfDataset): chosen = llm_dataset['chosen_input_ids'] rejected = llm_dataset['rejected_input_ids'] for cc, rr in zip(chosen, rejected): _token_len.append(max(len(cc), len(rr))) else: for d in llm_dataset: _token_len.append(max(len(d['chosen_input_ids']), len(d['rejected_input_ids']))) _, stat_str = stat_array(_token_len) logger.info(f'Dataset Token Length: {stat_str}') return stat_str def get_batch_loss_metrics( self, model, batch: Dict[str, Union[List, torch.LongTensor]], train_eval: Literal['train', 'eval'] = 'train', ): """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" metrics = {} ( policy_chosen_logps, policy_rejected_logps, policy_chosen_logits, policy_rejected_logits, concatenated_batch, ) = self.concatenated_forward(model, batch) # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model if 'reference_chosen_logps' in batch and 'reference_rejected_logps' in batch: reference_chosen_logps = batch['reference_chosen_logps'] reference_rejected_logps = batch['reference_rejected_logps'] else: with torch.no_grad(): if self.ref_model is None: with self.null_ref_context(): ( reference_chosen_logps, reference_rejected_logps, _, _, _, ) = self.concatenated_forward(self.model, batch) else: ( reference_chosen_logps, reference_rejected_logps, _, _, _, ) = self.concatenated_forward(self.ref_model, batch) losses, chosen_rewards, rejected_rewards = self.dpo_loss( policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, ) if self.sft_beta > 0.: chosen_labels = concatenated_batch['concatenated_labels'][:batch['chosen_labels'].shape[0]] sft_loss = -self.get_batch_logps(policy_chosen_logits, chosen_labels, average_log_prob=True) if losses.shape[0] == 2 * sft_loss.shape[0]: sft_loss = sft_loss.repeat(2, *sft_loss.shape[1:]) losses = (1 - self.sft_beta) * losses + self.sft_beta * sft_loss reward_accuracies = (chosen_rewards > rejected_rewards).float() prefix = 'eval_' if train_eval == 'eval' else '' metrics[f'{prefix}rewards/chosen'] = chosen_rewards.mean().cpu() metrics[f'{prefix}rewards/rejected'] = rejected_rewards.mean().cpu() metrics[f'{prefix}rewards/accuracies'] = reward_accuracies.mean().cpu() metrics[f'{prefix}rewards/margins'] = (chosen_rewards - rejected_rewards).mean().cpu() metrics[f'{prefix}logps/rejected'] = policy_rejected_logps.detach().mean().cpu() metrics[f'{prefix}logps/chosen'] = policy_chosen_logps.detach().mean().cpu() metrics[f'{prefix}logps/ref_rejected'] = reference_rejected_logps.detach( # noqa ).mean().cpu() # noqa metrics[f'{prefix}logps/ref_chosen'] = reference_chosen_logps.detach().mean().cpu() metrics[f'{prefix}logits/rejected'] = policy_rejected_logits.detach().mean().cpu() metrics[f'{prefix}logits/chosen'] = policy_chosen_logits.detach().mean().cpu() return losses.mean(), metrics def concatenated_forward( self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, Dict[str, torch.LongTensor]]: """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. We do this to avoid doing two forward passes, because it's faster for FSDP. """ concatenated_batch = self.concatenated_inputs( batch, is_encoder_decoder=self.is_encoder_decoder, label_pad_token_id=self.label_pad_token_id, padding_value=self.padding_value, device=self.accelerator.device, ) len_chosen = batch['chosen_labels'].shape[0] model_kwargs = ({ 'labels': concatenated_batch['concatenated_labels'], 'decoder_input_ids': concatenated_batch.pop('concatenated_decoder_input_ids', None), } if self.is_encoder_decoder else {}) all_logits = model( concatenated_batch['concatenated_input_ids'], attention_mask=concatenated_batch['concatenated_attention_mask'], **model_kwargs, ).logits all_logps = self.get_batch_logps( all_logits, concatenated_batch['concatenated_labels'], average_log_prob=False, is_encoder_decoder=self.is_encoder_decoder, label_pad_token_id=self.label_pad_token_id, ) chosen_logps = all_logps[:len_chosen] rejected_logps = all_logps[len_chosen:] chosen_logits = all_logits[:len_chosen] rejected_logits = all_logits[len_chosen:] return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, concatenated_batch) # monkey patching trainer.DEFAULT_PROGRESS_CALLBACK = ProgressCallbackNew trainer.DEFAULT_CALLBACKS = [DefaultFlowCallbackNew] trainer.PrinterCallback = PrinterCallbackNew