from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import nn from transformers import PreTrainedModel, trainer from trl import DPOTrainer as HFDPOTrainer from swift.llm.utils.template import Context, Template from swift.llm.utils.utils import sort_by_max_length from swift.utils import get_logger from .callback import DefaultFlowCallbackNew, PrinterCallbackNew, ProgressCallbackNew from .mixin import PushToMsHubMixin, SwiftMixin logger = get_logger() # ref: https://github.com/princeton-nlp/SimPO/blob/main/scripts/simpo_trainer.py class SimPOTrainer(PushToMsHubMixin, SwiftMixin, HFDPOTrainer): def __init__(self, *args, template: Template, test_oom_error=False, **kwargs): self.template = template self.gamma = kwargs.pop('gamma') super().__init__(*args, **kwargs) train_ds_info = self.stat_dataset(self.train_dataset) val_ds_info = self.stat_dataset(self.eval_dataset) self.dataset_info = {'train_dataset': train_ds_info, 'val_dataset': val_ds_info} if test_oom_error: self.train_dataset = sort_by_max_length(self.train_dataset, 20000) # performance self.perf: Dict[str, Any] = { 'gen_time': 0., 'gen_len': 0, 'memory': {}, 'model': self.model.get_trainable_parameters() if hasattr(self.model, 'get_trainable_parameters') else None, } def train(self, *args, **kwargs) -> torch.Tensor: res = super().train(*args, **kwargs) for i in range(torch.cuda.device_count()): self.perf['memory'][f'cuda:{i}'] = f'{torch.cuda.max_memory_reserved(i)/1024/1024/1024:.2f}GiB' return res def concat_template(self, feature): query: Optional[str] = feature.get('query', None) system: Optional[str] = feature.get('system', None) history: List = feature.get('history', []) if system is None: if self.template.use_default_system: system = self.template.default_system else: assert self.template.prefix_has_system is not None, 'not support `system`' res_context_list: List[Context] = [] compute_loss_idx: List[float] = [] if system is None: assert self.template.prefix != self.template.prefix_has_system, f'template.prefix: {self.template.prefix}' prefix = self.template.prefix else: prefix = self.template.prefix_has_system self.template._concat_context_list(prefix, res_context_list, compute_loss_idx, system=system) for i, (q, r) in enumerate(history): self.template._concat_context_list( [ *self.template.prompt, '{{RESPONSE}}', *self.template.chat_sep # noqa ], res_context_list, compute_loss_idx, query=q, response=r, round0=i) # noqa self.template._concat_context_list( self.template.prompt, res_context_list, compute_loss_idx, query=query, round0=len(history)) res_context_list, compute_loss_idx = self.template._simplify_context_list(res_context_list, compute_loss_idx) return res_context_list, feature['response'], feature['rejected_response'], compute_loss_idx def build_tokenized_answer(self, prompt, answer, prompt_loss_scale): input_ids, labels, loss_scale, kwargs = self.template._encode_context_list(prompt, prompt_loss_scale) tgt_input_ids = self.template._encode_context_list([answer], [1.0])[0] tgt_input_ids += self.template._encode_context_list(self.template.suffix, [1.0])[0] return dict( prompt_input_ids=input_ids, prompt_attention_mask=[1] * len(input_ids), input_ids=tgt_input_ids, attention_mask=[1] * len(tgt_input_ids), ) def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None) -> Dict: batch = {} if not self.is_encoder_decoder: prompt, chosen, rejected, loss_scale = self.concat_template(feature) prompt_tokens, _, _, _ = self.template._encode_context_list(prompt, loss_scale) prompt_tokens = { 'input_ids': prompt_tokens, 'attention_mask': [1] * len(prompt_tokens), } prompt_tokens = {f'prompt_{k}': v for k, v in prompt_tokens.items()} if not isinstance(chosen, str): raise ValueError(f'chosen should be an str but got {type(chosen)}') chosen_tokens = self.build_tokenized_answer(prompt, chosen, loss_scale) if not isinstance(rejected, str): raise ValueError(f'rejected should be an str but got {type(rejected)}') rejected_tokens = self.build_tokenized_answer(prompt, rejected, loss_scale) longer_response_length = max(len(chosen_tokens['input_ids']), len(rejected_tokens['input_ids'])) # if combined sequence is too long, truncate the prompt for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]: if len(answer_tokens['prompt_input_ids']) + longer_response_length > self.max_length: if self.truncation_mode == 'keep_start': for k in ['prompt_input_ids', 'prompt_attention_mask']: answer_tokens[k] = answer_tokens[k][:self.max_prompt_length] elif self.truncation_mode == 'keep_end': for k in ['prompt_input_ids', 'prompt_attention_mask']: answer_tokens[k] = answer_tokens[k][-self.max_prompt_length:] else: raise ValueError(f'Unknown truncation mode: {self.truncation_mode}') # if that's still too long, truncate the response for answer_tokens in [chosen_tokens, rejected_tokens]: if len(answer_tokens['prompt_input_ids']) + longer_response_length > self.max_length: for k in ['input_ids', 'attention_mask']: answer_tokens[k] = answer_tokens[k][:self.max_length - self.max_prompt_length] # Create labels chosen_sequence_tokens = { k: chosen_tokens[f'prompt_{k}'] + chosen_tokens[k] for k in ['input_ids', 'attention_mask'] } rejected_sequence_tokens = { k: rejected_tokens[f'prompt_{k}'] + rejected_tokens[k] for k in ['input_ids', 'attention_mask'] } chosen_sequence_tokens['labels'] = chosen_sequence_tokens['input_ids'][:] _paddings = [self.label_pad_token_id] * len(chosen_tokens['prompt_input_ids']) chosen_sequence_tokens['labels'][:len(chosen_tokens['prompt_input_ids'])] = _paddings rejected_sequence_tokens['labels'] = rejected_sequence_tokens['input_ids'][:] _paddings = [self.label_pad_token_id] * len(rejected_tokens['prompt_input_ids']) rejected_sequence_tokens['labels'][:len(rejected_tokens['prompt_input_ids'])] = _paddings for k, toks in { 'chosen_': chosen_sequence_tokens, 'rejected_': rejected_sequence_tokens, '': prompt_tokens, }.items(): for type_key, tokens in toks.items(): if type_key == 'token_type_ids': continue batch[f'{k}{type_key}'] = tokens else: # encoder-decoder batch = super().tokenize_row(feature, model) return batch @staticmethod def stat_dataset(llm_dataset) -> Any: _token_len = [] from datasets import Dataset as HfDataset from swift.utils.np_utils import stat_array if isinstance(llm_dataset, HfDataset): chosen = llm_dataset['chosen_input_ids'] rejected = llm_dataset['rejected_input_ids'] for cc, rr in zip(chosen, rejected): _token_len.append(max(len(cc), len(rr))) else: for d in llm_dataset: _token_len.append(max(len(d['chosen_input_ids']), len(d['rejected_input_ids']))) _, stat_str = stat_array(_token_len) logger.info(f'Dataset Token Length: {stat_str}') return stat_str def get_batch_loss_metrics( self, model, batch: Dict[str, Union[List, torch.LongTensor]], train_eval: Literal['train', 'eval'] = 'train', ): """Compute the SimPO loss and other metrics for the given batch of inputs for train or test.""" metrics = {} ( policy_chosen_logps, policy_rejected_logps, policy_chosen_logits, policy_rejected_logits, ) = self.concatenated_forward(model, batch) losses, chosen_rewards, rejected_rewards = self.simpo_loss(policy_chosen_logps, policy_rejected_logps) reward_accuracies = (chosen_rewards > rejected_rewards).float() prefix = 'eval_' if train_eval == 'eval' else '' metrics[f'{prefix}rewards/chosen'] = chosen_rewards.mean().cpu() metrics[f'{prefix}rewards/rejected'] = rejected_rewards.mean().cpu() metrics[f'{prefix}rewards/accuracies'] = reward_accuracies.mean().cpu() metrics[f'{prefix}rewards/margins'] = (chosen_rewards - rejected_rewards).mean().cpu() metrics[f'{prefix}logps/rejected'] = policy_rejected_logps.detach().mean().cpu() metrics[f'{prefix}logps/chosen'] = policy_chosen_logps.detach().mean().cpu() metrics[f'{prefix}logits/rejected'] = policy_rejected_logits.detach().mean().cpu() metrics[f'{prefix}logits/chosen'] = policy_chosen_logits.detach().mean().cpu() return losses.mean(), metrics def concatenated_forward( self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. We do this to avoid doing two forward passes, because it's faster for FSDP. """ concatenated_batch = self.concatenated_inputs( batch, is_encoder_decoder=self.is_encoder_decoder, label_pad_token_id=self.label_pad_token_id, padding_value=self.padding_value, device=self.accelerator.device, ) len_chosen = batch['chosen_labels'].shape[0] model_kwargs = ({ 'labels': concatenated_batch['concatenated_labels'], 'decoder_input_ids': concatenated_batch.pop('concatenated_decoder_input_ids', None), } if self.is_encoder_decoder else {}) all_logits = model( concatenated_batch['concatenated_input_ids'], attention_mask=concatenated_batch['concatenated_attention_mask'], use_cache=False, **model_kwargs, ).logits all_logps = self.get_batch_logps( all_logits, concatenated_batch['concatenated_labels'], average_log_prob=True, is_encoder_decoder=self.is_encoder_decoder, label_pad_token_id=self.label_pad_token_id, ) chosen_logps = all_logps[:len_chosen] rejected_logps = all_logps[len_chosen:] chosen_logits = all_logits[:len_chosen] rejected_logits = all_logits[len_chosen:] return (chosen_logps, rejected_logps, chosen_logits, rejected_logits) def simpo_loss( self, policy_chosen_logps: torch.FloatTensor, policy_rejected_logps: torch.FloatTensor, ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: """Compute the SimPO loss for a batch of policy model log probabilities. Args: policy_chosen_logps: Log probabilities of the policy model for the chosen responses. \ Shape: (batch_size,) policy_rejected_logps: Log probabilities of the policy model for the rejected responses.\ Shape: (batch_size,) Returns: A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). The losses tensor contains the SimPO loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses,\ respectively. """ pi_logratios = policy_chosen_logps - policy_rejected_logps gamma_logratios = self.gamma / self.beta pi_logratios = pi_logratios.to(self.accelerator.device) logits = pi_logratios - gamma_logratios if self.loss_type == 'sigmoid': losses = (-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) - F.logsigmoid(-self.beta * logits) * self.label_smoothing) elif self.loss_type == 'hinge': losses = torch.relu(1 - self.beta * logits) else: raise ValueError(f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge']") chosen_rewards = self.beta * policy_chosen_logps.to(self.accelerator.device).detach() rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach() return losses, chosen_rewards, rejected_rewards # monkey patching trainer.DEFAULT_PROGRESS_CALLBACK = ProgressCallbackNew trainer.DEFAULT_CALLBACKS = [DefaultFlowCallbackNew] trainer.PrinterCallback = PrinterCallbackNew