Commit f87b35b2 authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
Pipeline #2648 failed with stages
in 0 seconds
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
An naive implementation of split placment example
"""
from pprint import pprint
from verl import DataProto
from verl.trainer.ppo.ray_trainer import compute_advantage, apply_kl_penalty, reduce_metrics, compute_data_metrics, _timer, compute_timing_metrics, AdvantageEstimator
from copy import deepcopy
import numpy as np
import torch
import uuid
def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
from verl.utils.tracking import Tracking
from omegaconf import OmegaConf
logger = Tracking(project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True))
self.global_steps = 0
# load checkpoint before doing anything
self._load_checkpoint()
# perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True):
val_metrics = self._validate()
pprint(f'Initial validation metrics: {val_metrics}')
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get('val_only', False):
return
# we start from step 1
self.global_steps += 1
last_val_metrics = None
for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
metrics = {}
timing_raw = {}
batch: DataProto = DataProto.from_single_dict(batch_dict)
# pop those keys for generation
gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids'])
is_last_step = self.global_steps >= self.total_training_steps
with _timer('step', timing_raw):
# generate a batch
with _timer('gen', timing_raw):
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
with _timer('gen_max', timing_raw):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info['do_sample'] = False
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
batch = batch.union(gen_baseline_output)
reward_baseline_tensor = self.reward_fn(batch)
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
batch.batch['reward_baselines'] = reward_baseline_tensor
del gen_baseline_batch, gen_baseline_output
batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))],
dtype=object)
# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)
# balance the number of valid tokens on each dp rank.
# Note that this breaks the order of data inside the batch.
# Please take care when you implement group based adv computation such as GRPO and rloo
self._balance_batch(batch, metrics=metrics)
# compute global_valid tokens
batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist()
# recompute old_log_probs
with _timer('old_log_prob', timing_raw):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
batch = batch.union(old_log_prob)
if self.use_reference_policy:
# compute reference log_prob
with _timer('ref', timing_raw):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
# compute values
if self.use_critic:
with _timer('values', timing_raw):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
with _timer('adv', timing_raw):
# compute scores. Support both model and function-based.
# We first compute the scores using reward model. Then, we call reward_fn to combine
# the results from reward model and rule-based results.
if self.use_rm:
# we first compute reward model score
reward_tensor = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor)
# we combine with rule-based rm
reward_tensor = self.reward_fn(batch)
batch.batch['token_level_scores'] = reward_tensor
# compute rewards. apply_kl_penalty if available
if self.config.algorithm.use_kl_in_reward:
batch, kl_metrics = apply_kl_penalty(batch,
kl_ctrl=self.kl_ctrl_in_reward,
kl_penalty=self.config.algorithm.kl_penalty)
metrics.update(kl_metrics)
else:
batch.batch['token_level_rewards'] = batch.batch['token_level_scores']
# compute advantages, executed on the driver process
batch = compute_advantage(batch,
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.actor_rollout_ref.rollout.n)
# update critic
if self.use_critic:
with _timer('update_critic_call', timing_raw):
critic_output = self.critic_wg.update_critic(batch)
# implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
with _timer('update_actor_call', timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch)
# NOTE: make sure you set blocking=False in update_actor and update_crtic in the worker class
with _timer('update_actor_critic', timing_raw):
critic_output = critic_output.get()
critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])
metrics.update(critic_output_metrics)
actor_output = actor_output.get()
actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
metrics.update(actor_output_metrics)
# validate
if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \
(is_last_step or self.global_steps % self.config.trainer.test_freq == 0):
with _timer('testing', timing_raw):
val_metrics: dict = self._validate()
if is_last_step:
last_val_metrics = val_metrics
metrics.update(val_metrics)
if self.config.trainer.save_freq > 0 and (is_last_step or \
self.global_steps % self.config.trainer.save_freq == 0):
with _timer('save_checkpoint', timing_raw):
self._save_checkpoint()
# collect metrics
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)
if self.global_steps >= self.total_training_steps:
pprint(f'Final validation metrics: {last_val_metrics}')
return
self.global_steps += 1
set -x
#export VLLM_ATTENTION_BACKEND=XFORMERS
gsm8k_train_path=$HOME/data/rlhf/gsm8k/train.parquet
gsm8k_test_path=$HOME/data/rlhf/math/test.parquet
model_path=Qwen/Qwen2.5-Coder-14B-Instruct
train_files="['$gsm8k_train_path']"
test_files="['$gsm8k_test_path']"
PYTHONPATH=/opt/tiger/open_verl python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.train_batch_size=1024 \
data.max_prompt_length=1024 \
data.max_response_length=1024 \
data.filter_overlong_prompts=True \
data.truncation='error' \
actor_rollout_ref.model.path=$model_path \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=True \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='verl_grpo_example_gsm8k' \
trainer.experiment_name='qwen2_14b_function_rm' \
trainer.n_gpus_per_node=4 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=5 \
trainer.total_epochs=1 $@
\ No newline at end of file
set -x
gsm8k_train_path=$HOME/data/rlhf/gsm8k/train.parquet
gsm8k_val_path=$HOME/data/rlhf/math/test.parquet
model_path=Qwen/Qwen2-72B-Instruct
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=$data_path \
data.val_files=$gsm8k_val_path \
data.train_batch_size=1024 \
data.max_prompt_length=512 \
data.max_response_length=512 \
data.filter_overlong_prompts=True \
data.truncation='error' \
actor_rollout_ref.model.path=model_path \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=True \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
actor_rollout_ref.rollout.tensor_model_parallel_size=16 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='verl_grpo_example_gsm8k' \
trainer.experiment_name='Qwen2_72B_Instruct' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=4 \
trainer.save_freq=-1 \
trainer.test_freq=5 \
trainer.total_epochs=1 $@
\ No newline at end of file
set -x
#### important: vllm version must be >= 0.8.3
gsm8k_train_path=$HOME/data/rlhf/gsm8k/train.parquet
gsm8k_val_path=$HOME/data/rlhf/math/test.parquet
model_path=Qwen/Qwen2-72B-Instruct
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=$gsm8k_train_path \
data.val_files=$gsm8k_val_path \
data.train_batch_size=1024 \
data.max_prompt_length=512 \
data.max_response_length=512 \
data.filter_overlong_prompts=True \
data.truncation='error' \
actor_rollout_ref.model.path=$model_path \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=True \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
actor_rollout_ref.rollout.tensor_model_parallel_size=16 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='verl_grpo_example_gsm8k' \
trainer.experiment_name='Qwen2_72B_Instruct' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=4 \
trainer.save_freq=-1 \
trainer.test_freq=5 \
trainer.total_epochs=1 $@
\ No newline at end of file
set -x
#export VLLM_ATTENTION_BACKEND=XFORMERS
gsm8k_train_path=$HOME/data/rlhf/gsm8k/train.parquet
gsm8k_test_path=$HOME/data/rlhf/math/test.parquet
model_path=Qwen/Qwen2-7B-Instruct
train_files="['$gsm8k_train_path']"
test_files="['$gsm8k_test_path']"
PYTHONPATH=/opt/tiger/open_verl python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.train_batch_size=1024 \
data.max_prompt_length=1024 \
data.max_response_length=1024 \
data.filter_overlong_prompts=True \
data.truncation='error' \
actor_rollout_ref.model.path=$model_path \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=True \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='verl_grpo_example_gsm8k' \
trainer.experiment_name='qwen2_7b_function_rm' \
trainer.n_gpus_per_node=2 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=5 \
trainer.total_epochs=15 $@
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
index 5955b349..ade0cd51 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,3 +7,5 @@ build
slurm*
logs
.vscode
+tests/*
+examples/*
diff --git a/build.sh b/build.sh
new file mode 100644
index 00000000..49d5361f
--- /dev/null
+++ b/build.sh
@@ -0,0 +1,4 @@
+#! /bin/bash
+
+export PYTHONPATH=$PYTHONPATH:$(pwd)
+pip3 install regex ninja
diff --git a/megatron/__init__.py b/megatron/__init__.py
index c35de282..60896b47 100644
--- a/megatron/__init__.py
+++ b/megatron/__init__.py
@@ -2,7 +2,7 @@
import torch
-from .global_vars import get_args, get_retro_args
+from .global_vars import get_args, update_args, fork_args_namespace, get_retro_args
from .global_vars import get_current_global_batch_size
from .global_vars import get_num_microbatches
from .global_vars import get_signal_handler
diff --git a/megatron/arguments.py b/megatron/arguments.py
index 0ca8776e..9ef67624 100644
--- a/megatron/arguments.py
+++ b/megatron/arguments.py
@@ -59,6 +59,16 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
return args
def validate_args(args, defaults={}):
+ # Set input defaults.
+ for key in defaults:
+ if getattr(args, key, None) is not None:
+ if args.rank == 0 and defaults[key] != getattr(args, key):
+ print('WARNING: overriding default argument {key}:{v2} \
+ with {key}:{v}'.format(key=key, v=defaults[key],
+ v2=getattr(args, key)),
+ flush=True)
+
+ setattr(args, key, defaults[key])
# Tensor model parallel size.
args.tensor_model_parallel_size = min(
args.tensor_model_parallel_size, args.world_size)
@@ -125,19 +135,19 @@ def validate_args(args, defaults={}):
args.recompute_granularity = 'selective'
del args.recompute_activations
- # Set input defaults.
- for key in defaults:
- # For default to be valid, it should not be provided in the
- # arguments that are passed to the program. We check this by
- # ensuring the arg is set to None.
- if getattr(args, key, None) is not None:
- if args.rank == 0:
- print('WARNING: overriding default arguments for {key}:{v} \
- with {key}:{v2}'.format(key=key, v=defaults[key],
- v2=getattr(args, key)),
- flush=True)
- else:
- setattr(args, key, defaults[key])
+ # # Set input defaults.
+ # for key in defaults:
+ # # For default to be valid, it should not be provided in the
+ # # arguments that are passed to the program. We check this by
+ # # ensuring the arg is set to None.
+ # if getattr(args, key, None) is not None:
+ # if args.rank == 0:
+ # print('WARNING: overriding default arguments for {key}:{v} \
+ # with {key}:{v2}'.format(key=key, v=defaults[key],
+ # v2=getattr(args, key)),
+ # flush=True)
+ # else:
+ # setattr(args, key, defaults[key])
# Batch size.
assert args.micro_batch_size is not None
diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py
index 29ee34df..fa590b16 100644
--- a/megatron/core/pipeline_parallel/p2p_communication.py
+++ b/megatron/core/pipeline_parallel/p2p_communication.py
@@ -130,32 +130,28 @@ def _batched_p2p_ops(
send_prev_op = torch.distributed.P2POp(
torch.distributed.isend,
tensor_send_prev,
- get_pipeline_model_parallel_prev_rank(),
- group,
+ get_pipeline_model_parallel_prev_rank()
)
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv_prev,
- get_pipeline_model_parallel_prev_rank(),
- group,
+ get_pipeline_model_parallel_prev_rank()
)
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(
torch.distributed.isend,
tensor_send_next,
- get_pipeline_model_parallel_next_rank(),
- group,
+ get_pipeline_model_parallel_next_rank()
)
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv_next,
- get_pipeline_model_parallel_next_rank(),
- group,
+ get_pipeline_model_parallel_next_rank()
)
ops.append(recv_next_op)
if len(ops) > 0:
diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py
index 992da781..2eb78d52 100644
--- a/megatron/core/pipeline_parallel/schedules.py
+++ b/megatron/core/pipeline_parallel/schedules.py
@@ -78,6 +78,8 @@ def get_forward_backward_func():
transformer, this is the encoder's sequence length. This is ignored if variable_seq_lengths
in the config is True. Otherwise, each microbatch in the current global batch size must use
this sequence length.
+
+ hidden_size (int, required): hidden size of the model
micro_batch_size (int, required): The number of sequences in a microbatch.
@@ -287,6 +289,7 @@ def forward_backward_no_pipelining(
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
seq_length: int, # unused
+ hidden_size: int, # unused
micro_batch_size: int, # unused
decoder_seq_length: int = None, # unused
forward_only: bool = False,
@@ -370,8 +373,10 @@ def forward_backward_pipelining_with_interleaving(
data_iterator: Union[Iterator, List[Iterator]],
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
- seq_length: int,
- micro_batch_size: int,
+ seq_length: int = None,
+ hidden_size: int = None,
+ micro_batch_size: int = None,
+ input_shapes: list = None,
decoder_seq_length: int = None,
forward_only: bool = False,
collect_non_loss_data: bool = False,
@@ -457,7 +462,7 @@ def forward_backward_pipelining_with_interleaving(
"Interleaving is not supported with a different decoder sequence length."
)
- tensor_shape = [seq_length, micro_batch_size, config.hidden_size]
+ tensor_shape = [seq_length, micro_batch_size, hidden_size]
if config.sequence_parallel:
tensor_shape[0] = tensor_shape[0] // parallel_state.get_tensor_model_parallel_world_size()
@@ -944,6 +949,7 @@ def get_tensor_shapes(
rank: int,
model_type: ModelType,
seq_length: int,
+ hidden_size: int,
micro_batch_size: int,
decoder_seq_length: int,
config,
@@ -967,12 +973,12 @@ def get_tensor_shapes(
if model_type == ModelType.encoder_and_decoder:
if parallel_state.is_pipeline_stage_before_split(rank):
- tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
+ tensor_shapes.append((seq_length, micro_batch_size, hidden_size))
else:
- tensor_shapes.append((decoder_seq_length, micro_batch_size, config.hidden_size))
- tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
+ tensor_shapes.append((decoder_seq_length, micro_batch_size, hidden_size))
+ tensor_shapes.append((seq_length, micro_batch_size, hidden_size))
else:
- tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
+ tensor_shapes.append((seq_length, micro_batch_size, hidden_size))
return tensor_shapes
@@ -1050,8 +1056,10 @@ def forward_backward_pipelining_without_interleaving(
data_iterator: Union[Iterator, List[Iterator]],
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
- seq_length: int,
- micro_batch_size: int,
+ seq_length: int = None,
+ hidden_size: int = None,
+ micro_batch_size: int = None,
+ input_shapes: list = None,
decoder_seq_length: int = None,
forward_only: bool = False,
collect_non_loss_data: bool = False,
@@ -1127,22 +1135,34 @@ def forward_backward_pipelining_without_interleaving(
model_type = get_model_type(model)
rank = parallel_state.get_pipeline_model_parallel_rank()
- recv_tensor_shapes = get_tensor_shapes(
- rank=rank - 1,
- model_type=model_type,
- seq_length=seq_length,
- micro_batch_size=micro_batch_size,
- decoder_seq_length=decoder_seq_length,
- config=config,
- )
- send_tensor_shapes = get_tensor_shapes(
- rank=rank,
- model_type=model_type,
- seq_length=seq_length,
- micro_batch_size=micro_batch_size,
- decoder_seq_length=decoder_seq_length,
- config=config,
- )
+
+ def get_recv_tensor_shapes(microbatch_id):
+ if input_shapes:
+ return [input_shapes[microbatch_id]]
+ recv_tensor_shapes = get_tensor_shapes(
+ rank=rank - 1,
+ model_type=model_type,
+ seq_length=seq_length,
+ hidden_size=hidden_size,
+ micro_batch_size=micro_batch_size,
+ decoder_seq_length=decoder_seq_length,
+ config=config,
+ )
+ return recv_tensor_shapes
+
+ def get_send_tensor_shapes(microbatch_id):
+ if input_shapes:
+ return [input_shapes[microbatch_id]]
+ send_tensor_shapes = get_tensor_shapes(
+ rank=rank,
+ model_type=model_type,
+ seq_length=seq_length,
+ hidden_size=hidden_size,
+ micro_batch_size=micro_batch_size,
+ decoder_seq_length=decoder_seq_length,
+ config=config,
+ )
+ return send_tensor_shapes
# Input, output tensors only need to be saved when doing backward passes
input_tensors = None
@@ -1163,7 +1183,12 @@ def forward_backward_pipelining_without_interleaving(
else:
checkpoint_activations_microbatch = None
+ # if torch.cuda.current_device() == 0 or torch.cuda.current_device() == 4:
+ # print(f'rank {torch.cuda.current_device()}: micro batch {i}: warmup recv_forward begin...')
+ recv_tensor_shapes = get_recv_tensor_shapes(i) # fwd recv shape
input_tensor = recv_forward(recv_tensor_shapes, config)
+ # if torch.cuda.current_device() == 0 or torch.cuda.current_device() == 4:
+ # print(f'rank {torch.cuda.current_device()}: micro batch {i}: warmup recv_forward end & forward begin...')
output_tensor = forward_step(
forward_step_func,
data_iterator,
@@ -1175,7 +1200,13 @@ def forward_backward_pipelining_without_interleaving(
collect_non_loss_data,
checkpoint_activations_microbatch,
)
+ # if torch.cuda.current_device() == 0 or torch.cuda.current_device() == 4:
+ # print(f'rank {torch.cuda.current_device()}: output tensor shape = {output_tensor[0].shape}, send_tensor_shapes={send_tensor_shapes}')
+ # print(f'rank {torch.cuda.current_device()}: micro batch {i}: warmup forward end & send_forward begin...')
+ send_tensor_shapes = get_send_tensor_shapes(i) # fwd send shape
send_forward(output_tensor, send_tensor_shapes, config)
+ # if torch.cuda.current_device() == 0 or torch.cuda.current_device() == 4:
+ # print(f'rank {torch.cuda.current_device()}: micro batch {i}: warmup send_forward end...')
if not forward_only:
input_tensors.append(input_tensor)
@@ -1186,11 +1217,16 @@ def forward_backward_pipelining_without_interleaving(
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
- input_tensor = recv_forward(recv_tensor_shapes, config)
+ # if torch.cuda.current_device() == 0 or torch.cuda.current_device() == 4:
+ # print(f'rank {torch.cuda.current_device()}: micro batch {num_warmup_microbatches}: 1f1b recv_forward begin...')
+ recv_tensor_shapes = get_recv_tensor_shapes(num_warmup_microbatches) # fwd recv shape
+ input_tensor = recv_forward(recv_tensor_shapes, config)
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
last_iteration = i == (num_microbatches_remaining - 1)
+ next_forward_k = num_warmup_microbatches + i + 1
+ backward_k = i
# Decide to checkpoint all layers' activations of the current micro-batch
if max_outstanding_backprops is not None:
@@ -1199,7 +1235,8 @@ def forward_backward_pipelining_without_interleaving(
) >= config.num_microbatches_with_partial_activation_checkpoints
else:
checkpoint_activations_microbatch = None
-
+ # if torch.cuda.current_device() == 0 or torch.cuda.current_device() == 4:
+ # print(f'rank {torch.cuda.current_device()}: micro batch {num_warmup_microbatches + i}: 1f1b recv_forward end & forward begin...')
output_tensor = forward_step(
forward_step_func,
data_iterator,
@@ -1213,12 +1250,23 @@ def forward_backward_pipelining_without_interleaving(
)
if forward_only:
+ # if torch.cuda.current_device() == 0 or torch.cuda.current_device() == 4:
+ # print(f'rank {torch.cuda.current_device()}: micro batch {num_warmup_microbatches + i}: 1f1b forward end & send forward begin...')
+ send_tensor_shapes = get_send_tensor_shapes(next_forward_k - 1) # fwd send shape
send_forward(output_tensor, send_tensor_shapes, config)
if not last_iteration:
+ # if torch.cuda.current_device() == 0 or torch.cuda.current_device() == 4:
+ # print(f'rank {torch.cuda.current_device()}: micro batch {num_warmup_microbatches + i}: 1f1b send forward end & recv forward begin...')
+ recv_tensor_shapes = get_recv_tensor_shapes(next_forward_k) # fwd recv shape
input_tensor = recv_forward(recv_tensor_shapes, config)
+ else:
+ pass
+ # if torch.cuda.current_device() == 0 or torch.cuda.current_device() == 4:
+ # print(f'rank {torch.cuda.current_device()}: micro batch {num_warmup_microbatches + i}: 1f1b send forward end...')
else:
+ send_tensor_shapes = get_send_tensor_shapes(backward_k) # bwd recv shape
output_tensor_grad = send_forward_recv_backward(
output_tensor, send_tensor_shapes, config
)
@@ -1245,8 +1293,10 @@ def forward_backward_pipelining_without_interleaving(
if last_iteration:
input_tensor = None
+ recv_tensor_shapes = get_recv_tensor_shapes(backward_k) # bwd send shape
send_backward(input_tensor_grad, recv_tensor_shapes, config)
else:
+ recv_tensor_shapes = get_recv_tensor_shapes(next_forward_k) # fwd recv shape
input_tensor = send_backward_recv_forward(
input_tensor_grad, recv_tensor_shapes, config
)
@@ -1254,7 +1304,7 @@ def forward_backward_pipelining_without_interleaving(
# Run cooldown backward passes.
if not forward_only:
for i in range(num_warmup_microbatches):
-
+ backward_k = num_microbatches_remaining + i
# Enable async grad reduction in the last backward pass
# Note: If grad sync function is provided, only enable
# async grad reduction in first pipeline stage. Other
@@ -1267,12 +1317,14 @@ def forward_backward_pipelining_without_interleaving(
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
+ send_tensor_shapes = get_send_tensor_shapes(backward_k) # bwd recv shape
output_tensor_grad = recv_backward(send_tensor_shapes, config)
input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad, model_type, config
)
+ recv_tensor_shapes = get_recv_tensor_shapes(backward_k) # bwd send shape
send_backward(input_tensor_grad, recv_tensor_shapes, config)
# Launch any remaining grad reductions.
diff --git a/megatron/core/utils.py b/megatron/core/utils.py
index d4e042b2..c480d14e 100644
--- a/megatron/core/utils.py
+++ b/megatron/core/utils.py
@@ -55,8 +55,9 @@ def get_model_type(model):
return get_attr_wrapped_model(model, 'model_type')
+# walkaround: get_model_config to get megatron config (ModelParallelConfig)
def get_model_config(model):
- return get_attr_wrapped_model(model, 'config', allow_none=False)
+ return get_attr_wrapped_model(model, 'megatron_config', allow_none=False)
class GlobalMemoryBuffer:
diff --git a/megatron/global_vars.py b/megatron/global_vars.py
index b1b4b043..9e23dea5 100644
--- a/megatron/global_vars.py
+++ b/megatron/global_vars.py
@@ -21,11 +21,48 @@ _GLOBAL_ADLR_AUTORESUME = None
_GLOBAL_TIMERS = None
_GLOBAL_SIGNAL_HANDLER = None
-def get_args():
+DEFAULT_NAMESPACE = 'default'
+import contextlib
+
+@contextlib.contextmanager
+def fork_args_namespace(namespace):
+ """
+ Usage example:
+ update_args('vit', vit_config)
+ with fork_args_namespace('vit'):
+ do vit stuff here
+ """
+ # Check if we have added the args namespace
+ if namespace not in _GLOBAL_ARGS:
+ raise Exception('args namespace {} is not added'.format(namespace))
+ # Store current args namespace.
+ tmp = _GLOBAL_ARGS[DEFAULT_NAMESPACE]
+ # Set args namespace to the desired one
+ _GLOBAL_ARGS[DEFAULT_NAMESPACE] = _GLOBAL_ARGS[namespace]
+ # Do the stuff we wanted to do.
+ try:
+ yield
+ finally:
+ _GLOBAL_ARGS[DEFAULT_NAMESPACE] = tmp
+
+def get_args(namespace=DEFAULT_NAMESPACE):
"""Return arguments."""
_ensure_var_is_initialized(_GLOBAL_ARGS, 'args')
- return _GLOBAL_ARGS
+ return _GLOBAL_ARGS[namespace]
+def set_args(args):
+ global _GLOBAL_ARGS
+ if _GLOBAL_ARGS is None:
+ _GLOBAL_ARGS = {}
+ _GLOBAL_ARGS[DEFAULT_NAMESPACE] = args
+
+def update_args(namespace, args):
+ _ensure_var_is_initialized(_GLOBAL_ARGS, 'args')
+ if namespace not in _GLOBAL_ARGS:
+ import copy
+ _GLOBAL_ARGS[namespace] = copy.deepcopy(_GLOBAL_ARGS[DEFAULT_NAMESPACE])
+ for k, v in args.items():
+ setattr(_GLOBAL_ARGS[namespace], k, v)
def get_retro_args():
"""Return retro arguments."""
@@ -87,7 +124,7 @@ def _set_signal_handler():
-def set_global_variables(args, build_tokenizer=True):
+def set_global_variables(args):
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
assert args is not None
@@ -96,7 +133,7 @@ def set_global_variables(args, build_tokenizer=True):
set_args(args)
_build_num_microbatches_calculator(args)
- if build_tokenizer:
+ if args.vocab_file:
_ = _build_tokenizer(args)
_set_tensorboard_writer(args)
_set_wandb_writer(args)
@@ -107,11 +144,6 @@ def set_global_variables(args, build_tokenizer=True):
_set_signal_handler()
-def set_args(args):
- global _GLOBAL_ARGS
- _GLOBAL_ARGS = args
-
-
def set_retro_args(retro_args):
global _GLOBAL_RETRO_ARGS
_GLOBAL_RETRO_ARGS = retro_args
diff --git a/megatron/initialize.py b/megatron/initialize.py
index fb7866ab..01999622 100644
--- a/megatron/initialize.py
+++ b/megatron/initialize.py
@@ -39,7 +39,7 @@ def initialize_megatron(
if not allow_no_cuda:
# Make sure cuda is available.
assert torch.cuda.is_available(), "Megatron requires CUDA."
-
+ print('use open-source megatron initialize...')
# Parse arguments
args = parse_args(extra_args_provider, ignore_unknown_args)
diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py
index c91a674e..bcb7bd7e 100644
--- a/megatron/model/fused_layer_norm.py
+++ b/megatron/model/fused_layer_norm.py
@@ -81,7 +81,7 @@ class MixedFusedLayerNorm(torch.nn.Module):
if self.no_persist_layer_norm:
assert FusedLayerNormAffineFunction is not None, \
"FusedLayerNormAffineFunction is not available, please install apex from https://github.com/NVIDIA/apex"
- return FusedLayerNormAffineFunction.apply(input, weight, self.bias, self.normalized_shape, self.eps)
+ return FusedLayerNormAffineFunction.apply(input, weight, self.bias, self.normalized_shape, self.eps, False)
else:
output = FastLayerNormFN.apply(input, weight, self.bias, self.eps)
diff --git a/megatron/optimizer/distrib_optimizer.py b/megatron/optimizer/distrib_optimizer.py
index a04ae478..b64d22a5 100644
--- a/megatron/optimizer/distrib_optimizer.py
+++ b/megatron/optimizer/distrib_optimizer.py
@@ -366,7 +366,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
check_for_nan_in_grad, params_have_main_grad, fp16,
- bf16, params_dtype, grad_scaler, models):
+ bf16, params_dtype, grad_scaler, models, overlap_param_gather=False):
"""
See top of class definition for argument descriptions.
@@ -382,8 +382,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
check_for_nan_in_grad, params_have_main_grad,
fp16, bf16, params_dtype, grad_scaler, models)
- assert isinstance(optimizer, Adam), \
- "Only Adam currently supported, due to checkpointing requirements."
+ # assert isinstance(optimizer, Adam), \
+ # "Only Adam currently supported, due to checkpointing requirements."
+
+ if not isinstance(optimizer, Adam):
+ print("WARNING: the optimizer type is not Adam, and now Only Adam currently support checkpointing requirements!")
# Model grad buffer ranges.
self.model_gbuf_ranges = []
@@ -476,7 +479,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self.param_buffer_copied.append(False)
self.num_all_gather_handles = len(self.all_gather_handle_index_to_bucket_index_map)
- self.overlap_param_gather = get_args().overlap_param_gather
+ self.overlap_param_gather = overlap_param_gather
if self.overlap_param_gather:
self.remove_pre_hook_handle = torch.nn.modules.module.register_module_forward_pre_hook(
self._make_forward_pre_hook())
diff --git a/megatron/training.py b/megatron/training.py
index 36f6c52e..73664509 100644
--- a/megatron/training.py
+++ b/megatron/training.py
@@ -430,6 +430,7 @@ def train_step(forward_step_func, data_iterator,
model=model,
num_microbatches=get_num_microbatches(),
seq_length=args.seq_length,
+ hidden_size=args.hidden_size,
micro_batch_size=args.micro_batch_size,
decoder_seq_length=args.decoder_seq_length,
forward_only=False)
diff --git a/tools/prebuild_kernels.py b/tools/prebuild_kernels.py
new file mode 100644
index 00000000..6f891b9e
--- /dev/null
+++ b/tools/prebuild_kernels.py
@@ -0,0 +1,13 @@
+import os
+from megatron.fused_kernels import load
+
+
+class FakeArgs:
+ rank = 0
+
+
+# 7.0 for V100
+# 8.0 for A100/A800
+os.environ["TORCH_CUDA_ARCH_LIST"] = "7.0+PTX;8.0+PTX"
+
+load(FakeArgs)
\ No newline at end of file
# -------------------------------
# build-system
# -------------------------------
[build-system]
requires = [
"setuptools>=61.0",
"wheel"
]
build-backend = "setuptools.build_meta"
# -------------------------------
# project (PEP 621 metadata)
# -------------------------------
[project]
name = "verl"
# We'll mark the version as "dynamic" because it's read from the file "verl/version/version"
# (PEP 621 calls this "dynamic version").
# The actual version is specified in the [tool.setuptools.dynamic] section below.
dynamic = ["version", "dependencies", "optional-dependencies", "authors", "urls"]
description = "verl: Volcano Engine Reinforcement Learning for LLM"
license = {file = "LICENSE"} # or "Apache-2.0", if you prefer an SPDX identifier
readme = {file = "README.md", content-type = "text/markdown"}
requires-python = ">=3.8"
# -------------------------------
# tool.setuptools - Additional config
# -------------------------------
[tool.setuptools]
# True means `setuptools` will attempt to include all relevant files in package_data automatically.
# This corresponds to `include_package_data=True` in setup.py.
include-package-data = true
# We read the version from a file in 'verl/version/version'
[tool.setuptools.dynamic]
version = {file = "verl/version/version"}
# If you need to mimic `package_dir={'': '.'}`:
[tool.setuptools.package-dir]
"" = "."
# If you need to include specific non-Python data (like YAML files or version file):
# This is the rough equivalent of package_data={'': ['version/*'], 'verl': ['trainer/config/*.yaml']}
[tool.setuptools.package-data]
verl = [
"version/*",
"trainer/config/*.yaml"
]
[tool.pylint.message_control]
disable = [
"abstract-method",
"anomalous-backslash-in-string",
"arguments-differ",
"arguments-renamed",
"assignment-from-none",
"attribute-defined-outside-init",
"bad-str-strip-call",
"bare-except",
"broad-exception-caught",
"broad-exception-raised",
"cell-var-from-loop",
"chained-comparison",
"consider-iterating-dictionary",
"consider-using-enumerate",
"consider-using-f-string",
"consider-using-from-import",
"consider-using-generator",
"consider-using-in",
"consider-using-max-builtin",
"consider-using-set-comprehension",
"consider-using-sys-exit",
"consider-using-with",
"cyclic-import",
"dangerous-default-value",
"duplicate-code",
"eval-used",
"expression-not-assigned",
"f-string-without-interpolation",
"fixme",
"function-redefined",
"global-statement",
"global-variable-not-assigned",
"import-error",
"import-outside-toplevel",
"import-self",
"inconsistent-return-statements",
"invalid-character-zero-width-space",
"invalid-name",
"line-too-long",
"logging-fstring-interpolation",
"logging-not-lazy",
"missing-class-docstring",
"missing-final-newline",
"missing-function-docstring",
"missing-module-docstring",
"multiple-imports",
"no-else-continue",
"no-else-raise",
"no-else-return",
"no-member",
"no-self-argument",
"no-value-for-parameter",
"not-an-iterable",
"not-callable",
"notimplemented-raised",
"pointless-exception-statement",
"pointless-string-statement",
"pointless-statement",
"possibly-used-before-assignment",
"protected-access",
"raise-missing-from",
"raising-format-tuple",
"redefined-argument-from-local",
"redefined-builtin",
"redefined-outer-name",
"redundant-u-string-prefix",
"reimported",
"simplifiable-if-expression",
"simplifiable-if-statement",
"singleton-comparison",
"super-init-not-called",
"superfluous-parens",
"too-few-public-methods",
"too-many-arguments",
"too-many-boolean-expressions",
"too-many-branches",
"too-many-instance-attributes",
"too-many-lines",
"too-many-locals",
"too-many-positional-arguments",
"too-many-return-statements",
"too-many-statements",
"trailing-newlines",
"trailing-newlines",
"trailing-whitespace",
"unbalanced-tuple-unpacking",
"undefined-loop-variable",
"undefined-variable",
"ungrouped-imports",
"unidiomatic-typecheck",
"unnecessary-comprehension",
"unnecessary-lambda",
"unnecessary-lambda-assignment",
"unnecessary-pass",
"unspecified-encoding",
"unused-argument",
"unused-import",
"unused-variable",
"unused-wildcard-import",
"use-a-generator",
"use-dict-literal",
"used-before-assignment",
"useless-object-inheritance",
"useless-parent-delegation",
"useless-return",
"wildcard-import",
"wrong-import-order",
"wrong-import-position",
]
\ No newline at end of file
# DAPO Open-Source Implementation
> Open-Source Algorithm Implementation & Expriement Running: [Yuxuan Tong](https://tongyx361.github.io/), [Guangming Sheng](https://hk.linkedin.com/in/guangming-sheng-b50640211)
> [!IMPORTANT]
> **🔥 News!!!**
> - [2025/03] We published the training record of [an early version of DAPO (w/o Token-level PG Loss & Dynamic Sampling)](./run_dapo_early_qwen2.5_32b.sh), achieving 44% on AIME 2024, in [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl?nw=u7n2j5sht28).
🏠 [Homepage](https://dapo-sia.github.io/) | 📝 [Paper](https://dapo-sia.github.io/static/pdf/dapo_paper.pdf) | 🤗 [Datasets&Models@HF](https://huggingface.co/collections/BytedTsinghua-SIA/dapo-67d7f1517ee33c8aed059da0) | 🐱 [Code@GitHub](https://github.com/volcengine/verl/tree/gm-tyx/puffin/main/recipe/dapo) | 🐱 [Repo@GitHub](https://github.com/BytedTsinghua-SIA/DAPO)
> We propose the **D**ecoupled Clip and Dynamic s**A**mpling **P**olicy **O**ptimization (DAPO) algorithm. By making our work publicly available, we provide the broader research community and society with practical access to scalable reinforcement learning, enabling all to benefit from these advancements. Our system is based on the awesome [verl](https://github.com/volcengine/verl) framework. Thanks for their great work! Applying DAPO training to Qwen2.5-32B base model proves to outperform the previous state-of-the-art DeepSeek-R1-Zero-Qwen-32B on AIME 2024, achieving **50%** accuracy with **50%** less training steps.
>
> ![dapo-main-result](https://dapo-sia.github.io/static/images/score.png)
## Quickstart
1. Prepare the datasets **on the Ray cluster**:
```bash
bash prepare_dapo_data.sh # This downloads the datasets to ${HOME}/verl/data by default
```
2. Submit the job to the Ray cluster **from any machine**:
```bash
cd verl # Repo root
export RAY_ADDRESS="http://${RAY_IP:-localhost}:8265" # The Ray cluster address to connect to
export WORKING_DIR="${PWD}" # The local directory to package to the Ray cluster
# Set the runtime environment like env vars and pip packages for the Ray cluster in yaml
export RUNTIME_ENV="./verl/trainer/runtime_env.yaml"
bash recipe/dapo/run_dapo_qwen2.5_32b.sh
```
## Reproduction Runs
| Setup | AIME 2024 Acc. | Training Script | Training Record |
|-------|----------------------|-----------------|-----------------|
| DAPO w/o Token-level PG Loss & Dynamic Sampling | 44% | [run_dapo_early_qwen2.5_32b.sh](./run_dapo_early_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl?nw=u7n2j5sht28) |
| DAPO | 50% | [run_dapo_qwen2.5_32b.sh](./run_dapo_qwen2.5_32b.sh) | W&B (Coming soon) |
## Configuration
> [!NOTE]
> Most experiments in the paper, including the best-performant one, are run without Overlong Filtering because it's somehow overlapping with Overlong Reward Shaping in terms of properly learning from the longest outputs. So we don't implement it here.
### Separated Clip Epsilons (-> Clip-Higher)
An example configuration:
```yaml
actor_rollout_ref:
actor:
clip_ratio_low: 0.2
clip_ratio_high: 0.28
```
`clip_ratio_low` and `clip_ratio_high` specify the $\varepsilon_{\text {low }}$ and $\varepsilon_{\text {high }}$ in the DAPO objective.
Core relevant code:
```python
pg_losses1 = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)
pg_losses = torch.maximum(pg_losses1, pg_losses2)
```
### Dynamic Sampling (with Group Filtering)
An example configuration:
```yaml
data:
gen_batch_size: 1536
train_batch_size: 512
algorithm:
filter_groups:
enable: True
metric: acc # score / seq_reward / seq_final_reward / ...
max_num_gen_batches: 10 # Non-positive values mean no upper limit
```
Setting `filter_groups.enable` to `True` will filter out groups whose outputs' `metric` are all the same, e.g., for `acc`, groups whose outputs' accuracies are all 1 or 0.
The trainer will repeat sampling with `gen_batch_size` until there are enough qualified groups for `train_batch_size` or reaching the upper limit specified by `max_num_gen_batches`.
Core relevant code:
```python
prompt_bsz = self.config.data.train_batch_size
if num_prompt_in_batch < prompt_bsz:
print(f'{num_prompt_in_batch=} < {prompt_bsz=}')
num_gen_batches += 1
max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches
if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:
print(f'{num_gen_batches=} < {max_num_gen_batches=}. Keep generating...')
continue
else:
raise ValueError(
f'{num_gen_batches=} >= {max_num_gen_batches=}. Generated too many. Please check your data.'
)
else:
# Align the batch
traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n
batch = batch[:traj_bsz]
```
### Flexible Loss Aggregation Mode (-> Token-level Policy Gradient Loss)
An example configuration:
```yaml
actor_rollout_ref:
actor:
loss_agg_mode: "token-mean" # / "seq-mean-token-sum" / "seq-mean-token-mean"
# NOTE: "token-mean" is the default behavior
```
Setting `loss_agg_mode` to `token-mean` will mean the (policy gradient) loss across all the tokens in all the sequences in a mini-batch.
Core relevant code:
```python
if loss_agg_mode == "token-mean":
pg_loss = verl_F.masked_mean(pg_losses, eos_mask)
elif loss_agg_mode == "seq-mean-token-sum":
pg_loss = torch.sum(pg_losses * eos_mask, dim=-1) / torch.sum(eos_mask, dim=-1)
pg_loss = torch.mean(pg_loss)
elif loss_agg_mode == "seq-mean-token-mean":
pg_loss = torch.sum(pg_losses * eos_mask, dim=-1) / torch.sum(eos_mask, dim=-1)
pg_loss = torch.mean(pg_loss)
else:
raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}")
```
### Overlong Reward Shaping
An example configuration:
```yaml
data:
max_response_length: 20480 # 16384 + 4096
reward_model:
overlong_buffer:
enable: True
len: 4096
penalty_factor: 1.0
```
Setting `overlong_buffer.enable` to `True` will penalize the outputs whose lengths are overlong but still within the hard context limit.
Specifically, the penalty increases linearly from `0` to `overlong_buffer.penalty_factor` when the length of the output exceeds the `max_response_length` by `0` to `overlong_buffer.len` tokens.
Core relevant code:
```python
if self.overlong_buffer_cfg.enable:
overlong_buffer_len = self.overlong_buffer_cfg.len
expected_len = self.max_resp_len - overlong_buffer_len
exceed_len = valid_response_length - expected_len
overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor
overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0)
reward += overlong_reward
```
#!/usr/bin/env bash
set -uxo pipefail
export VERL_HOME=${VERL_HOME:-"${HOME}/verl"}
export TRAIN_FILE=${TRAIN_FILE:-"${VERL_HOME}/data/dapo-math-17k.parquet"}
export TEST_FILE=${TEST_FILE:-"${VERL_HOME}/data/aime-2024.parquet"}
export OVERWRITE=${OVERWRITE:-0}
mkdir -p "${VERL_HOME}/data"
if [ ! -f "${TRAIN_FILE}" ] || [ "${OVERWRITE}" -eq 1 ]; then
wget -O "${TRAIN_FILE}" "https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k/resolve/main/data/dapo-math-17k.parquet?download=true"
fi
if [ ! -f "${TEST_FILE}" ] || [ "${OVERWRITE}" -eq 1 ]; then
wget -O "${TEST_FILE}" "https://huggingface.co/datasets/BytedTsinghua-SIA/AIME-2024/resolve/main/data/aime-2024.parquet?download=true"
fi
#!/usr/bin/env bash
set -xeuo pipefail
project_name='DAPO'
exp_name='DAPO-Early-Qwen2.5-32B'
adv_estimator=grpo
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.28
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 20))
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 4))
overlong_penalty_factor=1.0
# An early version for DAPO
loss_agg_mode="seq-mean-token-mean"
enable_filter_groups=False
gen_prompt_bsz=512 # NOTE: no filtering here
train_prompt_bsz=512
train_prompt_mini_bsz=32
n_resp_per_prompt=16
# Ray
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
NNODES=${NNODES:-16}
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
val_top_p=0.7
# Performance Related Parameter
sp_size=8
use_dynamic_bsz=True
actor_ppo_max_token_len=$((max_prompt_length + max_response_length))
infer_ppo_max_token_len=$((max_prompt_length + max_response_length))
offload=True
gen_tp=4
ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
--working-dir "${WORKING_DIR}" \
-- python3 -m recipe.dapo.src.main_dapo \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.gen_batch_size=${gen_prompt_bsz} \
data.train_batch_size=${train_prompt_bsz} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
algorithm.filter_groups.enable=${enable_filter_groups} \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
+actor_rollout_ref.model.override_config.attention_dropout=0. \
+actor_rollout_ref.model.override_config.embd_pdrop=0. \
+actor_rollout_ref.model.override_config.resid_pdrop=0. \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k="${top_k}" \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
reward_model.reward_manager=dapo \
reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
reward_model.overlong_buffer.len=${overlong_buffer_len} \
reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
trainer.logger=['console','wandb'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.n_gpus_per_node=8 \
trainer.nnodes="${NNODES}" \
trainer.val_before_train=True \
trainer.test_freq=5 \
trainer.save_freq=5 \
trainer.total_epochs=1 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=auto
\ No newline at end of file
#!/usr/bin/env bash
set -xeuo pipefail
project_name='DAPO'
exp_name='DAPO-Qwen2.5-32B'
adv_estimator=grpo
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.28
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 20))
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 4))
overlong_penalty_factor=1.0
loss_agg_mode="token-mean"
enable_filter_groups=True
filter_groups_metric=acc
max_num_gen_batches=10
train_prompt_bsz=512
gen_prompt_bsz=$((train_prompt_bsz * 3))
n_resp_per_prompt=16
train_prompt_mini_bsz=32
# Ray
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
NNODES=${NNODES:-16}
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
# Performance Related Parameter
sp_size=8
use_dynamic_bsz=True
actor_ppo_max_token_len=$((max_prompt_length + max_response_length))
infer_ppo_max_token_len=$((max_prompt_length + max_response_length))
offload=True
gen_tp=4
ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
--working-dir "${WORKING_DIR}" \
-- python3 -m recipe.dapo.src.main_dapo \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.gen_batch_size=${gen_prompt_bsz} \
data.train_batch_size=${train_prompt_bsz} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
algorithm.filter_groups.enable=${enable_filter_groups} \
algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
algorithm.filter_groups.metric=${filter_groups_metric} \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
+actor_rollout_ref.model.override_config.attention_dropout=0. \
+actor_rollout_ref.model.override_config.embd_pdrop=0. \
+actor_rollout_ref.model.override_config.resid_pdrop=0. \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k="${top_k}" \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
reward_model.reward_manager=dapo \
reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
reward_model.overlong_buffer.len=${overlong_buffer_len} \
reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
trainer.logger=['console','wandb'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.n_gpus_per_node=8 \
trainer.nnodes="${NNODES}" \
trainer.val_before_train=True \
trainer.test_freq=5 \
trainer.save_freq=5 \
trainer.total_epochs=1 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=auto
\ No newline at end of file
data:
tokenizer: null
train_files: ~/data/rlhf/gsm8k/train.parquet
val_files: ~/data/rlhf/gsm8k/test.parquet
prompt_key: prompt
reward_fn_key: data_source
max_prompt_length: 512
max_response_length: 512
gen_batch_size: ${data.train_batch_size}
train_batch_size: 1024
val_batch_size: null # DEPRECATED: Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
return_raw_chat: False
shuffle: True
filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You should disable this and set `truncation='left'
truncation: error
image_key: images
video_key: videos
custom_cls:
path: null
name: null
actor_rollout_ref:
hybrid_engine: True
model:
path: ~/models/deepseek-llm-7b-chat
external_lib: null
override_config: { }
enable_gradient_checkpointing: True
use_remove_padding: False
actor:
strategy: fsdp # This is for backward-compatibility
ppo_mini_batch_size: 256
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
ppo_micro_batch_size_per_gpu: null
use_dynamic_bsz: False
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
# pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)
clip_ratio: 0.2 # default value if clip_ratio_low and clip_ratio_high are not specified
clip_ratio_low: 0.2
clip_ratio_high: 0.2
clip_ratio_c: 3.0 # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729
loss_agg_mode: "token-mean" # / "seq-mean-token-sum" / "seq-mean-token-mean"
# NOTE: "token-mean" is the default behavior
entropy_coeff: 0.001
use_kl_loss: False # True for GRPO
use_torch_compile: True # False to disable torch compile
kl_loss_coef: 0.001 # for grpo
kl_loss_type: low_var_kl # for grpo
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
checkpoint:
contents: ['model', 'hf_model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
optim:
lr: 1e-6
lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio.
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
weight_decay: 0.01
fsdp_config:
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
param_offload: False
optimizer_offload: False
fsdp_size: -1
ref:
fsdp_config:
param_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
log_prob_micro_batch_size_per_gpu: null
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
rollout:
name: vllm
temperature: 1.0
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
top_p: 1
use_fire_sampling: False # https://arxiv.org/abs/2410.21236
prompt_length: ${data.max_prompt_length} # not use for opensource
response_length: ${data.max_response_length}
# for vllm rollout
dtype: bfloat16 # should align with FSDP
gpu_memory_utilization: 0.5
ignore_eos: False
enforce_eager: True
free_cache_engine: True
load_format: dummy_dtensor
tensor_model_parallel_size: 2
max_num_batched_tokens: 8192
max_model_len: null
max_num_seqs: 1024
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
log_prob_micro_batch_size_per_gpu: null
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
disable_log_stats: True
enable_chunked_prefill: True # may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len.
# for hf rollout
do_sample: True
# number of responses (i.e. num sample times)
n: 1 # > 1 for grpo
val_kwargs:
# sampling parameters for validation
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
top_p: 1.0
temperature: 0
n: 1
do_sample: False # default eager for validation
critic:
rollout_n: ${actor_rollout_ref.rollout.n}
strategy: fsdp
optim:
lr: 1e-5
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
weight_decay: 0.01
model:
path: ~/models/deepseek-llm-7b-chat
tokenizer_path: ${actor_rollout_ref.model.path}
override_config: { }
external_lib: ${actor_rollout_ref.model.external_lib}
enable_gradient_checkpointing: True
use_remove_padding: False
fsdp_config:
param_offload: False
optimizer_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
fsdp_size: -1
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
ppo_micro_batch_size_per_gpu: null
forward_micro_batch_size: ${critic.ppo_micro_batch_size}
forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu}
use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2
forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: 1 # sp size
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
shuffle: ${actor_rollout_ref.actor.shuffle}
grad_clip: 1.0
cliprange_value: 0.5
checkpoint:
contents: ['model', 'hf_model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
reward_model:
enable: False
strategy: fsdp
model:
input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical
path: ~/models/FsfairX-LLaMA3-RM-v0.1
external_lib: ${actor_rollout_ref.model.external_lib}
use_remove_padding: False
fsdp_config:
wrap_policy:
min_num_params: 0
param_offload: False
fsdp_size: -1
micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu
micro_batch_size_per_gpu: null # set a number
max_length: null
ulysses_sequence_parallel_size: 1 # sp size
use_dynamic_bsz: ${critic.use_dynamic_bsz}
forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
reward_manager: naive
overlong_buffer:
enable: False # We try to avoid forgetting to set enable
len: 0
penalty_factor: 0.0
log: False
custom_reward_function:
path: null
name: compute_score
algorithm:
gamma: 1.0
lam: 1.0
adv_estimator: gae
use_kl_in_reward: False
kl_penalty: kl # how to estimate kl divergence
kl_ctrl:
type: fixed
kl_coef: 0.001
horizon: 10000
target_kl: 0.1
filter_groups:
enable: False # We try to avoid forgetting to set enable
metric: null # acc / score / seq_reward / seq_final_reward / ...
max_num_gen_batches: 0 # Non-positive values mean no upper limit
trainer:
balance_batch: True
total_epochs: 30
total_training_steps: null
project_name: verl_examples
experiment_name: gsm8k
logger: [ 'console', 'wandb' ]
log_val_generations: 0
nnodes: 1
n_gpus_per_node: 8
save_freq: -1
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or disable or resume_path if resume_from_path is set
resume_from_path: null
val_before_train: True
test_freq: -1
critic_warmup: 0
default_hdfs_dir: null
remove_previous_ckpt_in_save: False
del_local_ckpt_after_load: False
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
max_actor_ckpt_to_keep: null
max_critic_ckpt_to_keep: null
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
FSDP PPO Trainer with Ray-based single controller.
This trainer supports model-agonistic model initialization with huggingface
"""
import uuid
from pprint import pprint
from copy import deepcopy
from collections import defaultdict
from tqdm import tqdm
import numpy as np
import torch
from verl import DataProto
from verl.trainer.ppo.ray_trainer import RayPPOTrainer, _timer, apply_kl_penalty, compute_advantage, AdvantageEstimator
from verl.trainer.ppo.metric_utils import (compute_data_metrics, compute_throughout_metrics, compute_timing_metrics,
reduce_metrics)
class RayDAPOTrainer(RayPPOTrainer):
"""
Note that this trainer runs on the driver process on a single CPU/GPU node.
"""
def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
from verl.utils.tracking import Tracking
from omegaconf import OmegaConf
logger = Tracking(project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True))
self.global_steps = 0
# load checkpoint before doing anything
self._load_checkpoint()
# perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True):
val_metrics = self._validate()
pprint(f'Initial validation metrics: {val_metrics}')
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get('val_only', False):
return
# add tqdm
progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
# we start from step 1
self.global_steps += 1
last_val_metrics = None
timing_raw = defaultdict(float)
batch = None
num_prompt_in_batch = 0
num_gen_batches = 0
for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
metrics = {}
new_batch: DataProto = DataProto.from_single_dict(batch_dict)
num_gen_batches += 1
# pop those keys for generation
if 'multi_modal_inputs' in new_batch.non_tensor_batch.keys():
gen_batch = new_batch.pop(
batch_keys=['input_ids', 'attention_mask', 'position_ids'],
non_tensor_batch_keys=['raw_prompt_ids', 'multi_modal_data', 'multi_modal_inputs'],
)
else:
gen_batch = new_batch.pop(
batch_keys=['input_ids', 'attention_mask', 'position_ids'],
non_tensor_batch_keys=['raw_prompt_ids'],
)
is_last_step = self.global_steps >= self.total_training_steps
with _timer('step', timing_raw):
# generate a batch
with _timer('gen', timing_raw):
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
with _timer('gen_max', timing_raw):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info['do_sample'] = False
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
new_batch = new_batch.union(gen_baseline_output)
reward_baseline_tensor = self.reward_fn(new_batch)
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
new_batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
new_batch.batch['reward_baselines'] = reward_baseline_tensor
del gen_baseline_batch, gen_baseline_output
new_batch.non_tensor_batch['uid'] = np.array(
[str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object)
# repeat to align with repeated responses in rollout
new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
new_batch = new_batch.union(gen_batch_output)
with _timer('reward', timing_raw):
# compute scores. Support both model and function-based.
# We first compute the scores using reward model. Then, we call reward_fn to combine
# the results from reward model and rule-based results.
if self.use_rm:
# we first compute reward model score
reward_tensor = self.rm_wg.compute_rm_score(new_batch)
new_batch = new_batch.union(reward_tensor)
# we combine with rule-based rm
reward_extra_infos_dict: dict[str, list]
try:
reward_result = self.reward_fn(new_batch, return_dict=True)
reward_tensor = reward_result['reward_tensor']
reward_extra_infos_dict = reward_result['reward_extra_info']
except Exception as e:
print(f'Error in reward_fn: {e}')
reward_tensor = self.reward_fn(new_batch)
reward_extra_infos_dict = {}
new_batch.batch['token_level_scores'] = reward_tensor
print(f'{list(reward_extra_infos_dict.keys())=}')
if reward_extra_infos_dict:
new_batch.non_tensor_batch.update({
k: np.array(v) for k, v in reward_extra_infos_dict.items()
})
# compute rewards. apply_kl_penalty if available
if self.config.algorithm.use_kl_in_reward:
new_batch, kl_metrics = apply_kl_penalty(new_batch,
kl_ctrl=self.kl_ctrl_in_reward,
kl_penalty=self.config.algorithm.kl_penalty)
metrics.update(
kl_metrics) # TODO: This will be cleared if we use multiple genenration batches
else:
new_batch.batch['token_level_rewards'] = new_batch.batch['token_level_scores']
if not self.config.algorithm.filter_groups.enable:
batch = new_batch
else: # NOTE: When prompts after filtering is less than train batch size, we skip to the next generation batch
metric_name = self.config.algorithm.filter_groups.metric
if metric_name == "seq_final_reward":
# Turn to numpy for easier filtering
new_batch.non_tensor_batch["seq_final_reward"] = new_batch.batch['token_level_rewards'].sum(
dim=-1).numpy()
elif metric_name == "seq_reward":
new_batch.non_tensor_batch["seq_reward"] = new_batch.batch['token_level_scores'].sum(
dim=-1).numpy()
# Collect the sequence reward for each trajectory
prompt_uid2metric_vals = defaultdict(list)
for uid, metric_val in zip(new_batch.non_tensor_batch['uid'],
new_batch.non_tensor_batch[metric_name]):
prompt_uid2metric_vals[uid].append(metric_val)
prompt_uid2metric_std = {}
for prompt_uid, metric_vals in prompt_uid2metric_vals.items():
prompt_uid2metric_std[prompt_uid] = np.std(metric_vals)
kept_prompt_uids = [
uid for uid, std in prompt_uid2metric_std.items()
if std > 0 or len(prompt_uid2metric_vals[uid]) == 1
]
num_prompt_in_batch += len(kept_prompt_uids)
kept_traj_idxs = []
for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch['uid']):
if traj_from_prompt_uid in kept_prompt_uids:
kept_traj_idxs.append(idx)
new_batch = new_batch[kept_traj_idxs]
if batch is None:
batch = new_batch
else:
batch = DataProto.concat([batch, new_batch])
prompt_bsz = self.config.data.train_batch_size
if num_prompt_in_batch < prompt_bsz:
print(f'{num_prompt_in_batch=} < {prompt_bsz=}')
max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches
if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:
print(f'{num_gen_batches=}. Keep generating...')
continue
else:
raise ValueError(
f'{num_gen_batches=} >= {max_num_gen_batches=}. Generated too many. Please check your data.'
)
else:
# Align the batch
traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n
batch = batch[:traj_bsz]
# balance the number of valid tokens on each dp rank.
# Note that this breaks the order of data inside the batch.
# Please take care when you implement group based adv computation such as GRPO and rloo
if self.config.trainer.balance_batch:
self._balance_batch(batch, metrics=metrics)
# compute global_valid tokens
batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist()
# recompute old_log_probs
with _timer('old_log_prob', timing_raw):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
batch = batch.union(old_log_prob)
if self.use_reference_policy:
# compute reference log_prob
with _timer('ref', timing_raw):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
# compute values
if self.use_critic:
with _timer('values', timing_raw):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
with _timer('adv', timing_raw):
# compute advantages, executed on the driver process
batch = compute_advantage(batch,
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.actor_rollout_ref.rollout.n)
# update critic
if self.use_critic:
with _timer('update_critic', timing_raw):
critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])
metrics.update(critic_output_metrics)
# implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
with _timer('update_actor', timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
metrics.update(actor_output_metrics)
# validate
if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \
(is_last_step or self.global_steps % self.config.trainer.test_freq == 0):
with _timer('testing', timing_raw):
val_metrics: dict = self._validate()
if is_last_step:
last_val_metrics = val_metrics
metrics.update(val_metrics)
if self.config.trainer.save_freq > 0 and (is_last_step or
self.global_steps % self.config.trainer.save_freq == 0):
with _timer('save_checkpoint', timing_raw):
self._save_checkpoint()
# collect metrics
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
# TODO: implement actual tflpo and theoretical tflpo
n_gpus = self.resource_pool_manager.get_n_gpus()
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
timing_raw = defaultdict(float) # clear timing
metrics["train/num_gen_batches"] = num_gen_batches
batch = None
num_prompt_in_batch = 0
num_gen_batches = 0
# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)
if is_last_step:
pprint(f'Final validation metrics: {last_val_metrics}')
progress_bar.close()
return
progress_bar.update(1)
self.global_steps += 1
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
from .dapo_ray_trainer import RayDAPOTrainer
import os
import ray
import hydra
def get_custom_reward_fn(config):
import importlib.util, os
reward_fn_config = config.get("custom_reward_function") or {}
file_path = reward_fn_config.get("path")
if not file_path:
return None
if not os.path.exists(file_path):
raise FileNotFoundError(f"Reward function file '{file_path}' not found.")
spec = importlib.util.spec_from_file_location("custom_module", file_path)
module = importlib.util.module_from_spec(spec)
try:
spec.loader.exec_module(module)
except Exception as e:
raise RuntimeError(f"Error loading module from '{file_path}': {e}")
function_name = reward_fn_config.get("name")
if not hasattr(module, function_name):
raise AttributeError(f"Reward function '{function_name}' not found in '{file_path}'.")
print(f"using customized reward function '{function_name}' from '{file_path}'")
return getattr(module, function_name)
@hydra.main(config_path='config', config_name='dapo_trainer', version_base=None)
def main(config):
run_ppo(config)
def run_ppo(config) -> None:
# TODO(linjunrong.ocss884): this ENV is left for resolving SGLang conflict with ray devices
# isolation, will solve in the future
os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get('CUDA_VISIBLE_DEVICES', '')
if not ray.is_initialized():
# this is for local ray cluster
ray.init(runtime_env={
'env_vars': {
'TOKENIZERS_PARALLELISM': 'true',
'NCCL_DEBUG': 'WARN',
'VLLM_LOGGING_LEVEL': 'WARN'
}
})
runner = TaskRunner.remote()
ray.get(runner.run.remote(config))
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
class TaskRunner:
def run(self, config):
from verl.utils.fs import copy_to_local
# print initial config
from pprint import pprint
from omegaconf import OmegaConf
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
# download the checkpoint from hdfs
local_path = copy_to_local(config.actor_rollout_ref.model.path)
# instantiate tokenizer
from verl.utils import hf_tokenizer, hf_processor
tokenizer = hf_tokenizer(local_path)
processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none
# define worker classes
if config.actor_rollout_ref.actor.strategy == 'fsdp':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray import RayWorkerGroup
ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == 'megatron':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
ray_worker_group_cls = NVMegatronRayWorkerGroup
else:
raise NotImplementedError
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
role_worker_mapping = {
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
Role.Critic: ray.remote(CriticWorker),
Role.RefPolicy: ray.remote(ActorRolloutRefWorker)
}
global_pool_id = 'global_pool'
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
Role.Critic: global_pool_id,
Role.RefPolicy: global_pool_id,
}
# we should adopt a multi-source reward function here
# - for rule-based rm, we directly call a reward score
# - for model-based rm, we call a model
# - for code related prompt, we send to a sandbox if there are test cases
# - finally, we combine all the rewards together
# - The reward type depends on the tag of the data
if config.reward_model.enable:
if config.reward_model.strategy == 'fsdp':
from verl.workers.fsdp_workers import RewardModelWorker
elif config.reward_model.strategy == 'megatron':
from verl.workers.megatron_workers import RewardModelWorker
else:
raise NotImplementedError
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
mapping[Role.RewardModel] = global_pool_id
# reference model
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
mapping[Role.RefPolicy] = global_pool_id
reward_manager_name = config.reward_model.get("reward_manager", "naive")
if reward_manager_name == 'naive':
from verl.workers.reward_manager import NaiveRewardManager
reward_manager_cls = NaiveRewardManager
elif reward_manager_name == 'prime':
from verl.workers.reward_manager import PrimeRewardManager
reward_manager_cls = PrimeRewardManager
elif reward_manager_name == 'dapo':
from verl.workers.reward_manager import DAPORewardManager
reward_manager_cls = DAPORewardManager
else:
raise NotImplementedError
compute_score = get_custom_reward_fn(config)
reward_fn = reward_manager_cls(tokenizer=tokenizer,
num_examine=0,
compute_score=compute_score,
reward_fn_key=config.data.reward_fn_key,
max_resp_len=config.data.max_response_length,
overlong_buffer_cfg=config.reward_model.overlong_buffer)
# Note that we always use function-based RM for validation
val_reward_fn = reward_manager_cls(tokenizer=tokenizer,
num_examine=1,
compute_score=compute_score,
reward_fn_key=config.data.reward_fn_key,
max_resp_len=config.data.max_response_length,
overlong_buffer_cfg=config.reward_model.overlong_buffer)
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
trainer = RayDAPOTrainer(config=config,
tokenizer=tokenizer,
processor=processor,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn)
trainer.init_workers()
trainer.fit()
if __name__ == '__main__':
main()
#!/usr/bin/env bash
set -xeuo pipefail
project_name='DAPO'
exp_name='DAPO-Qwen2.5-7B-Math-Test'
adv_estimator=grpo
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.28
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 2))
enable_overlong_buffer=True
overlong_buffer_len=512
overlong_penalty_factor=1.0
loss_agg_mode="token-mean"
enable_filter_groups=True
filter_groups_metric=acc
max_num_gen_batches=10
train_prompt_bsz=512
gen_prompt_bsz=$((train_prompt_bsz * 3))
train_prompt_mini_bsz=32
n_resp_per_prompt=16
# Ray
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
NNODES=${NNODES:-4}
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
# Mathematically equivalent
use_dynamic_bsz=True
infer_micro_batch_size=null
train_micro_batch_size=null
offload=False
ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
--working-dir "${WORKING_DIR}" \
-- python3 -m recipe.dapo.src.main_dapo \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.gen_batch_size=${gen_prompt_bsz} \
data.train_batch_size=${train_prompt_bsz} \
data.truncation='left' \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
algorithm.filter_groups.enable=${enable_filter_groups} \
algorithm.filter_groups.metric=${filter_groups_metric} \
algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.model.path="${MODEL_PATH}" \
+actor_rollout_ref.model.override_config.attention_dropout=0. \
+actor_rollout_ref.model.override_config.embd_pdrop=0. \
+actor_rollout_ref.model.override_config.resid_pdrop=0. \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.actor.loss_agg_mode=True \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k="${top_k}" \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \
actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
reward_model.reward_manager=dapo \
reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
reward_model.overlong_buffer.len=${overlong_buffer_len} \
reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
trainer.logger=['console','wandb'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.n_gpus_per_node=8 \
trainer.nnodes="${NNODES}" \
trainer.val_before_train=True \
trainer.test_freq=2 \
trainer.save_freq=2 \
trainer.total_epochs=1 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=disable
\ No newline at end of file
# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
\ No newline at end of file
# the prime config will override default ppo_trainer.yaml
hydra:
searchpath:
- file://verl/trainer/config
defaults:
- ppo_trainer
- _self_
data:
filter_accuracy: True
accuracy_lower_bound: 0.2
accuracy_upper_bound: 0.8
oversample_factor: 4.0 # Sample more responses than the batch size. prompts satisfying the filter will be prioritized.
filter_truncate: True
truncation: right
actor_rollout_ref:
hybrid_engine: True
model:
use_remove_padding: True
rollout:
# number of responses (i.e. num sample times)
n: 4
actor:
entropy_coeff: 0.001
reward_model:
enable: True
strategy: fsdp
model:
ref_path: ${reward_model.model.path}
use_remove_padding: True
tokenizer_path: ${actor_rollout_ref.model.path}
enable_gradient_checkpointing: ${actor_rollout_ref.model.enable_gradient_checkpointing}
ref_type: freeze
fsdp_config:
min_num_params: 0
param_offload: ${actor_rollout_ref.actor.fsdp_config.param_offload}
# grad_offload: ${actor_rollout_ref.actor.fsdp_config.grad_offload}
optimizer_offload: ${actor_rollout_ref.actor.fsdp_config.optimizer_offload}
update: before # ``before`` for double-forward, ``after`` for single-forward
optim:
lr: 1e-6
lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio.
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null
warmup_style: constant
total_training_steps: -1 # must be overridden by program
weight_decay: 0.
grad_clip: 10.0
beta_train: 0.05
loss_type: ce # currently only supports ce loss
prime_granularity: token
prime_norm: batch_norm # batch_norm or none. if set to none, the normalizer is beta_train
mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
reward_manager: prime
algorithm:
adv_estimator: rloo
# now supports rloo. it treats different source of reward separately.
kl_ctrl:
type: fixed
kl_coef: 0.000
reward_gt_coef: 5
reward_dpo_coef: 5
trainer:
project_name: prime
experiment_name: examples
val_before_train: False
balance_batch: False
\ No newline at end of file
# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
from .prime_ray_trainer import RayPRIMETrainer
import ray
import hydra
@hydra.main(config_path='config', config_name='prime_trainer', version_base=None)
def main(config):
run_prime(config)
def run_prime(config, compute_score=None):
if not ray.is_initialized():
# this is for local ray cluster
ray.init(
runtime_env={'env_vars': {
'TOKENIZERS_PARALLELISM': 'true',
'NCCL_DEBUG': 'WARN'
}},
)
ray.get(main_task.remote(config, compute_score))
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
def main_task(config, compute_score=None):
from verl.utils.fs import copy_local_path_from_hdfs
# print initial config
from pprint import pprint
from omegaconf import OmegaConf
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
# download the checkpoint from hdfs
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
# instantiate tokenizer
from verl.utils import hf_tokenizer
tokenizer = hf_tokenizer(local_path)
# define worker classes
if config.actor_rollout_ref.actor.strategy == 'fsdp':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray import RayWorkerGroup
ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == 'megatron':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
ray_worker_group_cls = NVMegatronRayWorkerGroup
else:
raise NotImplementedError
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
role_worker_mapping = {
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
}
global_pool_id = 'global_pool'
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
}
#use reference model
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
mapping[Role.RefPolicy] = global_pool_id
if config.reward_model.enable:
from .prime_fsdp_workers import PRIMERewardModelWorker
role_worker_mapping[Role.RewardModel] = ray.remote(PRIMERewardModelWorker)
mapping[Role.RewardModel] = global_pool_id
reward_manager_name = config.reward_model.get("reward_manager", "naive")
if reward_manager_name == 'naive':
from verl.workers.reward_manager import NaiveRewardManager
reward_manager_cls = NaiveRewardManager
elif reward_manager_name == 'prime':
from verl.workers.reward_manager import PrimeRewardManager
reward_manager_cls = PrimeRewardManager
else:
raise NotImplementedError
reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score)
# Note that we always use function-based RM for validation
val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score)
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
trainer = RayPRIMETrainer(config=config,
tokenizer=tokenizer,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn)
trainer.init_workers()
trainer.fit()
if __name__ == '__main__':
main()
# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import verl
import verl.utils.torch_functional as verl_F
def compute_rloo_advantage_return(data: verl.DataProto, response_mask: torch.Tensor, n_samples, config):
# calculate rloo reward on different reward sources, and sum again
def masked_rloo(reward_tensor_original, mask_tensor):
reward_tensor = reward_tensor_original.clone()
reward_tensor[~mask_tensor] = 0
for start_pos in range(0, reward_tensor.shape[0], n_samples):
cur_rewards_mean = torch.cat([
reward_tensor[pos:pos + 1][mask_tensor[pos:pos + 1]].mean(dim=0, keepdim=True)
for pos in range(start_pos, start_pos + n_samples)
],
dim=0)
cur_rewards_sum = cur_rewards_mean.sum()
cur_reward_baseline = cur_rewards_sum / (n_samples - 1)
reward_tensor[start_pos:start_pos + n_samples][
mask_tensor[start_pos:start_pos + n_samples]] = \
reward_tensor[start_pos:start_pos + n_samples][
mask_tensor[start_pos:start_pos + n_samples]] * (
n_samples / (n_samples - 1)) - cur_reward_baseline
return reward_tensor
reward_tensors = []
with torch.no_grad():
if 'rm_scores' in data.batch.keys() and config.algorithm.reward_dpo_coef != 0.:
reward_tensor = data.batch['rm_scores']
reward_mask = response_mask.bool()
reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_dpo_coef)
if 'acc' in data.batch.keys() and config.algorithm.reward_gt_coef != 0.:
reward_tensor = torch.zeros_like(response_mask, dtype=torch.float32)
reward_mask = torch.zeros_like(response_mask, dtype=torch.bool)
prompt_ids = data.batch['prompts']
prompt_length = prompt_ids.shape[-1]
valid_response_length = data.batch['attention_mask'][:, prompt_length:].sum(-1)
reward_mask[
torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device),
valid_response_length - 1] = True
reward_tensor[
torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device),
valid_response_length - 1] = data.batch['acc']
reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_gt_coef)
final_reward_tensor = sum(reward_tensors)
returns = (final_reward_tensor * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
advantages = returns.clone()
advantages = verl_F.masked_whiten(advantages, response_mask)
return advantages, returns
def compute_ce_dpo_loss_rm(token_level_scores, acc, response_mask, beta):
cur_scores = ((token_level_scores * response_mask).sum(dim=1) * beta).sigmoid()
cur_dpo_loss = torch.nn.functional.binary_cross_entropy(cur_scores, acc)
return cur_dpo_loss
def compute_detach_dpo_loss_rm(token_level_scores, acc, Q_bc, acc_bc, response_mask, beta, bon_mode='none'):
# we always assume that the BoN size equals n_samples
# mode1: use acc as rm
# mode2: use Q as rm
cur_Q = (token_level_scores * response_mask).sum(dim=1) * beta
other_Q = torch.zeros_like(cur_Q)
for i in range(token_level_scores.shape[0]):
if acc[i] > 0:
Q_chosen = Q_bc[i][acc_bc[i] < acc[i]]
else:
Q_chosen = Q_bc[i][acc_bc[i] > acc[i]]
if len(Q_chosen) > 0:
other_Q[i] = Q_chosen.mean() * beta
else:
other_Q[i] = 0
dpo_loss = -torch.log(torch.sigmoid((cur_Q - other_Q) * ((acc > 0).float() * 2 - 1)))
if bon_mode == 'none':
dpo_loss = dpo_loss.mean()
else:
weight = torch.zeros_like(dpo_loss)
n_samples = acc_bc.shape[1]
if bon_mode == 'bon_rm':
for i in range(token_level_scores.shape[0]):
weight[i] = n_samples * torch.pow((Q_bc[i] * beta <= cur_Q[i]).float().mean(), n_samples - 1)
elif bon_mode == 'bon_acc':
for i in range(token_level_scores.shape[0]):
weight[i] = n_samples * torch.pow((acc_bc[i] <= acc[i]).float().mean(), n_samples - 1)
else:
raise NotImplementedError
dpo_loss = (dpo_loss * weight).sum()
return dpo_loss
def compute_dpo_accuracy(token_level_scores, acc, response_mask, n_samples):
dpo_acc = []
for start_id in range(0, token_level_scores.shape[0], n_samples):
cur_scores = (token_level_scores[start_id:start_id + n_samples] *
response_mask[start_id:start_id + n_samples]).sum(dim=1)
def get_upper_triangle(tensor_x):
diff_matrix = tensor_x.unsqueeze(1) - tensor_x.unsqueeze(0)
upper_tri_indices = torch.triu(torch.ones_like(diff_matrix).bool(), diagonal=1)
return diff_matrix[upper_tri_indices]
cur_acc_diff = get_upper_triangle(acc[start_id:start_id + n_samples]) # in range [-1,1]
cur_score_diff = get_upper_triangle(cur_scores) # in R
cur_score_prediction = (cur_score_diff > 0).float() # in [0,1]
if cur_acc_diff.abs().sum() == 0:
cur_acc = torch.zeros_like(cur_score_prediction[0]) + 0.5
else:
cur_acc = (((cur_score_diff > 0) == (cur_acc_diff > 0)).float() *
cur_acc_diff.abs()).sum() / cur_acc_diff.abs().sum()
dpo_acc.append(cur_acc.unsqueeze(0))
return torch.cat(dpo_acc, dim=0).mean()
def compute_dpo_abs_accuracy(token_level_scores, acc, response_mask, n_samples):
return (torch.sign((token_level_scores * response_mask).sum(dim=-1)) == torch.sign(acc * 2 - 1)).float().mean()
# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Implement a multiprocess PPOCritic
"""
import itertools
from typing import Iterable
import torch
import torch.distributed
from torch import nn, optim
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from .prime_core_algos import compute_ce_dpo_loss_rm, compute_detach_dpo_loss_rm
from verl import DataProto
from verl.trainer.ppo import core_algos
from verl.workers.critic import BasePPOCritic
from verl.utils.py_functional import append_to_dict
from verl.utils.torch_functional import masked_mean
from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad
from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx
import verl.utils.torch_functional as verl_F
from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis
__all__ = ['DataParallelPRIMERewardModel']
class DataParallelPRIMERewardModel:
def __init__(self, config, reward_module: nn.Module, ref_module: nn.Module, reward_optimizer: optim.Optimizer):
self.config = config
self.reward_module = reward_module
self.ref_module = ref_module
self.reward_optimizer = reward_optimizer
self.use_remove_padding = self.config.model.get('use_remove_padding', False)
print(f'Reward model use_remove_padding={self.use_remove_padding}')
self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1)
def _forward_micro_batch(self, micro_batch, prompt_length):
from flash_attn.bert_padding import pad_input, unpad_input, index_first_axis, rearrange
from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad
input_ids = micro_batch['input_ids']
batch_size, seqlen = input_ids.shape
attention_mask = micro_batch['attention_mask']
position_ids = micro_batch['position_ids']
num_actions = micro_batch['input_ids'].shape[-1] - prompt_length
max_positions = micro_batch['attention_mask'][:, prompt_length:].sum(-1)
if self.use_remove_padding:
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
indices).transpose(0, 1)
# for compute the log_prob
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz)
# pad and slice the inputs if sp > 1
if self.ulysses_sequence_parallel_size > 1:
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(
input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size)
input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(input_ids_rmpad_rolled, None,
self.ulysses_sequence_parallel_size)
input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0)
rm_output_logits = self.reward_module(input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids_rmpad,
use_cache=False).logits.squeeze(
0) # copied. I don't really know why there is a squeeze
rm_log_labels = verl_F.logprobs_from_logits(logits=rm_output_logits, labels=input_ids_rmpad_rolled)
if self.ulysses_sequence_parallel_size > 1:
rm_log_labels = gather_outpus_and_unpad(rm_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size)
rm_log_labels = pad_input(hidden_states=rm_log_labels.unsqueeze(-1),
indices=indices,
batch=batch_size,
seqlen=seqlen).squeeze(-1)[:, -num_actions - 1:-1]
else:
rm_output_logits = self.reward_module(input_ids=micro_batch['input_ids'],
attention_mask=micro_batch['attention_mask'],
position_ids=micro_batch['position_ids'],
use_cache=False).logits
rm_log_prob = torch.nn.functional.log_softmax(rm_output_logits[:, :-1, :],
dim=-1) # (batch_size, seq_length, vocab_size)
rm_log_labels = rm_log_prob.gather(dim=-1, index=micro_batch['input_ids'][:, 1:].unsqueeze(-1)).squeeze(
-1) # (batch, seq_length)
if self.ref_module is not None:
# do not have to pad again
with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16):
if self.ulysses_sequence_parallel_size > 1 and self.use_remove_padding:
ref_output_logits = self.ref_module(input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids_rmpad,
use_cache=False).logits.squeeze(0)
ref_log_labels = verl_F.logprobs_from_logits(logits=ref_output_logits,
labels=input_ids_rmpad_rolled)
ref_log_labels = gather_outpus_and_unpad(ref_log_labels,
gather_dim=0,
unpad_dim=0,
padding_size=pad_size)
ref_log_labels = pad_input(hidden_states=ref_log_labels.unsqueeze(-1),
indices=indices,
batch=batch_size,
seqlen=seqlen).squeeze(-1)[:, -num_actions - 1:-1]
else:
ref_output_logits = self.ref_module(input_ids=micro_batch['input_ids'],
attention_mask=micro_batch['attention_mask'],
position_ids=micro_batch['position_ids'],
use_cache=False).logits
ref_log_prob = torch.nn.functional.log_softmax(ref_output_logits[:, :-1, :],
dim=-1) # (batch_size, seq_length, vocab_size)
ref_log_labels = ref_log_prob.gather(dim=-1,
index=micro_batch['input_ids'][:, 1:].unsqueeze(-1)).squeeze(
-1) # (batch, seq_length)
else:
ref_log_labels = micro_batch['old_log_probs']
ref_log_labels.to(rm_log_labels.dtype)
q = rm_log_labels[:, -num_actions:] - ref_log_labels[:, -num_actions:] # this is actually diff of q
# trim unnecessary logprobs here
for i in range(micro_batch['input_ids'].shape[0]):
q[i, max_positions[i]:] = 0
# reward computation does not need gradient. only q needs
with torch.no_grad():
# generalized estimation of r should go before the reward filling. r means process reward for policy model, or the advantage of reward model.
lam = self.config.get('lambda', 0.)
beta = self.config.model.get('beta_train', 0.05)
if lam == 0.:
r = q * beta
else:
# reward coefficient takes no effect here
acc = micro_batch['acc']
q_ = q * beta
r = torch.zeros_like(q)
lastgaelam = 0
# change the last token and mask out all paddings to make this process easier if we rely on outcome reward to calculate V
for i in range(q.shape[0]):
if self.config.prime_use_gt:
q_[i, max_positions[i] - 1] = acc[i] - q_[i, :max_positions[i] - 1].sum()
q_[i, max_positions[i]:] = 0
for t in reversed(range(num_actions)):
delta = q_[:, t]
lastgaelam = delta + lam * lastgaelam
r[:, t] = lastgaelam
token_level_score = torch.zeros_like(q)
if self.config.prime_granularity == 'token':
for i in range(micro_batch['input_ids'].shape[0]):
token_level_score[i, :max_positions[i] - 1] = r[i, :max_positions[i] - 1]
elif self.config.prime_granularity == 'whole':
for i in range(micro_batch['input_ids'].shape[0]):
token_level_score[i, max_positions[i] - 1] = r[i, :max_positions[i]]
else:
raise NotImplementedError
return token_level_score, q
def _optimizer_step(self):
assert self.config.model.optim.grad_clip is not None
if isinstance(self.reward_module, FSDP):
grad_norm = self.reward_module.clip_grad_norm_(self.config.model.optim.grad_clip)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(self.reward_module.parameters(),
max_norm=self.config.model.optim.grad_clip)
self.reward_optimizer.step()
return grad_norm
def prime_norm(self, token_level_scores):
if self.config.prime_norm == 'batch_norm':
reverse_cumsum = torch.cumsum(token_level_scores.flip(dims=[1]), dim=-1).flip(dims=[1])
token_level_scores = token_level_scores / (reverse_cumsum.abs().max() + 1e-6)
return token_level_scores
def compute_rm_score(self, data: DataProto):
self.reward_module.eval()
self.ref_module.eval()
micro_batch_size = data.meta_info['micro_batch_size']
select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'acc']
batch = data.select(batch_keys=select_keys).batch
use_dynamic_bsz = data.meta_info['use_dynamic_bsz']
prompt_length = data.batch['input_ids'].shape[-1] - data.batch['responses'].shape[-1]
if use_dynamic_bsz:
# split using dynamic bsz
max_token_len = data.meta_info['max_token_len'] * self.ulysses_sequence_parallel_size
micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)
else:
micro_batches = batch.split(micro_batch_size)
rm_scores_lst = []
q_lst = []
for micro_batch in micro_batches:
with torch.no_grad():
rm_score, q = self._forward_micro_batch(micro_batch, prompt_length)
rm_scores_lst.append(rm_score)
q_lst.append(q)
rm_scores = torch.concat(rm_scores_lst, dim=0)
q = torch.concat(q_lst, dim=0)
rm_scores = self.prime_norm(rm_scores)
if use_dynamic_bsz:
indices = list(itertools.chain.from_iterable(indices))
assert len(indices) == rm_scores.size(0), f"{len(indices)} vs. {rm_scores.size()}"
revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
rm_scores = rm_scores[revert_indices]
return rm_scores, q.detach(), {
'reward_model/reward': rm_scores.sum(dim=-1).mean().item(),
'reward_model/raw_reward': q.sum(dim=-1).mean().item()
}
def update_rm(self, data: DataProto):
# make sure we are in training mode
self.reward_module.train()
metrics = {}
beta = self.config.model.get('beta_train', 0.05)
select_keys = ['input_ids', 'responses', 'attention_mask', 'position_ids', 'acc', 'prompts']
for key in ['Q_bc', 'acc_bc']:
if key in data.batch.keys():
select_keys.append(key)
batch = data.select(batch_keys=select_keys).batch
# Split to make minibatch iterator for updating the actor
# See PPO paper for details. https://arxiv.org/abs/1707.06347
dataloader = batch.split(self.config.mini_batch_size)
rm_scores_lst = []
q_lst = []
for batch_idx, data in enumerate(dataloader):
# split batch into micro_batches
mini_batch = data
if self.config.use_dynamic_bsz:
max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
else:
micro_batches = mini_batch.split(self.config.micro_batch_size_per_gpu)
self.gradient_accumulation = self.config.mini_batch_size // self.config.micro_batch_size_per_gpu
self.reward_optimizer.zero_grad()
for data in micro_batches:
data = data.cuda()
attention_mask = data['attention_mask']
acc = data['acc']
prompt_ids = data['prompts']
prompt_length = prompt_ids.shape[-1]
response_mask = attention_mask[:, prompt_length:]
rm_score, q = self._forward_micro_batch(data, prompt_length)
rm_scores_lst.append(rm_score)
q_lst.append(q.detach())
if self.config.model.loss_type == 'ce':
dpo_loss = compute_ce_dpo_loss_rm(q, acc, response_mask=response_mask, beta=beta)
elif self.config.model.loss_type == 'dpo':
# the implementation of dpo is actually detached, which means we have to know the average value of w/l reward before the update.
dpo_loss = compute_detach_dpo_loss_rm(q,
acc,
Q_bc=data['Q_bc'],
acc_bc=data['acc_bc'],
response_mask=response_mask,
beta=beta)
elif self.config.model.loss_type == 'bon_acc':
# change the original distribution of each sample to BoN distribution, then update reward model
dpo_loss = compute_detach_dpo_loss_rm(q,
acc,
Q_bc=data['Q_bc'],
acc_bc=data['acc_bc'],
response_mask=response_mask,
beta=beta,
bon_mode='bon_acc')
elif self.config.model.loss_type == 'bon_rm':
dpo_loss = compute_detach_dpo_loss_rm(q,
acc,
Q_bc=data['Q_bc'],
acc_bc=data['acc_bc'],
response_mask=response_mask,
beta=beta,
bon_mode='bon_rm')
else:
raise NotImplementedError
data = {'reward_model/dpo_loss': dpo_loss.detach().item()}
if self.config.use_dynamic_bsz:
# relative to the dynamic bsz
loss = dpo_loss * (len(data) / self.config.ppo_mini_batch_size)
else:
loss = dpo_loss / self.gradient_accumulation
loss.backward()
append_to_dict(metrics, data)
grad_norm = self._optimizer_step()
data = {'reward_model/grad_norm': grad_norm.detach().item()}
append_to_dict(metrics, data)
self.reward_optimizer.zero_grad()
rm_scores = torch.cat(rm_scores_lst, dim=0)
q = torch.concat(q_lst, dim=0)
rm_scores = self.prime_norm(rm_scores)
metrics.update({
'reward_model/reward': rm_scores.sum(dim=-1).mean().item(),
'reward_model/raw_reward': q.sum(dim=-1).mean().item()
})
return rm_scores, metrics
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment