Commit 7f6cc211 authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
Pipeline #2874 failed with stages
in 0 seconds
# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
FSDP PPO Trainer with Ray-based single controller.
This trainer supports model-agonistic model initialization with huggingface
"""
import os
import statistics
import uuid
from copy import deepcopy
from pprint import pprint
import numpy as np
import torch
from omegaconf import OmegaConf, open_dict
from verl import DataProto
from verl.single_controller.ray import RayWorkerGroup
from verl.trainer.ppo.core_algos import agg_loss
from verl.trainer.ppo.metric_utils import _compute_response_info
from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role, WorkerType
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
from verl.utils.metric import reduce_metrics
from verl.utils.profiler.performance import simple_timer
from . import prime_core_algos
def compute_advantage(data: DataProto, adv_estimator, config):
if adv_estimator == "rloo":
responses = data.batch["responses"]
response_length = responses.size(-1)
attention_mask = data.batch["attention_mask"]
response_mask = attention_mask[:, -response_length:]
advantages, returns = prime_core_algos.compute_rloo_advantage_return(
data, response_mask, config.actor_rollout_ref.rollout.n, config
)
data.batch["advantages"] = advantages
data.batch["returns"] = returns
else:
raise NotImplementedError
return data
def compute_data_metrics(batch, use_critic=True):
advantages = batch.batch["advantages"]
returns = batch.batch["returns"]
max_response_length = batch.batch["responses"].shape[-1]
prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool()
response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool()
max_prompt_length = prompt_mask.size(-1)
response_info = _compute_response_info(batch)
prompt_length = response_info["prompt_length"]
response_length = response_info["response_length"]
valid_adv = torch.masked_select(advantages, response_mask)
valid_returns = torch.masked_select(returns, response_mask)
if use_critic:
values = batch.batch["values"]
valid_values = torch.masked_select(values, response_mask)
return_diff_var = torch.var(valid_returns - valid_values)
return_var = torch.var(valid_returns)
metrics = {
# adv
"critic/advantages/mean": torch.mean(valid_adv).detach().item(),
"critic/advantages/max": torch.max(valid_adv).detach().item(),
"critic/advantages/min": torch.min(valid_adv).detach().item(),
# returns
"critic/returns/mean": torch.mean(valid_returns).detach().item(),
"critic/returns/max": torch.max(valid_returns).detach().item(),
"critic/returns/min": torch.min(valid_returns).detach().item(),
**(
{
# values
"critic/values/mean": torch.mean(valid_values).detach().item(),
"critic/values/max": torch.max(valid_values).detach().item(),
"critic/values/min": torch.min(valid_values).detach().item(),
# vf explained var
"critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
}
if use_critic
else {}
),
# response length
"response_length/mean": torch.mean(response_length).detach().item(),
"response_length/max": torch.max(response_length).detach().item(),
"response_length/min": torch.min(response_length).detach().item(),
"response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float())
.detach()
.item(),
# prompt length
"prompt_length/mean": torch.mean(prompt_length).detach().item(),
"prompt_length/max": torch.max(prompt_length).detach().item(),
"prompt_length/min": torch.min(prompt_length).detach().item(),
"prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
}
return metrics
def compute_response_mask(data: DataProto):
responses = data.batch["responses"]
response_length = responses.size(1)
attention_mask = data.batch["attention_mask"]
return attention_mask[:, -response_length:]
def compute_timing_metrics(batch, timing_raw):
response_info = _compute_response_info(batch)
num_prompt_tokens = torch.sum(response_info["prompt_length"]).item()
num_response_tokens = torch.sum(response_info["response_length"]).item()
num_overall_tokens = num_prompt_tokens + num_response_tokens
num_tokens_of_section = {
"gen": num_response_tokens,
**{name: num_overall_tokens for name in ["ref", "values", "adv", "update_critic", "update_actor"]},
}
return {
**{f"timing_s/{name}": value for name, value in timing_raw.items()},
**{
f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name]
for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys())
},
}
class RayPRIMETrainer(RayPPOTrainer):
"""
Note that this trainer runs on the driver process on a single CPU/GPU node.
"""
# TODO: support each role have individual ray_worker_group_cls,
# i.e., support different backend of different role
def __init__(
self,
config,
tokenizer,
role_worker_mapping: dict[Role, WorkerType],
resource_pool_manager: ResourcePoolManager,
ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
reward_fn=None,
val_reward_fn=None,
device_name="cuda",
):
# assert get_torch_device().is_available(), 'cuda must be available on driver'
super().__init__(
config,
tokenizer,
role_worker_mapping,
resource_pool_manager,
ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
device_name=device_name,
)
self.use_critic = False
def _validate_config(self):
super()._validate_config()
# TODO: Additional config checks can be added here
def _create_dataloader(self, *args, **kwargs):
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
# TODO: we have to make sure the batch size is divisible by the dp size
self.train_dataset = RLHFDataset(
data_files=self.config.data.train_files, tokenizer=self.tokenizer, config=self.config.data
)
# use sampler for better ckpt resume
if self.config.data.shuffle:
train_dataloader_generator = torch.Generator()
train_dataloader_generator.manual_seed(self.config.data.get("seed", 1))
sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator)
else:
sampler = SequentialSampler(data_source=self.train_dataset)
self.train_dataloader = DataLoader(
dataset=self.train_dataset,
batch_size=int(self.config.data.train_batch_size * self.config.data.oversample_factor),
drop_last=True,
collate_fn=collate_fn,
sampler=sampler,
)
self.val_dataset = RLHFDataset(
data_files=self.config.data.val_files, tokenizer=self.tokenizer, config=self.config.data
)
self.val_dataloader = DataLoader(
dataset=self.val_dataset,
batch_size=len(self.val_dataset),
shuffle=True,
drop_last=True,
collate_fn=collate_fn,
)
assert len(self.train_dataloader) >= 1
assert len(self.val_dataloader) >= 1
print(f"Size of train dataloader: {len(self.train_dataloader)}")
print(f"Size of val dataloader: {len(self.val_dataloader)}")
# inject total_training_steps to actor/critic optim_config. This is hacky.
total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
if self.config.trainer.total_training_steps is not None:
total_training_steps = self.config.trainer.total_training_steps
self.total_training_steps = total_training_steps
print(f"Total training steps: {self.total_training_steps}")
OmegaConf.set_struct(self.config, True)
with open_dict(self.config):
self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps
self.config.critic.optim.total_training_steps = total_training_steps
def _save_checkpoint(self):
# path: given_path + `/global_step_{global_steps}` + `/actor`
local_global_step_folder = os.path.join(
self.config.trainer.default_local_dir, f"global_step_{self.global_steps}"
)
print(f"local_global_step_folder: {local_global_step_folder}")
actor_local_path = os.path.join(local_global_step_folder, "actor")
actor_remote_path = (
None
if self.config.trainer.default_hdfs_dir is None
else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor")
)
self.actor_rollout_wg.save_checkpoint(
actor_local_path,
actor_remote_path,
self.global_steps,
)
if self.use_rm:
reward_local_path = os.path.join(local_global_step_folder, "reward")
reward_remote_path = (
None
if self.config.trainer.default_hdfs_dir is None
else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "reward")
)
self.rm_wg.save_checkpoint(
reward_local_path,
reward_remote_path,
self.global_steps,
)
# save dataloader
dataloader_local_path = os.path.join(local_global_step_folder, "data.pt")
import dill
torch.save(self.train_dataloader, dataloader_local_path, pickle_module=dill)
# latest checkpointed iteration tracker (for atomic usage)
local_latest_checkpointed_iteration = os.path.join(
self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt"
)
with open(local_latest_checkpointed_iteration, "w") as f:
f.write(str(self.global_steps))
def _load_checkpoint(self):
if self.config.trainer.resume_mode == "disable":
return 0
# load from hdfs
if self.config.trainer.default_hdfs_dir is not None:
NotImplementedError("load from hdfs is not implemented yet")
else:
checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path
if not os.path.isabs(checkpoint_folder):
working_dir = os.getcwd()
checkpoint_folder = os.path.join(working_dir, checkpoint_folder)
global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest
# find global_step_folder
if self.config.trainer.resume_mode == "auto":
if global_step_folder is None:
print("Training from scratch")
return 0
else:
if self.config.trainer.resume_mode == "resume_path":
assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type"
assert "global_step_" in self.config.trainer.resume_from_path, (
"resume ckpt must specify the global_steps"
)
global_step_folder = self.config.trainer.resume_from_path
if not os.path.isabs(global_step_folder):
working_dir = os.getcwd()
global_step_folder = os.path.join(working_dir, global_step_folder)
print(f"Load from checkpoint folder: {global_step_folder}")
# set global step
self.global_steps = int(global_step_folder.split("global_step_")[-1])
print(f"Setting global step to {self.global_steps}")
print(f"Resuming from {global_step_folder}")
actor_path = os.path.join(global_step_folder, "actor")
reward_path = os.path.join(global_step_folder, "reward")
# load actor
self.actor_rollout_wg.load_checkpoint(
actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load
)
# load rm
if self.use_rm:
self.rm_wg.load_checkpoint(reward_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load)
# load dataloader,
# TODO: from remote not implemented yet
dataloader_local_path = os.path.join(global_step_folder, "data.pt")
self.train_dataloader = torch.load(dataloader_local_path)
if isinstance(self.train_dataloader.dataset, RLHFDataset):
self.train_dataloader.dataset.resume_dataset_state()
def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC to
construct the PPO dataflow. The light-weight advantage computation is done on the driver process.
"""
from omegaconf import OmegaConf
from verl.utils.tracking import Tracking
logger = Tracking(
project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True),
)
self.global_steps = 0
# load checkpoint before doing anything
self._load_checkpoint()
# perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
val_metrics = self._validate()
assert val_metrics, f"{val_metrics=}"
pprint(f"Initial validation metrics: {val_metrics}")
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get("val_only", False):
return
# we start from step 1
self.global_steps += 1
for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
metrics = {}
timing_raw = {}
batch: DataProto = DataProto.from_single_dict(batch_dict)
# pop those keys for generation
gen_batch = batch.pop(batch_keys=["input_ids", "attention_mask", "position_ids"])
gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
with simple_timer("step", timing_raw):
# generate a batch
with simple_timer("gen", timing_raw):
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
timing_raw.update(gen_batch_output.meta_info["timing"])
gen_batch_output.meta_info.pop("timing", None)
if self.config.algorithm.adv_estimator == "remax":
with simple_timer("gen_max", timing_raw):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["do_sample"] = False
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
batch = batch.union(gen_baseline_output)
reward_baseline_tensor = self.reward_fn(batch)
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
batch.batch["reward_baselines"] = reward_baseline_tensor
del gen_baseline_batch, gen_baseline_output
batch.non_tensor_batch["uid"] = np.array(
[str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object
)
# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)
# Balance the number of valid tokens across DP ranks.
# NOTE: This usually changes the order of data in the `batch`,
# which won't affect the advantage calculation (since it's based on uid),
# but might affect the loss calculation (due to the change of mini-batching).
# TODO: Decouple the DP balancing and mini-batching.
if self.config.trainer.balance_batch:
self._balance_batch(batch, metrics=metrics)
# compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
# verify
with simple_timer("verify", timing_raw):
scores = self.reward_fn.verify(batch)
metrics["acc"] = statistics.mean(scores)
# filter the batch. 1/oversample_factor samples will be kept.
# If there is a filter, prompts passing it will be prioritized.
batch = self.filter_and_downsample(scores, batch)
batch.meta_info["n"] = self.config.actor_rollout_ref.rollout.n
n_samples = self.config.actor_rollout_ref.rollout.n
# recompute old_log_probs
with simple_timer("old_log_prob", timing_raw):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
entropys = old_log_prob.batch["entropys"]
response_masks = compute_response_mask(batch)
loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
metrics.update(old_log_prob_metrics)
old_log_prob.batch.pop("entropys")
batch = batch.union(old_log_prob)
if self.use_reference_policy:
# compute reference log_prob
with simple_timer("ref", timing_raw):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
with simple_timer("adv", timing_raw):
if self.use_rm:
update_style = self.config.reward_model.model.get("update", "none")
if update_style == "none": # only run forward
reward_output = self.rm_wg.compute_rm_score(batch)
elif update_style == "after": # update and directly return the reward
reward_output = self.rm_wg.update_rm(batch)
elif update_style == "before": # update reward model, and then run forward
reward_output = self.rm_wg.update_rm(batch)
if "metrics" in reward_output.meta_info.keys():
reward_output_metrics = reduce_metrics(reward_output.meta_info["metrics"])
metrics.update(reward_output_metrics)
reward_output = self.rm_wg.compute_rm_score(batch)
elif (
update_style == "reverse"
): # run forward to calculate statistics, then update reward model
reward_output = self.rm_wg.compute_rm_score(batch)
# broadcast q and acc tensor to each result
bc_td = DataProto.from_dict(
tensors={
"Q_bc": reward_output.batch["q"]
.sum(dim=-1)
.view(-1, n_samples)
.unsqueeze(1)
.expand(-1, n_samples, -1)
.reshape(-1, n_samples),
"acc_bc": batch.batch["acc"]
.view(-1, n_samples)
.unsqueeze(1)
.expand(-1, n_samples, -1)
.reshape(-1, n_samples),
}
)
batch = batch.union(bc_td)
reward_output = self.rm_wg.update_rm(batch)
else:
raise NotImplementedError
batch = batch.union(reward_output)
if "metrics" in reward_output.meta_info.keys():
reward_output_metrics = reduce_metrics(reward_output.meta_info["metrics"])
metrics.update(reward_output_metrics)
# compute advantages, executed on the driver process
batch = compute_advantage(
batch, adv_estimator=self.config.algorithm.adv_estimator, config=self.config
)
# update actor
with simple_timer("update_actor", timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)
# validate
if (
self.val_reward_fn is not None
and self.config.trainer.test_freq > 0
and self.global_steps % self.config.trainer.test_freq == 0
):
with simple_timer("testing", timing_raw):
val_metrics: dict = self._validate()
metrics.update(val_metrics)
if self.config.trainer.save_freq > 0 and self.global_steps % self.config.trainer.save_freq == 0:
with simple_timer("save_checkpoint", timing_raw):
self._save_checkpoint()
# collect metrics
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)
self.global_steps += 1
if self.global_steps >= self.total_training_steps:
# perform validation after training
if self.val_reward_fn is not None:
val_metrics = self._validate()
pprint(f"Final validation metrics: {val_metrics}")
logger.log(data=val_metrics, step=self.global_steps)
if (
self.config.trainer.save_freq > 0
and (self.global_steps - 1) % self.config.trainer.save_freq != 0
):
with simple_timer("save_checkpoint", timing_raw):
self._save_checkpoint()
return
def filter_and_downsample(self, scores, batch: DataProto):
"""
downsample the batch according to oversample_factor
samples passing the filters will be prioritized
"""
n_samples = int(self.config.actor_rollout_ref.rollout.n)
reward_matrix = torch.tensor(scores).reshape(-1, n_samples)
filter_mask = torch.ones((reward_matrix.shape[0]), dtype=torch.bool)
if self.config.data.filter_accuracy:
acc_tensor = torch.mean(reward_matrix, dim=-1)
filter_mask[
(acc_tensor > self.config.data.accuracy_upper_bound)
| (acc_tensor < self.config.data.accuracy_lower_bound)
] = False
if self.config.data.filter_truncate:
length_matrix = (
batch.batch["attention_mask"][:, -batch.batch["responses"].shape[-1] :]
.sum(dim=-1)
.reshape(-1, n_samples)
)
length_tensor = torch.max(length_matrix, dim=-1)[0]
filter_mask[length_tensor >= self.config.data.max_response_length - 1] = False
reorder_index = torch.argsort(filter_mask, descending=True)
reorder_index = (reorder_index.unsqueeze(-1) * n_samples + torch.arange(0, n_samples).unsqueeze(0)).view(-1)
batch.reorder(
reorder_index[: int(len(batch) // self.config.data.oversample_factor)]
) # this operation is inplace
return batch
set -x
gsm8k_train_path=$HOME/data/gsm8k/train.parquet
gsm8k_test_path=$HOME/data/gsm8k/test.parquet
# download from https://huggingface.co/datasets/PRIME-RL/Eurus-2-RL-Data
math_train_path=$HOME/data/math/train.parquet
math_test_path=$HOME/data/math/test.parquet
train_files="['$gsm8k_train_path', '$math_train_path']"
test_files="['$gsm8k_test_path', '$math_test_path']"
model_path=PRIME-RL/Eurus-2-7B-SFT
# model_path=Qwen/Qwen2.5-0.5B-Instruct
python3 -m recipe.prime.main_prime \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.train_batch_size=64 \
data.val_batch_size=6312 \
data.max_prompt_length=1024 \
data.max_response_length=3072 \
data.filter_overlong_prompts=True \
data.filter_accuracy=True \
data.accuracy_lower_bound=0.2 \
data.accuracy_upper_bound=0.8 \
data.oversample_factor=4 \
actor_rollout_ref.model.path=$model_path \
actor_rollout_ref.actor.optim.lr=5e-7 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=True \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.n=4 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
algorithm.adv_estimator=rloo \
algorithm.use_kl_in_reward=True \
algorithm.kl_penalty=kl \
algorithm.kl_ctrl.kl_coef=0.001 \
reward_model.model.path=$model_path \
reward_model.micro_batch_size_per_gpu=1 \
reward_model.model.update=before \
reward_model.model.beta_train=0.05 \
reward_model.model.optim.lr=1e-6 \
reward_model.model.optim.grad_clip=10.0 \
reward_model.model.input_tokenizer=null \
reward_model.mini_batch_size=64 \
trainer.val_before_train=False \
trainer.logger='["console","wandb"]' \
trainer.project_name='prime_example' \
trainer.experiment_name='Eurus-2-7B-SFT-gsm8k' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=64 \
trainer.test_freq=64 \
trainer.total_epochs=15 $@
set -x
# download from https://huggingface.co/datasets/PRIME-RL/Eurus-2-RL-Data
code_train_path=$HOME/data/code/train.parquet
code_test_path=$HOME/data/code/test.parquet
train_files="['$code_train_path']"
test_files="['$code_test_path']"
model_path=PRIME-RL/Eurus-2-7B-SFT
# model_path=Qwen/Qwen2.5-0.5B-Instruct
python3 -m recipe.prime.main_prime \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.train_batch_size=64 \
data.val_batch_size=6312 \
data.max_prompt_length=1024 \
data.max_response_length=3072 \
data.filter_overlong_prompts=True \
data.filter_accuracy=True \
data.accuracy_lower_bound=0.2 \
data.accuracy_upper_bound=0.8 \
data.oversample_factor=4 \
actor_rollout_ref.model.path=$model_path \
actor_rollout_ref.actor.optim.lr=5e-7 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=True \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.n=4 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
algorithm.adv_estimator=rloo \
algorithm.use_kl_in_reward=True \
algorithm.kl_penalty=kl \
algorithm.kl_ctrl.kl_coef=0.001 \
reward_model.model.path=$model_path \
reward_model.micro_batch_size_per_gpu=1 \
reward_model.model.update=before \
reward_model.model.beta_train=0.05 \
reward_model.model.optim.lr=1e-6 \
reward_model.model.optim.grad_clip=10.0 \
reward_model.model.input_tokenizer=null \
reward_model.mini_batch_size=64 \
trainer.val_before_train=False \
trainer.logger='["console","wandb"]' \
trainer.project_name='prime_example' \
trainer.experiment_name='Eurus-2-7B-SFT-code' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=64 \
trainer.test_freq=64 \
trainer.total_epochs=15 $@
# DeepSeek R1 Reproduction
This recipe is under development, if you are interested, checkout the TODO list and join this project! https://github.com/volcengine/verl/issues/708
## Reproducing Evaluation
Eval Results of DS-R1-Distill-Qwen2.5-1.5B (k=8)
Dataset | Test Results | Reported
-- | -- | --
GPQA Diamond | 35.3 | 33.8
LiveCodeBench | 16.9 | 16.9
AIME 2024 | 30.4 | 28.9
CNMO 2024 (en) | 45.1 | -
CNMO 2024 (zh) | 41.0 | -
---
Eval Results (DS-R1)
Dataset | Test Results (k=1) | Test Results (k=4) | Reported
-- | -- | -- | --
GPQA Diamond | 67.7 | 69.6 | 71.5
LiveCodeBench | 64.7 | 63.1 | 65.9
AIME 2024 | 86.7 | 79.2 | 79.8
CNMO 2024 | 75.0 | 78.5 | 78.8
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
data:
path: /tmp/math_Qwen2-7B-Instruct.parquet
prompt_key: prompt
response_key: responses
data_source_key: data_source
reward_model_key: reward_model
custom_reward_function:
path: null
name: compute_score
ray_init:
num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then.
\ No newline at end of file
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocess the dataset to parquet format
"""
import argparse
import os
from functools import partial
from datasets import concatenate_datasets, load_dataset
from verl.utils.hdfs_io import copy, makedirs
def example_map_fn(example, idx, process_fn, data_source, ability, split):
question, solution = process_fn(example)
data = {
"data_source": data_source,
"prompt": [{"role": "user", "content": question}],
"ability": ability,
"reward_model": {"style": "rule", "ground_truth": solution},
"extra_info": {"split": split, "index": idx},
}
return data
def build_aime2024_dataset():
def process_aime2024(example):
return example["Problem"], str(example["Answer"])
data_source = "Maxwell-Jia/AIME_2024"
print(f"Loading the {data_source} dataset from huggingface...", flush=True)
dataset = load_dataset(data_source, split="train")
map_fn = partial(
example_map_fn, process_fn=process_aime2024, data_source=data_source, ability="English", split="test"
)
dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names)
return dataset
def build_gpqa_dimond_dataset():
import random
GPQA_QUERY_TEMPLATE = (
"Answer the following multiple choice question. The last line of your response should be of the following "
"format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before "
"answering.\n\n{Question}\n\nA) {A}\nB) {B}\nC) {C}\nD) {D}"
)
def process_gpqa_diamond(example):
choices = [example["Incorrect Answer 1"], example["Incorrect Answer 2"], example["Incorrect Answer 3"]]
random.shuffle(choices)
gold_index = random.randint(0, 3)
choices.insert(gold_index, example["Correct Answer"])
query_prompt = GPQA_QUERY_TEMPLATE.format(
A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=example["Question"]
)
gold_choice = "ABCD"[gold_index]
return query_prompt, gold_choice
data_source = "Idavidrein/gpqa"
print(f"Loading the {data_source} dataset from huggingface...", flush=True)
dataset = load_dataset(data_source, "gpqa_diamond", split="train")
map_fn = partial(
example_map_fn, process_fn=process_gpqa_diamond, data_source=data_source, ability="Math", split="test"
)
dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names)
return dataset
def build_cnmo2024_dataset():
def process_cnmo2024(example):
return example["question"], example["answer"]
data_source = "opencompass/LiveMathBench"
print(f"Loading the {data_source} dataset from huggingface...", flush=True)
dataset_en = load_dataset(data_source, "v202412_CNMO_en", split="test")
map_fn_en = partial(
example_map_fn, process_fn=process_cnmo2024, data_source="opencompass/cnmo2024_en", ability="Math", split="test"
)
dataset_en = dataset_en.map(map_fn_en, with_indices=True, remove_columns=dataset_en.column_names)
dataset_zh = load_dataset(data_source, "v202412_CNMO_cn", split="test")
map_fn_zh = partial(
example_map_fn, process_fn=process_cnmo2024, data_source="opencompass/cnmo2024_zh", ability="Math", split="test"
)
dataset_zh = dataset_zh.map(map_fn_zh, with_indices=True, remove_columns=dataset_zh.column_names)
dataset = concatenate_datasets([dataset_en, dataset_zh])
return dataset
def build_livecodebench_dataset():
import base64
import json
import pickle
import zlib
def process_livecodebench(example):
# Construct Query Prompt
# From https://github.com/LiveCodeBench/LiveCodeBench/blob/998c52d394b836f15fff3b9a29866191108ff81b/lcb_runner/prompts/code_generation.py#L140
query_prompt = (
f"You will be given a question (problem specification) and will generate a correct Python program "
f"that matches the specification and passes all tests.\n\nQuestion: {example['question_content']}\n\n"
)
if example["starter_code"]:
query_prompt += (
f"You will use the following starter code to write the solution to the problem and enclose your "
f"code within delimiters.\n```python\n{example['starter_code']}\n```"
)
else:
query_prompt += (
"Read the inputs from stdin solve the problem and write the answer to stdout (do not directly test "
"on the sample inputs). Enclose your code within delimiters as follows. Ensure that when the python "
"program runs, it reads the inputs, runs the algorithm and writes output to STDOUT."
"```python\n# YOUR CODE HERE\n```"
)
# Construct test cases
public_test_cases = json.loads(example["public_test_cases"])
try:
private_test_cases = json.loads(example["private_test_cases"])
except Exception as e:
print(f"Error loading private test cases: {e}")
private_test_cases = json.loads(
pickle.loads(zlib.decompress(base64.b64decode(example["private_test_cases"].encode("utf-8"))))
)
full_test_cases = public_test_cases + private_test_cases
metadata = json.loads(example["metadata"])
test_cases = {
"inputs": [t["input"] for t in full_test_cases],
"outputs": [t["output"] for t in full_test_cases],
"fn_name": metadata.get("func_name", None),
}
text_cases_compressed = base64.b64encode(zlib.compress(pickle.dumps(json.dumps(test_cases)))).decode("utf-8")
return query_prompt, text_cases_compressed
data_source = "livecodebench/code_generation_lite"
print(f"Loading the {data_source} dataset from huggingface...", flush=True)
dataset = load_dataset(data_source, split="test")
# R1 Evaluation use LiveCodeBench 24.08-25.01
dataset = dataset.filter(lambda line: "2024-08-00T00:00:00" <= line["contest_date"] < "2025-01-00T00:00:00")
map_fn = partial(
example_map_fn, process_fn=process_livecodebench, data_source=data_source, ability="Code", split="test"
)
dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names, num_proc=8)
return dataset
TASK2DATA = {
"aime2024": build_aime2024_dataset,
"gpqa_diamond": build_gpqa_dimond_dataset,
"cnmo2024": build_cnmo2024_dataset,
"livecodebench": build_livecodebench_dataset,
}
SUPPORTED_TASKS = TASK2DATA.keys()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local_dir", default="~/data/r1")
parser.add_argument("--hdfs_dir", default=None)
parser.add_argument("--tasks", default="all")
args = parser.parse_args()
if args.tasks.lower() == "all":
args.tasks = SUPPORTED_TASKS
else:
args.tasks = [task.strip() for task in args.tasks.split(",") if task.strip()]
for task in args.tasks:
if task not in SUPPORTED_TASKS:
raise NotImplementedError(f"{task} has not been supported.")
datasets = []
for task in args.tasks:
datasets.append(TASK2DATA[task]())
test_dataset = concatenate_datasets(datasets)
local_dir = args.local_dir
hdfs_dir = args.hdfs_dir
test_dataset.to_parquet(os.path.join(local_dir, "test.parquet"))
if hdfs_dir is not None:
makedirs(hdfs_dir)
copy(src=local_dir, dst=hdfs_dir)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Offline evaluate the performance of a generated file using reward model and ground truth verifier.
The input is a parquet file that contains N generated sequences and (optional) the ground truth.
"""
from collections import defaultdict
import hydra
import numpy as np
import pandas as pd
import ray
from tqdm import tqdm
from verl.trainer.ppo.reward import get_custom_reward_fn
from verl.utils.fs import copy_to_local
@ray.remote
def process_item(config, data_source, response_lst, reward_data):
reward_fn = get_custom_reward_fn(config)
ground_truth = reward_data["ground_truth"]
score_lst = [reward_fn(data_source, r, ground_truth) for r in response_lst]
return data_source, np.mean(score_lst)
@hydra.main(config_path="config", config_name="evaluation", version_base=None)
def main(config):
local_path = copy_to_local(config.data.path)
dataset = pd.read_parquet(local_path)
responses = dataset[config.data.response_key]
data_sources = dataset[config.data.data_source_key]
reward_model_data = dataset[config.data.reward_model_key]
total = len(dataset)
# Initialize Ray
if not ray.is_initialized():
ray.init(num_cpus=config.ray_init.num_cpus)
# evaluate test_score based on data source
data_source_reward = defaultdict(list)
# Create remote tasks
remote_tasks = [
process_item.remote(config, data_sources[i], responses[i], reward_model_data[i]) for i in range(total)
]
# Process results as they come in
with tqdm(total=total) as pbar:
while len(remote_tasks) > 0:
# Use ray.wait to get completed tasks
done_ids, remote_tasks = ray.wait(remote_tasks)
for result_id in done_ids:
data_source, score = ray.get(result_id)
data_source_reward[data_source].append(score)
pbar.update(1)
metric_dict = {}
for data_source, rewards in data_source_reward.items():
metric_dict[f"test_score/{data_source}"] = np.mean(rewards)
print(metric_dict)
if __name__ == "__main__":
main()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def reward_func(data_source, solution_str, ground_truth, extra_info=None):
if data_source in ["Maxwell-Jia/AIME_2024", "opencompass/cnmo2024_en", "opencompass/cnmo2024_zh"]:
from recipe.r1.tasks import math
return math.compute_score(solution_str, ground_truth)
elif data_source == "Idavidrein/gpqa":
from recipe.r1.tasks import gpqa
return gpqa.compute_score(solution_str, ground_truth)
elif data_source in ["livecodebench/code_generation_lite", "livecodebench/code_generation"]:
from recipe.r1.tasks import livecodebench
return livecodebench.compute_score(solution_str, ground_truth)
else:
raise NotImplementedError
MODEL_PATH=Qwen/DeepSeek-R1-Distill-Qwen-1.5B
DATA_PATH=/workspace/datasets/r1_bench
# Eval Data Process
python3 -m recipe.r1.data_process \
--local_dir $DATA_PATH \
--tasks all
# Generation
python3 -m verl.trainer.main_generation \
trainer.nnodes=1 \
trainer.n_gpus_per_node=8 \
data.path=$DATA_PATH/test.parquet \
data.prompt_key=prompt \
data.batch_size=1024 \
data.n_samples=8 \
data.output_path=$DATA_PATH/test-output-8.parquet \
model.path=$MODEL_PATH \
rollout.temperature=0.6 \
rollout.top_p=0.95 \
rollout.prompt_length=1024 \
rollout.response_length=32768 \
rollout.tensor_model_parallel_size=1 \
rollout.gpu_memory_utilization=0.9 \
rollout.max_num_batched_tokens=65536
# Evaluation
python3 -m recipe.r1.main_eval \
data.path=$DATA_PATH/test-output-8.parquet \
data.prompt_key=prompt \
data.response_key=responses \
custom_reward_function.path=recipe/r1/reward_score.py \
custom_reward_function.name=reward_func
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
# Extraction Template from https://github.com/openai/simple-evals/blob/90e3e821cabba2aeb6be651dcb662b253df04225/common.py#L25
ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer[ \t]*:[ \t]*\$?([A-D])\$?"
def compute_score(solution_str, ground_truth) -> float:
match = re.search(ANSWER_PATTERN_MULTICHOICE, solution_str)
extracted_answer = match.group(1) if match else None
score = 1.0 if extracted_answer == ground_truth else 0.0
return score
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import json
import multiprocessing
import pickle
import zlib
# Reuse `run_test` for convenience
from verl.utils.reward_score.prime_code.testing_util import run_test
def _temp_run(in_outs, generation, debug, result, metadata_list, timeout):
res, metadata = run_test(in_outs, test=generation, debug=debug, timeout=timeout)
result.append(res)
metadata_list.append(metadata)
def check_correctness(in_outs, generation, timeout, debug=True):
"""Check correctness of code generation with a global timeout.
The global timeout is to catch some extreme/rare cases not handled by the timeouts
inside `run_test`"""
manager = multiprocessing.Manager()
result = manager.list()
metadata_list = manager.list()
p = multiprocessing.Process(
target=_temp_run,
args=(in_outs, generation, debug, result, metadata_list, timeout),
)
p.start()
p.join(timeout=(timeout + 1) * len(in_outs["inputs"]) + 5)
if p.is_alive():
p.kill()
if not result:
# consider that all tests failed
result = [[-1 for i in range(len(in_outs["inputs"]))]]
if debug:
print("global timeout")
return result[0], metadata_list[0]
def compute_score(completion, test_cases):
solution = completion.split("```python")[-1].split("```")[0]
# extract test cases
try:
in_outs = json.loads(test_cases)
except Exception as e:
print(f"Error loading test cases: {e}")
in_outs = json.loads(pickle.loads(zlib.decompress(base64.b64decode(test_cases.encode("utf-8")))))
success = False
try:
res, metadata = check_correctness(in_outs=in_outs, generation=solution, timeout=6, debug=False)
success = all(map(lambda x: x is True, res))
except Exception:
pass
return success
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
try:
from math_verify.metric import math_metric
from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig
except ImportError:
print("To use Math-Verify, please install it first by running `pip install math-verify`.")
def compute_score(model_output: str, ground_truth: str) -> bool:
verify_func = math_metric(
gold_extraction_target=(LatexExtractionConfig(),),
pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
)
ret_score = 0.0
# Wrap the ground truth in \boxed{} format for verification
ground_truth_boxed = "\\boxed{" + ground_truth + "}"
with contextlib.suppress(Exception):
ret_score, _ = verify_func([ground_truth_boxed], [model_output])
return ret_score
# Retool
[ReTool: Reinforcement Learning for Strategic Tool Use in LLMs](https://arxiv.org/abs/2504.11536)
## Overview
- Base model: [Qwen/Qwen2.5-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-32B-Instruct)
- SFT dataset: [JoeYing/ReTool-SFT](https://huggingface.co/datasets/JoeYing/ReTool-SFT)
- RL dataset: [BytedTsinghua-SIA/DAPO-Math-17k](https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k)
- Val dataset: [yentinglin/aime_2025](https://huggingface.co/datasets/yentinglin/aime_2025)
## SFT
1. Data preparation
```bash
python3 recipe/retool/retool_sft_preprocess.py
```
2. Training
```bash
bash recipe/retool/run_qwen2-32b_sft.sh
```
After 6 epoches, validation metrics:
- val-core/aime_2025/acc/mean@30: 0.24
- val-aux/num_turns/mean: 7.2
## RL
### GRPO
```bash
bash recipe/retool/run_qwen2-32b_dapo.sh
```
After 150 steps, validation metrics:
- val-core/aime_2025/acc/mean@30: 0.6
- val-aux/num_turns/mean: 10
### PPO
```bash
bash recipe/retool/run_qwen2-32b_ppo.sh
```
After 250 steps, validation metrics:
- val-core/aime_2025/acc/mean@30: 0.55
- val-aux/num_turns/mean: 8.3
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
from typing import Any
import datasets
from verl.tools.base_tool import OpenAIFunctionToolSchema
from verl.tools.sandbox_fusion_tools import SandboxFusionTool
from verl.utils.dataset import RLHFDataset
from verl.utils.reward_score import math_dapo
from verl.utils.rollout_trace import rollout_trace_op
logger = logging.getLogger(__name__)
class CustomSandboxFusionTool(SandboxFusionTool):
def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):
super().__init__(config, tool_schema)
self.code_pattern = re.compile(r"```python(.*?)```", re.DOTALL)
@rollout_trace_op
async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:
code = parameters["code"]
matches = self.code_pattern.findall(code)
if matches:
code = matches[0].strip()
# NOTE: some script may not explicitly print result, we need to add a print statement to the end of the script
lines = code.split("\n")
for i, line in reversed(list(enumerate(lines))):
if line == "":
continue
if not lines[i].startswith("print"):
lines[i] = f"print({line})"
break
code = "\n".join(lines)
timeout = parameters.get("timeout", self.default_timeout)
language = parameters.get("language", self.default_language)
if not isinstance(code, str):
code = str(code)
result = await self.execution_pool.execute.remote(self.execute_code, instance_id, code, timeout, language)
# sandbox has no score or metrics, use Nones
return result, None, None
answer_format = """\nThe answer format must be: \\boxed{'The final answer goes here.'}"""
class CustomRLHFDataset(RLHFDataset):
"""Custom dataset class to process Maxwell-Jia/AIME_2024, yentinglin/aime_2025 datasets."""
def _read_files_and_tokenize(self):
dataframes = []
for parquet_file in self.data_files:
# read parquet files and cache
dataframe = datasets.load_dataset(parquet_file)["train"]
data_source = "/".join(parquet_file.split("/")[-2:])
if data_source in ["Maxwell-Jia/AIME_2024", "yentinglin/aime_2025"]:
dataframe = dataframe.map(
self.map_fn, fn_kwargs={"data_source": data_source}, remove_columns=dataframe.column_names
)
else:
dataframe = dataframe.map(self.map_fn2, num_proc=16)
dataframes.append(dataframe)
self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes)
print(f"dataset len: {len(self.dataframe)}")
def map_fn(self, row: dict, *, data_source: str = None):
if data_source == "Maxwell-Jia/AIME_2024":
problem, answer = row["Problem"], row["Answer"]
elif data_source == "yentinglin/aime_2025":
problem, answer = row["problem"], row["answer"]
prompt = problem + answer_format
data = {
"data_source": data_source.split("/")[1].lower(), # aime_2024, aime_2025
"prompt": [{"role": "user", "content": prompt}],
"ability": "MATH",
"reward_model": {"ground_truth": str(answer)},
"agent_name": "tool_agent",
}
return data
def map_fn2(self, row: dict):
content = row["prompt"][0]["content"]
row["prompt"][0]["content"] = content + answer_format
row["agent_name"] = "tool_agent"
return row
def compute_score(data_source, solution_str, ground_truth, extra_info):
# use \\boxed{...} answer
result = math_dapo.compute_score(solution_str, ground_truth, strict_box_verify=True)
# encourage model to call tools
num_turns = extra_info["num_turns"]
if result["score"] < 0:
tool_call_reward = (num_turns - 2) / 2 * 0.1
result["score"] = min(0, result["score"] + tool_call_reward)
if result["pred"] is None:
result["pred"] = ""
return result
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Convert JoeYing/ReTool-SFT to standard multi-turn tool calling messages.
"""
import json
import os
import re
from typing import Any
import datasets
from omegaconf import OmegaConf
code_pattern = re.compile(r"```python(.*?)```", re.DOTALL)
def extract_code_message(content: str) -> tuple[dict[str, Any], str]:
start, stop = "<code>", "</code>"
i = content.find(start)
if i == -1:
return None, content
j = content.find(stop)
assert j > i
code = content[i + len(start) : j]
matches = code_pattern.findall(code)
if matches:
code = matches[0].strip()
message = {
"role": "assistant",
"content": content[:i].strip(),
"tool_calls": [
{
"type": "function",
"function": {
"name": "code_interpreter",
"arguments": {"code": code},
},
},
],
}
return message, content[j + len(stop) :]
def extract_answer_message(content: str) -> tuple[dict[str, Any], str]:
start, stop = "<answer>", "</answer>"
i = content.find(start)
if i == -1:
return None, content
j = content.find(stop)
assert j > i
answer = content[:i] + content[i + len(start) : j]
message = {
"role": "assistant",
"content": answer.strip(),
}
return message, content[j + len(stop) :]
def extract_interpreter_message(content: str) -> tuple[dict[str, Any], str]:
start, stop = "<interpreter>", "</interpreter>"
i = content.find(start)
if i == -1:
return None, content
j = content.find(stop)
assert j > i
interpreter = content[i + len(start) : j]
message = {
"role": "tool",
"content": interpreter.strip(),
}
return message, content[j + len(stop) :]
def process(row: dict, *, tools: str):
messages = []
# extract problem
content = row["messages"][0]["content"]
start = "*user question:*"
i = content.find(start)
assert i != -1
prompt = content[i + len(start) :].replace("<answer>", "").replace("</answer>", "").strip()
messages.append(
{
"role": "user",
"content": prompt,
}
)
# extract multi turns
content = row["messages"][1]["content"]
role = "assistant"
while len(content) > 0:
if role == "assistant":
message, content = extract_code_message(content)
if message is None:
message, content = extract_answer_message(content)
assert message is not None
messages.append(message)
role = "tool"
else:
message, content = extract_interpreter_message(content)
assert message is not None
messages.append(message)
role = "assistant"
tools = json.loads(tools)
return {"messages": messages, "tools": tools}
if __name__ == "__main__":
tools_config_file = "recipe/retool/sandbox_fusion_tool_config.yaml"
tools_config = OmegaConf.load(tools_config_file)
tool_schema = OmegaConf.to_container(tools_config["tools"][0]["tool_schema"])
tools = json.dumps([tool_schema])
data = datasets.load_dataset("JoeYing/ReTool-SFT")["train"]
data = data.map(process, fn_kwargs={"tools": tools})
save_path = os.path.expanduser("~/ReTool-SFT/data/train-00000-of-00001.parquet")
data.to_parquet(save_path)
set -x
# ================= data/model/tool =================
HDFS_ROOT=${HDFS_ROOT:-$PWD}
DATA_ROOT=${DATA_ROOT:-$PWD}
dapo_math_17k=$DATA_ROOT/dataset/BytedTsinghua-SIA/DAPO-Math-17k
aime_2024=$DATA_ROOT/dataset/Maxwell-Jia/AIME_2024
aime_2025=$DATA_ROOT/dataset/yentinglin/aime_2025
model_path=$HDFS_ROOT/checkpoint/multiturn-sft-qwen-2.5-32b-instruct/global_step_372
train_files="['$dapo_math_17k']"
test_files="['$aime_2025']"
# tool
tool_config_path=recipe/retool/sandbox_fusion_tool_config.yaml
# wandb
project_name=wuxibin_retool
experiment_name=qwen2.5-32b_dapo
default_local_dir=$DATA_ROOT/checkpoint/$experiment_name
# ================= algorithm =================
adv_estimator=grpo
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.28
max_turns=8
max_prompt_length=2048
max_response_length=16384
actor_lr=1e-6
train_batch_size=512
ppo_mini_batch_size=64
n_resp_per_prompt=16
n_resp_per_prompt_val=30
# ================= perfomance =================
infer_tp=4 # vllm
train_sp=8 # train
offload=True
actor_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 1 ))
log_prob_max_token_len_per_gpu=$(( actor_max_token_len_per_gpu * 4 ))
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=$adv_estimator \
algorithm.use_kl_in_reward=$use_kl_in_reward \
algorithm.kl_ctrl.kl_coef=$kl_coef \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.return_raw_chat=True \
data.train_batch_size=$train_batch_size \
data.max_prompt_length=$max_prompt_length \
data.max_response_length=$max_response_length \
data.filter_overlong_prompts=True \
data.truncation='error' \
data.custom_cls.path=recipe/retool/retool.py \
data.custom_cls.name=CustomRLHFDataset \
custom_reward_function.path=recipe/retool/retool.py \
custom_reward_function.name=compute_score \
actor_rollout_ref.model.path=$model_path \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.use_kl_loss=$use_kl_loss \
actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \
actor_rollout_ref.actor.clip_ratio_low=$clip_ratio_low \
actor_rollout_ref.actor.clip_ratio_high=$clip_ratio_high \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.actor.optim.lr=$actor_lr \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=$train_sp \
actor_rollout_ref.actor.fsdp_config.param_offload=$offload \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=$offload \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$log_prob_max_token_len_per_gpu \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.mode=async \
actor_rollout_ref.rollout.tensor_model_parallel_size=$infer_tp \
actor_rollout_ref.rollout.multi_turn.enable=True \
actor_rollout_ref.rollout.multi_turn.max_user_turns=$max_turns \
actor_rollout_ref.rollout.multi_turn.max_assistant_turns=$max_turns \
actor_rollout_ref.rollout.multi_turn.tool_config_path=$tool_config_path \
actor_rollout_ref.rollout.multi_turn.format=hermes \
actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \
actor_rollout_ref.rollout.n=$n_resp_per_prompt \
actor_rollout_ref.rollout.val_kwargs.top_p=0.6 \
actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \
actor_rollout_ref.rollout.val_kwargs.n=$n_resp_per_prompt_val \
trainer.logger=['console','wandb'] \
trainer.project_name=$project_name \
trainer.experiment_name=$experiment_name \
trainer.n_gpus_per_node=8 \
trainer.val_before_train=True \
trainer.log_val_generations=100 \
trainer.nnodes=2 \
trainer.save_freq=30 \
trainer.default_local_dir=$default_local_dir \
trainer.test_freq=5 \
trainer.total_epochs=1 $@
set -x
# ================= data/model/tool =================
HDFS_ROOT=${HDFS_ROOT:-$PWD}
DATA_ROOT=${DATA_ROOT:-$PWD}
dapo_math_17k=$DATA_ROOT/dataset/BytedTsinghua-SIA/DAPO-Math-17k
aime_2024=$DATA_ROOT/dataset/Maxwell-Jia/AIME_2024
aime_2025=$DATA_ROOT/dataset/yentinglin/aime_2025
actor_model_path=$HDFS_ROOT/checkpoint/multiturn-sft-qwen-2.5-32b-instruct/global_step_372
critic_model_path=$actor_model_path
train_files="['$dapo_math_17k']"
test_files="['$aime_2025']"
# tool
tool_config_path=recipe/retool/sandbox_fusion_tool_config.yaml
# wandb
project_name=wuxibin_retool
experiment_name=qwen2.5-32b_ppo
default_local_dir=$DATA_ROOT/checkpoint/$experiment_name
# ================= algorithm =================
adv_estimator=gae
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.28
max_turns=8
max_prompt_length=2048
max_response_length=16384
actor_lr=1e-6
critic_lr=2e-6
gae_gamma=1.0
gae_lam=1.0
critic_warmup=20
train_batch_size=1024
ppo_mini_batch_size=256
n_resp_per_prompt_val=30
# ================= perfomance =================
infer_tp=4 # vllm
train_sp=4 # train
offload=True
actor_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 2 ))
critic_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 4 ))
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=$adv_estimator \
algorithm.use_kl_in_reward=$use_kl_in_reward \
algorithm.kl_ctrl.kl_coef=$kl_coef \
algorithm.gamma=$gae_gamma \
algorithm.lam=$gae_lam \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.return_raw_chat=True \
data.train_batch_size=$train_batch_size \
data.max_prompt_length=$max_prompt_length \
data.max_response_length=$max_response_length \
data.filter_overlong_prompts=True \
data.truncation='error' \
data.custom_cls.path=recipe/retool/retool.py \
data.custom_cls.name=CustomRLHFDataset \
custom_reward_function.path=recipe/retool/retool.py \
custom_reward_function.name=compute_score \
actor_rollout_ref.model.path=$actor_model_path \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.use_kl_loss=$use_kl_loss \
actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \
actor_rollout_ref.actor.clip_ratio_low=$clip_ratio_low \
actor_rollout_ref.actor.clip_ratio_high=$clip_ratio_high \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.actor.optim.lr=$actor_lr \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=$train_sp \
actor_rollout_ref.actor.fsdp_config.param_offload=$offload \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=$offload \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.mode=async \
actor_rollout_ref.rollout.tensor_model_parallel_size=$infer_tp \
actor_rollout_ref.rollout.multi_turn.enable=True \
actor_rollout_ref.rollout.multi_turn.max_user_turns=$max_turns \
actor_rollout_ref.rollout.multi_turn.max_assistant_turns=$max_turns \
actor_rollout_ref.rollout.multi_turn.tool_config_path=$tool_config_path \
actor_rollout_ref.rollout.multi_turn.format=hermes \
actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \
actor_rollout_ref.rollout.val_kwargs.top_p=0.6 \
actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \
actor_rollout_ref.rollout.val_kwargs.n=$n_resp_per_prompt_val \
critic.optim.lr=$critic_lr \
critic.model.use_remove_padding=True \
critic.model.path=$critic_model_path \
critic.model.enable_gradient_checkpointing=True \
critic.ppo_max_token_len_per_gpu=$critic_max_token_len_per_gpu \
critic.ulysses_sequence_parallel_size=$train_sp \
critic.model.fsdp_config.param_offload=$offload \
critic.model.fsdp_config.optimizer_offload=$offload \
trainer.critic_warmup=$critic_warmup \
trainer.logger=['console','wandb'] \
trainer.project_name=$project_name \
trainer.experiment_name=$experiment_name \
trainer.n_gpus_per_node=8 \
trainer.val_before_train=True \
trainer.log_val_generations=100 \
trainer.nnodes=2 \
trainer.save_freq=30 \
trainer.default_local_dir=$default_local_dir \
trainer.test_freq=5 \
trainer.total_epochs=1 $@
#!/bin/bash
set -x
nnodes=2
nproc_per_node=8
master_addr=
master_port=
experiment_name=multiturn-sft-qwen-2.5-32b-instruct
HDFS_ROOT=${HDFS_ROOT:-$PWD}
DATA_ROOT=${DATA_ROOT:-$PWD}
TRAIN_DATA=$DATA_ROOT/dataset/wuxibin/ReTool-SFT/data/train-00000-of-00001.parquet
EVAL_DATA=$DATA_ROOT/dataset/wuxibin/ReTool-SFT/data/train-00000-of-00001.parquet
MODEL_PATH=$HDFS_ROOT/model/Qwen2.5-32B-Instruct
SAVE_PATH=$DATA_ROOT/checkpoint/$experiment_name
torchrun --nnodes=$nnodes \
--nproc_per_node=$nproc_per_node \
--master-addr=$master_addr \
--master-port=$master_port \
--node-rank=$node_rank \
-m verl.trainer.fsdp_sft_trainer \
data.train_files=$TRAIN_DATA \
data.val_files=$EVAL_DATA \
data.max_length=16384 \
data.train_batch_size=32 \
data.multiturn.enable=true \
data.multiturn.messages_key=messages \
data.multiturn.tools_key=tools \
data.micro_batch_size_per_gpu=4 \
model.partial_pretrain=$MODEL_PATH \
model.strategy=fsdp \
trainer.default_local_dir=$SAVE_PATH \
trainer.project_name=wuxibin-multiturn-sft \
trainer.experiment_name=$experiment_name \
trainer.logger='["console","wandb"]' \
trainer.total_epochs=6 \
ulysses_sequence_parallel_size=4 \
use_remove_padding=true
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment