Commit f87b35b2 authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
Pipeline #2648 failed with stages
in 0 seconds
# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
import os
import warnings
import torch
import torch.distributed
from torch.distributed.device_mesh import init_device_mesh
import verl.utils.torch_functional as verl_F
from omegaconf import DictConfig, open_dict
from verl import DataProto
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import register, Dispatch
from verl.utils import hf_tokenizer
from verl.utils.debug import log_gpu_memory_usage
from verl.utils.fs import copy_local_path_from_hdfs
from verl.utils.fsdp_utils import get_fsdp_wrap_policy, init_fn, get_init_weight_context_manager
from verl.utils.fsdp_utils import offload_fsdp_optimizer, offload_fsdp_model_to_cpu, load_fsdp_optimizer, \
load_fsdp_model_to_gpu
from verl.utils.import_utils import import_external_libs
from verl.utils.model import compute_position_id_with_mask
from verl.utils.flops_counter import FlopsCounter
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
from codetiming import Timer
from verl.workers.fsdp_workers import create_device_mesh, get_sharding_strategy
from .prime_core_algos import compute_dpo_accuracy, compute_dpo_abs_accuracy
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN'))
class PRIMERewardModelWorker(Worker):
def __init__(self, config):
super().__init__()
import torch.distributed
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
self.config = config
# build device mesh for Ulysses Sequence Parallel
world_size = torch.distributed.get_world_size()
from torch.distributed.device_mesh import init_device_mesh
fsdp_size = self.config.model.fsdp_config.fsdp_size
self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)
self.ulysses_device_mesh = None
self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1)
dp = world_size // self.ulysses_sequence_parallel_size
if self.ulysses_sequence_parallel_size > 1:
self.ulysses_device_mesh = init_device_mesh('cuda',
mesh_shape=(dp, self.ulysses_sequence_parallel_size),
mesh_dim_names=['dp', 'sp'])
self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
# set FSDP offload params
self._is_offload_param = self.config.model.fsdp_config.param_offload
self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload
# normalize config
self.config.mini_batch_size //= (torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size)
if self.config.micro_batch_size is not None:
self.config.micro_batch_size //= (torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size)
self.config.micro_batch_size_per_gpu = self.config.micro_batch_size
assert self.config.mini_batch_size % self.config.micro_batch_size_per_gpu == 0
def _build_reward_ref_model_optimizer(self, config):
# the following line is necessary
from verl.utils.model import LambdaLayer, print_model_size, squeeze
from verl.utils.torch_dtypes import PrecisionType
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision
from torch import optim
local_path = copy_local_path_from_hdfs(config.model.path)
tokenizer_path = copy_local_path_from_hdfs(config.model.tokenizer_path)
self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get('trust_remote_code', False))
from omegaconf import OmegaConf
override_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create()))
override_config_kwargs = {
'bos_token_id': self.tokenizer.bos_token_id,
'eos_token_id': self.tokenizer.eos_token_id,
'pad_token_id': self.tokenizer.pad_token_id,
}
override_config_kwargs.update(override_config)
if self.rank == 0:
print(f'Reward model overriding config {override_config_kwargs}')
torch_dtype = self.config.model.fsdp_config.get('model_dtype', 'fp32')
torch_dtype = PrecisionType.to_dtype(torch_dtype)
from transformers import AutoConfig, AutoModelForCausalLM
from torch import nn
trust_remote_code = False
reward_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
reward_model_config.num_labels = 1
init_context = get_init_weight_context_manager(use_meta_tensor=not reward_model_config.tie_word_embeddings)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
setattr(reward_model_config, 'classifier_dropout', 0.)
setattr(reward_model_config, 'hidden_dropout', '0')
reward_module = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
config=reward_model_config,
attn_implementation='flash_attention_2',
trust_remote_code=trust_remote_code)
if config.model.get('use_remove_padding', False) or self.ulysses_sequence_parallel_size > 1:
from verl.models.transformers.monkey_patch import apply_monkey_patch
apply_monkey_patch(model=reward_module, ulysses_sp_size=self.ulysses_sequence_parallel_size)
# some parameters may not in torch_dtype
reward_module.to(torch_dtype)
if config.model.get('enable_gradient_checkpointing', False):
reward_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False})
if self.rank == 0:
print_model_size(reward_module)
self.reward_model_config = reward_model_config
fsdp_config = self.config.model.fsdp_config
mixed_precision_config = fsdp_config.get('mixed_precision', None)
if mixed_precision_config is not None:
param_dtype = PrecisionType.to_dtype(mixed_precision_config.get('param_dtype', 'bf16'))
reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get('reduce_dtype', 'fp32'))
buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get('buffer_dtype', 'fp32'))
else:
param_dtype = torch.bfloat16
reduce_dtype = torch.float32
buffer_dtype = torch.float32
mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)
auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config.wrap_policy)
log_gpu_memory_usage('Before reward model FSDP', logger=None)
fsdp_mesh = self.device_mesh
sharding_strategy = get_sharding_strategy(fsdp_mesh)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
setattr(reward_model_config, 'classifier_dropout', 0.)
setattr(reward_model_config, 'hidden_dropout', '0')
ref_module = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=copy_local_path_from_hdfs(
config.model.ref_path),
torch_dtype=torch_dtype,
config=reward_model_config,
attn_implementation='flash_attention_2',
trust_remote_code=trust_remote_code)
# some parameters may not in torch_dtype
ref_module.to(torch_dtype)
reward_module = FSDP(reward_module,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
sharding_strategy=sharding_strategy,
mixed_precision=mixed_precision,
sync_module_states=True,
forward_prefetch=False,
device_mesh=self.device_mesh,
cpu_offload=None)
log_gpu_memory_usage('After reward FSDP', logger=None)
ref_module = FSDP(ref_module,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
sharding_strategy=sharding_strategy,
mixed_precision=mixed_precision,
sync_module_states=True,
forward_prefetch=False,
device_mesh=self.device_mesh,
cpu_offload=None)
reward_optimizer = optim.AdamW(reward_module.parameters(),
lr=config.model.optim.lr,
betas=config.model.optim.get('betas', (0.9, 0.999)),
weight_decay=config.model.optim.get('weight_decay', 1e-2))
total_steps = config.model.optim.get('total_training_steps', 0)
num_warmup_steps = int(config.model.optim.get('lr_warmup_steps', -1))
if num_warmup_steps < 0:
num_warmup_steps_ratio = config.model.optim.get('lr_warmup_steps_ratio', 0.)
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
print(f'Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}')
from verl.utils.torch_functional import get_constant_schedule_with_warmup
reward_lr_scheduler = get_constant_schedule_with_warmup(optimizer=reward_optimizer,
num_warmup_steps=num_warmup_steps)
return reward_module, ref_module, reward_optimizer, reward_lr_scheduler
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get('external_lib', None))
from .prime_dp_rm import DataParallelPRIMERewardModel
self.reward_module, self.ref_module, self.reward_optimizer, self.reward_lr_scheduler = self._build_reward_ref_model_optimizer(
config=self.config)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.reward_module)
offload_fsdp_model_to_cpu(self.ref_module)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.reward_optimizer)
self.rm = DataParallelPRIMERewardModel(config=self.config,
reward_module=self.reward_module,
ref_module=self.ref_module,
reward_optimizer=self.reward_optimizer)
self.flops_counter = FlopsCounter(self.reward_model_config)
self.checkpoint_manager = FSDPCheckpointManager(model=self.reward_module,
optimizer=self.reward_optimizer,
lr_scheduler=self.reward_lr_scheduler,
tokenizer=self.tokenizer)
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_rm_score(self, data: DataProto):
data = data.to('cuda')
if self._is_offload_param:
load_fsdp_model_to_gpu(self.reward_module)
load_fsdp_model_to_gpu(self.ref_module)
micro_batch_size = self.config.micro_batch_size_per_gpu
data.meta_info['micro_batch_size'] = micro_batch_size
data.meta_info['max_token_len'] = self.config.forward_max_token_len_per_gpu
data.meta_info['use_dynamic_bsz'] = self.config.use_dynamic_bsz
# perform forward computation
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data=data)
rm_scores, q, metrics = self.rm.compute_rm_score(data=data)
prompt_length = data.batch['prompts'].shape[-1]
response_mask = data.batch['attention_mask'][:, prompt_length:]
acc = data.batch['acc']
dpo_acc = compute_dpo_accuracy(rm_scores, acc, response_mask=response_mask, n_samples=data.meta_info['n'])
dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, response_mask, n_samples=data.meta_info['n'])
metrics['reward_model/dpo_acc'] = dpo_acc.detach().item()
metrics['reward_model/dpo_acc_abs'] = dpo_acc_abs.detach().item()
output = DataProto.from_dict(tensors={'rm_scores': rm_scores, 'q': q}, meta_info={'metrics': metrics})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
output = output.to('cpu')
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.reward_module)
offload_fsdp_model_to_cpu(self.ref_module)
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_rm(self, data: DataProto):
data = data.to('cuda')
if self._is_offload_param:
load_fsdp_model_to_gpu(self.ref_module)
load_fsdp_model_to_gpu(self.reward_module)
if self._is_offload_optimizer:
load_fsdp_optimizer(optimizer=self.reward_optimizer, device_id=torch.cuda.current_device())
# perform forward computation
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data=data)
rm_scores, metrics = self.rm.update_rm(data=data)
self.reward_lr_scheduler.step()
lr = self.reward_lr_scheduler.get_last_lr()[0]
metrics['rm/lr'] = lr
prompt_length = data.batch['prompts'].shape[-1]
response_mask = data.batch['attention_mask'][:, prompt_length:]
acc = data.batch['acc']
dpo_acc_before = compute_dpo_accuracy(rm_scores,
acc,
response_mask=response_mask,
n_samples=data.meta_info['n'])
dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, response_mask, n_samples=data.meta_info['n'])
metrics['reward_model/dpo_acc_before'] = dpo_acc_before.detach().item()
metrics['reward_model/dpo_acc_abs_before'] = dpo_acc_abs.detach().item()
output = DataProto.from_dict(tensors={'rm_scores': rm_scores}, meta_info={'metrics': metrics})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.reward_module)
offload_fsdp_model_to_cpu(self.ref_module)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.reward_optimizer)
output = output.to('cpu')
return output
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
import torch
if self._is_offload_param:
load_fsdp_model_to_gpu(self.reward_module)
self.checkpoint_manager.save_checkpoint(local_path=local_path,
hdfs_path=hdfs_path,
global_step=global_step,
max_ckpt_to_keep=max_ckpt_to_keep)
torch.distributed.barrier()
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.reward_module)
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def load_checkpoint(self, local_path, del_local_after_load=True):
import torch
if self._is_offload_param:
load_fsdp_model_to_gpu(self.reward_module)
self.checkpoint_manager.load_checkpoint(local_path=local_path, del_local_after_load=del_local_after_load)
torch.distributed.barrier()
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.reward_module)
# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
FSDP PPO Trainer with Ray-based single controller.
This trainer supports model-agonistic model initialization with huggingface
"""
import os
import statistics
import uuid
from copy import deepcopy
from pprint import pprint
import numpy as np
import torch
from omegaconf import OmegaConf, open_dict
from verl import DataProto
from verl.single_controller.ray import RayWorkerGroup
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
from verl.trainer.ppo.ray_trainer import Role, WorkerType, ResourcePoolManager, reduce_metrics, _timer
from verl.trainer.ppo.metric_utils import _compute_response_info
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
from . import prime_core_algos
def compute_advantage(data: DataProto, adv_estimator, config):
if adv_estimator == 'rloo':
responses = data.batch['responses']
response_length = responses.size(-1)
attention_mask = data.batch['attention_mask']
response_mask = attention_mask[:, -response_length:]
advantages, returns = prime_core_algos.compute_rloo_advantage_return(data, response_mask,
config.actor_rollout_ref.rollout.n, config)
data.batch['advantages'] = advantages
data.batch['returns'] = returns
else:
raise NotImplementedError
return data
def compute_data_metrics(batch, use_critic=True):
advantages = batch.batch['advantages']
returns = batch.batch['returns']
max_response_length = batch.batch['responses'].shape[-1]
prompt_mask = batch.batch['attention_mask'][:, :-max_response_length].bool()
response_mask = batch.batch['attention_mask'][:, -max_response_length:].bool()
max_prompt_length = prompt_mask.size(-1)
response_info = _compute_response_info(batch)
prompt_length = response_info['prompt_length']
response_length = response_info['response_length']
valid_adv = torch.masked_select(advantages, response_mask)
valid_returns = torch.masked_select(returns, response_mask)
if use_critic:
values = batch.batch['values']
valid_values = torch.masked_select(values, response_mask)
return_diff_var = torch.var(valid_returns - valid_values)
return_var = torch.var(valid_returns)
metrics = {
# adv
'critic/advantages/mean':
torch.mean(valid_adv).detach().item(),
'critic/advantages/max':
torch.max(valid_adv).detach().item(),
'critic/advantages/min':
torch.min(valid_adv).detach().item(),
# returns
'critic/returns/mean':
torch.mean(valid_returns).detach().item(),
'critic/returns/max':
torch.max(valid_returns).detach().item(),
'critic/returns/min':
torch.min(valid_returns).detach().item(),
**({
# values
'critic/values/mean': torch.mean(valid_values).detach().item(),
'critic/values/max': torch.max(valid_values).detach().item(),
'critic/values/min': torch.min(valid_values).detach().item(),
# vf explained var
'critic/vf_explained_var': (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
} if use_critic else {}),
# response length
'response_length/mean':
torch.mean(response_length).detach().item(),
'response_length/max':
torch.max(response_length).detach().item(),
'response_length/min':
torch.min(response_length).detach().item(),
'response_length/clip_ratio':
torch.mean(torch.eq(response_length, max_response_length).float()).detach().item(),
# prompt length
'prompt_length/mean':
torch.mean(prompt_length).detach().item(),
'prompt_length/max':
torch.max(prompt_length).detach().item(),
'prompt_length/min':
torch.min(prompt_length).detach().item(),
'prompt_length/clip_ratio':
torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
}
return metrics
def compute_timing_metrics(batch, timing_raw):
response_info = _compute_response_info(batch)
num_prompt_tokens = torch.sum(response_info['prompt_length']).item()
num_response_tokens = torch.sum(response_info['response_length']).item()
num_overall_tokens = num_prompt_tokens + num_response_tokens
num_tokens_of_section = {
'gen': num_response_tokens,
**{
name: num_overall_tokens for name in ['ref', 'values', 'adv', 'update_critic', 'update_actor']
},
}
return {
**{
f'timing_s/{name}': value for name, value in timing_raw.items()
},
**{
f'timing_per_token_ms/{name}': timing_raw[name] * 1000 / num_tokens_of_section[name] for name in set(num_tokens_of_section.keys(
)) & set(timing_raw.keys())
},
}
class RayPRIMETrainer(RayPPOTrainer):
"""
Note that this trainer runs on the driver process on a single CPU/GPU node.
"""
# TODO: support each role have individual ray_worker_group_cls,
# i.e., support different backend of different role
def __init__(self,
config,
tokenizer,
role_worker_mapping: dict[Role, WorkerType],
resource_pool_manager: ResourcePoolManager,
ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
reward_fn=None,
val_reward_fn=None):
# assert torch.cuda.is_available(), 'cuda must be available on driver'
super().__init__(config, tokenizer, role_worker_mapping, resource_pool_manager, ray_worker_group_cls, reward_fn,
val_reward_fn)
self.use_critic = False
def _validate_config(self):
super()._validate_config()
# TODO: Additional config checks can be added here
config = self.config
def _create_dataloader(self):
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
# TODO: we have to make sure the batch size is divisible by the dp size
self.train_dataset = RLHFDataset(data_files=self.config.data.train_files,
tokenizer=self.tokenizer,
config=self.config.data)
# use sampler for better ckpt resume
if self.config.data.shuffle:
train_dataloader_generator = torch.Generator()
train_dataloader_generator.manual_seed(self.config.data.get('seed', 1))
sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator)
else:
sampler = SequentialSampler(data_source=self.train_dataset)
self.train_dataloader = DataLoader(dataset=self.train_dataset,
batch_size=int(self.config.data.train_batch_size *
self.config.data.oversample_factor),
drop_last=True,
collate_fn=collate_fn,
sampler=sampler)
self.val_dataset = RLHFDataset(data_files=self.config.data.val_files,
tokenizer=self.tokenizer,
config=self.config.data)
self.val_dataloader = DataLoader(dataset=self.val_dataset,
batch_size=len(self.val_dataset),
shuffle=True,
drop_last=True,
collate_fn=collate_fn)
assert len(self.train_dataloader) >= 1
assert len(self.val_dataloader) >= 1
print(f'Size of train dataloader: {len(self.train_dataloader)}')
print(f'Size of val dataloader: {len(self.val_dataloader)}')
# inject total_training_steps to actor/critic optim_config. This is hacky.
total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
if self.config.trainer.total_training_steps is not None:
total_training_steps = self.config.trainer.total_training_steps
self.total_training_steps = total_training_steps
print(f'Total training steps: {self.total_training_steps}')
OmegaConf.set_struct(self.config, True)
with open_dict(self.config):
self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps
self.config.critic.optim.total_training_steps = total_training_steps
def _save_checkpoint(self):
# path: given_path + `/global_step_{global_steps}` + `/actor`
local_global_step_folder = os.path.join(self.config.trainer.default_local_dir,
f'global_step_{self.global_steps}')
print(f'local_global_step_folder: {local_global_step_folder}')
actor_local_path = os.path.join(local_global_step_folder, 'actor')
actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'actor')
self.actor_rollout_wg.save_checkpoint(actor_local_path,
actor_remote_path,
self.global_steps,
remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save)
if self.use_rm:
reward_local_path = os.path.join(local_global_step_folder, 'reward')
reward_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'reward')
self.rm_wg.save_checkpoint(reward_local_path,
reward_remote_path,
self.global_steps,
remove_previous_ckpt=self.config.trainer.remove_previous_ckpt_in_save)
# save dataloader
dataloader_local_path = os.path.join(local_global_step_folder, 'data.pt')
import dill
torch.save(self.train_dataloader, dataloader_local_path, pickle_module=dill)
# latest checkpointed iteration tracker (for atomic usage)
local_latest_checkpointed_iteration = os.path.join(self.config.trainer.default_local_dir,
'latest_checkpointed_iteration.txt')
with open(local_latest_checkpointed_iteration, 'w') as f:
f.write(str(self.global_steps))
def _load_checkpoint(self):
if self.config.trainer.resume_mode == 'disable':
return 0
# load from hdfs
if self.config.trainer.default_hdfs_dir is not None:
NotImplementedError('load from hdfs is not implemented yet')
else:
checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path
if not os.path.isabs(checkpoint_folder):
working_dir = os.getcwd()
checkpoint_folder = os.path.join(working_dir, checkpoint_folder)
global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest
# find global_step_folder
if self.config.trainer.resume_mode == 'auto':
if global_step_folder is None:
print('Training from scratch')
return 0
else:
if self.config.trainer.resume_mode == "resume_path":
assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type"
assert 'global_step_' in self.config.trainer.resume_from_path, "resume ckpt must specify the global_steps"
global_step_folder = self.config.trainer.resume_from_path
if not os.path.isabs(global_step_folder):
working_dir = os.getcwd()
global_step_folder = os.path.join(working_dir, global_step_folder)
print(f'Load from checkpoint folder: {global_step_folder}')
# set global step
self.global_steps = int(global_step_folder.split('global_step_')[-1])
print(f'Setting global step to {self.global_steps}')
print(f'Resuming from {global_step_folder}')
actor_path = os.path.join(global_step_folder, 'actor')
reward_path = os.path.join(global_step_folder, 'reward')
# load actor
self.actor_rollout_wg.load_checkpoint(actor_path,
del_local_after_load=self.config.trainer.del_local_ckpt_after_load)
# load rm
if self.use_rm:
self.rm_wg.load_checkpoint(reward_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load)
# load dataloader,
# TODO: from remote not implemented yet
dataloader_local_path = os.path.join(global_step_folder, 'data.pt')
self.train_dataloader = torch.load(dataloader_local_path)
if isinstance(self.train_dataloader.dataset, RLHFDataset):
self.train_dataloader.dataset.resume_dataset_state()
def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
from verl.utils.tracking import Tracking
from omegaconf import OmegaConf
logger = Tracking(project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True))
self.global_steps = 0
# load checkpoint before doing anything
self._load_checkpoint()
# perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True):
val_metrics = self._validate()
pprint(f'Initial validation metrics: {val_metrics}')
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get('val_only', False):
return
# we start from step 1
self.global_steps += 1
for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
metrics = {}
timing_raw = {}
batch: DataProto = DataProto.from_single_dict(batch_dict)
# pop those keys for generation
gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids'])
with _timer('step', timing_raw):
# generate a batch
with _timer('gen', timing_raw):
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
if self.config.algorithm.adv_estimator == 'remax':
with _timer('gen_max', timing_raw):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info['do_sample'] = False
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
batch = batch.union(gen_baseline_output)
reward_baseline_tensor = self.reward_fn(batch)
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
batch.batch['reward_baselines'] = reward_baseline_tensor
del gen_baseline_batch, gen_baseline_output
batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))],
dtype=object)
# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)
# balance the number of valid tokens on each dp rank.
# Note that this breaks the order of data inside the batch.
# Please take care when you implement group based adv computation such as GRPO and rloo
# self._balance_batch(batch, metrics=metrics)
# compute global_valid tokens
batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist()
# verify
with _timer('verify', timing_raw):
scores = self.reward_fn.verify(batch)
metrics['acc'] = statistics.mean(scores)
# filter the batch. 1/oversample_factor samples will be kept. If there is a filter, prompts passing it will be prioritized.
batch = self.filter_and_downsample(scores, batch)
batch.meta_info['n'] = self.config.actor_rollout_ref.rollout.n
n_samples = self.config.actor_rollout_ref.rollout.n
# recompute old_log_probs
with _timer('old_log_prob', timing_raw):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
batch = batch.union(old_log_prob)
if self.use_reference_policy:
# compute reference log_prob
with _timer('ref', timing_raw):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
with _timer('adv', timing_raw):
if self.use_rm:
update_style = self.config.reward_model.model.get('update', 'none')
if update_style == 'none': # only run forward
reward_output = self.rm_wg.compute_rm_score(batch)
elif update_style == 'after': # update and directly return the reward
reward_output = self.rm_wg.update_rm(batch)
elif update_style == 'before': # update reward model, and then run forward
reward_output = self.rm_wg.update_rm(batch)
if 'metrics' in reward_output.meta_info.keys():
reward_output_metrics = reduce_metrics(reward_output.meta_info['metrics'])
metrics.update(reward_output_metrics)
reward_output = self.rm_wg.compute_rm_score(batch)
elif update_style == 'reverse': # run forward to calculate statistics, then update reward model
reward_output = self.rm_wg.compute_rm_score(batch)
# broadcast q and acc tensor to each result
bc_td = DataProto.from_dict(
tensors={
'Q_bc':
reward_output.batch['q'].sum(dim=-1).view(-1, n_samples).unsqueeze(
1).expand(-1, n_samples, -1).reshape(-1, n_samples),
'acc_bc':
batch.batch['acc'].view(-1, n_samples).unsqueeze(1).expand(
-1, n_samples, -1).reshape(-1, n_samples)
})
batch = batch.union(bc_td)
reward_output = self.rm_wg.update_rm(batch)
else:
raise NotImplementedError
batch = batch.union(reward_output)
if 'metrics' in reward_output.meta_info.keys():
reward_output_metrics = reduce_metrics(reward_output.meta_info['metrics'])
metrics.update(reward_output_metrics)
# compute advantages, executed on the driver process
batch = compute_advantage(batch,
adv_estimator=self.config.algorithm.adv_estimator,
config=self.config)
# update actor
with _timer('update_actor', timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
metrics.update(actor_output_metrics)
# validate
if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \
self.global_steps % self.config.trainer.test_freq == 0:
with _timer('testing', timing_raw):
val_metrics: dict = self._validate()
metrics.update(val_metrics)
if self.config.trainer.save_freq > 0 and \
self.global_steps % self.config.trainer.save_freq == 0:
with _timer('save_checkpoint', timing_raw):
self._save_checkpoint()
# collect metrics
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)
self.global_steps += 1
if self.global_steps >= self.total_training_steps:
# perform validation after training
if self.val_reward_fn is not None:
val_metrics = self._validate()
pprint(f'Final validation metrics: {val_metrics}')
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.save_freq > 0 and \
(self.global_steps - 1) % self.config.trainer.save_freq != 0:
with _timer('save_checkpoint', timing_raw):
self._save_checkpoint()
return
def filter_and_downsample(self, scores, batch: DataProto):
"""
downsample the batch according to oversample_factor
samples passing the filters will be prioritized
"""
n_samples = int(self.config.actor_rollout_ref.rollout.n)
reward_matrix = torch.tensor(scores).reshape(-1, n_samples)
filter_mask = torch.ones((reward_matrix.shape[0]), dtype=torch.bool)
if self.config.data.filter_accuracy:
acc_tensor = torch.mean(reward_matrix, dim=-1)
filter_mask[(acc_tensor > self.config.data.accuracy_upper_bound) |
(acc_tensor < self.config.data.accuracy_lower_bound)] = False
if self.config.data.filter_truncate:
length_matrix = batch.batch['attention_mask'][:, -batch.batch['responses'].shape[-1]:].sum(dim=-1).reshape(
-1, n_samples)
length_tensor = torch.max(length_matrix, dim=-1)[0]
filter_mask[length_tensor >= self.config.data.max_response_length - 1] = False
reorder_index = torch.argsort(filter_mask, descending=True)
reorder_index = (reorder_index.unsqueeze(-1) * n_samples + torch.arange(0, n_samples).unsqueeze(0)).view(-1)
batch.reorder(reorder_index[:int(len(batch) //
self.config.data.oversample_factor)]) # this operation is inplace
return batch
set -x
export VLLM_ATTENTION_BACKEND=XFORMERS
gsm8k_train_path=$HOME/data/gsm8k/train.parquet
gsm8k_test_path=$HOME/data/gsm8k/test.parquet
math_train_path=$HOME/data/math/train.parquet
math_test_path=$HOME/data/math/test.parquet
train_files="['$gsm8k_train_path', '$math_train_path']"
test_files="['$gsm8k_test_path', '$math_test_path']"
model_path=PRIME-RL/Eurus-2-7B-SFT
# model_path=Qwen/Qwen2.5-0.5B-Instruct
python3 -m recipe.prime.main_prime \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.train_batch_size=64 \
data.val_batch_size=6312 \
data.max_prompt_length=1024 \
data.max_response_length=3072 \
data.filter_overlong_prompts=True \
data.filter_accuracy=True \
data.accuracy_lower_bound=0.2 \
data.accuracy_upper_bound=0.8 \
data.oversample_factor=4 \
actor_rollout_ref.model.path=$model_path \
actor_rollout_ref.actor.optim.lr=5e-7 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=True \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.n=4 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
algorithm.adv_estimator=rloo \
algorithm.use_kl_in_reward=True \
algorithm.kl_penalty=kl \
algorithm.kl_ctrl.kl_coef=0.001 \
reward_model.model.path=$model_path \
reward_model.micro_batch_size_per_gpu=1 \
reward_model.model.update=before \
reward_model.model.beta_train=0.05 \
reward_model.model.optim.lr=1e-6 \
reward_model.model.optim.grad_clip=10.0 \
reward_model.model.input_tokenizer=null \
reward_model.mini_batch_size=64 \
trainer.val_before_train=False \
trainer.logger=['console','wandb'] \
trainer.project_name='prime_example' \
trainer.experiment_name='Eurus-2-7B-SFT-gsm8k' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=64 \
trainer.test_freq=64 \
trainer.total_epochs=15 $@
# DeepSeek R1 Reproduction
This recipe is under development, if you are interested, checkout the TODO list and join this project! https://github.com/volcengine/verl/issues/708
## Reproducing Evaluation
Eval Results of DS-R1-Distill-Qwen2.5-1.5B (k=8)
Dataset | Test Results | Reported
-- | -- | --
GPQA Diamond | 35.3 | 33.8
LiveCodeBench | 16.9 | 16.9
AIME 2024 | 30.4 | 28.9
CNMO 2024 (en) | 45.1 | -
CNMO 2024 (zh) | 41.0 | -
---
Eval Results (DS-R1)
Dataset | Test Results (k=1) | Test Results (k=4) | Reported
-- | -- | -- | --
GPQA Diamond | 67.7 | 69.6 | 71.5
LiveCodeBench | 64.7 | 63.1 | 65.9
AIME 2024 | 86.7 | 79.2 | 79.8
CNMO 2024 | 75.0 | 78.5 | 78.8
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
data:
path: /tmp/math_Qwen2-7B-Instruct.parquet
prompt_key: prompt
response_key: responses
data_source_key: data_source
reward_model_key: reward_model
custom_reward_function:
path: null
name: compute_score
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocess the dataset to parquet format
"""
import os
from datasets import load_dataset, concatenate_datasets
from functools import partial
from verl.utils.hdfs_io import copy, makedirs
import argparse
def example_map_fn(example, idx, process_fn, data_source, ability, split):
question, solution = process_fn(example)
data = {
"data_source": data_source,
"prompt": [{
"role": "user",
"content": question
}],
"ability": ability,
"reward_model": {
"style": "rule",
"ground_truth": solution
},
"extra_info": {
'split': split,
'index': idx
}
}
return data
def build_aime2024_dataset():
def process_aime2024(example):
return example["Problem"], str(example["Answer"])
data_source = 'Maxwell-Jia/AIME_2024'
print(f"Loading the {data_source} dataset from huggingface...", flush=True)
dataset = load_dataset(data_source, split="train")
map_fn = partial(example_map_fn,
process_fn=process_aime2024,
data_source=data_source,
ability="English",
split="test")
dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names)
return dataset
def build_gpqa_dimond_dataset():
import random
GPQA_QUERY_TEMPLATE = "Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.\n\n{Question}\n\nA) {A}\nB) {B}\nC) {C}\nD) {D}"
def process_gpqa_diamond(example):
choices = [example["Incorrect Answer 1"], example["Incorrect Answer 2"], example["Incorrect Answer 3"]]
random.shuffle(choices)
gold_index = random.randint(0, 3)
choices.insert(gold_index, example["Correct Answer"])
query_prompt = GPQA_QUERY_TEMPLATE.format(A=choices[0],
B=choices[1],
C=choices[2],
D=choices[3],
Question=example["Question"])
gold_choice = "ABCD"[gold_index]
return query_prompt, gold_choice
data_source = 'Idavidrein/gpqa'
print(f"Loading the {data_source} dataset from huggingface...", flush=True)
dataset = load_dataset(data_source, "gpqa_diamond", split="train")
map_fn = partial(example_map_fn,
process_fn=process_gpqa_diamond,
data_source=data_source,
ability="Math",
split="test")
dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names)
return dataset
def build_cnmo2024_dataset():
def process_cnmo2024(example):
return example["question"], example["answer"]
data_source = 'opencompass/LiveMathBench'
print(f"Loading the {data_source} dataset from huggingface...", flush=True)
dataset_en = load_dataset(data_source, "v202412_CNMO_en", split="test")
map_fn_en = partial(example_map_fn,
process_fn=process_cnmo2024,
data_source='opencompass/cnmo2024_en',
ability="Math",
split="test")
dataset_en = dataset_en.map(map_fn_en, with_indices=True, remove_columns=dataset_en.column_names)
dataset_zh = load_dataset(data_source, "v202412_CNMO_cn", split="test")
map_fn_zh = partial(example_map_fn,
process_fn=process_cnmo2024,
data_source='opencompass/cnmo2024_zh',
ability="Math",
split="test")
dataset_zh = dataset_zh.map(map_fn_zh, with_indices=True, remove_columns=dataset_zh.column_names)
dataset = concatenate_datasets([dataset_en, dataset_zh])
return dataset
def build_livecodebench_dataset():
import json, pickle, zlib, base64
def process_livecodebench(example):
# Construct Query Prompt
# From https://github.com/LiveCodeBench/LiveCodeBench/blob/998c52d394b836f15fff3b9a29866191108ff81b/lcb_runner/prompts/code_generation.py#L140
query_prompt = (
"You will be given a question (problem specification) and will generate a correct Python program that matches the specification and passes all tests.\n\n"
f"Question: {example['question_content']}\n\n")
if example["starter_code"]:
query_prompt += (
"You will use the following starter code to write the solution to the problem and enclose your code within delimiters.\n"
f"```python\n{example['starter_code']}\n```")
else:
query_prompt += (
"Read the inputs from stdin solve the problem and write the answer to stdout (do not directly test on the sample inputs). Enclose your code within delimiters as follows. Ensure that when the python program runs, it reads the inputs, runs the algorithm and writes output to STDOUT."
f"```python\n# YOUR CODE HERE\n```")
# Construct test cases
public_test_cases = json.loads(example["public_test_cases"])
try:
private_test_cases = json.loads(example["private_test_cases"])
except:
private_test_cases = json.loads(
pickle.loads(zlib.decompress(base64.b64decode(example["private_test_cases"].encode("utf-8")))))
full_test_cases = public_test_cases + private_test_cases
metadata = json.loads(example["metadata"])
test_cases = {
"inputs": [t["input"] for t in full_test_cases],
"outputs": [t["output"] for t in full_test_cases],
"fn_name": metadata.get("func_name", None),
}
text_cases_compressed = base64.b64encode(zlib.compress(pickle.dumps(json.dumps(test_cases)))).decode("utf-8")
return query_prompt, text_cases_compressed
data_source = 'livecodebench/code_generation_lite'
print(f"Loading the {data_source} dataset from huggingface...", flush=True)
dataset = load_dataset(data_source, split="test")
# R1 Evaluation use LiveCodeBench 24.08-25.01
dataset = dataset.filter(lambda line: "2024-08-00T00:00:00" <= line["contest_date"] < "2025-01-00T00:00:00")
map_fn = partial(example_map_fn,
process_fn=process_livecodebench,
data_source=data_source,
ability="Code",
split="test")
dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names, num_proc=8)
return dataset
TASK2DATA = {
"aime2024": build_aime2024_dataset,
"gpqa_diamond": build_gpqa_dimond_dataset,
"cnmo2024": build_cnmo2024_dataset,
"livecodebench": build_livecodebench_dataset,
}
SUPPORTED_TASKS = TASK2DATA.keys()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--local_dir', default='~/data/r1')
parser.add_argument('--hdfs_dir', default=None)
parser.add_argument('--tasks', default="all")
args = parser.parse_args()
if args.tasks.lower() == "all":
args.tasks = SUPPORTED_TASKS
else:
args.tasks = [task.strip() for task in args.tasks.split(',') if task.strip()]
for task in args.tasks:
if task not in SUPPORTED_TASKS:
raise NotImplementedError(f"{task} has not been supported.")
datasets = []
for task in args.tasks:
datasets.append(TASK2DATA[task]())
test_dataset = concatenate_datasets(datasets)
local_dir = args.local_dir
hdfs_dir = args.hdfs_dir
test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet'))
if hdfs_dir is not None:
makedirs(hdfs_dir)
copy(src=local_dir, dst=hdfs_dir)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Offline evaluate the performance of a generated file using reward model and ground truth verifier.
The input is a parquet file that contains N generated sequences and (optional) the ground truth.
"""
import hydra
from verl.utils.fs import copy_to_local
import pandas as pd
import numpy as np
from tqdm import tqdm
from collections import defaultdict
import ray
def get_custom_reward_fn(config):
import importlib.util, os
reward_fn_config = config.get("custom_reward_function") or {}
file_path = reward_fn_config.get("path")
if not file_path:
return None
if not os.path.exists(file_path):
raise FileNotFoundError(f"Reward function file '{file_path}' not found.")
spec = importlib.util.spec_from_file_location("custom_module", file_path)
module = importlib.util.module_from_spec(spec)
try:
spec.loader.exec_module(module)
except Exception as e:
raise RuntimeError(f"Error loading module from '{file_path}': {e}")
function_name = reward_fn_config.get("name")
if not hasattr(module, function_name):
raise AttributeError(f"Reward function '{function_name}' not found in '{file_path}'.")
print(f"using customized reward function '{function_name}' from '{file_path}'")
return getattr(module, function_name)
@ray.remote
def process_item(reward_fn, data_source, response_lst, reward_data):
ground_truth = reward_data['ground_truth']
score_lst = [reward_fn(data_source, r, ground_truth) for r in response_lst]
return data_source, np.mean(score_lst)
@hydra.main(config_path='config', config_name='evaluation', version_base=None)
def main(config):
local_path = copy_to_local(config.data.path)
dataset = pd.read_parquet(local_path)
prompts = dataset[config.data.prompt_key]
responses = dataset[config.data.response_key]
data_sources = dataset[config.data.data_source_key]
reward_model_data = dataset[config.data.reward_model_key]
total = len(dataset)
# Initialize Ray
if not ray.is_initialized():
ray.init()
# evaluate test_score based on data source
data_source_reward = defaultdict(list)
compute_score = get_custom_reward_fn(config)
# Create remote tasks
remote_tasks = [
process_item.remote(compute_score, data_sources[i], responses[i], reward_model_data[i]) for i in range(total)
]
# Process results as they come in
with tqdm(total=total) as pbar:
while len(remote_tasks) > 0:
# Use ray.wait to get completed tasks
done_ids, remote_tasks = ray.wait(remote_tasks)
for result_id in done_ids:
data_source, score = ray.get(result_id)
data_source_reward[data_source].append(score)
pbar.update(1)
metric_dict = {}
for data_source, rewards in data_source_reward.items():
metric_dict[f'test_score/{data_source}'] = np.mean(rewards)
print(metric_dict)
if __name__ == '__main__':
main()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def reward_func(data_source, solution_str, ground_truth, extra_info=None):
if data_source in ['Maxwell-Jia/AIME_2024', "opencompass/cnmo2024_en", "opencompass/cnmo2024_zh"]:
from recipe.r1.tasks import math
return math.compute_score(solution_str, ground_truth)
elif data_source == 'Idavidrein/gpqa':
from recipe.r1.tasks import gpqa
return gpqa.compute_score(solution_str, ground_truth)
elif data_source in ['livecodebench/code_generation_lite', 'livecodebench/code_generation']:
from recipe.r1.tasks import livecodebench
return livecodebench.compute_score(solution_str, ground_truth)
else:
raise NotImplementedError
MODEL_PATH=Qwen/DeepSeek-R1-Distill-Qwen-1.5B
DATA_PATH=/workspace/datasets/r1_bench
# Eval Data Process
python3 -m recipe.r1.data_process \
--local_dir $DATA_PATH \
--tasks all
# Generation
python3 -m verl.trainer.main_generation \
trainer.nnodes=1 \
trainer.n_gpus_per_node=8 \
data.path=$DATA_PATH/test.parquet \
data.prompt_key=prompt \
data.batch_size=1024 \
data.n_samples=8 \
data.output_path=$DATA_PATH/test-output-8.parquet \
model.path=$MODEL_PATH \
rollout.temperature=0.6 \
rollout.top_p=0.95 \
rollout.prompt_length=1024 \
rollout.response_length=32768 \
rollout.tensor_model_parallel_size=1 \
rollout.gpu_memory_utilization=0.9 \
rollout.max_num_batched_tokens=65536
# Evaluation
python3 -m recipe.r1.main_eval \
data.path=$DATA_PATH/test-output-8.parquet \
data.prompt_key=prompt \
data.response_key=responses \
custom_reward_function.path=recipe/r1/reward_score.py \
custom_reward_function.name=reward_func
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
# Extraction Template from https://github.com/openai/simple-evals/blob/90e3e821cabba2aeb6be651dcb662b253df04225/common.py#L25
ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer[ \t]*:[ \t]*\$?([A-D])\$?"
def compute_score(solution_str, ground_truth) -> float:
match = re.search(ANSWER_PATTERN_MULTICHOICE, solution_str)
extracted_answer = match.group(1) if match else None
score = 1.0 if extracted_answer == ground_truth else 0.0
return score
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing
import json
import pickle
import zlib
import base64
# Reuse `run_test` for convenience
from verl.utils.reward_score.prime_code.testing_util import run_test
def _temp_run(in_outs, generation, debug, result, metadata_list, timeout):
res, metadata = run_test(in_outs, test=generation, debug=debug, timeout=timeout)
result.append(res)
metadata_list.append(metadata)
def check_correctness(in_outs, generation, timeout, debug=True):
"""Check correctness of code generation with a global timeout.
The global timeout is to catch some extreme/rare cases not handled by the timeouts
inside `run_test`"""
manager = multiprocessing.Manager()
result = manager.list()
metadata_list = manager.list()
p = multiprocessing.Process(
target=_temp_run,
args=(in_outs, generation, debug, result, metadata_list, timeout),
)
p.start()
p.join(timeout=(timeout + 1) * len(in_outs["inputs"]) + 5)
if p.is_alive():
p.kill()
if not result:
# consider that all tests failed
result = [[-1 for i in range(len(in_outs["inputs"]))]]
if debug:
print(f"global timeout")
return result[0], metadata_list[0]
def compute_score(completion, test_cases):
solution = completion.split('```python')[-1].split('```')[0]
# extract test cases
try:
in_outs = json.loads(test_cases)
except:
in_outs = json.loads(pickle.loads(zlib.decompress(base64.b64decode(test_cases.encode("utf-8")))))
success = False
try:
res, metadata = check_correctness(in_outs=in_outs, generation=solution, timeout=6, debug=False)
success = all(map(lambda x: x == True, res))
except Exception as e:
pass
return success
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
try:
from math_verify.metric import math_metric
from math_verify.parser import LatexExtractionConfig, ExprExtractionConfig
except ImportError:
print("To use Math-Verify, please install it first by running `pip install math-verify`.")
def compute_score(model_output: str, ground_truth: str) -> bool:
verify_func = math_metric(
gold_extraction_target=(LatexExtractionConfig(),),
pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
)
ret_score = 0.
# Wrap the ground truth in \boxed{} format for verification
ground_truth_boxed = "\\boxed{" + ground_truth + "}"
try:
ret_score, _ = verify_func([ground_truth_boxed], [model_output])
except Exception as e:
pass
return ret_score
# requirements.txt records the full set of dependencies for development
accelerate
codetiming
datasets
dill
flash-attn
hydra-core
numpy
pandas
datasets
peft
pyarrow>=15.0.0
pybind11
pylatexenc
ray[default]>=2.10
tensordict<=0.6.2
torchdata
torchvision
transformers
wandb
sglang[all]==0.4.4.post4
torch-memory-saver>=0.0.5
\ No newline at end of file
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple, Dict
import re
import os
import torch
import argparse
import warnings
import numpy as np
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq
from concurrent.futures import ThreadPoolExecutor
from safetensors.torch import load_file
from torch.distributed._tensor import Shard, Placement
from verl.utils.megatron_utils import get_model, convert_config
from megatron.core.models.gpt.gpt_model import ModelType
from megatron.core import parallel_state as mpu
from megatron.core import dist_checkpointing
from megatron.core.dist_checkpointing.serialization import StrictHandling
def _init_args():
parser = argparse.ArgumentParser()
parser.add_argument('--hf_model_path', type=str, required=True, help="The path for the huggingface model")
parser.add_argument('--output_path', type=str, required=True, help="The path for the output mcore model")
parser.add_argument('--test', action='store_true', help="Whether to test the conversion")
args = parser.parse_args()
return args
class MegatronConfig:
def __init__(self):
self.params_dtype = torch.bfloat16
class ModelConfig:
def __init__(self):
self.path = None
class Config:
def __init__(self):
self.model = ModelConfig()
def convert_hf_to_mcore(hf_model_path, output_path, test=False):
os.makedirs(output_path, exist_ok=True)
if len(os.listdir(output_path)) > 0 and not test:
print(f"Output path {output_path} is not empty, skipping conversion")
return
# init torch distributed and mpu
os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
torch.distributed.init_process_group('nccl')
mpu.initialize_model_parallel(tensor_model_parallel_size=1,
virtual_pipeline_model_parallel_size=None,
context_parallel_size=1,
expert_model_parallel_size=1)
# init hf config
hf_config = AutoConfig.from_pretrained(hf_model_path)
print(hf_config)
megatron_config = MegatronConfig()
cfg = Config()
cfg.model.path = hf_model_path
tfconfig = convert_config(hf_config, megatron_config)
tie_word_embeddings = getattr(hf_config, "tie_word_embeddings", False)
# init megatron model
def megatron_model_provider(pre_process, post_process):
from verl.utils.model import get_parallel_gptmodel_from_config
parallel_model = get_parallel_gptmodel_from_config(tfconfig,
hf_config,
pre_process,
post_process,
share_embeddings_and_output_weights=tie_word_embeddings,
value=False)
return parallel_model
model = get_model(model_provider_func=megatron_model_provider,
model_type=ModelType.encoder_or_decoder,
wrap_with_ddp=True)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# init hf model
hf_model = AutoModelForCausalLM.from_pretrained(hf_model_path)
ref_state_dict = hf_model.state_dict()
# load hf state dict to megatron model
from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel
load_state_dict_to_megatron_gptmodel(state_dict=ref_state_dict,
wrapped_models=model,
config=hf_config,
params_dtype=torch.bfloat16,
is_value_model=False)
ssd = model[0].module.module.sharded_state_dict()
del ref_state_dict, hf_model
# save megatron model
if len(os.listdir(output_path)) == 0:
dist_checkpointing.save(ssd, output_path, sharded_strategy=None, async_sharded_save=False)
if test:
########### test ###########
# load model
model_test = get_model(model_provider_func=megatron_model_provider,
model_type=ModelType.encoder_or_decoder,
wrap_with_ddp=True)
ssd2 = model_test[0].module.module.sharded_state_dict()
dist_checkpointing.load(ssd2, output_path, strict=StrictHandling.ASSUME_OK_UNEXPECTED)
sd = model[0].module.module.state_dict()
sd2 = model_test[0].module.module.state_dict()
for k in sd.keys():
if sd[k] is None:
continue
d1 = sd[k].data
if k in sd2:
d2 = sd2[k].data
assert d1.shape == d2.shape, f'{k=} {d1.shape=} {d2.shape=}'
assert (d1 == d2).all(), f"{k} is not equal"
for k in sd2.keys():
if sd2[k] is None:
continue
d1 = sd2[k].data
if k in sd:
d2 = sd[k].data
assert d1.shape == d2.shape, f'{k=} {d1.shape=} {d2.shape=}'
assert (d1 == d2).all(), f"{k} is not equal"
# load value model
def megatron_value_model_provider(pre_process, post_process):
from verl.utils.model import get_parallel_gptmodel_from_config
parallel_model = get_parallel_gptmodel_from_config(tfconfig,
hf_config,
pre_process,
post_process,
share_embeddings_and_output_weights=False,
value=True)
parallel_model.cuda()
return parallel_model
model_value = get_model(model_provider_func=megatron_value_model_provider,
model_type=ModelType.encoder_or_decoder,
wrap_with_ddp=True)
ssd2 = model_value[0].module.module.sharded_state_dict()
dist_checkpointing.load(ssd2, output_path, strict=StrictHandling.IGNORE_ALL)
sd = model[0].module.module.state_dict()
sd2 = model_value[0].module.module.state_dict()
for k in sd.keys():
if sd[k] is None:
continue
d1 = sd[k].data
if k in sd2:
d2 = sd2[k].data
assert d1.shape == d2.shape, f'{k=} {d1.shape=} {d2.shape=}'
assert (d1 == d2).all(), f"{k} is not equal"
for k in sd2.keys():
if sd2[k] is None:
continue
d1 = sd2[k].data
if k in sd:
d2 = sd[k].data
assert d1.shape == d2.shape, f'{k=} {d1.shape=} {d2.shape=}'
assert (d1 == d2).all(), f"{k} is not equal"
if __name__ == "__main__":
args = _init_args()
convert_hf_to_mcore(args.hf_model_path, args.output_path, args.test)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Diagnose script for checking OS/hardware/python/pip/verl/network.
The output of this script can be a very good hint to issue/problem.
"""
import subprocess
import psutil
import platform, subprocess, sys, os
import socket, time
try:
from urllib.request import urlopen
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse
from urllib2 import urlopen
import argparse
import importlib.metadata
import torch
URLS = {
'PYPI': 'https://pypi.python.org/pypi/pip',
}
REGIONAL_URLS = {
'cn': {
'PYPI(douban)': 'https://pypi.douban.com/',
'Conda(tsinghua)': 'https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/',
}
}
def test_connection(name, url, timeout=10):
"""Simple connection test"""
urlinfo = urlparse(url)
start = time.time()
try:
ip = socket.gethostbyname(urlinfo.netloc)
except Exception as e:
print('Error resolving DNS for {}: {}, {}'.format(name, url, e))
return
dns_elapsed = time.time() - start
start = time.time()
try:
_ = urlopen(url, timeout=timeout)
except Exception as e:
print("Error open {}: {}, {}, DNS finished in {} sec.".format(name, url, e, dns_elapsed))
return
load_elapsed = time.time() - start
print("Timing for {}: {}, DNS: {:.4f} sec, LOAD: {:.4f} sec.".format(name, url, dns_elapsed, load_elapsed))
def check_python():
print('----------Python Info----------')
print('Version :', platform.python_version())
print('Compiler :', platform.python_compiler())
print('Build :', platform.python_build())
print('Arch :', platform.architecture())
def check_pip():
print('------------Pip Info-----------')
try:
import pip
print('Version :', pip.__version__)
print('Directory :', os.path.dirname(pip.__file__))
except ImportError:
print('No corresponding pip install for current python.')
def _get_current_git_commit():
try:
result = subprocess.run(['git', 'rev-parse', 'HEAD'], capture_output=True, text=True, check=True)
return result.stdout.strip()
except subprocess.CalledProcessError as e:
print(f"Error running git command: {e.stderr.strip()}")
return None
except FileNotFoundError:
print("Did not find command: git")
return None
def check_verl():
print('----------verl Info-----------')
try:
sys.path.insert(0, os.getcwd())
import verl
print('Version :', verl.__version__)
verl_dir = os.path.dirname(verl.__file__)
print('Directory :', verl_dir)
try:
commit_hash = _get_current_git_commit()
print('Commit Hash :', commit_hash)
except AttributeError:
print('Commit hash not found. ')
except ImportError as e:
print(f'No verl installed: {e}')
except Exception as e:
import traceback
if not isinstance(e, IOError):
print("An error occurred trying to import verl.")
print("This is very likely due to missing missing or incompatible library files.")
print(traceback.format_exc())
def check_os():
print('----------Platform Info----------')
print('Platform :', platform.platform())
print('system :', platform.system())
print('node :', platform.node())
print('release :', platform.release())
print('version :', platform.version())
def check_hardware():
print('----------Hardware Info----------')
print('machine :', platform.machine())
print('processor :', platform.processor())
if sys.platform.startswith('darwin'):
pipe = subprocess.Popen(('sysctl', '-a'), stdout=subprocess.PIPE)
output = pipe.communicate()[0]
for line in output.split(b'\n'):
if b'brand_string' in line or b'features' in line:
print(line.strip())
elif sys.platform.startswith('linux'):
subprocess.call(['lscpu'])
elif sys.platform.startswith('win32'):
subprocess.call(['wmic', 'cpu', 'get', 'name'])
def check_network(args):
print('----------Network Test----------')
if args.timeout > 0:
print('Setting timeout: {}'.format(args.timeout))
socket.setdefaulttimeout(10)
for region in args.region.strip().split(','):
r = region.strip().lower()
if not r:
continue
if r in REGIONAL_URLS:
URLS.update(REGIONAL_URLS[r])
else:
import warnings
warnings.warn('Region {} do not need specific test, please refer to global sites.'.format(r))
for name, url in URLS.items():
test_connection(name, url, args.timeout)
def check_environment():
print('----------Environment----------')
for k, v in os.environ.items():
if k.startswith('VERL_') or k.startswith('OMP_') or k.startswith('KMP_') or k == 'CC' or k == 'CXX':
print('{}="{}"'.format(k, v))
def check_pip_package_versions():
packages = ['vllm', 'sglang', 'ray', 'torch']
for package in packages:
try:
version = importlib.metadata.version(package)
print(f"{package}\t : {version}")
except importlib.metadata.PackageNotFoundError:
print(f"{package}\t : not found.")
def check_cuda_versions():
if torch.cuda.is_available():
try:
cuda_runtime_version = torch.version.cuda
print(f"CUDA Runtime : {cuda_runtime_version}")
import subprocess
nvcc_output = subprocess.check_output(['nvcc', '--version']).decode('utf-8')
cuda_compiler_version = next((line for line in nvcc_output.splitlines() if 'release' in line), None)
if cuda_compiler_version:
print(f"CUDA Compiler : {cuda_compiler_version.strip()}")
else:
print("Could not determine CUDA compiler version.")
except FileNotFoundError as e:
print(f"CUDA compiler : Not found: {e}")
except Exception as e:
print(f"An error occurred while checking CUDA versions: {e}")
else:
print("CUDA is not available.")
def _get_cpu_memory():
"""
Get the total CPU memory capacity in GB.
"""
memory = psutil.virtual_memory()
return memory.total / (1024**3)
def _get_gpu_info():
"""
Get GPU type, GPU memory, and GPU count using nvidia-smi command.
"""
try:
result = subprocess.run(['nvidia-smi', '--query-gpu=gpu_name,memory.total', '--format=csv,noheader,nounits'],
capture_output=True,
text=True,
check=True)
gpu_lines = result.stdout.strip().split('\n')
gpu_count = len(gpu_lines)
gpu_info = []
for line in gpu_lines:
gpu_name, gpu_memory = line.split(', ')
gpu_info.append({
'type': gpu_name,
'memory': float(gpu_memory) / 1024 # Convert to GB
})
return gpu_count, gpu_info
except subprocess.CalledProcessError:
print("Failed to execute nvidia-smi command.")
return 0, []
def _get_system_info():
"""
Get CPU memory capacity, GPU type, GPU memory, and GPU count.
"""
cpu_memory = _get_cpu_memory()
gpu_count, gpu_info = _get_gpu_info()
return {'cpu_memory': cpu_memory, 'gpu_count': gpu_count, 'gpu_info': gpu_info}
def check_system_info():
print('----------System Info----------')
system_info = _get_system_info()
print(f"CPU Memory\t: {system_info['cpu_memory']:.2f} GB")
print(f"GPU Count\t: {system_info['gpu_count']}")
for i, gpu in enumerate(system_info['gpu_info']):
print(f"GPU {i + 1}\tType : {gpu['type']}")
print(f"GPU {i + 1}\tMemory : {gpu['memory']:.2f} GB")
def parse_args():
"""Parse arguments."""
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description='Diagnose script for checking the current system.')
choices = ['python', 'pip', 'verl', 'system', 'os', 'environment']
for choice in choices:
parser.add_argument('--' + choice, default=1, type=int, help='Diagnose {}.'.format(choice))
parser.add_argument('--network', default=0, type=int, help='Diagnose network.')
parser.add_argument('--hardware', default=0, type=int, help='Diagnose hardware.')
parser.add_argument('--region',
default='',
type=str,
help="Additional sites in which region(s) to test. \
Specify 'cn' for example to test mirror sites in China.")
parser.add_argument('--timeout', default=10, type=int, help="Connection test timeout threshold, 0 to disable.")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
if args.python:
check_python()
if args.pip:
check_pip()
check_pip_package_versions()
if args.verl:
check_verl()
if args.os:
check_os()
if args.hardware:
check_hardware()
if args.network:
check_network(args)
if args.environment:
check_environment()
check_cuda_versions()
if args.system:
check_system_info()
#!/bin/bash
pip3 install --upgrade yapf
python3 -m yapf -ir -vv --style ./.style.yapf verl tests examples recipe scripts
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple, Dict
import re
import os
import torch
import argparse
import numpy as np
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq
from concurrent.futures import ThreadPoolExecutor
from safetensors.torch import load_file
from torch.distributed._tensor import Shard, Placement
try:
# for torch 2.5+
from torch.distributed.tensor import DTensor
except ImportError:
from torch.distributed._tensor import DTensor
parser = argparse.ArgumentParser()
parser.add_argument('--backend', type=str, required=True, help="The backend of the model", choices=["fsdp", "megatron"])
parser.add_argument('--tie-word-embedding', action='store_true', help="Whether to tie word embedding weights")
parser.add_argument('--is-value-model', action='store_true', help="Whether the model loaded as value model")
parser.add_argument('--hf_model_path', type=str, required=True, help="The path for the huggingface model")
parser.add_argument(
'--local_dir',
type=str,
required=True,
help=
"The path for your saved model. For megatron, point to the base dir of model, rng, optimizer checkpoints, commonly be `config.default_local_dir/global_step_\{global_step\}`."
)
parser.add_argument('--target_dir', required=False, default="tmp", type=str, help="The path for the target model")
parser.add_argument("--hf_upload_path", default=False, type=str, help="The path of the huggingface repo to upload")
parser.add_argument("--test", action="store_true", help="test correctness of hf_model")
parser.add_argument("--test_hf_dir",
type=str,
required=False,
help="test correctness of hf_model, , with hf_model in checkpoint.contents")
args = parser.parse_args()
os.makedirs(args.target_dir, exist_ok=True)
if args.test:
assert args.test_hf_dir is not None, f'You must run verl save checkpoint first, with hf_model in checkpoint.contents, and provide the directory here'
def merge_by_placement(tensors: List[torch.Tensor], placement: Placement):
if placement.is_replicate():
return tensors[0]
elif placement.is_partial():
raise NotImplementedError("Partial placement is not supported yet")
elif placement.is_shard():
return torch.cat(tensors, dim=placement.dim).contiguous()
else:
raise ValueError(f"Unsupported placement: {placement}")
def upload_model_to_huggingface(hf_path):
# Push to hugging face
from huggingface_hub import HfApi
api = HfApi()
api.create_repo(repo_id=args.hf_upload_path, private=False, exist_ok=True)
api.upload_folder(folder_path=hf_path, repo_id=args.hf_upload_path, repo_type="model")
def convert_fsdp_checkpoints_to_hfmodels():
local_dir = args.local_dir
# copy rank zero to find the shape of (dp, fsdp)
rank = 0
world_size = 0
for filename in os.listdir(local_dir):
match = re.match(r"model_world_size_(\d+)_rank_0\.pt", filename)
if match:
world_size = match.group(1)
break
assert world_size, "No model file with the proper format"
state_dict = torch.load(os.path.join(local_dir, f'model_world_size_{world_size}_rank_{rank}.pt'),
map_location='cpu',
weights_only=False)
pivot_key = sorted(list(state_dict.keys()))[0]
weight = state_dict[pivot_key]
if isinstance(weight, DTensor):
# get sharding info
device_mesh = weight.device_mesh
mesh = device_mesh.mesh
mesh_dim_names = device_mesh.mesh_dim_names
else:
# for non-DTensor
mesh = np.array([int(world_size)], dtype=np.int64)
mesh_dim_names = ('fsdp',)
print(f'Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}')
assert mesh_dim_names in (('fsdp',), ('ddp', 'fsdp')), f'Unsupported mesh_dim_names {mesh_dim_names}'
if 'tp' in mesh_dim_names:
# fsdp * tp
total_shards = mesh.shape[-1] * mesh.shape[-2]
mesh_shape = (mesh.shape[-2], mesh.shape[-1])
else:
# fsdp
total_shards = mesh.shape[-1]
mesh_shape = (mesh.shape[-1],)
print(f'Processing model shards with {total_shards} {mesh_shape} in total')
model_state_dict_lst = []
model_state_dict_lst.append(state_dict)
model_state_dict_lst.extend([""] * (total_shards - 1))
def process_one_shard(rank):
model_path = os.path.join(local_dir, f'model_world_size_{world_size}_rank_{rank}.pt')
state_dict = torch.load(model_path, map_location='cpu', weights_only=False)
model_state_dict_lst[rank] = state_dict
return state_dict
with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor:
for rank in range(1, total_shards):
executor.submit(process_one_shard, rank)
state_dict = {}
param_placements: Dict[str, List[Placement]] = {}
keys = set(model_state_dict_lst[0].keys())
for key in keys:
state_dict[key] = []
for model_state_dict in model_state_dict_lst:
try:
tensor = model_state_dict.pop(key)
except:
print("-" * 30)
print(model_state_dict)
if isinstance(tensor, DTensor):
state_dict[key].append(tensor._local_tensor.bfloat16())
placements = tuple(tensor.placements)
# replicated placement at dp dimension can be discarded
if mesh_dim_names[0] == 'dp':
placements = placements[1:]
elif mesh_dim_names[0] == 'ddp':
placements = placements[1:]
if key not in param_placements:
param_placements[key] = placements
else:
assert param_placements[key] == placements
else:
state_dict[key].append(tensor.bfloat16())
del model_state_dict_lst
for key in sorted(state_dict):
if not isinstance(state_dict[key], list):
print(f"No need to merge key {key}")
continue
if key in param_placements:
# merge shards
placements: Tuple[Shard] = param_placements[key]
if len(mesh_shape) == 1:
# 1-D list, FSDP without TP
assert len(placements) == 1
shards = state_dict[key]
state_dict[key] = merge_by_placement(shards, placements[0])
else:
# 2-D list, FSDP + TP
raise NotImplementedError("FSDP + TP is not supported yet")
else:
state_dict[key] = torch.cat(state_dict[key], dim=0)
print('Writing to local disk')
if args.target_dir is None:
hf_path = os.path.join(local_dir, 'huggingface')
else:
hf_path = args.target_dir
config = AutoConfig.from_pretrained(args.hf_model_path)
if 'ForTokenClassification' in config.architectures[0]:
auto_model = AutoModelForTokenClassification
elif 'ForCausalLM' in config.architectures[0]:
auto_model = AutoModelForCausalLM
elif 'ForConditionalGeneration' in config.architectures[0]:
auto_model = AutoModelForVision2Seq
else:
raise NotImplementedError(f'Unknown architecture {config["architectures"]}')
with torch.device('meta'):
model = auto_model.from_config(config, torch_dtype=torch.bfloat16)
model.to_empty(device='cpu')
print(f'Saving model to {hf_path}')
model.save_pretrained(hf_path, state_dict=state_dict)
del state_dict
del model
if args.hf_upload_path:
upload_model_to_huggingface(hf_path)
def get_tp_pp_rank_from_sharded_dir(sharded_dir):
match = re.match(r"mp_rank_(\d\d)_(\d\d\d)", sharded_dir)
tp_rank = int(match.group(1))
pp_rank = int(match.group(2))
return tp_rank, pp_rank
def check_megatron_checkpoint_path(model_path):
sharded_dirs = sorted(os.listdir(model_path))
tp_size = 0
pp_size = 0
for sharded_dir in sharded_dirs:
match = re.match(r"mp_rank_(\d\d)_(\d\d\d)", sharded_dir)
assert match, f"Invalid sharded dir {sharded_dir}"
assert f"model.pt" in os.listdir(os.path.join(model_path, sharded_dir)), f"model.pt not found in {sharded_dir}"
tp_rank = int(match.group(1))
pp_rank = int(match.group(2))
if tp_size < tp_rank + 1:
tp_size = tp_rank + 1
if pp_size < pp_rank + 1:
pp_size = pp_rank + 1
return sharded_dirs, tp_size, pp_size
def convert_megatron_checkpoints_to_hfmodels():
from verl.utils.megatron_utils import get_model_checkpoint_path, get_hf_model_checkpoint_path
local_path = args.local_dir
model_ckpt_path = get_model_checkpoint_path(local_path)
hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path)
sharded_dirs, tp_size, pp_size = check_megatron_checkpoint_path(model_ckpt_path)
mp_size = len(sharded_dirs)
model_state_dict_lst = []
for i in range(pp_size):
model_state_dict_lst.append([])
for j in range(tp_size):
model_state_dict_lst[i].append("")
print(f'sharded_dirs: {sharded_dirs}, tp_size: {tp_size}, pp_size: {pp_size}, mp_size: {mp_size}')
def process_one_shard(shard_dir):
model_path = os.path.join(model_ckpt_path, shard_dir, "model.pt")
state_dict = torch.load(model_path, map_location='cpu', weights_only=False)
tp_rank, pp_rank = get_tp_pp_rank_from_sharded_dir(shard_dir)
model_state_dict_lst[pp_rank][tp_rank] = state_dict
# with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor:
# for rank in range(1, mp_size):
# executor.submit(process_one_shard, sharded_dirs[rank])
for sharded_dir in sharded_dirs:
process_one_shard(sharded_dir)
state_dict = {}
config = AutoConfig.from_pretrained(args.hf_model_path)
if args.test:
ref_state_dict = load_file(os.path.join(args.test_hf_dir, 'model.safetensors'))
def merge_across_tp(key, tp_data):
if "linear_fc1.weight" in key:
# if the tensor is gate and proj
gate_lst = []
up_lst = []
for infer_param in tp_data:
gate, up = infer_param.chunk(2)
gate_lst.append(gate)
up_lst.append(up)
gate = torch.cat(gate_lst, dim=0)
up = torch.cat(up_lst, dim=0)
tp_data = [gate, up]
elif "self_attention.linear_qkv." in key and 'layer_norm' not in key:
# if the tensor is qkv, for each param on tp, split into q, k, v
# concat q, k, v separately.
q_lst = []
k_lst = []
v_lst = []
assert config.num_attention_heads % config.num_key_value_heads == 0
num_q_per_kv = config.num_attention_heads // config.num_key_value_heads
assert tp_data[0].shape[0] % (num_q_per_kv + 2) == 0
kv_size_per_tp = tp_data[0].shape[0] // (num_q_per_kv + 2)
split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp]
for infer_param in tp_data:
num_query_groups_per_partition = config.num_key_value_heads // tp_size
for chunk in infer_param.chunk(num_query_groups_per_partition):
split_size = [
kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition,
kv_size_per_tp // num_query_groups_per_partition,
kv_size_per_tp // num_query_groups_per_partition
]
q, k, v = chunk.split(split_size)
q_lst.append(q)
k_lst.append(k)
v_lst.append(v)
q = torch.cat(q_lst, dim=0)
k = torch.cat(k_lst, dim=0)
v = torch.cat(v_lst, dim=0)
tp_data = [q, k, v]
elif "layer_norm" in key or "layernorm" in key or "output_layer" in key and args.is_value_model:
tp_data = tp_data[0]
else:
dim = 0
if "linear_fc2.weight" in key or "self_attention.linear_proj" in key:
dim = 1
tp_data = torch.cat(tp_data, dim=dim)
return tp_data
vpp_size = len(model_state_dict_lst[0][0])
layers_cum = 0
for vpp_rank in range(vpp_size):
for pp_rank in range(pp_size):
layers_handled = 0
keys = model_state_dict_lst[pp_rank][0][vpp_rank].keys()
for key in keys:
if "extra_state" in key:
continue
if args.tie_word_embedding and ("output_layer" in key):
print(f'skip lm_head and reward_head loading because of tie_word_embeddings')
continue
new_key = key
if "decoder.layers." in key:
local_layer_no = int(key.split('.')[2])
layers_handled = max(local_layer_no, layers_handled)
global_layer_no = local_layer_no + layers_cum
new_key_list = key.split('.')
new_key_list[2] = str(global_layer_no)
new_key = '.'.join(new_key_list)
tp_data = [model_state_dict_lst[pp_rank][tp_rank][vpp_rank][key] for tp_rank in range(tp_size)]
merged = merge_across_tp(new_key, tp_data)
if not isinstance(merged, list):
state_dict[new_key] = merged
elif len(merged) == 3:
# split qkv
for n, d in zip(['q', 'k', 'v'], merged):
state_dict[new_key.replace("linear_qkv", f"linear_{n}")] = d
elif len(merged) == 2:
# split gate up
state_dict[new_key.replace("linear_fc1", "gate_proj")] = merged[0]
state_dict[new_key.replace("linear_fc1", "up_proj")] = merged[1]
layers_cum += layers_handled + 1 # zero based
del model_state_dict_lst
params_mapping = [
# (megatron core gpt model name, vllm model name)
("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"),
("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"),
("embedding.word_embeddings", "model.embed_tokens"),
("self_attention.linear_qkv", "self_attn.qkv_proj"),
("self_attention.linear_proj", "self_attn.o_proj"),
("pre_mlp_layernorm", "post_attention_layernorm"),
("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight"),
("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias"),
("mlp.linear_fc1", "mlp.gate_up_proj"),
("mlp.linear_fc2", "mlp.down_proj"),
("decoder.final_layernorm", "model.norm"),
("output_layer", "lm_head"),
("self_attention.linear_q", "self_attn.q_proj"),
("self_attention.linear_k", "self_attn.k_proj"),
("self_attention.linear_v", "self_attn.v_proj"),
]
if args.test:
for original_name, loaded_weight in state_dict.items():
name = _replace_name(original_name, params_mapping)
if not name or name.endswith(".bias") and name not in ref_state_dict:
continue
if "rotary_emb.inv_freq" in name:
continue
if args.tie_word_embedding and "lm_head.weight" in name:
continue
if name not in ref_state_dict:
raise RuntimeError(f'key: {name} not exist in state_dict')
param = ref_state_dict[name]
assert loaded_weight.dtype == param.dtype
torch.testing.assert_close(loaded_weight, param, atol=1e-4, rtol=1e-4)
print('Writing to local disk')
if args.target_dir is None:
hf_path = os.path.join(args.local_dir, 'huggingface')
else:
hf_path = args.target_dir
if 'ForTokenClassification' in config.architectures[0]:
auto_model = AutoModelForTokenClassification
elif 'ForCausalLM' in config.architectures[0]:
auto_model = AutoModelForCausalLM
elif 'ForConditionalGeneration' in config.architectures[0]:
auto_model = AutoModelForVision2Seq
else:
raise NotImplementedError(f'Unknown architecture {config["architectures"]}')
with torch.device('meta'):
model = auto_model.from_config(config, torch_dtype=torch.bfloat16)
model.to_empty(device='cpu')
print(f'Saving model to {hf_path}')
model.save_pretrained(hf_path, state_dict=state_dict)
del state_dict
del model
if args.hf_upload_path:
upload_model_to_huggingface(hf_path)
def _replace_name(megatron_name, name_mapping):
for m_name, v_name in name_mapping:
if m_name not in megatron_name:
continue
if "layers" in megatron_name: # deal with decoder layers
megatron_name = megatron_name.replace("decoder", "model")
megatron_name_list = megatron_name.split(".")
if "layer_norm_weight" in megatron_name_list or "layer_norm_bias" in megatron_name_list:
param_name_list = megatron_name_list[:3]
param_name_list.append(v_name)
param_name = ".".join(param_name_list)
else:
param_name_list = megatron_name_list[:3]
weight_or_bias = megatron_name_list[-1]
param_name_list.append(v_name)
param_name_list.append(weight_or_bias)
param_name = ".".join(param_name_list)
return param_name
else:
param_name = megatron_name.replace(m_name, v_name)
return param_name
if __name__ == '__main__':
if args.backend == "fsdp":
convert_fsdp_checkpoints_to_hfmodels()
elif args.backend == "megatron":
convert_megatron_checkpoints_to_hfmodels()
else:
raise NotImplementedError(f"{args.backend} not supported")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment