"examples/contrib/run_openai_gpt.py" did not exist on "edcb56fd96958984fc4bce93e931d5bac61d41c4"
Commit 7f6cc211 authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
Pipeline #2874 failed with stages
in 0 seconds
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
FSDP PPO Trainer with Ray-based single controller.
This trainer supports model-agonistic model initialization with huggingface
"""
import uuid
from collections import defaultdict
from copy import deepcopy
from pprint import pprint
import numpy as np
import torch
from tqdm import tqdm
from verl import DataProto
from verl.trainer.ppo.metric_utils import (
compute_data_metrics,
compute_throughout_metrics,
compute_timing_metrics,
reduce_metrics,
)
from verl.trainer.ppo.ray_trainer import (
AdvantageEstimator,
RayPPOTrainer,
apply_kl_penalty,
compute_advantage,
compute_response_mask,
)
from verl.utils.profiler import simple_timer
class RayEntropyTrainer(RayPPOTrainer):
"""
Note that this trainer runs on the driver process on a single CPU/GPU node.
"""
def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC
to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
from omegaconf import OmegaConf
from verl.utils.tracking import Tracking
logger = Tracking(
project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True),
)
self.global_steps = 0
# load checkpoint before doing anything
self._load_checkpoint()
# perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
val_metrics = self._validate()
assert val_metrics, f"{val_metrics=}"
pprint(f"Initial validation metrics: {val_metrics}")
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get("val_only", False):
return
# add tqdm
progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
# we start from step 1
self.global_steps += 1
last_val_metrics = None
timing_raw = defaultdict(float)
batch = None
num_prompt_in_batch = 0
num_gen_batches = 0
for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
metrics = {}
new_batch: DataProto = DataProto.from_single_dict(batch_dict)
num_gen_batches += 1
# pop those keys for generation
if "multi_modal_inputs" in new_batch.non_tensor_batch.keys():
gen_batch = new_batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs"],
)
else:
gen_batch = new_batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["raw_prompt_ids"],
)
gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
is_last_step = self.global_steps >= self.total_training_steps
with simple_timer("step", timing_raw):
# generate a batch
# with simple_timer("gen", timing_raw):
# gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
with simple_timer("gen", timing_raw):
if not self.async_rollout_mode:
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
else:
gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch)
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
with simple_timer("gen_max", timing_raw):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["do_sample"] = False
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
new_batch = new_batch.union(gen_baseline_output)
reward_baseline_tensor = self.reward_fn(new_batch)
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
new_batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
new_batch.batch["reward_baselines"] = reward_baseline_tensor
del gen_baseline_batch, gen_baseline_output
new_batch.non_tensor_batch["uid"] = np.array(
[str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object
)
# repeat to align with repeated responses in rollout
new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
new_batch = new_batch.union(gen_batch_output)
with simple_timer("reward", timing_raw):
# compute scores. Support both model and function-based.
# We first compute the scores using reward model. Then, we call reward_fn to combine
# the results from reward model and rule-based results.
if self.use_rm:
# we first compute reward model score
reward_tensor = self.rm_wg.compute_rm_score(new_batch)
new_batch = new_batch.union(reward_tensor)
# we combine with rule-based rm
reward_extra_infos_dict: dict[str, list]
try:
reward_result = self.reward_fn(new_batch, return_dict=True)
reward_tensor = reward_result["reward_tensor"]
reward_extra_infos_dict = reward_result["reward_extra_info"]
except Exception as e:
print(f"Error in reward_fn: {e}")
reward_tensor = self.reward_fn(new_batch)
reward_extra_infos_dict = {}
new_batch.batch["token_level_scores"] = reward_tensor
print(f"{list(reward_extra_infos_dict.keys())=}")
if reward_extra_infos_dict:
new_batch.non_tensor_batch.update(
{k: np.array(v) for k, v in reward_extra_infos_dict.items()}
)
# compute rewards. apply_kl_penalty if available
if self.config.algorithm.use_kl_in_reward:
new_batch, kl_metrics = apply_kl_penalty(
new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
)
metrics.update(
kl_metrics
) # TODO: This will be cleared if we use multiple genenration batches
else:
new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"]
if not self.config.algorithm.filter_groups.enable:
batch = new_batch
else: # NOTE: When prompts after filtering is less than train batch size,
# we skip to the next generation batch
metric_name = self.config.algorithm.filter_groups.metric
if metric_name == "seq_final_reward":
# Turn to numpy for easier filtering
new_batch.non_tensor_batch["seq_final_reward"] = (
new_batch.batch["token_level_rewards"].sum(dim=-1).numpy()
)
elif metric_name == "seq_reward":
new_batch.non_tensor_batch["seq_reward"] = (
new_batch.batch["token_level_scores"].sum(dim=-1).numpy()
)
# Collect the sequence reward for each trajectory
prompt_uid2metric_vals = defaultdict(list)
for uid, metric_val in zip(
new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name], strict=True
):
prompt_uid2metric_vals[uid].append(metric_val)
prompt_uid2metric_std = {}
for prompt_uid, metric_vals in prompt_uid2metric_vals.items():
prompt_uid2metric_std[prompt_uid] = np.std(metric_vals)
kept_prompt_uids = [
uid
for uid, std in prompt_uid2metric_std.items()
if std > 0 or len(prompt_uid2metric_vals[uid]) == 1
]
num_prompt_in_batch += len(kept_prompt_uids)
kept_traj_idxs = []
for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"]):
if traj_from_prompt_uid in kept_prompt_uids:
kept_traj_idxs.append(idx)
new_batch = new_batch[kept_traj_idxs]
batch = new_batch if batch is None else DataProto.concat([batch, new_batch])
prompt_bsz = self.config.data.train_batch_size
if num_prompt_in_batch < prompt_bsz:
print(f"{num_prompt_in_batch=} < {prompt_bsz=}")
max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches
if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:
print(f"{num_gen_batches=}. Keep generating...")
continue
else:
raise ValueError(
f"{num_gen_batches=} >= {max_num_gen_batches=}."
+ " Generated too many. Please check if your data are too difficult."
+ " You could also try set max_num_gen_batches=0 to enable endless trials."
)
else:
# Align the batch
traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n
print(
f"Collected {num_prompt_in_batch} / {self.config.data.train_batch_size} prompt. "
f"Collecting finished."
)
batch = batch[:traj_bsz]
# === Updating ===
batch.batch["response_mask"] = compute_response_mask(batch)
# balance the number of valid tokens on each dp rank.
# Note that this breaks the order of data inside the batch.
# Please take care when you implement group based adv computation such as GRPO and rloo
if self.config.trainer.balance_batch:
self._balance_batch(batch, metrics=metrics)
# compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
# recompute old_log_probs
with simple_timer("old_log_prob", timing_raw):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
batch = batch.union(old_log_prob)
if self.use_reference_policy:
# compute reference log_prob
with simple_timer("ref", timing_raw):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
# compute values
if self.use_critic:
with simple_timer("values", timing_raw):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
with simple_timer("adv", timing_raw):
# compute advantages, executed on the driver process
norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True)
batch = compute_advantage(
batch,
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.actor_rollout_ref.rollout.n,
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
)
# update critic
if self.use_critic:
with simple_timer("update_critic", timing_raw):
critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
metrics.update(critic_output_metrics)
# implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
with simple_timer("update_actor", timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)
# validate
if (
self.val_reward_fn is not None
and self.config.trainer.test_freq > 0
and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
):
with simple_timer("testing", timing_raw):
val_metrics: dict = self._validate()
if is_last_step:
last_val_metrics = val_metrics
metrics.update(val_metrics)
if self.config.trainer.save_freq > 0 and (
is_last_step or self.global_steps % self.config.trainer.save_freq == 0
):
with simple_timer("save_checkpoint", timing_raw):
self._save_checkpoint()
# collect metrics
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
# TODO: implement actual tflpo and theoretical tflpo
n_gpus = self.resource_pool_manager.get_n_gpus()
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
timing_raw = defaultdict(float) # clear timing
metrics["train/num_gen_batches"] = num_gen_batches
batch = None
num_prompt_in_batch = 0
num_gen_batches = 0
# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)
if is_last_step:
pprint(f"Final validation metrics: {last_val_metrics}")
progress_bar.close()
return
progress_bar.update(1)
self.global_steps += 1
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
import hydra
import ray
from .entropy_ray_trainer import RayEntropyTrainer
from .reward import load_reward_manager
@hydra.main(config_path="config", config_name="entropy_trainer", version_base=None)
def main(config):
run_ppo(config)
def run_ppo(config) -> None:
if not ray.is_initialized():
# this is for local ray cluster
ray.init(
runtime_env={
"env_vars": {
"TOKENIZERS_PARALLELISM": "true",
"NCCL_DEBUG": "WARN",
"VLLM_LOGGING_LEVEL": "WARN",
"WANDB_API_KEY": "YOUR_WANDB_API_KEY",
}
},
num_cpus=config.ray_init.num_cpus,
)
runner = TaskRunner.remote()
ray.get(runner.run.remote(config))
def merge_dict(a: dict, b: dict) -> dict:
"""Return a new dict that has `a` updated with `b` (b wins on conflicts).
Example::
>>> d1 = {"x": 1, "y": 2}
>>> d2 = {"y": 20, "z": 3}
>>> new_dict = merge_dict(d1, d2)
>>> print(new_dict) # {'x': 1, 'y': 20, 'z': 3}
>>> print(d1) # {"x": 1, "y": 2} (unchanged)
>>> print(d2) # {"y": 20, "z": 3} (unchanged)
"""
return a | b
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
class TaskRunner:
def run(self, config):
# print initial config
from pprint import pprint
from omegaconf import OmegaConf
from verl.utils.fs import copy_to_local
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
# download the checkpoint from hdfs
local_path = copy_to_local(config.actor_rollout_ref.model.path)
print(f"{config.actor_rollout_ref.model.path}")
# instantiate tokenizer
from verl.utils import hf_processor, hf_tokenizer
trust_remote_code = config.data.get("trust_remote_code", False)
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none
# define worker classes
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
assert config.critic.strategy in {"fsdp", "fsdp2"}
from verl.single_controller.ray import RayWorkerGroup
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker
actor_rollout_cls = (
AsyncActorRolloutRefWorker
if config.actor_rollout_ref.rollout.mode == "async"
else ActorRolloutRefWorker
)
ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == "megatron":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
actor_rollout_cls = ActorRolloutRefWorker
ray_worker_group_cls = NVMegatronRayWorkerGroup
else:
raise NotImplementedError
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
role_worker_mapping = {
Role.ActorRollout: ray.remote(actor_rollout_cls),
Role.Critic: ray.remote(CriticWorker),
}
global_pool_id = "global_pool"
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
Role.Critic: global_pool_id,
}
# we should adopt a multi-source reward function here
# - for rule-based rm, we directly call a reward score
# - for model-based rm, we call a model
# - for code related prompt, we send to a sandbox if there are test cases
# - finally, we combine all the rewards together
# - The reward type depends on the tag of the data
if config.reward_model.enable:
if config.reward_model.strategy in {"fsdp", "fsdp2"}:
from verl.workers.fsdp_workers import RewardModelWorker
elif config.reward_model.strategy == "megatron":
from verl.workers.megatron_workers import RewardModelWorker
else:
raise NotImplementedError
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
mapping[Role.RewardModel] = global_pool_id
# use reference model
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
mapping[Role.RefPolicy] = global_pool_id
reward_kwargs = {
"max_resp_len": config.data.max_response_length,
"overlong_buffer_cfg": config.reward_model.overlong_buffer,
}
cfg_reward_kwargs = config.reward_model.get("reward_kwargs", {})
reward_fn = load_reward_manager(
config, tokenizer, num_examine=0, **OmegaConf.merge(OmegaConf.create(reward_kwargs), cfg_reward_kwargs)
)
val_reward_fn = load_reward_manager(config, tokenizer, num_examine=1, **reward_kwargs)
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
from verl.utils.dataset.rl_dataset import collate_fn
train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor)
val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor)
train_sampler = create_rl_sampler(config.data, train_dataset)
trainer = RayEntropyTrainer(
config=config,
tokenizer=tokenizer,
processor=processor,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
train_dataset=train_dataset,
val_dataset=val_dataset,
collate_fn=collate_fn,
train_sampler=train_sampler,
)
trainer.init_workers()
trainer.fit()
def create_rl_dataset(data_paths, data_config, tokenizer, processor):
"""Create a dataset.
Arguments:
data_config: The data config.
tokenizer (Tokenizer): The tokenizer.
processor (Processor): The processor.
Returns:
dataset (Dataset): The dataset.
"""
from torch.utils.data import Dataset
from verl.utils.dataset.rl_dataset import RLHFDataset
if "custom_cls" in data_config and data_config.custom_cls.get("path", None) is not None:
from verl.utils.import_utils import load_extern_type
dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name)
if not issubclass(dataset_cls, Dataset):
raise TypeError(
f"The custom dataset class '{data_config.custom_cls.name}' from '{data_config.custom_cls.path}' "
f"must inherit from torch.utils.data.Dataset"
)
else:
dataset_cls = RLHFDataset
print(f"Using dataset class: {dataset_cls.__name__}")
dataset = dataset_cls(
data_files=data_paths,
tokenizer=tokenizer,
processor=processor,
config=data_config,
)
return dataset
def create_rl_sampler(data_config, dataset):
"""Create a sampler for the dataset.
Arguments:
data_config: The data config.
dataset (Dataset): The dataset.
Returns:
sampler (Sampler): The sampler.
"""
import torch
from torch.utils.data import RandomSampler, SequentialSampler
# use sampler for better ckpt resume
if data_config.shuffle:
train_dataloader_generator = torch.Generator()
train_dataloader_generator.manual_seed(data_config.get("seed", 1))
sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator)
else:
sampler = SequentialSampler(data_source=dataset)
return sampler
if __name__ == "__main__":
main()
# Copyright 2025 Individual Contributor: Thibaut Barroyer
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing
from functools import partial
import ray
from verl import DataProto
from verl.trainer.ppo.reward import compute_reward, get_custom_reward_fn
from .reward_score import _default_compute_score
def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs):
"""
Load and initialize a reward manager based on the configuration.
Args:
config: PPO trainer configuration object containing reward_model fields.
tokenizer: Tokenizer object used for processing text.
num_examine: Number of samples to examine.
**reward_kwargs: Additional keyword arguments for the reward manager.
Returns:
An instance of the specified reward manager class.
"""
from verl.workers.reward_manager import get_reward_manager_cls
# The list of pre-defined reward managers are defined in `verl/workers/reward_manager/`:
# naive: NaiveRewardManager
# prime: PrimeRewardManager
# batch: BatchRewardManager
# dapo: DAPORewardManager
# Note(haibin.lin): For custom reward managers, please make sure they are imported and
# registered via `verl.workers.reward_manager.register`
# By default reward_manager is set to naive (NaiveRewardManager)
reward_manager_name = config.reward_model.get("reward_manager", "naive")
reward_manager_cls = get_reward_manager_cls(reward_manager_name)
# Try to get a custom reward function based on the configuration
compute_score = get_custom_reward_fn(config)
final_compute_score = compute_score
if compute_score is None:
sandbox_config = config.reward_model.get("sandbox_fusion")
sandbox_url = sandbox_config.get("url") if sandbox_config else None
if sandbox_url:
sandbox_manager = multiprocessing.Manager()
# Create a semaphore to control concurrent access to the sandbox
_concurrent_semaphore = sandbox_manager.Semaphore(sandbox_config.get("max_concurrent", 64))
final_compute_score = partial(
_default_compute_score, sandbox_fusion_url=sandbox_url, concurrent_semaphore=_concurrent_semaphore
)
else:
final_compute_score = _default_compute_score
# Instantiate and return the reward manager with the specified parameters
return reward_manager_cls(
tokenizer=tokenizer,
num_examine=num_examine,
compute_score=final_compute_score,
reward_fn_key=config.data.reward_fn_key,
**reward_kwargs,
)
@ray.remote(num_cpus=1)
def compute_reward_async(data: DataProto, config, tokenizer):
"""
Load the reward manager and compute the reward for a batch of data.
This is meant to be run in a separate Ray worker.
"""
reward_fn = load_reward_manager(config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}))
return compute_reward(data, reward_fn)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# from . import gsm8k, math, prime_math, prime_code
import traceback
from . import entropy_math
def _default_compute_score(
data_source, solution_str, ground_truth, extra_info=None, sandbox_fusion_url=None, concurrent_semaphore=None
):
try:
res = entropy_math.compute_score(solution_str, str(ground_truth))
# print(f"data_source: {data_source}")
# raise NotImplementedError(f"Reward function is not implemented for {data_source=}")
if isinstance(res, dict):
return res
elif isinstance(res, int | float | bool):
return float(res)
else:
return float(res[0])
except Exception as e:
print(f"[ERROR] Error in process_completion for task : {str(e)}")
traceback.print_exc() # 打印完整堆栈
raise # 重新抛出异常以便上层捕获
# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except Exception in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides a math answer grading function with high recall.
Based on HF math_verify, verl, open reasoner zero, etc.
"""
import os
import re
import signal
from itertools import islice, zip_longest
from math import isclose
from typing import Optional
import sympy
from latex2sympy2_extended import latex2sympy
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
from pylatexenc import latex2text
from sympy import N, simplify
from sympy.parsing import sympy_parser
from sympy.parsing.latex import parse_latex
from sympy.parsing.sympy_parser import parse_expr
"""
This code is adapted from: Dr. GRPO (https://github.com/sail-sg/understand-r1-zero/blob/main/understand_r1_zero/math_grader.py).
"""
def timeout_ours(timeout_seconds: int = 8):
if os.name == "posix":
import signal
def decorator(func):
def handler(signum, frame):
raise TimeoutError("Operation timed out!")
def wrapper(*args, **kwargs):
old_handler = signal.getsignal(signal.SIGALRM)
signal.signal(signal.SIGALRM, handler)
signal.alarm(timeout_seconds)
try:
return func(*args, **kwargs)
finally:
signal.alarm(0)
signal.signal(signal.SIGALRM, old_handler)
return wrapper
return decorator
else:
raise NotImplementedError(f"Unsupported OS: {os.name}")
# Dan Hendrycks' code
def mathd_normalize_answer(answer: Optional[str]) -> Optional[str]:
if answer is None:
return None
answer = answer.strip()
try:
# Remove enclosing `\text{}`.
m = re.search("^\\\\text\{(?P<text>.+?)\}$", answer)
if m is not None:
answer = m.group("text").strip()
return _strip_string(answer)
except Exception:
return answer
# units mainly from MathQA
unit_texts = [
"east",
"degree",
"mph",
"kmph",
"ft",
"m sqaure",
" m east",
"sq m",
"deg",
"mile",
"q .",
"monkey",
"prime",
"ratio",
"profit of rs",
"rd",
"o",
"gm",
"p . m",
"lb",
"tile",
"per",
"dm",
"lt",
"gain",
"ab",
"way",
"west",
"a .",
"b .",
"c .",
"d .",
"e .",
"f .",
"g .",
"h .",
"t",
"a",
"h",
"no change",
"men",
"soldier",
"pie",
"bc",
"excess",
"st",
"inches",
"noon",
"percent",
"by",
"gal",
"kmh",
"c",
"acre",
"rise",
"a . m",
"th",
"π r 2",
"sq",
"mark",
"l",
"toy",
"coin",
"sq . m",
"gallon",
"° f",
"profit",
"minw",
"yr",
"women",
"feet",
"am",
"pm",
"hr",
"cu cm",
"square",
"v â € ™",
"are",
"rupee",
"rounds",
"cubic",
"cc",
"mtr",
"s",
"ohm",
"number",
"kmph",
"day",
"hour",
"minute",
"min",
"second",
"man",
"woman",
"sec",
"cube",
"mt",
"sq inch",
"mp",
"∏ cm ³",
"hectare",
"more",
"sec",
"unit",
"cu . m",
"cm 2",
"rs .",
"rs",
"kg",
"g",
"month",
"km",
"m",
"cm",
"mm",
"apple",
"liter",
"loss",
"yard",
"pure",
"year",
"increase",
"decrease",
"d",
"less",
"Surface",
"litre",
"pi sq m",
"s .",
"metre",
"meter",
"inch",
]
unit_texts.extend([t + "s" for t in unit_texts])
def _strip_string(string):
def _fix_fracs(string):
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except Exception:
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string
def _fix_a_slash_b(string):
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
a = int(a)
b = int(b)
assert string == "{}/{}".format(a, b)
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
except Exception:
return string
def _remove_right_units(string):
# "\\text{ " only ever occurs (at least in the val set) when describing units
if "\\text{ " in string:
splits = string.split("\\text{ ")
assert len(splits) == 2
return splits[0]
else:
return string
def _fix_sqrt(string):
if "\\sqrt" not in string:
return string
splits = string.split("\\sqrt")
new_string = splits[0]
for split in splits[1:]:
if split[0] != "{":
a = split[0]
new_substr = "\\sqrt{" + a + "}" + split[1:]
else:
new_substr = "\\sqrt" + split
new_string += new_substr
return new_string
# linebreaks
string = string.replace("\n", "")
# print(string)
# remove inverse spaces
string = string.replace("\\!", "")
# print(string)
# replace \\ with \
string = string.replace("\\\\", "\\")
# print(string)
# matrix
string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string)
string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string)
string = string.replace("bmatrix", "pmatrix")
# replace tfrac and dfrac with frac
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
string = string.replace("\\neq", "\\ne").replace("\\leq", "\\le").replace("\\geq", "\\ge")
# print(string)
# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\right", "")
# print(string)
# Remove unit: miles, dollars if after is not none
_string = re.sub(r"\\text{.*?}$", "", string).strip()
if _string != "" and _string != string:
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
string = _string
# Remove unit: texts
for _ in range(2):
for unit_text in unit_texts:
# use regex, the prefix should be either the start of the string or a non-alphanumeric character
# the suffix should be either the end of the string or a non-alphanumeric character
_string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string)
if _string != "":
string = _string
# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")
# remove dollar signs
string = string.replace("\\$", "")
# remove units (on the right)
string = _remove_right_units(string)
# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "")
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split("=")) == 2:
if len(string.split("=")[0]) <= 2:
string = string.split("=")[1]
# fix sqrt3 --> sqrt{3}
string = _fix_sqrt(string)
# remove spaces
string = string.replace(" ", "")
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1).
# Also does a/b --> \\frac{a}{b}
string = _fix_fracs(string)
# manually change 0.5 --> \frac{1}{2}
if string == "0.5":
string = "\\frac{1}{2}"
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = _fix_a_slash_b(string)
return string
SUBSTITUTIONS = [
("an ", ""),
("a ", ""),
(".$", "$"),
("\\$", ""),
(r"\ ", ""),
(" ", ""),
("mbox", "text"),
(",\\text{and}", ","),
("\\text{and}", ","),
("\\text{m}", "\\text{}"),
]
REMOVED_EXPRESSIONS = [
"square",
"ways",
"integers",
"dollars",
"mph",
"inches",
"ft",
"hours",
"km",
"units",
"\\ldots",
"sue",
"points",
"feet",
"minutes",
"digits",
"cents",
"degrees",
"cm",
"gm",
"pounds",
"meters",
"meals",
"edges",
"students",
"childrentickets",
"multiples",
"\\text{s}",
"\\text{.}",
"\\text{\ns}",
"\\text{}^2",
"\\text{}^3",
"\\text{\n}",
"\\text{}",
r"\mathrm{th}",
r"^\circ",
r"^{\circ}",
r"\;",
r",\!",
"{,}",
'"',
"\\dots",
]
def normalize_final_answer(final_answer: str) -> str:
"""
Normalize a final answer to a quantitative reasoning question.
This code comes from https://arxiv.org/pdf/2206.14858.pdf, page18.
"""
# final_answer = final_answer.split("=")[-1]
for before, after in SUBSTITUTIONS:
final_answer = final_answer.replace(before, after)
for expr in REMOVED_EXPRESSIONS:
final_answer = final_answer.replace(expr, "")
# Extract answer that is in LaTeX math, is bold,
# is surrounded by a box, etc.
final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
# Normalize shorthand TeX:
# \fracab -> \frac{a}{b}
# \frac{abc}{bef} -> \frac{abc}{bef}
# \fracabc -> \frac{a}{b}c
# \sqrta -> \sqrt{a}
# \sqrtab -> sqrt{a}b
final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
final_answer = final_answer.replace("$", "")
# Normalize 100,000 -> 100000
if final_answer.replace(",", "").isdigit():
final_answer = final_answer.replace(",", "")
return final_answer
def repeatness(s: str):
def ranks(seq):
index = {v: i for i, v in enumerate(sorted(set(seq)))}
return [index[v] for v in seq]
def suffixArray(s):
line = ranks(s)
n, k, ans, sa = len(s), 1, line, [0] * len(s)
while k < n - 1:
line = ranks(list(zip_longest(line, islice(line, k, None), fillvalue=-1)))
ans, k = line, k << 1
for i, k in enumerate(ans):
sa[k] = i
return ans, sa
def lcp(arr, suffixArr, inv_suff):
n, ans, k = len(arr), [0] * len(arr), 0
for i in range(n):
if inv_suff[i] == n - 1:
k = 0
continue
j = suffixArr[inv_suff[i] + 1]
while i + k < n and j + k < n and arr[i + k] == arr[j + k]:
k += 1
ans[inv_suff[i]] = k
if k > 0:
k -= 1
return ans
arr = [ord(i) for i in s]
n = len(arr)
if n <= 1:
return 0
c, sa = suffixArray(arr)
cnt = sum(lcp(arr, sa, c))
return (cnt * 2 / (n * (n + 1))) > 0.2
class timeout:
def __init__(self, seconds=1, error_message="Timeout"):
self.seconds = seconds
self.error_message = error_message
def handle_timeout(self, signum, frame):
raise TimeoutError(self.error_message)
def __enter__(self):
signal.signal(signal.SIGALRM, self.handle_timeout)
signal.alarm(self.seconds)
def __exit__(self, type, value, traceback):
signal.alarm(0)
def latex_eval(latex):
sym = parse_latex(latex)
val = sym.evalf()
return sym, val
def numeric_equal(prediction: float, reference: float):
# Note that relative tolerance has significant impact
# on the result of the synthesized GSM-Hard dataset
# if reference.is_integer():
# return isclose(reference, round(prediction), abs_tol=1e-4)
# else:
# prediction = round(prediction, len(str(reference).split(".")[-1]))
return isclose(reference, prediction, rel_tol=1e-4)
@timeout_ours(timeout_seconds=5)
def symbolic_equal(a, b):
def _parse(s):
for f in [parse_latex, parse_expr, latex2sympy]:
try:
return f(s.replace("\\\\", "\\"))
except Exception:
try:
return f(s)
except Exception:
pass
return s
a = _parse(a)
b = _parse(b)
# direct equal
try:
if str(a) == str(b) or a == b:
return True
except Exception:
pass
# simplify equal
try:
if a.equals(b) or simplify(a - b) == 0:
return True
except Exception:
pass
# equation equal
try:
if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)):
return True
except Exception:
pass
try:
if numeric_equal(float(N(a)), float(N(b))):
return True
except Exception:
pass
# matrix
try:
# if a and b are matrix
if a.shape == b.shape:
_a = a.applyfunc(lambda x: round(x, 3))
_b = b.applyfunc(lambda x: round(x, 3))
if _a.equals(_b):
return True
except Exception:
pass
return False
def _is_latex_equal(str1, str2):
try:
sym1, val1 = latex_eval(str1)
sym2, val2 = latex_eval(str2)
if sym1 == sym2 or val1 == val2:
return True
else:
raise ValueError
except Exception: # noqa
try:
norm1, norm2 = normalize_final_answer(str1), normalize_final_answer(str2)
sym1, val1 = latex_eval(norm1)
sym2, val2 = latex_eval(norm2)
if sym1 == sym2 or val1 == val2:
return True
except Exception: # noqa
return norm1 == norm2
return False
def is_latex_equal(given_answer: str, ground_truth: str) -> bool:
try:
with timeout(1):
try:
if (len(given_answer) > 128 and repeatness(given_answer)) or (
len(ground_truth) > 128 and repeatness(ground_truth)
):
return False
# First conduct normalized string matching.
ground_truth_normalized = _normalize(ground_truth)
given_normalized = _normalize(given_answer)
if ground_truth_normalized is None:
return False
if ground_truth_normalized == given_normalized:
return True
# Next call math verify.
given_answer.replace("\n", "")
ground_truth.replace("\n", "")
if "$" not in given_answer:
given_answer = f"${given_answer}$"
if "$" not in ground_truth:
ground_truth = f"${ground_truth}$"
return verify(
parse(
ground_truth,
extraction_config=(
LatexExtractionConfig(boxed_match_priority=0),
ExprExtractionConfig(),
),
fallback_mode="no_fallback",
extraction_mode=["first_match"],
parsing_timeout=1,
),
parse(
given_answer,
extraction_config=(
LatexExtractionConfig(boxed_match_priority=0),
ExprExtractionConfig(),
),
fallback_mode="no_fallback",
extraction_mode=["first_match"],
parsing_timeout=1,
),
timeout_seconds=1,
)
# or symbolic_equal(ground_truth, given_answer)
except Exception:
return False
except TimeoutError:
return False
def is_value_equal(given_answer: str, ground_truth: str) -> bool:
assert ground_truth is not None
ground_truth_normalized_mathd = mathd_normalize_answer(ground_truth)
given_answer_normalized_mathd = mathd_normalize_answer(given_answer)
str_equal = ground_truth_normalized_mathd == given_answer_normalized_mathd
try:
number_equal = float(ground_truth_normalized_mathd) == float(given_answer_normalized_mathd)
return str_equal or number_equal
except Exception:
return str_equal
# sympy might hang -- we don't care about trying to be lenient in these cases
BAD_SUBSTRINGS = ["^{", "^("]
BAD_REGEXES = ["\^[0-9]+\^", "\^[0-9][0-9]+"]
TUPLE_CHARS = "()[]"
def _sympy_parse(expr: str):
"""Parses an expression with sympy."""
py_expr = expr.replace("^", "**")
return sympy_parser.parse_expr(
py_expr,
transformations=(sympy_parser.standard_transformations + (sympy_parser.implicit_multiplication_application,)),
)
def _parse_latex(expr: str) -> str:
"""Attempts to parse latex to an expression sympy can read."""
expr = expr.replace("\\tfrac", "\\frac")
expr = expr.replace("\\dfrac", "\\frac")
expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers.
expr = latex2text.LatexNodes2Text().latex_to_text(expr)
# Replace the specific characters that this parser uses.
expr = expr.replace("√", "sqrt")
expr = expr.replace("π", "pi")
expr = expr.replace("∞", "inf")
expr = expr.replace("∪", "U")
expr = expr.replace("·", "*")
expr = expr.replace("×", "*")
return expr.strip()
def _is_float(num: str) -> bool:
try:
float(num)
return True
except ValueError:
return False
def _is_int(x: float) -> bool:
try:
return abs(x - int(round(x))) <= 1e-7
except Exception:
return False
def _is_frac(expr: str) -> bool:
return bool(re.search(r"^-?[0-9]+.?/0*[1-9][0-9]*.?$", expr))
def _str_is_int(x: str) -> bool:
try:
x = _strip_properly_formatted_commas(x)
x = float(x)
return abs(x - int(round(x))) <= 1e-7
except Exception:
return False
def _str_to_int(x: str) -> bool:
x = x.replace(",", "")
x = float(x)
return int(x)
def _inject_implicit_mixed_number(step: str):
"""
Automatically make a mixed number evalable
e.g. 7 3/4 => 7+3/4
"""
p1 = re.compile("([0-9]) +([0-9])")
step = p1.sub("\\1+\\2", step) ## implicit mults
return step
def _strip_properly_formatted_commas(expr: str):
# We want to be careful because we don't want to strip tuple commas
p1 = re.compile("(\d)(,)(\d\d\d)($|\D)")
while True:
next_expr = p1.sub("\\1\\3\\4", expr)
if next_expr == expr:
break
expr = next_expr
return next_expr
def _normalize(expr: str) -> str:
"""Normalize answer expressions."""
if expr is None:
return None
# Remove enclosing `\text{}`.
m = re.search("^\\\\text\{(?P<text>.+?)\}$", expr)
if m is not None:
expr = m.group("text")
expr = expr.replace("\\%", "%")
expr = expr.replace("\\$", "$")
expr = expr.replace("$", "")
expr = expr.replace("%", "")
expr = expr.replace(" or ", " , ")
expr = expr.replace(" and ", " , ")
expr = expr.replace("million", "*10^6")
expr = expr.replace("billion", "*10^9")
expr = expr.replace("trillion", "*10^12")
for unit in [
"degree",
"cm",
"centimeter",
"meter",
"mile",
"second",
"minute",
"hour",
"day",
"week",
"month",
"year",
"foot",
"feet",
"inch",
"yard",
]:
expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr)
expr = re.sub("\^ *\\\\circ", "", expr)
if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}":
expr = expr[1:-1]
expr = re.sub(",\\\\! *", "", expr)
if _is_float(expr) and _is_int(float(expr)):
expr = str(int(round(float(expr))))
if "\\" in expr:
try:
expr = _parse_latex(expr)
except Exception:
pass
# edge case with mixed numbers and negative signs
expr = re.sub("- *", "-", expr)
expr = _inject_implicit_mixed_number(expr)
expr = expr.replace(" ", "")
# if we somehow still have latex braces here, just drop them
expr = expr.replace("{", "")
expr = expr.replace("}", "")
# don't be case sensitive for text answers
expr = expr.lower()
if _str_is_int(expr):
expr = str(_str_to_int(expr))
return expr
def count_unknown_letters_in_expr(expr: str):
expr = expr.replace("sqrt", "")
expr = expr.replace("frac", "")
letters_in_expr = set([x for x in expr if x.isalpha()])
return len(letters_in_expr)
def should_allow_eval(expr: str):
# we don't want to try parsing unknown text or functions of more than two variables
if count_unknown_letters_in_expr(expr) > 2:
return False
for bad_string in BAD_SUBSTRINGS:
if bad_string in expr:
return False
for bad_regex in BAD_REGEXES:
if re.search(bad_regex, expr) is not None:
return False
return True
@timeout_ours(timeout_seconds=5)
def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str):
are_equal = False
try:
expr = f"({ground_truth_normalized})-({given_normalized})"
if should_allow_eval(expr):
sympy_diff = _sympy_parse(expr)
simplified = sympy.simplify(sympy_diff)
if simplified == 0:
are_equal = True
except Exception:
pass
return are_equal
def split_tuple(expr: str):
"""
Split the elements in a tuple/interval, while handling well-formatted commas in large numbers
"""
expr = _strip_properly_formatted_commas(expr)
if len(expr) == 0:
return []
if (
len(expr) > 2
and expr[0] in TUPLE_CHARS
and expr[-1] in TUPLE_CHARS
and all([ch not in expr[1:-1] for ch in TUPLE_CHARS])
):
elems = [elem.strip() for elem in expr[1:-1].split(",")]
else:
elems = [expr]
return elems
def last_boxed_only_string(string):
idx = string.rfind("\\boxed")
if idx < 0:
idx = string.rfind("\\fbox")
if idx < 0:
return None
i = idx
right_brace_idx = None
num_left_braces_open = 0
while i < len(string):
if string[i] == "{":
num_left_braces_open += 1
if string[i] == "}":
num_left_braces_open -= 1
if num_left_braces_open == 0:
right_brace_idx = i
break
i += 1
if right_brace_idx is None:
retval = None
else:
retval = string[idx : right_brace_idx + 1]
return retval
def remove_boxed(s):
left = "\\boxed{"
try:
assert s[: len(left)] == left
assert s[-1] == "}"
return s[len(left) : -1]
except Exception:
return None
def extract_boxed_answer(solution: str) -> str:
"""Extract the answer from inside a LaTeX \\boxed{} command"""
solution = last_boxed_only_string(solution)
solution = remove_boxed(solution)
return solution
def grade_answer_sympy(given_answer: str, ground_truth: str) -> bool:
ground_truth_normalized = _normalize(ground_truth)
given_normalized = _normalize(given_answer)
if ground_truth_normalized is None:
return False
if ground_truth_normalized == given_normalized:
return True
if len(given_normalized) == 0:
return False
ground_truth_elems = split_tuple(ground_truth_normalized)
given_elems = split_tuple(given_normalized)
if len(ground_truth_elems) > 1 and (
ground_truth_normalized[0] != given_normalized[0] or ground_truth_normalized[-1] != given_normalized[-1]
):
is_correct = False
elif len(ground_truth_elems) != len(given_elems):
is_correct = False
else:
for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems, strict=True):
if _is_frac(ground_truth_elem) and _is_frac(given_elem):
# if fractions aren't reduced, then shouldn't be marked as correct
# so, we don't want to allow sympy.simplify in this case
is_correct = ground_truth_elem == given_elem
elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem):
# if the ground truth answer is an integer, we require the given answer to be a strict match
# (no sympy.simplify)
is_correct = False
else:
is_correct = are_equal_under_sympy(ground_truth_elem, given_elem)
if not is_correct:
break
return is_correct
def grade_answer_mathd(given_answer: str, ground_truth: str) -> bool:
ground_truth_normalized_mathd = mathd_normalize_answer(ground_truth)
given_answer_normalized_mathd = mathd_normalize_answer(given_answer)
# be at least as lenient as mathd
if ground_truth_normalized_mathd == given_answer_normalized_mathd:
return True
return False
def extract_answer(passage: str) -> str:
if "\\boxed" in passage:
return extract_boxed_answer(passage)
return None
def grade(model_answer: str, gt_answer: str, fast: bool = True):
if "\\boxed" in gt_answer:
gt_answer = extract_answer(gt_answer)
correct = grade_answer_mathd(model_answer, gt_answer) or grade_answer_sympy(model_answer, gt_answer)
if not fast:
# This mode further uses math_verify to recall originally false positives.
# Will be a bit slower, and sensitive to bad inputs.
correct = correct or is_latex_equal(
model_answer,
gt_answer,
)
return correct
def compute_score(model_response, gt_answer, fast=False):
model_answer = extract_answer(model_response)
if model_answer is None:
return {
"score": 0.0,
"format_score": 0.0,
"acc": False,
"extracted_gt": gt_answer,
# "extracted_pred": None,
}
# return 0.0, 0.0 # Cannot even parse anything.
is_correct = False
if isinstance(gt_answer, float) or isinstance(gt_answer, int):
gt_answer = str(gt_answer)
if isinstance(gt_answer, str):
is_correct = grade(model_answer, gt_answer, fast)
elif isinstance(gt_answer, list):
is_correct = False
for gt in gt_answer:
is_correct |= grade(model_answer, gt, fast)
if is_correct:
return {
"score": 1.0,
"format_score": 1.0,
"acc": True,
"extracted_gt": gt_answer,
# "extracted_pred": None,
}
else:
return {
"score": 0.0,
"format_score": 1.0,
"acc": False,
"extracted_gt": gt_answer,
# "extracted_pred": None,
}
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) Microsoft Corporation.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE
# Copyright (c) 2023 OpenAI
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# Copyright (c) 2021 Dan Hendrycks
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from:
- https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py
- https://github.com/microsoft/ProphetNet/tree/master/CRITIC
- https://github.com/openai/prm800k
"""
import contextlib
import math
import re
from math import isclose
# sympy related
from sympy import N, simplify
from sympy.parsing.latex import parse_latex
from sympy.parsing.sympy_parser import parse_expr
# verl related
from verl.utils.py_functional import timeout_limit
def is_digit(s):
try:
if "{,}" in str(s):
num = float(str(s).replace("{,}", ""))
return True, num
num = float(str(s).replace(",", ""))
return True, num
except ValueError:
return False, None
def normalize(answer, pi) -> str:
# checking if answer is $<number> and removing $ in that case to compare
if isinstance(answer, str) and bool(re.match(r"\$\d+(\.\d+)?", answer)):
return answer[1:]
# checking if answer is <number>% or <number>\\% and removing %
if isinstance(answer, str) and (
bool(re.match(r"^\d+(\.\d+)?%$", answer)) or bool(re.match(r"^\d+(\.\d+)?\\%$", answer))
):
return answer.replace("\\%", "").replace("%", "")
# handle base
answer = handle_base(answer)
# handle pi
answer = handle_pi(answer, pi)
return answer
def handle_base(x) -> str:
if isinstance(x, str) and "_" in x:
# Due to base
x = x.split("_")[0]
x = float(x)
return int(x)
return x
def handle_pi(string, pi):
if isinstance(string, str) and "\pi" in string:
# Find the first occurrence of "\pi"
idx = string.find("\pi")
# Iterate over the string and find all occurrences of "\pi" with a valid previous character
while idx != -1:
if idx > 0 and string[idx - 1].isdigit():
# Replace "\pi" with "*math.pi" if the previous character is a digit
string = string[:idx] + f"*{pi}" + string[idx + 3 :]
else:
# Replace "\pi" with "1*math.pi" if the previous character is not a digit
string = string[:idx] + f"1*{pi}" + string[idx + 3 :]
# Find the next occurrence of "\pi"
idx = string.find("\pi", idx + 1)
# Evaluate the expression using eval() function
with contextlib.suppress(Exception):
string = eval(string)
return string
def math_equal(
prediction: bool | float | str,
reference: float | str,
include_percentage: bool = True,
tolerance: float = 1e-4,
timeout: float = 10.0,
pi: float = math.pi,
) -> bool:
"""
Exact match of math if and only if:
1. numerical equal: both can convert to float and are equal
2. symbolic equal: both can convert to sympy expression and are equal
"""
prediction = normalize(prediction, pi)
reference = normalize(reference, pi)
if isinstance(prediction, str) and len(prediction) > 1000: # handling weird corner-cases
prediction = prediction[:1000]
# 0. string comparison
if isinstance(prediction, str) and isinstance(reference, str):
if prediction.strip().lower() == reference.strip().lower():
return True
if prediction.replace(" ", "") == reference.replace(" ", ""):
return True
try: # 1. numerical equal
if is_digit(prediction)[0] and is_digit(reference)[0]:
prediction = is_digit(prediction)[1]
reference = is_digit(reference)[1]
# number questions
gt_result = [reference / 100, reference, reference * 100] if include_percentage else [reference]
for item in gt_result:
try:
if isclose(item, prediction, rel_tol=tolerance):
return True
except Exception:
continue
return False
except Exception:
pass
if not prediction and prediction not in [0, False]:
return False
# 2. symbolic equal
reference = str(reference).strip()
prediction = str(prediction).strip()
## deal with [], (), {}
prediction = format_intervals(prediction)
pred_str, ref_str = prediction, reference
if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or (
prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")
):
pred_str = pred_str.strip("[]()")
ref_str = ref_str.strip("[]()")
for s in ["{", "}", "(", ")"]:
ref_str = ref_str.replace(s, "")
pred_str = pred_str.replace(s, "")
if pred_str == ref_str:
return True
## [a, b] vs. [c, d], return a==c and b==d
if (
prediction
and reference
and prediction[0] in "(["
and prediction[-1] in ")]"
and prediction[0] == reference[0]
and prediction[-1] == reference[-1]
):
pred_parts = prediction[1:-1].split(",")
ref_parts = reference[1:-1].split(",")
if len(pred_parts) == len(ref_parts) and all(
[
math_equal(pred_pt, ref_pt, include_percentage, tolerance)
for pred_pt, ref_pt in zip(pred_parts, ref_parts, strict=True)
]
):
return True
if "," in prediction and "," in reference:
pred_parts = [item.strip() for item in prediction.split(",")]
ref_parts = [item.strip() for item in reference.split(",")]
if len(pred_parts) == len(ref_parts):
return bool(
all(
[
math_equal(pred_parts[i], ref_parts[i], include_percentage, tolerance)
for i in range(len(pred_parts))
]
)
)
# if we have point == tuple of values
if prediction.startswith("Point") and reference[0] == "(" and reference[-1] == ")":
pred_parts = prediction[prediction.find("(") + 1 : -1].split(",")
ref_parts = reference[1:-1].split(",")
if len(pred_parts) == len(ref_parts) and all(
[
math_equal(pred_pt, ref_pt, include_percentage, tolerance)
for pred_pt, ref_pt in zip(pred_parts, ref_parts, strict=True)
]
):
return True
# if reference is a matrix
if "\begin{pmatrix}" in reference and prediction.startswith("Matrix"):
try:
pred_matrix = parse_expr(prediction)
ref_matrix_items = reference.split()[1:-1:2]
if len(pred_matrix) == len(ref_matrix_items) and all(
[
math_equal(pred, ref, include_percentage, tolerance)
for ref, pred in zip(ref_matrix_items, pred_matrix, strict=True)
]
):
return True
except Exception:
pass
elif "\begin{pmatrix}" in reference and prediction.startswith("[") and prediction.endswith("]"):
if isinstance(eval(prediction), list):
try:
pred_matrix = eval(prediction)
# ref_matrix_items = reference.split()[1:-1:2]
ref_matrix_items = (
reference.lstrip("\\begin{pmatrix}") # noqa: B005
.lstrip("\begin{pmatrix}")
.rstrip("\\end{pmatrix}")
.rstrip("\end{pmatrix}")
) # noqa: B005
ref_matrix_items = ref_matrix_items.split("\\")
ref_matrix_items = [row.split("&") if "&" in row else row for row in ref_matrix_items]
if len(pred_matrix) == len(ref_matrix_items) and all(
[
math_equal(pred, ref, include_percentage, tolerance)
for ref, pred in zip(ref_matrix_items, pred_matrix, strict=True)
]
):
return True
except Exception:
pass
return symbolic_equal(prediction, reference, tolerance, timeout)
def symbolic_equal(a, b, tolerance, timeout=10.0):
def _parse(s):
for f in [parse_expr, parse_latex]:
try:
with timeout_limit(seconds=timeout):
return f(s)
except TimeoutError:
print(f"Parsing timed out for {s}")
continue
except Exception:
continue
return s
a = _parse(a)
b = _parse(b)
try:
with timeout_limit(seconds=timeout):
if simplify(a - b) == 0:
return True
except TimeoutError:
print(f"Simplification timed out for {a} - {b}")
pass
except Exception:
pass
try:
with timeout_limit(seconds=timeout):
if isclose(N(a), N(b), rel_tol=tolerance):
return True
except TimeoutError:
print(f"Numerical evaluation timed out for {a}, {b}")
pass
except Exception:
pass
return False
def format_intervals(prediction):
patterns = {
"Interval(": r"^Interval\((.*)\)$",
"Interval.Ropen(": r"^Interval\.Ropen\((.*)\)$",
"Interval.Lopen(": r"^Interval\.Lopen\((.*)\)$",
"Interval.open(": r"^Interval\.open\((.*)\)$",
}
for key, pattern in patterns.items():
match = re.match(pattern, prediction)
if match:
inner_content = match.group(1)
if key == "Interval(": # Intarval(a, b) == [a, b]
return f"[{inner_content}]"
elif key == "Interval.Ropen(": # Intarval.Ropen(a, b) == [a, b)
return f"[{inner_content})"
elif key == "Interval.Lopen(": # Intarval.Lopen(a, b) == (a, b]
return f"({inner_content}]"
elif key == "Interval.open(": # Intarval.open(a, b) == (a, b)
return f"({inner_content})"
return prediction
# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 Dan Hendrycks
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
This logic is largely copied from the Hendrycks' MATH release (math_equivalence).
From: https://github.com/openai/prm800k/blob/main/prm800k/grading/math_normalize.py
"""
import re
from typing import Optional
def normalize_answer(answer: Optional[str]) -> Optional[str]:
if answer is None:
return None
answer = answer.strip()
try:
# Remove enclosing `\text{}`.
m = re.search("^\\\\text\{(?P<text>.+?)\}$", answer)
if m is not None:
answer = m.group("text").strip()
return _strip_string(answer)
except: # noqa: E722
return answer
def _fix_fracs(string):
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except: # noqa: E722
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string
def _fix_a_slash_b(string):
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
a = int(a)
b = int(b)
assert string == "{}/{}".format(a, b)
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
except: # noqa: E722
return string
def _remove_right_units(string):
# "\\text{ " only ever occurs (at least in the val set) when describing units
if "\\text{ " in string:
splits = string.split("\\text{ ")
assert len(splits) == 2
return splits[0]
else:
return string
def _fix_sqrt(string):
if "\\sqrt" not in string:
return string
splits = string.split("\\sqrt")
new_string = splits[0]
for split in splits[1:]:
if split[0] != "{":
a = split[0]
new_substr = "\\sqrt{" + a + "}" + split[1:]
else:
new_substr = "\\sqrt" + split
new_string += new_substr
return new_string
def _strip_string(string):
# linebreaks
string = string.replace("\n", "")
# remove inverse spaces
string = string.replace("\\!", "")
# replace \\ with \
string = string.replace("\\\\", "\\")
# replace tfrac and dfrac with frac
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\right", "")
# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")
# remove dollar signs
string = string.replace("\\$", "")
# remove units (on the right)
string = _remove_right_units(string)
# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "")
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split("=")) == 2 and len(string.split("=")[0]) <= 2:
string = string.split("=")[1]
# fix sqrt3 --> sqrt{3}
string = _fix_sqrt(string)
# remove spaces
string = string.replace(" ", "")
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1).
# Also does a/b --> \\frac{a}{b}
string = _fix_fracs(string)
# manually change 0.5 --> \frac{1}{2}
if string == "0.5":
string = "\\frac{1}{2}"
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = _fix_a_slash_b(string)
return string
# Generative Reward Model
## Scripts
### Step 1: Launch a vLLM Server (Optional)
Deploy the pretrained GenRM model using vLLM. Skip this step if you want to use an external api service.
```bash
VLLM_SERVER_DEV_MODE=1 vllm serve verl-team/GenRM-CI-Test-1.5B --served-model-name genrm-demo --enable-sleep-mode --dtype float32
```
Note that the wake_up and sleep operations for managing CUDA memory in vLLM are only available when both `VLLM_SERVER_DEV_MODE=1` and `enable_sleep_mode` are set. This capability is particularly beneficial when the model server shares resources with a training workload on the same machine. It allows the reward model service to be temporarily offloaded (to free up GPU memory) during intensive training sessions and reloaded when the service is required again. The relevant vllm code implementation can be found below:
[VLLM_SERVER_DEV_MODE](https://github.com/vllm-project/vllm/blob/5a19a6c6705fe83db2e3517a2d2f473586901743/vllm/entrypoints/openai/api_server.py#L971)
[sleep and wake_up mode](https://github.com/vllm-project/vllm/blob/5a19a6c6705fe83db2e3517a2d2f473586901743/vllm/entrypoints/openai/api_server.py#L994-L1003)
When the backend is configured as `SERVER_BACKEND`="VLLM", the `USE_OFFLOAD` flag can be toggled between True and False.(see `reward_function.py`)
### Step 2: Perform RL using GenRM
```bash
bash recipe/api-genrm/run_genrm_remote.sh
```
The implementation works by passing a customized reward function (see `reward_function.py`)
For convenience, we run both the RL training and server on the same machine. To use an external server, configure the `BASE_URL` and `API_KEY` in `reward_function.py` first.
## Advanced: Customizing Your GenRM
You can use sglang server with data parallel for faster inference:
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m sglang_router.launch_server --model-path verl-team/GenRM-CI-Test-1.5B --dp-size 4
```
Note that you should modify the `BASE_URL` in `reward_function.py` to match your SGLang Server address.
You can also create your own customized GenRM by implementing a custom reward function. Here are some tips for customizing your own GenRM based on `reward_function.py`:
- Design appropriate prompts for your GenRM
- Convert GenRM responses into RL rewards
- ...
Since these aspects are highly flexible, we only provide a demo implementation. The actual design and implementation of GenRM is left to the user's discretion.
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from time import sleep
import requests
import tqdm
from verl.utils.reward_score.math import last_boxed_only_string, remove_boxed
SERVER_BACKEND = "VLLM"
USE_OFFLOAD = True
BASE_URL = "http://localhost:30000"
API_KEY = "EMPTY"
MAX_RETRIES = 3
BASE_DELAY = 2
MAX_WORKERS = 32
MODEL_NAME = "genrm-demo"
GENRM_PROMPT_TEMPLATE = """
The following is a math problem and an AI solution:
[Math Problem]
{problem}
[AI Solution]
{solution}
Your task is to review and critique the solution step by step, and output whether the AI solution is correct.
Please put your final answer (i.e., 'True' or 'False') in \\boxed{{}}.
""".strip()
def vllm_execute_method(task="sleep"):
assert task in ["sleep", "wake_up"], f"Invalid task: {task}"
url_root = BASE_URL
response = requests.post(url_root + "/" + task)
assert response.status_code == 200
def get_response(problem, solution_str, ground_truth):
prompt = GENRM_PROMPT_TEMPLATE.format(problem=problem, solution=solution_str)
messages = [{"role": "user", "content": prompt}]
for attempt in range(MAX_RETRIES):
try:
headers = {"Content-Type": "application/json"}
chat_url = f"{BASE_URL}/v1/chat/completions"
data = {"model": MODEL_NAME, "messages": messages}
output = requests.post(chat_url, headers=headers, json=data, timeout=30)
response = output.json()["choices"][0]["message"]["content"]
return response
except Exception as e:
if attempt < MAX_RETRIES - 1:
print("Exception: ", repr(e))
delay = BASE_DELAY * (2**attempt)
print(f"Retrying in {delay} seconds...")
sleep(delay)
else:
print(f"Failed after {MAX_RETRIES} attempts. Error: {e}")
raise ConnectionRefusedError(f"Failed to run the model for {prompt}!")
def compute_reward(response):
reward_score = 0.0
try:
boxed_result = last_boxed_only_string(response)
if boxed_result is not None:
result = remove_boxed(boxed_result)
reward_score = float(result == "True")
except Exception as e:
print(e)
return reward_score
def compute_score(data_source, solution_str, ground_truth, extra_info, index):
split = extra_info["split"]
from verl.utils.reward_score import default_compute_score
func_rm_score = default_compute_score(data_source, solution_str, ground_truth, extra_info)
if split == "test":
return func_rm_score, index
else:
problem = extra_info["question"]
response = get_response(problem, solution_str, ground_truth)
if response is not None:
reward_score = compute_reward(response)
else:
reward_score = 0.0
return reward_score, index
def compute_score_batch(data_sources, solution_strs, ground_truths, extra_infos):
results = []
indexes = list(range(len(data_sources)))
if SERVER_BACKEND == "VLLM" and USE_OFFLOAD:
vllm_execute_method("wake_up")
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
futures = []
for data_source, solution_str, ground_truth, extra_info, index in zip(
data_sources, solution_strs, ground_truths, extra_infos, indexes, strict=True
):
future = executor.submit(compute_score, data_source, solution_str, ground_truth, extra_info, index)
time.sleep(0.001 * random.random())
futures.append(future)
for future in tqdm.tqdm(as_completed(futures), total=len(futures)):
results.append(future.result())
results = sorted(results, key=lambda x: x[-1], reverse=False)
results = [result[0] for result in results]
if SERVER_BACKEND == "VLLM" and USE_OFFLOAD:
vllm_execute_method("sleep")
return results
# vllm server
# CUDA_VISIBLE_DEVICES=0,1,2,3 vllm serve verl-team/GenRM-CI-Test-1.5B --served_model_name genrm-demo
# sglang server
# CUDA_VISIBLE_DEVICES=0,1,2,3 python -m sglang_router.launch_server --model-path verl-team/GenRM-CI-Test-1.5B --dp-size 4
set -x
CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=${HOME}/data/gsm8k/train.parquet \
data.val_files=${HOME}/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.max_prompt_length=1024 \
data.max_response_length=2048 \
data.filter_overlong_prompts=True \
data.truncation='error' \
actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
actor_rollout_ref.rollout.n=8 \
algorithm.use_kl_in_reward=False \
reward_model.reward_manager=batch \
custom_reward_function.path=recipe/genrm_remote/reward_function.py \
custom_reward_function.name=compute_score_batch \
trainer.critic_warmup=0 \
trainer.logger='["console","wandb"]' \
trainer.project_name='verl_func_rm_example_gsm8k' \
trainer.experiment_name='qwen2_5_3b_gen_rm' \
trainer.n_gpus_per_node=4 \
trainer.val_before_train=True \
trainer.nnodes=1 \
trainer.save_freq=20 \
trainer.test_freq=5 \
trainer.total_epochs=10 \
trainer.resume_mode='disable'
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Ref: https://python.langchain.com/docs/how_to/custom_chat_model/
"""
import asyncio
import json
import logging
import os
import uuid
from typing import Any, Optional
from langchain_core.language_models import BaseChatModel
from langchain_core.language_models.base import LanguageModelInput
from langchain_core.messages import (
AIMessage,
BaseMessage,
convert_to_openai_messages,
)
from langchain_core.messages.tool import InvalidToolCall, ToolCall
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.tools import StructuredTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic import Field
from verl.experimental.agent_loop.agent_loop import AgentLoopOutput, AsyncLLMServerManager
from verl.experimental.agent_loop.tool_parser import ToolParser
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
class MaxTokenExceededError(Exception):
"""Indicate that history chat messages + tool message exceeds LLM max_tokens."""
pass
class ChatModel(BaseChatModel):
model_name: str = Field(alias="model")
"""The name of the model"""
client: AsyncLLMServerManager
"""AsyncLLM server manager"""
tokenizer: Any
"""Tokenizer for the model"""
max_tokens: int
"""Max tokens to generate"""
tool_parser: str = "hermes"
"""Tool parser for the model"""
max_parallel_calls: int = 1
"""Max parallel tool calls"""
temperature: float = 1.0
"""Temperature for sampling"""
top_p: float = 1.0
"""Top p for sampling"""
repetition_penalty: float = 1.0
"""Repetition penalty for sampling"""
def bind_tools(self, tools, **kwargs) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tools to the model.
Args:
tools: Sequence of tools to bind to the model.
Returns:
A Runnable that returns a message.
"""
formatted_tools: list = [convert_to_openai_tool(tool) for tool in tools]
# used to remove system prompt prefix when encoding tool response
system_prompt = self.tokenizer.apply_chat_template([{}], add_generation_prompt=False, tokenize=True)
kwargs["system_prompt"] = system_prompt
return self.bind(tools=formatted_tools, **kwargs)
def with_structured_output(
self,
schema: dict | type,
*,
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, dict | BaseChatModel]:
"""Ref: https://langchain-ai.github.io/langgraph/how-tos/react-agent-structured-output/"""
raise NotImplementedError
def _generate(
self,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> ChatResult:
raise NotImplementedError
async def _agenerate(
self,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> ChatResult:
"""Asynchronously generate chat completion message.
Args:
messages (list[BaseMessage]): List of list of messages.
stop (Optional[list[str]], optional): Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings. Defaults to None.
Returns:
ChatResult: Chat result.
"""
request_id, prompt_ids, response_mask = await self._preprocess(messages, **kwargs)
sampling_params = {
"temperature": self.temperature,
"top_p": self.top_p,
"repetition_penalty": self.repetition_penalty,
}
if "sampling_params" in kwargs:
sampling_params.update(kwargs["sampling_params"])
response_ids = await self.client.generate(
request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params
)
message = await self._postprocess(request_id, prompt_ids, response_mask, response_ids, **kwargs)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@property
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model."""
return self.model_name
async def _preprocess(self, messages: list[BaseMessage], **kwargs: Any) -> tuple[str, list[int], list[int]]:
"""Preprocess messages for chat completion.
To ensure strong consistency with policy model, AsyncLLM server generate response with token in token out
instead of messages list.
But all agent frameworks use messages list to represent chat history. To mitigate the gap, we store trajectory
(prompt_ids, response_mask) in lastest AIMessage.response_metadata.
1. Encode ToolMessage to token ids.
2. Retrieve trajectory (prompt_ids, response_mask) from lastest AIMessage.response_metadata.
3. Append ToolMessage token ids to prompt_ids, and append 0 to response_mask.
Ref: https://python.langchain.com/docs/concepts/chat_history/
Args:
messages (list[BaseMessage]): List of messages.
Returns:
tuple[str, list[int], list[int]]: Request id, prompt ids, response mask.
"""
# messages: [system], human, ai, human|tool, ai, human|tool, ...
assert messages[-1].type in ["human", "tool"], (
f"Last message must be human or tool, but got {messages[-1].type}"
)
loop = asyncio.get_running_loop()
# Case 1: initial chat completion: [system], human
if messages[-1].type == "human" and (len(messages) == 1 or messages[-2].type != "ai"):
prompt_ids = await loop.run_in_executor(
None,
lambda: self.tokenizer.apply_chat_template(
convert_to_openai_messages(messages),
tools=kwargs.get("tools"),
add_generation_prompt=True,
tokenize=True,
),
)
return str(uuid.uuid4()), prompt_ids, []
# Case 2: follow up chat completion with tool/human response: [system], human, ai, human|tool, ...
for i in range(len(messages) - 1, -1, -1):
if messages[i].type == "ai":
break
assert "prompt_ids" in messages[i].response_metadata, "Last message must have prompt_ids in response_metadata"
assert "response_mask" in messages[i].response_metadata, (
"Last message must have response_mask in response_metadata"
)
# encode tool response
tool_responses = convert_to_openai_messages(messages[i + 1 :])
tool_response_ids = await loop.run_in_executor(
None,
lambda messages=tool_responses: self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True
),
)
tool_response_ids = tool_response_ids[len(kwargs["system_prompt"]) :]
# stop generation if response length exceeds max response length
if len(messages[i].response_metadata["response_mask"]) + len(tool_response_ids) >= self.max_tokens:
raise MaxTokenExceededError(f"Max response length {self.max_tokens} exceeded")
# append tool response to prompt
request_id = messages[i].response_metadata.pop("request_id")
prompt_ids = messages[i].response_metadata.pop("prompt_ids")
response_mask = messages[i].response_metadata.pop("response_mask")
prompt_ids += tool_response_ids
response_mask += [0] * len(tool_response_ids)
return request_id, prompt_ids, response_mask
async def _postprocess(
self, request_id: str, prompt_ids: list[int], response_mask: list[int], response_ids: list[int], **kwargs: Any
) -> AIMessage:
"""Postprocess response_ids when chat completion is done.
1. Decode response_ids, parse tool calls to AIMessage.
2. Append response_ids to prompt_ids, and append 1 to response_mask.
3. Store trajectory (prompt_ids, response_mask) in AIMessage.response_metadata.
Args:
request_id (str): Unique request id.
prompt_ids (list[int]): Input prompt token ids in this chat completion.
response_mask (list[int]): Response mask before this chat completion.
response_ids (list[int]): LLM generated token ids in this chat completion.
Returns:
AIMessage: Postprocessed message.
"""
prompt_ids += response_ids
response_mask += [1] * len(response_ids)
tool_parser = ToolParser.get_tool_parser(self.tool_parser, self.tokenizer)
content, function_calls = await tool_parser.extract_tool_calls(response_ids)
tool_calls, invalid_tool_calls = [], []
for function_call in function_calls:
try:
args = json.loads(function_call.arguments)
if not isinstance(args, dict):
raise json.JSONDecodeError(f"Invalid json tool arguments: {args}")
tool_call = ToolCall(
args=args,
name=function_call.name,
id=str(uuid.uuid4()),
)
tool_calls.append(tool_call)
except json.JSONDecodeError as e:
logger.warning(f"Invalid json tool arguments: {e}")
tool_call = InvalidToolCall(
args=function_call.arguments,
name=function_call.name,
error=f"Invalid json tool arguments: {e}",
)
invalid_tool_calls.append(tool_call)
message = AIMessage(
content=content,
tool_calls=tool_calls[: self.max_parallel_calls],
invalid_tool_calls=invalid_tool_calls[: self.max_parallel_calls],
response_metadata={
"request_id": request_id,
"prompt_ids": prompt_ids,
"response_mask": response_mask,
},
)
return message
class TruncateStructuredTool(StructuredTool):
"""Structured tool with response truncation."""
tool_response_truncate_side: str
"""truncate side of tool response: left, middle, right"""
max_tool_response_length: int
"""max length of tool response"""
async def _arun(
self,
*args: Any,
config: RunnableConfig,
**kwargs: Any,
) -> Any:
tool_response = await super()._arun(*args, config=config, **kwargs)
tool_response = str(tool_response)
if len(tool_response) > self.max_tool_response_length:
if self.tool_response_truncate_side == "left":
tool_response = tool_response[: self.max_tool_response_length] + "...(truncated)"
elif self.tool_response_truncate_side == "right":
tool_response = "(truncated)..." + tool_response[-self.max_tool_response_length :]
else:
length = self.max_tool_response_length // 2
tool_response = tool_response[:length] + "...(truncated)..." + tool_response[-length:]
return tool_response
def convert_to_agent_output(messages: list[BaseMessage], response_length: int) -> AgentLoopOutput:
"""Convert messages to AgentLoopOutput.
Args:
messages (List[BaseMessage]): List of messages, last message must be assistant
with response_metadata containing `prompt_ids` and `response_mask`.
response_length (int): Max length of response.
Returns:
AgentLoopOutput: agent loop output trajectory used for training.
"""
# skip last tool calls
for i in range(len(messages) - 1, -1, -1):
if messages[i].type != "tool":
break
last_message = messages[i]
assert last_message.type == "ai", f"Last message must be assistant, but got {last_message.type}"
assert "prompt_ids" in last_message.response_metadata, "Last message must have prompt_ids in response_metadata"
assert "response_mask" in last_message.response_metadata, (
"Last message must have response_mask in response_metadata"
)
num_turns = 0
for i in range(len(messages)):
if messages[i].type == "system":
continue
# parallel tool calls are in single turn
if i == 0 or messages[i].type != messages[i - 1].type:
num_turns += 1
prompt_ids = last_message.response_metadata["prompt_ids"]
response_mask = last_message.response_metadata["response_mask"]
response_ids = prompt_ids[-len(response_mask) :]
prompt_ids = prompt_ids[: len(prompt_ids) - len(response_mask)]
output = AgentLoopOutput(
prompt_ids=prompt_ids,
response_ids=response_ids[:response_length],
response_mask=response_mask[:response_length],
num_turns=num_turns,
metrics={},
)
return output
# MathExpression: LangGraph Agent Example
MathExpression is a tiny example to demonstrate multi-turn rollout with [LangGraph ReactAgent](https://langchain-ai.github.io/langgraph/agents/overview/).
### Define react agent with tool
Firstly, to force ReactAgent to evaluate math expression by tool, we define a special operand `@`:
```python
@tool(parse_docstring=True)
def calculate(a: int, b: int, operand: str) -> int:
"""
Compute the results using operand with two integers
Args:
a: the first operand
b: the second operand
operand: '+' or '-' or '*' or '@'
"""
assert operand in ["+", "-", "*", "@"], f"unknown operand {operand}"
if operand == "@":
return 3 * a - 2 * b
return eval(f"{a} {operand} {b}")
```
Without calling `calculate`, ReactAgent is impossible to evaluate math expression correctly.
Then, we can equip ReactAgent with `calculate` tool:
```python
class MathExpressionReactAgentLoop(ReactAgentLoop):
@classmethod
def init_class(cls, config, tokenizer):
cls.tools = [calculate]
super().init_class(config, tokenizer)
```
We can define agent loop config in yaml file, which will be used by AgentLoopWorker to dynamic load custom AgentLoop class.
```yaml
- name: math_expression
_target_: recipe.langgraph_agent.example.math_expression.MathExpressionReactAgentLoop
```
### Prepare dataset
Now, let's prepare two small datasets for training and evaluation:
```bash
python recipe/langgraph_agent/example/create_dataset.py
```
Note that dataset should contain a column `agent_name` with `math_expression`, which is used by `AgentLoopWorker` to select the
agent loop class.
| prompt | reward_model | agent_name |
|--------------------------------------|------------------------------|-----------------|
| [{'role': 'user', 'content': '...'}] | {'ground_truth': '-10', ...} | math_expression |
| [{'role': 'user', 'content': '...'}] | {'ground_truth': '-10', ...} | math_expression |
Generated math expressions are like below, requiring model to call `calculate` multiple times to solve sub expressions.
```
(2 @ (8 @ 8 @ 5 @ 5 @ 3) @ 6 @ (1 @ 4 @ 4 @ 4) @ 2) @ 6
(4.6 @ (9.05 @ 4.0) @ 8.3 @ 1.21) @ 8.6
9 @ 4
((2 @ 2) @ (3 @ 3)) @ 4
```
### Training
Hook all these up and start training:
```bash
bash recipe/langgraph_agent/example/run_qwen2.5_3b.sh 2>&1 | tee train.log
```
After total 39 steps, model should achieve 100% accuray on test dataset:
- val-aux/lighteval/MATH/reward: 1.0
- val-aux/num_turns/mean: 9.0, average number of messages include assistant and tool turns.
We can see model can solve math expression by calling `calculate` tool:
```
system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.
user
We define a new math operator @, where you can only call an external tool to compute. Please put your final answer inside \boxed{} only in the last turn. Now answer the following questions:
Compute 5 @ (3 @ 6 @ 8)
assistant
Sure, let's break this down step by step. We need to compute the expression \(5 @ (3 @ 6 @ 8)\). First, we need to compute the inner expression \(3 @ 6 @ 8\).
Let's compute \(3 @ 6\) first using the external tool.
<tool_call>
{"name": "calculate", "arguments": {"a": 3, "b": 6, "operand": "@"}}
</tool_call>user
<tool_response>
-3
</tool_response>
assistant
We have computed that \(3 @ 6 = -3\). Now we need to compute \(-3 @ 8\).
Let's use the external tool to compute this.
<tool_call>
{"name": "calculate", "arguments": {"a": -3, "b": 8, "operand": "@"}}
</tool_call>user
<tool_response>
-25
</tool_response>
assistant
We have computed that \(-3 @ 8 = -25\). Now we can put this result into the original expression to get the final result of \(5 @ (3 @ 6 @ 8)\), which is \(5 @ (-25)\).
Let's compute \(5 @ (-25)\) using the external tool.
<tool_call>
{"name": "calculate", "arguments": {"a": 5, "b": -25, "operand": "@"}}
</tool_call>user
<tool_response>
65
</tool_response>
assistant
The final result of \(5 @ (3 @ 6 @ 8)\) is \(\boxed{65}\).
```
- name: math_expression
_target_: recipe.langgraph_agent.example.math_expression.MathExpressionReactAgentLoop
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Create dataset for calculator
"""
import random
import pandas as pd
def generate_math_expression(min_terms=2, max_terms=5, min_number=1, max_number=10, allow_decimals=False, max_depth=2):
"""
Generate a random mathematical expression with operators +, -, *, /, and parentheses.
Args:
min_terms (int): Minimum number of terms in the expression.
max_terms (int): Maximum number of terms in the expression.
max_number (int): Maximum value for numbers in the expression.
allow_decimals (bool): Whether to allow decimal numbers.
max_depth (int): Maximum nesting depth for parentheses.
Returns:
str: A valid mathematical expression as a string.
"""
def generate_number():
"""Generate a random number (integer or float)."""
assert min_number < max_number
num = random.uniform(min_number, max_number)
if not allow_decimals:
num = int(num)
else:
num = round(num, random.randint(0, 2)) # Round to 0-2 decimal places
return str(num)
def generate_term(depth=0):
"""Generate a term (number or parenthesized expression)."""
if depth < max_depth and random.random() < 0.5: # 50% chance to add parentheses
expr = generate_expression(depth + 1)
return f"({expr})"
else:
return generate_number()
def generate_expression(depth=0):
"""Generate a full expression with multiple terms and operators."""
num_terms = random.randint(min_terms, max_terms)
terms = [generate_term(depth) for _ in range(num_terms)]
# Randomly select operators
operators = ["+", "-", "*", "/", "@"]
expr = terms[0]
for i in range(1, num_terms):
# Bias towards + and - for readability
op = random.choices(
operators,
weights=[0, 0, 0, 0, 1], # + and - are 1.5x more likely than * and /
)[0]
expr += f" {op} " + terms[i]
return expr
return generate_expression()
def test():
# Example 1: Basic integer expression
print(generate_math_expression())
# Output: (3 + 7) * 2 - 5
# Example 2: Expression with decimals
print(generate_math_expression(allow_decimals=True))
# Output: 4.5 / (2.1 + 3.7) - 1.2
# Example 3: More complex expression with higher depth
print(generate_math_expression(max_terms=6, max_depth=3))
# Output: ((5 * 2) - (3 + 1)) / (7 - 2) + 4
# Example 4: Simplified expression
print(generate_math_expression(min_terms=2, max_terms=3, max_number=5))
# Output: 4 - 2 * 3
def calculate(expression: str) -> float:
"""
Evaluate a mathematical expression with +, -, *, /, @, and parentheses.
The @ operator is defined as: a @ b = 3a - 2b.
Args:
expression (str): Input mathematical expression (e.g., "3@2+4").
Returns:
float: Result of the evaluated expression.
Raises:
ValueError: For invalid expressions (e.g., mismatched parentheses, division by zero).
"""
def tokenize(s: str) -> list:
"""Convert the input string into tokens (numbers, operators, parentheses)."""
tokens = []
i = 0
while i < len(s):
if s[i].isdigit() or s[i] == ".":
# Parse number (integer or float)
j = i
while j < len(s) and (s[j].isdigit() or s[j] == "."):
j += 1
tokens.append(s[i:j])
i = j
elif s[i] in "+-*/@()":
# Operator or parenthesis
tokens.append(s[i])
i += 1
elif s[i].isspace():
# Skip whitespace
i += 1
else:
raise ValueError(f"Invalid character: {s[i]}")
return tokens
def infix_to_postfix(tokens: list) -> list:
"""Convert infix notation to postfix notation (Reverse Polish Notation)."""
output = []
stack = []
# Higher precedence for @ (between * and +)
precedence = {"@": 3, "*": 2, "/": 2, "+": 1, "-": 1}
for token in tokens:
if token.isdigit() or "." in token:
output.append(token)
elif token == "(":
stack.append(token)
elif token == ")":
while stack and stack[-1] != "(":
output.append(stack.pop())
if not stack or stack[-1] != "(":
raise ValueError("Mismatched parentheses")
stack.pop() # Discard '('
else: # Operator
while stack and stack[-1] != "(" and precedence.get(stack[-1], 0) >= precedence.get(token, 0):
output.append(stack.pop())
stack.append(token)
# Pop remaining operators
while stack:
if stack[-1] in "()":
raise ValueError("Mismatched parentheses")
output.append(stack.pop())
return output
def evaluate_postfix(postfix: list) -> float:
"""Evaluate postfix expression using a stack."""
stack = []
for token in postfix:
if token.isdigit() or "." in token:
stack.append(float(token))
else:
if len(stack) < 2:
raise ValueError("Invalid expression")
b = stack.pop()
a = stack.pop()
if token == "+":
res = a + b
elif token == "-":
res = a - b
elif token == "*":
res = a * b
elif token == "/":
if b == 0:
raise ValueError("Division by zero")
res = a / b
elif token == "@":
res = 3 * a - 2 * b # Custom @ operator implementation
else:
raise ValueError(f"Invalid operator: {token}")
stack.append(res)
if len(stack) != 1:
raise ValueError("Invalid expression")
return stack[0]
# Remove spaces and validate parentheses
expression = expression.replace(" ", "")
if expression.count("(") != expression.count(")"):
raise ValueError("Mismatched parentheses")
tokens = tokenize(expression)
postfix = infix_to_postfix(tokens)
result = evaluate_postfix(postfix)
# Convert integers to integer representation
if result.is_integer():
return int(result)
return result
def generate_data(total_num_dataset, split):
rl_dataset = {
"prompt": [],
"data_source": [],
"ability": [],
"reward_model": [],
"extra_info": [],
"agent_name": [],
}
for idx in range(total_num_dataset):
while True:
try:
expression: str = generate_math_expression(
min_terms=2, max_terms=3, min_number=1, max_number=10, allow_decimals=False, max_depth=1
)
num_plus = expression.count("+")
num_minus = expression.count("-")
num_mul = expression.count("*")
num_star = expression.count("@")
answer = str(calculate(expression))
# answer = str(eval(expression))
break
except Exception as e:
print(e)
continue
num_tool_calls = num_plus + num_minus + num_mul + num_star
prompt = (
f"We define a new math operator @, where you can only call an external tool to compute. "
f"Please put your final answer inside \\boxed{{}} only in the last turn. Now answer the "
f"following questions:\nCompute {expression}"
)
prompt_with_template = [
{
"role": "user",
"content": prompt,
}
]
rl_dataset["prompt"].append(prompt_with_template)
rl_dataset["data_source"].append("lighteval/MATH")
rl_dataset["ability"].append("math")
rl_dataset["reward_model"].append({"style": "lighteval/MATH", "ground_truth": answer})
rl_dataset["extra_info"].append(
{"index": idx, "expression": expression, "split": split, "expected_tool_calls": num_tool_calls}
)
rl_dataset["agent_name"].append("math_expression")
rl_dataset = pd.DataFrame(data=rl_dataset)
return rl_dataset
if __name__ == "__main__":
# print(calculate("3@2")) # Output: 5 (3*3 - 2*2)
# print(calculate("3@2+4")) # Output: 9 (5 + 4)
# print(calculate("3*(4@2)")) # Output: 24 (3 * 8)
# print(calculate("(5@3)*2")) # Output: 18 (9 * 2)
train_dataset = generate_data(total_num_dataset=5000, split="train")
test_dataset = generate_data(total_num_dataset=500, split="test")
train_dataset.to_parquet("train.parquet")
test_dataset.to_parquet("test.parquet")
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from langchain_core.tools import tool
from recipe.langgraph_agent.react_agent_loop import ReactAgentLoop
@tool(parse_docstring=True)
def calculate(a: int, b: int, operand: str) -> int:
"""
Compute the results using operand with two integers
Args:
a: the first operand
b: the second operand
operand: '+' or '-' or '*' or '@'
"""
assert operand in ["+", "-", "*", "@"], f"unknown operand {operand}"
if operand == "@":
return 3 * a - 2 * b
return eval(f"{a} {operand} {b}")
class MathExpressionReactAgentLoop(ReactAgentLoop):
@classmethod
def init_class(cls, config, tokenizer, **kwargs):
cls.tools = [calculate]
super().init_class(config, tokenizer)
set -x
# ================= data/model/tool =================
HDFS_ROOT=${HDFS_ROOT:-$PWD}
DATA_ROOT=${DATA_ROOT:-$PWD}
model_path=$DATA_ROOT/model/Qwen2.5-3B-Instruct
train_files=$DATA_ROOT/dataset/math_expression_tool/train.parquet
test_files=$DATA_ROOT/dataset/math_expression_tool/test.parquet
# agent
agent_loop_config_path=recipe/langgraph_agent/example/agent.yaml
# wandb
project_name=math_expression_tool
experiment_name=qwen2.5-3b
default_local_dir=$DATA_ROOT/checkpoint/$experiment_name
# ================= algorithm =================
adv_estimator=grpo
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.28
max_turns=8
max_prompt_length=1024
max_response_length=2048
actor_lr=1e-6
train_batch_size=128
ppo_mini_batch_size=16
n_resp_per_prompt=8
n_resp_per_prompt_val=1
# ================= perfomance =================
infer_tp=2 # vllm
train_sp=4 # train
offload=True
actor_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 4 ))
log_prob_max_token_len_per_gpu=$(( actor_max_token_len_per_gpu * 2 ))
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=$adv_estimator \
algorithm.use_kl_in_reward=$use_kl_in_reward \
algorithm.kl_ctrl.kl_coef=$kl_coef \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.return_raw_chat=True \
data.train_batch_size=$train_batch_size \
data.max_prompt_length=$max_prompt_length \
data.max_response_length=$max_response_length \
data.filter_overlong_prompts=True \
data.truncation='error' \
actor_rollout_ref.model.path=$model_path \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.use_kl_loss=$use_kl_loss \
actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \
actor_rollout_ref.actor.clip_ratio_low=$clip_ratio_low \
actor_rollout_ref.actor.clip_ratio_high=$clip_ratio_high \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.actor.optim.lr=$actor_lr \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=$train_sp \
actor_rollout_ref.actor.fsdp_config.param_offload=$offload \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=$offload \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$log_prob_max_token_len_per_gpu \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.mode=async \
actor_rollout_ref.rollout.tensor_model_parallel_size=$infer_tp \
actor_rollout_ref.rollout.multi_turn.max_user_turns=$max_turns \
actor_rollout_ref.rollout.multi_turn.max_assistant_turns=$max_turns \
actor_rollout_ref.rollout.multi_turn.format=hermes \
actor_rollout_ref.rollout.agent.agent_loop_config_path=$agent_loop_config_path \
actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \
actor_rollout_ref.rollout.n=$n_resp_per_prompt \
actor_rollout_ref.rollout.val_kwargs.top_p=0.6 \
actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \
actor_rollout_ref.rollout.val_kwargs.n=$n_resp_per_prompt_val \
trainer.logger=['console','wandb'] \
trainer.project_name=$project_name \
trainer.experiment_name=$experiment_name \
trainer.n_gpus_per_node=$ARNOLD_WORKER_GPU \
trainer.val_before_train=True \
trainer.log_val_generations=50 \
trainer.nnodes=$ARNOLD_WORKER_NUM \
trainer.save_freq=-1 \
trainer.default_local_dir=$default_local_dir \
trainer.test_freq=5 \
trainer.total_epochs=1 $@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
LangGraph React Agent Loop.
This implementation is exact same as `ToolAgentLoop`.
Ref: https://langchain-ai.github.io/langgraph/tutorials/workflows/
"""
from typing import Any, Literal
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode
from recipe.langgraph_agent.chat_model import (
ChatModel,
MaxTokenExceededError,
convert_to_agent_output,
)
from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput
async def call_model(state: MessagesState, config: RunnableConfig):
model = config["configurable"]["model"]
sampling_params = config["configurable"]["sampling_params"]
try:
message = await model.ainvoke(state["messages"], sampling_params=sampling_params)
return {"messages": [message]}
except MaxTokenExceededError:
# last message is ToolMessage
return {"messages": []}
def should_continue(state: MessagesState, config: RunnableConfig) -> Literal["tools", END]:
max_assistant_turns = config["configurable"]["max_assistant_turns"]
num_assistant_turns = 0
for message in state["messages"]:
if message.type == "ai":
num_assistant_turns += 1
last_message = state["messages"][-1]
# LLM call failed, e.g: max response length exceeded
if last_message.type == "tool":
return END
# max assistant turns exceeded
if max_assistant_turns and num_assistant_turns >= max_assistant_turns:
return END
# no tool calls
if not last_message.tool_calls:
return END
return "tools"
class ReactAgentLoop(AgentLoopBase):
@classmethod
def init_class(cls, config, tokenizer, **kwargs):
if cls._class_initialized:
return
cls._class_initialized = True
print("Performing class-level ReactAgentLoop initialization")
# build graph
cls.graph = cls.build_graph()
@classmethod
def build_graph(cls) -> StateGraph:
workflow = StateGraph(MessagesState)
workflow.add_node("agent", call_model)
workflow.add_node("tools", ToolNode(cls.tools))
workflow.set_entry_point("agent")
workflow.add_conditional_edges(
"agent",
should_continue,
{
"tools": "tools",
END: END,
},
)
workflow.add_edge("tools", "agent")
graph = workflow.compile()
return graph
async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput:
model_path = self.config.actor_rollout_ref.model.path
model_name = "/".join(model_path.split("/")[-2:])
rollout = self.config.actor_rollout_ref.rollout
model = ChatModel(
model=model_name,
client=self.server_manager,
tokenizer=self.tokenizer,
max_tokens=rollout.response_length,
max_parallel_calls=rollout.multi_turn.max_parallel_calls,
tool_parser=rollout.multi_turn.format,
)
model = model.bind_tools(self.tools, tool_choice="any")
config = {
"configurable": {
"model": model,
"sampling_params": sampling_params,
"max_user_turns": rollout.multi_turn.max_user_turns,
"max_assistant_turns": rollout.multi_turn.max_assistant_turns,
}
}
# TODO: how to handle multiple trajectories in an graph invocation?
# Each graph node may has its own LLM calls and state, e.g:
# https://github.com/google-gemini/gemini-fullstack-langgraph-quickstart
state = await self.graph.ainvoke(input={"messages": messages}, config=config)
output = convert_to_agent_output(state["messages"], rollout.response_length)
return output
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import numpy as np
import pytest
import ray
from langchain_core.tools import tool
from omegaconf import DictConfig
from recipe.langgraph_agent.react_agent_loop import ReactAgentLoop
from tests.experimental.agent_loop.agent_utils import init_agent_loop_manager
from verl.protocol import DataProto
from verl.utils import hf_tokenizer
@pytest.fixture
def init_config() -> DictConfig:
from hydra import compose, initialize_config_dir
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
config = compose(config_name="ppo_trainer")
model_path = "Qwen/Qwen2.5-1.5B-Instruct"
config.actor_rollout_ref.model.path = model_path
config.actor_rollout_ref.rollout.name = os.getenv("ROLLOUT_NAME", "vllm")
config.actor_rollout_ref.rollout.mode = "async"
config.actor_rollout_ref.rollout.prompt_length = 4096
config.actor_rollout_ref.rollout.response_length = 4096
config.actor_rollout_ref.rollout.n = 4
config.actor_rollout_ref.rollout.agent.num_workers = 2
# test sleep/wake_up with fsdp offload
config.actor_rollout_ref.actor.fsdp_config.param_offload = True
config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True
return config
@tool(parse_docstring=True)
def get_current_temperature(location: str, unit: str = "celsius"):
"""Get current temperature at a location.
Args:
location: The location to get the temperature for, in the format "City, State, Country".
unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"])
Returns:
the temperature, the location, and the unit in a dict
"""
print(f"[DEBUG] get_current_temperature: {location}, {unit}")
return {
"temperature": 26.1,
"location": location,
"unit": unit,
}
@tool(parse_docstring=True)
def get_temperature_date(location: str, date: str, unit: str = "celsius"):
"""Get temperature at a location and date.
Args:
location: The location to get the temperature for, in the format "City, State, Country".
date: The date to get the temperature for, in the format "Year-Month-Day".
unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"])
Returns:
the temperature, the location, the date and the unit in a dict
"""
print(f"[DEBUG] get_temperature_date: {location}, {date}, {unit}")
return {
"temperature": 25.9,
"location": location,
"date": date,
"unit": unit,
}
class TestReactAgentLoop(ReactAgentLoop):
@classmethod
def init_class(cls, config, tokenizer, **kwargs):
# TODO: find better way to configure tools
cls.tools = [get_current_temperature, get_temperature_date]
super().init_class(config, tokenizer, **kwargs)
def test_react_agent(init_config):
ray.init(
runtime_env={
"env_vars": {
"TOKENIZERS_PARALLELISM": "true",
"NCCL_DEBUG": "WARN",
"VLLM_LOGGING_LEVEL": "INFO",
"VLLM_USE_V1": "1",
}
}
)
# =========================== 1. Init rollout manager ===========================
agent_loop_config = [
{
"_target_": "recipe.langgraph_agent.test_react_agent_loop.TestReactAgentLoop",
"name": "react_agent",
},
]
agent_loop_config_path = "/tmp/agent_loop_config.json"
with open(agent_loop_config_path, "w") as f:
json.dump(agent_loop_config, f)
n = 2
init_config.actor_rollout_ref.rollout.n = n
# init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path
init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 2
init_config.actor_rollout_ref.rollout.agent.agent_loop_config_path = agent_loop_config_path
agent_loop_manager = init_agent_loop_manager(init_config)
# =========================== 2. Generate sequences ===========================
raw_prompts = [
[
{"role": "user", "content": "How are you?"},
],
[
{"role": "user", "content": "What's the temperature in Los Angeles now?"},
],
[
{"role": "user", "content": "What's the temperature in New York now?"},
],
[
{
"role": "system",
"content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\n\n"
"Current Date: 2024-09-30",
},
{"role": "user", "content": "What's the temperature in San Francisco now? How about tomorrow?"},
],
]
batch = DataProto(
non_tensor_batch={
"raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object),
"agent_name": np.array(["react_agent"] * len(raw_prompts)),
},
)
batch = batch.repeat(n)
result = agent_loop_manager.generate_sequences(prompts=batch)
assert len(result) == len(raw_prompts) * n
# Check turns
num_turns = result.non_tensor_batch["__num_turns__"]
print(f"num_turns: {num_turns}")
for i in range(len(num_turns)):
if i // n == 0:
# [user, assistant]
assert num_turns[i] == 2
else:
# [user, assistant, tool, assistant]
assert num_turns[i] == 4
# Check response_mask
tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)
responses = result.batch["responses"]
response_mask = result.batch["response_mask"]
attention_mask = result.batch["attention_mask"]
assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}"
response_length = response_mask.size(1)
for i in range(len(responses)):
# response with tool response
valid_tokens = responses[i][attention_mask[i][-response_length:].bool()]
response_with_obs = tokenizer.decode(valid_tokens)
# response without tool response
valid_tokens = responses[i][response_mask[i].bool()]
response_without_obs = tokenizer.decode(valid_tokens)
assert "<tool_response>" not in response_without_obs, (
f"found <tool_response> in response: {response_without_obs}"
)
assert "</tool_response>" not in response_without_obs, (
f"found </tool_response> in response: {response_without_obs}"
)
print("=========================")
print(response_with_obs)
print("---")
print(response_without_obs)
print("Test passed!")
ray.shutdown()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
# Copyright 2025 ModelBest Inc. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
import math
import os
import re
from typing import Optional
import datasets
import torch
from omegaconf import DictConfig, ListConfig
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from transformers import PreTrainedTokenizer, ProcessorMixin
import verl.utils.torch_functional as verl_F
from verl.utils.dataset.vision_utils import process_image
from verl.utils.model import compute_position_id_with_mask
logger = logging.getLogger(__name__)
def build_transform():
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_MEAN
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_STD
return transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
]
)
def build_image_bound(input_ids, tokenizer, new_schema=True, logger=None):
if new_schema:
start_cond = (input_ids == tokenizer.im_start_id) | (input_ids == tokenizer.slice_start_id)
end_cond = (input_ids == tokenizer.im_end_id) | (input_ids == tokenizer.slice_end_id)
else:
start_cond = input_ids == tokenizer.im_start_id
end_cond = input_ids == tokenizer.im_end_id
image_start_tokens = torch.where(start_cond)[0]
image_start_tokens += 1
image_end_tokens = torch.where(end_cond)[0]
if len(image_start_tokens) != len(image_end_tokens):
logger.error("image start token != image end tokens")
raise Exception("image start token != image end tokens")
if len(image_start_tokens) > 0:
image_bound = torch.hstack([image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)])
else:
image_bound = []
return image_bound
def preprocess(
images_dict,
conversations,
tokenizer,
transform,
query_nums=64,
slice_config=None,
llm_type=None,
patch_size=14,
batch_vision=False,
max_length=2048,
truncation="error",
logger=None,
):
"""
single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation
"""
conversations = copy.deepcopy(conversations)
assert conversations[0]["role"] == "user", "the first role must be user"
if slice_config is not None:
assert isinstance(slice_config, dict)
assert "patch_size" in slice_config
assert "max_slice_nums" in slice_config
assert "scale_resolution" in slice_config
default_image_placeholder = tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end
new_schema = False
use_image_id = False
if llm_type == "qwen":
new_schema = True
use_image_id = True
image_placeholder_dict = {}
images = []
image_id_cnt = 0
for img_name, image in images_dict.items():
if slice_config:
source_image, patches, best_grid = slice_image(
image,
slice_config["max_slice_nums"],
slice_config["scale_resolution"],
slice_config["patch_size"],
)
images.append(source_image)
image_placeholder = default_image_placeholder
if len(patches) > 0:
for i in range(len(patches)):
for j in range(len(patches[0])):
images.append(patches[i][j])
if use_image_id:
image_placeholder = (
f"{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}" + image_placeholder
)
image_id_cnt += 1
image_placeholder += get_grid_placeholder(tokenizer, best_grid, query_nums, new_schema=new_schema)
image_placeholder_dict[img_name] = image_placeholder
else:
images.append(image)
if use_image_id:
image_placeholder = f"{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}" + image_placeholder
image_id_cnt += 1
else:
image_placeholder = default_image_placeholder
image_placeholder_dict[img_name] = image_placeholder
images = [transform(i) for i in images]
if len(images_dict) == 1 and "<image>" in images_dict:
if "<image>" in conversations[0]["content"]:
conversations[0]["content"] = conversations[0]["content"].replace("<image>", image_placeholder)
else:
conversations[0]["content"] = image_placeholder + "\n" + conversations[0]["content"]
else:
pattern = r"<image_\d+>"
new_conversations = []
for conversation in conversations:
content = conversation["content"]
parts = re.split(f"({pattern})", content)
for i, part in enumerate(parts):
if not part.strip():
continue
if re.match(pattern, part):
if part in image_placeholder_dict:
parts[i] = image_placeholder_dict[part]
else:
raise Exception(f"not found {part} in image dict")
conversation["content"] = "\n".join(parts)
new_conversations.append(conversation)
conversations = new_conversations
# TODO change role in conversation for different llm
prompt_with_chat_template = tokenizer.apply_chat_template(conversations, add_generation_prompt=True, tokenize=False)
input_ids, attention_mask = verl_F.tokenize_and_postprocess_data(
prompt=prompt_with_chat_template,
tokenizer=tokenizer,
max_length=max_length,
pad_token_id=tokenizer.pad_token_id,
left_pad=True,
truncation=truncation,
)
position_ids = compute_position_id_with_mask(attention_mask)
image_bound = build_image_bound(input_ids[0], tokenizer, new_schema, logger)
input_dict = {
"input_ids": input_ids[0],
"attention_mask": attention_mask[0],
"position_ids": position_ids[0],
"image_bound": image_bound,
}
if batch_vision:
tgt_sizes = []
reshape_images = []
for image in images:
H, W = image.shape[1:]
reshape_image = reshape_by_patch(image, patch_size)
reshape_images.append(reshape_image)
tgt_sizes.append([H // patch_size, W // patch_size])
if tgt_sizes:
tgt_sizes = torch.Tensor(tgt_sizes).type(torch.int32)
input_dict["pixel_values"] = reshape_images
input_dict["tgt_sizes"] = tgt_sizes
else:
input_dict["pixel_values"] = images
input_dict["tgt_sizes"] = []
return input_dict
def slice_image(image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False):
original_size = image.size
original_width, original_height = original_size
log_ratio = math.log(original_width / original_height)
ratio = original_width * original_height / (scale_resolution * scale_resolution)
multiple = min(math.ceil(ratio), max_slice_nums)
source_image = None
best_grid = None
patches = []
if multiple <= 1 or never_split:
# dont need to slice, upsample
best_size = find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=True)
source_image = image.resize(best_size, Image.Resampling.BICUBIC)
else:
candidate_split_grids_nums = []
for i in [multiple - 1, multiple, multiple + 1]:
if i == 1 or i > max_slice_nums:
continue
candidate_split_grids_nums.append(i)
# source image, down-sampling and ensure divided by patch_size
best_resize = find_best_resize(original_size, scale_resolution, patch_size)
source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC)
candidate_grids = []
# find best grid
for split_grids_nums in candidate_split_grids_nums:
m = 1
while m <= split_grids_nums:
if split_grids_nums % m == 0:
candidate_grids.append([m, split_grids_nums // m])
m += 1
best_grid = [1, 1]
min_error = float("inf")
for grid in candidate_grids:
error = abs(log_ratio - math.log(grid[0] / grid[1]))
if error < min_error:
best_grid = grid
min_error = error
refine_size = get_refine_size(original_size, best_grid, scale_resolution, patch_size, allow_upscale=True)
refine_image = image.resize(refine_size, Image.Resampling.BICUBIC)
patches = split_to_patches(refine_image, best_grid)
return source_image, patches, best_grid
def ensure_divide(length, patch_size):
return max(round(length / patch_size) * patch_size, patch_size)
def find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=False):
width, height = original_size
if (width * height > scale_resolution * scale_resolution) or allow_upscale:
r = width / height
height = int(scale_resolution / math.sqrt(r))
width = int(height * r)
best_width = ensure_divide(width, patch_size)
best_height = ensure_divide(height, patch_size)
return (best_width, best_height)
def get_refine_size(original_size, grid, scale_resolution, patch_size, allow_upscale=False):
width, height = original_size
grid_x, grid_y = grid
refine_width = ensure_divide(width, grid_x)
refine_height = ensure_divide(height, grid_y)
grid_width = refine_width / grid_x
grid_height = refine_height / grid_y
best_grid_size = find_best_resize(
(grid_width, grid_height),
scale_resolution,
patch_size,
allow_upscale=allow_upscale,
)
refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y)
return refine_size
def split_to_patches(image, grid):
patches = []
width, height = image.size
grid_x = int(width / grid[0])
grid_y = int(height / grid[1])
for i in range(0, height, grid_y):
images = []
for j in range(0, width, grid_x):
box = (j, i, j + grid_x, i + grid_y)
patch = image.crop(box)
images.append(patch)
patches.append(images)
return patches
def get_grid_placeholder(tokenizer, grid, query_num, new_schema=False):
if new_schema:
image_placeholder = tokenizer.slice_start + tokenizer.unk_token * query_num + tokenizer.slice_end
else:
image_placeholder = tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end
cols = grid[0]
rows = grid[1]
slices = []
for i in range(rows):
lines = []
for j in range(cols):
lines.append(image_placeholder)
slices.append("".join(lines))
if new_schema:
slice_placeholder = "\n".join(slices)
else:
slice_placeholder = tokenizer.slice_start + "\n".join(slices) + tokenizer.slice_end
return slice_placeholder
def reshape_by_patch(image_tensor, patch_size):
"""
:param image_tensor: shape [3, H, W]
:param patch_size:
:return: [3, patch_size, HW/patch_size]
"""
patches = torch.nn.functional.unfold(image_tensor, (patch_size, patch_size), stride=(patch_size, patch_size))
patches = patches.reshape(image_tensor.size(0), patch_size, patch_size, -1)
patches = patches.permute(0, 1, 3, 2).reshape(image_tensor.size(0), patch_size, -1)
return patches
def init_minicpmo_config(processor, config):
"""Initialize MiniCPM-o specific configuration"""
minicpmo_config = {
"transform": build_transform(),
"patch_size": config.get("patch_size", 14),
"query_nums": config.get("query_nums", 64),
"slice_config": config.get(
"slice_config", {"max_slice_nums": 9, "patch_size": config.get("patch_size", 14), "scale_resolution": 448}
),
"llm_type": config.get("llm_type", "qwen"),
"batch_vision": config.get("batch_vision", True),
}
return minicpmo_config
def process_minicpmo_data(
row_dict, messages, tokenizer, minicpmo_config, image_key, max_prompt_length, truncation, logger
):
"""Process data for MiniCPM-o model"""
if len(row_dict[image_key]) == 1:
multi_modal_data = {}
image = process_image(row_dict.pop(image_key)[0])
multi_modal_data["image"] = [image]
images_dict = {"<image>": image}
else:
raise NotImplementedError
model_inputs = preprocess(
images_dict,
messages,
tokenizer,
minicpmo_config["transform"],
query_nums=minicpmo_config["query_nums"],
slice_config=minicpmo_config["slice_config"],
llm_type=minicpmo_config["llm_type"],
patch_size=minicpmo_config["patch_size"],
batch_vision=minicpmo_config["batch_vision"],
max_length=max_prompt_length,
truncation=truncation,
logger=logger,
)
raw_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
raw_prompt = raw_prompt.replace("<image>", "(<image>./</image>)")
return model_inputs, multi_modal_data, raw_prompt
class RLHFDataset(Dataset):
"""
Load and preprocess RLHF data from Parquet files.
- Caches files locally.
- Reads into a HuggingFace Dataset and tokenizes prompts.
- Optionally handles images/videos via a ProcessorMixin.
- Filters prompts over a max length.
- Supports resuming from checkpoints.
Args:
data_files (str or list): Path(s) to Parquet file(s).
tokenizer (PreTrainedTokenizer): For the tokenization of text to token IDs.
config (DictConfig): Options like cache_dir, prompt_key, max_prompt_length, truncation, etc.
processor (ProcessorMixin, optional): Multimodal preprocessor for images/videos.
"""
def __init__(
self,
data_files: str | list[str],
tokenizer: PreTrainedTokenizer,
config: DictConfig,
processor: Optional[ProcessorMixin] = None,
):
if not isinstance(data_files, list | ListConfig):
data_files = [data_files]
self.data_files = copy.deepcopy(data_files)
self.original_data_files = copy.deepcopy(data_files) # use for resume
self.tokenizer = tokenizer
self.processor = processor
self.config = config
self.cache_dir = os.path.expanduser(config.get("cache_dir", "~/.cache/verl/rlhf"))
self.prompt_key = config.get("prompt_key", "prompt")
self.image_key = config.get("image_key", "images")
self.video_key = config.get("video_key", "videos")
self.max_prompt_length = config.get("max_prompt_length", 1024)
self.return_raw_chat = config.get("return_raw_chat", False)
self.return_full_prompt = config.get("return_full_prompt", False)
self.truncation = config.get("truncation", "error")
self.filter_overlong_prompts = config.get("filter_overlong_prompts", True)
self.num_workers = config.get("filter_overlong_prompts_workers", max(1, os.cpu_count() // 4))
self.num_workers = min(self.num_workers, os.cpu_count())
self.use_shm = config.get("use_shm", False)
self.chat_template_func = config.get("chat_template_func", None)
self.need_tools_kwargs = config.get("need_tools_kwargs", False)
self.filter_prompts = config.get("filter_prompts", True)
self.serialize_dataset = False
self.minicpmo_config = init_minicpmo_config(self.processor, config)
self._download()
self._read_files_and_tokenize()
def _download(self, use_origin_parquet=False):
from verl.utils.fs import copy_to_local
data_files = self.data_files if not use_origin_parquet else self.original_data_files
for i, parquet_file in enumerate(data_files):
self.data_files[i] = copy_to_local(src=parquet_file, cache_dir=self.cache_dir, use_shm=self.use_shm)
def _read_files_and_tokenize(self):
dataframes = []
for parquet_file in self.data_files:
# read parquet files and cache
dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"]
dataframes.append(dataframe)
self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes)
print(f"dataset len: {len(self.dataframe)}")
def resume_dataset_state(self):
self.serialize_dataset = not hasattr(self, "original_data_files")
# resume dataframe if not it's serialized in data.pt
if not self.serialize_dataset:
self._download(use_origin_parquet=True) # download and resume from original parquet files
self._read_files_and_tokenize()
else:
print(r"old dataloader ckpt file is used, please train from scratch for better ckpt performance")
def __len__(self):
return len(self.dataframe)
def _build_messages(self, example: dict):
return example.pop(self.prompt_key)
def __getitem__(self, item):
"""
Note that we also return the raw_input_ids so that it can be combined with other chat template
"""
row_dict: dict = self.dataframe[item]
messages = self._build_messages(row_dict)
model_inputs = {}
if self.processor is not None:
model_inputs, multi_modal_data, raw_prompt = process_minicpmo_data(
row_dict,
messages,
self.tokenizer,
self.minicpmo_config,
self.image_key,
self.max_prompt_length,
self.truncation,
logger,
)
input_ids = model_inputs.pop("input_ids")
attention_mask = model_inputs.pop("attention_mask")
position_ids = model_inputs.pop("position_ids")
# There's a trap here, multi_modal_inputs has to be a dict, not BatchFeature
row_dict["multi_modal_data"] = multi_modal_data
row_dict["multi_modal_inputs"] = dict(model_inputs)
else:
raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
model_inputs = self.tokenizer(raw_prompt, return_tensors="pt", add_special_tokens=False)
input_ids = model_inputs.pop("input_ids")
attention_mask = model_inputs.pop("attention_mask")
position_ids = compute_position_id_with_mask(attention_mask)
row_dict["input_ids"] = input_ids
row_dict["attention_mask"] = attention_mask
row_dict["position_ids"] = position_ids
raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False)
if len(raw_prompt_ids) > self.max_prompt_length:
if self.truncation == "left":
raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :]
elif self.truncation == "right":
raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length]
elif self.truncation == "middle":
left_half = self.max_prompt_length // 2
right_half = self.max_prompt_length - left_half
raw_prompt_ids = raw_prompt_ids[:left_half] + raw_prompt_ids[-right_half:]
elif self.truncation == "error":
raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.")
row_dict["raw_prompt_ids"] = raw_prompt_ids
# encode prompts without chat template
if self.return_raw_chat:
row_dict["raw_prompt"] = messages
# get prompts with chat template
if self.return_full_prompt:
row_dict["full_prompts"] = raw_prompt # array of strings
# add index for each prompt
index = row_dict.get("extra_info", {}).get("index", 0)
tools_kwargs = row_dict.get("extra_info", {}).get("tools_kwargs", {})
interaction_kwargs = row_dict.get("extra_info", {}).get("interaction_kwargs", {})
need_tools_kwargs = row_dict.get("extra_info", {}).get("need_tools_kwargs", self.need_tools_kwargs)
if need_tools_kwargs and not tools_kwargs:
logger.warning("tools_kwargs is empty for index {}, data source: {}", index, row_dict["data_source"])
row_dict["index"] = index
row_dict["tools_kwargs"] = tools_kwargs
row_dict["interaction_kwargs"] = interaction_kwargs
return row_dict
def __getstate__(self):
if not self.serialize_dataset:
state = self.__dict__.copy()
if "dataframe" in state:
del state["dataframe"]
return state
return self.__dict__.copy()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment