from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn from transformers import PreTrainedModel from trl import ORPOTrainer as HFORPOTrainer from swift.llm.utils.template import Template from swift.utils import get_logger from .mixin import PushToMsHubMixin, SwiftMixin from .utils import build_tokenized_answer, patch_trl, sort_by_max_length logger = get_logger() class ORPOTrainer(PushToMsHubMixin, SwiftMixin, HFORPOTrainer): def __init__(self, *args, template: Template, test_oom_error=False, **kwargs): self.template = template template._is_training = True self.streaming = kwargs.pop('streaming') is_vision = kwargs.pop('is_vision') patch_trl(is_vision) self.processed_keys = [] # keys after tokenize_row mapiing self.column_names = list(next(iter(kwargs.get('train_dataset'))).keys()) self._data_keys = [] # vision related key in _data self.need_filter: bool = False super().__init__(*args, **kwargs) self.train_dataset = self.train_dataset.remove_columns(self.column_names) if self.eval_dataset is not None: self.eval_dataset = self.eval_dataset.remove_columns(self.column_names) if self.need_filter: self.train_dataset = self.train_dataset.filter(lambda x: x['prompt_input_ids'] is not None) if self.eval_dataset is not None: self.eval_dataset = self.eval_dataset.filter(lambda x: x['prompt_input_ids'] is not None) if not self.streaming: train_ds_info = self.stat_dataset(self.train_dataset, self.is_encoder_decoder) if self.eval_dataset is not None: val_ds_info = self.stat_dataset(self.eval_dataset, self.is_encoder_decoder) self.dataset_info = {'train_dataset': train_ds_info, 'val_dataset': val_ds_info} else: self.dataset_info = {'train_dataset': train_ds_info} if test_oom_error: self.train_dataset = sort_by_max_length(self.train_dataset, 20000, self.is_encoder_decoder) # performance self.perf: Dict[str, Any] = { 'gen_time': 0., 'gen_len': 0, 'memory': {}, 'model': self.model.get_trainable_parameters() if hasattr(self.model, 'get_trainable_parameters') else None, } self.model.config.model_type = self.model.config.model_type[:-1] # remove suffix self.is_vision_model = is_vision def train(self, *args, **kwargs) -> torch.Tensor: res = super().train(*args, **kwargs) for i in range(torch.cuda.device_count()): self.perf['memory'][f'cuda:{i}'] = f'{torch.cuda.max_memory_reserved(i)/1024/1024/1024:.2f}GiB' return res def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None) -> Dict: batch = {} if not self.is_encoder_decoder: # encode without response prompt = feature.copy() prompt['response'] = None prompt_tokens = self.template.encode(prompt)[0] prompt_tokens.pop('labels', None) # Skip examples that have too lengthy prompt to avoid conflict in following processing if 'input_ids' not in prompt_tokens: self.need_filter = True return {k: None for k in self.processed_keys} # for MLLM, pop vision related data to process after if '_data' in prompt_tokens: if not self._data_keys: self._data_keys = prompt_tokens['_data'].keys() for key in prompt_tokens['_data'].keys(): if key not in prompt_tokens: prompt_tokens[key] = prompt_tokens['_data'][key] prompt_tokens.pop('_data') # convert bfloat16 to float32 to avoid conflict in mapping if 'pixel_values' in prompt_tokens and prompt_tokens['pixel_values'].dtype == torch.bfloat16: prompt_tokens['pixel_values'] = prompt_tokens['pixel_values'].to(torch.float32) if 'attention_mask' not in prompt_tokens: prompt_tokens['attention_mask'] = [1] * len(prompt_tokens['input_ids']) prompt_tokens = {f'prompt_{k}': v for k, v in prompt_tokens.items()} # encode with response chosen_tokens = build_tokenized_answer(feature['response'], self.template) chosen_tokens.update(prompt_tokens) rejected_tokens = build_tokenized_answer(feature['rejected_response'], self.template) rejected_tokens.update(prompt_tokens) longer_response_length = max(len(chosen_tokens['input_ids']), len(rejected_tokens['input_ids'])) # if combined sequence is too long, truncate the prompt for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]: if len(answer_tokens['prompt_input_ids']) + longer_response_length > self.max_length: if self.truncation_mode == 'keep_start': for k in ['prompt_input_ids', 'prompt_attention_mask']: answer_tokens[k] = answer_tokens[k][:self.max_prompt_length] elif self.truncation_mode == 'keep_end': for k in ['prompt_input_ids', 'prompt_attention_mask']: answer_tokens[k] = answer_tokens[k][-self.max_prompt_length:] else: raise ValueError(f'Unknown truncation mode: {self.truncation_mode}') # if that's still too long, truncate the response for answer_tokens in [chosen_tokens, rejected_tokens]: if len(answer_tokens['prompt_input_ids']) + longer_response_length > self.max_length: for k in ['input_ids', 'attention_mask']: answer_tokens[k] = answer_tokens[k][:self.max_length - self.max_prompt_length] # Create labels chosen_sequence_tokens = { k: chosen_tokens[f'prompt_{k}'] + chosen_tokens[k] for k in ['input_ids', 'attention_mask'] } rejected_sequence_tokens = { k: rejected_tokens[f'prompt_{k}'] + rejected_tokens[k] for k in ['input_ids', 'attention_mask'] } chosen_sequence_tokens['labels'] = chosen_sequence_tokens['input_ids'][:] _paddings = [self.label_pad_token_id] * len(chosen_tokens['prompt_input_ids']) chosen_sequence_tokens['labels'][:len(chosen_tokens['prompt_input_ids'])] = _paddings rejected_sequence_tokens['labels'] = rejected_sequence_tokens['input_ids'][:] _paddings = [self.label_pad_token_id] * len(rejected_tokens['prompt_input_ids']) rejected_sequence_tokens['labels'][:len(rejected_tokens['prompt_input_ids'])] = _paddings for k, toks in { 'chosen_': chosen_sequence_tokens, 'rejected_': rejected_sequence_tokens, '': prompt_tokens, }.items(): for type_key, tokens in toks.items(): if type_key == 'token_type_ids': continue batch[f'{k}{type_key}'] = tokens else: # encoder-decoder prompt = feature.copy() prompt['response'] = None prompt_tokens = self.template.encode(prompt)[0] prompt_tokens.pop('labels', None) if '_data' in prompt_tokens: if not self._data_keys: self._data_keys = prompt_tokens['_data'].keys() for key in prompt_tokens['_data'].keys(): if key not in prompt_tokens: prompt_tokens[key] = prompt_tokens['_data'][key] prompt_tokens.pop('_data') if 'pixel_values' in prompt_tokens and prompt_tokens['pixel_values'].dtype == torch.bfloat16: # datasets do not accept bfloat16; convert to float32. prompt_tokens['pixel_values'] = prompt_tokens['pixel_values'].to(torch.float32) if 'attention_mask' not in prompt_tokens: prompt_tokens['attention_mask'] = [1] * len(prompt_tokens['input_ids']) prompt_tokens = {f'prompt_{k}': v for k, v in prompt_tokens.items()} # encode with response chosen_tokens = build_tokenized_answer(feature['response'], self.template) rejected_tokens = build_tokenized_answer(feature['rejected_response'], self.template) batch['chosen_labels'] = chosen_tokens['input_ids'] batch['rejected_labels'] = rejected_tokens['input_ids'] if model is not None and hasattr(model, 'prepare_decoder_input_ids_from_labels'): batch['rejected_decoder_input_ids'] = model.prepare_decoder_input_ids_from_labels( labels=torch.tensor(batch['rejected_labels'])) batch['chosen_decoder_input_ids'] = model.prepare_decoder_input_ids_from_labels( labels=torch.tensor(batch['chosen_labels'])) batch.update(prompt_tokens) if not self.processed_keys: self.processed_keys = (list(batch.keys())) return batch def concatenated_forward( self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. We do this to avoid doing two forward passes, because it's faster for FSDP. """ concatenated_batch = self.concatenated_inputs( batch, is_encoder_decoder=self.is_encoder_decoder, label_pad_token_id=self.label_pad_token_id, padding_value=self.padding_value, device=self.accelerator.device, ) if self.is_vision_model: concatenated_batch = self.concatenated_vision_inputs(batch, concatenated_batch) len_chosen = batch['chosen_labels'].shape[0] if self.is_encoder_decoder and self.decoder_start_token_id is None: self.decoder_start_token_id = self.tokenizer.pad_token_id model_kwargs = ({ 'decoder_input_ids': self._shift_right(concatenated_batch['concatenated_labels']), } if self.is_encoder_decoder else {}) if self.is_vision_model: # Here, we restore the _data, processing image information within the forward hook of the model. batch_size = concatenated_batch['concatenated_input_ids'].shape[0] if self._data_keys: _data = [dict() for _ in range(batch_size)] for k in self._data_keys: if k == 'input_ids': _data = [{**d, k: concatenated_batch['concatenated_input_ids'][i]} for i, d in enumerate(_data)] elif k == 'pixel_values': # convert the dtype of the pixel values that may be converted to float32 in tokenize_row model_dtype = self.accelerator.unwrap_model(model).dtype # for vision related data, paired response share the same one _data = [{**d, k: concatenated_batch[k][i // 2].to(model_dtype)} for i, d in enumerate(_data)] else: _data = [{**d, k: concatenated_batch[k][i // 2]} for i, d in enumerate(_data)] model_kwargs['_data'] = _data if 'images' in concatenated_batch: model_kwargs['images'] = concatenated_batch['images'] if self.aux_loss_enabled: model_kwargs['output_router_logits'] = True outputs = model( input_ids=concatenated_batch['concatenated_input_ids'], attention_mask=concatenated_batch['concatenated_attention_mask'], use_cache=False, **model_kwargs, ) all_logits = outputs.logits if all_logits.shape[:2] != concatenated_batch['concatenated_labels'].shape[:2]: # for llava, the model returns logits for the entire sequence, # including the image tokens (placed before the text tokens) seq_len = concatenated_batch['concatenated_labels'].shape[1] all_logits = all_logits[:, -seq_len:] def cross_entropy_loss(logits, labels): if not self.is_encoder_decoder: # Shift so that tokens < n predict n logits = logits[..., :-1, :].contiguous() labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = nn.CrossEntropyLoss() logits = logits.view(-1, logits.shape[-1]) labels = labels.view(-1) # Enable model parallelism labels = labels.to(logits.device) loss = loss_fct(logits, labels) return loss if self.is_encoder_decoder: labels = concatenated_batch['concatenated_labels'].clone() else: labels = concatenated_batch['concatenated_input_ids'].clone() attention_mask = concatenated_batch['concatenated_attention_mask'] labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id) chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) all_logps = self.get_batch_logps( all_logits, concatenated_batch['concatenated_labels'], average_log_prob=True, is_encoder_decoder=self.is_encoder_decoder, label_pad_token_id=self.label_pad_token_id, ) chosen_logps = all_logps[:len_chosen] rejected_logps = all_logps[len_chosen:] chosen_logits = all_logits[:len_chosen] rejected_logits = all_logits[len_chosen:] if self.aux_loss_enabled: return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss) return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss) @staticmethod def concatenated_vision_inputs( batch: Dict[str, Union[List, torch.LongTensor]], concatenated_batch: Dict[str, torch.LongTensor], device: Optional[torch.device] = None, ) -> Dict[str, torch.LongTensor]: if 'prompt_pixel_values' in batch: pixel_values = [values for values in batch['prompt_pixel_values']] concatenated_batch['pixel_values'] = pixel_values if 'prompt_image_flags' in batch: image_flags = [torch.tensor(flags) for flags in batch['prompt_image_flags']] concatenated_batch['image_flags'] = image_flags if 'prompt_pixel_attention_mask' in batch: pixel_attention_mask = [mask for mask in batch['pixel_attention_mask']] concatenated_batch['pixel_attention_mask'] = pixel_attention_mask if 'prompt_image_sizes' in batch: concatenated_batch['image_sizes'] = batch['prompt_image_sizes'] if 'prompt_images' in batch: # images not in _data, we manually execute data collector here concatenated_batch['images'] = batch['prompt_images'].squeeze(1).repeat(2, 1, 1, 1).to(device=device) return concatenated_batch @staticmethod def stat_dataset(llm_dataset, is_encoder_decoder: bool = False) -> Any: _token_len = [] from datasets import Dataset as HfDataset from swift.utils.np_utils import stat_array if isinstance(llm_dataset, HfDataset): if is_encoder_decoder: prompt = llm_dataset['prompt_input_ids'] chosen = llm_dataset['chosen_labels'] rejected = llm_dataset['chosen_labels'] for p, cc, rr in zip(prompt, chosen, rejected): _token_len.append(max(len(cc), len(rr)) + len(p)) else: chosen = llm_dataset['chosen_input_ids'] rejected = llm_dataset['rejected_input_ids'] for cc, rr in zip(chosen, rejected): _token_len.append(max(len(cc), len(rr))) else: for d in llm_dataset: if is_encoder_decoder: _token_len.append( max(len(d['chosen_labels']), len(d['chosen_labels'])) + len(d['prompt_input_ids'])) else: _token_len.append(max(len(d['chosen_input_ids']), len(d['rejected_input_ids']))) _, stat_str = stat_array(_token_len) logger.info(f'Dataset Token Length: {stat_str}') return stat_str