Commit 7f6cc211 authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
Pipeline #2874 failed with stages
in 0 seconds
# Recipe: One Step Off Policy Async Trainer
**Author:** `https://github.com/meituan-search`
Last updated: 07/17/2025.
## Introduction
### Background
The current reinforcement learning training process implemented by verl is synchronous, adhering to the algorithmic
workflows of established methods like PPO, GRPO, and DAPO. In each step, training samples are generated by the latest
model, and the model is updated after training completes. While this approach aligns with off-policy reinforcement
learning and stabilizes RL training, but it suffers from severe efficiency issues.
Model updates must wait for the longest output in the generation phase to complete.
During the generation of long-tail samples, GPUs remain idle, resulting in significant underutilization.
The more severe the long-tail problem in sample generation, the lower the overall training efficiency.
For example, in DAPO 32B training, the Rollout phase accounts for approximately 70% of the total time,
and increasing resources does not reduce the Rollout duration.
![DAPO 32B Math Performance](
https://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/docs/dapo_32b_math.png)
> source data: https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=nwusertongyuxuan361
### Solution
We have implemented the **One Step Off Async Trainer** to help alleviate this issue. This approach parallelizes the
generation and training processes, utilizing samples generated in the previous step for current training.
It also involves appropriately partitioning resources, allocating dedicated resources for generation while automatically
assigning the remainder to training. By reducing resources allocated to the generation phase, we mitigate GPU idle time
during long-tail sample generation. Throughout this process, generation and training parameters maintain a one-step off
policy.
![One Step Off Policy Diagram](
https://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/docs/one_step_off_policy.png)
> reference: [AReaL: A Large-Scale Asynchronous Reinforcement Learning System for Language Reasoning](
> https://arxiv.org/abs/2505.24298)
Our core contributions include:
1. **Parallel Generation and Training**:
Samples for the next batch are asynchronously generated while the current batch is being trained.
2. **Resource Isolation**:
Unlike `hybrid_engine`, this method requires explicit resource allocation for rollout, with remaining resources
automatically assigned to training.
3. **NCCL Parameter Synchronization**:
Employs NCCL communication primitives for seamless parameter transfer between generation and training modules.
### Experimental Results
- **Machine Configuration**: 2 nodes with 16 H20 GPUs each
- Generation: 4 GPUs
- Training: 12 GPUs
- **Model**: Qwen2.5-Math-7B
- **Rollout Configuration**:
- **Max Response Length**: FSDP2: 20,480 tokens; Megatron: 8,192 tokens
- **Algorithm**: DAPO
- **Rollout Engine**: vLLM
| training mode | engine | step | gen | wait_prev_gen | generate_sequences | old_log_prob | update_actor | total time | acc/best@32/mean | acc/maj@32/mean |
|------------------------|---------------|------|-----|---------------|--------------------|--------------|--------------|---------------|------------------|-----------------|
| colocate sync | VLLM+FSDP2 | 749 | 321 | - | 247 | 88 | 286 | 19h18m | 0.5948 | 0.417 |
| one-step-overlap async | VLLM+FSDP2 | 520 | - | 45 | 458 | 108 | 337 | 15h34m(+23%) | 0.6165 | 0.494 |
| colocate sync | VLLM+Megatron | 699 | 207 | - | 162 | 119 | 344 | 18h21m | 0.605 | 0.4217 |
| one-step-overlap async | VLLM+Megatron | 566 | - | 59 | 501 | 120 | 347 | 13h06m (+40%) | 0.6569 | 0.4038 |
* colocate sync: step ≈ gen + old_log_prob + update_actor
* one-step-overlap async: step ≈ wait_prev_gen + old_log_prob + update_actor
![One Step Off Megatron Performance](
https://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/docs/one_step_off_megatron.png)
> source data: https://wandb.ai/hou-zg-meituan/one-step-off-policy?nw=nwuserhouzg
## Implementation
### One Step Off Policy Async Pipline
Our implemented **One Step Off Policy Async Pipeline** integrates seamlessly into existing training logic at minimal
cost,
eliminating the need for additional sample storage management. The core mechanism uses `async_gen_next_batch`
for asynchronous rollout generation while maintaining continuous operation during epoch transitions
via `create_continuous_iterator`.
```python
# iterator generator, simplify one-step integration of the training process
def _create_continuous_iterator(self):
for epoch in range(self.config.trainer.total_epochs):
iterator = iter(self.train_dataloader)
for batch_dict in iterator:
yield epoch, batch_dict
# read next batch samples, parameters sync and launch asyn gen_seq
def _async_gen_next_batch(self, continuous_iterator):
# read train_data
try:
epoch, batch_dict = next(continuous_iterator)
except StopIteration:
return None
batch = DataProto.from_single_dict(batch_dict)
gen_batch = batch_pocess(batch)
# sync weights from actor to rollout
self.sync_rollout_weights()
# async generation
gen_batch_output = self.rollout_wg.async_generate_sequences(gen_batch)
# future encapsulated
return GenerationBatchFuture(epoch, batch, gen_batch_output)
continuous_iterator = self._create_continuous_iterator()
# run rollout first to achieve one-step-off
batch_data_future = self._async_gen_next_batch(continuous_iterator)
while batch_data_future is not None:
# wait for the gen_seq result from the previous step
batch = batch_data_future.get()
# launch the next async call to generate sequences
batch_data_future = self._async_gen_next_batch(continuous_iterator)
# compute advantages
batch = critic.compute_values(batch)
batch = reference.compute_log_prob(batch)
batch = reward.compute_reward(batch)
batch = compute_advantages(batch)
# model update
critic_metrics = critic.update_critic(batch)
actor_metrics = actor.update_actor(batch)
```
### Parameter Synchronization
The exciting point is that our nccl based weights updating for rollout model has great performance.
At most of time, the latency is under 300ms, which is negligible for RLHF.
> **sync_rollout_weights**:The time for synchronizing parameters from actor to rollout is extremely fast and can almost
> be ignored because it is implemented with nccl.
```python
class ActorRolloutRefWorker:
# actor acquires the meta-info of model parameters for parameter sync
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def get_actor_weights_info(self):
params = self._get_actor_params()
ret = []
for key, tensor in params.items():
ret.append((key, tensor.size(), tensor.dtype))
self._weights_info = ret
return ret
# rollout sets the meta-info of model parameters for parameter sync
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def set_actor_weights_info(self, weights_info):
self._weights_info = weights_info
class AsyncRayPPOTrainer(RayPPOTrainer):
def init_workers(self):
...
# rollout obtains the meta-info of model parameters from the actor for parameter sync
weights_info = self.actor_wg.get_actor_weights_info()[0]
self.rollout_wg.set_actor_weights_info(weights_info)
# Create an actor-rollout communication group for parameter sync
actor_rollout_workers = self.actor_wg.workers + self.rollout_wg.workers
collective.create_collective_group(
actor_rollout_workers,
len(actor_rollout_workers),
list(range(0, len(actor_rollout_workers))),
backend="nccl",
group_name="actor_rollout"
)
```
```python
# drive process call the actor and rollout respectively to sync parameters by nccl
def sync_rollout_weights(self):
self.actor_wg.sync_rollout_weights()
ray.get(self.rollout_wg.sync_rollout_weights())
# fsdp model parameter sync
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
def sync_rollout_weights(self):
params = self._get_actor_params() if self._is_actor else None
if self._is_rollout:
inference_model = (
self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
)
patch_vllm_moe_model_weight_loader(inference_model)
# Model parameters are broadcast tensor-by-tensor from actor to rollout
for key, shape, dtype in self._weights_info:
tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())
if self._is_actor:
assert key in params
origin_data = params[key]
if hasattr(origin_data, "full_tensor"):
origin_data = origin_data.full_tensor()
if torch.distributed.get_rank() == 0:
tensor.copy_(origin_data)
from ray.util.collective import collective
collective.broadcast(tensor, src_rank=0, group_name="actor_rollout")
if self._is_rollout:
inference_model.load_weights([(key, tensor)])
```
## Usage
### FSDP2 Configuration Example
```shell
python3 -m recipe.one_step_off_policy.async_main_ppo \
--config-path=config \
--config-name='one_step_off_ppo_trainer.yaml' \
actor_rollout_ref.actor.strategy=fsdp2 \
# actor and rollout are placed separately
actor_rollout_ref.hybrid_engine=False \
# actor and rollout resource
trainer.nnodes=1 \
trainer.n_gpus_per_node=6 \
rollout.nnodes=1 \
rollout.n_gpus_per_node=2
```
### Megatron Configuration Example
```shell
python3 -m recipe.one_step_off_policy.async_main_ppo \
--config-path=config \
--config-name='one_step_off_ppo_megatron_trainer.yaml' \
actor_rollout_ref.actor.strategy=megatron \
# actor and rollout are placed separately
actor_rollout_ref.hybrid_engine=False \
# actor and rollout resource
trainer.nnodes=1 \
trainer.n_gpus_per_node=6 \
rollout.nnodes=1 \
rollout.n_gpus_per_node=2
```
### Configuration Guidelines
1. **Card Number Relationships**
Maintain either of these relationships for optimal batch distribution:
- `actor_rollout_ref.rollout.n` should be an integer divisor of:
`trainer.n_gpus_per_node * trainer.nnodes`
- `actor_rollout_ref.rollout.n * data.train_batch_size` should be evenly divisible by:
`trainer.n_gpus_per_node * trainer.nnodes`
> Rationale: Ensures training samples can be evenly distributed across training GPUs when using partial resources for
generation.
2. **Dynamic Resource Tuning**
Adjust `trainer.nnodes` `trainer.n_gpus_per_node` `rollout.nnodes` `rollout.n_gpus_per_node` based on phase
durations:
- **Ideal state**: Rollout and training phases have comparable durations
- **Diagnostic metrics**:
- Monitor `wait_prev_gen` duration
- Analyze `sequence_length` distribution
- **Adjustment strategy**:
- High `wait_prev_gen` + uniform sequence lengths → Increase rollout resources
- High `wait_prev_gen` + long-tail sequences → Optimize stopping criteria (resource increase won't help)
> **wait_prev_gen**:The time consumed waiting for the previous rollout to end (the part that is not fully
overlapped).
**Resource Configuration Strategies:**
- **Resource-constrained scenario**: Optimize resource utilization by adjusting GPU allocation ratios,
keeping the number of nodes equal to allow training and rollout to share nodes;
- Configure `trainer.nnodes = rollout.nnodes` with
`trainer.n_gpus_per_node + rollout.n_gpus_per_node = physical_gpus_per_node`. Control rollout resource
allocation by adjusting `n_gpus_per_node`.
- **Resource-abundant scenario**: Optimize performance by adjusting the number of nodes,
keeping the number of GPUs per node equal to enable independent scaling of training and rollout
parallelism.
- Configure `trainer.n_gpus_per_node = rollout.n_gpus_per_node` and control rollout resource allocation by
adjusting `trainer.nnodes` and `rollout.nnodes`to achieve optimal performance.
> **Note**: The total number of nodes required by the system is not simply `trainer.nnodes + rollout.nnodes`. The
> actual calculation depends on GPU capacity:
> - When `trainer.n_gpus_per_node + rollout.n_gpus_per_node <= physical_gpus_per_node`,
> the required node count is `max(trainer.nnodes, rollout.nnodes)`
> - When `trainer.n_gpus_per_node + rollout.n_gpus_per_node > physical_gpus_per_node`,
> the required node count is `trainer.nnodes + rollout.nnodes`
## Functional Support
| Category | Support Situation |
|--------------------|-----------------------------------------------------------------------------------------------------------------|
| train engine | FSDP2 <br/> Megatron |
| rollout engine | vLLM |
| AdvantageEstimator | GRPO <br/> GRPO_PASSK <br/> REINFORCE_PLUS_PLUS <br/> RLOO <br/> OPO <br/> REINFORCE_PLUS_PLUS_BASELINE<br/>GPG |
| Reward | all |
hydra:
searchpath:
- file://verl/trainer/config
defaults:
- ppo_megatron_trainer
- _self_
# config for the rollout (only for resource isolation)
rollout:
# Number of nodes used in the rollout
nnodes: 1
# Number of GPUs per node
n_gpus_per_node: 8
\ No newline at end of file
hydra:
searchpath:
- file://verl/trainer/config
defaults:
- ppo_trainer
- _self_
# config for the rollout (only for resource isolation)
rollout:
# Number of nodes used in the rollout
nnodes: 1
# Number of GPUs per node
n_gpus_per_node: 8
\ No newline at end of file
#!/usr/bin/env bash
set -xeuo pipefail
project_name='DAPO'
exp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-one-step-off-4-12'
adv_estimator=grpo
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.28
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 8))
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 4))
overlong_penalty_factor=1.0
loss_agg_mode="token-mean"
train_prompt_bsz=512
n_resp_per_prompt=12
train_prompt_mini_bsz=32
# Ray
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
NNODES=${NNODES:-2}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
n_gpus_rollout=2
n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
val_top_p=0.7
# Performance Related Parameter
use_dynamic_bsz=True
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
ref_offload=True
actor_offload=False
gen_tp=2
sp_size=4
fsdp_size=2
python3 -m recipe.one_step_off_policy.main_ppo \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.train_batch_size=${train_prompt_bsz} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
actor_rollout_ref.actor.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.hybrid_engine=False \
+actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
reward_model.reward_manager=dapo \
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
trainer.logger=['console','tensorboard'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.val_before_train=True \
trainer.test_freq=10 \
trainer.save_freq=-1 \
trainer.total_epochs=10 \
trainer.total_training_steps=100 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=auto \
trainer.log_val_generations=10 \
trainer.nnodes="${NNODES}" \
trainer.n_gpus_per_node="${n_gpus_training}" \
rollout.nnodes="${NNODES}" \
rollout.n_gpus_per_node="${n_gpus_rollout}"
#!/usr/bin/env bash
set -xeuo pipefail
project_name='DAPO'
exp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-colocate'
adv_estimator=grpo
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.28
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 8))
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 4))
overlong_penalty_factor=1.0
loss_agg_mode="token-mean"
train_prompt_bsz=512
n_resp_per_prompt=12
train_prompt_mini_bsz=32
# Ray
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
NNODES=${NNODES:-2}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
val_top_p=0.7
# Performance Related Parameter
use_dynamic_bsz=True
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
offload=True
gen_tp=2
sp_size=4
fsdp_size=2
# reference run wandb: https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/runs/ow47vvon?nw=nwusertongyuxuan361
python3 -m verl.trainer.main_ppo \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.train_batch_size=${train_prompt_bsz} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
actor_rollout_ref.actor.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.model.use_remove_padding=True \
+actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
reward_model.reward_manager=dapo \
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
trainer.logger=['console','tensorboard'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
trainer.nnodes="${NNODES}" \
trainer.val_before_train=True \
trainer.test_freq=10 \
trainer.save_freq=-1 \
trainer.total_epochs=10 \
trainer.total_training_steps=100 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=auto \
trainer.log_val_generations=10
#!/usr/bin/env bash
set -xeuo pipefail
project_name='DAPO'
exp_name='DAPO-Qwen2.5-7b-MATH-0527a1-megatron-one-step-off-4-12'
adv_estimator=grpo
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.28
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 8))
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 4))
overlong_penalty_factor=1.0
loss_agg_mode="token-mean"
train_prompt_bsz=512
n_resp_per_prompt=12
train_prompt_mini_bsz=32
# Ray
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
NNODES=${NNODES:-2}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
n_gpus_rollout=2
n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
val_top_p=0.7
# Performance Related Parameter
use_dynamic_bsz=True
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
ref_offload=True
actor_offload=False
gen_tp=2
train_tp=2
train_pp=2
# TODO: support dynamic_bsz for megatron
# actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
# actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
# actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
# actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
# actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
# actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
python3 -m recipe.one_step_off_policy.main_ppo \
--config-path=config \
--config-name='one_step_off_ppo_megatron_trainer.yaml' \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.train_batch_size=${train_prompt_bsz} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
actor_rollout_ref.actor.strategy=megatron \
critic.strategy=megatron \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.hybrid_engine=False \
+actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.megatron.param_offload=${actor_offload} \
actor_rollout_ref.actor.megatron.optimizer_offload=${actor_offload} \
actor_rollout_ref.actor.megatron.grad_offload=${actor_offload} \
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.optim.clip_grad=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \
actor_rollout_ref.ref.megatron.param_offload=${ref_offload} \
reward_model.reward_manager=dapo \
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
trainer.logger=['console','tensorboard'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.val_before_train=True \
trainer.test_freq=10 \
trainer.save_freq=-1 \
trainer.total_epochs=10 \
trainer.total_training_steps=100 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=auto \
trainer.log_val_generations=10 \
trainer.nnodes="${NNODES}" \
trainer.n_gpus_per_node="${n_gpus_training}" \
rollout.nnodes="${NNODES}" \
rollout.n_gpus_per_node="${n_gpus_rollout}"
#!/usr/bin/env bash
set -xeuo pipefail
project_name='DAPO'
exp_name='DAPO-Qwen2.5-7b-MATH-0519a1-megatron-colocate'
adv_estimator=grpo
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.28
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 8))
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 4))
overlong_penalty_factor=1.0
loss_agg_mode="token-mean"
train_prompt_bsz=512
n_resp_per_prompt=16
train_prompt_mini_bsz=32
# Ray
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
NNODES=${NNODES:-2}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
val_top_p=0.7
# Performance Related Parameter
use_dynamic_bsz=True
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
offload=True
gen_tp=2
train_tp=2
train_pp=2
# TODO: support dynamic_bsz for megatron
# actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
# actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
# actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
# actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
# actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
# actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
python3 -m verl.trainer.main_ppo \
--config-path=config \
--config-name='ppo_megatron_trainer.yaml' \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.train_batch_size=${train_prompt_bsz} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
actor_rollout_ref.actor.strategy=megatron \
critic.strategy=megatron \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
+actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.megatron.param_offload=${offload} \
actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \
actor_rollout_ref.actor.megatron.grad_offload=${offload} \
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.optim.clip_grad=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \
actor_rollout_ref.ref.megatron.param_offload=${offload} \
reward_model.reward_manager=dapo \
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
trainer.logger=['console','tensorboard'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.n_gpus_per_node=8 \
trainer.nnodes="${NNODES}" \
trainer.val_before_train=True \
trainer.test_freq=10 \
trainer.save_freq=-1 \
trainer.total_epochs=10 \
trainer.total_training_steps=100 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=auto \
trainer.log_val_generations=10
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import torch
import torch.distributed
from omegaconf import DictConfig, OmegaConf
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import AutoConfig
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, register
from verl.utils import hf_processor, hf_tokenizer, omega_conf_to_dataclass
from verl.utils.debug import DistProfiler, DistProfilerExtension, log_gpu_memory_usage
from verl.utils.device import (
get_device_name,
get_nccl_backend,
get_torch_device,
)
from verl.utils.fs import copy_to_local
from verl.utils.fsdp_utils import (
fsdp_version,
)
from verl.utils.import_utils import import_external_libs
from verl.utils.model import get_generation_config, update_model_config
from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader
from verl.workers.fsdp_workers import ActorRolloutRefWorker as ARRWorker
from verl.workers.fsdp_workers import CriticWorker
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
device_name = get_device_name()
__all__ = ["ActorRolloutRefWorker", "AsyncActorRolloutRefWorker", "CriticWorker", "RolloutWorker"]
class ActorRolloutRefWorker(ARRWorker):
def _get_actor_params(self):
assert self._is_actor
params = self.actor_module_fsdp.state_dict()
from verl.utils.model import convert_weight_keys
params = convert_weight_keys(
params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp)
)
return params
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
def sync_rollout_weights(self):
assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine
assert hasattr(self, "_weights_info") and self._weights_info is not None
params = self._get_actor_params() if self._is_actor else None
if self._is_rollout:
inference_model = (
self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
)
patch_vllm_moe_model_weight_loader(inference_model)
for key, shape, dtype in self._weights_info:
tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())
if self._is_actor:
assert key in params
origin_data = params[key]
if hasattr(origin_data, "full_tensor"):
origin_data = origin_data.full_tensor()
if torch.distributed.get_rank() == 0:
tensor.copy_(origin_data)
from ray.util.collective import collective
collective.broadcast(tensor, src_rank=0, group_name="actor_rollout")
if self._is_rollout:
inference_model.load_weights([(key, tensor)])
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def get_actor_weights_info(self):
assert self._is_actor
if hasattr(self, "_weights_info"):
return self._weights_info
if fsdp_version(self.actor_module_fsdp) == 1:
from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType
FSDP.set_state_dict_type(
self.actor_module_fsdp,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(),
)
params = self._get_actor_params()
ret = []
for key, tensor in params.items():
ret.append((key, tensor.size(), tensor.dtype))
self._weights_info = ret
return ret
class RolloutWorker(ActorRolloutRefWorker):
def __init__(self, config: DictConfig, role: str):
Worker.__init__(self)
assert role == "rollout"
self.config = config
import torch.distributed
if not torch.distributed.is_initialized():
rank = int(os.environ.get("RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
torch.distributed.init_process_group(
backend=f"cpu:gloo,{get_device_name()}:{get_nccl_backend()}",
rank=rank,
world_size=world_size,
init_method=os.environ.get("DIST_INIT_METHOD", None),
)
# TODO(haibin.lin):
# As of now the type of config is DictConfig, if we assign config.profiler with ProfilerConfig,
# it will actually convert the ProfilerConfig dataclass back to a DictConfig.
# We can still use ProfilerConfig for testing purpose (tests/utils/test_nvtx_profile.py)
# as they provides DictConfig-like interface
# The benefit of creating the dataclass config is to perform validation during __post_init__
profiler_config = omega_conf_to_dataclass(config.rollout.get("profiler", {}))
DistProfilerExtension.__init__(self, DistProfiler(rank=self.rank, config=profiler_config))
self._is_rollout = True
self._is_actor = False
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get("external_lib", None))
override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
use_shm = self.config.model.get("use_shm", False)
local_path = copy_to_local(self.config.model.path, use_shm=use_shm)
trust_remote_code = self.config.model.get("trust_remote_code", False)
self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
self.processor = hf_processor(local_path, trust_remote_code=trust_remote_code)
if self.config.model.get("custom_chat_template", None) is not None:
if self.processor is not None:
self.processor.chat_template = self.config.model.custom_chat_template
else:
self.tokenizer.chat_template = self.config.model.custom_chat_template
# override model kwargs
actor_model_config = AutoConfig.from_pretrained(
local_path, trust_remote_code=trust_remote_code, attn_implementation="flash_attention_2"
)
# patch for kimi-vl
if getattr(actor_model_config, "model_type", None) == "kimi_vl":
actor_model_config.text_config.topk_method = "greedy"
self.generation_config = get_generation_config(local_path, trust_remote_code=trust_remote_code)
override_config_kwargs = {
"bos_token_id": self.tokenizer.bos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
}
override_config_kwargs.update(override_model_config)
update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs)
if self.rank == 0:
print(f"Model config after override: {actor_model_config}")
infer_tp = self.config.rollout.tensor_model_parallel_size
dp = self.world_size // infer_tp
assert self.world_size % infer_tp == 0, (
f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}"
)
rollout_device_mesh = init_device_mesh(
device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]
)
rollout_name = self.config.rollout.name
assert rollout_name == "vllm"
from verl.workers.rollout.vllm_rollout import vLLMRollout
log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger)
from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout
vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout
rollout = vllm_rollout_cls(
model_path=local_path,
config=self.config.rollout,
tokenizer=self.tokenizer,
model_hf_config=actor_model_config,
device_mesh=rollout_device_mesh,
trust_remote_code=trust_remote_code,
)
log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger)
from .vllm_sharding_manager import VLLMShardingManager
rollout_sharding_manager = VLLMShardingManager(
inference_engine=rollout.inference_engine, device_mesh=rollout_device_mesh
)
log_gpu_memory_usage("After building sharding manager", logger=logger)
self.rollout = rollout
self.rollout_sharding_manager = rollout_sharding_manager
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False)
def async_generate_sequences(self, *args, **kwargs):
return super().generate_sequences(*args, **kwargs)
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def set_actor_weights_info(self, weights_info):
assert self._is_rollout
self._weights_info = weights_info
class AsyncActorRolloutRefWorker(ActorRolloutRefWorker):
def __init__(self, *args, **kwargs):
raise NotImplementedError
set -x
project_name='GRPO'
exp_name='GRPO-Qwen3-0.6b-gsm8k-fsdp2-one-step-off-2-6'
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-0.6B"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/gsm8k/train.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/gsm8k/test.parquet"}
NNODES=${NNODES:-1}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
n_gpus_rollout=2
n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))
python3 -m recipe.one_step_off_policy.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.train_batch_size=1152 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
data.filter_overlong_prompts=True \
data.truncation='error' \
actor_rollout_ref.actor.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.hybrid_engine=False \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=192 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.rollout.load_format=safetensors \
actor_rollout_ref.rollout.layered_summon=True \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.val_before_train=True \
trainer.logger=['console','tensorboard'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.save_freq=-1 \
trainer.test_freq=5 \
trainer.total_epochs=2 \
trainer.nnodes="${NNODES}" \
trainer.n_gpus_per_node="${n_gpus_training}" \
rollout.nnodes="${NNODES}" \
rollout.n_gpus_per_node="${n_gpus_rollout}" $@
\ No newline at end of file
set -x
project_name='GRPO'
exp_name='GRPO-Qwen3-0.6b-gsm8k-fsdp2-one-step-off-2-6'
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen/Qwen2.5-3B-Instruct"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/gsm8k/train.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/gsm8k/test.parquet"}
NNODES=${NNODES:-1}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
n_gpus_rollout=2
n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))
python3 -m recipe.one_step_off_policy.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.train_batch_size=1152 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
data.filter_overlong_prompts=True \
data.truncation='error' \
actor_rollout_ref.actor.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.hybrid_engine=False \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=192 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.rollout.load_format=safetensors \
actor_rollout_ref.rollout.layered_summon=True \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.val_before_train=True \
trainer.logger=['console','tensorboard'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.save_freq=-1 \
trainer.test_freq=5 \
trainer.total_epochs=2 \
trainer.nnodes="${NNODES}" \
trainer.n_gpus_per_node="${n_gpus_training}" \
rollout.nnodes="${NNODES}" \
rollout.n_gpus_per_node="${n_gpus_rollout}" $@
\ No newline at end of file
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
import os
import socket
import hydra
import ray
from omegaconf import OmegaConf
from verl.trainer.constants_ppo import get_ppo_ray_runtime_env
from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler
from verl.trainer.ppo.reward import load_reward_manager
from .ray_trainer import OneStepOffRayTrainer
@hydra.main(config_path="config", config_name="one_step_off_ppo_trainer", version_base=None)
def main(config):
run_ppo(config)
# Define a function to run the PPO-like training process
def run_ppo(config) -> None:
# Check if Ray is not initialized
if not ray.is_initialized():
# Initialize Ray with a local cluster configuration
# Set environment variables in the runtime environment to control tokenizer parallelism,
# NCCL debug level, VLLM logging level, and allow runtime LoRA updating
# `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration
ray.init(
runtime_env=get_ppo_ray_runtime_env(),
num_cpus=config.ray_init.num_cpus,
)
# Create a remote instance of the TaskRunner class, and
# Execute the `run` method of the TaskRunner instance remotely and wait for it to complete
if (
OmegaConf.select(config.trainer, "profile_steps") is not None
and len(OmegaConf.select(config.trainer, "profile_steps")) > 0
):
nsight_options = OmegaConf.to_container(config.trainer.controller_nsight_options)
runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote()
else:
runner = TaskRunner.remote()
ray.get(runner.run.remote(config))
# [Optional] get the path of the timeline trace file from the configuration, default to None
# This file is used for performance analysis
timeline_json_file = config.ray_init.get("timeline_json_file", None)
if timeline_json_file:
ray.timeline(filename=timeline_json_file)
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
class TaskRunner:
def run(self, config):
# Print the initial configuration. `resolve=True` will evaluate symbolic values.
from pprint import pprint
from omegaconf import OmegaConf
from verl.utils.fs import copy_to_local
print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}")
pprint(OmegaConf.to_container(config, resolve=True))
OmegaConf.resolve(config)
# Download the checkpoint from HDFS to the local machine.
# `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on
local_path = copy_to_local(
config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False)
)
# Instantiate the tokenizer and processor.
from verl.utils import hf_processor, hf_tokenizer
trust_remote_code = config.data.get("trust_remote_code", False)
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
# Used for multimodal LLM, could be None
processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
# Define worker classes based on the actor strategy.
if config.actor_rollout_ref.actor.strategy == "fsdp2":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.single_controller.ray import RayWorkerGroup
from .fsdp_workers import (
ActorRolloutRefWorker,
AsyncActorRolloutRefWorker,
CriticWorker,
RolloutWorker,
)
actor_rollout_cls = (
AsyncActorRolloutRefWorker
if config.actor_rollout_ref.rollout.mode == "async"
else ActorRolloutRefWorker
)
ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == "megatron":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
from .megatron_workers import (
ActorRolloutRefWorker,
AsyncActorRolloutRefWorker,
CriticWorker,
RolloutWorker,
)
actor_rollout_cls = (
AsyncActorRolloutRefWorker
if config.actor_rollout_ref.rollout.mode == "async"
else ActorRolloutRefWorker
)
ray_worker_group_cls = NVMegatronRayWorkerGroup
else:
raise NotImplementedError
from .ray_trainer import ResourcePoolManager, Role
role_worker_mapping = {
Role.Actor: ray.remote(actor_rollout_cls),
Role.Rollout: ray.remote(RolloutWorker),
Role.Critic: ray.remote(CriticWorker),
}
global_pool_id = "actor_pool"
assert config.trainer.n_gpus_per_node > 0, "config.trainer.n_gpus_per_node must be greater than 0"
assert config.trainer.nnodes > 0, "config.trainer.nnodes must be greater than 0"
assert config.rollout.n_gpus_per_node > 0, "config.rollout.n_gpus_per_node must be greater than 0"
assert config.rollout.nnodes > 0, "config.rollout.nnodes must be greater than 0"
actor_pool = [config.trainer.n_gpus_per_node] * config.trainer.nnodes
rollout_pool = [config.rollout.n_gpus_per_node] * config.rollout.nnodes
resource_pool_spec = {
"actor_pool": actor_pool,
"rollout_pool": rollout_pool,
}
mapping = {
Role.Actor: "actor_pool",
Role.Rollout: "rollout_pool",
Role.Critic: "actor_pool",
}
print(f"resource_pool_spec: {resource_pool_spec}")
# We should adopt a multi-source reward function here:
# - for rule-based rm, we directly call a reward score
# - for model-based rm, we call a model
# - for code related prompt, we send to a sandbox if there are test cases
# finally, we combine all the rewards together
# The reward type depends on the tag of the data
if config.reward_model.enable:
if config.reward_model.strategy in ["fsdp2"]:
from verl.workers.fsdp_workers import RewardModelWorker
elif config.reward_model.strategy == "megatron":
from verl.workers.megatron_workers import RewardModelWorker
else:
raise NotImplementedError
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
mapping[Role.RewardModel] = global_pool_id
# Add a reference policy worker if KL loss or KL reward is used.
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
mapping[Role.RefPolicy] = global_pool_id
# Load the reward manager for training and validation.
reward_fn = load_reward_manager(
config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})
)
val_reward_fn = load_reward_manager(
config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {})
)
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
from verl.utils.dataset.rl_dataset import collate_fn
# Create training and validation datasets.
train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor)
val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor)
train_sampler = create_rl_sampler(config.data, train_dataset)
# Initialize the PPO trainer.
trainer = OneStepOffRayTrainer(
config=config,
tokenizer=tokenizer,
processor=processor,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
train_dataset=train_dataset,
val_dataset=val_dataset,
collate_fn=collate_fn,
train_sampler=train_sampler,
device_name=config.trainer.device,
)
# Initialize the workers of the trainer.
trainer.init_workers()
# Start the training process.
trainer.fit()
if __name__ == "__main__":
main()
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import torch
import torch.distributed
from omegaconf import DictConfig, OmegaConf
from verl.single_controller.base.decorator import Dispatch, register
from verl.utils.debug import (
log_gpu_memory_usage,
)
from verl.utils.device import get_device_name, get_torch_device
from verl.utils.fs import copy_to_local
from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader
from verl.workers.megatron_workers import ActorRolloutRefWorker as ARRWorker
from verl.workers.megatron_workers import CriticWorker, RewardModelWorker
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
__all__ = ["ActorRolloutRefWorker", "AsyncActorRolloutRefWorker", "CriticWorker", "RewardModelWorker", "RolloutWorker"]
class ActorRolloutRefWorker(ARRWorker):
def __init__(self, config: DictConfig, role: str):
assert role in ["actor", "ref"]
tmp_role = "ref" if role == "ref" else "actor_rollout"
super().__init__(config, tmp_role)
if role == "actor":
self._is_rollout = False
self.role = role
def _get_actor_params_generator(self):
assert self._is_actor
from verl.models.mcore import get_mcore_weight_converter
from verl.utils.megatron_utils import per_tensor_generator
layer_name_mapping = {
"qkv_layer_name": "self_attention.linear_qkv.",
"gate_proj_layer_name": "linear_fc1.",
}
weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype)
generator = per_tensor_generator(
self.actor.actor_module,
self.actor_model_config,
weight_converter,
self.tf_config,
layer_name_mapping,
)
return generator
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
def sync_rollout_weights(self):
assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine
assert hasattr(self, "_weights_info") and self._weights_info is not None
params_generator = self._get_actor_params_generator() if self._is_actor else None
if self._is_rollout:
inference_model = (
self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
)
patch_vllm_moe_model_weight_loader(inference_model)
for key, shape, dtype in self._weights_info:
if self._is_actor:
weight_key, weight = next(params_generator)
assert key == weight_key
assert shape == weight.size()
assert dtype == weight.dtype
tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())
if self._is_actor and torch.distributed.get_rank() == 0:
tensor.copy_(weight)
from ray.util.collective import collective
collective.broadcast(tensor, src_rank=0, group_name="actor_rollout")
if self._is_rollout:
inference_model.load_weights([(key, tensor)])
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def get_actor_weights_info(self):
assert self._is_actor
if hasattr(self, "_weights_info"):
return self._weights_info
params_generator = self._get_actor_params_generator()
ret = []
for key, tensor in params_generator:
ret.append((key, tensor.size(), tensor.dtype))
self._weights_info = ret
return ret
class RolloutWorker(ActorRolloutRefWorker):
def __init__(self, config: DictConfig, role: str):
assert role == "rollout"
ARRWorker.__init__(self, config, role)
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
if self.config.model.get("external_lib", None) is not None:
# This is used to import external_lib into the huggingface systems
import importlib
importlib.import_module(self.config.model.external_lib)
from verl.utils.torch_dtypes import PrecisionType
override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
override_transformer_config = {}
self.param_dtype = torch.bfloat16
self.dtype = PrecisionType.to_dtype(self.param_dtype)
trust_remote_code = self.config.model.get("trust_remote_code", False)
from verl.utils.model import get_generation_config
self._init_hf_config_and_tf_config(
self.config.model.path,
self.config.model.path,
self.dtype,
override_model_config,
override_transformer_config,
trust_remote_code,
)
self.generation_config = get_generation_config(self.local_path)
from torch.distributed.device_mesh import init_device_mesh
assert self.config.rollout.name == "vllm"
assert self.config.rollout.mode == "sync"
from verl.workers.rollout.vllm_rollout import vLLMRollout
from .vllm_sharding_manager import VLLMShardingManager
# NOTE(sgm): If the QKV and gate_up projection layer are concate together in actor,
# we will reorganize their weight format when resharding from actor to rollout.
infer_tp = self.config.rollout.tensor_model_parallel_size
dp = self.world_size // infer_tp
assert self.world_size % infer_tp == 0, (
f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}"
)
rollout_device_mesh = init_device_mesh(
get_device_name(), mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]
)
log_gpu_memory_usage("Before building vllm rollout", logger=None)
local_path = copy_to_local(self.config.model.path, use_shm=self.config.model.get("use_shm", False))
from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout
vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout
rollout = vllm_rollout_cls(
model_path=local_path,
config=self.config.rollout,
tokenizer=self.tokenizer,
model_hf_config=self.hf_config,
device_mesh=rollout_device_mesh,
trust_remote_code=trust_remote_code,
)
log_gpu_memory_usage("After building vllm rollout", logger=logger)
sharding_manager = VLLMShardingManager(
inference_engine=rollout.inference_engine,
device_mesh=rollout_device_mesh,
)
log_gpu_memory_usage("After building sharding manager", logger=logger)
self.rollout, self.sharding_manager = rollout, sharding_manager
self.rollout.sharding_manager = sharding_manager
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False)
def async_generate_sequences(self, *args, **kwargs):
return super().generate_sequences(*args, **kwargs)
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def set_actor_weights_info(self, weights_info):
assert self._is_rollout
self._weights_info = weights_info
class AsyncActorRolloutRefWorker(ActorRolloutRefWorker):
def __init__(self, *args, **kwargs):
raise NotImplementedError
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
# Copyright 2025 ModelBest Inc. and/or its affiliates
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This trainer supports model-agonistic model initialization with huggingface
"""
import uuid
from pprint import pprint
import numpy as np
import ray
import torch
from omegaconf import OmegaConf
from torch.utils.data import Dataset, Sampler
from tqdm import tqdm
from verl import DataProto
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
from verl.single_controller.ray.base import create_colocated_worker_cls
from verl.trainer.ppo import core_algos
from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss
from verl.trainer.ppo.metric_utils import (
compute_data_metrics,
compute_throughout_metrics,
compute_timing_metrics,
)
from verl.trainer.ppo.ray_trainer import (
RayPPOTrainer,
ResourcePoolManager,
Role,
WorkerType,
apply_kl_penalty,
compute_advantage,
compute_response_mask,
)
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
from verl.utils.debug import marked_timer
from verl.utils.metric import (
reduce_metrics,
)
from verl.utils.tracking import ValidationGenerationsLogger
class GenerationBatchFuture:
"""
Wrapper class for encapsulating batch generation results
"""
def __init__(self, epoch, batch, gen_batch_output):
"""
:param epoch: current epoch
:param batch: Input batch data
:param gen_batch_output: Generated sequences from the main model (DataProtoFuture)
"""
self.epoch = epoch
self.batch = batch
self.gen_batch_output = gen_batch_output
def get(self):
"""
Get the actual results by calling get() method on gen_batch_output
Returns:
tuple: (batch, gen_batch_result)
- batch: Original input batch data
- gen_batch_result: Result from gen_batch_output.get() or gen_batch_output itself
"""
# Call get() method on gen_batch_output if available
if hasattr(self.gen_batch_output, "get"):
gen_batch_result = self.gen_batch_output.get()
else:
gen_batch_result = self.gen_batch_output
return self.epoch, self.batch, gen_batch_result
class OneStepOffRayTrainer(RayPPOTrainer):
# TODO: support each role have individual ray_worker_group_cls,
# i.e., support different backend of different role
def __init__(
self,
config,
tokenizer,
role_worker_mapping: dict[Role, WorkerType],
resource_pool_manager: ResourcePoolManager,
ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
processor=None,
reward_fn=None,
val_reward_fn=None,
train_dataset: Dataset | None = None,
val_dataset: Dataset | None = None,
collate_fn=None,
train_sampler: Sampler | None = None,
device_name="cuda",
):
"""
Initialize distributed PPO trainer with Ray backend.
Note that this trainer runs on the driver process on a single CPU/GPU node.
Args:
config: Configuration object containing training parameters.
tokenizer: Tokenizer used for encoding and decoding text.
role_worker_mapping (dict[Role, WorkerType]): Mapping from roles to worker classes.
resource_pool_manager (ResourcePoolManager): Manager for Ray resource pools.
ray_worker_group_cls (RayWorkerGroup, optional): Class for Ray worker groups. Defaults to RayWorkerGroup.
processor: Optional data processor, used for multimodal data
reward_fn: Function for computing rewards during training.
val_reward_fn: Function for computing rewards during validation.
train_dataset (Optional[Dataset], optional): Training dataset. Defaults to None.
val_dataset (Optional[Dataset], optional): Validation dataset. Defaults to None.
collate_fn: Function to collate data samples into batches.
train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None.
device_name (str, optional): Device name for training (e.g., "cuda", "cpu"). Defaults to "cuda".
"""
# Store the tokenizer for text processing
self.tokenizer = tokenizer
self.processor = processor
self.config = config
self.reward_fn = reward_fn
self.val_reward_fn = val_reward_fn
self.hybrid_engine = config.actor_rollout_ref.hybrid_engine
assert not self.hybrid_engine
self.role_worker_mapping = role_worker_mapping
self.resource_pool_manager = resource_pool_manager
self.use_reference_policy = Role.RefPolicy in role_worker_mapping
self.use_rm = Role.RewardModel in role_worker_mapping
self.ray_worker_group_cls = ray_worker_group_cls
self.device_name = device_name
self.validation_generations_logger = ValidationGenerationsLogger()
# if ref_in_actor is True, the reference policy will be actor without lora applied
self.ref_in_actor = config.actor_rollout_ref.model.get("lora_rank", 0) > 0
# define in-reward KL control
# kl loss control currently not suppoorted
if config.algorithm.use_kl_in_reward:
self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl)
if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE:
self.use_critic = True
elif self.config.algorithm.adv_estimator in [
AdvantageEstimator.GRPO,
AdvantageEstimator.GRPO_PASSK,
AdvantageEstimator.REINFORCE_PLUS_PLUS,
# AdvantageEstimator.REMAX, # TODO:REMAX advantage estimator is not yet supported in one_step_off_policy
AdvantageEstimator.RLOO,
AdvantageEstimator.OPO,
AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE,
AdvantageEstimator.GPG,
]:
self.use_critic = False
else:
raise NotImplementedError
self._validate_config()
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
def _validate(self):
self.actor_rollout_wg = self.rollout_wg
ret = super()._validate()
self.actor_rollout_wg = self.actor_wg
return ret
def init_workers(self):
"""Initialize distributed training workers using Ray backend.
Creates:
1. Ray resource pools from configuration
2. Worker groups for each role (actor, critic, etc.)
"""
self.resource_pool_manager.create_resource_pool()
self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}
# create actor and rollout
for role, role_name in [(Role.Actor, "actor"), (Role.Rollout, "rollout")]:
resource_pool = self.resource_pool_manager.get_resource_pool(role)
role_cls = RayClassWithInitArgs(
cls=self.role_worker_mapping[role],
config=self.config.actor_rollout_ref,
role=role_name,
)
self.resource_pool_to_cls[resource_pool][role_name] = role_cls
# create critic
if self.use_critic:
resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic)
self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls
# create reference policy if needed
if self.use_reference_policy:
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
ref_policy_cls = RayClassWithInitArgs(
self.role_worker_mapping[Role.RefPolicy],
config=self.config.actor_rollout_ref,
role="ref",
profile_option=self.config.trainer.npu_profile.options,
)
self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls
# create a reward model if reward_fn is None
if self.use_rm:
# we create a RM here
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)
self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls
# initialize WorkerGroup
# NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
# you should not use `create_colocated_worker_cls`.
# Instead, directly pass different resource pool to different worker groups.
# See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
all_wg = {}
wg_kwargs = {} # Setting up kwargs for RayWorkerGroup
if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None:
wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout
if OmegaConf.select(self.config.trainer, "profile_steps") is not None:
wg_kwargs["profile_steps"] = OmegaConf.select(self.config.trainer, "profile_steps")
assert OmegaConf.select(self.config.trainer, "worker_nsight_options") is not None, (
"worker_nsight_options must be set when profile_steps is set"
)
wg_kwargs["worker_nsight_options"] = OmegaConf.to_container(
OmegaConf.select(self.config.trainer, "worker_nsight_options")
)
for resource_pool, class_dict in self.resource_pool_to_cls.items():
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
wg_dict = self.ray_worker_group_cls(
resource_pool=resource_pool,
ray_cls_with_init=worker_dict_cls,
device_name=self.device_name,
**wg_kwargs,
)
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
all_wg.update(spawn_wg)
if self.use_critic:
self.critic_wg = all_wg["critic"]
self.critic_wg.init_model()
if self.use_reference_policy and not self.ref_in_actor:
self.ref_policy_wg = all_wg["ref"]
self.ref_policy_wg.init_model()
if self.use_rm:
self.rm_wg = all_wg["rm"]
self.rm_wg.init_model()
self.actor_wg = all_wg["actor"]
self.rollout_wg = all_wg["rollout"]
self.actor_wg.init_model()
self.rollout_wg.init_model()
self.actor_rollout_wg = self.actor_wg # to be compatible with the functions that not be modified
weights_info = self.actor_wg.get_actor_weights_info()[0]
self.rollout_wg.set_actor_weights_info(weights_info)
from ray.util.collective import collective
actor_rollout_workers = self.actor_wg.workers + self.rollout_wg.workers
collective.create_collective_group(
actor_rollout_workers,
len(actor_rollout_workers),
list(range(0, len(actor_rollout_workers))),
backend="nccl",
group_name="actor_rollout",
)
self.sync_rollout_weights()
# create async rollout manager and request scheduler
self.async_rollout_mode = False
if self.config.actor_rollout_ref.rollout.mode == "async" and self._is_rollout:
from verl.workers.rollout.async_server import AsyncLLMServerManager
self.async_rollout_mode = True
self.async_rollout_manager = AsyncLLMServerManager(
config=self.config,
worker_group=self.rollout_wg,
)
def sync_rollout_weights(self):
if not self.hybrid_engine:
self.actor_wg.sync_rollout_weights()
ray.get(self.rollout_wg.sync_rollout_weights())
def _create_continuous_iterator(self):
"""
Create a continuous data iterator across epoch
"""
for epoch in range(self.config.trainer.total_epochs):
iterator = iter(self.train_dataloader)
for batch_dict in iterator:
yield epoch, batch_dict
def _async_gen_next_batch(self, continuous_iterator):
"""
Call parameter synchronization and asynchronous sequence generation.
"""
try:
epoch, batch_dict = next(continuous_iterator)
except StopIteration:
return None
except Exception as e:
print(f"Error in async_gen_next_batch: {e}")
return None
batch = DataProto.from_single_dict(batch_dict)
# pop those keys for generation
batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
if "multi_modal_data" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("multi_modal_data")
if "raw_prompt" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("raw_prompt")
if "tools_kwargs" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("tools_kwargs")
if "interaction_kwargs" in batch.non_tensor_batch:
non_tensor_batch_keys_to_pop.append("interaction_kwargs")
gen_batch = batch.pop(
batch_keys=batch_keys_to_pop,
non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
)
gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
# sync weights from actor to rollout
self.sync_rollout_weights()
# async generation
gen_batch_output = self.rollout_wg.async_generate_sequences(gen_batch)
return GenerationBatchFuture(epoch, batch, gen_batch_output)
def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC
to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
from omegaconf import OmegaConf
from verl.utils.tracking import Tracking
logger = Tracking(
project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True),
)
self.global_steps = 0
# load checkpoint before doing anything
self._load_checkpoint()
# perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
val_metrics = self._validate()
assert val_metrics, f"{val_metrics=}"
pprint(f"Initial validation metrics: {val_metrics}")
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get("val_only", False):
return
# add tqdm
progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
# we start from step 1
self.global_steps += 1
last_val_metrics = None
# across epoch iterator
continuous_iterator = self._create_continuous_iterator()
# Start the first asynchronous generation task.
batch_data_future = self._async_gen_next_batch(continuous_iterator)
while batch_data_future is not None:
do_profile = (
self.global_steps in self.config.trainer.profile_steps
if self.config.trainer.profile_steps is not None
else False
)
if do_profile:
self.actor_wg.start_profile()
if not self.hybrid_engine:
self.rollout_wg.start_profile()
if self.use_reference_policy:
self.ref_policy_wg.start_profile()
if self.use_critic:
self.critic_wg.start_profile()
if self.use_rm:
self.rm_wg.start_profile()
metrics = {}
timing_raw = {}
is_last_step = self.global_steps >= self.total_training_steps
with marked_timer("step", timing_raw):
# wait for the previous batch
with marked_timer("wait_prev_gen", timing_raw, color="red"):
epoch, batch, gen_batch_output = batch_data_future.get()
timing_raw.update(gen_batch_output.meta_info["timing"])
gen_batch_output.meta_info.pop("timing", None)
# asys next generation (with syns weights from actor to rollout)
with marked_timer("sync_rollout_weights", timing_raw, color="purple"):
if not is_last_step:
batch_data_future = self._async_gen_next_batch(continuous_iterator)
batch.non_tensor_batch["uid"] = np.array(
[str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object
)
# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)
batch.batch["response_mask"] = compute_response_mask(batch)
# Balance the number of valid tokens across DP ranks.
# NOTE: This usually changes the order of data in the `batch`,
# which won't affect the advantage calculation (since it's based on uid),
# but might affect the loss calculation (due to the change of mini-batching).
# TODO: Decouple the DP balancing and mini-batching.
if self.config.trainer.balance_batch:
self._balance_batch(batch, metrics=metrics)
# compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
with marked_timer("reward", timing_raw, color="yellow"):
# compute reward model score
if self.use_rm:
reward_tensor = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor)
if self.config.reward_model.launch_reward_fn_async:
future_reward = compute_reward_async.remote(batch, self.config, self.tokenizer)
else:
reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)
# recompute old_log_probs
with marked_timer("old_log_prob", timing_raw, color="blue"):
old_log_prob = self.actor_wg.compute_log_prob(batch)
entropys = old_log_prob.batch["entropys"]
response_masks = batch.batch["response_mask"]
loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
metrics.update(old_log_prob_metrics)
old_log_prob.batch.pop("entropys")
batch = batch.union(old_log_prob)
if "rollout_log_probs" in batch.batch.keys():
# TODO: we may want to add diff of probs too.
rollout_old_log_probs = batch.batch["rollout_log_probs"]
actor_old_log_probs = batch.batch["old_log_probs"]
attention_mask = batch.batch["attention_mask"]
responses = batch.batch["responses"]
response_length = responses.size(1)
response_mask = attention_mask[:, -response_length:]
rollout_probs = torch.exp(rollout_old_log_probs)
actor_probs = torch.exp(actor_old_log_probs)
rollout_probs_diff = torch.abs(rollout_probs - actor_probs)
rollout_probs_diff = torch.masked_select(rollout_probs_diff, response_mask.bool())
rollout_probs_diff_max = torch.max(rollout_probs_diff)
rollout_probs_diff_mean = torch.mean(rollout_probs_diff)
rollout_probs_diff_std = torch.std(rollout_probs_diff)
metrics.update(
{
"training/rollout_probs_diff_max": rollout_probs_diff_max.detach().item(),
"training/rollout_probs_diff_mean": rollout_probs_diff_mean.detach().item(),
"training/rollout_probs_diff_std": rollout_probs_diff_std.detach().item(),
}
)
if self.use_reference_policy:
# compute reference log_prob
with marked_timer("ref", timing_raw, color="olive"):
if not self.ref_in_actor:
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
else:
ref_log_prob = self.actor_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
# compute values
if self.use_critic:
with marked_timer("values", timing_raw, color="cyan"):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
with marked_timer("adv", timing_raw, color="brown"):
# we combine with rule-based rm
reward_extra_infos_dict: dict[str, list]
if self.config.reward_model.launch_reward_fn_async:
reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
batch.batch["token_level_scores"] = reward_tensor
if reward_extra_infos_dict:
batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})
# compute rewards. apply_kl_penalty if available
if self.config.algorithm.use_kl_in_reward:
batch, kl_metrics = apply_kl_penalty(
batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
)
metrics.update(kl_metrics)
else:
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
# compute advantages, executed on the driver process
norm_adv_by_std_in_grpo = self.config.algorithm.get(
"norm_adv_by_std_in_grpo", True
) # GRPO adv normalization factor
batch = compute_advantage(
batch,
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.actor_rollout_ref.rollout.n,
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
config=self.config.algorithm,
)
# update critic
if self.use_critic:
with marked_timer("update_critic", timing_raw, color="pink"):
critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
metrics.update(critic_output_metrics)
# implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
with marked_timer("update_actor", timing_raw, color="red"):
batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable
actor_output = self.actor_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)
# Log rollout generations if enabled
rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
if rollout_data_dir:
with marked_timer("dump_rollout_generations", timing_raw, color="green"):
inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True)
outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True)
scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist()
self._dump_generations(
inputs=inputs,
outputs=outputs,
scores=scores,
reward_extra_infos_dict=reward_extra_infos_dict,
dump_path=rollout_data_dir,
)
# validate
if (
self.val_reward_fn is not None
and self.config.trainer.test_freq > 0
and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
):
with marked_timer("testing", timing_raw, color="green"):
val_metrics: dict = self._validate()
if is_last_step:
last_val_metrics = val_metrics
metrics.update(val_metrics)
if self.config.trainer.save_freq > 0 and (
is_last_step or self.global_steps % self.config.trainer.save_freq == 0
):
with marked_timer("save_checkpoint", timing_raw, color="green"):
self._save_checkpoint()
# training metrics
metrics.update(
{
"training/global_step": self.global_steps,
"training/epoch": epoch,
}
)
# collect metrics
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
# TODO: implement actual tflpo and theoretical tflpo
n_gpus = self.resource_pool_manager.get_n_gpus()
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)
progress_bar.update(1)
self.global_steps += 1
if do_profile:
self.actor_wg.stop_profile()
if not self.hybrid_engine:
self.rollout_wg.stop_profile()
if self.use_reference_policy:
self.ref_policy_wg.stop_profile()
if self.use_critic:
self.critic_wg.stop_profile()
if self.use_rm:
self.rm_wg.stop_profile()
if is_last_step:
pprint(f"Final validation metrics: {last_val_metrics}")
progress_bar.close()
return
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from torch.distributed.device_mesh import DeviceMesh
from verl import DataProto
from verl.protocol import all_gather_data_proto
from verl.third_party.vllm import parallel_state as vllm_ps
from verl.utils.debug import GPUMemoryLogger
from verl.utils.device import get_torch_device
from verl.utils.torch_functional import check_device_is_available
from verl.workers.sharding_manager.base import BaseShardingManager
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
class VLLMShardingManager(BaseShardingManager):
@check_device_is_available()
def __init__(self, inference_engine, device_mesh: DeviceMesh):
self.device_mesh = device_mesh
self.inference_engine = inference_engine
inference_engine.wake_up()
assert device_mesh is not None
assert inference_engine is not None
self.tp_size = self.device_mesh["infer_tp"].size()
self.tp_rank = self.device_mesh["infer_tp"].get_local_rank()
self.timing = {}
gen_dp_rank = self.device_mesh["dp"].get_local_rank()
get_torch_device().manual_seed(gen_dp_rank + 1000)
self.gen_random_states = get_torch_device().get_rng_state()
@GPUMemoryLogger(role="vllm sharding_manager", logger=logger)
def __enter__(self):
get_torch_device().set_rng_state(self.gen_random_states)
@GPUMemoryLogger(role="vllm sharding_manager", logger=logger)
def __exit__(self, exc_type, exc_value, traceback):
self.gen_random_states = get_torch_device().get_rng_state()
self.inference_engine.reset_prefix_cache()
@GPUMemoryLogger(role="vllm sharding_manager", logger=logger)
def preprocess_data(self, data: DataProto) -> DataProto:
"""All gather across tp group to make each rank has identical input."""
if self.tp_size == 1:
return data
group = vllm_ps.get_tensor_model_parallel_group().device_group
all_gather_data_proto(data=data, process_group=group)
return data
@GPUMemoryLogger(role="vllm sharding_manager", logger=logger)
def postprocess_data(self, data: DataProto) -> DataProto:
"""Get chunk data of this tp rank since we do all gather in preprocess."""
if self.tp_size == 1:
return data
return data.chunk(chunks=self.tp_size)[self.tp_rank]
# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# the prime config will override default ppo_trainer.yaml
hydra:
searchpath:
- file://verl/trainer/config
defaults:
- ppo_trainer
- _self_
data:
filter_accuracy: True
accuracy_lower_bound: 0.2
accuracy_upper_bound: 0.8
oversample_factor: 4.0 # Sample more responses than the batch size. prompts satisfying the filter will be prioritized.
filter_truncate: True
truncation: right
actor_rollout_ref:
hybrid_engine: True
model:
use_remove_padding: True
rollout:
# number of responses (i.e. num sample times)
n: 4
actor:
entropy_coeff: 0.001
reward_model:
enable: True
strategy: fsdp
model:
ref_path: ${reward_model.model.path}
use_remove_padding: True
use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels}
fused_kernel_options:
impl_backend: torch # triton, torch
tokenizer_path: ${actor_rollout_ref.model.path}
enable_gradient_checkpointing: ${actor_rollout_ref.model.enable_gradient_checkpointing}
ref_type: freeze
fsdp_config:
min_num_params: 0
param_offload: ${actor_rollout_ref.actor.fsdp_config.param_offload}
optimizer_offload: ${actor_rollout_ref.actor.fsdp_config.optimizer_offload}
update: before # ``before`` for double-forward, ``after`` for single-forward
optim:
lr: 1e-6
lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio.
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null
warmup_style: constant
total_training_steps: -1 # must be overridden by program
weight_decay: 0.
grad_clip: 10.0
beta_train: 0.05
loss_type: ce # currently only supports ce loss
prime_granularity: token
prime_norm: batch_norm # batch_norm or none. if set to none, the normalizer is beta_train
mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
reward_manager: prime
algorithm:
adv_estimator: rloo
# now supports rloo. it treats different source of reward separately.
kl_ctrl:
type: fixed
kl_coef: 0.000
reward_gt_coef: 5
reward_dpo_coef: 5
trainer:
project_name: prime
experiment_name: examples
val_before_train: False
balance_batch: False
# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
import hydra
import ray
from .prime_ray_trainer import RayPRIMETrainer
@hydra.main(config_path="config", config_name="prime_trainer", version_base=None)
def main(config):
run_prime(config)
def run_prime(config, compute_score=None):
if not ray.is_initialized():
# this is for local ray cluster
ray.init(
runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}},
num_cpus=config.ray_init.num_cpus,
)
ray.get(main_task.remote(config, compute_score))
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
def main_task(config, compute_score=None):
# print initial config
from pprint import pprint
from omegaconf import OmegaConf
from verl.utils.fs import copy_local_path_from_hdfs
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
# download the checkpoint from hdfs
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
# instantiate tokenizer
from verl.utils import hf_tokenizer
tokenizer = hf_tokenizer(local_path)
# define worker classes
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
assert config.critic.strategy in {"fsdp", "fsdp2"}
from verl.single_controller.ray import RayWorkerGroup
from verl.workers.fsdp_workers import ActorRolloutRefWorker
ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == "megatron":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
from verl.workers.megatron_workers import ActorRolloutRefWorker
ray_worker_group_cls = NVMegatronRayWorkerGroup
else:
raise NotImplementedError
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
role_worker_mapping = {
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
}
global_pool_id = "global_pool"
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
}
# use reference model
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
mapping[Role.RefPolicy] = global_pool_id
if config.reward_model.enable:
from .prime_fsdp_workers import PRIMERewardModelWorker
role_worker_mapping[Role.RewardModel] = ray.remote(PRIMERewardModelWorker)
mapping[Role.RewardModel] = global_pool_id
reward_manager_name = config.reward_model.get("reward_manager", "naive")
if reward_manager_name == "naive":
from verl.workers.reward_manager import NaiveRewardManager
reward_manager_cls = NaiveRewardManager
elif reward_manager_name == "prime":
from verl.workers.reward_manager import PrimeRewardManager
reward_manager_cls = PrimeRewardManager
else:
raise NotImplementedError
reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score)
# Note that we always use function-based RM for validation
val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score)
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
trainer = RayPRIMETrainer(
config=config,
tokenizer=tokenizer,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
)
trainer.init_workers()
trainer.fit()
if __name__ == "__main__":
main()
# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import verl
import verl.utils.torch_functional as verl_F
def compute_rloo_advantage_return(data: verl.DataProto, response_mask: torch.Tensor, n_samples, config):
# calculate rloo reward on different reward sources, and sum again
def masked_rloo(reward_tensor_original, mask_tensor):
reward_tensor = reward_tensor_original.clone()
reward_tensor[~mask_tensor] = 0
for start_pos in range(0, reward_tensor.shape[0], n_samples):
cur_rewards_mean = torch.cat(
[
reward_tensor[pos : pos + 1][mask_tensor[pos : pos + 1]].mean(dim=0, keepdim=True)
for pos in range(start_pos, start_pos + n_samples)
],
dim=0,
)
cur_rewards_sum = cur_rewards_mean.sum()
cur_reward_baseline = cur_rewards_sum / (n_samples - 1)
reward_tensor[start_pos : start_pos + n_samples][mask_tensor[start_pos : start_pos + n_samples]] = (
reward_tensor[start_pos : start_pos + n_samples][mask_tensor[start_pos : start_pos + n_samples]]
* (n_samples / (n_samples - 1))
- cur_reward_baseline
)
return reward_tensor
reward_tensors = []
with torch.no_grad():
if "rm_scores" in data.batch.keys() and config.algorithm.reward_dpo_coef != 0.0:
reward_tensor = data.batch["rm_scores"]
reward_mask = response_mask.bool()
reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_dpo_coef)
if "acc" in data.batch.keys() and config.algorithm.reward_gt_coef != 0.0:
reward_tensor = torch.zeros_like(response_mask, dtype=torch.float32)
reward_mask = torch.zeros_like(response_mask, dtype=torch.bool)
prompt_ids = data.batch["prompts"]
prompt_length = prompt_ids.shape[-1]
valid_response_length = data.batch["attention_mask"][:, prompt_length:].sum(-1)
reward_mask[
torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device),
valid_response_length - 1,
] = True
reward_tensor[
torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device),
valid_response_length - 1,
] = data.batch["acc"]
reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_gt_coef)
final_reward_tensor = sum(reward_tensors)
returns = (final_reward_tensor * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
advantages = returns.clone()
advantages = verl_F.masked_whiten(advantages, response_mask)
return advantages, returns
def compute_ce_dpo_loss_rm(token_level_scores, acc, response_mask, beta):
cur_scores = ((token_level_scores * response_mask).sum(dim=1) * beta).sigmoid()
cur_dpo_loss = torch.nn.functional.binary_cross_entropy(cur_scores, acc)
return cur_dpo_loss
def compute_detach_dpo_loss_rm(token_level_scores, acc, Q_bc, acc_bc, response_mask, beta, bon_mode="none"):
# we always assume that the BoN size equals n_samples
# mode1: use acc as rm
# mode2: use Q as rm
cur_Q = (token_level_scores * response_mask).sum(dim=1) * beta
other_Q = torch.zeros_like(cur_Q)
for i in range(token_level_scores.shape[0]):
Q_chosen = Q_bc[i][acc_bc[i] < acc[i]] if acc[i] > 0 else Q_bc[i][acc_bc[i] > acc[i]]
if len(Q_chosen) > 0:
other_Q[i] = Q_chosen.mean() * beta
else:
other_Q[i] = 0
dpo_loss = -torch.log(torch.sigmoid((cur_Q - other_Q) * ((acc > 0).float() * 2 - 1)))
if bon_mode == "none":
dpo_loss = dpo_loss.mean()
else:
weight = torch.zeros_like(dpo_loss)
n_samples = acc_bc.shape[1]
if bon_mode == "bon_rm":
for i in range(token_level_scores.shape[0]):
weight[i] = n_samples * torch.pow((Q_bc[i] * beta <= cur_Q[i]).float().mean(), n_samples - 1)
elif bon_mode == "bon_acc":
for i in range(token_level_scores.shape[0]):
weight[i] = n_samples * torch.pow((acc_bc[i] <= acc[i]).float().mean(), n_samples - 1)
else:
raise NotImplementedError
dpo_loss = (dpo_loss * weight).sum()
return dpo_loss
def compute_dpo_accuracy(token_level_scores, acc, response_mask, n_samples):
dpo_acc = []
for start_id in range(0, token_level_scores.shape[0], n_samples):
cur_scores = (
token_level_scores[start_id : start_id + n_samples] * response_mask[start_id : start_id + n_samples]
).sum(dim=1)
def get_upper_triangle(tensor_x):
diff_matrix = tensor_x.unsqueeze(1) - tensor_x.unsqueeze(0)
upper_tri_indices = torch.triu(torch.ones_like(diff_matrix).bool(), diagonal=1)
return diff_matrix[upper_tri_indices]
cur_acc_diff = get_upper_triangle(acc[start_id : start_id + n_samples]) # in range [-1,1]
cur_score_diff = get_upper_triangle(cur_scores) # in R
cur_score_prediction = (cur_score_diff > 0).float() # in [0,1]
if cur_acc_diff.abs().sum() == 0:
cur_acc = torch.zeros_like(cur_score_prediction[0]) + 0.5
else:
cur_acc = (
((cur_score_diff > 0) == (cur_acc_diff > 0)).float() * cur_acc_diff.abs()
).sum() / cur_acc_diff.abs().sum()
dpo_acc.append(cur_acc.unsqueeze(0))
return torch.cat(dpo_acc, dim=0).mean()
def compute_dpo_abs_accuracy(token_level_scores, acc, response_mask, n_samples):
return (torch.sign((token_level_scores * response_mask).sum(dim=-1)) == torch.sign(acc * 2 - 1)).float().mean()
# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Implement a multiprocess PPOCritic
"""
import itertools
import torch
import torch.distributed
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
from torch import nn, optim
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import verl.utils.torch_functional as verl_F
from verl import DataProto
from verl.utils.device import get_device_name
from verl.utils.py_functional import append_to_dict
from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches
from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs
from .prime_core_algos import compute_ce_dpo_loss_rm, compute_detach_dpo_loss_rm
__all__ = ["DataParallelPRIMERewardModel"]
class DataParallelPRIMERewardModel:
def __init__(self, config, reward_module: nn.Module, ref_module: nn.Module, reward_optimizer: optim.Optimizer):
self.config = config
self.reward_module = reward_module
self.ref_module = ref_module
self.reward_optimizer = reward_optimizer
self.use_remove_padding = self.config.model.get("use_remove_padding", False)
print(f"Reward model use_remove_padding={self.use_remove_padding}")
self.use_fused_kernels = self.config.model.get("use_fused_kernels", False)
print(f"Reward model use_fused_kernels={self.use_fused_kernels}")
self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1)
def _forward_micro_batch(self, micro_batch, prompt_length):
input_ids = micro_batch["input_ids"]
batch_size, seqlen = input_ids.shape
attention_mask = micro_batch["attention_mask"]
position_ids = micro_batch["position_ids"]
num_actions = micro_batch["input_ids"].shape[-1] - prompt_length
max_positions = micro_batch["attention_mask"][:, prompt_length:].sum(-1)
if self.use_remove_padding:
input_ids_rmpad, indices, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask
) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
).transpose(0, 1)
# for compute the log_prob
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz)
# pad and slice the inputs if sp > 1
if self.ulysses_sequence_parallel_size > 1:
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(
input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size
)
input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(
input_ids_rmpad_rolled, None, self.ulysses_sequence_parallel_size
)
input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0)
output = self.reward_module(
input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids_rmpad,
use_cache=False,
return_dict=self.use_fused_kernels,
)
if self.use_fused_kernels:
rm_log_labels = output.log_probs.squeeze(0) # (total_nnz,)
rm_log_labels = rm_log_labels.to(torch.float32)
else:
rm_output_logits = output.logits.squeeze(0)
rm_log_labels = verl_F.logprobs_from_logits(
logits=rm_output_logits,
labels=input_ids_rmpad_rolled,
)
if self.ulysses_sequence_parallel_size > 1:
rm_log_labels = gather_outputs_and_unpad(
rm_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size
)
rm_log_labels = pad_input(
hidden_states=rm_log_labels.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen
).squeeze(-1)[:, -num_actions - 1 : -1]
else:
output = self.reward_module(
input_ids=micro_batch["input_ids"],
attention_mask=micro_batch["attention_mask"],
position_ids=micro_batch["position_ids"],
use_cache=False,
return_dict=self.use_fused_kernels,
)
if self.use_fused_kernels:
rm_log_labels = output.log_probs[:, :-1] # (bsz, seq_length)
rm_log_labels = rm_log_labels.to(torch.float32)
else:
rm_output_logits = output.logits
rm_log_prob = torch.nn.functional.log_softmax(
rm_output_logits[:, :-1, :], dim=-1
) # (batch_size, seq_length, vocab_size)
rm_log_labels = rm_log_prob.gather(dim=-1, index=micro_batch["input_ids"][:, 1:].unsqueeze(-1)).squeeze(
-1
) # (batch, seq_length)
if self.ref_module is not None:
# do not have to pad again
with torch.no_grad(), torch.autocast(device_type=get_device_name(), dtype=torch.bfloat16):
if self.ulysses_sequence_parallel_size > 1 and self.use_remove_padding:
ref_output = self.ref_module(
input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids_rmpad,
use_cache=False,
)
if self.use_fused_kernels:
ref_log_labels = ref_output.log_probs.squeeze(0) # (total_nnz,)
ref_log_labels = ref_log_labels.to(torch.float32)
else:
ref_output_logits = ref_output.logits.squeeze(0)
ref_log_labels = verl_F.logprobs_from_logits(
logits=ref_output_logits, labels=input_ids_rmpad_rolled
)
ref_log_labels = gather_outputs_and_unpad(
ref_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size
)
ref_log_labels = pad_input(
hidden_states=ref_log_labels.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen
).squeeze(-1)[:, -num_actions - 1 : -1]
else:
ref_output = self.ref_module(
input_ids=micro_batch["input_ids"],
attention_mask=micro_batch["attention_mask"],
position_ids=micro_batch["position_ids"],
use_cache=False,
)
if self.use_fused_kernels:
ref_log_labels = ref_output.log_probs[:, :-1] # (batch_size, seq_length)
ref_log_labels = ref_log_labels.to(torch.float32)
else:
ref_output_logits = ref_output.logits
ref_log_prob = torch.nn.functional.log_softmax(
ref_output_logits[:, :-1, :], dim=-1
) # (batch_size, seq_length, vocab_size)
ref_log_labels = ref_log_prob.gather(
dim=-1, index=micro_batch["input_ids"][:, 1:].unsqueeze(-1)
).squeeze(-1) # (batch, seq_length)
else:
ref_log_labels = micro_batch["old_log_probs"]
ref_log_labels.to(rm_log_labels.dtype)
q = rm_log_labels[:, -num_actions:] - ref_log_labels[:, -num_actions:] # this is actually diff of q
# trim unnecessary logprobs here
for i in range(micro_batch["input_ids"].shape[0]):
q[i, max_positions[i] :] = 0
# reward computation does not need gradient. only q needs
with torch.no_grad():
# generalized estimation of r should go before the reward filling. r means process reward for policy
# model, or the advantage of reward model.
lam = self.config.get("lambda", 0.0)
beta = self.config.model.get("beta_train", 0.05)
if lam == 0.0:
r = q * beta
else:
# reward coefficient takes no effect here
acc = micro_batch["acc"]
q_ = q * beta
r = torch.zeros_like(q)
lastgaelam = 0
# change the last token and mask out all paddings to make this process easier if we rely on
# outcome reward to calculate V
for i in range(q.shape[0]):
if self.config.prime_use_gt:
q_[i, max_positions[i] - 1] = acc[i] - q_[i, : max_positions[i] - 1].sum()
q_[i, max_positions[i] :] = 0
for t in reversed(range(num_actions)):
delta = q_[:, t]
lastgaelam = delta + lam * lastgaelam
r[:, t] = lastgaelam
token_level_score = torch.zeros_like(q)
if self.config.prime_granularity == "token":
for i in range(micro_batch["input_ids"].shape[0]):
token_level_score[i, : max_positions[i] - 1] = r[i, : max_positions[i] - 1]
elif self.config.prime_granularity == "whole":
for i in range(micro_batch["input_ids"].shape[0]):
token_level_score[i, max_positions[i] - 1] = r[i, : max_positions[i]]
else:
raise NotImplementedError
return token_level_score, q
def _optimizer_step(self):
assert self.config.model.optim.grad_clip is not None
if isinstance(self.reward_module, FSDP):
grad_norm = self.reward_module.clip_grad_norm_(self.config.model.optim.grad_clip)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(
self.reward_module.parameters(), max_norm=self.config.model.optim.grad_clip
)
self.reward_optimizer.step()
return grad_norm
def prime_norm(self, token_level_scores):
if self.config.prime_norm == "batch_norm":
reverse_cumsum = torch.cumsum(token_level_scores.flip(dims=[1]), dim=-1).flip(dims=[1])
token_level_scores = token_level_scores / (reverse_cumsum.abs().max() + 1e-6)
return token_level_scores
def compute_rm_score(self, data: DataProto):
self.reward_module.eval()
self.ref_module.eval()
micro_batch_size = data.meta_info["micro_batch_size"]
select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "acc"]
batch = data.select(batch_keys=select_keys).batch
use_dynamic_bsz = data.meta_info["use_dynamic_bsz"]
prompt_length = data.batch["input_ids"].shape[-1] - data.batch["responses"].shape[-1]
if use_dynamic_bsz:
# split using dynamic bsz
max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size
micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)
else:
micro_batches = batch.split(micro_batch_size)
rm_scores_lst = []
q_lst = []
for micro_batch in micro_batches:
with torch.no_grad():
rm_score, q = self._forward_micro_batch(micro_batch, prompt_length)
rm_scores_lst.append(rm_score)
q_lst.append(q)
rm_scores = torch.concat(rm_scores_lst, dim=0)
q = torch.concat(q_lst, dim=0)
rm_scores = self.prime_norm(rm_scores)
if use_dynamic_bsz:
indices = list(itertools.chain.from_iterable(indices))
assert len(indices) == rm_scores.size(0), f"{len(indices)} vs. {rm_scores.size()}"
revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
rm_scores = rm_scores[revert_indices]
return (
rm_scores,
q.detach(),
{
"reward_model/reward": rm_scores.sum(dim=-1).mean().item(),
"reward_model/raw_reward": q.sum(dim=-1).mean().item(),
},
)
def update_rm(self, data: DataProto):
# make sure we are in training mode
self.reward_module.train()
metrics = {}
beta = self.config.model.get("beta_train", 0.05)
select_keys = ["input_ids", "responses", "attention_mask", "position_ids", "acc", "prompts"]
for key in ["Q_bc", "acc_bc"]:
if key in data.batch.keys():
select_keys.append(key)
batch = data.select(batch_keys=select_keys).batch
# Split to make minibatch iterator for updating the actor
# See PPO paper for details. https://arxiv.org/abs/1707.06347
dataloader = batch.split(self.config.mini_batch_size)
rm_scores_lst = []
q_lst = []
for batch_idx, data in enumerate(dataloader):
# split batch into micro_batches
mini_batch = data
if self.config.use_dynamic_bsz:
max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
else:
micro_batches = mini_batch.split(self.config.micro_batch_size_per_gpu)
self.gradient_accumulation = self.config.mini_batch_size // self.config.micro_batch_size_per_gpu
self.reward_optimizer.zero_grad()
for data in micro_batches:
data = data.to(get_device_name())
attention_mask = data["attention_mask"]
acc = data["acc"]
prompt_ids = data["prompts"]
prompt_length = prompt_ids.shape[-1]
response_mask = attention_mask[:, prompt_length:]
rm_score, q = self._forward_micro_batch(data, prompt_length)
rm_scores_lst.append(rm_score)
q_lst.append(q.detach())
if self.config.model.loss_type == "ce":
dpo_loss = compute_ce_dpo_loss_rm(q, acc, response_mask=response_mask, beta=beta)
elif self.config.model.loss_type == "dpo":
# the implementation of dpo is actually detached, which means we have to know the average
# value of w/l reward before the update.
dpo_loss = compute_detach_dpo_loss_rm(
q, acc, Q_bc=data["Q_bc"], acc_bc=data["acc_bc"], response_mask=response_mask, beta=beta
)
elif self.config.model.loss_type == "bon_acc":
# change the original distribution of each sample to BoN distribution, then update reward model
dpo_loss = compute_detach_dpo_loss_rm(
q,
acc,
Q_bc=data["Q_bc"],
acc_bc=data["acc_bc"],
response_mask=response_mask,
beta=beta,
bon_mode="bon_acc",
)
elif self.config.model.loss_type == "bon_rm":
dpo_loss = compute_detach_dpo_loss_rm(
q,
acc,
Q_bc=data["Q_bc"],
acc_bc=data["acc_bc"],
response_mask=response_mask,
beta=beta,
bon_mode="bon_rm",
)
else:
raise NotImplementedError
data = {"reward_model/dpo_loss": dpo_loss.detach().item()}
if self.config.use_dynamic_bsz:
# relative to the dynamic bsz
loss = dpo_loss * (len(data) / self.config.ppo_mini_batch_size)
else:
loss = dpo_loss / self.gradient_accumulation
loss.backward()
append_to_dict(metrics, data)
grad_norm = self._optimizer_step()
data = {"reward_model/grad_norm": grad_norm.detach().item()}
append_to_dict(metrics, data)
self.reward_optimizer.zero_grad()
rm_scores = torch.cat(rm_scores_lst, dim=0)
q = torch.concat(q_lst, dim=0)
rm_scores = self.prime_norm(rm_scores)
metrics.update(
{
"reward_model/reward": rm_scores.sum(dim=-1).mean().item(),
"reward_model/raw_reward": q.sum(dim=-1).mean().item(),
}
)
return rm_scores, metrics
# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import warnings
import torch
import torch.distributed
from omegaconf import OmegaConf
from torch.distributed.device_mesh import init_device_mesh
from verl import DataProto
from verl.models.transformers.monkey_patch import apply_monkey_patch
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, register
from verl.utils import hf_tokenizer
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from verl.utils.device import get_device_id, get_device_name, get_nccl_backend
from verl.utils.flops_counter import FlopsCounter
from verl.utils.fs import copy_local_path_from_hdfs
from verl.utils.fsdp_utils import (
get_fsdp_wrap_policy,
get_init_weight_context_manager,
init_fn,
load_fsdp_model_to_gpu,
load_fsdp_optimizer,
offload_fsdp_model_to_cpu,
offload_fsdp_optimizer,
)
from verl.utils.import_utils import import_external_libs
from verl.utils.profiler import log_gpu_memory_usage
from verl.workers.fsdp_workers import create_device_mesh, get_sharding_strategy
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
from .prime_core_algos import compute_dpo_abs_accuracy, compute_dpo_accuracy
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
class PRIMERewardModelWorker(Worker):
def __init__(self, config):
super().__init__()
import torch.distributed
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend=get_nccl_backend())
self.config = config
# build device mesh for Ulysses Sequence Parallel
world_size = torch.distributed.get_world_size()
fsdp_size = self.config.model.fsdp_config.fsdp_size
self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)
self.ulysses_device_mesh = None
self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1)
dp = world_size // self.ulysses_sequence_parallel_size
if self.ulysses_sequence_parallel_size > 1:
self.ulysses_device_mesh = init_device_mesh(
get_device_name(), mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"]
)
self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
# set FSDP offload params
self._is_offload_param = self.config.model.fsdp_config.param_offload
self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload
# normalize config
self.config.mini_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size
if self.config.micro_batch_size is not None:
self.config.micro_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size
self.config.micro_batch_size_per_gpu = self.config.micro_batch_size
assert self.config.mini_batch_size % self.config.micro_batch_size_per_gpu == 0
def _build_reward_ref_model_optimizer(self, config):
# the following line is necessary
from torch import optim
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision
from verl.utils.model import print_model_size
from verl.utils.torch_dtypes import PrecisionType
local_path = copy_local_path_from_hdfs(config.model.path)
tokenizer_path = copy_local_path_from_hdfs(config.model.tokenizer_path)
self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False))
override_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {})))
override_config_kwargs = {
"bos_token_id": self.tokenizer.bos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
}
override_config_kwargs.update(override_config)
if self.rank == 0:
print(f"Reward model overriding config {override_config_kwargs}")
torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32")
torch_dtype = PrecisionType.to_dtype(torch_dtype)
from transformers import AutoConfig, AutoModelForCausalLM
trust_remote_code = False
reward_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
reward_model_config.num_labels = 1
init_context = get_init_weight_context_manager(use_meta_tensor=not reward_model_config.tie_word_embeddings)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
reward_model_config.classifier_dropout = 0.0
reward_model_config.hidden_dropout = "0"
reward_module = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
config=reward_model_config,
attn_implementation="flash_attention_2",
trust_remote_code=trust_remote_code,
)
fused_kernel_options = config.model.get("fused_kernel_options", None)
fused_kernels_backend = (
fused_kernel_options.get("impl_backend", None) if fused_kernel_options is not None else None
)
apply_monkey_patch(
model=reward_module,
ulysses_sp_size=self.ulysses_sequence_parallel_size,
use_remove_padding=config.model.get("use_remove_padding", False),
use_fused_kernels=config.model.get("use_fused_kernels", False),
fused_kernels_backend=fused_kernels_backend,
)
# some parameters may not in torch_dtype
reward_module.to(torch_dtype)
if config.model.get("enable_gradient_checkpointing", False):
reward_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
if self.rank == 0:
print_model_size(reward_module)
self.reward_model_config = reward_model_config
fsdp_config = self.config.model.fsdp_config
mixed_precision_config = fsdp_config.get("mixed_precision", None)
if mixed_precision_config is not None:
param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16"))
reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32"))
buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32"))
else:
param_dtype = torch.bfloat16
reduce_dtype = torch.float32
buffer_dtype = torch.float32
mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)
auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config.wrap_policy)
log_gpu_memory_usage("Before reward model FSDP", logger=None)
fsdp_mesh = self.device_mesh
sharding_strategy = get_sharding_strategy(fsdp_mesh)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
reward_model_config.classifier_dropout = 0.0
reward_model_config.hidden_dropout = "0"
ref_module = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=copy_local_path_from_hdfs(config.model.ref_path),
torch_dtype=torch_dtype,
config=reward_model_config,
attn_implementation="flash_attention_2",
trust_remote_code=trust_remote_code,
)
# some parameters may not in torch_dtype
ref_module.to(torch_dtype)
reward_module = FSDP(
reward_module,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=get_device_id(),
sharding_strategy=sharding_strategy,
mixed_precision=mixed_precision,
sync_module_states=True,
forward_prefetch=False,
device_mesh=self.device_mesh,
cpu_offload=None,
)
log_gpu_memory_usage("After reward FSDP", logger=None)
ref_module = FSDP(
ref_module,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=get_device_id(),
sharding_strategy=sharding_strategy,
mixed_precision=mixed_precision,
sync_module_states=True,
forward_prefetch=False,
device_mesh=self.device_mesh,
cpu_offload=None,
)
reward_optimizer = optim.AdamW(
reward_module.parameters(),
lr=config.model.optim.lr,
betas=config.model.optim.get("betas", (0.9, 0.999)),
weight_decay=config.model.optim.get("weight_decay", 1e-2),
)
total_steps = config.model.optim.get("total_training_steps", 0)
num_warmup_steps = int(config.model.optim.get("lr_warmup_steps", -1))
if num_warmup_steps < 0:
num_warmup_steps_ratio = config.model.optim.get("lr_warmup_steps_ratio", 0.0)
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
from verl.utils.torch_functional import get_constant_schedule_with_warmup
reward_lr_scheduler = get_constant_schedule_with_warmup(
optimizer=reward_optimizer, num_warmup_steps=num_warmup_steps
)
return reward_module, ref_module, reward_optimizer, reward_lr_scheduler
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get("external_lib", None))
from .prime_dp_rm import DataParallelPRIMERewardModel
self.reward_module, self.ref_module, self.reward_optimizer, self.reward_lr_scheduler = (
self._build_reward_ref_model_optimizer(config=self.config)
)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.reward_module)
offload_fsdp_model_to_cpu(self.ref_module)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.reward_optimizer)
self.rm = DataParallelPRIMERewardModel(
config=self.config,
reward_module=self.reward_module,
ref_module=self.ref_module,
reward_optimizer=self.reward_optimizer,
)
self.flops_counter = FlopsCounter(self.reward_model_config)
self.checkpoint_manager = FSDPCheckpointManager(
model=self.reward_module,
optimizer=self.reward_optimizer,
lr_scheduler=self.reward_lr_scheduler,
tokenizer=self.tokenizer,
)
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_rm_score(self, data: DataProto):
data = data.to(get_device_name())
if self._is_offload_param:
load_fsdp_model_to_gpu(self.reward_module)
load_fsdp_model_to_gpu(self.ref_module)
micro_batch_size = self.config.micro_batch_size_per_gpu
data.meta_info["micro_batch_size"] = micro_batch_size
data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu
data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz
# perform forward computation
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data=data)
rm_scores, q, metrics = self.rm.compute_rm_score(data=data)
prompt_length = data.batch["prompts"].shape[-1]
response_mask = data.batch["attention_mask"][:, prompt_length:]
acc = data.batch["acc"]
dpo_acc = compute_dpo_accuracy(rm_scores, acc, response_mask=response_mask, n_samples=data.meta_info["n"])
dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, response_mask, n_samples=data.meta_info["n"])
metrics["reward_model/dpo_acc"] = dpo_acc.detach().item()
metrics["reward_model/dpo_acc_abs"] = dpo_acc_abs.detach().item()
output = DataProto.from_dict(tensors={"rm_scores": rm_scores, "q": q}, meta_info={"metrics": metrics})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
output = output.to("cpu")
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.reward_module)
offload_fsdp_model_to_cpu(self.ref_module)
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_rm(self, data: DataProto):
data = data.to(get_device_name())
if self._is_offload_param:
load_fsdp_model_to_gpu(self.ref_module)
load_fsdp_model_to_gpu(self.reward_module)
if self._is_offload_optimizer:
load_fsdp_optimizer(optimizer=self.reward_optimizer, device_id=get_device_id())
# perform forward computation
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data=data)
rm_scores, metrics = self.rm.update_rm(data=data)
self.reward_lr_scheduler.step()
lr = self.reward_lr_scheduler.get_last_lr()[0]
metrics["rm/lr"] = lr
prompt_length = data.batch["prompts"].shape[-1]
response_mask = data.batch["attention_mask"][:, prompt_length:]
acc = data.batch["acc"]
dpo_acc_before = compute_dpo_accuracy(
rm_scores, acc, response_mask=response_mask, n_samples=data.meta_info["n"]
)
dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, response_mask, n_samples=data.meta_info["n"])
metrics["reward_model/dpo_acc_before"] = dpo_acc_before.detach().item()
metrics["reward_model/dpo_acc_abs_before"] = dpo_acc_abs.detach().item()
output = DataProto.from_dict(tensors={"rm_scores": rm_scores}, meta_info={"metrics": metrics})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.reward_module)
offload_fsdp_model_to_cpu(self.ref_module)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.reward_optimizer)
output = output.to("cpu")
return output
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
import torch
if self._is_offload_param:
load_fsdp_model_to_gpu(self.reward_module)
self.checkpoint_manager.save_checkpoint(
local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep
)
torch.distributed.barrier()
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.reward_module)
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def load_checkpoint(self, local_path, del_local_after_load=True):
import torch
if self._is_offload_param:
load_fsdp_model_to_gpu(self.reward_module)
self.checkpoint_manager.load_checkpoint(local_path=local_path, del_local_after_load=del_local_after_load)
torch.distributed.barrier()
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.reward_module)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment