Commit 7f6cc211 authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
Pipeline #2874 failed with stages
in 0 seconds
set -x
export VLLM_USE_V1=1
# ================= data/model/tool =================
HDFS_ROOT=${HDFS_ROOT:-$PWD}
DATA_ROOT=${DATA_ROOT:-$PWD}
dapo_math_17k=$DATA_ROOT/dataset/BytedTsinghua-SIA/DAPO-Math-17k
aime_2024=$DATA_ROOT/dataset/Maxwell-Jia/AIME_2024
aime_2025=$DATA_ROOT/dataset/yentinglin/aime_2025
model_path=$HDFS_ROOT/checkpoint/multiturn-sft-qwen-2.5-7b-instruct/global_step_372
train_files="['$dapo_math_17k']"
test_files="['$aime_2025', '$aime_2024']"
# tool
tool_config_path=recipe/retool/sandbox_fusion_tool_config.yaml
# wandb
project_name=retool
experiment_name=qwen2.5-7b_dapo
default_local_dir=$DATA_ROOT/checkpoint/$experiment_name
# ================= algorithm =================
adv_estimator=grpo
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.28
max_turns=16
max_prompt_length=2048
max_response_length=16384
actor_lr=1e-6
train_batch_size=64
ppo_mini_batch_size=16
n_resp_per_prompt=16
n_resp_per_prompt_val=30
# ================= perfomance =================
infer_tp=4 # vllm
train_sp=4 # train
offload=True
actor_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 1 ))
log_prob_max_token_len_per_gpu=$(( actor_max_token_len_per_gpu * 4 ))
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=$adv_estimator \
algorithm.use_kl_in_reward=$use_kl_in_reward \
algorithm.kl_ctrl.kl_coef=$kl_coef \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.return_raw_chat=True \
data.train_batch_size=$train_batch_size \
data.max_prompt_length=$max_prompt_length \
data.max_response_length=$max_response_length \
data.filter_overlong_prompts=True \
data.truncation='error' \
data.custom_cls.path=recipe/retool/retool.py \
data.custom_cls.name=CustomRLHFDataset \
custom_reward_function.path=recipe/retool/retool.py \
custom_reward_function.name=compute_score \
actor_rollout_ref.model.path=$model_path \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.use_kl_loss=$use_kl_loss \
actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \
actor_rollout_ref.actor.clip_ratio_low=$clip_ratio_low \
actor_rollout_ref.actor.clip_ratio_high=$clip_ratio_high \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.actor.optim.lr=$actor_lr \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=$train_sp \
actor_rollout_ref.actor.fsdp_config.param_offload=$offload \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=$offload \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$log_prob_max_token_len_per_gpu \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.mode=async \
actor_rollout_ref.rollout.tensor_model_parallel_size=$infer_tp \
actor_rollout_ref.rollout.multi_turn.enable=True \
actor_rollout_ref.rollout.multi_turn.max_user_turns=$max_turns \
actor_rollout_ref.rollout.multi_turn.max_assistant_turns=$max_turns \
actor_rollout_ref.rollout.multi_turn.tool_config_path=$tool_config_path \
actor_rollout_ref.rollout.multi_turn.format=hermes \
actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \
actor_rollout_ref.rollout.n=$n_resp_per_prompt \
actor_rollout_ref.rollout.val_kwargs.top_p=0.6 \
actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \
actor_rollout_ref.rollout.val_kwargs.n=$n_resp_per_prompt_val \
trainer.logger=['console','wandb'] \
trainer.project_name=$project_name \
trainer.experiment_name=$experiment_name \
trainer.n_gpus_per_node=8 \
trainer.val_before_train=True \
trainer.log_val_generations=20 \
trainer.nnodes=1 \
trainer.save_freq=20 \
trainer.default_local_dir=$default_local_dir \
trainer.test_freq=10 \
trainer.total_epochs=1 $@
#!/bin/bash
set -x
nnodes=1
nproc_per_node=8
master_addr=
master_port=
node_rank=${ARNOLD_ID:-0}
project_name=retool
experiment_name=multiturn-sft-qwen-2.5-7b-instruct
HDFS_ROOT=${HDFS_ROOT:-$PWD}
DATA_ROOT=${DATA_ROOT:-$PWD}
TRAIN_DATA=$DATA_ROOT/dataset/wuxibin/ReTool-SFT/data/train-00000-of-00001.parquet
EVAL_DATA=$DATA_ROOT/dataset/wuxibin/ReTool-SFT/data/train-00000-of-00001.parquet
MODEL_PATH=$HDFS_ROOT/model/Qwen2.5-7B-Instruct
SAVE_PATH=$DATA_ROOT/checkpoint/$experiment_name
torchrun --nnodes=$nnodes \
--nproc_per_node=$nproc_per_node \
--master-addr=$master_addr \
--master-port=$master_port \
--node-rank=$node_rank \
-m verl.trainer.fsdp_sft_trainer \
data.train_files=$TRAIN_DATA \
data.val_files=$EVAL_DATA \
data.max_length=16384 \
data.train_batch_size=32 \
data.multiturn.enable=true \
data.multiturn.messages_key=messages \
data.multiturn.tools_key=tools \
data.micro_batch_size_per_gpu=4 \
model.partial_pretrain=$MODEL_PATH \
model.strategy=fsdp \
trainer.default_local_dir=$SAVE_PATH \
trainer.project_name=wuxibin-multiturn-sft \
trainer.experiment_name=$experiment_name \
trainer.logger='["console","wandb"]' \
trainer.total_epochs=6 \
trainer.save_freq=62 \
ulysses_sequence_parallel_size=4 \
use_remove_padding=true
\ No newline at end of file
tools:
- class_name: "recipe.retool.retool.CustomSandboxFusionTool"
config:
sandbox_fusion_url: "https://***.apigateway-cn-beijing.volceapi.com/run_code"
num_workers: 128
enable_global_rate_limit: true
rate_limit: 128
default_timeout: 30
default_language: "python"
memory_limit_mb: 1024
type: native
tool_schema:
type: "function"
function:
name: "code_interpreter"
description: "A tool for executing code."
parameters:
type: "object"
properties:
code:
type: "string"
description: "The code to execute."
required: ["code"]
# SPIN: Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models
This repository hosts a `verl` recipe inspired by the paper **"Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models"** (SPIN). SPIN is a language model finetuning algorithm that enables iterative self-improvement through a self-play mechanism inspired by game theory.
**Core Idea:** Models learn by playing against themselves, reducing reliance on external preference datasets or stronger teacher models:
1. **Synthetic Data Generation:** The current model generates responses, creating its own training data from previous iterations.
2. **Two-Player Game Setup:** A game involving two players acted by a single LLM.
3. **Iterative Training:** The model progressively improves by refining its policy, with each iteration's model becoming the opponent for the next iteration.
Paper Authors: [Zixiang Chen](https://github.com/uclaml/SPIN)\*, [Yihe Deng](https://github.com/uclaml/SPIN)\*, [Huizhuo Yuan](https://scholar.google.com/citations?user=8foZzX4AAAAJ)\*, [Kaixuan Ji](https://scholar.google.com/citations?user=FOoKDukAAAAJ), [Quanquan Gu](https://web.cs.ucla.edu/~qgu/)
[[Webpage](https://uclaml.github.io/SPIN/)] [[Huggingface](https://huggingface.co/papers/2401.01335)] [[Paper](https://arxiv.org/abs/2401.01335)] [[Original Implementation](https://github.com/uclaml/SPIN)]
verl Implementation Authors: [Chendong Wang](https://cdwang96.github.io/), [Chenyang Zhao](https://github.com/zhaochenyang20)
---
## Key Function (compute_online_dpo_loss) and Related works
SPIN (Chen et al., 2024) proposes an iterative self-play mechanism to fine-tune language models. In each iteration, SPIN's training objective, when using a logistic loss function, is equivalent to Direct Preference Optimization (DPO) loss (Rafailov et al., 2023).
This `verl` recipe realizes SPIN's core concept by using DPO loss iteratively (Xu et al., 2023; Xiong et al., 2023; Snorkel AI, 2024). This means that in each iteration, we fine-tune the LLM using DPO loss for preference optimization. Notably, Xu et al. (2023) explored iterative preference optimization with pairwise cringe loss, while Xiong et al. (2023) discussed how to bridge theory and practice for RLHF under KL constraints using iterative training. The concept of iterative preference learning was also explored in online DPO (Guo et al., 2024), which focuses on direct alignment from online AI feedback. In online DPO, preference data is dynamically updated during training, allowing the model to learn from its own generated data.
Specifically, we developed the **`compute_online_dpo_loss`** function and built this SPIN recipe on top of it. By incorporating online preference generation, this approach enables continuously refining language models without relying on fixed external preference datasets.
**Reference Papers:**
* [Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models](https://arxiv.org/abs/2401.01335) (Chen et al., 2024)
* [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290) (Rafailov et al., 2023)
* [Somethings are more cringe than others: Preference optimization with the pairwise cringe loss](https://arxiv.org/abs/2312.16682) (Xu et al., 2023)
* [Iterative preference learning from human feedback: Bridging theory and practice for rlhf under kl-constraint](https://arxiv.org/abs/2312.11456) (Xiong et al., 2023)
* [Snorkel-Mistral-PairRM-DPO](https://huggingface.co/snorkelai/Snorkel-Mistral-PairRM-DPO) (Snorkel AI, 2024)
* [Direct language model alignment from online ai feedback](https://arxiv.org/abs/2402.04792) (Guo et al., 2024)
## Our Online DPO Implementation
Our `compute_online_dpo_loss` function adapts `verl`'s existing PPO infrastructure (based on `verl` v0.3.0.post1) for this iterative online DPO. Key aspects of our implementation include:
* **No Critic:** Unlike PPO, we omit the value function critic.
* **Dynamic Reference Model:** An explicit reference policy (`ref_policy_wg`) is used for DPO loss. This reference model's weights can be periodically updated from the actor (`ref_update_freq`), providing a dynamic baseline.
* **Online Preference Generation:** The `compute_onlineDPO_pref` function (in `core_algos.py`) dynamically creates chosen/rejected pairs based on a reward source (e.g., rule-based ranking for math problems).
* **DPO Loss Integration:** We replace PPO's policy loss with our `compute_online_dpo_loss` (in `core_algos.py`) within the actor update (`dp_actor.py`), directly optimizing the policy using the generated preferences.
* **Iterative Training Orchestration:** The `SpinTrainer` (in `spin_trainer.py`) manages the entire self-play loop: generation, preference labeling, optional reference model updates, and policy updates, enabling continuous self-improvement aligned with SPIN's principles.
---
## Algorithm
This recipe implements an Online algorithm adapted to the `verl` Reinforcement Learning framework, which provides an alternative to PPO for fine-tuning language models.
**Online Loop:** Instead of maximizing a scalar reward signal in PPO, this approach directly optimizes the policy model to align with preference data generated *online* during training:
1. **Generation:** The current model generates multiple responses for each prompt in a batch.
2. **Preference Labeling:** A function evaluates these generated responses to determine which one is preferred (chosen) and which is dispreferred (rejected). This can be done using a reward function or implicit ranking based on specific rules. (In this recipe, we use rule-based ranking on the math problem).
3. **Update:** This preference tuple (`prompt`, `chosen_response`, `rejected_response`) is used to update the actor model using `compute_online_dpo_loss`, comparing against a reference model.
**Connection with SPIN:**
Instead of only using a fixed target data distribution, the online generation loop in step 2 will dynamically change the target data distribution by using a certain Preference Labeling method (rule-based ranking on the math problem by selecting the better one in this recipe). This explores the direction mentioned in SPIN's paper Section 7 about "dynamically changing target data distribution" to potentially elevate LLM performance beyond the fixed human-annotated data ceiling.
---
## Reproduce the Experiment (Example Setup)
The following steps outline how to set up the environment and run the SPIN recipe, based on the provided test log using GSM8K and Qwen2.5-3B-Instruct.
1. **Setup Environment (Example using Docker):**
```bash
# Start a container with GPU access and shared memory
docker run -it --name spin_test --gpus all \
--shm-size=32g \
--ipc=host \
-v /path/to/host/.cache:/root/.cache \
-e HF_TOKEN=<YOUR_HUGGINGFACE_TOKEN> \
lmsysorg/sglang:latest \
/bin/bash
# Inside the container or on your host machine:
# Ensure /tmp is writable
mkdir -p /tmp
chmod 1777 /tmp
# Install Python 3.10 (if not present) and venv
sudo apt update
sudo apt install -y python3.10 python3.10-venv tmux
python3 -m ensurepip --upgrade
# Create and activate a virtual environment
python3 -m venv ~/.python/spin_env
source ~/.python/spin_env/bin/activate
# Install uv (fast package installer)
python3 -m pip install uv
```
2. **Install verl and Dependencies:**
```bash
# Clone the verl repository and checkout the spin branch
cd ~
git clone git@github.com:volcengine/verl.git && cd verl
# Install flash-attn (handle potential build issues)
python3 -m uv pip install wheel packaging
python3 -m uv pip install flash-attn --no-build-isolation --no-deps
# Install verl with sglang extras
python3 -m uv pip install -e ".[sglang]"
```
*Note: If `flash-attn` installation fails, try the manual steps again or consult its documentation.*
3. **Login & Download Data/Model:**
```bash
# Login to Weights & Biases (optional, for logging)
export WANDB_API_KEY=<YOUR_WANDB_API_KEY>
# wandb login
# Download the GSM8K dataset
python3 examples/data_preprocess/gsm8k.py --local_dir ~/data/gsm8k # Adjusted path
# Download the base model (Example: Qwen2.5-3B-Instruct)
huggingface-cli download Qwen/Qwen2.5-3B-Instruct --local-dir $HOME/models/Qwen2.5-3B-Instruct
```
4. **Configure:**
* Modify the configuration file (e.g., `config/spin_trainer.yaml` or the one specified in the run script) with correct paths to your downloaded model, data, desired hyperparameters (`dpo_beta`, learning rate, etc.), and distributed training settings (nodes, GPUs per node).
* Pay attention to `actor_rollout_ref.model_path`, `data` paths, `reward_model` config (if using one), and `trainer.ref_update_freq`.
5. **Run Training:**
```bash
# Set CUDA visible devices (adjust based on your hardware and config)
export CUDA_VISIBLE_DEVICES=0,1,2,3
# Launch the training script (e.g., test.sh or a custom script)
# Ensure test.sh points to the correct config and main script
bash recipe/spin/run_spin.sh
```
---
## Configuration
* The primary configuration is typically managed through a YAML file specified in the launch script (e.g., `config/spin_trainer.yaml`).
* Key configuration sections:
* `data`: Paths to training/validation prompt files, batch sizes, sequence lengths.
* `actor_rollout_ref`: Paths to the base model (used for actor and initial reference), FSDP settings, optimization parameters (learning rate, scheduler).
* `reward_model`: Configuration for the reward model used for online preference labeling (path, batch size, etc.). Can be omitted if using a simpler reward function.
* `algorithm`: DPO-specific hyperparameters like `dpo_beta`, `dpo_loss_type`.
* `trainer`: Distributed training settings (nodes, GPUs per node), logging (WandB), checkpointing frequency, and `ref_update_freq` (set > 0 to enable periodic reference model updates from the actor).
---
## Key Files
* `main_spin.py`: Main entry point using Hydra to load the config and launch the `SpinTrainer`.
* `spin_trainer.py`: Defines the `SpinTrainer` class, orchestrating the Online DPO training loop.
* `fsdp_workers.py`: Implements Ray workers (Actor, Reference) potentially using FSDP.
* `dp_actor.py`: Contains the actor class, including the DPO policy update logic.
* `core_algos.py`: Includes helper functions for `compute_online_dpo_loss` and `compute_onlineDPO_pref`.
* `config/spin_trainer.yaml` (or similar): Main Hydra configuration file for the recipe.
* `run_spin.sh` (or similar): Example bash script for launching a training run.
* `README.md`: This file.
---
## Acknowledgement
We sincerely thank the contribution and guidance from the `verl` community and advisors, including (adapted from SPPO):
* [Zixiang Chen](https://sites.google.com/view/zxchen)
* [Yuhao Yang](https://github.com/yhyang201)
* [Yifan Zhang](https://github.com/yifanzhang-pro)
* [Yongan Xiang](https://github.com/BearBiscuit05)
* [Junrong Lin](https://github.com/ocss884)
* [Yuxuan Tong](https://github.com/tongyx361)
* [Guangming Shen](https://github.com/PeterSH6)
* [Biao He](https://www.linkedin.com/in/biao-he/)
* [Qingquan Song](https://qingquansong.github.io/)
* [Chenyang Zhao](https://zhaochenyang20.github.io/Chayenne/)
* [Quanquan Gu](https://web.cs.ucla.edu/~qgu/)
---
# the sppo config will override default ppo_trainer.yaml
hydra:
searchpath:
- file://verl/trainer/config
defaults:
- ppo_trainer
- _self_
actor_rollout_ref:
actor:
dpo_beta: 0.1
optim:
lr_warmup_steps: 15
rollout:
name: sglang
tensor_model_parallel_size: 2
gpu_memory_utilization: 0.5
val_kwargs:
n: 2 # 2 will trigger validation, 1 will bypass
algorithm:
adv_estimator: null
trainer:
log_val_generations: 0
ref_update_freq: 1
\ No newline at end of file
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
class AdaptiveKLController:
"""
Adaptive KL controller described in the paper:
https://arxiv.org/pdf/1909.08593.pdf
"""
def __init__(self, init_kl_coef, target_kl, horizon):
self.value = init_kl_coef
self.target = target_kl
self.horizon = horizon
def update(self, current_kl, n_steps):
target = self.target
proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
mult = 1 + proportional_error * n_steps / self.horizon
self.value *= mult
class FixedKLController:
"""Fixed KL controller."""
def __init__(self, kl_coef):
self.value = kl_coef
def update(self, current_kl, n_steps):
pass
def get_kl_controller(kl_ctrl):
if kl_ctrl.type == "fixed":
return FixedKLController(kl_coef=kl_ctrl.kl_coef)
elif kl_ctrl.type == "adaptive":
assert kl_ctrl.horizon > 0, f"horizon must be larger than 0. Got {kl_ctrl.horizon}"
return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon)
else:
raise NotImplementedError
def compute_onlinedpo_pref(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
) -> torch.Tensor:
"""
Computes preferences between pairs of sequences based on summed rewards
and returns a mask aligned with the interleaved batch.
Assumes inputs are interleaved: [Resp1_Prompt0, Resp2_Prompt0, Resp1_Prompt1, Resp2_Prompt1, ...]
Args:
token_level_rewards: Tensor of shape [batch_size * 2, seq_len]
response_mask: Tensor of shape [batch_size * 2, seq_len]
Returns:
torch.Tensor: A boolean mask of shape [batch_size * 2], where True indicates
the corresponding entry is the chosen response for its pair.
Example: [True, False, False, True, ...] means for prompt 0,
response 1 was chosen; for prompt 1, response 2 was chosen.
"""
# print(f"---- [DEBUG] Inside compute_onlinedpo_pref ----")
if token_level_rewards.shape[0] % 2 != 0 or response_mask.shape[0] % 2 != 0:
raise ValueError(
f"Input tensor batch dimension must be even for pair comparison, got shapes: "
f"{token_level_rewards.shape}, {response_mask.shape}"
)
if token_level_rewards.shape != response_mask.shape:
raise ValueError(f"Shape mismatch between rewards {token_level_rewards.shape} and mask {response_mask.shape}")
# 1. Calculate Sequence Scores
scores = (token_level_rewards * response_mask).sum(dim=-1)
# print(f" Calculated sequence scores shape: {scores.shape}") # [batch_size * 2]
# 2. Reshape scores to group pairs: [batch_size, 2]
try:
score_pairs = scores.view(-1, 2)
except RuntimeError as e:
print(f"ERROR reshaping scores (shape {scores.shape}) into pairs: {e}")
raise e
print(f" Reshaped score pairs shape: {score_pairs.shape}") # [batch_size, 2]
# 3. Compare scores to find which index (0 or 1) is the winner within each pair
# winner_indices[i] = 0 if score_pairs[i, 0] >= score_pairs[i, 1] else 1
winner_indices = torch.argmax(score_pairs, dim=1) # 0 if first is max, 1 if second is max
# Handle ties explicitly if argmax behavior isn't guaranteed (usually picks first max)
# Alternatively: winner_mask_original = score_pairs[:, 0] >= score_pairs[:, 1]
# print(f" Winner indices shape: {winner_indices.shape}") # [batch_size]
# print(f" Number where Response 2 (index 1) is preferred: {winner_indices.sum().item()}") # Counts number of 1s
# 4. Create the final [batch_size * 2] mask
num_pairs = score_pairs.shape[0]
full_batch_size = num_pairs * 2
# Create indices for the full batch [0, 1, 2, 3, ..., N*2-1]
# full_indices = torch.arange(full_batch_size, device=scores.device)
# Create indices corresponding to the winner within each pair's original index
# E.g., if winner_indices is [0, 1, 0], pair_indices is [0, 1, 2]
# winner_global_indices = (pair_indices * 2) + winner_indices -> [ (0*2)+0, (1*2)+1, (2*2)+0 ] -> [0, 3, 4]
pair_indices = torch.arange(num_pairs, device=scores.device)
winner_global_indices = (pair_indices * 2) + winner_indices
# Create boolean mask - True at the winner's position
output_preference_mask = torch.zeros(full_batch_size, dtype=torch.bool, device=scores.device)
output_preference_mask[winner_global_indices] = True
# print(f" Output preference mask shape: {output_preference_mask.shape}") # Should be [batch_size * 2]
# print(f" Output mask True count (Chosen): {output_preference_mask.sum().item()}") # Should be batch_size
# print(f" Output mask False count (Rejected): {(~output_preference_mask).sum().item()}") # Should be batch_size
# print(f"---- [DEBUG] Exiting compute_onlinedpo_pref ----")
return output_preference_mask
def compute_online_dpo_loss(
policy_chosen_logps: torch.Tensor,
policy_rejected_logps: torch.Tensor,
reference_chosen_logps: torch.Tensor,
reference_rejected_logps: torch.Tensor,
beta: float,
label_smoothing: float = 0.0,
loss_type: str = "sigmoid",
reference_free: bool = False,
) -> torch.Tensor:
import torch.nn.functional as F
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
if reference_free:
ref_logratios = torch.zeros_like(pi_logratios)
logits = pi_logratios - ref_logratios
if loss_type == "sigmoid":
losses = -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta * logits) * label_smoothing
elif loss_type == "ipo":
losses = (logits - 1 / (2 * beta)) ** 2
else:
raise ValueError(f"Unsupported loss_type: {loss_type}. Choose 'sigmoid', 'ipo', or 'hinge'.")
return losses.mean()
def get_batch_logps(
logits: torch.FloatTensor, labels: torch.LongTensor, average_log_prob: bool = False
) -> torch.FloatTensor:
"""
Compute the log probabilities of the given labels under the given logits.
Args:
logits: Logits of the model (e.g., huggingface CausalLMOutputs `logits`).
Shape: (batch_size, sequence_length, vocab_size)
labels: Labels for computing the sequence log probabilities. Shape: (batch_size, sequence_length)
average_log_prob: If True, return the average log probability per sequence. Otherwise, return the sum.
Returns:
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given sequences.
"""
if logits.shape[:-1] != labels.shape:
raise ValueError("Logits and labels must have the same shape[:-1]")
# Ensure labels are contiguous and on the same device as logits
labels = labels.contiguous().to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Calculate per token log probability
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="none")
per_token_logps = -loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
per_token_logps = per_token_logps.view(
shift_logits.size(0), shift_logits.size(1)
) # Reshape back to (batch_size, seq_len-1)
# Create a mask for the labels that are not -100
loss_mask = shift_labels != -100
# Apply the mask to the per token log probabilities
masked_logps = per_token_logps * loss_mask
# Calculate the sum or average log probability per sequence
sequence_logps = masked_logps.sum(dim=-1)
if average_log_prob:
# Avoid division by zero for sequences with no valid tokens
num_valid_tokens = loss_mask.sum(dim=-1)
return sequence_logps / torch.clamp(num_valid_tokens, min=1)
else:
return sequence_logps
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import math
from collections import defaultdict
import numpy as np
import torch
from recipe.spin.core_algos import compute_online_dpo_loss, get_batch_logps
from verl import DataProto
from verl.utils.device import get_device_name
from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches
from verl.workers.actor import DataParallelPPOActor
__all__ = ["DataParallelPPOActor"]
class SPINDataParallelPPOActor(DataParallelPPOActor):
def compute_log_prob(self, data: DataProto) -> torch.Tensor:
"""Compute the log probability of the responses given input_ids, attention_mask and position_ids
Args:
data (DataProto): a DataProto containing keys
``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the
concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``.
``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64.
``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64.
``responses``: tensor of shape [batch_size, response_length]. torch.int64.
Returns:
torch.Tensor: the log_prob tensor
"""
# set to eval
self.actor_module.eval()
micro_batch_size = data.meta_info["micro_batch_size"]
temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid silent error
use_dynamic_bsz = data.meta_info["use_dynamic_bsz"]
select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
batch = data.select(batch_keys=select_keys).batch
has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
if has_multi_modal_inputs:
num_micro_batches = data.batch.batch_size[0] // micro_batch_size
non_tensor_select_keys = ["multi_modal_inputs"]
micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches)
elif use_dynamic_bsz:
# split using dynamic bsz
max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size
micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)
else:
micro_batches = batch.split(micro_batch_size)
log_probs_lst = []
for micro_batch in micro_batches:
if isinstance(micro_batch, DataProto):
micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch}
with torch.no_grad():
_, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature)
log_probs_lst.append(log_probs)
log_probs = torch.concat(log_probs_lst, dim=0)
if use_dynamic_bsz:
indices = list(itertools.chain.from_iterable(indices))
assert len(indices) == log_probs.size(0), f"{len(indices)} vs. {log_probs.size()}"
revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
log_probs = log_probs[revert_indices]
return log_probs
def update_policy_dpo_with_ref(self, data: DataProto):
"""
Performs the DPO update step using pre-calculated reference log probs
from an external, periodically updated reference model.
"""
self.actor_module.train() # Ensure training mode
# --- Retrieve necessary data ---
try:
# Expects batch prepared by fit_dpo loop, including reference log probs
batch_td = data.batch
chosen_labels = batch_td["chosen_labels"]
rejected_labels = batch_td["rejected_labels"]
# ... other needed tensors like chosen/rejected input_ids, attention_mask, position_ids ...
# === Get PRE-CALCULATED reference log probs from input data ===
reference_chosen_logps = batch_td["reference_chosen_logps"] # Should be sequence-level logps
reference_rejected_logps = batch_td["reference_rejected_logps"] # Should be sequence-level logps
# ============================================================
# Get DPO params from meta_info
# beta = data.meta_info.get('dpo_beta', 0.1) # Default beta
beta = self.config.get("dpo_beta", 0.1) # Default beta
loss_type = data.meta_info.get("dpo_loss_type", "sigmoid")
label_smoothing = data.meta_info.get("dpo_label_smoothing", 0.0)
# reference_free should now be False as we provide ref logps
reference_free = data.meta_info.get("reference_free", False) # Default False
except KeyError as e:
print(f"ERROR: Missing required key for DPO update (in update_policy_dpo): {e}")
print(f"Available keys in data.batch: {list(batch_td.keys())}") # Debug print
return {} # Return empty metrics on error
except Exception as e_data:
print(f"ERROR accessing data for DPO update (in update_policy_dpo): {e_data}")
return {}
# --- Micro-batching Setup ---
micro_batch_size = self.config.get("ppo_micro_batch_size_per_gpu")
if micro_batch_size is None:
# Fallback or default if not set, or raise error
micro_batch_size = 1 # Example fallback, adjust as needed
print(f"Warning: 'ppo_micro_batch_size_per_gpu' not set, defaulting to {micro_batch_size}")
# raise ValueError("Config 'ppo_micro_batch_size_per_gpu' must be set.")
# Ensure chosen_input_ids exists before getting shape
if "chosen_input_ids" not in batch_td:
print("ERROR: 'chosen_input_ids' not found in batch_td for DPO update.")
return {}
bsz = batch_td["chosen_input_ids"].shape[0]
if bsz == 0:
print("Warning: DPO batch size is 0 in update_policy_dpo. Skipping update.")
return {"actor/dpo_loss": 0.0, "actor/grad_norm": 0.0} # Return zero metrics if batch is empty
num_micro_batches = math.ceil(bsz / micro_batch_size)
gradient_accumulation_steps = num_micro_batches
# --- Metrics Accumulation ---
total_loss = 0.0
accumulated_metrics = defaultdict(list)
metrics = {} # Final metrics dict
# --- Zero Gradients ---
self.actor_optimizer.zero_grad(set_to_none=True)
# --- Micro-batch Loop ---
for i in range(num_micro_batches):
start_idx = i * micro_batch_size
end_idx = min(start_idx + micro_batch_size, bsz)
if start_idx >= end_idx:
continue
# Slice the full DPO batch into micro-batches
# Important: Slice ALL required tensors, including labels and inputs
micro_batch_chosen_labels = chosen_labels[start_idx:end_idx]
micro_batch_rejected_labels = rejected_labels[start_idx:end_idx]
micro_batch_chosen_inputs = {
"input_ids": batch_td["chosen_input_ids"][start_idx:end_idx],
"attention_mask": batch_td["chosen_attention_mask"][start_idx:end_idx],
}
if "chosen_position_ids" in batch_td:
micro_batch_chosen_inputs["position_ids"] = batch_td["chosen_position_ids"][start_idx:end_idx]
micro_batch_rejected_inputs = {
"input_ids": batch_td["rejected_input_ids"][start_idx:end_idx],
"attention_mask": batch_td["rejected_attention_mask"][start_idx:end_idx],
}
if "rejected_position_ids" in batch_td:
micro_batch_rejected_inputs["position_ids"] = batch_td["rejected_position_ids"][start_idx:end_idx]
# Determine autocast dtype
autocast_dtype = torch.bfloat16 # Or get dynamically from config/FSDP settings
# --- Autocast Forward Pass ---
with torch.autocast(device_type=get_device_name(), dtype=autocast_dtype):
# --- Step 1: Forward pass for CURRENT policy log probs (with grad) ---
policy_chosen_outputs = self.actor_module(**micro_batch_chosen_inputs, use_cache=False)
policy_rejected_outputs = self.actor_module(**micro_batch_rejected_inputs, use_cache=False)
# --- Step 2: Calculate CURRENT policy log probs using get_batch_logps ---
policy_chosen_logps = get_batch_logps(
policy_chosen_outputs.logits, micro_batch_chosen_labels, average_log_prob=False
)
policy_rejected_logps = get_batch_logps(
policy_rejected_outputs.logits, micro_batch_rejected_labels, average_log_prob=False
)
# --- Step 3: Retrieve PRE-CALCULATED reference log probs (NO grad needed) ---
# Slice the full batch reference logps for the current micro-batch
micro_ref_chosen_logps = reference_chosen_logps[start_idx:end_idx]
micro_ref_rejected_logps = reference_rejected_logps[start_idx:end_idx]
# --- The ActorAsRef calculation block is REMOVED ---
# --- Step 4: Calculate DPO Logits and Loss ---
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = micro_ref_chosen_logps - micro_ref_rejected_logps # Uses pre-calculated values
logits = pi_logratios - ref_logratios # DPO logits
loss = compute_online_dpo_loss(
policy_chosen_logps=policy_chosen_logps, # Has grad
policy_rejected_logps=policy_rejected_logps, # Has grad
reference_chosen_logps=micro_ref_chosen_logps, # No grad (from input)
reference_rejected_logps=micro_ref_rejected_logps, # No grad (from input)
beta=beta,
label_smoothing=label_smoothing,
loss_type=loss_type,
reference_free=reference_free, # Should be False now
)
# --- Scale loss for gradient accumulation ---
scaled_loss = loss / gradient_accumulation_steps
# --- Accumulate Metrics ---
total_loss += loss.item() # Unscaled loss
accumulated_metrics["actor/dpo_loss_batch"].append(loss.item())
accumulated_metrics["actor/dpo_logits_batch"].append(logits.mean().item())
# Accumulate policy and reference log probs/ratios if needed for debugging
accumulated_metrics["actor/policy_chosen_logps_batch"].append(policy_chosen_logps.mean().item())
accumulated_metrics["actor/policy_rejected_logps_batch"].append(policy_rejected_logps.mean().item())
accumulated_metrics["actor/reference_chosen_logps_batch"].append(micro_ref_chosen_logps.mean().item())
accumulated_metrics["actor/reference_rejected_logps_batch"].append(
micro_ref_rejected_logps.mean().item()
)
# --- Backward Pass (outside autocast) ---
# Check if loss requires grad before backward
if scaled_loss.requires_grad:
scaled_loss.backward()
else:
print(f"Warning: Scaled loss at micro-batch {i} does not require grad. Skipping backward.")
# --- End Micro-batch Loop ---
# --- Optimizer Step (after accumulating gradients for all micro-batches) ---
grad_norm = self._optimizer_step()
# --- Populate Final Metrics ---
if num_micro_batches > 0 and bsz > 0: # Check if any processing happened
metrics["actor/dpo_loss"] = total_loss / num_micro_batches
metrics["actor/grad_norm"] = (
grad_norm.item() if torch.is_tensor(grad_norm) and torch.isfinite(grad_norm) else float("inf")
)
# Average other accumulated metrics
for key, val_list in accumulated_metrics.items():
if val_list:
metrics[key.replace("_batch", "")] = np.mean(val_list)
# Calculate accuracy / rewards / margins based on averaged logprobs if desired
if (
"actor/policy_chosen_logps" in metrics
and "actor/policy_rejected_logps" in metrics
and "actor/reference_chosen_logps" in metrics
and "actor/reference_rejected_logps" in metrics
):
policy_ratio_mean = metrics["actor/policy_chosen_logps"] - metrics["actor/policy_rejected_logps"]
ref_ratio_mean = metrics["actor/reference_chosen_logps"] - metrics["actor/reference_rejected_logps"]
logits_mean = policy_ratio_mean - ref_ratio_mean
metrics["actor/rewards_chosen"] = beta * (
metrics["actor/policy_chosen_logps"] - metrics["actor/reference_chosen_logps"]
)
metrics["actor/rewards_rejected"] = beta * (
metrics["actor/policy_rejected_logps"] - metrics["actor/reference_rejected_logps"]
)
metrics["actor/rewards_accuracies"] = float(logits_mean > 0) # Mean accuracy proxy
metrics["actor/rewards_margins"] = metrics["actor/rewards_chosen"] - metrics["actor/rewards_rejected"]
else: # Handle case where no micro-batches were run (e.g., bsz=0)
metrics["actor/dpo_loss"] = 0.0
metrics["actor/grad_norm"] = 0.0
# Initialize other metrics to 0 or NaN as appropriate
for key in accumulated_metrics.keys():
metrics[key.replace("_batch", "")] = 0.0
metrics["actor/rewards_chosen"] = 0.0
metrics["actor/rewards_rejected"] = 0.0
metrics["actor/rewards_accuracies"] = 0.0
metrics["actor/rewards_margins"] = 0.0
return metrics # Return aggregated metrics
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import warnings
import psutil
import torch
import torch.distributed
from codetiming import Timer
from omegaconf import OmegaConf, open_dict
from torch.distributed.device_mesh import init_device_mesh
import verl.utils.torch_functional as verl_F
from verl import DataProto
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, register
from verl.utils import hf_tokenizer
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from verl.utils.device import get_device_id, get_device_name, get_nccl_backend, get_torch_device
from verl.utils.flops_counter import FlopsCounter
from verl.utils.fs import copy_to_local
from verl.utils.fsdp_utils import (
get_fsdp_wrap_policy,
get_init_weight_context_manager,
init_fn,
load_fsdp_model_to_gpu,
load_fsdp_optimizer,
offload_fsdp_model_to_cpu,
offload_fsdp_optimizer,
)
from verl.utils.import_utils import import_external_libs
from verl.utils.model import compute_position_id_with_mask
from verl.utils.profiler import log_gpu_memory_usage
from verl.workers.fsdp_workers import ActorRolloutRefWorker
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN"))
def create_device_mesh(world_size, fsdp_size):
if fsdp_size < 0 or fsdp_size >= world_size:
device_mesh = init_device_mesh(get_device_name(), mesh_shape=(world_size,), mesh_dim_names=["fsdp"])
else:
device_mesh = init_device_mesh(
get_device_name(), mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"]
)
return device_mesh
def get_sharding_strategy(device_mesh):
from torch.distributed.fsdp import ShardingStrategy
if device_mesh.ndim == 1:
sharding_strategy = ShardingStrategy.FULL_SHARD
elif device_mesh.ndim == 2:
sharding_strategy = ShardingStrategy.HYBRID_SHARD
else:
raise NotImplementedError(f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2")
return sharding_strategy
class SPINRolloutRefWorker(ActorRolloutRefWorker):
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
from recipe.spin.dp_actor import SPINDataParallelPPOActor as DataParallelPPOActor
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get("external_lib", None))
override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
use_remove_padding = self.config.model.get("use_remove_padding", False)
use_fused_kernels = self.config.model.get("use_fused_kernels", False)
if self._is_actor or self._is_rollout or self._is_ref:
# we need the model for actor and rollout
if self._is_actor or self._is_ref:
optim_config = self.config.actor.optim
fsdp_config = self.config.actor.fsdp_config
else:
optim_config = None
fsdp_config = OmegaConf.create()
self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = (
self._build_model_optimizer(
model_path=self.config.model.path,
fsdp_config=fsdp_config,
optim_config=optim_config,
override_model_config=override_model_config,
use_remove_padding=use_remove_padding,
use_fused_kernels=use_fused_kernels,
enable_gradient_checkpointing=self.config.model.get("enable_gradient_checkpointing", False),
trust_remote_code=self.config.model.get("trust_remote_code", False),
use_liger=self.config.model.get("use_liger", False),
role="actor",
)
)
# get the original unwrapped module
self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
log_gpu_memory_usage("After offload actor optimizer during init", logger=logger)
# load from checkpoint
if self._is_actor or self._is_ref:
OmegaConf.set_struct(self.config.actor, True)
with open_dict(self.config.actor):
self.config.actor.use_remove_padding = use_remove_padding
self.config.actor.use_fused_kernels = use_fused_kernels
self.actor = DataParallelPPOActor(
config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer
)
if self._is_rollout:
self.rollout, self.rollout_sharding_manager = self._build_rollout(
trust_remote_code=self.config.model.get("trust_remote_code", False)
)
if self._is_ref:
# self.ref_module_fsdp = self._build_model_optimizer(
# model_path=self.config.model.path,
# fsdp_config=self.config.ref.fsdp_config,
# optim_config=None,
# override_model_config=override_model_config,
# use_remove_padding=use_remove_padding,
# use_fused_kernels=use_fused_kernels,
# trust_remote_code=self.config.model.get("trust_remote_code", False),
# use_liger=self.config.model.get("use_liger", False),
# role="ref",
# )[0]
OmegaConf.set_struct(self.config.ref, True)
with open_dict(self.config.ref):
self.config.ref.use_remove_padding = use_remove_padding
self.config.ref.use_fused_kernels = use_fused_kernels
self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp)
self.checkpoint_manager = FSDPCheckpointManager(
model=self.actor_module_fsdp,
optimizer=self.actor.actor_optimizer,
lr_scheduler=self.actor_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_config=self.config.actor.checkpoint,
)
if self._is_actor:
self.flops_counter = FlopsCounter(self.actor_model_config)
self.checkpoint_manager = FSDPCheckpointManager(
model=self.actor_module_fsdp,
optimizer=self.actor.actor_optimizer,
lr_scheduler=self.actor_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_config=self.config.actor.checkpoint,
)
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_ref_log_prob(self, data: DataProto):
assert self._is_ref
# Support all hardwares
data = data.to(get_device_id())
micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu
data.meta_info["micro_batch_size"] = micro_batch_size
data.meta_info["temperature"] = self.config.rollout.temperature
data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu
data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data)
output = self.ref_policy.compute_log_prob(data=data)
output = DataProto.from_dict(tensors={"ref_log_prob": output})
output = self.ulysses_sharding_manager.postprocess_data(output)
output = output.to("cpu")
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
if self.world_size > 1:
self.ref_policy.actor_module._handle.reshard(True)
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_log_prob(self, data: DataProto):
assert self._is_actor
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
# Support all hardwares
data = data.to(get_device_id())
# we should always recompute old_log_probs when it is HybridEngine
data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu
data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu
data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz
data.meta_info["temperature"] = self.config.rollout.temperature
# perform recompute log_prob
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data)
output = self.actor.compute_log_prob(data=data)
output = DataProto.from_dict(
tensors={"old_log_probs": output}, meta_info={"temperature": self.config.rollout.temperature}
)
output = self.ulysses_sharding_manager.postprocess_data(output)
output = output.to("cpu")
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
if self.world_size > 1:
self.actor.actor_module._handle.reshard(True)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
log_gpu_memory_usage("After compute_log_prob", logger=logger)
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_actor_dpo(self, data: DataProto):
"""
Wrapper for actor update step. Handles FSDP state management.
Calls self.actor.update_policy which now contains DPO logic based
on pre-calculated log probabilities.
"""
# Support all hardwares
data = data.to(get_device_id())
assert self._is_actor # Make sure this worker has the actor role
if self.actor is None:
raise RuntimeError("Actor instance (self.actor) not initialized in worker.")
# --- FSDP State Management ---
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
if self._is_offload_optimizer:
load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_device_id())
log_gpu_memory_usage("Before update policy (DPO via PPO path)", logger=logger)
# --- Ulysses Sharding (if used) ---
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data=data)
# --- Call the core update method (now containing DPO logic) ---
with Timer(name="update_policy_dpo_via_ppo", logger=None) as timer: # Use a distinct timer name
# Calls the modified update_policy method
metrics = self.actor.update_policy_dpo_with_ref(data=data) # <-- THIS CALLS THE MODIFIED FUNCTION
delta_time = timer.last
# --- Add Performance Metrics ---
# MFU calculation might be less accurate/meaningful here for DPO
metrics["perf/approx_tokens_processed"] = torch.sum(
data.batch.get("attention_mask", torch.tensor(0))
).item() # Approx tokens
metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024**3)
metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024**3)
metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3)
global_num_tokens = data.meta_info["global_token_num"]
estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
metrics["perf/mfu/actor"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size
# --- LR Scheduler Step ---
lr = self.actor_lr_scheduler.get_last_lr()[0]
metrics["actor/lr"] = lr
self.actor_lr_scheduler.step()
log_gpu_memory_usage("After update policy (DPO via PPO path)", logger=logger)
# --- Prepare Output ---
output = DataProto(meta_info={"metrics": metrics})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
output = output.to("cpu")
# --- FSDP State Management (Offload) ---
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
return output
# TODO(sgm): we may need to extract it to dp_reward_model.py
class RewardModelWorker(Worker):
"""
Note that we only implement the reward model that is subclass of AutoModelForTokenClassification.
"""
def __init__(self, config):
super().__init__()
import torch.distributed
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend=get_nccl_backend())
self.config = config
# build device mesh for Ulysses Sequence Parallel
world_size = torch.distributed.get_world_size()
from torch.distributed.device_mesh import init_device_mesh
fsdp_size = self.config.model.fsdp_config.fsdp_size
self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)
self.ulysses_device_mesh = None
self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1)
dp = world_size // self.ulysses_sequence_parallel_size
if self.ulysses_sequence_parallel_size > 1:
self.ulysses_device_mesh = init_device_mesh(
get_device_name(), mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"]
)
self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
self.use_remove_padding = self.config.model.get("use_remove_padding", False)
# normalize config
if self.config.micro_batch_size is not None:
self.config.micro_batch_size //= torch.distributed.get_world_size()
self.config.micro_batch_size_per_gpu = self.config.micro_batch_size
def _build_model(self, config):
# the following line is necessary
from torch.distributed.fsdp import CPUOffload
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import AutoConfig, AutoModelForTokenClassification
# download the checkpoint from hdfs
local_path = copy_to_local(config.model.path)
if self.config.model.input_tokenizer is None:
self._do_switch_chat_template = False
else:
self._do_switch_chat_template = True
input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer)
self.input_tokenizer = hf_tokenizer(
input_tokenizer_local_path, trust_remote_code=config.model.get("trust_remote_code", False)
)
self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get("trust_remote_code", False))
trust_remote_code = config.model.get("trust_remote_code", False)
model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
model_config.num_labels = 1
# note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
init_context = get_init_weight_context_manager(
use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh
)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
model_config.classifier_dropout = 0.0
reward_module = AutoModelForTokenClassification.from_pretrained(
pretrained_model_name_or_path=local_path,
config=model_config,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
trust_remote_code=trust_remote_code,
)
if config.model.get("use_remove_padding", False) or self.ulysses_sequence_parallel_size > 1:
from verl.models.transformers.monkey_patch import apply_monkey_patch
apply_monkey_patch(model=reward_module, ulysses_sp_size=self.ulysses_sequence_parallel_size)
reward_module.to(torch.bfloat16)
auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config)
fsdp_mesh = self.device_mesh
sharding_strategy = get_sharding_strategy(fsdp_mesh)
reward_module = FSDP(
reward_module,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=get_device_id(),
sharding_strategy=sharding_strategy, # zero3
sync_module_states=True,
cpu_offload=CPUOffload(offload_params=True),
forward_prefetch=False,
device_mesh=self.device_mesh,
)
return reward_module
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get("external_lib", None))
self.reward_module = self._build_model(config=self.config)
def _forward_micro_batch(self, micro_batch):
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs
with torch.no_grad(), torch.autocast(device_type=get_device_name(), dtype=torch.bfloat16):
input_ids = micro_batch["input_ids"]
batch_size, seqlen = input_ids.shape
attention_mask = micro_batch["attention_mask"]
position_ids = micro_batch["position_ids"]
if self.use_remove_padding:
input_ids_rmpad, indices, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask
) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
).transpose(0, 1)
# pad and slice the inputs if sp > 1
if self.ulysses_sequence_parallel_size > 1:
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(
input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size
)
# only pass input_ids and position_ids to enable flash_attn_varlen
output = self.reward_module(
input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False
) # prevent model thinks we are generating
reward_rmpad = output.logits
reward_rmpad = reward_rmpad.squeeze(0) # (total_nnz)
# gather output if sp > 1
if self.ulysses_sequence_parallel_size > 1:
reward_rmpad = gather_outputs_and_unpad(
reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size
)
# pad it back
rm_score = pad_input(reward_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1)
else:
output = self.reward_module(
input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False
)
rm_score = output.logits # (batch_size, seq_len, 1)
rm_score = rm_score.squeeze(-1)
# extract the result of the last valid token
eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,)
rm_score = rm_score[torch.arange(batch_size), eos_mask_idx]
return rm_score
def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor):
batch_size = data.batch.batch_size[0]
# expand as token_level_reward
attention_mask = data.batch["attention_mask"]
position_ids = data.batch["position_ids"]
response_length = data.batch["responses"].shape[-1]
eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,)
token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype) # (bsz, seqlen)
token_level_scores[torch.arange(batch_size), eos_mask_idx] = scores
# select the response part
token_level_scores = token_level_scores[:, -response_length:]
return token_level_scores
def _switch_chat_template(self, data: DataProto):
src_max_length = data.batch["attention_mask"].shape[-1]
src_tokenizer = self.input_tokenizer
target_tokenizer = self.tokenizer
rm_input_ids = []
rm_attention_mask = []
for i in range(data.batch.batch_size[0]):
# extract raw prompt
if isinstance(data.non_tensor_batch["raw_prompt"][i], list):
chat: list = data.non_tensor_batch["raw_prompt"][i]
else:
chat: list = data.non_tensor_batch["raw_prompt"][i].tolist()
# extract response
response_ids = data.batch["responses"][i]
response_length = response_ids.shape[-1]
valid_response_length = data.batch["attention_mask"][i][-response_length:].sum()
valid_response_ids = response_ids[:valid_response_length]
# decode
response = src_tokenizer.decode(valid_response_ids)
# remove bos and eos
response = response.replace(src_tokenizer.eos_token, "")
chat.append({"role": "assistant", "content": response})
prompt_with_chat_template = target_tokenizer.apply_chat_template(
chat, add_generation_prompt=False, tokenize=False
)
if self.rank == 0 and i == 0:
# for debugging purpose
print(f"Switch template. chat: {prompt_with_chat_template}")
# the maximum length is actually determined by the reward model itself
max_length = self.config.get("max_length", src_max_length)
if max_length is None:
max_length = src_max_length
model_inputs = target_tokenizer(prompt_with_chat_template, return_tensors="pt", add_special_tokens=False)
input_ids, attention_mask = verl_F.postprocess_data(
input_ids=model_inputs["input_ids"],
attention_mask=model_inputs["attention_mask"],
max_length=max_length,
pad_token_id=target_tokenizer.pad_token_id,
left_pad=False, # right padding
truncation=self.config.get("truncation", "right"),
) # truncate from the right
rm_input_ids.append(input_ids)
rm_attention_mask.append(attention_mask)
rm_input_ids = torch.cat(rm_input_ids, dim=0)
rm_attention_mask = torch.cat(rm_attention_mask, dim=0)
rm_position_ids = compute_position_id_with_mask(rm_attention_mask)
rm_inputs = {"input_ids": rm_input_ids, "attention_mask": rm_attention_mask, "position_ids": rm_position_ids}
return DataProto.from_dict(rm_inputs)
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_rm_score(self, data: DataProto):
import itertools
from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches
# Support all hardwares
data = data.to(get_device_id())
if self._do_switch_chat_template:
rm_data = self._switch_chat_template(data)
else:
rm_input_ids = data.batch["input_ids"]
rm_attention_mask = data.batch["attention_mask"]
rm_position_ids = data.batch["position_ids"]
rm_inputs = {
"input_ids": rm_input_ids,
"attention_mask": rm_attention_mask,
"position_ids": rm_position_ids,
}
rm_data = DataProto.from_dict(rm_inputs)
# Support all hardwares
rm_data.batch = rm_data.batch.to(get_device_id())
# perform forward computation
with self.ulysses_sharding_manager:
rm_data = self.ulysses_sharding_manager.preprocess_data(data=rm_data)
data = self.ulysses_sharding_manager.preprocess_data(data=data)
use_dynamic_bsz = self.config.use_dynamic_bsz
if use_dynamic_bsz:
max_token_len = self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
micro_batches, indices = rearrange_micro_batches(batch=rm_data.batch, max_token_len=max_token_len)
else:
micro_batches = rm_data.batch.split(self.config.micro_batch_size_per_gpu)
output = []
for micro_batch in micro_batches:
rm_score = self._forward_micro_batch(micro_batch)
output.append(rm_score)
scores = torch.cat(output, dim=0) # (batch_size)
if use_dynamic_bsz:
indices = list(itertools.chain.from_iterable(indices))
assert len(indices) == scores.size(0), f"{len(indices)} vs. {scores.size()}"
revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
scores = scores[revert_indices]
token_level_scores = self._expand_to_token_level(data, scores)
# Note that this is only the scores, may not be the final rewards used to train RL
output = DataProto.from_dict(tensors={"rm_scores": token_level_scores})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
self.reward_module._handle.reshard(True)
output = output.to("cpu")
return output
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import hydra
import ray
from recipe.spin.spin_trainer import RaySPINTrainer
from verl.trainer.ppo.reward import get_custom_reward_fn
@hydra.main(config_path="config", config_name="spin_trainer", version_base=None)
def main(config):
run_ppo(config)
def run_ppo(config) -> None:
# TODO(linjunrong.ocss884): this ENV is left for resolving SGLang conflict with ray devices
# isolation, will solve in the future
os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "")
if not ray.is_initialized():
# this is for local ray cluster
ray.init(
runtime_env={
"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"}
}
)
runner = TaskRunner.remote()
ray.get(runner.run.remote(config))
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
class TaskRunner:
def run(self, config):
# print initial config
from pprint import pprint
from omegaconf import OmegaConf
from verl.utils.fs import copy_to_local
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
# download the checkpoint from hdfs
local_path = copy_to_local(config.actor_rollout_ref.model.path)
# instantiate tokenizer
from verl.utils import hf_processor, hf_tokenizer
trust_remote_code = config.data.get("trust_remote_code", False)
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none
# define worker classes
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
assert config.critic.strategy in {"fsdp", "fsdp2"}
# from recipe.spin.fsdp_workers import ActorRolloutRefWorker
from recipe.spin.fsdp_workers import SPINRolloutRefWorker
from verl.single_controller.ray import RayWorkerGroup
ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == "megatron":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
ray_worker_group_cls = NVMegatronRayWorkerGroup
else:
raise NotImplementedError
from recipe.spin.spin_trainer import ResourcePoolManager, Role
role_worker_mapping = {
# Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
Role.ActorRollout: ray.remote(SPINRolloutRefWorker),
# Role.Critic: ray.remote(CriticWorker),
}
global_pool_id = "global_pool"
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
# Role.Critic: global_pool_id,
}
if config.reward_model.enable:
if config.reward_model.strategy in {"fsdp", "fsdp2"}:
from recipe.spin.fsdp_workers import RewardModelWorker
elif config.reward_model.strategy == "megatron":
from verl.workers.megatron_workers import RewardModelWorker
else:
raise NotImplementedError
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
mapping[Role.RewardModel] = global_pool_id
# use reference model
# if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
# role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
role_worker_mapping[Role.RefPolicy] = ray.remote(SPINRolloutRefWorker)
mapping[Role.RefPolicy] = global_pool_id
from verl.workers.reward_manager import get_reward_manager_cls
# Note(haibin.lin): please make sure custom reward managers are imported and
# registered via `verl.workers.reward_manager.register`
reward_manager_name = config.reward_model.get("reward_manager", "naive")
reward_manager_cls = get_reward_manager_cls(reward_manager_name)
compute_score = get_custom_reward_fn(config)
reward_kwargs = dict(config.reward_model.get("reward_kwargs", {}))
reward_fn = reward_manager_cls(
tokenizer=tokenizer,
num_examine=0,
compute_score=compute_score,
reward_fn_key=config.data.reward_fn_key,
**reward_kwargs,
)
# Note that we always use function-based RM for validation
val_reward_fn = reward_manager_cls(
tokenizer=tokenizer, num_examine=1, compute_score=compute_score, reward_fn_key=config.data.reward_fn_key
)
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
trainer = RaySPINTrainer(
config=config,
tokenizer=tokenizer,
processor=processor,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
)
trainer.init_workers()
trainer.fit_dpo()
if __name__ == "__main__":
main()
set -e
set -x
VISIBLE_DEVICES="4,5,6,7"
export HYDRA_FULL_ERROR=1
CUDA_VISIBLE_DEVICES=${VISIBLE_DEVICES} python3 -m recipe.spin.main_spin \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.max_prompt_length=1024 \
data.max_response_length=1024 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
actor_rollout_ref.actor.ppo_micro_batch_size=8 \
actor_rollout_ref.rollout.log_prob_micro_batch_size=64 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.ref.log_prob_micro_batch_size=64 \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.logger=console \
trainer.val_before_train=True \
trainer.n_gpus_per_node=4 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=1 \
+trainer.log_freq=1 \
trainer.ref_update_freq=1 \
trainer.total_epochs=1000 2>&1 | tee verl_demo.log
\ No newline at end of file
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import traceback
import uuid
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from pprint import pprint
from typing import Any, Optional
import numpy as np
import ray
import torch
from codetiming import Timer
from omegaconf import OmegaConf, open_dict
from torch.utils.data import Dataset, Sampler
from torchdata.stateful_dataloader import StatefulDataLoader
from tqdm import tqdm
from recipe.spin import core_algos
from verl import DataProto
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
from verl.single_controller.base import Worker
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from verl.single_controller.ray.base import create_colocated_worker_cls
from verl.trainer.ppo.metric_utils import (
compute_throughout_metrics,
compute_timing_metrics,
process_validation_metrics,
reduce_metrics,
)
from verl.trainer.ppo.ray_trainer import Role
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
from verl.utils.torch_functional import masked_mean
from verl.utils.tracking import ValidationGenerationsLogger
WorkerType = type[Worker]
class AdvantageEstimator(str, Enum):
"""
Using an enumeration class to avoid spelling errors in adv_estimator
"""
GAE = "gae"
GRPO = "grpo"
REINFORCE_PLUS_PLUS = "reinforce_plus_plus"
REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline"
REMAX = "remax"
RLOO = "rloo"
@dataclass
class ResourcePoolManager:
"""
Define a resource pool specification. Resource pool will be initialized first.
Mapping
"""
resource_pool_spec: dict[str, list[int]]
mapping: dict[Role, str]
resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict)
def create_resource_pool(self):
for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():
# max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool
# For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one.
# For Megatron backend, we recommend using max_colocate_count>1 that can utilize different
# WorkerGroup for different models
resource_pool = RayResourcePool(
process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name
)
self.resource_pool_dict[resource_pool_name] = resource_pool
self._check_resource_available()
def get_resource_pool(self, role: Role) -> RayResourcePool:
"""Get the resource pool of the worker_cls"""
return self.resource_pool_dict[self.mapping[role]]
def get_n_gpus(self) -> int:
"""Get the number of gpus in this cluster."""
return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes])
def _check_resource_available(self):
"""Check if the resource pool can be satisfied in this ray cluster."""
node_available_resources = ray.state.available_resources_per_node()
node_available_gpus = {node: node_info.get("GPU", 0) for node, node_info in node_available_resources.items()}
# check total required gpus can be satisfied
total_available_gpus = sum(node_available_gpus.values())
total_required_gpus = sum(
[n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]
)
if total_available_gpus < total_required_gpus:
raise ValueError(
f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}"
)
# check each resource pool can be satisfied, O(#resource_pools * #nodes)
for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():
num_gpus, num_nodes = process_on_nodes[0], len(process_on_nodes)
for node, available_gpus in node_available_gpus.items():
if available_gpus >= num_gpus:
node_available_gpus[node] -= num_gpus
num_nodes -= 1
if num_nodes == 0:
break
if num_nodes > 0:
raise ValueError(
f"Resource pool {resource_pool_name}: {num_gpus}*{num_nodes} cannot be satisfied in this "
f"ray cluster"
)
def _compute_response_info(batch: DataProto) -> dict[str, Any]:
"""Placeholder: Computes prompt and response lengths."""
try:
# Assuming 'prompts' and 'responses' keys exist after generation/union
prompt_len = batch.batch["prompts"].shape[1]
resp_len = batch.batch["responses"].shape[1]
# This is simplified - real implementation might use attention masks
# to get actual lengths per sample.
batch_size = batch.batch.batch_size[0]
prompt_lengths_tensor = torch.full((batch_size,), prompt_len, dtype=torch.float32, device=batch.batch.device)
response_lengths_tensor = torch.full((batch_size,), resp_len, dtype=torch.float32, device=batch.batch.device)
# Try getting actual lengths from attention mask if possible (more accurate)
if "response_mask" in batch.batch:
response_lengths_tensor = batch.batch["response_mask"].sum(dim=1).float()
# if "attention_mask" in batch.batch and "response_mask" in batch.batch:
# full_mask = batch.batch["attention_mask"]
# resp_mask = batch.batch["response_mask"]
# Infer prompt mask length based on where response mask starts or total length
# This logic depends heavily on how your masks are constructed.
# Example: prompt_lengths_tensor = full_mask.sum(dim=1).float() - response_lengths_tensor
# Fallback to using prompt shape if mask logic is complex:
prompt_lengths_tensor = torch.tensor(
[batch.batch["prompts"].shape[1]] * batch_size, dtype=torch.float32, device=batch.batch.device
)
return {
"prompt_length": prompt_lengths_tensor,
"response_length": response_lengths_tensor,
"max_response_length": resp_len,
"max_prompt_length": prompt_len, # Or from config if fixed padding
}
except KeyError as e:
print(f"Warning: Missing key in _compute_response_info: {e}. Returning defaults.")
# Return default/dummy values if keys are missing
b_size = batch.batch.batch_size[0] if batch.batch.batch_size else 1
max_resp = batch.batch.get("responses").shape[1] if batch.batch.get("responses") is not None else 0
max_prompt = batch.batch.get("prompts").shape[1] if batch.batch.get("prompts") is not None else 0
return {
"prompt_length": torch.zeros(b_size),
"response_length": torch.zeros(b_size),
"max_response_length": max_resp,
"max_prompt_length": max_prompt,
}
# --- Modified Metric Function ---
def compute_dpo_data_metrics(batch: DataProto) -> dict[str, Any]:
"""
Computes and returns metrics relevant for the DPO-like process.
Assumes 'batch' contains results after generation and preference marking,
potentially including 'dpo_logits', 'preferences', 'chosen_logps', etc.
Removes PPO-specific advantage/return/critic metrics.
"""
print("---- [DEBUG] Computing DPO Data Metrics ----")
metrics = {}
try:
# --- Scores and Rewards (from reward_fn) ---
if "token_level_scores" in batch.batch and batch.batch["token_level_scores"] is not None:
sequence_score = batch.batch["token_level_scores"].sum(-1)
metrics.update(
{
"reward/score/mean": torch.mean(sequence_score).item(),
"reward/score/max": torch.max(sequence_score).item(),
"reward/score/min": torch.min(sequence_score).item(),
}
)
else:
print("DEBUG compute_dpo_data_metrics: 'token_level_scores' not found.")
if "token_level_rewards" in batch.batch and batch.batch["token_level_rewards"] is not None:
sequence_reward = batch.batch["token_level_rewards"].sum(-1)
metrics.update(
{
"reward/rewards/mean": torch.mean(sequence_reward).item(),
"reward/rewards/max": torch.max(sequence_reward).item(),
"reward/rewards/min": torch.min(sequence_reward).item(),
}
)
else:
print("DEBUG compute_dpo_data_metrics: 'token_level_rewards' not found.")
# --- DPO Specific Metrics (if stored previously) ---
if "dpo_logits" in batch.batch and batch.batch["dpo_logits"] is not None:
metrics["actor/dpo_logits"] = batch.batch["dpo_logits"].mean().item()
else:
print("DEBUG compute_dpo_data_metrics: 'dpo_logits' not found.")
if "chosen_logps" in batch.batch and batch.batch["chosen_logps"] is not None:
metrics["actor/chosen_logps"] = batch.batch["chosen_logps"].mean().item()
else:
print("DEBUG compute_dpo_data_metrics: 'chosen_logps' not found.")
if "rejected_logps" in batch.batch and batch.batch["rejected_logps"] is not None:
metrics["actor/rejected_logps"] = batch.batch["rejected_logps"].mean().item()
else:
print("DEBUG compute_dpo_data_metrics: 'rejected_logps' not found.")
# Add metrics based on the 'preferences' mask if available
# if "preferences" in batch.batch and batch.batch["preferences"] is not None:
# prefs_mask = batch.batch["preferences"] # Shape [batch_size * n]
# Calculate accuracy based on RM scores (assuming higher score -> True in mask)
# Requires chosen/rejected scores to be available or recalculated
# This is complex here, better calculated in the main loop or update function
# --- Length Metrics ---
response_info = _compute_response_info(batch)
prompt_length = response_info["prompt_length"]
response_length = response_info["response_length"]
max_response_length = response_info["max_response_length"]
max_prompt_length = response_info["max_prompt_length"] # Use calculated or from config
metrics.update(
{
"response_length/mean": torch.mean(response_length).item(),
"response_length/max": torch.max(response_length).item(),
"response_length/min": torch.min(response_length).item(),
"response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()).item(),
"prompt_length/mean": torch.mean(prompt_length).item(),
"prompt_length/max": torch.max(prompt_length).item(),
"prompt_length/min": torch.min(prompt_length).item(),
# Prompt clip ratio might need adjustment based on how max_prompt_length is defined
"prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).item(),
}
)
except KeyError as e:
print(f"ERROR in compute_dpo_data_metrics: Missing key {e}")
except Exception as e:
print(f"ERROR in compute_dpo_data_metrics: {e}")
traceback.print_exc()
print(f"---- [DEBUG] Calculated DPO Data Metrics: {list(metrics.keys())} ----")
return metrics
def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"):
responses = data.batch["responses"]
response_length = responses.size(1)
token_level_scores = data.batch["token_level_scores"]
batch_size = data.batch.batch_size[0]
attention_mask = data.batch["attention_mask"]
response_mask = attention_mask[:, -response_length:]
# compute kl between ref_policy and current policy
# When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled.
kld = core_algos.kl_penalty(
data.batch["old_log_probs"], data.batch["ref_log_prob"], kl_penalty=kl_penalty
) # (batch_size, response_length)
kld = kld * response_mask
beta = kl_ctrl.value
token_level_rewards = token_level_scores - beta * kld
current_kl = masked_mean(kld, mask=response_mask, axis=-1) # average over sequence
current_kl = torch.mean(current_kl, dim=0).item()
# according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837
kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)
data.batch["token_level_rewards"] = token_level_rewards
metrics = {"actor/reward_kl_penalty": current_kl, "actor/reward_kl_penalty_coeff": beta}
return data, metrics
def compute_response_mask(data: DataProto):
responses = data.batch["responses"]
response_length = responses.size(1)
attention_mask = data.batch["attention_mask"]
return attention_mask[:, -response_length:]
def compute_onlineDPO_pref(data: DataProto):
"""
Wrapper to compute DPO preference and add it to the DataProto batch.
Includes debugging prints.
"""
# print(f"\n---- [DEBUG] Entering compute_onlineDPO_pref ----")
# print(f" Input batch keys: {list(data.batch.keys())}")
# Check inputs
rewards_tensor = data.batch.get("token_level_rewards")
mask_tensor = data.batch.get("response_mask")
if rewards_tensor is None or mask_tensor is None:
print(" ERROR: Missing 'token_level_rewards' or 'response_mask' in input data!")
# Handle error case - maybe return original data or raise?
# Returning original data for now to potentially allow skipping
return data
try:
preferences = core_algos.compute_onlinedpo_pref(token_level_rewards=rewards_tensor, response_mask=mask_tensor)
# Store the result
data.batch["preferences"] = preferences
except AttributeError:
print("ERROR: Function 'compute_online_dpo_preference' not found in core_algos.py!")
# Assign dummy value or raise error
data.batch["preferences"] = None # Indicate failure
except Exception as e_pref:
print(f"ERROR during core_algos.compute_online_dpo_preference: {e_pref}")
import traceback
traceback.print_exc()
data.batch["preferences"] = None # Indicate failure
# print(f"---- [DEBUG] Exiting compute_onlineDPO_pref ----")
return data
@contextmanager
def _timer(name: str, timing_raw: dict[str, float]):
with Timer(name=name, logger=None) as timer:
yield
timing_raw[name] = timer.last
class RaySPINTrainer:
"""
Note that this trainer runs on the driver process on a single CPU/GPU node.
"""
# TODO: support each role have individual ray_worker_group_cls,
# i.e., support different backend of different role
def __init__(
self,
config,
tokenizer,
role_worker_mapping: dict[Role, WorkerType],
resource_pool_manager: ResourcePoolManager,
ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
processor=None,
reward_fn=None,
val_reward_fn=None,
train_dataset: Optional[Dataset] = None,
val_dataset: Optional[Dataset] = None,
collate_fn=None,
train_sampler: Optional[Sampler] = None,
device_name=None,
):
# assert get_torch_device().is_available(), 'cuda must be available on driver'
self.tokenizer = tokenizer
self.processor = processor
self.config = config
self.reward_fn = reward_fn
self.val_reward_fn = val_reward_fn
self.hybrid_engine = config.actor_rollout_ref.hybrid_engine
assert self.hybrid_engine, "Currently, only support hybrid engine"
if self.hybrid_engine:
assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}"
self.role_worker_mapping = role_worker_mapping
self.resource_pool_manager = resource_pool_manager
self.use_reference_policy = Role.RefPolicy in role_worker_mapping
self.use_rm = Role.RewardModel in role_worker_mapping
self.ray_worker_group_cls = ray_worker_group_cls
self.validation_generations_logger = ValidationGenerationsLogger()
self.async_rollout_mode = False
self.device_name = device_name if device_name else self.config.trainer.device
# define in-reward KL control
# kl loss control currently not suppoorted
if config.algorithm.use_kl_in_reward:
self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl)
self.use_critic = False
self._validate_config()
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
def _validate_config(self):
config = self.config
# number of GPUs total
n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes
# 1. Check total batch size for data correctness
real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n
assert real_train_batch_size % n_gpus == 0, (
f"real_train_batch_size ({real_train_batch_size}) must be divisible by total n_gpus ({n_gpus})."
)
# A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu"
# We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu".
def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
settings = {
"actor_rollout_ref.actor": "micro_batch_size",
"critic": "micro_batch_size",
"reward_model": "micro_batch_size",
"actor_rollout_ref.ref": "log_prob_micro_batch_size",
"actor_rollout_ref.rollout": "log_prob_micro_batch_size",
}
if name in settings:
param = settings[name]
param_per_gpu = f"{param}_per_gpu"
if mbs is None and mbs_per_gpu is None:
raise ValueError(
f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'."
)
if mbs is not None and mbs_per_gpu is not None:
raise ValueError(
f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. "
f"Please remove '{name}.{param}' because only '*_{param_per_gpu}' is supported "
f"(the former is deprecated)."
)
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
# actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu
check_mutually_exclusive(
config.actor_rollout_ref.actor.ppo_micro_batch_size,
config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu,
"actor_rollout_ref.actor",
)
if self.use_reference_policy:
# reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
check_mutually_exclusive(
config.actor_rollout_ref.ref.log_prob_micro_batch_size,
config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu,
"actor_rollout_ref.ref",
)
# The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
check_mutually_exclusive(
config.actor_rollout_ref.rollout.log_prob_micro_batch_size,
config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu,
"actor_rollout_ref.rollout",
)
if self.use_critic and not config.critic.use_dynamic_bsz:
# Check for critic micro-batch size conflicts
check_mutually_exclusive(
config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, "critic"
)
# Check for reward model micro-batch size conflicts
if config.reward_model.enable and not config.reward_model.use_dynamic_bsz:
check_mutually_exclusive(
config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model"
)
# Actor
# check if train_batch_size is larger than ppo_mini_batch_size
# if NOT dynamic_bsz, we must ensure:
# ppo_mini_batch_size is divisible by ppo_micro_batch_size
# ppo_micro_batch_size * sequence_parallel_size >= n_gpus
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
assert config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size
sp_size = config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1)
if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None:
assert (
config.actor_rollout_ref.actor.ppo_mini_batch_size
% config.actor_rollout_ref.actor.ppo_micro_batch_size
== 0
)
assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus
assert config.actor_rollout_ref.actor.loss_agg_mode in [
"token-mean",
"seq-mean-token-sum",
"seq-mean-token-mean",
], f"Invalid loss_agg_mode: {config.actor_rollout_ref.actor.loss_agg_mode}"
if config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss:
print("NOTICE: You have both enabled in-reward kl and kl loss.")
# critic
if self.use_critic and not config.critic.use_dynamic_bsz:
assert config.data.train_batch_size >= config.critic.ppo_mini_batch_size
sp_size = config.critic.get("ulysses_sequence_parallel_size", 1)
if config.critic.ppo_micro_batch_size is not None:
assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0
assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus
# Check if use_remove_padding is enabled when using sequence parallelism for fsdp
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
if (
config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) > 1
or config.actor_rollout_ref.ref.get("ulysses_sequence_parallel_size", 1) > 1
):
assert config.actor_rollout_ref.model.use_remove_padding, (
"When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`."
)
if self.use_critic and config.critic.strategy in {"fsdp", "fsdp2"}:
if config.critic.get("ulysses_sequence_parallel_size", 1) > 1:
assert config.critic.model.use_remove_padding, (
"When using sequence parallelism for critic, you must enable `use_remove_padding`."
)
if config.data.get("val_batch_size", None) is not None:
print(
"WARNING: val_batch_size is deprecated. Validation datasets are sent to inference engines "
"as a whole batch, which will schedule the memory themselves."
)
# check eval config
if config.actor_rollout_ref.rollout.val_kwargs.do_sample:
assert config.actor_rollout_ref.rollout.temperature > 0, (
"validation gen temperature should be greater than 0 when enabling do_sample"
)
print("[validate_config] All configuration checks passed successfully!")
def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler):
"""
Creates the train and validation dataloaders.
"""
# TODO: we have to make sure the batch size is divisible by the dp size
from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler
if train_dataset is None:
train_dataset = create_rl_dataset(
self.config.data.train_files, self.config.data, self.tokenizer, self.processor
)
if val_dataset is None:
val_dataset = create_rl_dataset(
self.config.data.val_files, self.config.data, self.tokenizer, self.processor
)
self.train_dataset, self.val_dataset = train_dataset, val_dataset
if train_sampler is None:
train_sampler = create_rl_sampler(self.config.data, self.train_dataset)
if collate_fn is None:
from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn
collate_fn = default_collate_fn
self.train_dataloader = StatefulDataLoader(
dataset=self.train_dataset,
batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size),
num_workers=self.config.data.get("dataloader_num_workers", 8),
drop_last=True,
collate_fn=collate_fn,
sampler=train_sampler,
)
val_batch_size = self.config.data.val_batch_size # Prefer config value if set
if val_batch_size is None:
val_batch_size = len(self.val_dataset)
self.val_dataloader = StatefulDataLoader(
dataset=self.val_dataset,
batch_size=val_batch_size,
num_workers=self.config.data.get("dataloader_num_workers", 8),
shuffle=False,
drop_last=False,
collate_fn=collate_fn,
)
assert len(self.train_dataloader) >= 1, "Train dataloader is empty!"
assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!"
print(
f"Size of train dataloader: {len(self.train_dataloader)}, "
f"Size of val dataloader: {len(self.val_dataloader)}"
)
total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
if self.config.trainer.total_training_steps is not None:
total_training_steps = self.config.trainer.total_training_steps
self.total_training_steps = total_training_steps
print(f"Total training steps: {self.total_training_steps}")
try:
OmegaConf.set_struct(self.config, True)
with open_dict(self.config):
if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"):
self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps
if OmegaConf.select(self.config, "critic.optim"):
self.config.critic.optim.total_training_steps = total_training_steps
except Exception as e:
print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}")
def _maybe_log_val_generations(self, inputs, outputs, scores):
"""Log a table of validation samples to the configured logger (wandb or swanlab)"""
generations_to_log = self.config.trainer.log_val_generations
if generations_to_log == 0:
return
import numpy as np
# Create tuples of (input, output, score) and sort by input text
samples = list(zip(inputs, outputs, scores, strict=True))
samples.sort(key=lambda x: x[0]) # Sort by input text
# Use fixed random seed for deterministic shuffling
rng = np.random.RandomState(42)
rng.shuffle(samples)
# Take first N samples after shuffling
samples = samples[:generations_to_log]
# Log to each configured logger
self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps)
def _validate(self):
data_source_lst = []
reward_extra_infos_dict: dict[str, list] = defaultdict(list)
# Lists to collect samples for the table
sample_inputs = []
sample_outputs = []
sample_scores = []
for test_data in self.val_dataloader:
test_batch = DataProto.from_single_dict(test_data)
# repeat test batch
test_batch = test_batch.repeat(
repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True
)
# we only do validation on rule-based rm
if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model":
return {}
# Store original inputs
input_ids = test_batch.batch["input_ids"]
# TODO: Can we keep special tokens except for padding tokens?
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
sample_inputs.extend(input_texts)
batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
if "multi_modal_inputs" in test_batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.extend(["multi_modal_data", "multi_modal_inputs"])
if "raw_prompt" in test_batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("raw_prompt")
if "tools_kwargs" in test_batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("tools_kwargs")
test_gen_batch = test_batch.pop(
batch_keys=batch_keys_to_pop,
non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
)
test_gen_batch.meta_info = {
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
"recompute_log_prob": False,
"do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample,
"validate": True,
}
print(f"test_gen_batch meta info: {test_gen_batch.meta_info}")
# pad to be divisible by dp_size
test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size)
if not self.async_rollout_mode:
test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)
else:
test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded)
# unpad
test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)
print("validation generation end")
# Store generated outputs
output_ids = test_output_gen_batch.batch["responses"]
output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
sample_outputs.extend(output_texts)
test_batch = test_batch.union(test_output_gen_batch)
# evaluate using reward_function
result = self.val_reward_fn(test_batch, return_dict=True)
reward_tensor = result["reward_tensor"]
scores = reward_tensor.sum(-1).cpu().tolist()
sample_scores.extend(scores)
reward_extra_infos_dict["reward"].extend(scores)
if "reward_extra_info" in result:
for key, lst in result["reward_extra_info"].items():
reward_extra_infos_dict[key].extend(lst)
data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0]))
self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)
# dump generations
val_data_dir = self.config.trainer.get("validation_data_dir", None)
if val_data_dir:
self._dump_generations(
inputs=sample_inputs,
outputs=sample_outputs,
scores=sample_scores,
reward_extra_infos_dict=reward_extra_infos_dict,
dump_path=val_data_dir,
)
for key_info, lst in reward_extra_infos_dict.items():
assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}"
data_sources = np.concatenate(data_source_lst, axis=0)
print(f"DEBUG: Data sources shape: {data_sources.shape}") # Added Print
print(f"DEBUG: reward_extra_infos_dict keys before processing: {reward_extra_infos_dict.keys()}") # Added Print
data_src2var2metric2val = process_validation_metrics(data_sources, sample_inputs, reward_extra_infos_dict)
print(
f"DEBUG: Output of process_validation_metrics (data_src2var2metric2val): {data_src2var2metric2val}"
) # Added Print
metric_dict = {}
for data_source, var2metric2val in data_src2var2metric2val.items():
core_var = "acc" if "acc" in var2metric2val else "reward"
for var_name, metric2val in var2metric2val.items():
n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()])
for metric_name, metric_val in metric2val.items():
if (
(var_name == core_var)
and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"])
and (f"@{n_max}" in metric_name)
):
metric_sec = "val-core"
else:
metric_sec = "val-aux"
pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}"
metric_dict[pfx] = metric_val
return metric_dict
def init_workers(self):
"""Init resource pool and worker group"""
self.resource_pool_manager.create_resource_pool()
self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}
# create actor and rollout
if self.hybrid_engine:
resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)
actor_rollout_cls = RayClassWithInitArgs(
cls=self.role_worker_mapping[Role.ActorRollout],
config=self.config.actor_rollout_ref,
role="actor_rollout",
)
self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls
else:
raise NotImplementedError
# create critic
if self.use_critic:
resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic)
self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls
# create reference policy if needed
if self.use_reference_policy:
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
ref_policy_cls = RayClassWithInitArgs(
self.role_worker_mapping[Role.RefPolicy], config=self.config.actor_rollout_ref, role="ref"
)
self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls
# create a reward model if reward_fn is None
if self.use_rm:
# we create a RM here
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)
self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls
# initialize WorkerGroup
# NOTE: if you want to use a different resource pool for each role, which can support different
# parallel size,
# you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to
# different worker groups.
# See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
all_wg = {}
self.wg_dicts = []
wg_kwargs = {} # Setting up kwargs for RayWorkerGroup
if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None:
wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout
wg_kwargs["device_name"] = self.device_name
for resource_pool, class_dict in self.resource_pool_to_cls.items():
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
wg_dict = self.ray_worker_group_cls(
resource_pool=resource_pool,
ray_cls_with_init=worker_dict_cls,
**wg_kwargs,
)
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
all_wg.update(spawn_wg)
# keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699
self.wg_dicts.append(wg_dict)
if self.use_critic:
self.critic_wg = all_wg["critic"]
self.critic_wg.init_model()
if self.use_reference_policy:
self.ref_policy_wg = all_wg["ref"]
self.ref_policy_wg.init_model()
if self.use_rm:
self.rm_wg = all_wg["rm"]
self.rm_wg.init_model()
# we should create rollout at the end so that vllm can have a better estimation of kv cache memory
self.actor_rollout_wg = all_wg["actor_rollout"]
self.actor_rollout_wg.init_model()
def _save_checkpoint(self):
# path: given_path + `/global_step_{global_steps}` + `/actor`
local_global_step_folder = os.path.join(
self.config.trainer.default_local_dir, f"global_step_{self.global_steps}"
)
print(f"local_global_step_folder: {local_global_step_folder}")
actor_local_path = os.path.join(local_global_step_folder, "actor")
actor_remote_path = (
None
if self.config.trainer.default_hdfs_dir is None
else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor")
)
remove_previous_ckpt_in_save = self.config.trainer.get("remove_previous_ckpt_in_save", False)
if remove_previous_ckpt_in_save:
print(
"Warning: remove_previous_ckpt_in_save is deprecated, set max_actor_ckpt_to_keep=1 and "
"max_critic_ckpt_to_keep=1 instead"
)
max_actor_ckpt_to_keep = (
self.config.trainer.get("max_actor_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1
)
max_critic_ckpt_to_keep = (
self.config.trainer.get("max_critic_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1
)
self.actor_rollout_wg.save_checkpoint(
actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep
)
if self.use_critic:
critic_local_path = os.path.join(local_global_step_folder, "critic")
critic_remote_path = (
None
if self.config.trainer.default_hdfs_dir is None
else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "critic")
)
self.critic_wg.save_checkpoint(
critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep
)
# save dataloader
dataloader_local_path = os.path.join(local_global_step_folder, "data.pt")
dataloader_state_dict = self.train_dataloader.state_dict()
torch.save(dataloader_state_dict, dataloader_local_path)
# latest checkpointed iteration tracker (for atomic usage)
local_latest_checkpointed_iteration = os.path.join(
self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt"
)
with open(local_latest_checkpointed_iteration, "w") as f:
f.write(str(self.global_steps))
def _load_checkpoint(self):
if self.config.trainer.resume_mode == "disable":
return 0
# load from hdfs
if self.config.trainer.default_hdfs_dir is not None:
raise NotImplementedError("load from hdfs is not implemented yet")
else:
checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path
if not os.path.isabs(checkpoint_folder):
working_dir = os.getcwd()
checkpoint_folder = os.path.join(working_dir, checkpoint_folder)
global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest
# find global_step_folder
if self.config.trainer.resume_mode == "auto":
if global_step_folder is None:
print("Training from scratch")
return 0
else:
if self.config.trainer.resume_mode == "resume_path":
assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type"
assert "global_step_" in self.config.trainer.resume_from_path, (
"resume ckpt must specify the global_steps"
)
global_step_folder = self.config.trainer.resume_from_path
if not os.path.isabs(global_step_folder):
working_dir = os.getcwd()
global_step_folder = os.path.join(working_dir, global_step_folder)
print(f"Load from checkpoint folder: {global_step_folder}")
# set global step
self.global_steps = int(global_step_folder.split("global_step_")[-1])
print(f"Setting global step to {self.global_steps}")
print(f"Resuming from {global_step_folder}")
actor_path = os.path.join(global_step_folder, "actor")
critic_path = os.path.join(global_step_folder, "critic")
# load actor
self.actor_rollout_wg.load_checkpoint(
actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load
)
# load critic
if self.use_critic:
self.critic_wg.load_checkpoint(
critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load
)
# load dataloader,
# TODO: from remote not implemented yet
dataloader_local_path = os.path.join(global_step_folder, "data.pt")
if os.path.exists(dataloader_local_path):
dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False)
self.train_dataloader.load_state_dict(dataloader_state_dict)
else:
print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch")
def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"):
"""Reorder the data on single controller such that each dp rank gets similar total tokens"""
attention_mask = batch.batch["attention_mask"]
batch_size = attention_mask.shape[0]
global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,)
world_size = self.actor_rollout_wg.world_size
global_partition_lst = get_seqlen_balanced_partitions(
global_seqlen_lst, k_partitions=world_size, equal_size=True
)
# reorder based on index. The data will be automatically equally partitioned by dispatch function
global_idx = torch.tensor([j for partition in global_partition_lst for j in partition])
batch.reorder(global_idx)
global_balance_stats = log_seqlen_unbalance(
seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix
)
metrics.update(global_balance_stats)
def fit_dpo(self): # Renamed for clarity as standard PPO loop
"""
The training loop of Online DPO using a periodically updated reference model.
The driver process calls worker groups for computation.
Advantage computation is replaced by DPO logic.
"""
import traceback # Ensure traceback is imported
from omegaconf import OmegaConf
from verl.utils.tracking import Tracking
# Initialize logger
logger = None
try:
logger = Tracking(
project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True, throw_on_missing=False),
)
except Exception as e:
print(f"Warning: Failed to initialize logger: {e}")
self.global_steps = 0
# Load checkpoint before doing anything
loaded_step = self._load_checkpoint()
self.global_steps = loaded_step + 1 if loaded_step is not None and loaded_step > 0 else 1
print(
f"Starting Online DPO training from global step {self.global_steps}. "
f"Total steps: {self.total_training_steps}"
)
print(f"Reference model update frequency: {self.config.trainer.get('ref_update_freq', 'Not Set')}")
# Check if reference policy is configured correctly for this mode
if not self.use_reference_policy:
print(
"WARNING: 'use_reference_policy' is False. Periodic reference model update requires a "
"reference policy worker. DPO updates might fail or use incorrect logic."
)
# Consider raising an error if strict adherence is required:
# raise ValueError("Periodic reference model update requires 'use_reference_policy' to be True "
# "and a configured reference worker.")
# Perform validation before training
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
print("Running validation before Online DPO training...")
val_metrics = self._validate()
pprint(f"Initial validation metrics: {val_metrics}")
if logger and val_metrics:
logger.log(data=val_metrics, step=max(0, self.global_steps - 1))
if self.config.trainer.get("val_only", False):
print("Validation only mode enabled. Exiting training.")
if logger and hasattr(logger, "finish"):
logger.finish()
return
# Add tqdm progress bar
progress_bar = tqdm(
total=self.total_training_steps,
initial=self.global_steps,
desc="Online DPO Training Progress",
position=0,
leave=True,
)
last_val_metrics = None
should_stop = False
for epoch in range(self.config.trainer.total_epochs):
if should_stop:
break
print(f"--- Starting Online DPO Epoch {epoch} ---")
try:
train_iterator = iter(self.train_dataloader)
except TypeError:
print("Warning: Dataloader is not iterable.")
train_iterator = self.train_dataloader # Fallback attempt
for batch_idx, batch_dict in enumerate(train_iterator):
if self.global_steps > self.total_training_steps:
should_stop = True
break
metrics = {}
timing_raw = {}
step_timer = Timer(logger=None)
ref_log_prob_computed = False # Flag to track if ref log probs were computed
try: # Outer try-except for the whole step
step_timer.start()
with _timer("step", timing_raw):
batch: DataProto = DataProto.from_single_dict(batch_dict)
current_batch_size = batch.batch.batch_size[0]
print(
f"\n[Step {self.global_steps}, Batch {batch_idx}] Processing batch size: "
f"{current_batch_size}"
)
# --- Reference Model Update ---
ref_update_freq = self.config.trainer.get("ref_update_freq", -1)
if (
self.use_reference_policy
and ref_update_freq > 0
and self.global_steps % ref_update_freq == 0
):
print(f"\n[Step {self.global_steps}] Updating Reference Model Weights from Actor...")
try:
# --- This requires careful implementation with FSDP ---
# 1. Save actor state dict (potentially to CPU memory or disk)
# This needs to be done collectively across actor worker ranks.
# The checkpoint_manager might be adaptable, or use FSDP APIs directly.
# Example placeholder using a conceptual save/load mechanism:
actor_state_path = "/tmp/actor_state_mid" # Temporary path
self.actor_rollout_wg.save_checkpoint(actor_state_path) # Adapt save logic
# 2. Load the state dict onto the reference model worker group
# This also needs collective loading on the ref worker ranks.
self.ref_policy_wg.load_checkpoint(actor_state_path, None, True) # Adapt load logic
print(f"[Step {self.global_steps}] Reference Model Weights Updated.")
# Optionally remove the temporary state file
# os.remove(actor_state_path) # Needs rank-aware removal or shared storage
except Exception as sync_e:
print(f"ERROR during reference model sync at step {self.global_steps}: {sync_e}")
traceback.print_exc()
# Pop keys for generation
pop_batch_keys = ["input_ids", "attention_mask"]
if "position_ids" in batch.batch:
pop_batch_keys.append("position_ids")
pop_non_tensor_keys = ["raw_prompt_ids"] if "raw_prompt_ids" in batch.non_tensor_batch else []
if "multi_modal_inputs" in batch.non_tensor_batch.keys():
pop_non_tensor_keys.extend(["multi_modal_data", "multi_modal_inputs"])
original_non_tensor_data = batch.non_tensor_batch
gen_batch = batch.pop(
batch_keys=pop_batch_keys,
non_tensor_batch_keys=pop_non_tensor_keys,
)
gen_batch = gen_batch.repeat(
repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
)
# (Add Debug prints for gen_batch if needed)
# Generate sequences (chosen/rejected pairs)
with _timer("gen", timing_raw):
try:
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
# (Add Debug prints for gen_batch_output if needed)
except Exception as gen_e:
print(f"\n!!!!!!!! ERROR DURING GENERATION (Step {self.global_steps}) !!!!!!!!")
print(gen_e)
traceback.print_exc()
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
step_timer.stop()
continue
# Combine original prompts with generated sequences
batch.non_tensor_batch = original_non_tensor_data # Restore non-tensor data
batch.non_tensor_batch["uid"] = np.array(
[str(uuid.uuid4()) for _ in range(current_batch_size)], dtype=object
)
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)
# (Add Debug prints after union if needed)
# Compute response mask (needed for ref logprob calc and DPO prep)
batch.batch["response_mask"] = compute_response_mask(batch)
if self.config.trainer.balance_batch:
self._balance_batch(batch, metrics=metrics)
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
# --- Compute Log Probs for the CURRENT policy (used for KL if enabled, or ActorAsRef
# fallback) ---
# Note: For pure DPO with external ref, this 'old_log_probs' might not be strictly needed
# unless used for other metrics or a fallback. Keep it for now.
with _timer("policy_log_prob", timing_raw):
policy_log_prob_output = self.actor_rollout_wg.compute_log_prob(batch)
batch = batch.union(policy_log_prob_output) # Adds 'old_log_probs'
# (Debug prints for old_log_probs)
# --- Compute Log Probs using the EXTERNAL Reference Model ---
if self.use_reference_policy:
with _timer("ref_log_prob_dpo", timing_raw):
# print(f"---- [Step {self.global_steps}] DEBUG DPO: Calling compute_ref_log_prob ----")
try:
# 'batch' contains interleaved chosen/rejected sequences
ref_log_prob_output = self.ref_policy_wg.compute_ref_log_prob(
batch
) # Returns DataProto with 'ref_log_prob'
batch = batch.union(
ref_log_prob_output
) # Adds 'ref_log_prob' key [batch_size * n, seq_len]
ref_log_prob_computed = True # Mark success
# print(f"---- [Step {self.global_steps}] DEBUG DPO: ref_log_prob tensor shape: "
# f"{batch.batch['ref_log_prob'].shape} ----")
except Exception as ref_e:
print(f"ERROR computing reference log probs at step {self.global_steps}: {ref_e}")
traceback.print_exc()
batch.batch["ref_log_prob"] = None # Mark as failed
ref_log_prob_computed = False
else:
print(
"Warning: Skipping external reference log prob calculation as use_reference_policy "
"is False."
)
# DPO update will likely fail unless ActorAsRef logic is re-enabled in dp_actor
# --- Compute Rewards/Scores (used to determine preference) ---
with _timer("reward_calc", timing_raw):
# (Reward calculation logic using RM or reward_fn as before)
# ... Ensure this calculates 'token_level_rewards' or similar ...
if self.use_rm:
reward_tensor_rm = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor_rm) # Adds 'rm_scores'
reward_extra_infos_dict = {}
try:
if self.reward_fn is None:
# print(f"---- [DEBUG Step {self.global_steps}] ERROR: self.reward_fn is None! "
# f"Using dummy rewards. ----")
# Use rm_scores if available, otherwise zeros
reward_tensor = batch.batch.get(
"rm_scores", torch.zeros_like(batch.batch["response_mask"], dtype=torch.float32)
)
else:
reward_result = self.reward_fn(batch, return_dict=True)
reward_tensor = reward_result["reward_tensor"] # Final combined reward
reward_extra_infos_dict = reward_result.get("reward_extra_info", {})
except Exception:
# print(f'---- [DEBUG Step {self.global_steps}] Error in reward_fn call: {e}. '
# f'Using dummy rewards. ----')
traceback.print_exc()
reward_tensor = torch.zeros_like(batch.batch["response_mask"], dtype=torch.float32)
reward_extra_infos_dict = {}
# Use 'token_level_rewards' as the key for preference calculation
batch.batch["token_level_rewards"] = reward_tensor
if reward_extra_infos_dict:
batch.non_tensor_batch.update(
{k: np.array(v) for k, v in reward_extra_infos_dict.items()}
)
# --- Determine Preferences ---
# Uses 'token_level_rewards' to determine chosen/rejected based on score
batch = compute_onlineDPO_pref(batch) # Adds 'preferences' key
# --- Prepare DPO Batch ---
dpo_update_batch_proto = None # Initialize
with _timer("prepare_dpo_batch", timing_raw):
try:
if "preferences" not in batch.batch or batch.batch["preferences"] is None:
raise ValueError("'preferences' key missing or None after compute_onlineDPO_pref.")
# Check if reference log probs were computed successfully (if needed)
if self.use_reference_policy and not ref_log_prob_computed:
raise ValueError("Reference log probs required but failed to compute.")
# Check required base keys
required_keys = ["input_ids", "attention_mask", "response_mask"]
for rk in required_keys:
if rk not in batch.batch or batch.batch[rk] is None:
raise KeyError(f"Required key '{rk}' missing from batch for DPO prep.")
preferences_mask = batch.batch["preferences"] # Shape [batch_size * n]
not_preferences_mask = ~preferences_mask
# Gather Chosen/Rejected Base Tensors
chosen_input_ids = batch.batch["input_ids"][preferences_mask]
chosen_attention_mask = batch.batch["attention_mask"][preferences_mask]
rejected_input_ids = batch.batch["input_ids"][not_preferences_mask]
rejected_attention_mask = batch.batch["attention_mask"][not_preferences_mask]
chosen_position_ids = (
batch.batch.get("position_ids")[preferences_mask]
if "position_ids" in batch.batch
else None
)
rejected_position_ids = (
batch.batch.get("position_ids")[not_preferences_mask]
if "position_ids" in batch.batch
else None
)
# Create Labels
print("WARNING: Creating DPO labels using configured max_prompt_length...")
prompt_len = self.config.data.max_prompt_length
chosen_labels = chosen_input_ids.clone()
chosen_labels[:, :prompt_len] = -100
rejected_labels = rejected_input_ids.clone()
rejected_labels[:, :prompt_len] = -100
# Calculate and Gather Reference Log Probs (Sequence Level)
if self.use_reference_policy:
ref_log_prob_tensor = batch.batch["ref_log_prob"] # Token level [bsz * n, seq_len]
response_mask_full = batch.batch[
"response_mask"
] # Response mask [bsz * n, seq_len]
ref_sequence_logps = (ref_log_prob_tensor * response_mask_full).sum(
dim=-1
) # Sequence level [bsz * n]
reference_chosen_logps = ref_sequence_logps[preferences_mask]
reference_rejected_logps = ref_sequence_logps[not_preferences_mask]
else:
# If not using external ref, DPO needs ActorAsRef logic in dp_actor
# We won't add the keys here, dp_actor will handle it (or fail if not modified)
print(
"Info: Not adding explicit reference logps to DPO batch "
"(use_reference_policy=False)."
)
reference_chosen_logps = None # Explicitly None
reference_rejected_logps = None
# Package Tensors
dpo_tensors = {
"chosen_input_ids": chosen_input_ids,
"chosen_attention_mask": chosen_attention_mask,
"chosen_labels": chosen_labels,
"rejected_input_ids": rejected_input_ids,
"rejected_attention_mask": rejected_attention_mask,
"rejected_labels": rejected_labels,
}
# Conditionally add reference logps if computed
if reference_chosen_logps is not None:
dpo_tensors["reference_chosen_logps"] = reference_chosen_logps
if reference_rejected_logps is not None:
dpo_tensors["reference_rejected_logps"] = reference_rejected_logps
# Add position ids if they exist
if chosen_position_ids is not None:
dpo_tensors["chosen_position_ids"] = chosen_position_ids
if rejected_position_ids is not None:
dpo_tensors["rejected_position_ids"] = rejected_position_ids
# Prepare Meta Info
dpo_meta = {
"dpo_beta": OmegaConf.select(self.config.algorithm, "dpo_beta", default=0.1),
"dpo_loss_type": OmegaConf.select(
self.config.algorithm, "dpo_loss_type", default="sigmoid"
),
"dpo_label_smoothing": OmegaConf.select(
self.config.algorithm, "dpo_label_smoothing", default=0.0
),
"use_reference_policy": self.use_reference_policy,
"reference_free": not self.use_reference_policy, # False if using external ref
"global_step": self.global_steps,
}
dpo_update_batch_proto = DataProto.from_dict(tensors=dpo_tensors, meta_info=dpo_meta)
# print(f"---- [Step {self.global_steps}] DEBUG DPO: Prepared DPO Update Batch ----")
# print(f" Keys: {list(dpo_update_batch_proto.batch.keys())}")
# print(f" Meta Info: {dpo_meta}")
except Exception as e_prep:
print(f"ERROR preparing DPO batch at step {self.global_steps}: {e_prep}")
traceback.print_exc()
dpo_update_batch_proto = None # Skip update on error
# --- Actor Update Step ---
actor_output = None
if self.config.trainer.critic_warmup <= self.global_steps and dpo_update_batch_proto:
with _timer("update_actor", timing_raw):
# Pass the batch containing reference log probs (if computed)
# The modified update_actor_dpo expects them if reference_free=False
actor_output = self.actor_rollout_wg.update_actor_dpo(dpo_update_batch_proto)
if actor_output and "metrics" in actor_output.meta_info:
metrics.update(reduce_metrics(actor_output.meta_info["metrics"]))
elif dpo_update_batch_proto is None:
print(
f"Skipping actor update at step {self.global_steps} due to DPO batch preparation error."
)
# --- Validation and Saving ---
test_freq = OmegaConf.select(self.config.trainer, "test_freq", default=-1)
is_last_step = self.global_steps >= self.total_training_steps
if (
self.val_reward_fn is not None
and test_freq > 0
and (is_last_step or self.global_steps % test_freq == 0)
):
print(f"\nRunning DPO validation at step {self.global_steps}...")
val_timing_raw = {}
with _timer("testing", val_timing_raw):
val_metrics: dict = self._validate()
if is_last_step:
last_val_metrics = val_metrics
if val_metrics:
metrics["time/validation_run"] = val_timing_raw.get("testing", 0)
metrics.update(val_metrics)
else:
print("Validation skipped or returned no metrics.")
save_freq = OmegaConf.select(self.config.trainer, "save_freq", default=-1)
if save_freq > 0 and (is_last_step or self.global_steps % save_freq == 0):
print(f"\nSaving DPO checkpoint at step {self.global_steps}...")
with _timer("save_checkpoint", timing_raw):
self._save_checkpoint() # Saves actor (and potentially critic if used elsewhere)
metrics["time/save_checkpoint"] = timing_raw.get("save_checkpoint", 0)
# --- End main step timer context ---
# --- Metrics calculation AFTER the 'step' timer block ---
metrics.update(compute_dpo_data_metrics(batch=batch)) # Use DPO-specific metrics
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
n_gpus = self.resource_pool_manager.get_n_gpus()
if "step" in timing_raw:
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
else:
print(
f"Warning: 'step' key missing from timing_raw at step {self.global_steps}. "
f"Skipping throughput."
)
step_timer.stop()
metrics["time/step"] = step_timer.last
# Log metrics
log_freq = OmegaConf.select(self.config.trainer, "log_freq", default=1)
if logger and self.global_steps % log_freq == 0:
log_payload = metrics.copy()
# Add learning rate to log payload
if actor_output and "actor/lr" in metrics:
log_payload["actor/lr"] = metrics["actor/lr"]
print(f"[Step {self.global_steps} DPO] Logging Step Payload Keys: {list(log_payload.keys())}")
try:
logger.log(data=log_payload, step=self.global_steps)
except Exception as e:
print(f"Logging failed at step {self.global_steps}: {e}")
# Update progress bar
postfix_metrics = {
k: f"{v:.3f}" if isinstance(v, float) else v
for k, v in metrics.items()
if isinstance(v, int | float)
}
progress_bar.set_postfix(postfix_metrics)
except Exception as step_e:
print(f"\n!!!!!!!! ERROR DURING DPO Step {self.global_steps} !!!!!!!!")
print(f"Caught Exception: {step_e}")
traceback.print_exc()
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
step_timer.stop()
should_stop = True
break
if is_last_step or should_stop:
print(f"Stopping DPO training at step {self.global_steps}.")
break
self.global_steps += 1
progress_bar.update(1)
# End of epoch handling
if hasattr(self.train_dataloader, "reset"):
try:
self.train_dataloader.reset()
except Exception as e:
print(f"Warning: Failed to reset train dataloader state: {e}")
if should_stop:
break
# --- Final cleanup and logging ---
progress_bar.close()
final_step = max(0, self.global_steps - 1)
print(f"Online DPO Training finished at step {final_step}.")
# Save final checkpoint
save_freq = OmegaConf.select(self.config.trainer, "save_freq", default=-1)
if not self.config.trainer.get("val_only", False) and (save_freq <= 0 or final_step % save_freq != 0):
print(f"Saving final DPO checkpoint at step {final_step}...")
self._save_checkpoint()
# Final validation run
if self.val_reward_fn and last_val_metrics is None and not self.config.trainer.get("val_only", False):
print("Running final validation...")
last_val_metrics = self._validate()
if last_val_metrics and logger:
last_val_metrics["final_validation"] = True
try:
logger.log(data=last_val_metrics, step=final_step)
except Exception as e:
print(f"[Final Val Metrics Log Error]: {e}")
pprint(f"Final validation metrics: {last_val_metrics}")
if logger and hasattr(logger, "finish"):
logger.finish()
print("Online DPO Training Run Complete.")
# SPPO: Self-Play Preference Optimization for Language Model Alignment
This repository hosts the community implementation for the paper [Self-Play Preference Optimization for Language Model Alignment](https://arxiv.org/abs/2405.00675). SPPO can significantly enhance the performance of an LLM without strong external signals such as responses or preferences from GPT-4. It can outperform the model trained with iterative direct preference optimization (DPO), among other methods. SPPO is theoretically grounded, ensuring that the LLM can converge to the von Neumann winner (i.e., Nash equilibrium) under general, potentially intransitive preference, and empirically validated through extensive evaluations on multiple datasets.
Paper Authors: [Yue Wu](https://yuewu.us/)\*, [Zhiqing Sun](https://www.cs.cmu.edu/~zhiqings/)\*, [Huizhuo Yuan](https://scholar.google.com/citations?user=8foZzX4AAAAJ)\*, [Kaixuan Ji](https://scholar.google.com/citations?user=FOoKDukAAAAJ), [Yiming Yang](https://www.cs.cmu.edu/~yiming/), [Quanquan Gu](https://web.cs.ucla.edu/~qgu/)
verl Implementation Authors: [Yuhao Yang](https://github.com/yhyang201), [Chenyang Zhao](https://github.com/zhaochenyang20)
[[Webpage](https://uclaml.github.io/SPPO/)] [[Huggingface](https://huggingface.co/papers/2405.00675)] [[Paper](https://arxiv.org/abs/2405.00675)][[Original Implementation](https://github.com/uclaml/SPPO)]
## Reproduce the Experiment
We evaluate the performance of SPPO on the MATH dataset. Starting from an initial score of 46.6 with Qwen2.5-7B-Instruct, we achieve a score of 65.6 after 20 epochs of training, placing our model approximately in the top 20 on the [MATH leaderboard](https://paperswithcode.com/sota/math-word-problem-solving-on-math). It's important to note that verl's internal evaluation metrics may not perfectly align with the official evaluation methodology for Qwen2.5-7B-Instruct. Therefore, for consistency and fair comparison, we report only the results based on verl's evaluation framework.
```
git clone git@github.com:volcengine/verl.git
cd verl
python3 -m uv pip install -e ".[sglang]"
export WANDB_API_KEY=<YOUR_WANDB_API_KEY>
python3 examples/data_preprocess/math_dataset.py --local_dir ~/data/math
huggingface-cli download Qwen/Qwen2.5-7B-Instruct --local-dir $HOME/models/Qwen2.5-7B-Instruct
export CUDA_VISIBLE_DEVICES=0,1,2,3
bash recipe/sppo/run_qwen2.5-7b_rm.sh
```
Note that the installation would occasionally fail to install flash-attn. If this happens, you can install it manually by running:
```bash
python3 -m uv pip install wheel
python3 -m uv pip install packaging
python3 -m uv pip install flash-attn --no-build-isolation --no-deps
```
## Acknowledgement
We sincerely thank the contribution and guidance from:
- [Yue Wu](https://yuewu.us/)
- [Chendong Wang](https://cdwang96.github.io/)
- [Yifan Zhang](https://github.com/yifanzhang-pro)
- [Yongan Xiang](https://github.com/BearBiscuit05)
- [Junrong Lin](https://github.com/ocss884)
- [Yuxuan Tong](https://github.com/tongyx361)
- [Guangming Shen](https://github.com/PeterSH6)
- [Biao He](https://www.linkedin.com/in/biao-he/)
- [Qingquan Song](https://qingquansong.github.io/)
- [Quanquan Gu](https://web.cs.ucla.edu/~qgu/)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from verl.workers.config import FSDPActorConfig
@dataclass
class SPPOActorConfig(FSDPActorConfig):
sppo_eta: float = 1.0
# the sppo config will override default ppo_trainer.yaml
hydra:
searchpath:
- file://verl/trainer/config
defaults:
- ppo_trainer
- _self_
actor_rollout_ref:
actor:
_target_: recipe.sppo.config.SPPOActorConfig
# sppo_eta is an additional hyperparameter for SPPO, not available in
# verl core. specifying _target_ with SPPOActorConfig is needed to
# extend verl ActorConfig with custom fields.
# additional, it is also possible to use the `extra` field natively supported
# by all verl core dataclasses, without having to define SPPOActorConfig
# extra:
# sppo_eta: 1.0
sppo_eta: 1.0
optim:
lr_warmup_steps: 15
rollout:
name: sglang
tensor_model_parallel_size: 2
gpu_memory_utilization: 0.5
val_kwargs:
n: 2 # 2 will trigger validation, 1 will bypass
algorithm:
adv_estimator: null
sppo_eta: 1.0
trainer:
log_val_generations: 0
\ No newline at end of file
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import torch
import verl.utils.torch_functional as verl_F
from verl import DataProto
from verl.trainer.ppo.core_algos import agg_loss, kl_penalty
from verl.utils.device import get_device_id
from verl.utils.profiler import GPUMemoryLogger
from verl.utils.py_functional import append_to_dict
from verl.utils.seqlen_balancing import rearrange_micro_batches
from verl.workers.actor.dp_actor import DataParallelPPOActor
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
def compute_sppo_loss(
old_log_prob: torch.Tensor, # (bs, seq_len)
log_prob: torch.Tensor, # (bs, seq_len)
rewards: torch.Tensor, # (bs,)
response_mask: torch.Tensor, # (bs, seq_len)
eta: float = 1.0,
loss_agg_mode: str = "token-mean",
):
"""
SPPO Loss computation.
"""
# Compute log-ratios over masked tokens
log_prob_sum = (log_prob * response_mask).sum(dim=1) # (bs,)
old_log_prob_sum = (old_log_prob * response_mask).sum(dim=1) # (bs,)
log_ratios = log_prob_sum - old_log_prob_sum # (bs,)
scaled_rewards = eta * (rewards)
loss_vec = (log_ratios - scaled_rewards) ** 2 # (bs,)
if loss_agg_mode == "token-mean":
sample_mask = response_mask.any(dim=1).float() # (bs,)
loss = verl_F.masked_mean(loss_vec, sample_mask)
return loss, log_ratios, scaled_rewards
class DataParallelSPPOActor(DataParallelPPOActor):
@GPUMemoryLogger(role="dp actor", logger=logger)
def update_policy(self, data: DataProto):
# make sure we are in training mode
self.actor_module.train()
temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid slient error
multi_turn = data.meta_info.get("multi_turn", False)
select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "old_log_probs", "seq_level_rewards"]
if multi_turn:
select_keys.append("loss_mask")
if self.config.use_kl_loss:
select_keys.append("ref_log_prob")
batch = data.select(batch_keys=select_keys).batch
has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
# Split to make minibatch iterator for updating the actor
# See PPO paper for details. https://arxiv.org/abs/1707.06347
if has_multi_modal_inputs:
num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size
non_tensor_select_keys = ["multi_modal_inputs"]
dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches)
else:
dataloader = batch.split(self.config.ppo_mini_batch_size)
metrics = {}
for epoch in range(self.config.ppo_epochs):
for batch_idx, data in enumerate(dataloader):
# split batch into micro_batches
mini_batch = data
if has_multi_modal_inputs:
self.gradient_accumulation = (
self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
)
num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu
micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches)
elif self.config.use_dynamic_bsz:
max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
else:
self.gradient_accumulation = (
self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
)
# split batch into micro_batches
micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)
self.actor_optimizer.zero_grad()
for data in micro_batches:
# Support all hardwares
if isinstance(data, DataProto):
data = {**data.batch.to(get_device_id()), **data.non_tensor_batch}
else:
data = data.to(get_device_id()) # actor device is cpu when using offload
responses = data["responses"]
response_length = responses.size(1)
attention_mask = data["attention_mask"]
if multi_turn:
response_mask = data["loss_mask"][:, -response_length:]
else:
response_mask = attention_mask[:, -response_length:]
old_log_prob = data["old_log_probs"]
rewards = data["seq_level_rewards"]
entropy_coeff = self.config.entropy_coeff
loss_agg_mode = self.config.loss_agg_mode
eta = self.config.get("sppo_eta", 1.0)
# all return: (bsz, response_length)
calculate_entropy = False
if entropy_coeff != 0:
calculate_entropy = True
entropy, log_prob = self._forward_micro_batch(
micro_batch=data, temperature=temperature, calculate_entropy=calculate_entropy
)
pg_loss, log_ratios, preference = compute_sppo_loss(
old_log_prob=old_log_prob,
log_prob=log_prob,
rewards=rewards,
response_mask=response_mask,
eta=eta,
loss_agg_mode=loss_agg_mode,
)
if entropy_coeff != 0:
entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
# compute policy loss
policy_loss = pg_loss - entropy_loss * entropy_coeff
else:
policy_loss = pg_loss
if self.config.use_kl_loss:
ref_log_prob = data["ref_log_prob"]
# compute kl loss
kld = kl_penalty(
logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type
)
kl_loss = agg_loss(
loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode
)
policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
metrics["actor/kl_loss"] = kl_loss.detach().item()
metrics["actor/kl_coef"] = self.config.kl_loss_coef
if self.config.use_dynamic_bsz:
# relative to the dynamic bsz
loss = policy_loss * (len(data) / self.config.ppo_mini_batch_size)
else:
loss = policy_loss / self.gradient_accumulation
loss.backward()
data = {
"actor/loss": loss.detach().item(),
"actor/log_ratio_mean": log_ratios.mean().detach().item(),
"actor/preference_mean": preference.mean().detach().item(),
}
append_to_dict(metrics, data)
grad_norm = self._optimizer_step()
data = {"actor/grad_norm": grad_norm.detach().item()}
append_to_dict(metrics, data)
self.actor_optimizer.zero_grad()
return metrics
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
import os
import hydra
import ray
from verl.trainer.ppo.reward import load_reward_manager
from .sppo_ray_trainer import RaySPPOTrainer
@hydra.main(config_path="config", config_name="sppo_trainer", version_base=None)
def main(config):
run_ppo(config)
def run_ppo(config) -> None:
# TODO(linjunrong.ocss884): this ENV is left for resolving SGLang conflict with ray devices
# isolation, will solve in the future
os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "")
if not ray.is_initialized():
# this is for local ray cluster
ray.init(
runtime_env={
"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"}
},
num_cpus=config.ray_init.num_cpus,
)
runner = TaskRunner.remote()
ray.get(runner.run.remote(config))
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
class TaskRunner:
def run(self, config):
# print initial config
from pprint import pprint
from omegaconf import OmegaConf
from verl.utils.fs import copy_to_local
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
# download the checkpoint from hdfs
local_path = copy_to_local(config.actor_rollout_ref.model.path)
# instantiate tokenizer
from verl.utils import hf_processor, hf_tokenizer
trust_remote_code = config.data.get("trust_remote_code", False)
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none
# define worker classes
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
assert config.critic.strategy in {"fsdp", "fsdp2"}
from verl.single_controller.ray import RayWorkerGroup
from .sppo_worker import SPPOActorRolloutRefWorker # , CriticWorker
actor_rollout_cls = SPPOActorRolloutRefWorker
ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == "megatron":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
from verl.workers.megatron_workers import ActorRolloutRefWorker
actor_rollout_cls = ActorRolloutRefWorker
ray_worker_group_cls = NVMegatronRayWorkerGroup
else:
raise NotImplementedError
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
# sppo does not use critic
role_worker_mapping = {
Role.ActorRollout: ray.remote(actor_rollout_cls),
}
global_pool_id = "global_pool"
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
}
# we should adopt a multi-source reward function here
# - for rule-based rm, we directly call a reward score
# - for model-based rm, we call a model
# - for code related prompt, we send to a sandbox if there are test cases
# - finally, we combine all the rewards together
# - The reward type depends on the tag of the data
if config.reward_model.enable:
if config.reward_model.strategy in {"fsdp", "fsdp2"}:
from verl.workers.fsdp_workers import RewardModelWorker
elif config.reward_model.strategy == "megatron":
from verl.workers.megatron_workers import RewardModelWorker
else:
raise NotImplementedError
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
mapping[Role.RewardModel] = global_pool_id
# use reference model
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
role_worker_mapping[Role.RefPolicy] = ray.remote(SPPOActorRolloutRefWorker)
mapping[Role.RefPolicy] = global_pool_id
reward_fn = load_reward_manager(
config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})
)
val_reward_fn = load_reward_manager(config, tokenizer, num_examine=1)
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
trainer = RaySPPOTrainer(
config=config,
tokenizer=tokenizer,
processor=processor,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
)
trainer.init_workers()
trainer.fit()
if __name__ == "__main__":
main()
# Discliamer: the model used in the script is only for academic purpose.
set -x
# Data preparation scripts are available in ``examples/data_preprocess``.
# Example usage:
#
# python3 examples/data_preprocess/math_dataset.py --local_dir ~/data/math
# python3 examples/data_preprocess/gsm8k.py --local_dir ~/data/gsm8k
gsm8k_train_path=$HOME/data/math/train.parquet
gsm8k_test_path=$HOME/data/math/test.parquet
train_files="['$gsm8k_train_path']"
test_files="['$gsm8k_test_path']"
# prepare model ckpt
huggingface-cli download Qwen/Qwen2.5-7B-Instruct --local-dir $HOME/models/Qwen2.5-7B-Instruct &
# huggingface-cli download sfairXC/FsfairX-LLaMA3-RM-v0.1 --local-dir $HOME/models/FsfairX-LLaMA3-RM-v0.1 &
wait
python3 -m recipe.sppo.main_sppo \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.train_batch_size=1024 \
data.max_prompt_length=1024 \
data.max_response_length=512 \
data.filter_overlong_prompts=True \
data.truncation='error' \
data.return_raw_chat=True \
actor_rollout_ref.model.path="$HOME/models/Qwen2.5-7B-Instruct" \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.name=sglang \
actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger='["console","wandb"]' \
trainer.project_name='sppo-sglang' \
trainer.val_before_train=True \
trainer.experiment_name='Qwen2-7B-Instruct_hybrid_rm' \
trainer.n_gpus_per_node=4 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=1 \
trainer.total_epochs=1000 $@
# Note that we set lr_warmup_steps = 15 in config/sppo_trainer.yaml
# The experiment will converge to 0.656 on MATH dataset after 20 epochs
\ No newline at end of file
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
FSDP PPO Trainer with Ray-based single controller.
This trainer supports model-agonistic model initialization with huggingface
"""
import uuid
from copy import deepcopy
from pprint import pprint
from typing import Optional
import numpy as np
import ray
import torch
from torch.utils.data import Dataset, Sampler
from tqdm import tqdm
from verl import DataProto
from verl.single_controller.ray import RayWorkerGroup
from verl.trainer.ppo import core_algos
from verl.trainer.ppo.core_algos import agg_loss
from verl.trainer.ppo.metric_utils import reduce_metrics
from verl.trainer.ppo.ray_trainer import (
AdvantageEstimator,
RayPPOTrainer,
ResourcePoolManager,
Role,
WorkerType,
apply_kl_penalty,
compute_response_mask,
)
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
from verl.utils.profiler.performance import simple_timer
from verl.utils.tracking import ValidationGenerationsLogger
def softmean(x: torch.Tensor, beta: float, dim: int = -1, keepdim: bool = False) -> torch.Tensor:
"""
Compute SoftMean_β(x) = (1/β) * log( (1/n) * Σ exp(β * x_i) )
Falls back to arithmetic mean when β=0.
"""
if beta == 0.0:
return x.mean(dim=dim, keepdim=keepdim)
# cast beta to tensor on same device/dtype
beta_t = x.new_tensor(beta)
# numerically-stable logsumexp(β x)
lse = torch.logsumexp(x * beta_t, dim=dim, keepdim=keepdim)
n = x.size(dim)
log_n = x.new_tensor(n).log()
return (lse - log_n) / beta_t
def compute_advantage(data: DataProto, beta=1.0):
rewards = data.batch["token_level_rewards"].sum(axis=-1) # (bs, )
s_mean = softmean(rewards, beta, keepdim=True) # (bs, )
rewards = rewards - s_mean # (bs, )
data.batch["seq_level_rewards"] = rewards # (bs, )
return data
class RaySPPOTrainer(RayPPOTrainer):
"""
Note that this trainer runs on the driver process on a single CPU/GPU node.
"""
# TODO: support each role have individual ray_worker_group_cls,
# i.e., support different backend of different role
def __init__(
self,
config,
tokenizer,
role_worker_mapping: dict[Role, WorkerType],
resource_pool_manager: ResourcePoolManager,
ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
processor=None,
reward_fn=None,
val_reward_fn=None,
train_dataset: Optional[Dataset] = None,
val_dataset: Optional[Dataset] = None,
collate_fn=None,
train_sampler: Optional[Sampler] = None,
device_name=None,
):
self.tokenizer = tokenizer
self.processor = processor
self.config = config
self.reward_fn = reward_fn
self.val_reward_fn = val_reward_fn
self.hybrid_engine = config.actor_rollout_ref.hybrid_engine
assert self.hybrid_engine, "Currently, only support hybrid engine"
if self.hybrid_engine:
assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}"
self.role_worker_mapping = role_worker_mapping
self.resource_pool_manager = resource_pool_manager
self.use_reference_policy = Role.RefPolicy in role_worker_mapping
self.use_rm = Role.RewardModel in role_worker_mapping
self.ray_worker_group_cls = ray_worker_group_cls
self.validation_generations_logger = ValidationGenerationsLogger()
self.device_name = device_name if device_name else self.config.trainer.device
# define in-reward KL control
# kl loss control currently not supported
if config.algorithm.use_kl_in_reward:
self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl)
self.use_critic = False
self._validate_config()
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the
worker group through RPC to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
from omegaconf import OmegaConf
from verl.utils.tracking import Tracking
logger = Tracking(
project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True),
)
self.global_steps = 0
# load checkpoint before doing anything
self._load_checkpoint()
# perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
val_metrics = self._validate()
pprint(f"Initial validation metrics: {val_metrics}")
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get("val_only", False):
return
# add tqdm
progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
# we start from step 1
self.global_steps += 1
last_val_metrics = None
for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
metrics = {}
timing_raw = {}
batch: DataProto = DataProto.from_single_dict(batch_dict)
# pop those keys for generation
batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
if "multi_modal_data" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("multi_modal_data")
if "raw_prompt" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("raw_prompt")
if "tools_kwargs" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("tools_kwargs")
gen_batch = batch.pop(
batch_keys=batch_keys_to_pop,
non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
)
gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
is_last_step = self.global_steps >= self.total_training_steps
with simple_timer("step", timing_raw):
# generate a batch
with simple_timer("gen", timing_raw):
if not self.async_rollout_mode:
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
else:
gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch)
timing_raw.update(gen_batch_output.meta_info["timing"])
gen_batch_output.meta_info.pop("timing", None)
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
with simple_timer("gen_max", timing_raw):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["do_sample"] = False
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
batch = batch.union(gen_baseline_output)
reward_baseline_tensor = self.reward_fn(batch)
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
batch.batch["reward_baselines"] = reward_baseline_tensor
del gen_baseline_batch, gen_baseline_output
batch.non_tensor_batch["uid"] = np.array(
[str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object
)
# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)
batch.batch["response_mask"] = compute_response_mask(batch)
# Balance the number of valid tokens across DP ranks.
# NOTE: This usually changes the order of data in the `batch`,
# which won't affect the advantage calculation (since it's based on uid),
# but might affect the loss calculation (due to the change of mini-batching).
# TODO: Decouple the DP balancing and mini-batching.
if self.config.trainer.balance_batch:
self._balance_batch(batch, metrics=metrics)
# compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
with simple_timer("reward", timing_raw):
# compute reward model score
if self.use_rm:
reward_tensor = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor)
if self.config.reward_model.launch_reward_fn_async:
future_reward = compute_reward_async.remote(batch, self.config, self.tokenizer)
else:
reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)
# recompute old_log_probs
with simple_timer("old_log_prob", timing_raw):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
entropys = old_log_prob.batch["entropys"]
response_masks = batch.batch["response_mask"]
loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
metrics.update(old_log_prob_metrics)
old_log_prob.batch.pop("entropys")
batch = batch.union(old_log_prob)
if self.use_reference_policy:
# compute reference log_prob
with simple_timer("ref", timing_raw):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
# compute values
if self.use_critic:
with simple_timer("values", timing_raw):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
with simple_timer("adv", timing_raw):
# we combine with rule-based rm
reward_extra_infos_dict: dict[str, list]
if self.config.reward_model.launch_reward_fn_async:
reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
batch.batch["token_level_scores"] = reward_tensor
if reward_extra_infos_dict:
batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})
# compute rewards. apply_kl_penalty if available
if self.config.algorithm.use_kl_in_reward:
batch, kl_metrics = apply_kl_penalty(
batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
)
metrics.update(kl_metrics)
else:
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
batch.batch["seq_level_rewards"] = batch.batch["token_level_scores"]
beta = self.config.algorithm.sppo_eta
batch = compute_advantage(batch, beta=beta)
# update critic
if self.use_critic:
with simple_timer("update_critic", timing_raw):
critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
metrics.update(critic_output_metrics)
# implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
with simple_timer("update_actor", timing_raw):
batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)
# Log rollout generations if enabled
rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
if rollout_data_dir:
with simple_timer("dump_rollout_generations", timing_raw):
print(batch.batch.keys())
inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True)
outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True)
scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist()
self._dump_generations(
inputs=inputs,
outputs=outputs,
scores=scores,
reward_extra_infos_dict=reward_extra_infos_dict,
dump_path=rollout_data_dir,
)
# validate
if (
self.val_reward_fn is not None
and self.config.trainer.test_freq > 0
and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
):
with simple_timer("testing", timing_raw):
val_metrics: dict = self._validate()
if is_last_step:
last_val_metrics = val_metrics
metrics.update(val_metrics)
if self.config.trainer.save_freq > 0 and (
is_last_step or self.global_steps % self.config.trainer.save_freq == 0
):
with simple_timer("save_checkpoint", timing_raw):
self._save_checkpoint()
# training metrics
metrics.update(
{
"training/global_step": self.global_steps,
"training/epoch": epoch,
}
)
# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)
if is_last_step:
pprint(f"Final validation metrics: {last_val_metrics}")
progress_bar.close()
return
progress_bar.update(1)
self.global_steps += 1
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from omegaconf import OmegaConf, open_dict
from verl.single_controller.base.decorator import Dispatch, register
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from verl.utils.flops_counter import FlopsCounter
from verl.utils.fsdp_utils import offload_fsdp_model_to_cpu, offload_fsdp_optimizer
from verl.utils.import_utils import import_external_libs
from verl.utils.profiler import log_gpu_memory_usage
from verl.workers.fsdp_workers import ActorRolloutRefWorker
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN"))
class SPPOActorRolloutRefWorker(ActorRolloutRefWorker):
"""
This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy
or a hybrid engine based on the config.rollout
"""
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
from .dp_actor import DataParallelSPPOActor
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get("external_lib", None))
override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
use_remove_padding = self.config.model.get("use_remove_padding", False)
use_fused_kernels = self.config.model.get("use_fused_kernels", False)
if self._is_actor or self._is_rollout:
# we need the model for actor and rollout
if self._is_actor:
optim_config = self.config.actor.optim
fsdp_config = self.config.actor.fsdp_config
else:
optim_config = None
fsdp_config = OmegaConf.create()
self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = (
self._build_model_optimizer(
model_path=self.config.model.path,
fsdp_config=fsdp_config,
optim_config=optim_config,
override_model_config=override_model_config,
use_remove_padding=use_remove_padding,
use_fused_kernels=use_fused_kernels,
enable_gradient_checkpointing=self.config.model.get("enable_gradient_checkpointing", False),
trust_remote_code=self.config.model.get("trust_remote_code", False),
use_liger=self.config.model.get("use_liger", False),
role="actor",
)
)
# get the original unwrapped module
self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
log_gpu_memory_usage("After offload actor model during init", logger=logger)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
log_gpu_memory_usage("After offload actor optimizer during init", logger=logger)
# load from checkpoint
if self._is_actor:
OmegaConf.set_struct(self.config.actor, True)
with open_dict(self.config.actor):
self.config.actor.use_remove_padding = use_remove_padding
self.config.actor.use_fused_kernels = use_fused_kernels
self.actor = DataParallelSPPOActor(
config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer
)
if self._is_rollout:
self.rollout, self.rollout_sharding_manager = self._build_rollout(
trust_remote_code=self.config.model.get("trust_remote_code", False)
)
if self._is_ref:
self.ref_module_fsdp = self._build_model_optimizer(
model_path=self.config.model.path,
fsdp_config=self.config.ref.fsdp_config,
optim_config=None,
override_model_config=override_model_config,
use_remove_padding=use_remove_padding,
use_fused_kernels=use_fused_kernels,
trust_remote_code=self.config.model.get("trust_remote_code", False),
use_liger=self.config.model.get("use_liger", False),
role="ref",
)[0]
OmegaConf.set_struct(self.config.ref, True)
with open_dict(self.config.ref):
self.config.ref.use_remove_padding = use_remove_padding
self.config.ref.use_fused_kernels = use_fused_kernels
self.ref_policy = DataParallelSPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp)
if self._is_actor:
self.flops_counter = FlopsCounter(self.actor_model_config)
self.checkpoint_manager = FSDPCheckpointManager(
model=self.actor_module_fsdp,
optimizer=self.actor.actor_optimizer,
lr_scheduler=self.actor_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_config=self.config.actor.checkpoint,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment