Commit d5878167 authored by mashun1's avatar mashun1
Browse files

llava-next

parents
Pipeline #2589 failed with stages
in 0 seconds
# Copyright 2023 DDPO-pytorch authors (Kevin Black), metric-space, The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
from collections import defaultdict
from concurrent import futures
from typing import Any, Callable, Optional, Tuple
from warnings import warn
import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import whoami
from ..models import DDPOStableDiffusionPipeline
from . import BaseTrainer, DDPOConfig
from .utils import PerPromptStatTracker
logger = get_logger(__name__)
MODEL_CARD_TEMPLATE = """---
license: apache-2.0
tags:
- trl
- ddpo
- diffusers
- reinforcement-learning
- text-to-image
- stable-diffusion
---
# {model_name}
This is a diffusion model that has been fine-tuned with reinforcement learning to
guide the model outputs according to a value, function, or human feedback. The model can be used for image generation conditioned with text.
"""
class DDPOTrainer(BaseTrainer):
"""
The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
Note, this trainer is heavily inspired by the work here: https://github.com/kvablack/ddpo-pytorch
As of now only Stable Diffusion based pipelines are supported
Attributes:
**config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more
details.
**reward_function** (Callable[[torch.Tensor, Tuple[str], Tuple[Any]], torch.Tensor]) -- Reward function to be used
**prompt_function** (Callable[[], Tuple[str, Any]]) -- Function to generate prompts to guide model
**sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training.
**image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images
"""
_tag_names = ["trl", "ddpo"]
def __init__(
self,
config: DDPOConfig,
reward_function: Callable[[torch.Tensor, Tuple[str], Tuple[Any]], torch.Tensor],
prompt_function: Callable[[], Tuple[str, Any]],
sd_pipeline: DDPOStableDiffusionPipeline,
image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
):
if image_samples_hook is None:
warn("No image_samples_hook provided; no images will be logged")
self.prompt_fn = prompt_function
self.reward_fn = reward_function
self.config = config
self.image_samples_callback = image_samples_hook
accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
if self.config.resume_from:
self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
if "checkpoint_" not in os.path.basename(self.config.resume_from):
# get the most recent checkpoint in this directory
checkpoints = list(
filter(
lambda x: "checkpoint_" in x,
os.listdir(self.config.resume_from),
)
)
if len(checkpoints) == 0:
raise ValueError(f"No checkpoints found in {self.config.resume_from}")
checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
self.config.resume_from = os.path.join(
self.config.resume_from,
f"checkpoint_{checkpoint_numbers[-1]}",
)
accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
# number of timesteps within each trajectory to train on
self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction)
self.accelerator = Accelerator(
log_with=self.config.log_with,
mixed_precision=self.config.mixed_precision,
project_config=accelerator_project_config,
# we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
# number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
# the total number of optimizer steps to accumulate across.
gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps,
**self.config.accelerator_kwargs,
)
is_okay, message = self._config_check()
if not is_okay:
raise ValueError(message)
is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
if self.accelerator.is_main_process:
self.accelerator.init_trackers(
self.config.tracker_project_name,
config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(),
init_kwargs=self.config.tracker_kwargs,
)
logger.info(f"\n{config}")
set_seed(self.config.seed, device_specific=True)
self.sd_pipeline = sd_pipeline
self.sd_pipeline.set_progress_bar_config(
position=1,
disable=not self.accelerator.is_local_main_process,
leave=False,
desc="Timestep",
dynamic_ncols=True,
)
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
if self.accelerator.mixed_precision == "fp16":
inference_dtype = torch.float16
elif self.accelerator.mixed_precision == "bf16":
inference_dtype = torch.bfloat16
else:
inference_dtype = torch.float32
self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
trainable_layers = self.sd_pipeline.get_trainable_layers()
self.accelerator.register_save_state_pre_hook(self._save_model_hook)
self.accelerator.register_load_state_pre_hook(self._load_model_hook)
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if self.config.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
self.optimizer = self._setup_optimizer(trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers)
self.neg_prompt_embed = self.sd_pipeline.text_encoder(
self.sd_pipeline.tokenizer(
[""] if self.config.negative_prompts is None else self.config.negative_prompts,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.sd_pipeline.tokenizer.model_max_length,
).input_ids.to(self.accelerator.device)
)[0]
if config.per_prompt_stat_tracking:
self.stat_tracker = PerPromptStatTracker(
config.per_prompt_stat_tracking_buffer_size,
config.per_prompt_stat_tracking_min_count,
)
# NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
# more memory
self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
else:
self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
if self.config.async_reward_computation:
self.executor = futures.ThreadPoolExecutor(max_workers=config.max_workers)
if config.resume_from:
logger.info(f"Resuming from {config.resume_from}")
self.accelerator.load_state(config.resume_from)
self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
else:
self.first_epoch = 0
def compute_rewards(self, prompt_image_pairs, is_async=False):
if not is_async:
rewards = []
for images, prompts, prompt_metadata in prompt_image_pairs:
reward, reward_metadata = self.reward_fn(images, prompts, prompt_metadata)
rewards.append(
(
torch.as_tensor(reward, device=self.accelerator.device),
reward_metadata,
)
)
else:
rewards = self.executor.map(lambda x: self.reward_fn(*x), prompt_image_pairs)
rewards = [(torch.as_tensor(reward.result(), device=self.accelerator.device), reward_metadata.result()) for reward, reward_metadata in rewards]
return zip(*rewards)
def step(self, epoch: int, global_step: int):
"""
Perform a single step of training.
Args:
epoch (int): The current epoch.
global_step (int): The current global step.
Side Effects:
- Model weights are updated
- Logs the statistics to the accelerator trackers.
- If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
Returns:
global_step (int): The updated global step.
"""
samples, prompt_image_data = self._generate_samples(
iterations=self.config.sample_num_batches_per_epoch,
batch_size=self.config.sample_batch_size,
)
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
rewards, rewards_metadata = self.compute_rewards(prompt_image_data, is_async=self.config.async_reward_computation)
for i, image_data in enumerate(prompt_image_data):
image_data.extend([rewards[i], rewards_metadata[i]])
if self.image_samples_callback is not None:
self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0])
rewards = torch.cat(rewards)
rewards = self.accelerator.gather(rewards).cpu().numpy()
self.accelerator.log(
{
"reward": rewards,
"epoch": epoch,
"reward_mean": rewards.mean(),
"reward_std": rewards.std(),
},
step=global_step,
)
if self.config.per_prompt_stat_tracking:
# gather the prompts across processes
prompt_ids = self.accelerator.gather(samples["prompt_ids"]).cpu().numpy()
prompts = self.sd_pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True)
advantages = self.stat_tracker.update(prompts, rewards)
else:
advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
# ungather advantages; keep the entries corresponding to the samples on this process
samples["advantages"] = torch.as_tensor(advantages).reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index].to(self.accelerator.device)
del samples["prompt_ids"]
total_batch_size, num_timesteps = samples["timesteps"].shape
for inner_epoch in range(self.config.train_num_inner_epochs):
# shuffle samples along batch dimension
perm = torch.randperm(total_batch_size, device=self.accelerator.device)
samples = {k: v[perm] for k, v in samples.items()}
# shuffle along time dimension independently for each sample
# still trying to understand the code below
perms = torch.stack([torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)])
for key in ["timesteps", "latents", "next_latents", "log_probs"]:
samples[key] = samples[key][
torch.arange(total_batch_size, device=self.accelerator.device)[:, None],
perms,
]
original_keys = samples.keys()
original_values = samples.values()
# rebatch them as user defined train_batch_size is different from sample_batch_size
reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values]
# Transpose the list of original values
transposed_values = zip(*reshaped_values)
# Create new dictionaries for each row of transposed values
samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values]
self.sd_pipeline.unet.train()
global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched)
# ensure optimization step at the end of the inner epoch
if not self.accelerator.sync_gradients:
raise ValueError("Optimization step should have been performed by this point. Please check calculated gradient accumulation settings.")
if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
self.accelerator.save_state()
return global_step
def calculate_loss(self, latents, timesteps, next_latents, log_probs, advantages, embeds):
"""
Calculate the loss for a batch of an unpacked sample
Args:
latents (torch.Tensor):
The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
timesteps (torch.Tensor):
The timesteps sampled from the diffusion model, shape: [batch_size]
next_latents (torch.Tensor):
The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
log_probs (torch.Tensor):
The log probabilities of the latents, shape: [batch_size]
advantages (torch.Tensor):
The advantages of the latents, shape: [batch_size]
embeds (torch.Tensor):
The embeddings of the prompts, shape: [2*batch_size or batch_size, ...]
Note: the "or" is because if train_cfg is True, the expectation is that negative prompts are concatenated to the embeds
Returns:
loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor)
(all of these are of shape (1,))
"""
with self.autocast():
if self.config.train_cfg:
noise_pred = self.sd_pipeline.unet(
torch.cat([latents] * 2),
torch.cat([timesteps] * 2),
embeds,
).sample
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.config.sample_guidance_scale * (noise_pred_text - noise_pred_uncond)
else:
noise_pred = self.sd_pipeline.unet(
latents,
timesteps,
embeds,
).sample
# compute the log prob of next_latents given latents under the current model
scheduler_step_output = self.sd_pipeline.scheduler_step(
noise_pred,
timesteps,
latents,
eta=self.config.sample_eta,
prev_sample=next_latents,
)
log_prob = scheduler_step_output.log_probs
advantages = torch.clamp(
advantages,
-self.config.train_adv_clip_max,
self.config.train_adv_clip_max,
)
ratio = torch.exp(log_prob - log_probs)
loss = self.loss(advantages, self.config.train_clip_range, ratio)
approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2)
clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float())
return loss, approx_kl, clipfrac
def loss(
self,
advantages: torch.Tensor,
clip_range: float,
ratio: torch.Tensor,
):
unclipped_loss = -advantages * ratio
clipped_loss = -advantages * torch.clamp(
ratio,
1.0 - clip_range,
1.0 + clip_range,
)
return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
def _setup_optimizer(self, trainable_layers_parameters):
if self.config.train_use_8bit_adam:
import bitsandbytes
optimizer_cls = bitsandbytes.optim.AdamW8bit
else:
optimizer_cls = torch.optim.AdamW
return optimizer_cls(
trainable_layers_parameters,
lr=self.config.train_learning_rate,
betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
weight_decay=self.config.train_adam_weight_decay,
eps=self.config.train_adam_epsilon,
)
def _save_model_hook(self, models, weights, output_dir):
self.sd_pipeline.save_checkpoint(models, weights, output_dir)
weights.pop() # ensures that accelerate doesn't try to handle saving of the model
def _load_model_hook(self, models, input_dir):
self.sd_pipeline.load_checkpoint(models, input_dir)
models.pop() # ensures that accelerate doesn't try to handle loading of the model
def _generate_samples(self, iterations, batch_size):
"""
Generate samples from the model
Args:
iterations (int): Number of iterations to generate samples for
batch_size (int): Batch size to use for sampling
Returns:
samples (List[Dict[str, torch.Tensor]]), prompt_image_pairs (List[List[Any]])
"""
samples = []
prompt_image_pairs = []
self.sd_pipeline.unet.eval()
sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
for _ in range(iterations):
prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
prompt_ids = self.sd_pipeline.tokenizer(
prompts,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.sd_pipeline.tokenizer.model_max_length,
).input_ids.to(self.accelerator.device)
prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
with self.autocast():
sd_output = self.sd_pipeline(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=sample_neg_prompt_embeds,
num_inference_steps=self.config.sample_num_steps,
guidance_scale=self.config.sample_guidance_scale,
eta=self.config.sample_eta,
output_type="pt",
)
images = sd_output.images
latents = sd_output.latents
log_probs = sd_output.log_probs
latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, ...)
log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
timesteps = self.sd_pipeline.scheduler.timesteps.repeat(batch_size, 1) # (batch_size, num_steps)
samples.append(
{
"prompt_ids": prompt_ids,
"prompt_embeds": prompt_embeds,
"timesteps": timesteps,
"latents": latents[:, :-1], # each entry is the latent before timestep t
"next_latents": latents[:, 1:], # each entry is the latent after timestep t
"log_probs": log_probs,
"negative_prompt_embeds": sample_neg_prompt_embeds,
}
)
prompt_image_pairs.append([images, prompts, prompt_metadata])
return samples, prompt_image_pairs
def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples):
"""
Train on a batch of samples. Main training segment
Args:
inner_epoch (int): The current inner epoch
epoch (int): The current epoch
global_step (int): The current global step
batched_samples (List[Dict[str, torch.Tensor]]): The batched samples to train on
Side Effects:
- Model weights are updated
- Logs the statistics to the accelerator trackers.
Returns:
global_step (int): The updated global step
"""
info = defaultdict(list)
for i, sample in enumerate(batched_samples):
if self.config.train_cfg:
# concat negative prompts to sample prompts to avoid two forward passes
embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]])
else:
embeds = sample["prompt_embeds"]
for j in range(self.num_train_timesteps):
with self.accelerator.accumulate(self.sd_pipeline.unet):
loss, approx_kl, clipfrac = self.calculate_loss(
sample["latents"][:, j],
sample["timesteps"][:, j],
sample["next_latents"][:, j],
sample["log_probs"][:, j],
sample["advantages"],
embeds,
)
info["approx_kl"].append(approx_kl)
info["clipfrac"].append(clipfrac)
info["loss"].append(loss)
self.accelerator.backward(loss)
if self.accelerator.sync_gradients:
self.accelerator.clip_grad_norm_(
self.trainable_layers.parameters() if not isinstance(self.trainable_layers, list) else self.trainable_layers,
self.config.train_max_grad_norm,
)
self.optimizer.step()
self.optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if self.accelerator.sync_gradients:
# log training-related stuff
info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
info = self.accelerator.reduce(info, reduction="mean")
info.update({"epoch": epoch, "inner_epoch": inner_epoch})
self.accelerator.log(info, step=global_step)
global_step += 1
info = defaultdict(list)
return global_step
def _config_check(self) -> Tuple[bool, str]:
samples_per_epoch = self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch
total_train_batch_size = self.config.train_batch_size * self.accelerator.num_processes * self.config.train_gradient_accumulation_steps
if not self.config.sample_batch_size >= self.config.train_batch_size:
return (
False,
f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})",
)
if not self.config.sample_batch_size % self.config.train_batch_size == 0:
return (
False,
f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})",
)
if not samples_per_epoch % total_train_batch_size == 0:
return (
False,
f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})",
)
return True, ""
def train(self, epochs: Optional[int] = None):
"""
Train the model for a given number of epochs
"""
global_step = 0
if epochs is None:
epochs = self.config.num_epochs
for epoch in range(self.first_epoch, epochs):
global_step = self.step(epoch, global_step)
def create_model_card(self, path: str, model_name: Optional[str] = "TRL DDPO Model") -> None:
"""Creates and saves a model card for a TRL model.
Args:
path (`str`): The path to save the model card to.
model_name (`str`, *optional*): The name of the model, defaults to `TRL DDPO Model`.
"""
try:
user = whoami()["name"]
# handle the offline case
except: # noqa
warnings.warn("Cannot retrieve user information assuming you are running in offline mode.")
return
if not os.path.exists(path):
os.makedirs(path)
model_card_content = MODEL_CARD_TEMPLATE.format(model_name=model_name, model_id=f"{user}/{path}")
with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f:
f.write(model_card_content)
def _save_pretrained(self, save_directory):
self.sd_pipeline.save_pretrained(save_directory)
self.create_model_card(save_directory)
# DPO Authors: Rafael Rafailov, Archit Sharma, Eric Mitchell, Stefano Ermon, Christopher D. Manning, and Chelsea Finn 2023
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import random
import warnings
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from copy import deepcopy
from functools import wraps
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate import PartialState
from accelerate.utils import is_deepspeed_available, tqdm
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import (
AutoModelForCausalLM,
DataCollator,
PreTrainedModel,
PreTrainedTokenizerBase,
Trainer,
TrainingArguments,
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from ..import_utils import is_peft_available, is_wandb_available
from ..models import PreTrainedModelWrapper, create_reference_model
from .utils import (
DPODataCollatorWithPadding,
disable_dropout_in_model,
pad_to_length,
peft_module_casting_to_bf16,
trl_sanitze_kwargs_for_tagging,
)
if is_peft_available():
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
if is_wandb_available():
import wandb
if is_deepspeed_available():
import deepspeed
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
class DPOTrainer(Trainer):
r"""
Initialize DPOTrainer.
Args:
model (`transformers.PreTrainedModel`):
The model to train, preferably an `AutoModelForSequenceClassification`.
ref_model (`PreTrainedModelWrapper`):
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
beta (`float`, defaults to 0.1):
The beta factor in DPO loss. Higher beta means less divergence from the initial policy. For the IPO loss, beta is the regularization parameter denoted by tau in the paper.
label_smoothing (`float`, defaults to 0):
The robust DPO label smoothing parameter from the [cDPO](https://ericmitchell.ai/cdpo.pdf) report that should be between 0 and 0.5.
loss_type (`str`, defaults to `"sigmoid"`):
The type of DPO loss to use. Either `"sigmoid"` the default DPO loss,`"hinge"` loss from [SLiC](https://arxiv.org/abs/2305.10425) paper, `"ipo"` from [IPO](https://arxiv.org/abs/2310.12036) paper, or `"kto"` from the HALOs [report](https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf).
args (`transformers.TrainingArguments`):
The arguments to use for training.
data_collator (`transformers.DataCollator`):
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
label_pad_token_id (`int`, defaults to `-100`):
The label pad token id. This argument is required if you want to use the default data collator.
padding_value (`int`, defaults to `0`):
The padding value if it is different to the tokenizer's pad_token_id.
truncation_mode (`str`, defaults to `keep_end`):
The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator.
train_dataset (`datasets.Dataset`):
The dataset to use for training.
eval_dataset (`datasets.Dataset`):
The dataset to use for evaluation.
tokenizer (`transformers.PreTrainedTokenizerBase`):
The tokenizer to use for training. This argument is required if you want to use the default data collator.
model_init (`Callable[[], transformers.PreTrainedModel]`):
The model initializer to use for training. If None is specified, the default model initializer will be used.
callbacks (`List[transformers.TrainerCallback]`):
The callbacks to use for training.
optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
The optimizer and scheduler to use for training.
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
The function to use to preprocess the logits before computing the metrics.
max_length (`int`, defaults to `None`):
The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
max_prompt_length (`int`, defaults to `None`):
The maximum length of the prompt. This argument is required if you want to use the default data collator.
max_target_length (`int`, defaults to `None`):
The maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder.
peft_config (`Dict`, defaults to `None`):
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`):
If no model is provided, we need to know if the model_init returns an encoder-decoder.
disable_dropout (`bool`, defaults to `True`):
Whether or not to disable dropouts in `model` and `ref_model`.
generate_during_eval (`bool`, defaults to `False`):
Whether to sample and log generations during evaluation step.
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
The function to use to compute the metrics. Must take a `EvalPrediction` and return
a dictionary string to metric values.
precompute_ref_log_probs (`bool`, defaults to `False`):
Flag to precompute reference model log probabilities and evaluation datasets. This is useful if you want to train
without the reference model and reduce the total GPU memory needed.
dataset_num_proc (`Optional[int]`, *optional*):
The number of workers to use to tokenize the data. Defaults to None.
model_init_kwargs (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when instantiating the model from a string
ref_model_init_kwargs (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when instantiating the ref model from a string
model_adapter_name (`str`, defaults to `None`):
Name of the train target PEFT adapter, when using LoRA with multiple adapters.
ref_adapter_name (`str`, defaults to `None`):
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
reference_free (`bool`):
If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.
"""
_tag_names = ["trl", "dpo"]
def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
dpo_alpha: float = 1.0,
beta: float = 0.1,
gamma: float = 0.1,
label_smoothing: float = 0,
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid",
args: Optional[TrainingArguments] = None,
data_collator: Optional[DataCollator] = None,
label_pad_token_id: int = -100,
padding_value: Optional[int] = None,
truncation_mode: str = "keep_end",
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
max_length: Optional[int] = None,
max_prompt_length: Optional[int] = None,
max_target_length: Optional[int] = None,
peft_config: Optional[Dict] = None,
is_encoder_decoder: Optional[bool] = None,
disable_dropout: bool = True,
generate_during_eval: bool = False,
compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
precompute_ref_log_probs: bool = False,
dataset_num_proc: Optional[int] = None,
model_init_kwargs: Optional[Dict] = None,
ref_model_init_kwargs: Optional[Dict] = None,
model_adapter_name: Optional[str] = None,
ref_adapter_name: Optional[str] = None,
reference_free: bool = False,
):
# import pdb;pdb.set_trace()
if model_init_kwargs is None:
model_init_kwargs = {}
elif not isinstance(model, str):
raise ValueError("You passed model_kwargs to the DPOTrainer. But your model is already instantiated.")
if ref_model_init_kwargs is None:
ref_model_init_kwargs = {}
elif not isinstance(ref_model, str):
raise ValueError("You passed ref_model_kwargs to the DPOTrainer. But your ref_model is already instantiated.")
if isinstance(model, str):
warnings.warn("You passed a model_id to the DPOTrainer. This will automatically create an " "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you.")
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
if isinstance(ref_model, str):
warnings.warn("You passed a ref model_id to the DPOTrainer. This will automatically create an " "`AutoModelForCausalLM`")
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
# has been called in order to properly call autocast if needed.
self._peft_has_been_casted_to_bf16 = False
if generate_during_eval and not is_wandb_available():
raise ValueError("`generate_during_eval=True` requires Weights and Biases to be installed." " Please install `wandb` to resolve.")
if model is not None:
self.is_encoder_decoder = model.config.is_encoder_decoder
elif is_encoder_decoder is None:
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
else:
self.is_encoder_decoder = is_encoder_decoder
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
self.model_adapter_name = model_adapter_name
self.ref_adapter_name = ref_adapter_name
self.reference_free = reference_free
if ref_model:
self.ref_model = ref_model
elif self.is_peft_model or precompute_ref_log_probs:
# The `model` with adapters turned off will be used as the reference model
self.ref_model = None
else:
if is_deepspeed_zero3_enabled():
self.ref_model = AutoModelForCausalLM.from_pretrained(model)
else:
self.ref_model = create_reference_model(model)
if tokenizer is None:
raise ValueError("tokenizer must be specified to tokenize a DPO dataset.")
if max_length is None:
warnings.warn(
"`max_length` is not set in the DPOTrainer's init" " it will default to `512` by default, but you should do it yourself in the future.",
UserWarning,
)
max_length = 512
if max_prompt_length is None:
warnings.warn(
"`max_prompt_length` is not set in the DPOTrainer's init" " it will default to `128` by default, but you should do it yourself in the future.",
UserWarning,
)
max_prompt_length = 128
if max_target_length is None and self.is_encoder_decoder:
warnings.warn(
"When using an encoder decoder architecture, you should set `max_target_length` in the DPOTrainer's init" " it will default to `128` by default, but you should do it yourself in the future.",
UserWarning,
)
max_target_length = 128
if data_collator is None:
data_collator = DPODataCollatorWithPadding(
pad_token_id=tokenizer.pad_token_id,
label_pad_token_id=label_pad_token_id,
is_encoder_decoder=self.is_encoder_decoder,
)
if args.remove_unused_columns:
args.remove_unused_columns = False
# warn users
warnings.warn(
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments" " we have set it for you, but you should do it yourself in the future.",
UserWarning,
)
self.use_dpo_data_collator = True
else:
self.use_dpo_data_collator = False
if disable_dropout:
disable_dropout_in_model(model)
if self.ref_model is not None:
disable_dropout_in_model(self.ref_model)
self.max_length = max_length
self.generate_during_eval = generate_during_eval
self.label_pad_token_id = label_pad_token_id
self.padding_value = padding_value if padding_value is not None else tokenizer.pad_token_id
self.max_prompt_length = max_prompt_length
self.truncation_mode = truncation_mode
self.max_target_length = max_target_length
self.tokenizer = tokenizer
self.precompute_ref_log_probs = precompute_ref_log_probs
# Since ref_logs are precomputed on the first call to get_train/eval_dataloader
# keep track of first called to avoid computation of future calls
self._precomputed_train_ref_log_probs = False
self._precomputed_eval_ref_log_probs = False
if loss_type in ["hinge", "ipo", "kto_pair"] and label_smoothing > 0:
warnings.warn("You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter.")
self.dpo_alpha = dpo_alpha
self.beta = beta
self.gamma = gamma
self.label_smoothing = label_smoothing
self.loss_type = loss_type
self._stored_metrics = defaultdict(lambda: defaultdict(list))
self.dataset_num_proc = dataset_num_proc
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
# with PartialState().local_main_process_first():
# # tokenize the dataset
# train_dataset = train_dataset.map(self.tokenize_row, num_proc=self.dataset_num_proc)
# if eval_dataset is not None:
# eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=self.dataset_num_proc)
super().__init__(
model=model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
model_init=model_init,
compute_metrics=compute_metrics,
callbacks=callbacks,
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
if not hasattr(self, "accelerator"):
raise AttributeError("Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`.")
# Deepspeed Zero-3 does not support precompute_ref_log_probs
if self.is_deepspeed_enabled:
if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
raise ValueError("You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`.")
if self.ref_model is None:
if not (self.is_peft_model or self.precompute_ref_log_probs):
raise ValueError("No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`")
else:
if self.is_deepspeed_enabled:
self.ref_model = self._prepare_deepspeed(self.ref_model)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
if model is not None:
if hasattr(model, "config"):
hidden_size = max(model.config.hidden_sizes) if getattr(model.config, "hidden_sizes", None) else getattr(model.config, "hidden_size", None)
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
config_kwargs.update(
{
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
}
)
# If ZeRO-3 is used, we shard both the active and reference model.
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
if config_kwargs["zero_optimization"]["stage"] != 3:
config_kwargs["zero_optimization"]["stage"] = 0
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
model.eval()
return model
def get_train_dataloader(self) -> DataLoader:
"""
Returns the training [`~torch.utils.data.DataLoader`].
Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
"""
if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
dataloader_params = {
"batch_size": self.args.per_device_train_batch_size,
"collate_fn": self.data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"shuffle": False,
}
# prepare dataloader
data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
reference_chosen_logps = []
reference_rejected_logps = []
for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
reference_chosen_logp, reference_rejected_logp = self.compute_reference_log_probs(padded_batch)
reference_chosen_logp, reference_rejected_logp = self.accelerator.gather_for_metrics((reference_chosen_logp, reference_rejected_logp))
reference_chosen_logps.append(reference_chosen_logp.cpu())
reference_rejected_logps.append(reference_rejected_logp.cpu())
all_reference_chosen_logps = torch.cat(reference_chosen_logps).float().numpy()
all_reference_rejected_logps = torch.cat(reference_rejected_logps).float().numpy()
self.train_dataset = self.train_dataset.add_column(name="reference_chosen_logps", column=all_reference_chosen_logps)
self.train_dataset = self.train_dataset.add_column(name="reference_rejected_logps", column=all_reference_rejected_logps)
self._precomputed_train_ref_log_probs = True
return super().get_train_dataloader()
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
"""
Returns the evaluation [`~torch.utils.data.DataLoader`].
Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
Args:
eval_dataset (`torch.utils.data.Dataset`, *optional*):
If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
by the `model.forward()` method are automatically removed. It must implement `__len__`.
"""
if eval_dataset is None and self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.")
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
dataloader_params = {
"batch_size": self.args.per_device_eval_batch_size,
"collate_fn": self.data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"shuffle": False,
}
# prepare dataloader
data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
reference_chosen_logps = []
reference_rejected_logps = []
for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
reference_chosen_logp, reference_rejected_logp = self.compute_reference_log_probs(padded_batch)
reference_chosen_logp, reference_rejected_logp = self.accelerator.gather_for_metrics((reference_chosen_logp, reference_rejected_logp))
reference_chosen_logps.append(reference_chosen_logp.cpu())
reference_rejected_logps.append(reference_rejected_logp.cpu())
all_reference_chosen_logps = torch.cat(reference_chosen_logps).float().numpy()
all_reference_rejected_logps = torch.cat(reference_rejected_logps).float().numpy()
eval_dataset = eval_dataset.add_column(name="reference_chosen_logps", column=all_reference_chosen_logps)
eval_dataset = eval_dataset.add_column(name="reference_rejected_logps", column=all_reference_rejected_logps)
# Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
if self.eval_dataset is not None:
self.eval_dataset = eval_dataset
self._precomputed_eval_ref_log_probs = True
return super().get_eval_dataloader(eval_dataset=eval_dataset)
def build_tokenized_answer(self, prompt, answer):
"""
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
Reference:
https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
"""
full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False)
prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"]
answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
# Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
# Prepare input tokens for token by token comparison
full_input_ids = np.array(full_tokenized["input_ids"])
if len(full_input_ids) != len(full_concat_input_ids):
raise ValueError("Prompt input ids and answer input ids should have the same length.")
# On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
# can be merged together when tokenizing prompt+answer. This could result
# on the last token from the prompt being different when tokenized on its own
# vs when done as prompt+answer.
response_token_ids_start_idx = len(prompt_input_ids)
# If tokenized prompt is different than both prompt+answer, then it means the
# last token has changed due to merging.
if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
response_token_ids_start_idx -= 1
prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
if len(prompt_input_ids) != len(prompt_attention_mask):
raise ValueError("Prompt input ids and attention mask should have the same length.")
answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
return dict(
prompt_input_ids=prompt_input_ids,
prompt_attention_mask=prompt_attention_mask,
input_ids=answer_input_ids,
attention_mask=answer_attention_mask,
)
def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> Dict:
"""Tokenize a single row from a DPO specific dataset.
At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
in case the prompt + chosen or prompt + rejected responses is/are too long. First
we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
We also create the labels for the chosen/rejected responses, which are of length equal to
the sum of the length of the prompt and the chosen/rejected response, with
label_pad_token_id for the prompt tokens.
"""
batch = {}
prompt = feature["prompt"]
chosen = feature["chosen"]
rejected = feature["rejected"]
if not self.is_encoder_decoder:
# Check issues below for more details
# 1. https://github.com/huggingface/trl/issues/907
# 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
# 3. https://github.com/LianjiaTech/BELLE/issues/337
if not isinstance(prompt, str):
raise ValueError(f"prompt should be an str but got {type(prompt)}")
prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)
prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
if not isinstance(chosen, str):
raise ValueError(f"chosen should be an str but got {type(chosen)}")
chosen_tokens = self.build_tokenized_answer(prompt, chosen)
if not isinstance(rejected, str):
raise ValueError(f"rejected should be an str but got {type(rejected)}")
rejected_tokens = self.build_tokenized_answer(prompt, rejected)
# Last prompt token might get merged by tokenizer and
# it should not be included for generation if that happens
prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
for k, v in prompt_tokens.items():
prompt_tokens[k] = v[:prompt_len_input_ids]
# Make sure prompts only have one different token at most an
# and length only differs by 1 at most
num_diff_tokens = sum([a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])])
num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
if num_diff_tokens > 1 or num_diff_len > 1:
raise ValueError("Chosen and rejected prompt_input_ids might only differ on the " "last token due to tokenizer merge ops.")
# add BOS token to head of prompt
prompt_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + prompt_tokens["prompt_input_ids"]
chosen_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + chosen_tokens["prompt_input_ids"]
rejected_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + rejected_tokens["prompt_input_ids"]
prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"]
chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"]
rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"]
# add EOS token to end of answer
chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id)
chosen_tokens["attention_mask"].append(1)
rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id)
rejected_tokens["attention_mask"].append(1)
longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
# if combined sequence is too long, truncate the prompt
for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
if self.truncation_mode == "keep_start":
for k in ["prompt_input_ids", "prompt_attention_mask"]:
answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
elif self.truncation_mode == "keep_end":
for k in ["prompt_input_ids", "prompt_attention_mask"]:
answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
else:
raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
# if that's still too long, truncate the response
for answer_tokens in [chosen_tokens, rejected_tokens]:
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
for k in ["input_ids", "attention_mask"]:
answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
# Create labels
chosen_sequence_tokens = {k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]}
rejected_sequence_tokens = {k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]}
chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [self.label_pad_token_id] * len(chosen_tokens["prompt_input_ids"])
rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [self.label_pad_token_id] * len(rejected_tokens["prompt_input_ids"])
for k, toks in {
"chosen_": chosen_sequence_tokens,
"rejected_": rejected_sequence_tokens,
"": prompt_tokens,
}.items():
for type_key, tokens in toks.items():
if type_key == "token_type_ids":
continue
batch[f"{k}{type_key}"] = tokens
else:
chosen_tokens = self.tokenizer(chosen, truncation=True, max_length=self.max_target_length, add_special_tokens=True)
rejected_tokens = self.tokenizer(rejected, truncation=True, max_length=self.max_target_length, add_special_tokens=True)
prompt_tokens = self.tokenizer(prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True)
batch["chosen_labels"] = chosen_tokens["input_ids"]
batch["rejected_labels"] = rejected_tokens["input_ids"]
batch["prompt_input_ids"] = prompt_tokens["input_ids"]
batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(labels=batch["rejected_labels"])
batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(labels=batch["chosen_labels"])
return batch
@contextmanager
def null_ref_context(self):
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
with self.accelerator.unwrap_model(self.model).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext():
if self.ref_adapter_name:
self.model.set_adapter(self.ref_adapter_name)
yield
if self.ref_adapter_name:
self.model.set_adapter(self.model_adapter_name or "default")
def compute_reference_log_probs(self, padded_batch: Dict) -> Dict:
"""Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset."""
compte_ref_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
# compute reference logps
with torch.no_grad(), compte_ref_context_manager():
if self.ref_model is None:
with self.null_ref_context():
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
) = self.concatenated_forward(self.model, padded_batch)
else:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
) = self.concatenated_forward(self.ref_model, padded_batch)
return reference_chosen_logps, reference_rejected_logps
@staticmethod
def concatenated_inputs(
batch: Dict[str, Union[List, torch.LongTensor]],
is_encoder_decoder: bool = False,
label_pad_token_id: int = -100,
padding_value: int = 0,
device: Optional[torch.device] = None,
) -> Dict[str, torch.LongTensor]:
"""Concatenate the chosen and rejected inputs into a single tensor.
Args:
batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
is_encoder_decoder: Whether the model is an encoder-decoder model.
label_pad_token_id: The label pad token id.
padding_value: The padding value to use for the concatenated inputs_ids.
device: The device for the concatenated inputs.
Returns:
A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
"""
concatenated_batch = {}
if is_encoder_decoder:
max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
else:
max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
for k in batch:
# import pdb; pdb.set_trace()
if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
if "labels" in k or is_encoder_decoder:
pad_value = label_pad_token_id
elif k.endswith("_input_ids"):
pad_value = padding_value
elif k.endswith("_attention_mask"):
pad_value = 0
concatenated_key = k.replace("chosen", "concatenated")
concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
for k in batch:
if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
if "labels" in k or is_encoder_decoder:
pad_value = label_pad_token_id
elif k.endswith("_input_ids"):
pad_value = padding_value
elif k.endswith("_attention_mask"):
pad_value = 0
concatenated_key = k.replace("rejected", "concatenated")
concatenated_batch[concatenated_key] = torch.cat(
(
concatenated_batch[concatenated_key],
pad_to_length(batch[k], max_length, pad_value=pad_value),
),
dim=0,
).to(device=device)
if is_encoder_decoder:
concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
concatenated_batch["concatenated_attention_mask"] = batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
# import pdb; pdb.set_trace()
# repeated_list = [
# batch['images'][0] * 2,
# batch['images'][1] * 2
# ]
concatenated_batch["concatenated_images"] = batch["images"] * 2
concatenated_batch["image_sizes"] = batch["image_sizes"] * 2
concatenated_batch["modalities"] = batch["modalities"] * 2
return concatenated_batch
def dpo_loss(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Compute the DPO loss for a batch of policy and reference model log probabilities.
Args:
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
Returns:
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
The losses tensor contains the DPO loss for each example in the batch.
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
"""
pi_logratios = policy_chosen_logps - policy_rejected_logps
if self.reference_free:
ref_logratios = torch.tensor([0], dtype=pi_logratios.dtype, device=pi_logratios.device)
else:
ref_logratios = reference_chosen_logps - reference_rejected_logps
pi_logratios = pi_logratios.to(self.accelerator.device)
ref_logratios = ref_logratios.to(self.accelerator.device)
logits = pi_logratios - ref_logratios
# print(f"pi log ratios: {pi_logratios}")
# print(f"ref log ratios: {ref_logratios}")
# print(f"logits: {logits}")
# The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5.
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
# calculates a conservative DPO loss.
if self.loss_type == "sigmoid":
losses = -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) - F.logsigmoid(-self.beta * logits) * self.label_smoothing
elif self.loss_type == "hinge":
losses = torch.relu(1 - self.beta * logits)
elif self.loss_type == "ipo":
# eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
losses = (logits - 1 / (2 * self.beta)) ** 2
elif self.loss_type == "kto_pair":
# eqn (7) of the HALOs paper
chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0)
rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0)
chosen_logratios = policy_chosen_logps - reference_chosen_logps
rejected_logratios = policy_rejected_logps - reference_rejected_logps
# As described in the KTO report, the KL term for chosen (rejected) is estimated using the rejected (chosen) half.
losses = torch.cat(
(
1 - F.sigmoid(self.beta * (chosen_logratios - rejected_KL)),
1 - F.sigmoid(self.beta * (chosen_KL - rejected_logratios)),
),
0,
)
else:
raise ValueError(f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair']")
chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device) - reference_chosen_logps.to(self.accelerator.device)).detach()
rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device) - reference_rejected_logps.to(self.accelerator.device)).detach()
return losses, chosen_rewards, rejected_rewards
@staticmethod
def get_batch_logps(
logits: torch.FloatTensor,
labels: torch.LongTensor,
average_log_prob: bool = False,
label_pad_token_id: int = -100,
is_encoder_decoder: bool = False,
) -> torch.FloatTensor:
"""Compute the log probabilities of the given labels under the given logits.
Args:
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
label_pad_token_id: The label pad token id.
is_encoder_decoder: Whether the model is an encoder-decoder model.
Returns:
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
"""
if logits.shape[:-1] != labels.shape:
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
if not is_encoder_decoder:
labels = labels[:, 1:].clone()
logits = logits[:, :-1, :]
loss_mask = labels != label_pad_token_id
# dummy token; we'll ignore the losses on these tokens later
labels[labels == label_pad_token_id] = 0
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
if average_log_prob:
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
else:
return (per_token_logps * loss_mask).sum(-1)
def get_sft_loss(self, logits, labels):
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
# Enable model/pipeline parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
return loss
def concatenated_forward(self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
We do this to avoid doing two forward passes, because it's faster for FSDP.
"""
# import pdb; pdb.set_trace()
concatenated_batch = self.concatenated_inputs(
batch,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
padding_value=self.padding_value,
device=self.accelerator.device,
)
len_chosen = batch["chosen_labels"].shape[0]
# import pdb; pdb.set_trace()
all_logits, new_labels = model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
labels=concatenated_batch["concatenated_labels"],
images=concatenated_batch["concatenated_images"],
image_sizes=concatenated_batch["image_sizes"],
modalities=concatenated_batch["modalities"],
use_cache=False,
dpo_forward=True,
)
all_logits = all_logits.to(torch.float32)
all_logps = self.get_batch_logps(
all_logits,
new_labels,
average_log_prob=self.loss_type == "ipo",
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]
# don't count image embeds logits
# loss_mask = new_labels != -100
# logits = [all_logits[i][loss_mask[i]] for i in range(loss_mask.shape[0])]
# chosen_logits = logits[:len_chosen]
# rejected_logits = logits[len_chosen:]
# chosen_logits = [l.detach().cpu().mean() for l in chosen_logits]
# rejected_logits = [l.detach().cpu().mean() for l in rejected_logits]
# chosen_logits = sum(chosen_logits)/len_chosen
# rejected_logits = sum(rejected_logits)/len_chosen
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]
chosen_labels = new_labels[:len_chosen]
rejected_labels = new_labels[len_chosen:]
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_labels, rejected_labels)
def get_batch_loss_metrics(
self,
model,
batch: Dict[str, Union[List, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the DPO loss and other metrics for the given batch of inputs for train or test.
CHANGE: 1. add sft loss
2. all gather metrics
"""
metrics = {}
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
chosen_labels,
rejected_labels,
) = self.concatenated_forward(model, batch)
# if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
if "reference_chosen_logps" in batch and "reference_rejected_logps" in batch:
reference_chosen_logps = batch["reference_chosen_logps"]
reference_rejected_logps = batch["reference_rejected_logps"]
else:
with torch.no_grad():
if self.ref_model is None:
with self.null_ref_context():
(
reference_chosen_logps,
reference_rejected_logps,
) = self.concatenated_forward(
self.model, batch
)[:2]
else:
(
reference_chosen_logps,
reference_rejected_logps,
) = self.concatenated_forward(
self.ref_model, batch
)[:2]
unscaled_dpo_losses, chosen_rewards, rejected_rewards = self.dpo_loss(
policy_chosen_logps,
policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps,
)
unscaled_dpo_losses = unscaled_dpo_losses.mean()
dpo_losses = unscaled_dpo_losses * self.dpo_alpha
unscaled_sft_loss = self.get_sft_loss(policy_chosen_logits, chosen_labels)
sft_loss = unscaled_sft_loss * self.gamma
# print(sft_loss.shape, dpo_losses.shape)
losses = dpo_losses + sft_loss
# losses = sft_loss # sft only
# losses = dpo_losses # dpo only
reward_accuracies = (chosen_rewards > rejected_rewards).float()
def all_gather_tensor(tensor):
if torch.distributed.is_available() and torch.distributed.is_initialized():
tensor = tensor.detach()
gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(gathered_tensor, tensor)
tensor = torch.cat(gathered_tensor, dim=0)
# else:
# print('not distributed')
return tensor
# gather chosen_rewards across devices
chosen_rewards = all_gather_tensor(chosen_rewards)
rejected_rewards = all_gather_tensor(rejected_rewards)
reward_accuracies = all_gather_tensor(reward_accuracies)
policy_chosen_logps = all_gather_tensor(policy_chosen_logps)
policy_rejected_logps = all_gather_tensor(policy_rejected_logps)
reference_chosen_logps = all_gather_tensor(reference_chosen_logps)
reference_rejected_logps = all_gather_tensor(reference_rejected_logps)
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}losses/dpo"] = unscaled_dpo_losses.cpu()
metrics[f"{prefix}losses/sft"] = unscaled_sft_loss.cpu()
metrics[f"{prefix}losses/total"] = losses.cpu()
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
# policy logps
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
# policy logits (exclude image tokens)
# metrics[f"{prefix}logits/rejected"] =policy_rejected_logits
# metrics[f"{prefix}logits/chosen"] = policy_chosen_logits
# reference logps
metrics[f"{prefix}ref_logps/rejected"] = reference_rejected_logps.mean().cpu()
metrics[f"{prefix}ref_logps/chosen"] = reference_chosen_logps.mean().cpu()
# metrics all pick .4 digits
# for k in metrics:
# metrics[k] = round(metrics[k].item(), 4)
return losses, metrics
def compute_loss(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs=False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
if not self.use_dpo_data_collator:
warnings.warn(
"compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)
compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
with compute_loss_context_manager():
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
# force log the metrics
self.store_metrics(metrics, train_eval="train")
if return_outputs:
return (loss, metrics)
return loss
def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
"""Generate samples from the model and reference model for the given batch of inputs."""
# If one uses `generate_during_eval` with peft + bf16, we need to explictly call generate with
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast
with generate_context_manager():
policy_output = model.generate(
input_ids=batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
max_length=self.max_length,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
)
# if reference_output in batch use that otherwise use the reference model
if "reference_output" in batch:
reference_output = batch["reference_output"]
else:
if self.ref_model is None:
with self.null_ref_context():
reference_output = self.model.generate(
input_ids=batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
max_length=self.max_length,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
)
else:
reference_output = self.ref_model.generate(
input_ids=batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
max_length=self.max_length,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
)
policy_output = pad_to_length(policy_output, self.max_length, self.tokenizer.pad_token_id)
policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True)
reference_output = pad_to_length(reference_output, self.max_length, self.tokenizer.pad_token_id)
reference_output_decoded = self.tokenizer.batch_decode(reference_output, skip_special_tokens=True)
return policy_output_decoded, reference_output_decoded
def prediction_step(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
):
if not self.use_dpo_data_collator:
warnings.warn(
"prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)
if ignore_keys is None:
if hasattr(model, "config"):
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
else:
ignore_keys = []
prediction_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
with torch.no_grad(), prediction_context_manager():
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
# force log the metrics
self.store_metrics(metrics, train_eval="eval")
if prediction_loss_only:
return (loss.detach(), None, None)
# logits for the chosen and rejected samples from model
logits_dict = {
"eval_logits/chosen": metrics["eval_logits/chosen"],
"eval_logits/rejected": metrics["eval_logits/rejected"],
}
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
return (loss.detach(), logits, labels)
def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
def evaluation_loop(
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> EvalLoopOutput:
"""
Overriding built-in evaluation loop to store metrics for each batch.
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Works both with or without labels.
"""
# Sample and save to game log if requested (for one batch to save time)
if self.generate_during_eval:
# Generate random indices within the range of the total number of samples
num_samples = len(dataloader.dataset)
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
random_batch_dataset = dataloader.dataset.select(random_indices)
random_batch = self.data_collator(random_batch_dataset)
random_batch = self._prepare_inputs(random_batch)
policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, random_batch)
self.log(
{
"game_log": wandb.Table(
columns=["Prompt", "Policy", "Ref Model"],
rows=[[prompt, pol[len(prompt) :], ref[len(prompt) :]] for prompt, pol, ref in zip(random_batch["prompt"], policy_output_decoded, ref_output_decoded)],
)
}
)
self.state.log_history.pop()
# Base evaluation
initial_output = super().evaluation_loop(dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
return initial_output
def log(self, logs: Dict[str, float]) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.
Args:
logs (`Dict[str, float]`):
The values to log.
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super().log(logs)
@wraps(Trainer.push_to_hub)
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)
return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import (
DataCollator,
DataCollatorForLanguageModeling,
DataCollatorForSeq2Seq,
PreTrainedModel,
PreTrainedTokenizerBase,
Trainer,
TrainingArguments,
)
from transformers.trainer_utils import EvalLoopOutput
from ..core import PPODecorators
from ..import_utils import is_peft_available
if is_peft_available():
from peft import PeftModel
class IterativeSFTTrainer(Trainer):
"""
The IterativeSFTTrainer can be used to finetune models with methods that requires some steps between optimization.
Attributes:
**model** (`PreTrainedModel`) -- Model to be optimized, either an 'AutoModelForCausalLM' or an 'AutoModelForSeq2SeqLM'.
Check the documentation of `PreTrainedModel` for more details.
**args** (`transformers.TrainingArguments`): -- The arguments to use for training.
**tokenizer** (`PreTrainedTokenizerBase`) -- Tokenizer to be used for encoding the
data. Check the documentation of `transformers.PreTrainedTokenizer` and
`transformers.PreTrainedTokenizerFast` for more details.
**optimizers** (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): -- The optimizer and scheduler to use for training.
**data_collator** (Union[DataCollatorForLanguageModeling, DataCollatorForSeq2Seq], *optional*) -- Data collator to be used for training and
passed along the dataloader.
**eval_dataset** (`datasets.Dataset`): The dataset to use for evaluation.
**max_length** (`int`, defaults to `None`): -- The maximum length of the input.
**truncation_mode** (`str`, defaults to `keep_end`): -- The truncation mode to use, either `keep_end` or `keep_start`.
**preprocess_logits_for_metrics** (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): -- The function to use to preprocess the logits before computing the metrics.
**compute_metrics** (`Callable[[EvalPrediction], Dict]`, *optional*): -- The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to metric values.
**optimize_device_cache ** (`bool`, *optional*, defaults to `False`) -- Optimize CUDA cache for slightly more memory-efficient training.
"""
def __init__(
self,
model: PreTrainedModel = None,
args: TrainingArguments = None,
tokenizer: PreTrainedTokenizerBase = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
None,
None,
),
data_collator: Optional[DataCollator] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
max_length: Optional[int] = None,
truncation_mode: Optional[str] = "keep_end",
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
optimize_device_cache: Optional[bool] = False,
):
# Step 0: check positional arguments validity
if not isinstance(tokenizer, (PreTrainedTokenizerBase)):
raise ValueError(f"tokenizer must be a PreTrainedTokenizerBase like a PreTrainedTokenizer or a PreTrainedTokenizerFast, got {type(tokenizer)}")
if not isinstance(model, PreTrainedModel):
raise ValueError(f"model must be a PreTrainedModel, got {type(model)}")
if not model.can_generate():
warnings.warn(f"The current model class {type(model)} is not compatible with `.generate()`" "Please make sure that this is intended.")
if optimizers[1] is None and args.max_steps == -1:
raise ValueError("When no scheduler is provided, you need to set the total number of training steps to perform `max_steps`")
self.is_encoder_decoder = getattr(model.config, "is_encoder_decoder", False)
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
self.tokenizer = tokenizer
if data_collator is None:
if self.is_encoder_decoder:
warnings.warn("No data collator is provided. Using 'DataCollatorForSeq2Seq' with" "'labels_pad_token_id' set to '-100' and 'pad_to_multiple_of' set to 8.")
self.data_collator = DataCollatorForSeq2Seq(tokenizer, label_pad_token_id=-100, pad_to_multiple_of=8)
else:
warnings.warn("No data collator is provided. Using 'DataCollatorForLanguageModeling'")
self.data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False)
else:
self.data_collator = data_collator
self.max_length = max_length
self.truncation_mode = truncation_mode
self.optimize_device_cache = optimize_device_cache
super().__init__(
model=model,
args=args,
data_collator=self.data_collator,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
self.create_optimizer_and_scheduler(self.args.max_steps)
# prepare model, optimizer and lr_scheduler
self.model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(self.model, self.optimizer, self.lr_scheduler)
self.tokenizer.truncation_side = "left" if self.truncation_mode == "keep_end" else "right"
if not hasattr(self, "accelerator"):
raise AttributeError("Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`.")
PPODecorators.optimize_device_cache = self.optimize_device_cache
def prepare_model_inputs(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor):
if attention_mask is None:
attention_mask = [torch.ones_like(ids) for ids in input_ids]
if self.is_encoder_decoder:
input_data = self.data_collator([{"input_ids": ids, "attention_mask": att, "labels": lab} for ids, att, lab in zip(input_ids, attention_mask, labels)]).to(self.model.device)
input_data.pop("decoder_input_ids", None) # This is directly computed inside the model
input_data["labels"][input_data["labels"] == self.tokenizer.pad_token_id] = -100
else:
input_data = self.data_collator([{"input_ids": ids, "attention_mask": att} for ids, att in zip(input_ids, attention_mask)]).to(self.model.device)
# truncate in case the user has provided input_ids, attention_mask and labels
if self.max_length is not None:
if self.truncation_mode == "keep_start":
input_data = {k: v[: self.max_length] for k, v in input_data.items()}
elif self.truncation_mode == "keep_end":
input_data = {k: v[-self.max_length :] for k, v in input_data.items()}
else:
raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
return input_data
@staticmethod
def _step_safety_checker(
input_ids: List[torch.LongTensor],
attention_mask: List[torch.LongTensor],
labels: List[torch.LongTensor],
texts: List[str],
texts_labels: List[str],
):
"""
Check if the input data is valid for training.
Args:
input_ids (List[`torch.LongTensor`]):
List of tensors containing the input_ids
attention_mask (List[`torch.LongTensor`]):
List of tensors containing the attention_mask
labels (List[`torch.FloatTensor`]):
List of tensors containing the labels
texts (List[`str`]):
List of string containing the text input.
texts_labels (List[`str`]):
List of string containing the text labels.
Returns:
`tuple`: The input data.
"""
if texts is None:
if attention_mask is None:
for name, tensor_list in zip(["input_ids", "labels"], [input_ids, labels]):
if not isinstance(tensor_list, list):
raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}")
if not isinstance(tensor_list[0], torch.Tensor):
raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}")
else:
for name, tensor_list in zip(["input_ids", "attention_mask", "labels"], [input_ids, attention_mask, labels]):
if not isinstance(tensor_list, list):
raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}")
if not isinstance(tensor_list[0], torch.Tensor):
raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}")
else:
if not isinstance(texts, list):
raise ValueError(f"'text' must be a list of strings - got {type(texts)}")
if not isinstance(texts[0], str):
raise ValueError(f"Elements in 'text' must be strings - got {type(texts[0])}")
if texts_labels is not None:
if not isinstance(texts_labels, list):
raise ValueError(f"'text_labels' must be a list of strings - got {type(texts_labels)}")
if not isinstance(texts_labels[0], str):
raise ValueError(f"Elements in 'text_labels' must be strings - got {type(texts_labels[0])}")
return input_ids, attention_mask, labels, texts, texts_labels
@PPODecorators.empty_device_cache()
def step(
self,
input_ids: Optional[List[torch.LongTensor]] = None,
attention_mask: Optional[List[torch.LongTensor]] = None,
labels: Optional[List[torch.LongTensor]] = None,
texts: Optional[List[str]] = None,
texts_labels: Optional[List[str]] = None,
):
"""
Run an optimisation step given a list of input_ids, attention_mask, and labels or a list of text and text_labels.
Args:
input_ids (List[`torch.LongTensor`]):
List of tensors containing the input_ids (if not provided, text will be used)
attention_mask (List[`torch.LongTensor`], , *optional*):
List of tensors containing the attention_mask
labels (List[`torch.FloatTensor`], *optional*):
List of tensors containing the labels (if set to None, will default to input_ids)
texts (List[`str`], *optional*):
List of strings containing the text input (if not provided, input_ids will directly be used)
texts_labels (List[`str`], *optional*):
List of strings containing the text labels (if set to None, will default to text)
Returns:
`dict[str, Any]`: A summary of the training statistics
"""
self.model.train()
if self.state.global_step == 0:
self.tr_loss = torch.tensor(0.0).to(self.args.device)
self._globalstep_last_logged = self.state.global_step
if input_ids is None and texts is None:
raise ValueError("Step should include `input_ids` or `texts` as keyword arguments.")
elif input_ids is not None and texts is not None:
warnings.warn("Both 'input_ids' and 'texts' are provided. 'input_ids' will be overwritten using inputs provided by the 'texts' keyword argument.")
if labels is None and texts_labels is None and self.is_encoder_decoder:
raise ValueError("No 'labels' or 'text_labels' are provided. When using an encoder-decoder architecture, 'labels' or 'text_labels' must be passed.")
input_ids, attention_mask, labels, texts, texts_labels = self._step_safety_checker(input_ids, attention_mask, labels, texts, texts_labels)
if texts is not None:
model_inputs = self.tokenizer(texts, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt")
input_ids, attention_mask = model_inputs["input_ids"], model_inputs["attention_mask"]
if texts_labels is not None:
labels = self.tokenizer(texts, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt")["input_ids"]
if labels is None:
warnings.warn("No labels are provided. Setting labels to input_ids")
labels = input_ids
model_inputs = self.prepare_model_inputs(input_ids, attention_mask, labels)
model_inputs_names = list(model_inputs.keys())
batch_dict = {}
batch_dict.update(model_inputs)
def collator(data):
return_dict = dict()
for key in data[0]:
if key in ["input_ids", "attention_mask", "labels"]:
return_dict[key] = torch.stack([d[key] for d in data]).to(self.model.device)
return return_dict
batch_data = Dataset.from_dict(batch_dict)
batch_data.set_format("torch")
step_dataloader = DataLoader(
batch_data,
batch_size=self.args.per_device_train_batch_size,
shuffle=True,
collate_fn=collator,
)
for _, batch in enumerate(step_dataloader):
with self.accelerator.accumulate(self.model):
model_inputs = {k: batch[k] for k in model_inputs_names}
loss = self.compute_loss(self.model, model_inputs)
if self.args.n_gpu > 1:
loss = loss.mean()
tr_loss_step = loss.detach()
self.accelerator.backward(loss)
if self.accelerator.sync_gradients and self.args.max_grad_norm is not None:
self.accelerator.clip_grad_norm_(
self.model.parameters(),
self.args.max_grad_norm,
)
self.optimizer.step()
self.optimizer.zero_grad()
if self.lr_scheduler is not None:
self.lr_scheduler.step()
self.state.global_step += 1
# update stats etc
self.tr_loss += tr_loss_step
self._maybe_log_save_evaluate()
def _maybe_log_save_evaluate(self):
# check if eval is required
if self.args.eval_steps is not None:
if self.state.global_step % self.args.eval_steps == 0 and self.state.global_step != 0:
self.evaluate(self.eval_dataset)
# check if logging is required
if self.args.logging_steps is not None:
if self.state.global_step % self.args.logging_steps == 0 and self.state.global_step != 0:
logs: Dict[str, float] = {}
tr_loss_scalar = self._nested_gather(self.tr_loss).mean().item()
# reset tr_loss to zero
self.tr_loss -= self.tr_loss
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
logs["learning_rate"] = self._get_learning_rate()
self._globalstep_last_logged = self.state.global_step
self.log(logs)
from dataclasses import dataclass, field
from typing import List, Optional
from ..core import flatten_dict
@dataclass
class ModelConfig:
"""
Arguments which define the model and tokenizer to load.
"""
model_name_or_path: Optional[str] = field(
default=None,
metadata={"help": ("The model checkpoint for weights initialization.")},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
torch_dtype: Optional[str] = field(
default=None,
metadata={
"help": ("Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " "dtype will be automatically derived from the model's weights."),
"choices": ["auto", "bfloat16", "float16", "float32"],
},
)
trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."})
attn_implementation: Optional[str] = field(
default=None,
metadata={"help": ("Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`")},
)
use_peft: bool = field(
default=False,
metadata={"help": ("Whether to use PEFT or not for training.")},
)
lora_r: Optional[int] = field(
default=16,
metadata={"help": ("LoRA R value.")},
)
lora_alpha: Optional[int] = field(
default=32,
metadata={"help": ("LoRA alpha.")},
)
lora_dropout: Optional[float] = field(
default=0.05,
metadata={"help": ("LoRA dropout.")},
)
lora_target_modules: Optional[List[str]] = field(
default=None,
metadata={"help": ("LoRA target modules.")},
)
lora_modules_to_save: Optional[List[str]] = field(
default=None,
metadata={"help": ("Model layers to unfreeze & train")},
)
load_in_8bit: bool = field(default=False, metadata={"help": "use 8 bit precision for the base model - works only with LoRA"})
load_in_4bit: bool = field(default=False, metadata={"help": "use 4 bit precision for the base model - works only with LoRA"})
bnb_4bit_quant_type: Optional[str] = field(default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"})
use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"})
def to_dict(self):
output_dict = {}
for key, value in self.__dict__.items():
output_dict[key] = value
return flatten_dict(output_dict)
def __post_init__(self):
if self.load_in_8bit and self.load_in_4bit:
raise ValueError("You can't use 8 bit and 4 bit precision at the same time")
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import sys
import warnings
from dataclasses import dataclass, field
from typing import Literal, Optional
import numpy as np
import tyro
from typing_extensions import Annotated
from trl.trainer.utils import exact_div
from ..core import flatten_dict
from ..import_utils import is_wandb_available
JSONDict = Annotated[Optional[dict], tyro.conf.arg(metavar="JSON", constructor=json.loads)]
@dataclass
class PPOConfig:
"""
Configuration class for PPOTrainer
"""
# common parameters
exp_name: str = os.path.basename(sys.argv[0])[: -len(".py")]
"""the name of this experiment (by default is the file name without the extension name)"""
seed: int = 0
"""Seed value for random generations"""
log_with: Optional[Literal["wandb", "tensorboard"]] = None
"""Log with either 'wandb' or 'tensorboard', check https://huggingface.co/docs/accelerate/usage_guides/tracking for more details"""
task_name: Optional[str] = None
"""Name of task to use - used only for tracking purposes"""
model_name: Optional[str] = "gpt2"
"""Name of model to use - used only for tracking purposes"""
query_dataset: Optional[str] = "imdb"
"""Name of dataset to query - used only for tracking purposes"""
reward_model: Optional[str] = "sentiment-analysis:lvwerra/distilbert-imdb"
"""The reward model to use - used only for tracking purposes"""
remove_unused_columns: bool = True
"""Remove unused columns from the dataset if `datasets.Dataset` is used"""
tracker_kwargs: JSONDict = field(default_factory=dict)
"""Keyword arguments for the tracker (e.g. python ppo.py --tracker_kwargs='{"wandb": {"entity": "my_wandb_entity", "name": "my_exp_name"}}'"""
accelerator_kwargs: JSONDict = field(default_factory=dict)
"""Keyword arguments for the accelerator"""
project_kwargs: JSONDict = field(default_factory=dict)
"""Keyword arguments for the accelerator project config (e.g. `logging_dir`)"""
tracker_project_name: str = "trl"
"""Name of project to use for tracking"""
push_to_hub_if_best_kwargs: JSONDict = field(default_factory=dict)
"""Keyword arguments for pushing model to the hub during training (e.g. repo_id)"""
# hyperparameters
steps: int = 20000
"""Number of training steps"""
learning_rate: float = 1.41e-5
"""Adam learning rate"""
adap_kl_ctrl: bool = True
"""Use adaptive KL control, otherwise linear"""
init_kl_coef: Optional[float] = 0.2
"""Initial KL penalty coefficient (used for adaptive and linear control)"""
kl_penalty: Literal["kl", "abs", "mse", "full"] = "kl"
"""kl penalty options: 'kl': model_logp - ref_logp, 'abs': abs(kl), 'mse': mean squared error mse(kl) and 'full': the actual kl for all tokens in the distribution"""
target: Optional[float] = 6
"""Target KL value for adaptive KL control"""
horizon: Optional[float] = 10000
"""Horizon for adaptive KL control"""
gamma: float = 1
"""Gamma parameter for advantage calculation"""
lam: float = 0.95
"""Lambda parameter for advantage calculation"""
cliprange: float = 0.2
"""Range for clipping in PPO policy gradient loss"""
cliprange_value: float = 0.2
"""Range for clipping values in loss calculation"""
vf_coef: float = 0.1
"""Scaling factor for value loss"""
batch_size: int = 128
"""Number of samples per optimisation step"""
forward_batch_size: Optional[int] = None
"""DEPRECATED: use `mini_batch_size` instead, which does the same thing."""
mini_batch_size: int = 128
"""Number of samples optimized in each mini batch"""
gradient_accumulation_steps: int = 1
"""The number of gradient accumulation steps"""
world_size: tyro.conf.Suppress[int] = None
"""The world size for distributed training"""
ppo_epochs: int = 4
"""Number of optimisation epochs per batch of samples"""
max_grad_norm: Optional[float] = None
"""Maximum gradient norm for gradient clipping"""
optimize_cuda_cache: Optional[bool] = None
"""DEPRECATED: use `optimize_device_cache` instead, which does the same thing."""
optimize_device_cache: Optional[bool] = False
"""Optimize device cache for slightly more memory-efficient training"""
early_stopping: bool = False
"""Whether to stop the PPO optimization loop early is the KL too high"""
target_kl: float = 1
"""Stop early if we exceed this value by over 50%"""
compare_steps: int = 1
"""Number of steps between comparison of the current reward with the best seen so far"""
ratio_threshold: float = 10.0
"""Skip mini-batches with high PPO ratios that can cause loss spikes"""
use_score_scaling: bool = False
"""Use score scaling"""
use_score_norm: bool = False
"""Use score normalization. Only applicable if use_score_scaling is True"""
score_clip: Optional[float] = None
"""Score clipping"""
whiten_rewards: bool = False
"""Whiten the rewards before compute advantages"""
# computed hyperparameters at runtime; we use `tyro.conf.Suppress` to hide them from the help text
is_encoder_decoder: Optional[tyro.conf.Suppress[bool]] = None
"""TO BE FILLED In RUNTIME: Whether the model is an encoder-decoder model"""
is_peft_model: Optional[tyro.conf.Suppress[bool]] = None
"""TO BE FILLED In RUNTIME: Whether the model is a PEFT model"""
backward_batch_size: tyro.conf.Suppress[int] = None
"""TO BE FILLED In RUNTIME: Number of samples optimized in an `optimizer.step()` call"""
global_backward_batch_size: tyro.conf.Suppress[int] = None
"""TO BE FILLED In RUNTIME: the effective `backward_batch_size` across all processes"""
global_batch_size: tyro.conf.Suppress[int] = None
"""TO BE FILLED In RUNTIME: the effective `batch_size` across all processes"""
if optimize_cuda_cache is not None:
warnings.warn("The `optimize_cuda_cache` argument will be deprecated soon, please use `optimize_device_cache` instead.")
optimize_device_cache = optimize_cuda_cache
else:
optimize_device_cache = False
def __post_init__(self):
if self.forward_batch_size is not None:
warnings.warn(
"Note that using `forward_batch_size` is deprecated, use `mini_batch_size` instead. By setting it you overwrite `mini_batch_size` which affects both the batch size during forward passes and also the mini batch size for PPO optimization."
)
self.mini_batch_size = self.forward_batch_size
self.backward_batch_size = self.mini_batch_size * self.gradient_accumulation_steps
exact_div(
self.batch_size,
self.backward_batch_size,
"`batch_size`",
"`mini_batch_size * gradient_accumulation_steps`",
"`batch_size` must be a multiple of `mini_batch_size * gradient_accumulation_steps`",
)
# check if wandb is installed
if self.log_with == "wandb":
# raise error if wandb is not installed
if not is_wandb_available():
raise ImportError("Please install wandb to use wandb logging. You can do this by running `pip install wandb`.")
self.total_ppo_epochs = int(np.ceil(self.steps / self.batch_size))
assert self.kl_penalty in ["kl", "abs", "mse", "full"]
def to_dict(self):
output_dict = {}
for key, value in self.__dict__.items():
output_dict[key] = value
return flatten_dict(output_dict)
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
import os
import time
import typing
import warnings
from contextlib import nullcontext
from typing import Callable, List, Optional, Union
import datasets
import numpy as np
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration, gather_object, is_deepspeed_available
from datasets import Dataset
from huggingface_hub import whoami
from packaging import version
from torch.optim import Adam
from transformers import (
DataCollatorForLanguageModeling,
PreTrainedTokenizer,
PreTrainedTokenizerBase,
PreTrainedTokenizerFast,
)
from ..core import (
WANDB_PADDING,
PPODecorators,
clip_by_value,
convert_to_scalar,
entropy_from_logits,
flatten_dict,
logprobs_from_logits,
masked_mean,
masked_var,
masked_whiten,
set_seed,
stack_dicts,
stats_to_np,
)
from ..import_utils import is_npu_available, is_torch_greater_2_0, is_xpu_available
from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper, create_reference_model
from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig, RunningMoments
if is_deepspeed_available():
import deepspeed
MODEL_CARD_TEMPLATE = """---
license: apache-2.0
tags:
- trl
- ppo
- transformers
- reinforcement-learning
---
# {model_name}
This is a [TRL language model](https://github.com/huggingface/trl) that has been fine-tuned with reinforcement learning to
guide the model outputs according to a value, function, or human feedback. The model can be used for text generation.
## Usage
To use this model for inference, first install the TRL library:
```bash
python -m pip install trl
```
You can then generate text as follows:
```python
from transformers import pipeline
generator = pipeline("text-generation", model="{model_id}")
outputs = generator("Hello, my llama is cute")
```
If you want to use the model for training or to obtain the outputs from the value head, load the model as follows:
```python
from transformers import AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead
tokenizer = AutoTokenizer.from_pretrained("{model_id}")
model = AutoModelForCausalLMWithValueHead.from_pretrained("{model_id}")
inputs = tokenizer("Hello, my llama is cute", return_tensors="pt")
outputs = model(**inputs, labels=inputs["input_ids"])
```
"""
class PPOTrainer(BaseTrainer):
"""
The PPOTrainer uses Proximal Policy Optimization to optimise language models.
Note, this trainer is heavily inspired by the original OpenAI learning to summarize work here:
https://github.com/openai/summarize-from-feedback
Attributes:
**config** (`PPOConfig`) -- Configuration object for PPOTrainer. Check the documentation of `PPOConfig` for more
details.
**model** (`PreTrainedModelWrapper`) -- Model to be optimized, Hugging Face transformer model with a value head.
Check the documentation of `PreTrainedModelWrapper` for more details.
**ref_model** (`PreTrainedModelWrapper`, *optional*) -- Reference model to be used for KL penalty, Hugging Face
transformer model with a casual language modelling head. Check the documentation of `PreTrainedModelWrapper`
for more details. If no reference model is provided, the trainer will create a reference model with the same
architecture as the model to be optimized with shared layers.
**tokenizer** (`PreTrainedTokenizerBase`) -- Tokenizer to be used for encoding the
data. Check the documentation of `transformers.PreTrainedTokenizer` and
`transformers.PreTrainedTokenizerFast` for more details.
**dataset** (Union[`torch.utils.data.Dataset`, `datasets.Dataset`], *optional*) -- PyTorch dataset or Hugging
Face dataset. This is used to create a PyTorch dataloader. If no dataset is provided, the dataloader must be
created outside the trainer users needs to design their own dataloader and make sure the batch
size that is used is the same as the one specified in the configuration object.
**optimizer** (`torch.optim.Optimizer`, *optional*) -- Optimizer to be used for training. If no optimizer is
provided, the trainer will create an Adam optimizer with the learning rate specified in the configuration
object.
**data_collator** (DataCollatorForLanguageModeling, *optional*) -- Data collator to be used for training and
passed along the dataloader
**num_shared_layers** (int, *optional*) -- Number of layers to be shared between the model and the reference
model, if no reference model is passed. If no number is provided, all the layers will be shared.
**lr_scheduler** (`torch.optim.lr_scheduler`, *optional*) -- Learning rate scheduler to be used for training.
"""
_tag_names = ["trl", "ppo"]
def __init__(
self,
config: PPOConfig = None,
model: PreTrainedModelWrapper = None,
ref_model: Optional[PreTrainedModelWrapper] = None,
tokenizer: PreTrainedTokenizerBase = None,
dataset: Optional[Union[torch.utils.data.Dataset, Dataset]] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
data_collator: Optional[typing.Callable] = None,
num_shared_layers: Optional[int] = None,
lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
):
"""
Initialize PPOTrainer.
Args:
config (`PPOConfig`):
Configuration object for PPOTrainer. Check the documentation of `PPOConfig` for more details.
model (`PreTrainedModelWrapper`):
Hugging Face transformer model with a value head.
ref_model (`PreTrainedModelWrapper`):
Hugging Face transformer model with a casual language modelling head. Used for KL penalty
tokenizer (`transformers.PreTrainedTokenizerBase`):
Hugging Face tokenizer
dataset (Optional[Union[`torch.utils.data.Dataset`, `datasets.Dataset`]]):
PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset
will be preprocessed by removing the columns that are not used by the model. If none is passed,
a warning will be raised in a multi-GPU setting.
optimizer (Optional[`torch.optim.Optimizer`]):
Optimizer used for training. If `None`, the `Adam` is used as default.
data_collator (Optional[function]):
Data collator function.
num_shared_layers (Optional[int]):
Number of shared layers between the model and the reference model. If `None`, all layers are shared.
used only if `ref_model` is `None`.
lr_scheduler (Optional[`torch.optim.lr_scheduler`]):
Learning rate scheduler used for training.
"""
super().__init__(config)
# initial seed for reproducible experiments
set_seed(config.seed)
# Step 0: check positional arguments validity
if not isinstance(config, PPOConfig):
raise ValueError(f"config must be a PPOConfig, got {type(config)}")
if not isinstance(tokenizer, (PreTrainedTokenizerBase)):
raise ValueError(f"tokenizer must be a PreTrainedTokenizerBase like a PreTrainedTokenizer or a PreTrainedTokenizerFast, got {type(tokenizer)}")
if not isinstance(model, (SUPPORTED_ARCHITECTURES)):
raise ValueError(f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}")
# Step 1: Initialize Accelerator
self.accelerator = Accelerator(
log_with=config.log_with,
gradient_accumulation_steps=config.gradient_accumulation_steps,
project_config=ProjectConfiguration(**config.project_kwargs),
**config.accelerator_kwargs,
)
# Step 1.1 Runtime variables filled by the accelerator
config.world_size = self.accelerator.num_processes
config.global_backward_batch_size = config.backward_batch_size * config.world_size
config.global_batch_size = config.batch_size * config.world_size
self.model = model
self.model_params = filter(lambda p: p.requires_grad, self.model.parameters())
self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder")
self.is_peft_model = getattr(self.model, "is_peft_model", False)
config.is_encoder_decoder = self.is_encoder_decoder
config.is_peft_model = self.is_peft_model
is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
self.accelerator.init_trackers(
config.tracker_project_name,
config=dict(trl_ppo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(),
init_kwargs=config.tracker_kwargs,
)
self.is_using_text_environment = getattr(config, "use_text_environment", False)
if isinstance(ref_model, SUPPORTED_ARCHITECTURES):
self.ref_model = ref_model
if num_shared_layers is not None:
warnings.warn(
"num_shared_layers is ignored when ref_model is provided. Two different models are used for the " "model and the reference model and no layers are shared.",
UserWarning,
)
elif ref_model is None and not self.is_peft_model:
self.ref_model = create_reference_model(self.model, num_shared_layers=num_shared_layers)
elif self.is_peft_model:
self.ref_model = None
else:
raise ValueError(f"ref_model must be a PreTrainedModelWrapper or `None`, got {type(ref_model)} - supported " f"architectures are: {SUPPORTED_ARCHITECTURES} ")
self.optional_peft_ctx = self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter if self.is_peft_model else nullcontext
if not (isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast)):
raise ValueError("tokenizer must be a transformers.PreTrainedTokenizer or transformers.PreTrainedTokenizerFast")
self.tokenizer = tokenizer
if dataset is not None and not (isinstance(dataset, torch.utils.data.Dataset) or isinstance(dataset, Dataset)):
raise ValueError("dataset must be a torch.utils.data.Dataset or datasets.Dataset")
elif dataset is None:
warnings.warn(
"No dataset is provided. Make sure to set config.batch_size to the correct value before training.",
UserWarning,
)
self.dataset = dataset
self._signature_columns = None
if self.dataset is not None:
self.dataloader = self.prepare_dataloader(self.dataset, data_collator)
elif self.dataset is None and self.accelerator.num_processes > 1:
warnings.warn(
"No dataset is provided. In a multi-GPU setting, this will lead to an error. You should"
" prepare your dataloader yourself with `dataloader = ppo_trainer.accelerator.prepare(dataloader)`"
" and using `torch.utils.data.DataLoader`, or pass a dataset to the `PPOTrainer`. Please "
" refer to the documentation for more details.",
UserWarning,
)
self.dataloader = None
else:
self.dataloader = None
# Step 3: Initialize optimizer and data collator
self.data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False)
if optimizer is None:
self.optimizer = Adam(
filter(lambda p: p.requires_grad, self.model.parameters()),
lr=self.config.learning_rate,
)
else:
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
if self.lr_scheduler is not None:
lr_scheduler_class = torch.optim.lr_scheduler._LRScheduler if not is_torch_greater_2_0() else torch.optim.lr_scheduler.LRScheduler
if not isinstance(self.lr_scheduler, lr_scheduler_class):
raise ValueError("lr_scheduler must be a torch.optim.lr_scheduler._LRScheduler or torch.optim.lr_scheduler.LRScheduler (for torch >= 2.0)")
if self.config.adap_kl_ctrl:
self.kl_ctl = AdaptiveKLController(self.config.init_kl_coef, self.config.target, self.config.horizon)
else:
self.kl_ctl = FixedKLController(self.config.init_kl_coef)
# Safety checkers for DS integration
is_deepspeed_used = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(self.accelerator.state, "deepspeed_plugin")
(
self.model,
self.optimizer,
self.data_collator,
self.dataloader,
self.lr_scheduler,
) = self.accelerator.prepare(
self.model,
self.optimizer,
self.data_collator,
self.dataloader,
self.lr_scheduler,
)
if is_deepspeed_used:
# Quantized models are already set on the correct device
if not self.is_peft_model and not (getattr(self.ref_model.pretrained_model, "is_loaded_in_8bit", False) or getattr(self.ref_model.pretrained_model, "is_loaded_in_4bit", False)):
self.ref_model = self._prepare_deepspeed(self.ref_model)
else:
self.ref_model = self.accelerator.prepare(self.ref_model)
# In a distributed setup, only logging needs to be performed on the main process
# check: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
# or: https://discuss.pytorch.org/t/use-distributed-data-parallel-correctly/82500/11
self.is_distributed = self.accelerator.num_processes > 1
# init the current step
self.current_step = 0
# init variables for pushing model to hub
if config.push_to_hub_if_best_kwargs:
if "repo_id" not in config.push_to_hub_if_best_kwargs:
raise ValueError("You have to specify repo_id in order to push the model to the hub!")
self.push_to_hub_kwargs = config.push_to_hub_if_best_kwargs
self.compare_step = 0
self.highest_reward = torch.tensor(-float("inf"))
# post process for PP
if not getattr(self.model, "is_sequential_parallel", False):
self.current_device = self.accelerator.device
else:
if is_xpu_available():
self.current_device = torch.device("xpu:0")
elif is_npu_available():
self.current_device = torch.device("npu:0")
else:
self.current_device = torch.device("cuda:0")
PPODecorators.optimize_device_cache = self.config.optimize_device_cache
self.running = RunningMoments(self.accelerator)
def _filter_kwargs(self, kwargs, target_func):
"""
filter the keyword arguments that are supported by the target function.
Args:
kwargs (dict):
Keyword arguments
target_func (function):
Target function
"""
return {k: v for k, v in kwargs.items() if k in inspect.signature(target_func).parameters.keys()}
def prepare_dataloader(self, dataset: Union[torch.utils.data.Dataset, Dataset], data_collator=None):
"""
Prepare the dataloader for training.
Args:
dataset (Union[`torch.utils.data.Dataset`, `datasets.Dataset`]):
PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset
will be preprocessed by removing the columns that are not used by the model.
data_collator (Optional[function]):
Data collator function.
Returns:
`torch.utils.data.DataLoader`: PyTorch dataloader
"""
if isinstance(dataset, Dataset):
dataset = self._remove_unused_columns(dataset)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=self.config.batch_size,
collate_fn=data_collator,
shuffle=True,
drop_last=True,
)
return dataloader
# Adapted from transformers.Trainer._set_signature_columns_if_needed
def _set_signature_columns_if_needed(self):
if self._signature_columns is None:
# Inspect model forward signature to keep only the arguments it accepts.
signature = inspect.signature(self.model.forward)
self._signature_columns = list(signature.parameters.keys())
# label => sentiment | we need query and response for logging purpose
self._signature_columns += ["label", "query", "response"]
# Adapted from transformers.Trainer._remove_unused_columns
def _remove_unused_columns(self, dataset: "Dataset"):
if not self.config.remove_unused_columns:
return dataset
self._set_signature_columns_if_needed()
signature_columns = self._signature_columns
ignored_columns = list(set(dataset.column_names) - set(signature_columns))
columns = [k for k in signature_columns if k in dataset.column_names]
if version.parse(datasets.__version__) < version.parse("1.4.0"):
dataset.set_format(
type=dataset.format["type"],
columns=columns,
format_kwargs=dataset.format["format_kwargs"],
)
return dataset
else:
return dataset.remove_columns(ignored_columns)
def generate(
self,
query_tensor: Union[torch.Tensor, List[torch.Tensor]],
length_sampler: Callable = None,
batch_size: int = 4,
return_prompt: bool = True,
generate_ref_response: bool = False,
**generation_kwargs,
):
"""
Generate response with the model given the query tensor.
call the `generate` method of the model.
Args:
query_tensor (`torch.LongTensor`):
A tensor of shape (`seq_len`) containing query tokens or a list of tensors of shape (`seq_len`).
length_sampler (`Callable`, *optional*):
Callable that returns the number of newly generated tokens.
batch_size (`int`, *optional):
Batch size used for generation, defaults to `4`.
return_prompt (`bool`, *optional*):
If set to `False` the prompt is not returned but only the newly generated tokens, defaults to `True`.
generate_ref_response (`bool`, *optional*):
If set to `True` the reference response is also generated, defaults to `False`.
generation_kwargs (dict[str, Any]):
Keyword arguments for generation.
Returns:
`torch.LongTensor`: A tensor of shape (`batch_size`, `gen_len`) containing response tokens.
"""
if generate_ref_response:
ref_model = self.model if self.is_peft_model else self.ref_model
if isinstance(query_tensor, List):
response = self._generate_batched(
self.model,
query_tensor,
length_sampler=length_sampler,
batch_size=batch_size,
return_prompt=return_prompt,
**generation_kwargs,
)
if generate_ref_response:
with self.optional_peft_ctx():
ref_response = self._generate_batched(
ref_model,
query_tensor,
length_sampler=length_sampler,
batch_size=batch_size,
return_prompt=return_prompt,
**generation_kwargs,
)
else:
if len(query_tensor.shape) == 2:
raise ValueError("query_tensor must be a tensor of shape (`seq_len`) or a list of tensors of shape (`seq_len`)")
if length_sampler is not None:
generation_kwargs["max_new_tokens"] = length_sampler()
response = self.accelerator.unwrap_model(self.model).generate(input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs)
if generate_ref_response:
with self.optional_peft_ctx():
ref_response = ref_model.generate(input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs)
if not return_prompt and not self.is_encoder_decoder:
response = response[:, query_tensor.shape[0] :]
if generate_ref_response:
ref_response = ref_response[:, query_tensor.shape[0] :]
if generate_ref_response:
return response, ref_response
return response
def _generate_batched(
self,
model: PreTrainedModelWrapper,
query_tensors: List[torch.Tensor],
length_sampler: Callable = None,
batch_size: int = 4,
return_prompt: bool = True,
pad_to_multiple_of: int = None,
remove_padding: bool = True,
**generation_kwargs,
):
outputs = []
padding_side_default = self.tokenizer.padding_side
if not self.is_encoder_decoder:
self.tokenizer.padding_side = "left"
# in case we have fewer examples than bs
batch_size = min(len(query_tensors), batch_size)
for i in range(0, len(query_tensors), batch_size):
if length_sampler is not None:
generation_kwargs["max_new_tokens"] = length_sampler()
# prevent overflow if query tensors are not even multiple of bs
end_index = min(len(query_tensors), i + batch_size)
batch = query_tensors[i:end_index]
batch_mask = [torch.ones_like(element) for element in batch]
inputs = {"input_ids": batch, "attention_mask": batch_mask}
padded_inputs = self.tokenizer.pad(
inputs,
padding=True,
max_length=None,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors="pt",
).to(self.current_device)
generations = self.accelerator.unwrap_model(model).generate(**padded_inputs, **generation_kwargs)
for generation, mask in zip(generations, padded_inputs["attention_mask"]):
if not self.is_encoder_decoder:
output = generation[(1 - mask).sum() :] # remove padding
else:
output = generation
if not return_prompt and not self.is_encoder_decoder:
output = output[(mask).sum() :] # remove prompt
if remove_padding and self.tokenizer.eos_token_id in output:
pad_mask = output == self.tokenizer.eos_token_id
pad_start = torch.nonzero(pad_mask, as_tuple=False)[0, 0].item()
output = output[: pad_start + 1] # keep the eos token at the end
outputs.append(output)
self.tokenizer.padding_side = padding_side_default
return outputs
def _step_safety_checker(
self,
batch_size: int,
queries: List[torch.LongTensor],
responses: List[torch.LongTensor],
scores: List[torch.FloatTensor],
masks: Optional[List[torch.LongTensor]] = None,
):
"""
Check if the input data is valid for training.
Args:
batch_size (int):
Batch size from the config file.
queries (List[`torch.LongTensor`]):
List of tensors containing the encoded queries of shape (`query_length`)
responses (List[`torch.LongTensor`]):
List of tensors containing the encoded responses of shape (`response_length`)
scores (List[`torch.FloatTensor`]):
List of tensors containing the scores.
masks (List[`torch.LongTensor`], *optional*):
list of optional tensors containing the masks of shape (`query_length` + `response_length`)
Returns:
`tuple`: The input processed data.
"""
for name, tensor_list in zip(["queries", "responses", "scores"], [queries, responses, scores]):
if not isinstance(tensor_list, list):
raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}")
if not isinstance(tensor_list[0], torch.Tensor):
raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}")
if batch_size is not None and len(tensor_list) != batch_size:
raise ValueError(f"Batch size ({batch_size}) does not match number of examples - but got {len(tensor_list)} for: {name}")
# add queries, scores and responses on the correct device
queries = [tensor.to(self.current_device) for tensor in queries]
responses = [tensor.to(self.current_device) for tensor in responses]
scores = [tensor.to(self.current_device) for tensor in scores]
masks = [tensor.to(self.current_device) for tensor in masks] if masks is not None else None
# squeeze scores if needed
for i, score in enumerate(scores):
if score.dim() > 1:
raise ValueError(f"Scores must be 1-dimensional - got {score.dim()} for {score}")
elif score.dim() == 1:
scores[i] = score.squeeze()
return queries, responses, scores, masks
@PPODecorators.empty_device_cache()
def step(
self,
queries: List[torch.LongTensor],
responses: List[torch.LongTensor],
scores: List[torch.FloatTensor],
response_masks: Optional[List[torch.LongTensor]] = None,
):
"""
Run a PPO optimisation step given a list of queries, model responses, and rewards.
Args:
queries (List[`torch.LongTensor`]):
List of tensors containing the encoded queries of shape (`query_length`)
responses (List[`torch.LongTensor`]):
List of tensors containing the encoded responses of shape (`response_length`)
scores (List[`torch.FloatTensor`]):
List of tensors containing the scores.
response_masks (List[`torch.FloatTensor`], *optional*)):
List of tensors containing masks of the response tokens.
Returns:
`dict[str, Any]`: A summary of the training statistics
"""
bs = self.config.batch_size
queries, responses, scores, response_masks = self._step_safety_checker(bs, queries, responses, scores, response_masks)
scores = torch.tensor(scores, device=self.current_device)
if self.config.use_score_scaling:
# Score scaling
scores_mean, scores_std = self.running.update(scores)
tensor_to_kwargs = dict(dtype=scores.dtype, device=scores.device)
score_scaling_factor = self.running.std.to(**tensor_to_kwargs) + torch.finfo(scores.dtype).eps
if self.config.use_score_norm:
scores = (scores - self.running.mean.to(**tensor_to_kwargs)) / score_scaling_factor
else:
scores /= score_scaling_factor
if self.config.score_clip is not None:
# Score clipping
scores_dtype = scores.dtype
scores = torch.clip(scores.float(), -self.config.score_clip, self.config.score_clip).to(dtype=scores_dtype)
# if we want to push best model to the hub
if hasattr(self, "highest_reward"):
if self.compare_step % self.config.compare_steps == 0:
curr_mean_reward = scores.mean()
# if the best reward ever seen
if curr_mean_reward > self.highest_reward:
self.highest_reward = curr_mean_reward
# push model to hub
self.push_to_hub(**self.push_to_hub_kwargs)
self.compare_step += 1
timing = dict()
t0 = time.time()
t = time.time()
model_inputs = self.prepare_model_inputs(queries, responses)
if self.is_distributed:
pad_first = self.tokenizer.padding_side == "left"
model_inputs["input_ids"] = self.accelerator.pad_across_processes(
model_inputs["input_ids"],
dim=1,
pad_index=self.tokenizer.pad_token_id,
pad_first=pad_first,
)
model_inputs["attention_mask"] = self.accelerator.pad_across_processes(model_inputs["attention_mask"], dim=1, pad_index=0, pad_first=pad_first)
if self.is_encoder_decoder:
model_inputs["decoder_input_ids"] = self.accelerator.pad_across_processes(
model_inputs["decoder_input_ids"],
dim=1,
pad_index=self.tokenizer.pad_token_id,
pad_first=pad_first,
)
model_inputs["decoder_attention_mask"] = self.accelerator.pad_across_processes(
model_inputs["decoder_attention_mask"],
dim=1,
pad_index=0,
pad_first=pad_first,
)
model_inputs_names = list(model_inputs.keys())
full_kl_penalty = self.config.kl_penalty == "full"
with torch.no_grad():
all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(
self.model,
queries,
responses,
model_inputs,
response_masks=response_masks,
return_logits=full_kl_penalty,
)
with self.optional_peft_ctx():
ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
self.model if self.is_peft_model else self.ref_model,
queries,
responses,
model_inputs,
return_logits=full_kl_penalty,
)
timing["time/ppo/forward_pass"] = time.time() - t
with torch.no_grad():
t = time.time()
if full_kl_penalty:
active_full_logprobs = logprobs_from_logits(logits_or_none, None, gather=False)
ref_full_logprobs = logprobs_from_logits(ref_logits_or_none, None, gather=False)
rewards, non_score_reward, kls = self.compute_rewards(scores, active_full_logprobs, ref_full_logprobs, masks)
else:
rewards, non_score_reward, kls = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks)
timing["time/ppo/compute_rewards"] = time.time() - t
t = time.time()
values, advantages, returns = self.compute_advantages(values, rewards, masks)
timing["time/ppo/compute_advantages"] = time.time() - t
# upcast to float32 to avoid dataset issues
batch_dict = {
"queries": queries,
"responses": responses,
"logprobs": all_logprobs.to(torch.float32),
"values": values.to(torch.float32),
"masks": masks,
"advantages": advantages,
"returns": returns,
}
batch_dict.update(model_inputs)
t = time.time()
all_stats = []
early_stop = False
for _ in range(self.config.ppo_epochs):
if early_stop:
break
b_inds = np.random.permutation(bs)
for backward_batch_start in range(0, bs, self.config.backward_batch_size):
backward_batch_end = backward_batch_start + self.config.backward_batch_size
backward_batch_inds = b_inds[backward_batch_start:backward_batch_end]
for mini_batch_start in range(0, self.config.backward_batch_size, self.config.mini_batch_size):
mini_batch_end = mini_batch_start + self.config.mini_batch_size
mini_batch_inds = backward_batch_inds[mini_batch_start:mini_batch_end]
mini_batch_dict = {
"logprobs": batch_dict["logprobs"][mini_batch_inds],
"values": batch_dict["values"][mini_batch_inds],
"masks": batch_dict["masks"][mini_batch_inds],
# hacks: the queries and responses are ragged.
"queries": [batch_dict["queries"][i] for i in mini_batch_inds],
"responses": [batch_dict["responses"][i] for i in mini_batch_inds],
"advantages": batch_dict["advantages"][mini_batch_inds],
"returns": batch_dict["returns"][mini_batch_inds],
}
for k in model_inputs_names:
mini_batch_dict[k] = batch_dict[k][mini_batch_inds]
with self.accelerator.accumulate(self.model):
model_inputs = {k: mini_batch_dict[k] for k in model_inputs_names}
logprobs, logits, vpreds, _ = self.batched_forward_pass(
self.model,
mini_batch_dict["queries"],
mini_batch_dict["responses"],
model_inputs,
return_logits=True,
)
train_stats = self.train_minibatch(
mini_batch_dict["logprobs"],
mini_batch_dict["values"],
logprobs,
logits,
vpreds,
mini_batch_dict["masks"],
mini_batch_dict["advantages"],
mini_batch_dict["returns"],
)
all_stats.append(train_stats)
# typically, early stopping is done at the epoch level
if self.config.early_stopping:
policykl = train_stats["policy/policykl"]
early_stop = self._early_stop(policykl)
if early_stop:
break
timing["time/ppo/optimize_step"] = time.time() - t
t = time.time()
train_stats = stack_dicts(all_stats)
# reshape advantages/ratios such that they are not averaged.
train_stats["policy/advantages"] = torch.flatten(train_stats["policy/advantages"]).unsqueeze(0)
train_stats["policy/advantages"] = torch.nan_to_num(train_stats["policy/advantages"], WANDB_PADDING)
train_stats["policy/ratio"] = torch.flatten(train_stats["policy/ratio"]).unsqueeze(0)
stats = self.record_step_stats(
scores=scores,
logprobs=all_logprobs,
ref_logprobs=ref_logprobs,
non_score_reward=non_score_reward,
train_stats=train_stats,
kl_coef=self.kl_ctl.value,
masks=masks,
queries=queries,
responses=responses,
kls=kls,
)
# Gather/Reduce stats from all processes
if self.is_distributed:
stats = self.gather_stats(stats)
stats = stats_to_np(stats)
timing["time/ppo/calc_stats"] = time.time() - t
stats["ppo/learning_rate"] = self.optimizer.param_groups[0]["lr"]
# Update the KL control - multiply the batch_size by the number of processes
self.kl_ctl.update(
stats["objective/kl"],
self.config.batch_size * self.accelerator.num_processes,
)
# Log the total ppo time
timing["time/ppo/total"] = time.time() - t0
stats.update(timing)
# post-process stats for tensorboard and other loggers
if self.config.log_with != "wandb":
stats = convert_to_scalar(stats)
if self.lr_scheduler is not None:
self.lr_scheduler.step()
return stats
def _early_stop(self, policykl):
r"""
Handles the early stopping logic. If the policy KL is greater than the target KL, then the gradient is zeroed and
the optimization step is skipped.
This also handles the multi-gpu case where the policy KL is averaged across all processes.
Args:
policy_kl (torch.Tensor):
the policy KL
Returns:
`bool`: whether to early stop or not
"""
early_stop = False
if not self.config.early_stopping:
return early_stop
if not self.is_distributed and policykl > 1.5 * self.config.target_kl:
self.optimizer.zero_grad()
early_stop = True
elif self.is_distributed:
import torch.distributed as dist
# Wait for all processes to finish
dist.barrier()
# all gather the policykl
dist.all_reduce(policykl, dist.ReduceOp.SUM)
policykl /= self.accelerator.num_processes
if policykl > 1.5 * self.config.target_kl:
self.optimizer.zero_grad()
early_stop = True
return early_stop
def gather_stats(self, stats):
"""
Gather stats from all processes. Useful in the context of distributed training.
Args:
stats (dict[str, Any]):
a dictionary of stats to be gathered. The stats should contain torch tensors.
Returns:
`dict[str, Any]`: A dictionary of stats with the tensors gathered.
"""
import torch.distributed as dist
# Wait for all processes to finish
dist.barrier()
for k, v in stats.items():
if isinstance(v, torch.Tensor):
dist.all_reduce(v.to(self.accelerator.device), dist.ReduceOp.SUM)
v /= self.accelerator.num_processes
stats[k] = v
return stats
def prepare_model_inputs(self, queries: torch.Tensor, responses: torch.Tensor):
if self.is_encoder_decoder:
input_data = self.data_collator([{"input_ids": q, "attention_mask": torch.ones_like(q)} for q in queries]).to(self.current_device)
decoder_inputs = self.data_collator([{"input_ids": r, "attention_mask": torch.ones_like(r)} for r in responses]).to(self.current_device)
input_data["decoder_input_ids"] = decoder_inputs["input_ids"]
input_data["decoder_attention_mask"] = decoder_inputs["attention_mask"]
else:
input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)]
input_data = self.data_collator([{"input_ids": ids, "attention_mask": torch.ones_like(ids)} for ids in input_ids]).to(self.current_device)
input_data.pop("labels", None) # we don't want to compute LM losses
return input_data
@PPODecorators.empty_device_cache()
def batched_forward_pass(
self,
model: PreTrainedModelWrapper,
queries: torch.Tensor,
responses: torch.Tensor,
model_inputs: dict,
return_logits: bool = False,
response_masks: Optional[torch.Tensor] = None,
):
"""
Calculate model outputs in multiple batches.
Args:
queries (`torch.LongTensor`):
List of tensors containing the encoded queries, shape (`batch_size`, `query_length`)
responses (`torch.LongTensor`):
List of tensors containing the encoded responses, shape (`batch_size`, `response_length`)
return_logits (`bool`, *optional*, defaults to `False`):
Whether to return all_logits. Set to `False` if logits are not needed to reduce memory consumption.
Returns:
(tuple):
- all_logprobs (`torch.FloatTensor`): Log probabilities of the responses,
shape (`batch_size`, `response_length`)
- all_ref_logprobs (`torch.FloatTensor`): Log probabilities of the responses,
shape (`batch_size`, `response_length`)
- all_values (`torch.FloatTensor`): Values of the responses, shape (`batch_size`, `response_length`)
"""
bs = len(queries)
fbs = self.config.mini_batch_size
all_logprobs = []
all_logits = []
all_masks = []
all_values = []
model.eval()
for i in range(math.ceil(bs / fbs)):
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
query_batch = queries[i * fbs : (i + 1) * fbs]
response_batch = responses[i * fbs : (i + 1) * fbs]
if response_masks is not None:
response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
logits, _, values = model(**input_kwargs)
if self.is_encoder_decoder:
input_ids = input_kwargs["decoder_input_ids"]
attention_mask = input_kwargs["decoder_attention_mask"]
else:
input_ids = input_kwargs["input_ids"]
attention_mask = input_kwargs["attention_mask"]
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
masks = torch.zeros_like(attention_mask)
masks[:, :-1] = attention_mask[:, 1:]
for j in range(len(query_batch)):
if self.is_encoder_decoder:
# Decoder sentence starts always in the index 1 after padding in the Enc-Dec Models
start = 1
end = attention_mask[j, :].sum() - 1
else:
start = len(query_batch[j]) - 1 # logprobs starts from the second query token
if attention_mask[j, 0] == 0: # offset left padding
start += attention_mask[j, :].nonzero()[0]
end = start + len(response_batch[j])
if response_masks is not None:
response_masks_batch[j] = torch.cat((torch.zeros_like(query_batch[j]), response_masks_batch[j]))[1:]
masks[j, :start] = 0
masks[j, end:] = 0
if response_masks is not None:
masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
if return_logits:
all_logits.append(logits)
else:
del logits
all_values.append(values)
all_logprobs.append(logprobs)
all_masks.append(masks)
return (
torch.cat(all_logprobs),
torch.cat(all_logits)[:, :-1] if return_logits else None,
torch.cat(all_values)[:, :-1],
torch.cat(all_masks)[:, :-1],
)
@PPODecorators.empty_device_cache()
def train_minibatch(
self,
old_logprobs: torch.FloatTensor,
values: torch.FloatTensor,
logprobs: torch.FloatTensor,
logits: torch.FloatTensor,
vpreds: torch.FloatTensor,
mask: torch.LongTensor,
advantages: torch.FloatTensor,
returns: torch.FloatTensor,
):
"""
Train one PPO minibatch
Args:
logprobs (`torch.FloatTensor`):
Log probabilities of the model, shape [mini_batch_size, response_length]
values (`torch.FloatTensor`):
Values of the value head, shape [mini_batch_size, response_length]
query (`torch.LongTensor`):
Encoded queries, shape [mini_batch_size, query_length]
response (`torch.LongTensor`):
Encoded responses, shape [mini_batch_size, response_length]
model_input (`torch.LongTensor`):
Concatenated queries and responses, shape [mini_batch_size, query_length+response_length]
Returns:
train_stats (dict[str, `torch.Tensor`]):
Dictionary of training statistics
"""
self.model.train()
loss_p, loss_v, train_stats = self.loss(old_logprobs, values, logits, vpreds, logprobs, mask, advantages, returns)
loss = loss_p + loss_v
self.accelerator.backward(loss)
if self.config.max_grad_norm is not None:
if self.accelerator.sync_gradients:
self.accelerator.clip_grad_norm_(self.model_params, self.config.max_grad_norm)
self.optimizer.step()
# we call optimizer.zero_grad() every time and let `accelerator` handle accumulation
# see https://huggingface.co/docs/accelerate/usage_guides/gradient_accumulation#the-finished-code
self.optimizer.zero_grad()
return train_stats
def compute_rewards(
self,
scores: torch.FloatTensor,
logprobs: torch.FloatTensor,
ref_logprobs: torch.FloatTensor,
masks: torch.LongTensor,
):
"""
Compute per token rewards from scores and KL-penalty.
Args:
scores (`torch.FloatTensor`):
Scores from the reward model, shape (`batch_size`)
logprobs (`torch.FloatTensor`):
Log probabilities of the model, shape (`batch_size`, `response_length`)
ref_logprobs (`torch.FloatTensor`):
Log probabilities of the reference model, shape (`batch_size`, `response_length`)
Returns:
`torch.FloatTensor`: Per token rewards, shape (`batch_size`, `response_length`)
`torch.FloatTensor`: Non score rewards, shape (`batch_size`, `response_length`)
`torch.FloatTensor`: KL penalty, shape (`batch_size`, `response_length`)
"""
rewards, non_score_rewards, kls = [], [], []
for score, logprob, ref_logprob, mask in zip(scores, logprobs, ref_logprobs, masks):
# compute KL penalty (from difference in logprobs)
kl = self._kl_penalty(logprob, ref_logprob)
kls.append(kl)
non_score_reward = -self.kl_ctl.value * kl
non_score_rewards.append(non_score_reward)
reward = non_score_reward.clone()
last_non_masked_index = mask.nonzero()[-1]
# reward is preference model score + KL penalty
reward[last_non_masked_index] += score
rewards.append(reward)
return torch.stack(rewards), torch.stack(non_score_rewards), torch.stack(kls)
def _kl_penalty(self, logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor) -> torch.FloatTensor:
if self.config.kl_penalty == "kl":
return logprob - ref_logprob
if self.config.kl_penalty == "abs":
return (logprob - ref_logprob).abs()
if self.config.kl_penalty == "mse":
return 0.5 * (logprob - ref_logprob).square()
if self.config.kl_penalty == "full":
# Flip is required due to this issue? :https://github.com/pytorch/pytorch/issues/57459
return F.kl_div(ref_logprob, logprob, log_target=True, reduction="none").sum(-1)
raise NotImplementedError
def compute_advantages(
self,
values: torch.FloatTensor,
rewards: torch.FloatTensor,
mask: torch.FloatTensor,
):
lastgaelam = 0
advantages_reversed = []
gen_len = rewards.shape[-1]
values = values * mask
rewards = rewards * mask
if self.config.whiten_rewards:
rewards = masked_whiten(rewards, mask, shift_mean=False)
for t in reversed(range(gen_len)):
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t]
lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)
returns = advantages + values
advantages = masked_whiten(advantages, mask)
advantages = advantages.detach()
return values, advantages, returns
def loss(
self,
old_logprobs: torch.FloatTensor,
values: torch.FloatTensor,
logits: torch.FloatTensor,
vpreds: torch.FloatTensor,
logprobs: torch.FloatTensor,
mask: torch.LongTensor,
advantages: torch.FloatTensor,
returns: torch.FloatTensor,
):
"""
Calculate policy and value losses.
Args:
old_logprobs (`torch.FloatTensor`):
Log probabilities of the model, shape (`batch_size`, `response_length`)
values (`torch.FloatTensor`):
Values of the value head, shape (`batch_size`, `response_length`)
rewards (`torch.FloatTensor`):
Rewards from the reward model, shape (`batch_size`, `response_length`)
logits (`torch.FloatTensor`):
Logits of the model, shape (`batch_size`, `response_length`, `vocab_size`)
v_pred (`torch.FloatTensor`):
Values of the value head, shape (`batch_size`, `response_length`)
logprobs (`torch.FloatTensor`):
Log probabilities of the model, shape (`batch_size`, `response_length`)
"""
vpredclipped = clip_by_value(
vpreds,
values - self.config.cliprange_value,
values + self.config.cliprange_value,
)
vf_losses1 = (vpreds - returns) ** 2
vf_losses2 = (vpredclipped - returns) ** 2
vf_loss = 0.5 * masked_mean(torch.max(vf_losses1, vf_losses2), mask)
vf_clipfrac = masked_mean(torch.gt(vf_losses2, vf_losses1).float(), mask)
ratio = torch.exp(logprobs - old_logprobs)
pg_losses = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - self.config.cliprange, 1.0 + self.config.cliprange)
pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), mask)
pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), mask)
loss = pg_loss + self.config.vf_coef * vf_loss
avg_ratio = masked_mean(ratio, mask).item()
if avg_ratio > self.config.ratio_threshold:
warnings.warn(f"The average ratio of batch ({avg_ratio:.2f}) exceeds threshold {self.config.ratio_threshold:.2f}. Skipping batch.")
pg_loss = pg_loss * 0.0
vf_loss = vf_loss * 0.0
loss = loss * 0.0
entropy = masked_mean(entropy_from_logits(logits), mask)
approxkl = 0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask)
policykl = masked_mean(old_logprobs - logprobs, mask)
return_mean, return_var = masked_mean(returns, mask), masked_var(returns, mask)
value_mean, value_var = masked_mean(values, mask), masked_var(values, mask)
stats = dict(
loss=dict(policy=pg_loss.detach(), value=vf_loss.detach(), total=loss.detach()),
policy=dict(
entropy=entropy.detach(),
approxkl=approxkl.detach(),
policykl=policykl.detach(),
clipfrac=pg_clipfrac.detach(),
advantages=advantages.detach(),
advantages_mean=masked_mean(advantages, mask).detach(),
ratio=ratio.detach(),
),
returns=dict(mean=return_mean.detach(), var=return_var.detach()),
val=dict(
vpred=masked_mean(vpreds, mask).detach(),
error=masked_mean((vpreds - returns) ** 2, mask).detach(),
clipfrac=vf_clipfrac.detach(),
mean=value_mean.detach(),
var=value_var.detach(),
),
)
return pg_loss, self.config.vf_coef * vf_loss, flatten_dict(stats)
def record_step_stats(self, kl_coef: float, **data):
"""
Record training step statistics.
Args:
kl_coef (`float`):
KL coefficient
data (`dict`):
Dictionary of training step data
Returns:
stats (`dict`):
Dictionary of training step statistics
"""
mask = data.pop("masks")
kls = data.pop("kls")
kl_list = ((kls) * mask).sum(axis=-1)
mean_kl = kl_list.mean()
mean_entropy = (-data["logprobs"] * mask).sum(axis=-1).mean()
mean_non_score_reward = masked_mean(data["non_score_reward"], mask) # non_score_reward is size `batch_size`, `response_length`
mean_scores = data["scores"].mean() # scores is size `batch_size`
std_scores = data["scores"].std()
if mean_kl.item() < -1.0:
# warn users
warnings.warn(
f"KL divergence is starting to become negative: {mean_kl.item():.2f} - this might be a precursor for failed training."
" sometimes this happens because the generation kwargs are not correctly set. Please make sure"
" that the generation kwargs are set correctly, or review your training hyperparameters."
)
stats = {
"objective/kl": mean_kl,
"objective/kl_dist": kl_list,
"objective/logprobs": data["logprobs"],
"objective/ref_logprobs": data["ref_logprobs"],
"objective/kl_coef": kl_coef,
"objective/entropy": mean_entropy,
"ppo/mean_non_score_reward": mean_non_score_reward,
"ppo/mean_scores": mean_scores,
"ppo/std_scores": std_scores,
}
# Log text properties
query_lens = torch.tensor([len(query) for query in data["queries"]], dtype=torch.float)
response_lens = torch.tensor([len(response) for response in data["responses"]], dtype=torch.float)
stats["tokens/queries_len_mean"] = torch.mean(query_lens).cpu().numpy().item()
stats["tokens/queries_len_std"] = torch.std(query_lens).cpu().numpy().item()
stats["tokens/queries_dist"] = query_lens.cpu().numpy()
stats["tokens/responses_len_mean"] = torch.mean(response_lens).cpu().numpy().item()
stats["tokens/responses_len_std"] = torch.std(response_lens).cpu().numpy().item()
stats["tokens/responses_dist"] = response_lens.cpu().numpy()
for k, v in data["train_stats"].items():
stats[f"ppo/{k}"] = torch.mean(v, axis=0)
stats["ppo/val/var_explained"] = 1 - stats["ppo/val/error"] / stats["ppo/returns/var"]
return stats
def log_stats(
self,
stats: dict,
batch: dict,
rewards: List[torch.FloatTensor],
columns_to_log: List[str] = ["query", "response"],
):
"""
A function that logs all the training stats. Call it at the end of each epoch.
Args:
stats (dict[str, Any]):
A dictionary of training stats.
batch (dict[str, Any]):
A dictionary of batch data, this contains the queries and responses.
rewards (`List[torch.FloatTensor]`):
A tensor of rewards.
"""
# all gather stats
if not isinstance(rewards, torch.Tensor):
rewards = torch.tensor(rewards).to(self.current_device)
rewards = self.accelerator.gather(rewards).flatten()
if self.config.log_with == "wandb":
import wandb
if any([column_to_log not in batch.keys() for column_to_log in columns_to_log]):
raise ValueError(f"Columns to log {columns_to_log} are not present in the batch {batch.keys()}.")
batch_list = [batch[column_to_log] for column_to_log in columns_to_log]
if self.is_distributed:
gathered_batch_list = []
for b in batch_list:
flattened = gather_object(b)
gathered_batch_list.append(flattened)
batch_list = gathered_batch_list
# Log only if we are in the main process
if self.accelerator.is_main_process:
logs = {}
# Log stats
if "query" not in batch.keys() and "response" not in batch.keys():
# warn the user that the game logs will not be logged
warnings.warn("The game logs will not be logged because the batch does not contain the keys 'query' and " "'response'. ")
elif self.config.log_with == "wandb":
table_rows = [list(r) for r in zip(*batch_list, rewards.cpu().tolist())]
logs.update({"game_log": wandb.Table(columns=[*columns_to_log, "reward"], rows=table_rows)})
logs.update(stats)
# manually cast in fp32 for bf16 torch tensors
for k, v in logs.items():
if isinstance(v, torch.Tensor) and v.dtype == torch.bfloat16:
logs[k] = v.float()
logs["env/reward_mean"] = torch.mean(rewards).cpu().numpy().item()
logs["env/reward_std"] = torch.std(rewards).cpu().numpy().item()
logs["env/reward_dist"] = rewards.cpu().numpy()
if self.config.log_with == "tensorboard":
# update the current step
self.current_step += 1
self.accelerator.log(
logs,
step=self.current_step if self.config.log_with == "tensorboard" else None,
)
def create_model_card(self, path: str, model_name: Optional[str] = "TRL Model") -> None:
"""Creates and saves a model card for a TRL model.
Args:
path (`str`): The path to save the model card to.
model_name (`str`, *optional*): The name of the model, defaults to `TRL Model`.
"""
try:
user = whoami()["name"]
# handle the offline case
except: # noqa
warnings.warn("Cannot retrieve user information assuming you are running in offline mode.")
return
if not os.path.exists(path):
os.makedirs(path)
model_card_content = MODEL_CARD_TEMPLATE.format(model_name=model_name, model_id=f"{user}/{path}")
with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f:
f.write(model_card_content)
def _save_pretrained(self, save_directory: str) -> None:
self.accelerator.unwrap_model(self.model).save_pretrained(save_directory)
self.tokenizer.save_pretrained(save_directory)
self.create_model_card(save_directory)
def _show_tokens(self, tokens, masks):
from rich import print
from rich.text import Text
text = Text()
for i, (token, mask) in enumerate(zip(tokens, masks)):
if mask == 1:
text.append(self.tokenizer.decode(token.item()), style="black on deep_sky_blue1")
text.append(" ")
else:
text.append(self.tokenizer.decode(token.item()), style="black on cyan3")
text.append(" ")
print(text)
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
config_kwargs = deepspeed_plugin.deepspeed_config
if model is not None:
if hasattr(model, "config"):
hidden_size = max(model.config.hidden_sizes) if getattr(model.config, "hidden_sizes", None) else getattr(model.config, "hidden_size", None)
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
config_kwargs.update(
{
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
}
)
# If ZeRO-3 is used, we shard both the active and reference model.
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
if config_kwargs["zero_optimization"]["stage"] != 3:
config_kwargs["zero_optimization"]["stage"] = 0
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
model.eval()
return model
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
from transformers import TrainingArguments
@dataclass
class RewardConfig(TrainingArguments):
"""
RewardConfig collects all training arguments related to the [`RewardTrainer`] class.
Using [`HfArgumentParser`] we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.
Parameters:
max_length (`int`, *optional*, defaults to `None`):
The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
gradient_checkpointing (`bool`, *optional*, defaults to `True`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
"""
max_length: Optional[int] = None
"""The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator."""
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import warnings
from dataclasses import FrozenInstanceError, replace
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from datasets import Dataset
from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainingArguments
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_pt_utils import nested_detach
from transformers.trainer_utils import EvalPrediction
from ..import_utils import is_peft_available
from .reward_config import RewardConfig
from .utils import RewardDataCollatorWithPadding, compute_accuracy
if is_peft_available():
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
class RewardTrainer(Trainer):
r"""
The RewardTrainer can be used to train your custom Reward Model. It is a subclass of the
`transformers.Trainer` class and inherits all of its attributes and methods. It is recommended to use
an `AutoModelForSequenceClassification` as the reward model. The reward model should be trained on a dataset
of paired examples, where each example is a tuple of two sequences. The reward model should be trained to
predict which example in the pair is more relevant to the task at hand.
The reward trainer expects a very specific format for the dataset. The dataset should contain two 4 entries at least
if you don't use the default `RewardDataCollatorWithPadding` data collator. The entries should be named
- `input_ids_chosen`
- `attention_mask_chosen`
- `input_ids_rejected`
- `attention_mask_rejected`
Optionally, you can also pass a `margin` entry to the dataset. This entry should contain the margin used to modulate the
loss of the reward model as outlined in https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/.
If you don't pass a margin, no margin will be used.
"""
def __init__(
self,
model: Union[PreTrainedModel, nn.Module] = None,
args: Optional[RewardConfig] = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
None,
None,
),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
max_length: Optional[int] = None,
peft_config: Optional[Dict] = None,
):
"""
Initialize RewardTrainer.
Args:
model (`transformers.PreTrainedModel`):
The model to train, preferably an `AutoModelForSequenceClassification`.
args (`RewardConfig`):
The arguments to use for training.
data_collator (`transformers.DataCollator`):
The data collator to use for training. If None is specified, the default data collator (`RewardDataCollatorWithPadding`) will be used
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
train_dataset (`datasets.Dataset`):
The dataset to use for training.
eval_dataset (`datasets.Dataset`):
The dataset to use for evaluation.
tokenizer (`transformers.PreTrainedTokenizerBase`):
The tokenizer to use for training. This argument is required if you want to use the default data collator.
model_init (`Callable[[], transformers.PreTrainedModel]`):
The model initializer to use for training. If None is specified, the default model initializer will be used.
compute_metrics (`Callable[[transformers.EvalPrediction], Dict]`, *optional* defaults to `compute_accuracy`):
The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
callbacks (`List[transformers.TrainerCallback]`):
The callbacks to use for training.
optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
The optimizer and scheduler to use for training.
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
The function to use to preprocess the logits before computing the metrics.
max_length (`int`, defaults to `None`):
The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
peft_config (`Dict`, defaults to `None`):
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
"""
if type(args) == TrainingArguments:
warnings.warn(
"Using `transformers.TrainingArguments` for `args` is deprecated and will be removed in a future version. Please use `RewardConfig` instead.",
FutureWarning,
)
if max_length is not None:
warnings.warn(
"The `max_length` argument is deprecated and will be removed in a future version. Please use the `RewardConfig` to set `max_length` instead.",
FutureWarning,
)
else:
if max_length is not None and args.max_length is not None:
raise ValueError("You cannot specify both `max_length` and `args.max_length`. Please use the `RewardConfig` to set `max_length` once.")
if max_length is not None and args.max_length is None:
warnings.warn(
"The `max_length` argument is deprecated and will be removed in a future version. Please use the `RewardConfig` to set `max_length` instead.",
FutureWarning,
)
if not is_peft_available() and peft_config is not None:
raise ValueError("PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models")
elif is_peft_available() and peft_config is not None:
if not isinstance(model, PeftModel):
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(inspect.signature(prepare_model_for_kbit_training).parameters)
preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
warnings.warn("You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. " "please update to the latest version of peft to use `gradient_checkpointing_kwargs`.")
elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
model = prepare_model_for_kbit_training(model, **preprare_model_kwargs)
model = get_peft_model(model, peft_config)
if compute_metrics is None:
compute_metrics = compute_accuracy
if data_collator is None:
if tokenizer is None:
raise ValueError("max_length or a tokenizer must be specified when using the default RewardDataCollatorWithPadding")
if type(args) == TrainingArguments:
if max_length is None:
warnings.warn(
"When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig." " It will be set to `512` by default, but you should do it yourself in the future.",
UserWarning,
)
max_length = 512
else:
if max_length is None and args.max_length is None:
warnings.warn(
"When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig." " It will be set to `512` by default, but you should do it yourself in the future.",
UserWarning,
)
max_length = 512
if max_length is None and args.max_length is not None:
max_length = args.max_length
data_collator = RewardDataCollatorWithPadding(tokenizer, max_length=max_length)
if args.remove_unused_columns:
try: # for bc before https://github.com/huggingface/transformers/pull/25435
args.remove_unused_columns = False
except FrozenInstanceError:
args = replace(args, remove_unused_columns=False)
# warn users
warnings.warn(
"When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig" " we have set it for you, but you should do it yourself in the future.",
UserWarning,
)
self.use_reward_data_collator = True
else:
self.use_reward_data_collator = False
super().__init__(
model,
args,
data_collator,
train_dataset,
eval_dataset,
tokenizer,
model_init,
compute_metrics,
callbacks,
optimizers,
preprocess_logits_for_metrics,
)
def compute_loss(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs=False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
if not self.use_reward_data_collator:
warnings.warn("The current compute_loss is implemented for RewardDataCollatorWithPadding," " if you are using a custom data collator make sure you know what you are doing or" " implement your own compute_loss method.")
rewards_chosen = model(
input_ids=inputs["input_ids_chosen"],
attention_mask=inputs["attention_mask_chosen"],
return_dict=True,
)["logits"]
rewards_rejected = model(
input_ids=inputs["input_ids_rejected"],
attention_mask=inputs["attention_mask_rejected"],
return_dict=True,
)["logits"]
# calculate loss, optionally modulate with margin
if "margin" in inputs:
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
else:
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
if return_outputs:
return loss, {
"rewards_chosen": rewards_chosen,
"rewards_rejected": rewards_rejected,
}
return loss
def prediction_step(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
inputs = self._prepare_inputs(inputs)
if ignore_keys is None:
if hasattr(self.model, "config"):
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
else:
ignore_keys = []
with torch.no_grad():
loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True)
if prediction_loss_only:
return (loss, None, None)
loss = loss.detach()
logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
logits = nested_detach(logits)
# Stack accepted against rejected, mean over logits
# and softmax to get preferences between accepted and rejected to sum to 1
logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T
labels = torch.zeros(logits.shape[0])
labels = self._prepare_inputs(labels)
return loss, logits, labels
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import inspect
import warnings
from functools import wraps
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from accelerate.state import PartialState
from datasets import Dataset
from datasets.arrow_writer import SchemaInferenceError
from datasets.builder import DatasetGenerationError
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollator,
DataCollatorForLanguageModeling,
PreTrainedModel,
PreTrainedTokenizerBase,
Trainer,
TrainingArguments,
)
from transformers.modeling_utils import unwrap_model
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from ..extras.dataset_formatting import get_formatting_func_from_dataset
from ..import_utils import is_peft_available
from .utils import (
ConstantLengthDataset,
DataCollatorForCompletionOnlyLM,
neftune_post_forward_hook,
peft_module_casting_to_bf16,
trl_sanitze_kwargs_for_tagging,
)
if is_peft_available():
from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
class SFTTrainer(Trainer):
r"""
Class definition of the Supervised Finetuning Trainer (SFT Trainer).
This class is a wrapper around the `transformers.Trainer` class and inherits all of its attributes and methods.
The trainer takes care of properly initializing the PeftModel in case a user passes a `PeftConfig` object.
Args:
model (Union[`transformers.PreTrainedModel`, `nn.Module`, `str`]):
The model to train, can be a `PreTrainedModel`, a `torch.nn.Module` or a string with the model name to
load from cache or download. The model can be also converted to a `PeftModel` if a `PeftConfig` object is
passed to the `peft_config` argument.
args (Optional[`transformers.TrainingArguments`]):
The arguments to tweak for training. Please refer to the official documentation of `transformers.TrainingArguments`
for more information.
data_collator (Optional[`transformers.DataCollator`]):
The data collator to use for training.
train_dataset (Optional[`datasets.Dataset`]):
The dataset to use for training. We recommend users to use `trl.trainer.ConstantLengthDataset` to create their dataset.
eval_dataset (Optional[Union[`datasets.Dataset`, Dict[`str`, `datasets.Dataset`]]]):
The dataset to use for evaluation. We recommend users to use `trl.trainer.ConstantLengthDataset` to create their dataset.
tokenizer (Optional[`transformers.PreTrainedTokenizer`]):
The tokenizer to use for training. If not specified, the tokenizer associated to the model will be used.
model_init (`Callable[[], transformers.PreTrainedModel]`):
The model initializer to use for training. If None is specified, the default model initializer will be used.
compute_metrics (`Callable[[transformers.EvalPrediction], Dict]`, *optional* defaults to None):
The function used to compute metrics during evaluation. It should return a dictionary mapping metric names to metric values.
If not specified, only the loss will be computed during evaluation.
callbacks (`List[transformers.TrainerCallback]`):
The callbacks to use for training.
optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
The optimizer and scheduler to use for training.
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
The function to use to preprocess the logits before computing the metrics.
peft_config (`Optional[PeftConfig]`):
The PeftConfig object to use to initialize the PeftModel.
dataset_text_field (`Optional[str]`):
The name of the text field of the dataset, in case this is passed by a user, the trainer will automatically create a
`ConstantLengthDataset` based on the `dataset_text_field` argument.
formatting_func (`Optional[Callable]`):
The formatting function to be used for creating the `ConstantLengthDataset`.
max_seq_length (`Optional[int]`):
The maximum sequence length to use for the `ConstantLengthDataset` and for automatically creating the Dataset. Defaults to `512`.
infinite (`Optional[bool]`):
Whether to use an infinite dataset or not. Defaults to `False`.
num_of_sequences (`Optional[int]`):
The number of sequences to use for the `ConstantLengthDataset`. Defaults to `1024`.
chars_per_token (`Optional[float]`):
The number of characters per token to use for the `ConstantLengthDataset`. Defaults to `3.6`. You can check how this is computed in the
stack-llama example: https://github.com/huggingface/trl/blob/08f550674c553c36c51d1027613c29f14f3676a5/examples/stack_llama/scripts/supervised_finetuning.py#L53.
packing (`Optional[bool]`):
Used only in case `dataset_text_field` is passed. This argument is used by the `ConstantLengthDataset` to pack the sequences
of the dataset.
dataset_num_proc (`Optional[int]`):
The number of workers to use to tokenize the data. Only used when `packing=False`. Defaults to None.
dataset_batch_size (`int`):
The number of examples to tokenize per batch. If batch_size <= 0 or batch_size == None,
tokenize the full dataset as a single batch. Defaults to 1000.
neftune_noise_alpha (`Optional[float]`):
If not `None`, this will activate NEFTune noise embeddings. This has been proven to drastically improve model performances for instruction
fine-tuning. Check out the original paper here: https://arxiv.org/abs/2310.05914 and the original code here: https://github.com/neelsjain/NEFTune
model_init_kwargs: (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when instantiating the model from a string
dataset_kwargs: (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when creating packed or non-packed datasets
"""
_tag_names = ["trl", "sft"]
def __init__(
self,
model: Union[PreTrainedModel, nn.Module, str] = None,
args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
peft_config: Optional["PeftConfig"] = None,
dataset_text_field: Optional[str] = None,
packing: Optional[bool] = False,
formatting_func: Optional[Callable] = None,
max_seq_length: Optional[int] = None,
infinite: Optional[bool] = None,
num_of_sequences: Optional[int] = 1024,
chars_per_token: Optional[float] = 3.6,
dataset_num_proc: Optional[int] = None,
dataset_batch_size: int = 1000,
neftune_noise_alpha: Optional[float] = None,
model_init_kwargs: Optional[Dict] = None,
dataset_kwargs: Optional[Dict] = None,
):
if model_init_kwargs is None:
model_init_kwargs = {}
elif not isinstance(model, str):
raise ValueError("You passed model_kwargs to the SFTTrainer. But your model is already instantiated.")
if infinite is not None:
warnings.warn("The `infinite` argument is deprecated and will be removed in a future version of TRL. Use `TrainingArguments.max_steps` or `TrainingArguments.num_train_epochs` instead to control training length.")
if isinstance(model, str):
warnings.warn("You passed a model_id to the SFTTrainer. This will automatically create an " "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you.")
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
if packing and data_collator is not None and isinstance(data_collator, DataCollatorForCompletionOnlyLM):
raise ValueError("You passed a `DataCollatorForCompletionOnlyLM` to the SFTTrainer. This is not compatible with the `packing` argument.")
if is_peft_available() and peft_config is not None:
if not isinstance(peft_config, PeftConfig):
raise ValueError("If you want to use the PeftModel, you need to pass a PeftConfig object to the SFTTrainer." f" and you passed a {type(peft_config)}.")
if not isinstance(model, PeftModel):
_support_gc_kwargs = hasattr(args, "gradient_checkpointing_kwargs") and "gradient_checkpointing_kwargs" in list(inspect.signature(prepare_model_for_kbit_training).parameters)
gradient_checkpointing_kwargs = getattr(args, "gradient_checkpointing_kwargs", None) or {}
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
preprare_model_kwargs = {"use_gradient_checkpointing": getattr(args, "gradient_checkpointing", False)}
if _support_gc_kwargs:
preprare_model_kwargs["gradient_checkpointing_kwargs"] = gradient_checkpointing_kwargs
model = prepare_model_for_kbit_training(model, **preprare_model_kwargs)
if args is not None:
args = dataclasses.replace(args, gradient_checkpointing=False)
elif getattr(args, "gradient_checkpointing", False) and ("use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]):
# For backward compatibility with older versions of transformers
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
model = get_peft_model(model, peft_config)
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
peft_module_casting_to_bf16(model)
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)
if getattr(tokenizer, "pad_token", None) is None:
tokenizer.pad_token = tokenizer.eos_token
if max_seq_length is None:
# to overcome some issues with broken tokenizers
max_seq_length = min(tokenizer.model_max_length, 1024)
warnings.warn(f"You didn't pass a `max_seq_length` argument to the SFTTrainer, this will default to {max_seq_length}")
self.dataset_num_proc = dataset_num_proc
self.dataset_batch_size = dataset_batch_size
self._trainer_supports_neftune = hasattr(args, "neftune_noise_alpha")
if neftune_noise_alpha is not None and self._trainer_supports_neftune:
args.neftune_noise_alpha = neftune_noise_alpha
warnings.warn("You passed a `neftune_noise_alpha` argument to the SFTTrainer, the value you passed will override the one in the `TrainingArguments`.")
# self.neftune_noise_alpha is done at Trainer level
elif not self._trainer_supports_neftune:
self.neftune_noise_alpha = neftune_noise_alpha
if formatting_func is None and dataset_text_field is None:
# check if dataset has ChatML format or instruction format and is supported
# if not stays #None
formatting_func = get_formatting_func_from_dataset(train_dataset, tokenizer)
if not packing:
if dataset_text_field is None and formatting_func is None:
raise ValueError("You passed `packing=False` to the SFTTrainer, but you didn't pass a `dataset_text_field` or `formatting_func` argument.")
if data_collator is None:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Pre-process the datasets only once per node. The remaining processes will use the cache.
with PartialState().local_main_process_first():
if dataset_kwargs is None:
dataset_kwargs = {}
if train_dataset is not None:
train_dataset = self._prepare_dataset(
train_dataset,
tokenizer,
packing,
dataset_text_field,
max_seq_length,
formatting_func,
num_of_sequences,
chars_per_token,
remove_unused_columns=args.remove_unused_columns if args is not None else True,
**dataset_kwargs,
)
if eval_dataset is not None:
_multiple = isinstance(eval_dataset, dict)
_eval_datasets = eval_dataset if _multiple else {"singleton": eval_dataset}
for _eval_dataset_name, _eval_dataset in _eval_datasets.items():
_eval_datasets[_eval_dataset_name] = self._prepare_dataset(
_eval_dataset,
tokenizer,
packing,
dataset_text_field,
max_seq_length,
formatting_func,
num_of_sequences,
chars_per_token,
remove_unused_columns=args.remove_unused_columns if args is not None else True,
**dataset_kwargs,
)
if not _multiple:
eval_dataset = _eval_datasets["singleton"]
if tokenizer.padding_side is not None and tokenizer.padding_side != "right":
warnings.warn(
"You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to "
"overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code."
)
super().__init__(
model=model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
model_init=model_init,
compute_metrics=compute_metrics,
callbacks=callbacks,
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
if self.args.max_steps > 0 and packing:
warnings.warn("You passed `packing=True` to the SFTTrainer, and you are training your model with `max_steps` strategy. The dataset will be iterated until the `max_steps` are reached.")
self.train_dataset.infinite = True
elif self.args.max_steps == -1 and packing:
self.train_dataset.infinite = False
@wraps(Trainer.train)
def train(self, *args, **kwargs):
# Activate neftune right before training.
if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune:
self.model = self._trl_activate_neftune(self.model)
output = super().train(*args, **kwargs)
# After training we make sure to retrieve back the original forward pass method
# for the embedding layer by removing the forward post hook.
if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune:
unwrapped_model = unwrap_model(self.model)
if is_peft_available() and isinstance(unwrapped_model, PeftModel):
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
else:
embeddings = unwrapped_model.get_input_embeddings()
self.neftune_hook_handle.remove()
del embeddings.neftune_noise_alpha
return output
@wraps(Trainer.push_to_hub)
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)
return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
def _prepare_dataset(
self,
dataset,
tokenizer,
packing,
dataset_text_field,
max_seq_length,
formatting_func,
num_of_sequences,
chars_per_token,
remove_unused_columns=True,
append_concat_token=True,
add_special_tokens=True,
):
if dataset is None:
raise ValueError("The dataset should not be None")
# check if torch dataset / dataloader and do nothing
if isinstance(dataset, (torch.utils.data.IterableDataset, torch.utils.data.Dataset, ConstantLengthDataset)):
return dataset
if not packing:
return self._prepare_non_packed_dataloader(
tokenizer,
dataset,
dataset_text_field,
max_seq_length,
formatting_func,
add_special_tokens,
remove_unused_columns,
)
else:
return self._prepare_packed_dataloader(
tokenizer,
dataset,
dataset_text_field,
max_seq_length,
num_of_sequences,
chars_per_token,
formatting_func,
append_concat_token,
add_special_tokens,
)
def _prepare_non_packed_dataloader(
self,
tokenizer,
dataset,
dataset_text_field,
max_seq_length,
formatting_func=None,
add_special_tokens=True,
remove_unused_columns=True,
):
use_formatting_func = formatting_func is not None and dataset_text_field is None
self._dataset_sanity_checked = False
# Inspired from: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt
def tokenize(element):
outputs = tokenizer(
element[dataset_text_field] if not use_formatting_func else formatting_func(element),
add_special_tokens=add_special_tokens,
truncation=True,
padding=False,
max_length=max_seq_length,
return_overflowing_tokens=False,
return_length=False,
)
if use_formatting_func and not self._dataset_sanity_checked:
if not isinstance(formatting_func(element), list):
raise ValueError("The `formatting_func` should return a list of processed strings since it can lead to silent bugs.")
else:
self._dataset_sanity_checked = True
return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]}
signature_columns = ["input_ids", "labels", "attention_mask"]
extra_columns = list(set(dataset.column_names) - set(signature_columns))
if not remove_unused_columns and len(extra_columns) > 0:
warnings.warn(
"You passed `remove_unused_columns=False` on a non-packed dataset. This might create some issues with the default collator and yield to errors. If you want to "
f"inspect dataset other columns (in this case {extra_columns}), you can subclass `DataCollatorForLanguageModeling` in case you used the default collator and create your own data collator in order to inspect the unused dataset columns."
)
tokenized_dataset = dataset.map(
tokenize,
batched=True,
remove_columns=dataset.column_names if remove_unused_columns else None,
num_proc=self.dataset_num_proc,
batch_size=self.dataset_batch_size,
)
return tokenized_dataset
def _prepare_packed_dataloader(
self,
tokenizer,
dataset,
dataset_text_field,
max_seq_length,
num_of_sequences,
chars_per_token,
formatting_func=None,
append_concat_token=True,
add_special_tokens=True,
):
if dataset_text_field is not None or formatting_func is not None:
if tokenizer is None:
raise ValueError("You need to pass a tokenizer when using `dataset_text_field` with `SFTTrainer`.")
constant_length_iterator = ConstantLengthDataset(
tokenizer,
dataset,
dataset_text_field=dataset_text_field,
formatting_func=formatting_func,
seq_length=max_seq_length,
infinite=False,
num_of_sequences=num_of_sequences,
chars_per_token=chars_per_token,
eos_token_id=tokenizer.eos_token_id,
append_concat_token=append_concat_token,
add_special_tokens=add_special_tokens,
)
def data_generator(constant_length_iterator):
for i in constant_length_iterator:
yield i
try:
packed_dataset = Dataset.from_generator(data_generator, gen_kwargs={"constant_length_iterator": constant_length_iterator})
except (DatasetGenerationError, SchemaInferenceError):
raise ValueError("Error occurred while packing the dataset. Make sure that your dataset has enough samples to at least yield one packed sequence.")
return packed_dataset
else:
raise ValueError("You need to pass a `dataset_text_field` or `formatting_func` argument to the SFTTrainer if you want to use the `ConstantLengthDataset`.")
def _trl_activate_neftune(self, model):
r"""
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper: https://arxiv.org/abs/2310.05914
Since in transformers Trainer we do have an `_activate_neftune` method, we need to rename this method to avoid conflicts.
"""
unwrapped_model = unwrap_model(model)
if is_peft_available() and isinstance(unwrapped_model, PeftModel):
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
else:
embeddings = unwrapped_model.get_input_embeddings()
embeddings.neftune_noise_alpha = self.neftune_noise_alpha
hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook)
self.neftune_hook_handle = hook_handle
return model
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import warnings
from collections import deque
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from accelerate import PartialState
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import IterableDataset
from transformers import BitsAndBytesConfig, DataCollatorForLanguageModeling, PreTrainedTokenizerBase
from ..import_utils import is_peft_available, is_unsloth_available, is_xpu_available
from ..trainer.model_config import ModelConfig
if is_peft_available():
from peft import LoraConfig, PeftConfig
class AdaptiveKLController:
"""
Adaptive KL controller described in the paper:
https://arxiv.org/pdf/1909.08593.pdf
"""
def __init__(self, init_kl_coef, target, horizon):
self.value = init_kl_coef
self.target = target
self.horizon = horizon
def update(self, current, n_steps):
target = self.target
proportional_error = np.clip(current / target - 1, -0.2, 0.2)
mult = 1 + proportional_error * n_steps / self.horizon
self.value *= mult
class FixedKLController:
"""Fixed KL controller."""
def __init__(self, kl_coef):
self.value = kl_coef
def update(self, current, n_steps):
pass
class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
"""
Data collator used for completion tasks. It ensures that all the tokens of the labels are set to an 'ignore_index'
when they do not come from the assistant. This ensure that the loss is only
calculated on the completion made by the assistant.
Args:
response_template (`Union[str, List[int]]`): the template form that indicates the start of the response, typically something like
'### Response:\n'. It can also be passed as tokenized ids, which can be useful when using a tokenizer that encodes the response
differently if it does not have proper context.
instruction_template (`Union[str, List[int]]`): the template form that indicates the start of the human instruction, typically something like
'### Human:\n'. Useful for assistant-style conversation datasets. It can also be passed as tokenized ids.
mlm (`bool`, *optional*, defaults to `False`): Whether or not to use masked language modeling in the underlying
`DataCollatorForLanguageModeling` class. Note that this option currently has no effect but is present
for flexibility and backwards-compatibility.
ignore_index (`int`, *optional*, defaults to `-100`):
The index to use to ignore the initial tokens with
"""
def __init__(
self,
response_template: Union[str, List[int]],
instruction_template: Union[str, List[int]] = None,
*args,
mlm: bool = False,
ignore_index: int = -100,
**kwargs,
):
super().__init__(*args, mlm=mlm, **kwargs)
self.instruction_template = instruction_template
if isinstance(instruction_template, str):
# The user provides a string, must tokenize
self.instruction_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False)
else:
# The user already provides the token ids
self.instruction_token_ids = instruction_template
self.response_template = response_template
if isinstance(response_template, str):
# The user provides a string, must tokenize
self.response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False)
else:
# The user already provides the token ids
self.response_token_ids = response_template
if not self.mlm and self.instruction_template and self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
warnings.warn(
"The pad_token_id and eos_token_id values of this tokenizer are identical. "
"If you are planning for multi-turn training, "
"it can result in the model continuously generating questions and answers without eos token. "
"To avoid this, set the pad_token_id to a different value."
)
self.ignore_index = ignore_index
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
batch = super().torch_call(examples)
if self.instruction_template is None:
for i in range(len(examples)):
response_token_ids_start_idx = None
for idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]:
# `response_token_ids` is `'### Response:\n'`, here we are just making sure that the token IDs match
if self.response_token_ids == batch["labels"][i][idx : idx + len(self.response_token_ids)].tolist():
response_token_ids_start_idx = idx
if response_token_ids_start_idx is None:
warnings.warn(
f"Could not find response key `{self.response_template}` in the "
f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
f"This instance will be ignored in loss calculation. "
f"Note, if this happens often, consider increasing the `max_seq_length`."
)
batch["labels"][i, :] = self.ignore_index
else:
response_token_ids_end_idx = response_token_ids_start_idx + len(self.response_token_ids)
# Make pytorch loss function ignore all tokens up through the end of the response key
batch["labels"][i, :response_token_ids_end_idx] = self.ignore_index
else:
for i in range(len(examples)):
response_token_ids_idxs = []
human_token_ids_idxs = []
for assistant_idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]:
# find the indexes of the start of a response.
if self.response_token_ids == batch["labels"][i][assistant_idx : assistant_idx + len(self.response_token_ids)].tolist():
response_token_ids_idxs.append(assistant_idx + len(self.response_token_ids))
if len(response_token_ids_idxs) == 0:
warnings.warn(
f"Could not find response key `{self.response_template}` in the "
f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
f"This instance will be ignored in loss calculation. "
f"Note, if this happens often, consider increasing the `max_seq_length`."
)
batch["labels"][i, :] = self.ignore_index
human_token_ids = self.instruction_token_ids
for human_idx in np.where(batch["labels"][i] == human_token_ids[0])[0]:
# find the indexes of the start of a human answer.
if human_token_ids == batch["labels"][i][human_idx : human_idx + len(human_token_ids)].tolist():
human_token_ids_idxs.append(human_idx)
if len(human_token_ids_idxs) == 0:
warnings.warn(
f"Could not find instruction key `{self.instruction_template}` in the "
f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
f"This instance will be ignored in loss calculation. "
f"Note, if this happens often, consider increasing the `max_seq_length`."
)
batch["labels"][i, :] = self.ignore_index
if len(human_token_ids_idxs) > 0 and len(response_token_ids_idxs) > 0 and human_token_ids_idxs[0] > response_token_ids_idxs[0]:
human_token_ids_idxs = [0] + human_token_ids_idxs
for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)):
# Make pytorch loss function ignore all non response tokens
if idx != 0:
batch["labels"][i, start:end] = self.ignore_index
else:
batch["labels"][i, :end] = self.ignore_index
if len(response_token_ids_idxs) < len(human_token_ids_idxs):
batch["labels"][i, human_token_ids_idxs[-1] :] = self.ignore_index
return batch
@dataclass
class RewardDataCollatorWithPadding:
r"""
Reward DataCollator class that pads the inputs to the maximum length of the batch.
Args:
tokenizer (`PreTrainedTokenizerBase`):
The tokenizer used for encoding the data.
padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`):
padding_strategy to pass to the tokenizer.
max_length (`Optional[int]`, `optional`, defaults to `None`):
The maximum length of the sequence to be processed.
pad_to_multiple_of (`Optional[int]`, `optional`, defaults to `None`):
If set will pad the sequence to a multiple of the provided value.
return_tensors (`str`, `optional`, defaults to `"pt"`):
The tensor type to use.
"""
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt"
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
features_chosen = []
features_rejected = []
margin = []
# check if we have a margin. If we do, we need to batch it as well
has_margin = "margin" in features[0]
for feature in features:
# check if the keys are named as expected
if "input_ids_chosen" not in feature or "input_ids_rejected" not in feature or "attention_mask_chosen" not in feature or "attention_mask_rejected" not in feature:
raise ValueError("The features should include `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`")
features_chosen.append(
{
"input_ids": feature["input_ids_chosen"],
"attention_mask": feature["attention_mask_chosen"],
}
)
features_rejected.append(
{
"input_ids": feature["input_ids_rejected"],
"attention_mask": feature["attention_mask_rejected"],
}
)
if has_margin:
margin.append(feature["margin"])
batch_chosen = self.tokenizer.pad(
features_chosen,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
batch_rejected = self.tokenizer.pad(
features_rejected,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
batch = {
"input_ids_chosen": batch_chosen["input_ids"],
"attention_mask_chosen": batch_chosen["attention_mask"],
"input_ids_rejected": batch_rejected["input_ids"],
"attention_mask_rejected": batch_rejected["attention_mask"],
"return_loss": True,
}
if has_margin:
margin = torch.tensor(margin, dtype=torch.float)
batch["margin"] = margin
return batch
@dataclass
class DPODataCollatorWithPadding:
r"""
DPO DataCollator class that pads the tokenized inputs to the maximum length of the batch.
Args:
pad_token_id (`int` defaults to 0):
The tokenizer's pad_token_id.
label_pad_token_id (`int`, defaults to -100):
The label used for masking.
is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`):
Whether or not you model has an encoder_decoder architecture.
"""
tokenizer: PreTrainedTokenizerBase
pad_token_id: int = 0
label_pad_token_id: int = -100
is_encoder_decoder: Optional[bool] = False
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
# first, pad everything to the same length
padded_batch = {}
for k in features[0].keys():
if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"):
if self.is_encoder_decoder:
to_pad = [torch.LongTensor(ex[k]) for ex in features]
if (k.startswith("prompt")) and (k.endswith("input_ids")):
if self.pad_token_id is None:
raise ValueError(
"Padding is enabled, but the tokenizer is not configured with a padding token." " Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)" " before calling the trainer."
)
padding_value = self.pad_token_id
elif k.endswith("_attention_mask"):
padding_value = 0
elif (k.startswith("chosen")) or (k.startswith("rejected")) or ("decoder" in k):
padding_value = self.label_pad_token_id
else:
raise ValueError(f"Unexpected key in batch '{k}'")
padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
else:
# adapted from https://stackoverflow.com/questions/73256206
if "prompt" in k:
to_pad = [torch.LongTensor(ex[k][::-1]) for ex in features]
else:
to_pad = [torch.LongTensor(ex[k]) for ex in features]
if k.endswith("_input_ids"):
if self.pad_token_id is None:
raise ValueError(
"Padding is enabled, but the tokenizer is not configured with a padding token." " Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)" " before calling the trainer."
)
padding_value = self.pad_token_id
elif k.endswith("_labels"):
padding_value = self.label_pad_token_id
elif k.endswith("_attention_mask"):
padding_value = 0
else:
raise ValueError(f"Unexpected key in batch '{k}'")
padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
# for the prompt, flip back so padding is on left side
if "prompt" in k:
padded_batch[k] = padded_batch[k].flip(dims=[1])
elif k.endswith("_logps"):
# the cached reference model logprobs
padded_batch[k] = torch.tensor([ex[k] for ex in features])
else:
padded_batch[k] = [ex[k] for ex in features]
return padded_batch
class ConstantLengthDataset(IterableDataset):
"""
Iterable dataset that returns constant length chunks of tokens from stream of text files.
The dataset also formats the text before tokenization with a specific format that is provided
by the user.
Args:
tokenizer (`transformers.PreTrainedTokenizer`):
The processor used for processing the data.
dataset (`dataset.Dataset`):
Dataset with text files.
dataset_text_field (`str`, **optional**):
Name of the field in the dataset that contains the text. Used only if `formatting_func` is `None`.
formatting_func (`Callable`, **optional**):
Function that formats the text before tokenization. Usually it is recommended to have follows a certain
pattern such as `"### Question: {question} ### Answer: {answer}"`
infinite (`bool`, *optional*, defaults to `False`):
If True the iterator is reset after dataset reaches end else stops.
seq_length (`int`, *optional*, defaults to `1024`):
Length of token sequences to return.
num_of_sequences (`int`, *optional*, defaults to `1024`):
Number of token sequences to keep in buffer.
chars_per_token (`int`, *optional*, defaults to `3.6`):
Number of characters per token used to estimate number of tokens in text buffer.
eos_token_id (`int`, *optional*, defaults to `0`):
Id of the end of sequence token if the passed tokenizer does not have an EOS token.
shuffle ('bool', *optional*, defaults to True)
Shuffle the examples before they are returned
append_concat_token ('bool', *optional*, defaults to True)
If true, appends `eos_token_id` at the end of each sample being packed.
add_special_tokens ('bool', *optional*, defaults to True)
If true, tokenizers adds special tokens to each sample being packed.
"""
def __init__(
self,
tokenizer,
dataset,
dataset_text_field=None,
formatting_func=None,
infinite=False,
seq_length=1024,
num_of_sequences=1024,
chars_per_token=3.6,
eos_token_id=0,
shuffle=True,
append_concat_token=True,
add_special_tokens=True,
):
self.tokenizer = tokenizer
if tokenizer.eos_token_id is None:
warnings.warn(
"The passed tokenizer does not have an EOS token. We will use the passed eos_token_id instead which corresponds" f" to {eos_token_id}. If this is not the correct EOS token, make sure to pass the correct eos_token_id."
)
self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else eos_token_id
self.dataset = dataset
self.seq_length = seq_length
self.infinite = infinite
self.current_size = 0
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
self.shuffle = shuffle
self.append_concat_token = append_concat_token
self.add_special_tokens = add_special_tokens
if formatting_func is None:
self.formatting_func = lambda x: x[dataset_text_field]
else:
self.formatting_func = formatting_func
if formatting_func is not None:
if formatting_func.__code__.co_argcount > 1:
warnings.warn(
"The passed formatting_func has more than one argument. Usually that function should have a single argument `example`"
" which corresponds to the dictionary returned by each element of the dataset. Make sure you know what you are doing."
)
def __len__(self):
return len(self.dataset)
def __iter__(self):
iterator = iter(self.dataset)
more_examples = True
while more_examples:
buffer, buffer_len = [], 0
while True:
if buffer_len >= self.max_buffer_size:
break
try:
buffer.append(self.formatting_func(next(iterator)))
buffer_len += len(buffer[-1])
except StopIteration:
if self.infinite:
iterator = iter(self.dataset)
warnings.warn("The dataset reached end and the iterator is reset to the start.")
else:
more_examples = False
break
tokenized_inputs = self.tokenizer(buffer, add_special_tokens=self.add_special_tokens, truncation=False)["input_ids"]
all_token_ids = []
for tokenized_input in tokenized_inputs:
if self.append_concat_token:
tokenized_input = tokenized_input + [self.concat_token_id]
all_token_ids.extend(tokenized_input)
examples = []
for i in range(0, len(all_token_ids), self.seq_length):
input_ids = all_token_ids[i : i + self.seq_length]
if len(input_ids) == self.seq_length:
examples.append(input_ids)
if self.shuffle:
random.shuffle(examples)
for example in examples:
self.current_size += 1
yield {
"input_ids": torch.LongTensor(example),
"labels": torch.LongTensor(example),
}
class RunningMoments:
def __init__(self, accelerator):
"""
Calculates the running mean and standard deviation of a data stream. Reference:
https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L75
"""
self.mean = 0
self.std = 1
self.var = 1
self.count = 1e-24
self.accelerator = accelerator
@torch.no_grad()
def update(self, xs: torch.Tensor) -> Tuple[float, float]:
"""
Updates running moments from batch's moments computed across ranks
"""
if self.accelerator.use_distributed:
xs_mean, xs_var, xs_count = get_global_statistics(self.accelerator, xs)
else:
xs_count = xs.numel()
xs_var, xs_mean = torch.var_mean(xs, unbiased=False)
xs_mean, xs_var = xs_mean.float(), xs_var.float()
delta = xs_mean - self.mean
tot_count = self.count + xs_count
new_sum = xs_var * xs_count
# correct old_sum deviation accounting for the new mean
old_sum = self.var * self.count + delta**2 * self.count * xs_count / tot_count
tot_sum = old_sum + new_sum
self.mean += delta * xs_count / tot_count
self.var = tot_sum / tot_count
self.std = (self.var * tot_count / (tot_count - 1)).float().sqrt()
self.count = tot_count
return xs_mean.item(), (xs_var * xs_count / (xs_count - 1)).float().sqrt().item()
@torch.no_grad()
def get_global_statistics(accelerator, xs: torch.Tensor, mask=None, device="cpu") -> Tuple[float, float, int]:
"""
Computes element-wise mean and variance of the tensor across processes. Reference:
https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L57C1-L73C75
"""
xs = xs.to(accelerator.device)
sum_and_count = torch.tensor([xs.sum(), (xs.numel() if mask is None else mask.sum())], device=xs.device)
sum_and_count = accelerator.reduce(sum_and_count)
global_sum, count = sum_and_count
global_mean = global_sum / count
sum_var = torch.sum(((xs - global_mean) ** 2).mul(1 if mask is None else mask))
sum_var = accelerator.reduce(sum_var)
global_var = sum_var / count
return global_mean.to(device), global_var.to(device), count.to(device)
def compute_accuracy(eval_pred) -> Dict[str, float]:
predictions, labels = eval_pred
# Here, predictions is rewards_chosen and rewards_rejected.
# We want to see how much of the time rewards_chosen > rewards_rejected.
if np.array(predictions[:, 0] == predictions[:, 1], dtype=float).sum() > 0:
warnings.warn(f"There are {np.array(predictions[:, 0] == predictions[:, 1]).sum()} out of {len(predictions[:, 0])} instances where the predictions for both options are equal. As a consequence the accuracy can be misleading.")
predictions = np.argmax(predictions, axis=1)
accuracy = np.array(predictions == labels, dtype=float).mean().item()
return {"accuracy": accuracy}
def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float], dim: int = -1) -> torch.Tensor:
if tensor.size(dim) >= length:
return tensor
else:
pad_size = list(tensor.shape)
pad_size[dim] = length - tensor.size(dim)
return torch.cat(
[
tensor,
pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device),
],
dim=dim,
)
def disable_dropout_in_model(model: torch.nn.Module) -> None:
for module in model.modules():
if isinstance(module, torch.nn.Dropout):
module.p = 0
def exact_div(a, b, a_str, b_str, custom_error_message=""):
q = a // b
if a != q * b:
raise ValueError(f"{custom_error_message}, {a_str}={a}, {b_str}={b}, inexact division: {a} / {b} = {a / b}")
return q
# copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/stat_tracking.py#L5
class PerPromptStatTracker:
r"""
Class for tracking statistics per prompt. Mainly used to calculate advantage for the DPPO algorithm
Args:
buffer_size (`int`):
Size of the buffer to keep for each prompt.
min_count (`int`):
Minimum number of samples to keep in the buffer before calculating the mean and std.
"""
def __init__(self, buffer_size, min_count):
self.buffer_size = buffer_size
self.min_count = min_count
self.stats = {}
def update(self, prompts, rewards):
prompts = np.array(prompts)
rewards = np.array(rewards)
unique = np.unique(prompts)
advantages = np.empty_like(rewards)
for prompt in unique:
prompt_rewards = rewards[prompts == prompt]
if prompt not in self.stats:
self.stats[prompt] = deque(maxlen=self.buffer_size)
self.stats[prompt].extend(prompt_rewards)
if len(self.stats[prompt]) < self.min_count:
mean = np.mean(rewards)
std = np.std(rewards) + 1e-6
else:
mean = np.mean(self.stats[prompt])
std = np.std(self.stats[prompt]) + 1e-6
advantages[prompts == prompt] = (prompt_rewards - mean) / std
return advantages
def get_stats(self):
return {k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)} for k, v in self.stats.items()}
def neftune_post_forward_hook(module, input, output):
"""
Implements the NEFTune forward pass for the model using forward hooks. Note this works only for
torch.nn.Embedding layers. This method is slightly adapted from the original source code
that can be found here: https://github.com/neelsjain/NEFTune
Simply add it to your model as follows:
```python
model = ...
model.embed_tokens.neftune_noise_alpha = 0.1
model.embed_tokens.register_forward_hook(neftune_post_forward_hook)
```
Args:
module (`torch.nn.Module`):
The embedding module where the hook is attached. Note that you need to set
`module.neftune_noise_alpha` to the desired noise alpha value.
input (`torch.Tensor`):
The input tensor to the model.
output (`torch.Tensor`):
The output tensor of the model (i.e. the embeddings).
"""
if module.training:
dims = torch.tensor(output.size(1) * output.size(2))
mag_norm = module.neftune_noise_alpha / torch.sqrt(dims)
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
return output
def peft_module_casting_to_bf16(model):
from peft.tuners.tuners_utils import BaseTunerLayer
for name, module in model.named_modules():
if isinstance(module, BaseTunerLayer):
module = module.to(torch.bfloat16)
elif isinstance(module, torch.nn.LayerNorm) or "norm" in name:
module = module.to(torch.float32)
elif any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]):
if hasattr(module, "weight"):
if module.weight.dtype == torch.float32:
module = module.to(torch.bfloat16)
def trl_sanitze_kwargs_for_tagging(model, tag_names, kwargs=None):
if is_unsloth_available():
# Unsloth adds a new attribute in the model config `unsloth_version`
# to keep track of models that have been patched with unsloth.
if hasattr(model, "config") and getattr(model.config, "unsloth_version", None) is not None:
tag_names.append("unsloth")
if kwargs is not None:
if "tags" not in kwargs:
kwargs["tags"] = tag_names
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
kwargs["tags"].extend(tag_names)
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
tag_names.append(kwargs["tags"])
kwargs["tags"] = tag_names
return kwargs
def get_quantization_config(model_config: ModelConfig) -> Optional[BitsAndBytesConfig]:
if model_config.load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_config.torch_dtype, # For consistency with model weights, we use the same value as `torch_dtype`
bnb_4bit_quant_type=model_config.bnb_4bit_quant_type,
bnb_4bit_use_double_quant=model_config.use_bnb_nested_quant,
)
elif model_config.load_in_8bit:
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
)
else:
quantization_config = None
return quantization_config
def get_kbit_device_map() -> Optional[Dict[str, int]]:
if is_xpu_available():
return {"": f"xpu:{PartialState().local_process_index}"}
elif torch.cuda.is_available():
return {"": PartialState().local_process_index}
else:
return None
def get_peft_config(model_config: ModelConfig) -> "Optional[PeftConfig]":
if model_config.use_peft is False:
return None
peft_config = LoraConfig(
r=model_config.lora_r,
lora_alpha=model_config.lora_alpha,
lora_dropout=model_config.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules=model_config.lora_target_modules,
modules_to_save=model_config.lora_modules_to_save,
)
return peft_config
# LLaVA-NeXT: A Strong Zero-shot Video Understanding Model
## 论文
`LLaVA-NeXT: A Strong Zero-shot Video Understanding Model`
* https://llava-vl.github.io/blog/2024-04-30-llava-next-video/
## 模型结构
参考[README.md](../README.md)
## 算法原理
参考[README.md](../README.md)
## 数据集
## 训练
## 推理
### 原生
```bash
cd ..
bash scripts/video/demo/video_demo.sh /path/to/LLaVA-NeXT-Video-7B-DPO vicuna_v1 32 2 average no_token True playground/demo/xU25MMA2N4aVtYay.mp4
```
### hf
```bash
python inference_hf.py
```
## result
![alt text](readme_imgs/result.png)
### 精度
## 应用场景
参考[README.md](../README.md)
## 预训练权重
|model|url|
|:---:|:---:|
|LLaVA-NeXT-Video-7B-DPO|[hf](https://huggingface.co/lmms-lab/LLaVA-NeXT-Video-7B-DPO) \| [SCNet]() |
|LLaVA-NeXT-Video-7B-hf|[hf](https://huggingface.co/llava-hf/LLaVA-NeXT-Video-7B-hf) \| [SCNet]() |
|LLaVA-NeXT-Video-7B-32K-hf|[hf](https://huggingface.co/llava-hf/LLaVA-NeXT-Video-7B-32K-hf) \| [SCNet]() |
|LLaVA-NeXT-Video-7B-DPO-hf|[hf](https://huggingface.co/llava-hf/LLaVA-NeXT-Video-7B-DPO-hf) \| [SCNet]() |
|LLaVA-NeXT-Video-34B-hf|[hf](https://huggingface.co/llava-hf/LLaVA-NeXT-Video-34B-hf) \| [SCNet]() |
|LLaVA-NeXT-Video-34B-DPO-hf|[hf](https://huggingface.co/llava-hf/LLaVA-NeXT-Video-34B-DPO-hf) \| [SCNet]() |
模型下载后保存至`ckpts`(需自行创建).
## 源码仓库及问题反馈
参考[README.md](../README.md)
## 参考资料
* https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/docs/LLaVA-NeXT-Video.md
* https://llava-vl.github.io/blog/2024-04-30-llava-next-video/
\ No newline at end of file
import av
import torch
import numpy as np
from huggingface_hub import hf_hub_download
from transformers import LlavaNextVideoProcessor, LlavaNextVideoForConditionalGeneration
import os
from pathlib import Path
current_dir = str(Path(__file__).resolve().parent)
# model_id = "llava-hf/LLaVA-NeXT-Video-7B-hf"
model_id = os.path.join(current_dir, "ckpts", "LLaVA-NeXT-Video-7B-hf")
model = LlavaNextVideoForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to("cuda")
processor = LlavaNextVideoProcessor.from_pretrained(model_id)
def read_video_pyav(container, indices):
'''
Decode the video with PyAV decoder.
Args:
container (`av.container.input.InputContainer`): PyAV container.
indices (`List[int]`): List of frame indices to decode.
Returns:
result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
'''
frames = []
container.seek(0)
start_index = indices[0]
end_index = indices[-1]
for i, frame in enumerate(container.decode(video=0)):
if i > end_index:
break
if i >= start_index and i in indices:
frames.append(frame)
return np.stack([x.to_ndarray(format="rgb24") for x in frames])
# define a chat history and use `apply_chat_template` to get correctly formatted prompt
# Each value in "content" has to be a list of dicts with types ("text", "image", "video")
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "Why is this video funny?"},
{"type": "video"},
],
},
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
# video_path = hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset")
video_path = "./examples/jobs.mp4"
container = av.open(video_path)
# sample uniformly 8 frames from the video, can sample more for longer videos
total_frames = container.streams.video[0].frames
indices = np.arange(0, total_frames, total_frames / 8).astype(int)
clip = read_video_pyav(container, indices)
inputs_video = processor(text=prompt, videos=clip, padding=True, return_tensors="pt").to(model.device)
output = model.generate(**inputs_video, max_new_tokens=100, do_sample=False)
print(processor.decode(output[0][2:], skip_special_tokens=True))
{"Q": "Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes.", "video_name": "playground/demo/xU25MMA2N4aVtYay.mp4", "pred": "The video features a large, fluffy white cloud with the word \"Sora\" prominently displayed in the center. The word \"Sora\" is written in a stylized, cursive font that stands out against the soft, cloudy backdrop. The cloud itself is set against a clear blue sky, which adds to the serene and ethereal atmosphere of the scene. The word \"Sora\" is the main focus of the video, and its placement in the center of the cloud draws the viewer's attention immediately. The overall effect is one of tranquility and beauty, with the word \"Sora\" serving as a poetic or symbolic element within the scene."}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment