"src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py" did not exist on "bc108e15333cb0e8a092647320cbb4d70d6d0f03"
Commit b84161d1 authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
Pipeline #2716 canceled with stages
data:
path: /tmp/math_Qwen2-7B-Instruct.parquet
prompt_key: prompt
response_key: responses
data_source_key: data_source
reward_model_key: reward_model
custom_reward_function:
path: null
name: compute_score
trainer:
nnodes: 1
n_gpus_per_node: 8
data:
path: ~/data/rlhf/math/test.parquet
prompt_key: prompt
n_samples: 5
output_path: /opt/tiger/math_Qwen2-7B-Instruct.parquet
batch_size: 128
model:
path: ~/models/Qwen2-7B-Instruct
external_lib: null
rollout:
name: vllm
temperature: 1.0
top_k: 50 # 0 for hf rollout, -1 for vllm rollout
top_p: 0.7
prompt_length: 1536
response_length: 512
# for vllm rollout
dtype: bfloat16 # should align with FSDP
gpu_memory_utilization: 0.5
ignore_eos: False
enforce_eager: True
free_cache_engine: True
load_format: dummy_dtensor
tensor_model_parallel_size: 1
max_num_batched_tokens: 8192
max_model_len: null
max_num_seqs: 1024
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
log_prob_micro_batch_size_per_gpu: 8
# for fire vllm rollout
use_fire_sampling: False # enable FIRE https://arxiv.org/abs/2410.21236
# for hf rollout
do_sample: True
disable_log_stats: True
enable_chunked_prefill: True
n: 1
actor:
strategy: fsdp # This is for backward-compatibility
ulysses_sequence_parallel_size: 1 # sp size
fsdp_config:
fsdp_size: -1
\ No newline at end of file
data:
tokenizer: null
train_files: ~/data/rlhf/gsm8k/train.parquet
val_files: ~/data/rlhf/gsm8k/test.parquet
prompt_key: prompt
reward_fn_key: data_source
max_prompt_length: 512
max_response_length: 512
train_batch_size: 1024
gen_batch_size: ${data.train_batch_size}
val_batch_size: null # DEPRECATED: Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
return_raw_chat: False
shuffle: True
filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You cat set the filter_overlong_prompts_workers to use multiprocessing to speed up.
filter_overlong_prompts_workers: 1
truncation: error
custom_cls:
path: null
name: null
actor_rollout_ref:
hybrid_engine: True
model:
path: ~/models/deepseek-llm-7b-chat
external_lib: null
override_config: {}
enable_gradient_checkpointing: False
gradient_checkpointing_kwargs:
## Activation Checkpointing
activations_checkpoint_method: null # 'uniform', 'block'; not used with 'selective'
# 'uniform' divides the total number of transformer layers and checkpoints the input activation of each chunk
# 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity
activations_checkpoint_granularity: null # 'selective' or 'full'
# 'full' will checkpoint the entire transformer layer and 'selective' only checkpoints memory intensive part of attention
activations_checkpoint_num_layers: null # not used with 'selective'
actor:
strategy: megatron # This is for backward-compatibility
ppo_mini_batch_size: 256
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
ppo_micro_batch_size_per_gpu: null
use_dynamic_bsz: False
use_torch_compile: True # False to disable torch compile
# pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)
clip_ratio: 0.2 # default value if clip_ratio_low and clip_ratio_high are not specified
clip_ratio_low: 0.2
clip_ratio_high: 0.2
clip_ratio_c: 3.0 # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729
loss_agg_mode: "token-mean" # / "seq-mean-token-sum" / "seq-mean-token-mean"
# NOTE: "token-mean" is the default behavior
entropy_coeff: 0.001
use_kl_loss: False # True for GRPO
kl_loss_coef: 0.001 # for grpo
kl_loss_type: low_var_kl # for grpo
ppo_epochs: 1
data_loader_seed: null
shuffle: False
optim:
lr: 1e-6
clip_grad: 1.0
lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio.
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
weight_decay: 0.01
megatron:
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests
context_parallel_size: 1
sequence_parallel: True
use_distributed_optimizer: True
use_dist_checkpointing: False
dist_checkpointing_path: null
seed: 1
load_weight: True
checkpoint:
contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
ref:
strategy: megatron
megatron:
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests
context_parallel_size: 1
sequence_parallel: True
use_distributed_optimizer: True
use_dist_checkpointing: False
dist_checkpointing_path: null
seed: 1
load_weight: True
param_offload: False
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
log_prob_micro_batch_size_per_gpu: null
rollout:
name: vllm
temperature: 1.0
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
top_p: 1
prompt_length: ${data.max_prompt_length} # for xperf_gpt
response_length: ${data.max_response_length}
# for vllm rollout
dtype: bfloat16 # should align with FSDP
gpu_memory_utilization: 0.5
ignore_eos: False
enforce_eager: True
free_cache_engine: True
load_format: dummy_megatron
tensor_model_parallel_size: 1
max_num_batched_tokens: 8192
max_model_len: null
max_num_seqs: 1024
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
log_prob_micro_batch_size_per_gpu: null
disable_log_stats: True
enable_chunked_prefill: False # could get higher throughput
# for hf rollout
do_sample: True
layer_name_map:
qkv_layer_name: qkv
gate_proj_layer_name: gate_up
# number of responses (i.e. num sample times)
n: 1
engine_kwargs: # inference engine parameters
swap_space: null # null means "use the engine default value" (usually 4 GB), setting it to, e.g., 32 means 32 GB
val_kwargs:
# sampling parameters for validation
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
top_p: 1.0
temperature: 0
n: 1
do_sample: False # default eager for validation
critic:
rollout_n: ${actor_rollout_ref.rollout.n}
strategy: megatron
optim:
lr: 1e-5
clip_grad: 1.0
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
weight_decay: 0.01
model:
path: ~/models/deepseek-llm-7b-chat
tokenizer_path: ${actor_rollout_ref.model.path}
override_config: {}
external_lib: ${actor_rollout_ref.model.external_lib}
enable_gradient_checkpointing: False
gradient_checkpointing_kwargs:
## Activation Checkpointing
activations_checkpoint_method: null
activations_checkpoint_granularity: null
activations_checkpoint_num_layers: null
megatron:
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests
context_parallel_size: 1
sequence_parallel: True
use_distributed_optimizer: True
use_dist_checkpointing: False
dist_checkpointing_path: null
seed: 1
load_weight: True
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
ppo_micro_batch_size_per_gpu: null
use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
data_loader_seed: ${actor_rollout_ref.actor.data_loader_seed}
shuffle: ${actor_rollout_ref.actor.shuffle}
cliprange_value: 0.5
kl_ctrl:
type: fixed
kl_coef: 0.001
checkpoint:
contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
reward_model:
enable: False
strategy: megatron
megatron:
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests
context_parallel_size: 1
sequence_parallel: True
use_distributed_optimizer: True
use_dist_checkpointing: False
dist_checkpointing_path: null
seed: 1
model:
input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical
path: ~/models/FsfairX-LLaMA3-RM-v0.1
external_lib: ${actor_rollout_ref.model.external_lib}
load_weight: True
param_offload: False
micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu
micro_batch_size_per_gpu: null
use_dynamic_bsz: ${critic.use_dynamic_bsz}
max_length: null
custom_reward_function:
path: null
name: compute_score
algorithm:
gamma: 1.0
lam: 1.0
adv_estimator: gae
use_kl_in_reward: False
kl_penalty: kl # how to estimate kl divergence
kl_ctrl:
type: fixed
kl_coef: 0.001
horizon: 10000
target_kl: 0.1
trainer:
balance_batch: True
total_epochs: 30
total_training_steps: null
project_name: verl_examples
experiment_name: gsm8k
logger: ['console', 'wandb']
log_val_generations: 0
nnodes: 1
n_gpus_per_node: 8
save_freq: -1
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or disable or resume_path if resume_from_path is set
resume_from_path: null
del_local_ckpt_after_load: False
val_before_train: True
test_freq: 2
critic_warmup: 0
default_hdfs_dir: null
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
max_actor_ckpt_to_keep: null
max_critic_ckpt_to_keep: null
# The timeout for ray worker group to wait for the register center to be ready
ray_wait_register_center_timeout: 300
data:
tokenizer: null
train_files: ~/data/rlhf/gsm8k/train.parquet
val_files: ~/data/rlhf/gsm8k/test.parquet
prompt_key: prompt
reward_fn_key: data_source
max_prompt_length: 512
max_response_length: 512
train_batch_size: 1024
val_batch_size: null # DEPRECATED: Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves
return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
return_raw_chat: False
shuffle: True
filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You cat set the filter_overlong_prompts_workers to use multiprocessing to speed up.
filter_overlong_prompts_workers: 1
truncation: error
image_key: images
video_key: videos
custom_cls:
path: null
name: null
actor_rollout_ref:
hybrid_engine: True
model:
path: ~/models/deepseek-llm-7b-chat
external_lib: null
override_config: { }
enable_gradient_checkpointing: True
use_remove_padding: False
use_liger: False
actor:
strategy: fsdp # This is for backward-compatibility
ppo_mini_batch_size: 256
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
ppo_micro_batch_size_per_gpu: null
use_dynamic_bsz: False
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
# pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high)
clip_ratio: 0.2 # default value if clip_ratio_low and clip_ratio_high are not specified
clip_ratio_low: 0.2
clip_ratio_high: 0.2
clip_ratio_c: 3.0 # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729
loss_agg_mode: "token-mean" # / "seq-mean-token-sum" / "seq-mean-token-mean"
entropy_coeff: 0.001
use_kl_loss: False # True for GRPO
use_torch_compile: True # False to disable torch compile
kl_loss_coef: 0.001 # for grpo
kl_loss_type: low_var_kl # for grpo
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
checkpoint:
contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
optim:
lr: 1e-6
lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio.
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
weight_decay: 0.01
fsdp_config:
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
param_offload: False
optimizer_offload: False
fsdp_size: -1
ref:
strategy: fsdp
fsdp_config:
param_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
log_prob_micro_batch_size_per_gpu: null
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
rollout:
name: vllm
temperature: 1.0
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
top_p: 1
use_fire_sampling: False # https://arxiv.org/abs/2410.21236
prompt_length: ${data.max_prompt_length} # not use for opensource
response_length: ${data.max_response_length}
# for vllm rollout
dtype: bfloat16 # should align with FSDP
gpu_memory_utilization: 0.5
ignore_eos: False
enforce_eager: True
free_cache_engine: True
load_format: dummy_dtensor
tensor_model_parallel_size: 2
max_num_batched_tokens: 8192
max_model_len: null
max_num_seqs: 1024
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
log_prob_micro_batch_size_per_gpu: null
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
disable_log_stats: True
enable_chunked_prefill: True # may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len.
# for hf rollout
do_sample: True
# number of responses (i.e. num sample times)
n: 1 # > 1 for grpo
engine_kwargs: # inference engine parameters
swap_space: null # null means "use the engine default value" (usually 4 GB), setting it to, e.g., 32 means 32 GB
val_kwargs:
# sampling parameters for validation
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
top_p: 1.0
temperature: 0
n: 1
do_sample: False # default eager for validation
critic:
rollout_n: ${actor_rollout_ref.rollout.n}
strategy: fsdp
optim:
lr: 1e-5
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
weight_decay: 0.01
model:
path: ~/models/deepseek-llm-7b-chat
tokenizer_path: ${actor_rollout_ref.model.path}
override_config: { }
external_lib: ${actor_rollout_ref.model.external_lib}
enable_gradient_checkpointing: True
use_remove_padding: False
fsdp_config:
param_offload: False
optimizer_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
fsdp_size: -1
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
ppo_micro_batch_size_per_gpu: null
forward_micro_batch_size: ${critic.ppo_micro_batch_size}
forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu}
use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2
forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: 1 # sp size
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
shuffle: ${actor_rollout_ref.actor.shuffle}
grad_clip: 1.0
cliprange_value: 0.5
checkpoint:
contents: ['model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
reward_model:
enable: False
strategy: fsdp
model:
input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical
path: ~/models/FsfairX-LLaMA3-RM-v0.1
external_lib: ${actor_rollout_ref.model.external_lib}
use_remove_padding: False
fsdp_config:
wrap_policy:
min_num_params: 0
param_offload: False
fsdp_size: -1
micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu
micro_batch_size_per_gpu: null # set a number
max_length: null
ulysses_sequence_parallel_size: 1 # sp size
use_dynamic_bsz: ${critic.use_dynamic_bsz}
forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
reward_manager: naive
custom_reward_function:
path: null
name: compute_score
algorithm:
gamma: 1.0
lam: 1.0
adv_estimator: gae
use_kl_in_reward: False
kl_penalty: kl # how to estimate kl divergence
kl_ctrl:
type: fixed
kl_coef: 0.001
horizon: 10000
target_kl: 0.1
trainer:
balance_batch: True
total_epochs: 30
total_training_steps: null
project_name: verl_examples
experiment_name: gsm8k
logger: [ 'console', 'wandb' ]
log_val_generations: 0
nnodes: 1
n_gpus_per_node: 8
save_freq: -1
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or disable or resume_path if resume_from_path is set
resume_from_path: null
val_before_train: True
test_freq: -1
critic_warmup: 0
default_hdfs_dir: null
del_local_ckpt_after_load: False
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
max_actor_ckpt_to_keep: null
max_critic_ckpt_to_keep: null
# The timeout for ray worker group to wait for the register center to be ready
ray_wait_register_center_timeout: 300
data:
train_batch_size: 128
micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu
micro_batch_size_per_gpu: 4 # this is also val batch size
train_files: ~/data/gsm8k/train.parquet
val_files: ~/data/gsm8k/test.parquet
# Single-turn settings
prompt_key: question
response_key: answer
prompt_dict_keys: ['question']
response_dict_keys: ['answer']
# Multi-turn settings
multiturn:
enable: false # Set to true to use multi-turn dataset
messages_key: messages # Key for messages list in multi-turn mode
max_length: 1024
truncation: error
balance_dp_token: False
chat_template: null
custom_cls:
path: null
name: null
model:
partial_pretrain: ~/models/gemma-1.1-7b-it
fsdp_config:
wrap_policy:
min_num_params: 1
cpu_offload: False
offload_params: False
external_lib: null
enable_gradient_checkpointing: False
trust_remote_code: True
lora_rank: 32 # Set to positive value to enable LoRA (e.g., 32)
lora_alpha: 16 # LoRA scaling factor
target_modules: all-linear # Target modules for LoRA adaptation
use_liger: False
optim:
lr: 1e-5
betas: [0.9, 0.95]
weight_decay: 0.01
warmup_steps_ratio: 0.1
clip_grad: 1.0
lr_scheduler: cosine
ulysses_sequence_parallel_size: 1
use_remove_padding: False
trainer:
default_local_dir: /tmp/sft_model
default_hdfs_dir: hdfs://tmp/experiments/gsm8k/gemma-1.1-7b-it/ # change the hdfs path here
resume_path: null
project_name: gsm8k-sft
experiment_name: test
total_epochs: 1
total_training_steps: 10
logger: ['console']
seed: 1
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A lightweight one-file FSDP SFT Trainer
TODO(zhangchi.usc1992)
- Add calculation of mfu
- Add validation
"""
import os
os.environ['NCCL_DEBUG'] = 'WARN'
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
import logging
import re
from contextlib import nullcontext
import torch
import torch.distributed
from torch import nn, optim
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision, ShardingStrategy, CPUOffload
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, AutoConfig
from verl.utils.torch_functional import get_cosine_schedule_with_warmup, get_wsd_schedule_with_warmup
from tensordict import TensorDict
from torch.utils.data import DataLoader, DistributedSampler
from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis
from verl.utils.fsdp_utils import get_fsdp_wrap_policy, init_fn, get_init_weight_context_manager
from verl.utils.dataset import SFTDataset
from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset
from verl.utils.fs import copy_to_local
from verl.utils.tracking import Tracking
from verl.utils.ulysses import get_ulysses_sequence_parallel_world_size, set_ulysses_sequence_parallel_group
from torch.distributed.device_mesh import DeviceMesh
import verl.utils.hdfs_io as hdfs_io
from verl.utils.debug import log_gpu_memory_usage
from peft import LoraConfig, TaskType, get_peft_model
from verl.workers.sharding_manager import FSDPUlyssesShardingManager
from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad
from verl import DataProto
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv('VERL_SFT_LOGGING_LEVEL', 'WARN'))
def extract_step(path):
match = re.search(r'global_step_(\d+)', path)
if match:
return int(match.group(1))
return None
def convert_to_regular_types(obj):
"""Convert Hydra configs and other special types to regular Python types."""
from omegaconf import ListConfig, DictConfig
if isinstance(obj, (ListConfig, DictConfig)):
return {k: convert_to_regular_types(v) for k, v in obj.items()} if isinstance(obj, DictConfig) else list(obj)
elif isinstance(obj, (list, tuple)):
return [convert_to_regular_types(x) for x in obj]
elif isinstance(obj, dict):
return {k: convert_to_regular_types(v) for k, v in obj.items()}
return obj
class FSDPSFTTrainer(object):
def __init__(self, config, device_mesh: DeviceMesh, ulysses_device_mesh: DeviceMesh):
self.config = config
self.device_mesh = device_mesh
self.ulysses_device_mesh = ulysses_device_mesh
self.sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
# build tokenizer first
local_model_path = copy_to_local(src=self.config.model.partial_pretrain, verbose=True)
from verl.utils import hf_tokenizer
self.tokenizer = hf_tokenizer(local_model_path, trust_remote_code=self.config.model.trust_remote_code)
if self.config.data.chat_template is not None:
raise ValueError('Apply Chat template from config is not supported yet.')
# normalize dp size
self._normalize_config_bsz()
# Set sequence parallel size
self.config.ulysses_sequence_parallel_size = getattr(self.config, 'ulysses_sequence_parallel_size', 1)
self.use_remove_padding = getattr(self.config, 'use_remove_padding', False)
if self.device_mesh.get_rank() == 0:
print(f'Using sequence parallel size: {self.config.ulysses_sequence_parallel_size}')
print(f'Using remove padding: {self.use_remove_padding}')
self._build_dataloader()
# build model
self._build_model_optimizer()
# TODO: add checkpoint manager
if self.device_mesh.get_rank() == 0:
print(self.config)
def _normalize_config_bsz(self):
dp_size = self.device_mesh.size(0) if not self.ulysses_device_mesh else self.ulysses_device_mesh.size(0)
if self.device_mesh.get_rank() == 0:
print(f'Normalize batch size by dp {dp_size}')
assert self.config.data.train_batch_size % dp_size == 0, f"Global batch size {self.config.data.train_batch_size} is not divisible by dp size {dp_size}"
self.config.data.train_batch_size //= dp_size
assert self.config.data.train_batch_size % self.config.data.micro_batch_size_per_gpu == 0
def _build_dataloader(self):
config = self.config
# build dataset
from verl.utils.import_utils import load_extern_type
# First check if a custom dataset class is specified
if config.data.custom_cls.get("path", None):
dataset_cls = load_extern_type(config.data.custom_cls.path, config.data.custom_cls.name)
# Then check if multi-turn dataset should be used
elif config.data.get('multiturn', {}).get('enable', False):
dataset_cls = MultiTurnSFTDataset
# Default to single-turn dataset
else:
dataset_cls = SFTDataset
# Create datasets based on the selected class
self.train_dataset = dataset_cls(parquet_files=config.data.train_files,
tokenizer=self.tokenizer,
config=config.data)
self.val_dataset = dataset_cls(parquet_files=config.data.val_files,
tokenizer=self.tokenizer,
config=config.data)
# build dataloader
# Use data parallel rank and size instead of global rank and world size
# If doing SP, we need to use the local rank and size
if self.config.ulysses_sequence_parallel_size > 1:
rank = self.ulysses_device_mesh.get_local_rank('dp')
world_size = self.ulysses_device_mesh.size(0)
if self.ulysses_device_mesh.get_rank() == 0:
print(f'Using SP rank {rank} and size {world_size} for data distribution')
print(f'Each SP rank gets different data, but the same data WITHIN the same rank')
else:
rank = self.device_mesh.get_rank()
world_size = self.device_mesh.size()
if self.device_mesh.get_rank() == 0:
print(f'Using FSDP rank {rank} and size {world_size} for data distribution')
self.train_sampler = DistributedSampler(self.train_dataset,
shuffle=True,
num_replicas=world_size,
rank=rank,
drop_last=True)
self.train_dataloader = DataLoader(dataset=self.train_dataset,
batch_size=config.data.train_batch_size,
sampler=self.train_sampler,
num_workers=8,
pin_memory=True,
drop_last=True)
self.val_sampler = DistributedSampler(self.val_dataset,
shuffle=False,
num_replicas=world_size,
rank=rank,
drop_last=True)
self.val_dataloader = DataLoader(dataset=self.val_dataset,
batch_size=config.data.micro_batch_size_per_gpu,
sampler=self.val_sampler,
num_workers=8,
pin_memory=True,
drop_last=True)
def _build_model_optimizer(self):
# TODO (zhangchi.usc1992):
# 1. support pretrain from random weights
# 2. support init directly from sharded weights
local_model_path = copy_to_local(src=self.config.model.partial_pretrain, verbose=True)
if self.config.model.get('external_lib', None) is not None:
# This is used to import external_lib into the huggingface systems
import importlib
importlib.import_module(self.config.model.external_lib)
log_gpu_memory_usage('Before model allocation', logger=logger)
trust_remote_code = self.config.model.trust_remote_code
# load config first
config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code)
if self.config.ulysses_sequence_parallel_size > 1:
assert self.use_remove_padding, "Sequence parallel is only supported when remove_padding is enabled"
# This may be very large
init_context = get_init_weight_context_manager(use_meta_tensor=not config.tie_word_embeddings,
mesh=self.device_mesh)
with init_context():
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(local_model_path,
config=config,
torch_dtype=torch.float32,
attn_implementation='flash_attention_2',
trust_remote_code=trust_remote_code)
if self.use_remove_padding or self.config.ulysses_sequence_parallel_size > 1:
from verl.models.transformers.monkey_patch import apply_monkey_patch
apply_monkey_patch(model=self.model, ulysses_sp_size=self.config.ulysses_sequence_parallel_size)
# Apply Liger kernel if use_liger is enabled
if self.config.model.get('use_liger', False):
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance
_apply_liger_kernel_to_instance(model=self.model)
if self.config.model.get('lora_rank', 0) > 0:
self.model.enable_input_require_grads()
# Convert config to regular Python types before creating PEFT model
lora_config = {
'task_type': TaskType.CAUSAL_LM,
'r': self.config.model.lora_rank,
'lora_alpha': self.config.model.lora_alpha,
'target_modules': convert_to_regular_types(self.config.model.target_modules),
'bias': "none"
}
self.model = get_peft_model(self.model, LoraConfig(**lora_config))
if self.config.model.enable_gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False})
log_gpu_memory_usage('After model allocation', logger=logger)
mixed_precision = MixedPrecision(param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32)
auto_wrap_policy = get_fsdp_wrap_policy(self.model,
config=self.config.model.fsdp_config.wrap_policy,
is_lora=self.config.model.get('lora_rank', 0) > 0)
if self.device_mesh.get_rank() == 0:
print(auto_wrap_policy)
if not self.config.model.fsdp_config.cpu_offload:
cpu_offload = None
else:
cpu_offload = CPUOffload(offload_params=self.config.model.fsdp_config.offload_params)
self.fsdp_model = FSDP(module=self.model,
auto_wrap_policy=auto_wrap_policy,
param_init_fn=init_fn,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision,
device_mesh=self.device_mesh,
sync_module_states=True,
device_id=torch.cuda.current_device(),
cpu_offload=cpu_offload,
use_orig_params=False)
log_gpu_memory_usage('After FSDP wrapping', logger=logger)
self.optimizer = optim.AdamW(self.fsdp_model.parameters(),
lr=self.config.optim.lr,
betas=self.config.optim.betas,
weight_decay=self.config.optim.weight_decay)
log_gpu_memory_usage('After initialize optimizer', logger=logger)
self.steps_per_epoch = len(self.train_dataloader)
self.total_steps = self.steps_per_epoch * self.config.trainer.total_epochs
if self.device_mesh.get_rank() == 0:
print(
f'Number of steps/epoch {self.steps_per_epoch}, number of epochs {self.config.trainer.total_epochs}, total number of steps {self.total_steps}'
)
num_warmup_steps = int(self.total_steps * self.config.optim.warmup_steps_ratio)
if not hasattr(self.config.optim, 'lr_scheduler') or self.config.optim.lr_scheduler == 'cosine':
self.lr_scheduler = get_cosine_schedule_with_warmup(optimizer=self.optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=self.total_steps)
elif self.config.optim.lr_scheduler == 'wsd':
self.lr_scheduler = get_wsd_schedule_with_warmup(optimizer=self.optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=self.total_steps)
else:
raise ValueError(f'Unknown lr scheduler: {self.config.optim.lr_scheduler}')
def _compute_loss_and_backward(self, batch, do_backward=True):
"""Compute loss with optional sequence parallelism and remove padding features"""
use_sp = self.use_remove_padding and self.config.ulysses_sequence_parallel_size > 1
# Move inputs to GPU and prepare loss mask
input_ids = batch['input_ids'].cuda()
attention_mask = batch['attention_mask'].cuda()
position_ids = batch['position_ids'].cuda()
loss_mask = batch.pop('loss_mask')[:, :-1].reshape(-1).cuda()
loss_fct = nn.CrossEntropyLoss(reduction='none')
# Context manager for sequence parallel if needed
context = self.sharding_manager if use_sp else nullcontext()
with context:
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
if not use_sp:
# Standard forward pass without sequence parallel
labels = input_ids[:, 1:].contiguous()
output = self.fsdp_model(input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=False)
logits = output.logits
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels.contiguous()
# Flatten the tokens
shift_logits = shift_logits.view(-1, self.model.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
loss = loss * loss_mask.to(loss.device)
else:
# IMPORTANT: We have a big assumption here, so we can shard the SAME sequence across SP ranks
# i.e., each GPU has <1 sequence, and each SP group has 1 sequence
# 1. All SP ranks will receive the *SAME* batch
# 2. Different SP groups will receive *DIFFERENT* batches
# This is implemented by the DistributedSampler
batch_size, seqlen = input_ids.shape
# Remove padding
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# Unpad position_ids to align rotary
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
indices).transpose(0, 1)
# Pad and slice inputs for sequence parallelism
input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs(
input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size())
# For computing loss
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz)
input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(
input_ids_rmpad_rolled, None, get_ulysses_sequence_parallel_world_size())
input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad)
# Forward pass
output = self.fsdp_model(
input_ids=input_ids_rmpad_sliced,
attention_mask=None, # Not needed with flash attention varlen
position_ids=position_ids_rmpad_padded,
use_cache=False)
# Compute loss locally then aggregate
logits_rmpad = output.logits.squeeze(0)
input_ids_rmpad_rolled = input_ids_rmpad_rolled.to(logits_rmpad.device)
loss = loss_fct(logits_rmpad, input_ids_rmpad_rolled)
# Gather and unpad for sequence parallelism
loss = gather_outpus_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=pad_size)
# This is the loss collected from all ulysses ranks
full_loss = pad_input(hidden_states=loss.unsqueeze(-1),
indices=indices,
batch=batch_size,
seqlen=seqlen)
full_loss = full_loss.squeeze(-1)[:, :-1] # Remove last token's loss
full_loss = full_loss.reshape(-1)
loss_mask = loss_mask.to(full_loss.device)
loss = full_loss * loss_mask
valid_token_this_rank = torch.sum(loss_mask)
if self.config.data.balance_dp_token:
torch.distributed.all_reduce(valid_token_this_rank)
dp_size = self.ulysses_device_mesh.size('dp') if use_sp else torch.distributed.get_world_size()
else:
dp_size = 1
loss = torch.sum(loss) / (valid_token_this_rank + 1e-8) * dp_size
if do_backward:
loss.backward()
return loss
def training_step(self, batch: TensorDict):
self.fsdp_model.train()
log_gpu_memory_usage('Before optimizer zero_grad', logger=logger)
self.optimizer.zero_grad()
log_gpu_memory_usage('After optimizer zero_grad', logger=logger)
micro_batches = batch.split(self.config.data.micro_batch_size_per_gpu)
n_micro_batches = len(micro_batches)
step_loss = 0
for micro_batch in micro_batches:
loss = self._compute_loss_and_backward(batch=micro_batch) / n_micro_batches
step_loss += loss.item()
grad_norm = self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad)
log_gpu_memory_usage('Before optimizer step', logger=logger)
# if grad_norm is not finite, skip the update
if not torch.isfinite(grad_norm):
print(f"WARN: grad_norm is not finite: {grad_norm}")
self.optimizer.zero_grad()
else:
self.optimizer.step()
log_gpu_memory_usage('After optimizer step', logger=logger)
self.lr_scheduler.step()
# reduce loss across dp ranks
lr = self.lr_scheduler.get_last_lr()[0]
log_gpu_memory_usage('After offload weights', logger=logger)
step_loss = torch.tensor(step_loss).cuda()
torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG)
return {'train/loss': step_loss.detach().item(), 'train/lr(1e-3)': lr * 1e3}
def validation_step(self, batch: TensorDict):
self.fsdp_model.eval()
with torch.no_grad():
loss = self._compute_loss_and_backward(batch, do_backward=False)
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG)
return loss
def save_checkpoint(self, step):
# save checkpoint
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(self.fsdp_model, StateDictType.FULL_STATE_DICT, cfg):
state_dict = self.fsdp_model.state_dict()
path = os.path.join(self.config.trainer.default_local_dir, f'global_step_{step}')
# save huggingface model
if self.device_mesh.get_rank() == 0:
os.makedirs(path, exist_ok=True)
self.model.save_pretrained(path, state_dict=state_dict)
self.tokenizer.save_pretrained(path)
if self.config.trainer.default_hdfs_dir:
hdfs_io.makedirs(self.config.trainer.default_hdfs_dir, exist_ok=True)
hdfs_io.copy(src=path, dst=self.config.trainer.default_hdfs_dir, dirs_exist_ok=True)
torch.distributed.barrier()
def fit(self):
rank = self.device_mesh.get_rank()
# TODO: add a unified tracking
if rank == 0:
tracking = Tracking(project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger)
global_step = 0
# compute the total training steps.
# the total training steps in SFT is mainly for early exit
total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
if self.config.trainer.total_training_steps is not None:
total_training_steps = self.config.trainer.total_training_steps
self.total_training_steps = total_training_steps
print(f'Total training steps: {self.total_training_steps}')
# TODO (zhangchi.usc1992) add back checkpoint manager. Currently, it blocks when uploading to hdfs. So very slow.
for epoch in range(self.config.trainer.total_epochs):
self.train_sampler.set_epoch(epoch=epoch)
for data in tqdm(self.train_dataloader,
total=self.steps_per_epoch,
desc=f"Epoch {epoch+1}/{self.config.trainer.total_epochs}"):
global_step += 1
data = TensorDict(data, batch_size=self.config.data.train_batch_size).cuda()
metric = self.training_step(data)
if rank == 0:
tracking.log(data=metric, step=global_step)
# for early exit validation
if global_step >= self.total_training_steps:
# Perform final validation
val_losses = []
for val_data in self.val_dataloader:
val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).cuda()
val_loss = self.validation_step(val_data)
val_losses.append(val_loss)
if rank == 0:
avg_val_loss = torch.mean(torch.stack(val_losses))
metric = {'val/loss': avg_val_loss.detach().item()}
tracking.log(data=metric, step=global_step)
torch.distributed.barrier()
# Save final checkpoint
self.save_checkpoint(step=global_step)
return
# validation
val_losses = []
for data in self.val_dataloader:
data = TensorDict(data, batch_size=self.config.data.micro_batch_size_per_gpu).cuda()
val_loss = self.validation_step(data)
val_losses.append(val_loss)
if rank == 0:
val_loss = torch.mean(torch.stack(val_losses))
metric = {'val/loss': val_loss.detach().item()}
tracking.log(data=metric, step=global_step)
torch.distributed.barrier()
# save checkpoint
self.save_checkpoint(step=global_step)
from verl.trainer.fsdp_sft_trainer import FSDPSFTTrainer
import hydra
from torch.distributed.device_mesh import init_device_mesh
from verl.utils.distributed import initialize_global_process_group
@hydra.main(config_path='config', config_name='sft_trainer', version_base=None)
def main(config):
local_rank, rank, world_size = initialize_global_process_group()
device_mesh = init_device_mesh(device_type='cuda', mesh_shape=(world_size,), mesh_dim_names=('fsdp',))
dp_size = world_size // config.ulysses_sequence_parallel_size
ulysses_device_mesh = init_device_mesh(device_type='cuda',
mesh_shape=(dp_size, config.ulysses_sequence_parallel_size),
mesh_dim_names=('dp', 'sp'))
trainer = FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh)
trainer.fit()
if __name__ == '__main__':
main()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Offline evaluate the performance of a generated file using reward model and ground truth verifier.
The input is a parquet file that contains N generated sequences and (optional) the ground truth.
"""
import hydra
from verl.utils.fs import copy_to_local
import pandas as pd
import numpy as np
from tqdm import tqdm
from collections import defaultdict
import ray
def get_custom_reward_fn(config):
import importlib.util, os, sys
reward_fn_config = config.get("custom_reward_function") or {}
file_path = reward_fn_config.get("path")
if not file_path:
return None
if not os.path.exists(file_path):
raise FileNotFoundError(f"Reward function file '{file_path}' not found.")
spec = importlib.util.spec_from_file_location("custom_module", file_path)
module = importlib.util.module_from_spec(spec)
try:
sys.modules["custom_module"] = module
spec.loader.exec_module(module)
except Exception as e:
raise RuntimeError(f"Error loading module from '{file_path}': {e}")
function_name = reward_fn_config.get("name")
if not hasattr(module, function_name):
raise AttributeError(f"Reward function '{function_name}' not found in '{file_path}'.")
print(f"using customized reward function '{function_name}' from '{file_path}'")
raw_fn = getattr(module, function_name)
reward_kwargs = dict(reward_fn_config.get("reward_kwargs", {}))
def wrapped_fn(*args, **kwargs):
return raw_fn(*args, **kwargs, **reward_kwargs)
return wrapped_fn
@ray.remote
def process_item(reward_fn, data_source, response_lst, reward_data):
ground_truth = reward_data['ground_truth']
score_lst = [reward_fn(data_source, r, ground_truth) for r in response_lst]
return data_source, np.mean(score_lst)
@hydra.main(config_path='config', config_name='evaluation', version_base=None)
def main(config):
local_path = copy_to_local(config.data.path)
dataset = pd.read_parquet(local_path)
prompts = dataset[config.data.prompt_key]
responses = dataset[config.data.response_key]
data_sources = dataset[config.data.data_source_key]
reward_model_data = dataset[config.data.reward_model_key]
total = len(dataset)
# Initialize Ray
if not ray.is_initialized():
ray.init()
# evaluate test_score based on data source
data_source_reward = defaultdict(list)
compute_score = get_custom_reward_fn(config)
# Create remote tasks
remote_tasks = [
process_item.remote(compute_score, data_sources[i], responses[i], reward_model_data[i]) for i in range(total)
]
# Process results as they come in
with tqdm(total=total) as pbar:
while len(remote_tasks) > 0:
# Use ray.wait to get completed tasks
done_ids, remote_tasks = ray.wait(remote_tasks)
for result_id in done_ids:
data_source, score = ray.get(result_id)
data_source_reward[data_source].append(score)
pbar.update(1)
metric_dict = {}
for data_source, rewards in data_source_reward.items():
metric_dict[f'test_score/{data_source}'] = np.mean(rewards)
print(metric_dict)
if __name__ == '__main__':
main()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Generate responses given a dataset of prompts
"""
import ray
import numpy as np
import hydra
import os
os.environ['NCCL_DEBUG'] = 'WARN'
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
# os.environ['TORCH_COMPILE_DISABLE'] = '1'
from verl.utils.model import compute_position_id_with_mask
import pandas as pd
from transformers import AutoTokenizer
from verl import DataProto
from verl.utils.fs import copy_to_local
from verl.workers.fsdp_workers import ActorRolloutRefWorker
from verl.utils.hdfs_io import makedirs
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
@hydra.main(config_path='config', config_name='generation', version_base=None)
def main(config):
run_generation(config)
def run_generation(config) -> None:
if not ray.is_initialized():
# this is for local ray cluster
ray.init(runtime_env={'env_vars': {'TOKENIZERS_PARALLELISM': 'true', 'NCCL_DEBUG': 'WARN'}})
ray.get(main_task.remote(config))
@ray.remote(num_cpus=1)
def main_task(config):
from pprint import pprint
from omegaconf import OmegaConf
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
local_path = copy_to_local(config.model.path)
from verl.utils import hf_tokenizer
trust_remote_code = config.data.get('trust_remote_code', False)
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
if config.rollout.temperature == 0.:
assert config.data.n_samples == 1, 'When temperature=0, n_samples must be 1.'
# read dataset. Note that the dataset should directly contain chat template format (e.g., a list of dictionary)
dataset = pd.read_parquet(config.data.path)
chat_lst = dataset[config.data.prompt_key].tolist()
chat_lst = [chat.tolist() for chat in chat_lst]
tokenizer.padding_side = 'left'
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role='rollout')
resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes)
wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)
wg.init_model()
total_samples = len(dataset)
# real_batch_size = data.batch['input_ids'].shape[0]
config_batch_size = config.data.batch_size
dispatch_dp_size = wg.world_size
num_batch = -(-total_samples // config_batch_size)
output_lst = [[] for _ in range(config.data.n_samples)]
for batch_idx in range(num_batch):
print(f'[{batch_idx+1}/{num_batch}] Start to process.')
batch_chat_lst = chat_lst[batch_idx * config_batch_size:(batch_idx + 1) * config_batch_size]
inputs = tokenizer.apply_chat_template(batch_chat_lst,
add_generation_prompt=True,
padding=True,
truncation=True,
max_length=config.rollout.prompt_length,
return_tensors='pt',
return_dict=True,
tokenize=True)
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
position_ids = compute_position_id_with_mask(attention_mask)
batch_dict = {'input_ids': input_ids, 'attention_mask': attention_mask, 'position_ids': position_ids}
data = DataProto.from_dict(batch_dict)
real_batch_size = data.batch['input_ids'].shape[0]
if real_batch_size % dispatch_dp_size != 0:
dummy_data_size = dispatch_dp_size - real_batch_size % dispatch_dp_size
if dummy_data_size <= real_batch_size:
dummy_data = data[:dummy_data_size]
else:
dummy_data = data.repeat(-(-dummy_data_size // real_batch_size))[:dummy_data_size]
data = DataProto.concat([data, dummy_data])
print(
f'real_batch_size {real_batch_size} is not divisible by dispatch_dp_size {dispatch_dp_size}, add {dummy_data_size} dummy data'
)
batch_size = data.batch['input_ids'].shape[0]
assert batch_size % dispatch_dp_size == 0, f'batch_size {batch_size} is not divisible by dispatch_dp_size {dispatch_dp_size}'
print(f'[{batch_idx+1}/{num_batch}] Start to generate.')
# START TO GENERATE FOR n_samples TIMES
for i in range(config.data.n_samples):
output = wg.generate_sequences(data)
# remove dummy data
output = output[:real_batch_size]
output_text = tokenizer.batch_decode(output.batch['input_ids'][:, -config.rollout.response_length:],
skip_special_tokens=False)
# remove the padding
pad_token = tokenizer.pad_token
output_text_unpad = []
for text in output_text:
output_text_unpad.append(text.replace(pad_token, ''))
output_lst[i].extend(output_text_unpad)
# convert output_lst from (n_samples, n_data) to (n_data, n_sampels)
output_lst = np.array(output_lst, dtype=object)
output_lst = np.transpose(output_lst, axes=(1, 0)).tolist()
# add to the data frame
dataset[f'responses'] = output_lst
# write to a new parquet
output_dir = os.path.dirname(config.data.output_path)
makedirs(output_dir, exist_ok=True)
dataset.to_parquet(config.data.output_path)
return output_text
if __name__ == '__main__':
main()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
import os
import ray
import hydra
def get_custom_reward_fn(config):
import importlib.util, sys
reward_fn_config = config.get("custom_reward_function") or {}
file_path = reward_fn_config.get("path")
if not file_path:
return None
if not os.path.exists(file_path):
raise FileNotFoundError(f"Reward function file '{file_path}' not found.")
spec = importlib.util.spec_from_file_location("custom_module", file_path)
module = importlib.util.module_from_spec(spec)
try:
sys.modules["custom_module"] = module
spec.loader.exec_module(module)
except Exception as e:
raise RuntimeError(f"Error loading module from '{file_path}': {e}")
function_name = reward_fn_config.get("name")
if not hasattr(module, function_name):
raise AttributeError(f"Reward function '{function_name}' not found in '{file_path}'.")
print(f"using customized reward function '{function_name}' from '{file_path}'")
raw_fn = getattr(module, function_name)
reward_kwargs = dict(reward_fn_config.get("reward_kwargs", {}))
def wrapped_fn(*args, **kwargs):
return raw_fn(*args, **kwargs, **reward_kwargs)
return wrapped_fn
@hydra.main(config_path='config', config_name='ppo_trainer', version_base=None)
def main(config):
run_ppo(config)
def run_ppo(config) -> None:
# TODO(linjunrong.ocss884): this ENV is left for resolving SGLang conflict with ray devices
# isolation, will solve in the future
os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get('CUDA_VISIBLE_DEVICES', '')
if not ray.is_initialized():
# this is for local ray cluster
ray.init(runtime_env={
'env_vars': {
'TOKENIZERS_PARALLELISM': 'true',
'NCCL_DEBUG': 'WARN',
'VLLM_LOGGING_LEVEL': 'WARN'
}
})
runner = TaskRunner.remote()
ray.get(runner.run.remote(config))
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
class TaskRunner:
def run(self, config):
from verl.utils.fs import copy_to_local
# print initial config
from pprint import pprint
from omegaconf import OmegaConf
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
# download the checkpoint from hdfs
local_path = copy_to_local(config.actor_rollout_ref.model.path)
# instantiate tokenizer
from verl.utils import hf_tokenizer, hf_processor
trust_remote_code = config.data.get('trust_remote_code', False)
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none
# define worker classes
if config.actor_rollout_ref.actor.strategy == 'fsdp':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray import RayWorkerGroup
ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == 'megatron':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
ray_worker_group_cls = NVMegatronRayWorkerGroup
else:
raise NotImplementedError
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
role_worker_mapping = {
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
Role.Critic: ray.remote(CriticWorker),
}
global_pool_id = 'global_pool'
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
Role.Critic: global_pool_id,
}
# we should adopt a multi-source reward function here
# - for rule-based rm, we directly call a reward score
# - for model-based rm, we call a model
# - for code related prompt, we send to a sandbox if there are test cases
# - finally, we combine all the rewards together
# - The reward type depends on the tag of the data
if config.reward_model.enable:
if config.reward_model.strategy == 'fsdp':
from verl.workers.fsdp_workers import RewardModelWorker
elif config.reward_model.strategy == 'megatron':
from verl.workers.megatron_workers import RewardModelWorker
else:
raise NotImplementedError
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
mapping[Role.RewardModel] = global_pool_id
#use reference model
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
mapping[Role.RefPolicy] = global_pool_id
reward_manager_name = config.reward_model.get("reward_manager", "naive")
if reward_manager_name == 'naive':
from verl.workers.reward_manager import NaiveRewardManager
reward_manager_cls = NaiveRewardManager
elif reward_manager_name == 'prime':
from verl.workers.reward_manager import PrimeRewardManager
reward_manager_cls = PrimeRewardManager
elif reward_manager_name == 'batch':
from verl.workers.reward_manager import BatchRewardManager
reward_manager_cls = BatchRewardManager
elif reward_manager_name == 'dapo':
from verl.workers.reward_manager import DAPORewardManager
reward_manager_cls = DAPORewardManager
else:
raise NotImplementedError
compute_score = get_custom_reward_fn(config)
reward_kwargs = dict(config.reward_model.get("reward_kwargs", {}))
reward_fn = reward_manager_cls(tokenizer=tokenizer,
num_examine=0,
compute_score=compute_score,
reward_fn_key=config.data.reward_fn_key,
**reward_kwargs)
# Note that we always use function-based RM for validation
val_reward_fn = reward_manager_cls(tokenizer=tokenizer,
num_examine=1,
compute_score=compute_score,
reward_fn_key=config.data.reward_fn_key)
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
trainer = RayPPOTrainer(config=config,
tokenizer=tokenizer,
processor=processor,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn)
trainer.init_workers()
trainer.fit()
if __name__ == '__main__':
main()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Core functions to implement PPO algorithms.
The function implemented in this file should be used by trainer with different distributed strategies to
implement PPO
"""
import numpy as np
import torch
from collections import defaultdict
import verl.utils.torch_functional as verl_F
class AdaptiveKLController:
"""
Adaptive KL controller described in the paper:
https://arxiv.org/pdf/1909.08593.pdf
"""
def __init__(self, init_kl_coef, target_kl, horizon):
self.value = init_kl_coef
self.target = target_kl
self.horizon = horizon
def update(self, current_kl, n_steps):
target = self.target
proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
mult = 1 + proportional_error * n_steps / self.horizon
self.value *= mult
class FixedKLController:
"""Fixed KL controller."""
def __init__(self, kl_coef):
self.value = kl_coef
def update(self, current_kl, n_steps):
pass
def get_kl_controller(kl_ctrl):
if kl_ctrl.type == 'fixed':
return FixedKLController(kl_coef=kl_ctrl.kl_coef)
elif kl_ctrl.type == 'adaptive':
assert kl_ctrl.horizon > 0, f'horizon must be larger than 0. Got {kl_ctrl.horizon}'
return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon)
else:
raise NotImplementedError
def compute_gae_advantage_return(token_level_rewards: torch.Tensor, values: torch.Tensor, response_mask: torch.Tensor,
gamma: torch.Tensor, lam: torch.Tensor):
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
values: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.
gamma: `(float)`
discounted factor used in RL
lam: `(float)`
lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
with torch.no_grad():
lastgaelam = 0
advantages_reversed = []
gen_len = token_level_rewards.shape[-1]
for t in reversed(range(gen_len)):
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
lastgaelam = delta + gamma * lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
returns = advantages + values
advantages = verl_F.masked_whiten(advantages, response_mask)
return advantages, returns
# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar.
def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6):
"""
Compute advantage for GRPO, operating only on Outcome reward
(with only one scalar reward for each response).
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2mean = {}
id2std = {}
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
id2std[idx] = torch.tensor(1.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
scores = scores.unsqueeze(-1) * response_mask
return scores, scores
def compute_reinforce_plus_plus_baseline_outcome_advantage(token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: torch.Tensor,
epsilon: float = 1e-6):
"""
Compute advantage for RF++-baseline (https://arxiv.org/abs/2501.03262), operating only on Outcome reward
(with only one scalar reward for each response).
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
response_length = token_level_rewards.shape[-1]
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2mean = {}
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
scores[i] = scores[i] - id2mean[index[i]]
scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask
scores = verl_F.masked_whiten(scores, response_mask)
return scores, scores
def compute_rloo_outcome_advantage(token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6):
"""
Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2mean = {}
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
response_num = len(id2score[index[i]])
if response_num > 1:
scores[i] = scores[i] * response_num / (response_num -
1) - id2mean[index[i]] * response_num / (response_num - 1)
scores = scores.unsqueeze(-1) * response_mask
return scores, scores
def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Tensor, response_mask: torch.Tensor,
gamma: torch.Tensor):
"""
Compute advantage for REINFORCE++.
This implementation is based on the paper: https://arxiv.org/abs/2501.03262
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
with torch.no_grad():
returns = torch.zeros_like(token_level_rewards)
running_return = 0
for t in reversed(range(token_level_rewards.shape[1])):
running_return = token_level_rewards[:, t] + gamma * running_return
returns[:, t] = running_return
# Reset after EOS
running_return = running_return * response_mask[:, t]
advantages = verl_F.masked_whiten(returns, response_mask)
advantages = advantages * response_mask
return advantages, returns
def compute_remax_outcome_advantage(token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor,
response_mask: torch.Tensor):
"""
Compute advantage for ReMax, operating only on Outcome reward
This implementation is based on the paper: https://arxiv.org/abs/2310.10505
(with only one scalar reward for each response).
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
reward_baselines: `(torch.Tensor)`
shape: (bs,)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
with torch.no_grad():
returns = (token_level_rewards * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
advantages = returns - reward_baselines.unsqueeze(-1) * response_mask
return advantages, returns
def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
kl = old_log_prob - ref_log_prob
return token_level_scores - kl * kl_ratio
def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str):
"""
Aggregate the loss matrix into a scalar.
Args:
loss_mat: `(torch.Tensor)`
shape: (bs, response_length)
loss_mask: `(torch.Tensor)`
shape: (bs, response_length)
loss_agg_mode: (str) choices: "token-mean" / "seq-mean-token-sum" / "seq-mean-token-mean"
"token-mean" is the default behavior
Returns:
loss: `a scalar torch.Tensor`
aggregated loss
"""
if loss_agg_mode == "token-mean":
loss = verl_F.masked_mean(loss_mat, loss_mask)
elif loss_agg_mode == "seq-mean-token-sum":
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum
loss = torch.mean(seq_losses) # seq-mean
elif loss_agg_mode == "seq-mean-token-mean":
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1) # token-mean
loss = torch.mean(seq_losses) # seq-mean
else:
raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}")
return loss
def compute_policy_loss(old_log_prob,
log_prob,
advantages,
response_mask,
cliprange=None,
cliprange_low=None,
cliprange_high=None,
clip_ratio_c=3.0,
loss_agg_mode="token-mean"):
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122
Args:
old_log_prob: `(torch.Tensor)`
shape: (bs, response_length)
log_prob: `(torch.Tensor)`
shape: (bs, response_length)
advantages: `(torch.Tensor)`
shape: (bs, response_length)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
cliprange: (float)
The clip range used in PPO. See https://arxiv.org/abs/1707.06347
cliprange_low: (float)
The lower clip range used in PPO.
cliprange_high: (float)
The higher clip range used in PPO.
clip_ratio_c: (float) default: 3.0
The lower bound of the ratio for dual-clip PPO, See https://arxiv.org/pdf/1912.09729
loss_agg_mode: (str) choices: "token-mean" / "seq-mean-token-sum" / "seq-mean-token-mean"
"token-mean" is the default behavior
Returns:
pg_loss: `a scalar torch.Tensor`
policy gradient loss computed via PPO
pg_clipfrac: (float)
the fraction of policy gradient loss being clipped
ppo_kl: (float)
the estimated KL divergence between the latest updating policy and the old sampling policy
pg_clipfrac_lower: (float)
the fraction of policy gradient loss being clipped when the advantage is negative
"""
assert clip_ratio_c > 1.0, f"The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0, but get the value: {clip_ratio_c}."
negative_approx_kl = log_prob - old_log_prob
ratio = torch.exp(negative_approx_kl)
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
pg_losses1 = -advantages * ratio
if cliprange_low is None:
cliprange_low = cliprange
if cliprange_high is None:
cliprange_high = cliprange
pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low,
1 + cliprange_high) # - clip(ratio, 1-cliprange, 1+cliprange) * A
clip_pg_losses1 = torch.maximum(pg_losses1,
pg_losses2) # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A)
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask)
pg_losses3 = -advantages * clip_ratio_c
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
pg_clipfrac_lower = verl_F.masked_mean(
torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask)
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
def compute_entropy_loss(logits, response_mask):
"""Compute Categorical entropy loss
Args:
logits: `(torch.Tensor)`
shape: (bs, response_length, vocab_size)
response_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
entropy: a scalar torch.Tensor
"""
# compute entropy
entropy = verl_F.entropy_from_logits(logits) # (bs, response_len)
entropy_loss = verl_F.masked_mean(entropy, mask=response_mask)
return entropy_loss
def compute_value_loss(vpreds, returns, values, response_mask, cliprange_value):
"""Compute the value loss. Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151
Args:
vpreds (`torch.FloatTensor`):
Predicted values of the value head, shape (`batch_size`, `response_length`)
values (`torch.FloatTensor`):
Old values of value head, shape (`batch_size`, `response_length`)
returns: (`torch.FloatTensor`):
Ground truth returns, shape (`batch_size`, `response_length`)
Returns:
vf_loss: a scalar (`torch.FloatTensor`):
value function loss
vf_clipfrac: a float
The ratio of vf being clipped
"""
vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)
vf_losses1 = (vpreds - returns)**2
vf_losses2 = (vpredclipped - returns)**2
vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), response_mask)
vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask)
return vf_loss, vf_clipfrac
def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor:
"""Compute KL divergence given logprob and ref_logprob.
Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104
Args:
logprob:
ref_logprob:
Returns:
"""
if kl_penalty == "kl":
return logprob - ref_logprob
if kl_penalty == "abs":
return (logprob - ref_logprob).abs()
if kl_penalty == "mse":
return 0.5 * (logprob - ref_logprob).square()
# J. Schulman. Approximating kl divergence, 2020.
# # URL http://joschu.net/blog/kl-approx.html.
if kl_penalty == 'low_var_kl':
kl = ref_logprob - logprob
ratio = torch.exp(kl)
kld = (ratio - kl - 1).contiguous()
return torch.clamp(kld, min=-10, max=10)
if kl_penalty == "full":
# so, here logprob and ref_logprob should contain the logits for every token in vocabulary
raise NotImplementedError
raise NotImplementedError
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Metrics related to the PPO trainer.
"""
import torch
from typing import Any, Dict, List, Callable
import numpy as np
from verl import DataProto
from collections import Counter, defaultdict
from functools import partial
def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]:
for key, val in metrics.items():
metrics[key] = np.mean(val)
return metrics
def _compute_response_info(batch: DataProto) -> Dict[str, Any]:
response_length = batch.batch['responses'].shape[-1]
prompt_mask = batch.batch['attention_mask'][:, :-response_length]
response_mask = batch.batch['attention_mask'][:, -response_length:]
prompt_length = prompt_mask.sum(-1).float()
response_length = response_mask.sum(-1).float() # (batch_size,)
return dict(
response_mask=response_mask,
prompt_length=prompt_length,
response_length=response_length,
)
def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> Dict[str, Any]:
# TODO: add response length
sequence_score = batch.batch['token_level_scores'].sum(-1)
sequence_reward = batch.batch['token_level_rewards'].sum(-1)
advantages = batch.batch['advantages']
returns = batch.batch['returns']
max_response_length = batch.batch['responses'].shape[-1]
prompt_mask = batch.batch['attention_mask'][:, :-max_response_length].bool()
response_mask = batch.batch['attention_mask'][:, -max_response_length:].bool()
max_prompt_length = prompt_mask.size(-1)
response_info = _compute_response_info(batch)
prompt_length = response_info['prompt_length']
response_length = response_info['response_length']
valid_adv = torch.masked_select(advantages, response_mask)
valid_returns = torch.masked_select(returns, response_mask)
if use_critic:
values = batch.batch['values']
valid_values = torch.masked_select(values, response_mask)
return_diff_var = torch.var(valid_returns - valid_values)
return_var = torch.var(valid_returns)
metrics = {
# score
'critic/score/mean':
torch.mean(sequence_score).detach().item(),
'critic/score/max':
torch.max(sequence_score).detach().item(),
'critic/score/min':
torch.min(sequence_score).detach().item(),
# reward
'critic/rewards/mean':
torch.mean(sequence_reward).detach().item(),
'critic/rewards/max':
torch.max(sequence_reward).detach().item(),
'critic/rewards/min':
torch.min(sequence_reward).detach().item(),
# adv
'critic/advantages/mean':
torch.mean(valid_adv).detach().item(),
'critic/advantages/max':
torch.max(valid_adv).detach().item(),
'critic/advantages/min':
torch.min(valid_adv).detach().item(),
# returns
'critic/returns/mean':
torch.mean(valid_returns).detach().item(),
'critic/returns/max':
torch.max(valid_returns).detach().item(),
'critic/returns/min':
torch.min(valid_returns).detach().item(),
**({
# values
'critic/values/mean': torch.mean(valid_values).detach().item(),
'critic/values/max': torch.max(valid_values).detach().item(),
'critic/values/min': torch.min(valid_values).detach().item(),
# vf explained var
'critic/vf_explained_var': (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
} if use_critic else {}),
# response length
'response_length/mean':
torch.mean(response_length).detach().item(),
'response_length/max':
torch.max(response_length).detach().item(),
'response_length/min':
torch.min(response_length).detach().item(),
'response_length/clip_ratio':
torch.mean(torch.eq(response_length, max_response_length).float()).detach().item(),
# prompt length
'prompt_length/mean':
torch.mean(prompt_length).detach().item(),
'prompt_length/max':
torch.max(prompt_length).detach().item(),
'prompt_length/min':
torch.min(prompt_length).detach().item(),
'prompt_length/clip_ratio':
torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
}
return metrics
def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Dict[str, Any]:
response_info = _compute_response_info(batch)
num_prompt_tokens = torch.sum(response_info['prompt_length']).item()
num_response_tokens = torch.sum(response_info['response_length']).item()
num_overall_tokens = num_prompt_tokens + num_response_tokens
num_tokens_of_section = {
'gen': num_response_tokens,
**{
name: num_overall_tokens for name in ['ref', 'values', 'adv', 'update_critic', 'update_actor']
},
}
return {
**{
f'timing_s/{name}': value for name, value in timing_raw.items()
},
**{
f'timing_per_token_ms/{name}': timing_raw[name] * 1000 / num_tokens_of_section[name] for name in set(num_tokens_of_section.keys(
)) & set(timing_raw.keys())
},
}
def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], n_gpus: int) -> Dict[str, Any]:
total_num_tokens = sum(batch.meta_info['global_token_num'])
time = timing_raw['step']
# estimated_flops, promised_flops = flops_function.estimate_flops(num_tokens, time)
# f'Actual TFLOPs/s/GPU​': estimated_flops/(n_gpus),
# f'Theoretical TFLOPs/s/GPU​': promised_flops,
return {
'perf/total_num_tokens': total_num_tokens,
'perf/time_per_step': time,
'perf/throughput': total_num_tokens / (time * n_gpus),
}
def bootstrap_metric(data: list[Any],
subset_size: int,
reduce_fns: list[Callable[[np.ndarray], float]],
n_bootstrap: int = 1000,
seed: int = 42) -> list[tuple[float, float]]:
np.random.seed(seed)
bootstrap_metric_lsts = [[] for _ in range(len(reduce_fns))]
for _ in range(n_bootstrap):
bootstrap_idxs = np.random.choice(len(data), size=subset_size, replace=True)
bootstrap_data = [data[i] for i in bootstrap_idxs]
for i, reduce_fn in enumerate(reduce_fns):
bootstrap_metric_lsts[i].append(reduce_fn(bootstrap_data))
return [(np.mean(lst), np.std(lst)) for lst in bootstrap_metric_lsts]
def calc_maj_val(data: list[dict[str, Any]], vote_key: str, val_key: str) -> float:
"""
Calculate the majority voting metric
"""
vote2vals = defaultdict(list)
for d in data:
vote2vals[d[vote_key]].append(d[val_key])
vote2cnt = {k: len(v) for k, v in vote2vals.items()}
maj_vote = max(vote2cnt, key=vote2cnt.get)
maj_val = vote2vals[maj_vote][0]
return maj_val
def process_validation_metrics(data_sources: list[str],
sample_inputs: list[str],
infos_dict: dict[str, list[Any]],
seed: int = 42) -> dict[str, dict[str, dict[str, float]]]:
"""Process validation metrics into a structured format.
Args:
data_sources: Array of data source identifiers for each sample
sample_inputs: List of input prompts
infos_dict: variable name -> list of values for each sample
Returns:
dict[str, dict[str, dict[str, float]]]: data source -> variable name -> metric value
"""
# Group metrics by data source, prompt and variable
data_src2prompt2var2vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
for sample_idx, data_source in enumerate(data_sources):
prompt = sample_inputs[sample_idx]
var2vals = data_src2prompt2var2vals[data_source][prompt]
for var_name, var_vals in infos_dict.items():
var2vals[var_name].append(var_vals[sample_idx])
# Calculate metrics for each group
data_src2prompt2var2metric = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
for data_source, prompt2var2vals in data_src2prompt2var2vals.items():
for prompt, var2vals in prompt2var2vals.items():
for var_name, var_vals in var2vals.items():
if isinstance(var_vals[0], str):
continue
metric = {}
n_resps = len(var_vals)
metric[f"mean@{n_resps}"] = np.mean(var_vals)
metric[f"std@{n_resps}"] = np.std(var_vals)
ns = []
n = 2
while n < n_resps:
ns.append(n)
n *= 2
ns.append(n_resps)
for n in ns:
# Best/Worst-of-N
[(bon_mean, bon_std), (won_mean, won_std)] = bootstrap_metric(data=var_vals,
subset_size=n,
reduce_fns=[np.max, np.min],
seed=seed)
metric[f"best@{n}/mean"], metric[f"best@{n}/std"] = bon_mean, bon_std
metric[f"worst@{n}/mean"], metric[f"worst@{n}/std"] = won_mean, won_std
# Majority voting
if var2vals.get("pred", None) is not None:
vote_data = [{"val": val, "pred": pred} for val, pred in zip(var_vals, var2vals["pred"])]
[(maj_n_mean, maj_n_std)
] = bootstrap_metric(data=vote_data,
subset_size=n,
reduce_fns=[partial(calc_maj_val, vote_key="pred", val_key="val")],
seed=seed)
metric[f"maj@{n}/mean"], metric[f"maj@{n}/std"] = maj_n_mean, maj_n_std
data_src2prompt2var2metric[data_source][prompt][var_name] = metric
# Aggregate metrics across prompts
data_src2var2metric2prompt_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
for data_source, prompt2var2metric in data_src2prompt2var2metric.items():
for prompt, var2metric in prompt2var2metric.items():
for var_name, metric in var2metric.items():
for metric_name, metric_val in metric.items():
data_src2var2metric2prompt_vals[data_source][var_name][metric_name].append(metric_val)
data_src2var2metric2val = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
for data_source, var2metric2prompt_vals in data_src2var2metric2prompt_vals.items():
for var_name, metric2prompt_vals in var2metric2prompt_vals.items():
for metric_name, prompt_vals in metric2prompt_vals.items():
data_src2var2metric2val[data_source][var_name][metric_name] = np.mean(prompt_vals)
return data_src2var2metric2val
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
FSDP PPO Trainer with Ray-based single controller.
This trainer supports model-agonistic model initialization with huggingface
"""
import os
import uuid
import warnings
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from pprint import pprint
from typing import Type, Dict
from copy import deepcopy
from collections import defaultdict
from functools import partial
from tqdm import tqdm
import ray
import numpy as np
from codetiming import Timer
from omegaconf import OmegaConf, open_dict
from verl import DataProto
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
from verl.single_controller.base import Worker
from verl.single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs
from verl.single_controller.ray.base import create_colocated_worker_cls
from verl.trainer.ppo import core_algos
from verl.trainer.ppo.metric_utils import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics, reduce_metrics, bootstrap_metric, calc_maj_val, process_validation_metrics
from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
from verl.utils.tracking import ValidationGenerationsLogger
from torch.utils.data import Dataset, RandomSampler, SequentialSampler
from torchdata.stateful_dataloader import StatefulDataLoader
WorkerType = Type[Worker]
class Role(Enum):
"""
To create more roles dynamically, you can subclass Role and add new members
"""
Actor = 0
Rollout = 1
ActorRollout = 2
Critic = 3
RefPolicy = 4
RewardModel = 5
ActorRolloutRef = 6
class AdvantageEstimator(str, Enum):
"""
Using an enumeration class to avoid spelling errors in adv_estimator
"""
GAE = 'gae'
GRPO = 'grpo'
REINFORCE_PLUS_PLUS = 'reinforce_plus_plus'
REINFORCE_PLUS_PLUS_BASELINE = 'reinforce_plus_plus_baseline'
REMAX = 'remax'
RLOO = 'rloo'
@dataclass
class ResourcePoolManager:
"""
Define a resource pool specification. Resource pool will be initialized first.
Mapping
"""
resource_pool_spec: dict[str, list[int]]
mapping: dict[Role, str]
resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict)
def create_resource_pool(self):
for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():
# max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool
# For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one.
# For Megatron backend, we recommend using max_colocate_count>1 that can utilize different WorkerGroup for differnt models
resource_pool = RayResourcePool(process_on_nodes=process_on_nodes,
use_gpu=True,
max_colocate_count=1,
name_prefix=resource_pool_name)
self.resource_pool_dict[resource_pool_name] = resource_pool
self._check_resource_available()
def get_resource_pool(self, role: Role) -> RayResourcePool:
"""Get the resource pool of the worker_cls"""
return self.resource_pool_dict[self.mapping[role]]
def get_n_gpus(self) -> int:
"""Get the number of gpus in this cluster."""
return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes])
def _check_resource_available(self):
"""Check if the resource pool can be satisfied in this ray cluster."""
node_available_resources = ray.state.available_resources_per_node()
node_available_gpus = {node: node_info.get('GPU', 0) for node, node_info in node_available_resources.items()}
# check total required gpus can be satisfied
total_available_gpus = sum(node_available_gpus.values())
total_required_gpus = sum(
[n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes])
if total_available_gpus < total_required_gpus:
raise ValueError(
f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}")
# check each resource pool can be satisfied, O(#resource_pools * #nodes)
for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():
num_gpus, num_nodes = process_on_nodes[0], len(process_on_nodes)
for node, available_gpus in node_available_gpus.items():
if available_gpus >= num_gpus:
node_available_gpus[node] -= num_gpus
num_nodes -= 1
if num_nodes == 0:
break
if num_nodes > 0:
raise ValueError(
f"Resource pool {resource_pool_name}: {num_gpus}*{num_nodes} cannot be satisfied in this ray cluster"
)
import torch
from verl.utils.torch_functional import masked_mean
def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty='kl'):
responses = data.batch['responses']
response_length = responses.size(1)
token_level_scores = data.batch['token_level_scores']
batch_size = data.batch.batch_size[0]
attention_mask = data.batch['attention_mask']
response_mask = attention_mask[:, -response_length:]
# compute kl between ref_policy and current policy
# When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled.
kld = core_algos.kl_penalty(data.batch['old_log_probs'], data.batch['ref_log_prob'],
kl_penalty=kl_penalty) # (batch_size, response_length)
kld = kld * response_mask
beta = kl_ctrl.value
token_level_rewards = token_level_scores - beta * kld
current_kl = masked_mean(kld, mask=response_mask, axis=-1) # average over sequence
current_kl = torch.mean(current_kl, dim=0).item()
# according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837
kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)
data.batch['token_level_rewards'] = token_level_rewards
metrics = {'actor/reward_kl_penalty': current_kl, 'actor/reward_kl_penalty_coeff': beta}
return data, metrics
def compute_response_mask(data: DataProto):
responses = data.batch['responses']
response_length = responses.size(1)
attention_mask = data.batch['attention_mask']
return attention_mask[:, -response_length:]
def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1):
# Back-compatible with trainers that do not compute response mask in fit
if "response_mask" not in data.batch.keys():
data.batch['response_mask'] = compute_response_mask(data)
# prepare response group
# TODO: add other ways to estimate advantages
if adv_estimator == AdvantageEstimator.GAE:
values = data.batch['values']
advantages, returns = core_algos.compute_gae_advantage_return(
token_level_rewards=data.batch['token_level_rewards'],
values=data.batch['values'],
response_mask=data.batch['response_mask'],
gamma=gamma,
lam=lam)
data.batch['advantages'] = advantages
data.batch['returns'] = returns
elif adv_estimator == AdvantageEstimator.GRPO:
advantages, returns = core_algos.compute_grpo_outcome_advantage(
token_level_rewards=data.batch['token_level_rewards'],
response_mask=data.batch['response_mask'],
index=data.non_tensor_batch['uid'])
data.batch['advantages'] = advantages
data.batch['returns'] = returns
elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE:
advantages, returns = core_algos.compute_reinforce_plus_plus_baseline_outcome_advantage(
token_level_rewards=data.batch['token_level_rewards'],
response_mask=data.batch['response_mask'],
index=data.non_tensor_batch['uid'])
data.batch['advantages'] = advantages
data.batch['returns'] = returns
elif adv_estimator == AdvantageEstimator.REINFORCE_PLUS_PLUS:
advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage(
token_level_rewards=data.batch['token_level_rewards'],
response_mask=data.batch['response_mask'],
gamma=gamma)
data.batch['advantages'] = advantages
data.batch['returns'] = returns
elif adv_estimator == AdvantageEstimator.REMAX:
advantages, returns = core_algos.compute_remax_outcome_advantage(
token_level_rewards=data.batch['token_level_rewards'],
reward_baselines=data.batch['reward_baselines'],
response_mask=data.batch['response_mask'])
data.batch['advantages'] = advantages
data.batch['returns'] = returns
elif adv_estimator == AdvantageEstimator.RLOO:
advantages, returns = core_algos.compute_rloo_outcome_advantage(
token_level_rewards=data.batch['token_level_rewards'],
response_mask=data.batch['response_mask'],
index=data.non_tensor_batch['uid'])
data.batch['advantages'] = advantages
data.batch['returns'] = returns
else:
raise NotImplementedError
return data
@contextmanager
def _timer(name: str, timing_raw: Dict[str, float]):
with Timer(name=name, logger=None) as timer:
yield
timing_raw[name] = timer.last
class RayPPOTrainer(object):
"""
Note that this trainer runs on the driver process on a single CPU/GPU node.
"""
# TODO: support each role have individual ray_worker_group_cls,
# i.e., support different backend of different role
def __init__(self,
config,
tokenizer,
role_worker_mapping: dict[Role, WorkerType],
resource_pool_manager: ResourcePoolManager,
ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
processor=None,
reward_fn=None,
val_reward_fn=None):
# assert torch.cuda.is_available(), 'cuda must be available on driver'
self.tokenizer = tokenizer
self.processor = processor
self.config = config
self.reward_fn = reward_fn
self.val_reward_fn = val_reward_fn
self.hybrid_engine = config.actor_rollout_ref.hybrid_engine
assert self.hybrid_engine, 'Currently, only support hybrid engine'
if self.hybrid_engine:
assert Role.ActorRollout in role_worker_mapping, f'{role_worker_mapping.keys()=}'
self.role_worker_mapping = role_worker_mapping
self.resource_pool_manager = resource_pool_manager
self.use_reference_policy = Role.RefPolicy in role_worker_mapping
self.use_rm = Role.RewardModel in role_worker_mapping
self.ray_worker_group_cls = ray_worker_group_cls
self.validation_generations_logger = ValidationGenerationsLogger()
# define in-reward KL control
# kl loss control currently not suppoorted
if config.algorithm.use_kl_in_reward:
self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl)
if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE:
self.use_critic = True
elif self.config.algorithm.adv_estimator in [
AdvantageEstimator.GRPO, AdvantageEstimator.REINFORCE_PLUS_PLUS, AdvantageEstimator.REMAX,
AdvantageEstimator.RLOO, AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE
]:
self.use_critic = False
else:
raise NotImplementedError
self._validate_config()
self._create_dataloader()
def _validate_config(self):
config = self.config
# number of GPUs total
n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes
# 1. Check total batch size for data correctness
real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n
assert real_train_batch_size % n_gpus == 0, \
f"real_train_batch_size ({real_train_batch_size}) must be divisible by total n_gpus ({n_gpus})."
# A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu"
# We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu".
def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
settings = {
"actor_rollout_ref.actor": "micro_batch_size",
"critic": "micro_batch_size",
"reward_model": "micro_batch_size",
"actor_rollout_ref.ref": "log_prob_micro_batch_size",
"actor_rollout_ref.rollout": "log_prob_micro_batch_size",
}
if name in settings:
param = settings[name]
param_per_gpu = f"{param}_per_gpu"
if mbs is None and mbs_per_gpu is None:
raise ValueError(
f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.")
if mbs is not None and mbs_per_gpu is not None:
raise ValueError(
f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. "
f"Please remove '{name}.{param}' because only '*_{param_per_gpu}' is supported (the former is deprecated)."
)
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
# actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu
check_mutually_exclusive(config.actor_rollout_ref.actor.ppo_micro_batch_size,
config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu,
"actor_rollout_ref.actor")
if self.use_reference_policy:
# reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
check_mutually_exclusive(config.actor_rollout_ref.ref.log_prob_micro_batch_size,
config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu,
"actor_rollout_ref.ref")
# The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
check_mutually_exclusive(config.actor_rollout_ref.rollout.log_prob_micro_batch_size,
config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu,
"actor_rollout_ref.rollout")
if self.use_critic and not config.critic.use_dynamic_bsz:
# Check for critic micro-batch size conflicts
check_mutually_exclusive(config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu,
"critic")
# Check for reward model micro-batch size conflicts
if config.reward_model.enable and not config.reward_model.use_dynamic_bsz:
check_mutually_exclusive(config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu,
"reward_model")
# Actor
# check if train_batch_size is larger than ppo_mini_batch_size
# if NOT dynamic_bsz, we must ensure:
# ppo_mini_batch_size is divisible by ppo_micro_batch_size
# ppo_micro_batch_size * sequence_parallel_size >= n_gpus
if not config.actor_rollout_ref.actor.use_dynamic_bsz:
assert config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size
sp_size = config.actor_rollout_ref.actor.get('ulysses_sequence_parallel_size', 1)
if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None:
assert config.actor_rollout_ref.actor.ppo_mini_batch_size % config.actor_rollout_ref.actor.ppo_micro_batch_size == 0
assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus
assert config.actor_rollout_ref.actor.loss_agg_mode in [
"token-mean", "seq-mean-token-sum", "seq-mean-token-mean"
], f"Invalid loss_agg_mode: {config.actor_rollout_ref.actor.loss_agg_mode}"
if config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss:
print(f"NOTICE: You have both enabled in-reward kl and kl loss.")
# critic
if self.use_critic and not config.critic.use_dynamic_bsz:
assert config.data.train_batch_size >= config.critic.ppo_mini_batch_size
sp_size = config.critic.get('ulysses_sequence_parallel_size', 1)
if config.critic.ppo_micro_batch_size is not None:
assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0
assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus
# Check if use_remove_padding is enabled when using sequence parallelism for fsdp
if config.actor_rollout_ref.actor.strategy == 'fsdp':
if config.actor_rollout_ref.actor.get('ulysses_sequence_parallel_size', 1) > 1 or \
config.actor_rollout_ref.ref.get('ulysses_sequence_parallel_size', 1) > 1:
assert config.actor_rollout_ref.model.use_remove_padding, \
"When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`."
if self.use_critic and config.critic.strategy == 'fsdp':
if config.critic.get('ulysses_sequence_parallel_size', 1) > 1:
assert config.critic.model.use_remove_padding, \
"When using sequence parallelism for critic, you must enable `use_remove_padding`."
if config.data.get('val_batch_size', None) is not None:
print(
f"WARNING: val_batch_size is deprecated. Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves."
)
# check eval config
if config.actor_rollout_ref.rollout.val_kwargs.do_sample:
assert config.actor_rollout_ref.rollout.temperature > 0, \
"validation gen temperature should be greater than 0 when enabling do_sample"
print("[validate_config] All configuration checks passed successfully!")
def _create_dataloader(self):
# TODO: we have to make sure the batch size is divisible by the dp size
from verl.utils.import_utils import load_extern_type
if "custom_cls" in self.config.data and self.config.data.custom_cls.get("path", None) is not None:
dataset_cls = load_extern_type(self.config.data.custom_cls.path, self.config.data.custom_cls.name)
if not issubclass(dataset_cls, Dataset):
raise TypeError(f"The custom dataset class '{self.config.data.custom_cls.name}' from "
f"'{self.config.data.custom_cls.path}' must inherit from torch.utils.data.Dataset")
else:
dataset_cls = RLHFDataset
self.train_dataset = dataset_cls(
data_files=self.config.data.train_files,
tokenizer=self.tokenizer,
processor=self.processor,
config=self.config.data,
)
# use sampler for better ckpt resume
if self.config.data.shuffle:
train_dataloader_generator = torch.Generator()
train_dataloader_generator.manual_seed(self.config.data.get('seed', 1))
sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator)
else:
sampler = SequentialSampler(data_source=self.train_dataset)
self.train_dataloader = StatefulDataLoader(dataset=self.train_dataset,
batch_size=self.config.data.get('gen_batch_size',
self.config.data.train_batch_size),
num_workers=8,
drop_last=True,
collate_fn=collate_fn,
sampler=sampler)
self.val_dataset = dataset_cls(
data_files=self.config.data.val_files,
tokenizer=self.tokenizer,
processor=self.processor,
config=self.config.data,
)
self.val_dataloader = StatefulDataLoader(
dataset=self.val_dataset,
# Validation datasets are sent to inference engines as a whole batch,
# which will schedule the memory themselves.
batch_size=len(self.val_dataset),
num_workers=8,
shuffle=False,
drop_last=False,
collate_fn=collate_fn)
assert len(self.train_dataloader) >= 1
assert len(
self.val_dataloader
) == 1, "Validation dataloader must have a single batch, which inference engines will schedule the memory themselves."
print(f'Size of train dataloader: {len(self.train_dataloader)}')
# inject total_training_steps to actor/critic optim_config. This is hacky.
total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
if self.config.trainer.total_training_steps is not None:
total_training_steps = self.config.trainer.total_training_steps
self.total_training_steps = total_training_steps
print(f'Total training steps: {self.total_training_steps}')
OmegaConf.set_struct(self.config, True)
with open_dict(self.config):
self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps
self.config.critic.optim.total_training_steps = total_training_steps
def _maybe_log_val_generations(self, inputs, outputs, scores):
"""Log a table of validation samples to the configured logger (wandb or swanlab)"""
generations_to_log = self.config.trainer.log_val_generations
if generations_to_log == 0:
return
import numpy as np
# Create tuples of (input, output, score) and sort by input text
samples = list(zip(inputs, outputs, scores))
samples.sort(key=lambda x: x[0]) # Sort by input text
# Use fixed random seed for deterministic shuffling
rng = np.random.RandomState(42)
rng.shuffle(samples)
# Take first N samples after shuffling
samples = samples[:generations_to_log]
# Log to each configured logger
self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps)
def _validate(self):
data_source_lst = []
reward_extra_infos_dict: dict[str, list] = defaultdict(list)
# Lists to collect samples for the table
sample_inputs = []
sample_outputs = []
sample_scores = []
for test_data in self.val_dataloader:
test_batch = DataProto.from_single_dict(test_data)
# repeat test batch
test_batch = test_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n,
interleave=True)
# we only do validation on rule-based rm
if self.config.reward_model.enable and test_batch[0].non_tensor_batch['reward_model']['style'] == 'model':
return {}
# Store original inputs
input_ids = test_batch.batch['input_ids']
# TODO: Can we keep special tokens except for padding tokens?
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
sample_inputs.extend(input_texts)
if 'multi_modal_inputs' in test_batch.non_tensor_batch.keys():
test_gen_batch = test_batch.pop(
batch_keys=['input_ids', 'attention_mask', 'position_ids'],
non_tensor_batch_keys=['raw_prompt_ids', 'multi_modal_data', 'multi_modal_inputs'],
)
else:
test_gen_batch = test_batch.pop(
batch_keys=['input_ids', 'attention_mask', 'position_ids'],
non_tensor_batch_keys=['raw_prompt_ids'],
)
test_gen_batch.meta_info = {
'eos_token_id': self.tokenizer.eos_token_id,
'pad_token_id': self.tokenizer.pad_token_id,
'recompute_log_prob': False,
'do_sample': self.config.actor_rollout_ref.rollout.val_kwargs.do_sample,
'validate': True,
}
print(f'test_gen_batch meta info: {test_gen_batch.meta_info}')
# pad to be divisible by dp_size
test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size)
test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)
# unpad
test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)
print('validation generation end')
# Store generated outputs
output_ids = test_output_gen_batch.batch['responses']
output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
sample_outputs.extend(output_texts)
test_batch = test_batch.union(test_output_gen_batch)
# evaluate using reward_function
result = self.val_reward_fn(test_batch, return_dict=True)
reward_tensor = result["reward_tensor"]
scores = reward_tensor.sum(-1).cpu().tolist()
sample_scores.extend(scores)
reward_extra_infos_dict["reward"].extend(scores)
if "reward_extra_info" in result:
for key, lst in result["reward_extra_info"].items():
reward_extra_infos_dict[key].extend(lst)
data_source_lst.append(test_batch.non_tensor_batch.get('data_source', ['unknown'] * reward_tensor.shape[0]))
self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)
for key_info, lst in reward_extra_infos_dict.items():
assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}"
data_sources = np.concatenate(data_source_lst, axis=0)
data_src2var2metric2val = process_validation_metrics(data_sources, sample_inputs, reward_extra_infos_dict)
metric_dict = {}
for data_source, var2metric2val in data_src2var2metric2val.items():
core_var = "acc" if "acc" in var2metric2val else "reward"
for var_name, metric2val in var2metric2val.items():
n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()])
for metric_name, metric_val in metric2val.items():
if (var_name == core_var) and any(
metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"]) and (f"@{n_max}"
in metric_name):
metric_sec = "val-core"
else:
metric_sec = "val-aux"
pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}"
metric_dict[pfx] = metric_val
return metric_dict
def init_workers(self):
"""Init resource pool and worker group"""
self.resource_pool_manager.create_resource_pool()
self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}
# create actor and rollout
if self.hybrid_engine:
resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)
actor_rollout_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.ActorRollout],
config=self.config.actor_rollout_ref,
role='actor_rollout')
self.resource_pool_to_cls[resource_pool]['actor_rollout'] = actor_rollout_cls
else:
raise NotImplementedError
# create critic
if self.use_critic:
resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic)
self.resource_pool_to_cls[resource_pool]['critic'] = critic_cls
# create reference policy if needed
if self.use_reference_policy:
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
ref_policy_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RefPolicy],
config=self.config.actor_rollout_ref,
role='ref')
self.resource_pool_to_cls[resource_pool]['ref'] = ref_policy_cls
# create a reward model if reward_fn is None
if self.use_rm:
# we create a RM here
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)
self.resource_pool_to_cls[resource_pool]['rm'] = rm_cls
# initialize WorkerGroup
# NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
# you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups.
# See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
all_wg = {}
self.wg_dicts = []
wg_kwargs = {} # Setting up kwargs for RayWorkerGroup
if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None:
wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout
for resource_pool, class_dict in self.resource_pool_to_cls.items():
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool,
ray_cls_with_init=worker_dict_cls,
**wg_kwargs)
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
all_wg.update(spawn_wg)
# keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699
self.wg_dicts.append(wg_dict)
if self.use_critic:
self.critic_wg = all_wg['critic']
self.critic_wg.init_model()
if self.use_reference_policy:
self.ref_policy_wg = all_wg['ref']
self.ref_policy_wg.init_model()
if self.use_rm:
self.rm_wg = all_wg['rm']
self.rm_wg.init_model()
# we should create rollout at the end so that vllm can have a better estimation of kv cache memory
self.actor_rollout_wg = all_wg['actor_rollout']
self.actor_rollout_wg.init_model()
def _save_checkpoint(self):
# path: given_path + `/global_step_{global_steps}` + `/actor`
local_global_step_folder = os.path.join(self.config.trainer.default_local_dir,
f'global_step_{self.global_steps}')
print(f'local_global_step_folder: {local_global_step_folder}')
actor_local_path = os.path.join(local_global_step_folder, 'actor')
actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'actor')
remove_previous_ckpt_in_save = self.config.trainer.get('remove_previous_ckpt_in_save', False)
if remove_previous_ckpt_in_save:
print(
'Warning: remove_previous_ckpt_in_save is deprecated, set max_actor_ckpt_to_keep=1 and max_critic_ckpt_to_keep=1 instead'
)
max_actor_ckpt_to_keep = self.config.trainer.get('max_actor_ckpt_to_keep',
None) if not remove_previous_ckpt_in_save else 1
max_critic_ckpt_to_keep = self.config.trainer.get('max_critic_ckpt_to_keep',
None) if not remove_previous_ckpt_in_save else 1
self.actor_rollout_wg.save_checkpoint(actor_local_path,
actor_remote_path,
self.global_steps,
max_ckpt_to_keep=max_actor_ckpt_to_keep)
if self.use_critic:
critic_local_path = os.path.join(local_global_step_folder, 'critic')
critic_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
self.config.trainer.default_hdfs_dir, f'global_step_{self.global_steps}', 'critic')
self.critic_wg.save_checkpoint(critic_local_path,
critic_remote_path,
self.global_steps,
max_ckpt_to_keep=max_critic_ckpt_to_keep)
# save dataloader
dataloader_local_path = os.path.join(local_global_step_folder, 'data.pt')
dataloader_state_dict = self.train_dataloader.state_dict()
torch.save(dataloader_state_dict, dataloader_local_path)
# latest checkpointed iteration tracker (for atomic usage)
local_latest_checkpointed_iteration = os.path.join(self.config.trainer.default_local_dir,
'latest_checkpointed_iteration.txt')
with open(local_latest_checkpointed_iteration, 'w') as f:
f.write(str(self.global_steps))
def _load_checkpoint(self):
if self.config.trainer.resume_mode == 'disable':
return 0
# load from hdfs
if self.config.trainer.default_hdfs_dir is not None:
raise NotImplementedError('load from hdfs is not implemented yet')
else:
checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path
if not os.path.isabs(checkpoint_folder):
working_dir = os.getcwd()
checkpoint_folder = os.path.join(working_dir, checkpoint_folder)
global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest
# find global_step_folder
if self.config.trainer.resume_mode == 'auto':
if global_step_folder is None:
print('Training from scratch')
return 0
else:
if self.config.trainer.resume_mode == "resume_path":
assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type"
assert 'global_step_' in self.config.trainer.resume_from_path, "resume ckpt must specify the global_steps"
global_step_folder = self.config.trainer.resume_from_path
if not os.path.isabs(global_step_folder):
working_dir = os.getcwd()
global_step_folder = os.path.join(working_dir, global_step_folder)
print(f'Load from checkpoint folder: {global_step_folder}')
# set global step
self.global_steps = int(global_step_folder.split('global_step_')[-1])
print(f'Setting global step to {self.global_steps}')
print(f'Resuming from {global_step_folder}')
actor_path = os.path.join(global_step_folder, 'actor')
critic_path = os.path.join(global_step_folder, 'critic')
# load actor
self.actor_rollout_wg.load_checkpoint(actor_path,
del_local_after_load=self.config.trainer.del_local_ckpt_after_load)
# load critic
if self.use_critic:
self.critic_wg.load_checkpoint(critic_path,
del_local_after_load=self.config.trainer.del_local_ckpt_after_load)
# load dataloader,
# TODO: from remote not implemented yet
dataloader_local_path = os.path.join(global_step_folder, 'data.pt')
if os.path.exists(dataloader_local_path):
dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False)
self.train_dataloader.load_state_dict(dataloader_state_dict)
else:
print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch")
def _balance_batch(self, batch: DataProto, metrics, logging_prefix='global_seqlen'):
"""Reorder the data on single controller such that each dp rank gets similar total tokens"""
attention_mask = batch.batch['attention_mask']
batch_size = attention_mask.shape[0]
global_seqlen_lst = batch.batch['attention_mask'].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,)
world_size = self.actor_rollout_wg.world_size
global_partition_lst = get_seqlen_balanced_partitions(global_seqlen_lst,
k_partitions=world_size,
equal_size=True)
# reorder based on index. The data will be automatically equally partitioned by dispatch function
global_idx = torch.tensor([j for partition in global_partition_lst for j in partition])
batch.reorder(global_idx)
global_balance_stats = log_seqlen_unbalance(seqlen_list=global_seqlen_lst,
partitions=global_partition_lst,
prefix=logging_prefix)
metrics.update(global_balance_stats)
def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
from verl.utils.tracking import Tracking
from omegaconf import OmegaConf
logger = Tracking(project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True))
self.global_steps = 0
# load checkpoint before doing anything
self._load_checkpoint()
# perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True):
val_metrics = self._validate()
pprint(f'Initial validation metrics: {val_metrics}')
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get('val_only', False):
return
# add tqdm
progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
# we start from step 1
self.global_steps += 1
last_val_metrics = None
for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
metrics = {}
timing_raw = {}
batch: DataProto = DataProto.from_single_dict(batch_dict)
# pop those keys for generation
if 'multi_modal_inputs' in batch.non_tensor_batch.keys():
gen_batch = batch.pop(
batch_keys=['input_ids', 'attention_mask', 'position_ids'],
non_tensor_batch_keys=['raw_prompt_ids', 'multi_modal_data', 'multi_modal_inputs'],
)
else:
gen_batch = batch.pop(
batch_keys=['input_ids', 'attention_mask', 'position_ids'],
non_tensor_batch_keys=['raw_prompt_ids'],
)
is_last_step = self.global_steps >= self.total_training_steps
with _timer('step', timing_raw):
# generate a batch
with _timer('gen', timing_raw):
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
with _timer('gen_max', timing_raw):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info['do_sample'] = False
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
batch = batch.union(gen_baseline_output)
reward_baseline_tensor = self.reward_fn(batch)
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
batch.batch['reward_baselines'] = reward_baseline_tensor
del gen_baseline_batch, gen_baseline_output
batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))],
dtype=object)
# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)
batch.batch['response_mask'] = compute_response_mask(batch)
# balance the number of valid tokens on each dp rank.
# Note that this breaks the order of data inside the batch.
# Please take care when you implement group based adv computation such as GRPO and rloo
if self.config.trainer.balance_batch:
self._balance_batch(batch, metrics=metrics)
# compute global_valid tokens
batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist()
# recompute old_log_probs
with _timer('old_log_prob', timing_raw):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
batch = batch.union(old_log_prob)
if self.use_reference_policy:
# compute reference log_prob
with _timer('ref', timing_raw):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
# compute values
if self.use_critic:
with _timer('values', timing_raw):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
with _timer('adv', timing_raw):
# compute scores. Support both model and function-based.
# We first compute the scores using reward model. Then, we call reward_fn to combine
# the results from reward model and rule-based results.
if self.use_rm:
# we first compute reward model score
reward_tensor = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor)
# we combine with rule-based rm
reward_extra_infos_dict: dict[str, list]
try:
reward_result = self.reward_fn(batch, return_dict=True)
reward_tensor = reward_result['reward_tensor']
reward_extra_infos_dict = reward_result['reward_extra_info']
except Exception as e:
print(f'Error in reward_fn: {e}')
reward_tensor = self.reward_fn(batch)
reward_extra_infos_dict = {}
batch.batch['token_level_scores'] = reward_tensor
print(f'{list(reward_extra_infos_dict.keys())=}')
if reward_extra_infos_dict:
batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})
# compute rewards. apply_kl_penalty if available
if self.config.algorithm.use_kl_in_reward:
batch, kl_metrics = apply_kl_penalty(batch,
kl_ctrl=self.kl_ctrl_in_reward,
kl_penalty=self.config.algorithm.kl_penalty)
metrics.update(kl_metrics)
else:
batch.batch['token_level_rewards'] = batch.batch['token_level_scores']
# compute advantages, executed on the driver process
batch = compute_advantage(batch,
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.actor_rollout_ref.rollout.n)
# update critic
if self.use_critic:
with _timer('update_critic', timing_raw):
critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])
metrics.update(critic_output_metrics)
# implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
with _timer('update_actor', timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
metrics.update(actor_output_metrics)
# validate
if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \
(is_last_step or self.global_steps % self.config.trainer.test_freq == 0):
with _timer('testing', timing_raw):
val_metrics: dict = self._validate()
if is_last_step:
last_val_metrics = val_metrics
metrics.update(val_metrics)
if self.config.trainer.save_freq > 0 and ( is_last_step or \
self.global_steps % self.config.trainer.save_freq == 0):
with _timer('save_checkpoint', timing_raw):
self._save_checkpoint()
# collect metrics
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
# TODO: implement actual tflpo and theoretical tflpo
n_gpus = self.resource_pool_manager.get_n_gpus()
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)
if is_last_step:
pprint(f'Final validation metrics: {last_val_metrics}')
progress_bar.close()
return
progress_bar.update(1)
self.global_steps += 1
working_dir: ./
excludes: ["/.git/"]
env_vars:
TORCH_NCCL_AVOID_RECORD_STREAMS: "1"
VLLM_ATTENTION_BACKEND: "XFORMERS"
\ No newline at end of file
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import tokenizer
from .tokenizer import hf_tokenizer, hf_processor
__all__ = tokenizer.__all__
\ No newline at end of file
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
\ No newline at end of file
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
from filelock import FileLock
import tempfile
from typing import Union
import torch
import torch.distributed
from transformers import PreTrainedTokenizer, ProcessorMixin
import numpy as np
import random
import re
class BaseCheckpointManager:
"""
A checkpoint manager that saves and loads
- model
- optimizer
- lr_scheduler
- extra_states
in a SPMD way.
We save
- sharded model states and optimizer states
- full lr_scheduler states
- huggingface tokenizer and config for ckpt merge
"""
def __init__(self,
model,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler = None,
processing_class: Union[PreTrainedTokenizer, ProcessorMixin] = None,
checkpoint_contents: list = ['model', 'optimizer', 'extra']):
self.previous_global_step = None
self.previous_saved_paths = []
self.model = model
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.processing_class = processing_class
self.checkpoint_contents = checkpoint_contents
self.rank = torch.distributed.get_rank()
self.world_size = torch.distributed.get_world_size()
def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load: bool = False):
raise NotImplementedError
def save_checkpoint(self,
local_path: str,
hdfs_path: str = None,
global_step: int = 0,
max_ckpt_to_keep: int = None):
raise NotImplementedError
@staticmethod
def checkpath(local_path: str, hdfs_path: str):
assert local_path is not None or hdfs_path is not None, "local_path and hdfs_path cannot be both None"
return True if local_path is not None else False, local_path if local_path is not None else hdfs_path
def remove_previous_save_local_path(self, path):
if isinstance(path, str):
path = [path]
for p in path:
abs_path = os.path.abspath(p)
print(f'Checkpoint manager remove previous save local path: {abs_path}')
if not os.path.exists(abs_path):
continue
shutil.rmtree(abs_path, ignore_errors=True)
@staticmethod
def local_mkdir(path):
if not os.path.isabs(path):
working_dir = os.getcwd()
path = os.path.join(working_dir, path)
# Using hash value of path as lock file name to avoid long file name
lock_filename = f"ckpt_{hash(path) & 0xFFFFFFFF:08x}.lock"
lock_path = os.path.join(tempfile.gettempdir(), lock_filename)
try:
with FileLock(lock_path, timeout=60): # Add timeout
# make a new dir
os.makedirs(path, exist_ok=True)
except Exception as e:
print(f"Warning: Failed to acquire lock for {path}: {e}")
# Even if the lock is not acquired, try to create the directory
os.makedirs(path, exist_ok=True)
return path
@staticmethod
def get_rng_state():
rng_state = {
'cpu': torch.get_rng_state(),
'cuda': torch.cuda.get_rng_state(),
'numpy': np.random.get_state(),
'random': random.getstate(),
}
return rng_state
@staticmethod
def load_rng_state(rng_state):
torch.set_rng_state(rng_state['cpu'])
torch.cuda.set_rng_state(rng_state['cuda'])
np.random.set_state(rng_state['numpy'])
random.setstate(rng_state['random'])
def find_latest_ckpt_path(path, directory_format="global_step_{}"):
if path is None:
return None
tracker_file = get_checkpoint_tracker_filename(path)
if not os.path.exists(tracker_file):
print("Checkpoint tracker file does not exist: %s", tracker_file)
return None
with open(tracker_file, "rb") as f:
iteration = int(f.read().decode())
ckpt_path = os.path.join(path, directory_format.format(iteration))
if not os.path.exists(ckpt_path):
print("Checkpoint does not exist: %s", ckpt_path)
return None
print("Found checkpoint: %s", ckpt_path)
return ckpt_path
def get_checkpoint_tracker_filename(root_path: str):
"""
Tracker file rescords the latest chckpoint during training to restart from.
"""
return os.path.join(root_path, "latest_checkpointed_iteration.txt")
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ray
import os
import warnings
from typing import Union
import torch
import torch.distributed
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
from torch.distributed.fsdp import ShardedStateDictConfig, ShardedOptimStateDictConfig
from verl.utils.fs import copy_to_local, is_non_local
from transformers import PreTrainedTokenizer, ProcessorMixin
from .checkpoint_manager import BaseCheckpointManager
class FSDPCheckpointManager(BaseCheckpointManager):
"""
A checkpoint manager that saves and loads
- model
- optimizer
- lr_scheduler
- extra_states
in a SPMD way.
We save
- sharded model states and optimizer states
- full lr_scheduler states
- huggingface tokenizer/processor and config for ckpt merge
"""
def __init__(self,
model: FSDP,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
processing_class: Union[PreTrainedTokenizer, ProcessorMixin] = None,
checkpoint_contents: list = ['model', 'optimizer', 'extra'],
**kwargs):
if processing_class is None:
assert "tokenizer" in kwargs, "tokenizer or processor must be provided"
warnings.warn("`tokenizer` is deprecated. use `processing_class` instead.", DeprecationWarning)
processing_class = kwargs.pop("tokenizer")
assert "model" in checkpoint_contents and "optimizer" in checkpoint_contents and "extra" in checkpoint_contents, f"FSDPCheckpointManager must include ['model', 'optimizer', 'extra'], got {checkpoint_contents}"
super().__init__(model,
optimizer,
lr_scheduler=lr_scheduler,
processing_class=processing_class,
checkpoint_contents=checkpoint_contents)
def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load=False):
if local_path is None:
return
# every rank download its own checkpoint
remote_model_path = os.path.join(local_path, f'model_world_size_{self.world_size}_rank_{self.rank}.pt')
remote_optim_path = os.path.join(local_path, f'optim_world_size_{self.world_size}_rank_{self.rank}.pt')
remote_extra_state_path = os.path.join(local_path,
f'extra_state_world_size_{self.world_size}_rank_{self.rank}.pt')
print(
f'[rank-{self.rank}]: Loading from {remote_model_path} and {remote_optim_path} and {remote_extra_state_path}'
)
local_model_path = copy_to_local(remote_model_path)
local_optim_path = copy_to_local(remote_optim_path)
local_extra_state_path = copy_to_local(remote_extra_state_path)
model_state_dict = torch.load(local_model_path, weights_only=False)
optimizer_state_dict = torch.load(local_optim_path, weights_only=False)
extra_state_dict = torch.load(local_extra_state_path, weights_only=False)
if del_local_after_load:
try:
os.remove(local_model_path) if is_non_local(local_model_path) else None
os.remove(local_optim_path) if is_non_local(local_optim_path) else None
os.remove(local_extra_state_path) if is_non_local(local_extra_state_path) else None
except Exception as e:
print(
f'[rank-{self.rank}]: remove local resume ckpt file after loading failed, exception {e} will be ignored'
)
lr_scheduler_state_dict = extra_state_dict['lr_scheduler']
state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True)
optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True)
with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg):
self.model.load_state_dict(model_state_dict)
if self.optimizer is not None:
self.optimizer.load_state_dict(optimizer_state_dict)
# recover random state
if 'rng' in extra_state_dict:
# 'rng' may not exist for backward compatibility
self.load_rng_state(extra_state_dict['rng'])
if self.lr_scheduler is not None:
self.lr_scheduler.load_state_dict(lr_scheduler_state_dict)
def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep=None):
if local_path is None:
return
# record the previous global step
self.previous_global_step = global_step
# remove previous local_path
if max_ckpt_to_keep and isinstance(max_ckpt_to_keep, int) and max_ckpt_to_keep > 0 and len(
self.previous_saved_paths) >= max_ckpt_to_keep:
keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1
self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start])
self.previous_saved_paths = self.previous_saved_paths[keep_start:]
local_path = self.local_mkdir(local_path)
torch.distributed.barrier()
# every rank will save its own model and optim shard
state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True)
optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg):
model_state_dict = self.model.state_dict()
if self.optimizer is not None:
optimizer_state_dict = self.optimizer.state_dict()
else:
optimizer_state_dict = None
if self.lr_scheduler is not None:
lr_scheduler_state_dict = self.lr_scheduler.state_dict()
else:
lr_scheduler_state_dict = None
extra_state_dict = {
'lr_scheduler': lr_scheduler_state_dict,
'rng': self.get_rng_state(),
}
model_path = os.path.join(local_path, f'model_world_size_{self.world_size}_rank_{self.rank}.pt')
optim_path = os.path.join(local_path, f'optim_world_size_{self.world_size}_rank_{self.rank}.pt')
extra_path = os.path.join(local_path, f'extra_state_world_size_{self.world_size}_rank_{self.rank}.pt')
print(f'[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}')
print(f'[rank-{self.rank}]: Saving checkpoint to {os.path.abspath(model_path)}')
print(f'[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}')
torch.save(model_state_dict, model_path)
torch.save(optimizer_state_dict, optim_path) # TODO: address optimizer is None
torch.save(extra_state_dict, extra_path)
if "hf_model" in self.checkpoint_contents:
# wait for everyone to dump to local
torch.distributed.barrier()
if self.rank == 0:
hf_local_path = os.path.join(local_path, 'huggingface')
os.makedirs(hf_local_path, exist_ok=True)
self.model._fsdp_wrapped_module.config.save_pretrained(hf_local_path)
self.processing_class.save_pretrained(hf_local_path)
torch.distributed.barrier()
self.previous_saved_paths.append(local_path)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ray
import os
import random
import numpy as np
import warnings
from typing import Union
import torch
import torch.distributed
from torch.nn.parallel import DistributedDataParallel as torchDDP
from verl.utils.fs import copy_to_local, is_non_local
from verl.models.weight_loader_registry import get_weight_saver
from verl.models.weight_loader_registry import get_weight_loader
from verl.utils.model import load_megatron_model_weights
from verl.utils.megatron_utils import TransformerConfig, get_model_checkpoint_path, get_hf_model_checkpoint_path, get_optimizer_checkpoint_path, get_rng_states_checkpoint_path, unwrap_model
from .checkpoint_manager import BaseCheckpointManager
from transformers import AutoModelForCausalLM
from megatron.core import mpu, tensor_parallel
from megatron.core.dist_checkpointing.mapping import ShardedObject
from megatron.core.transformer.module import Float16Module
from megatron.core.distributed import DistributedDataParallel as LocalDDP
class MegatronCheckpointManager(BaseCheckpointManager):
"""
A checkpoint manager that saves and loads
- model
- optimizer
- lr_scheduler
- extra_states
in a SPMD way.
We save
- sharded model states and optimizer states
- full lr_scheduler states
- huggingface tokenizer/processor and config for ckpt merge
"""
def __init__(self,
config,
model_config,
role,
model: torch.nn.ModuleList,
arch: str,
hf_config,
param_dtype: torch.dtype,
share_embeddings_and_output_weights: bool,
tokenizer,
optimizer,
use_distributed_optimizer: bool,
checkpoint_contents: list = ['model', 'optimizer', 'extra'],
**kwargs):
super().__init__(model,
optimizer=optimizer,
lr_scheduler=None,
processing_class=tokenizer,
checkpoint_contents=checkpoint_contents)
self.arch = arch
self.config = config
self.role = role
self.is_value_model = False
if self.role in ["reward", "critic"]:
self.is_value_model = True
self.model_config = model_config
self.hf_config = hf_config
self.param_dtype = param_dtype
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.model_path = self.config.model.path
self.use_distributed_optimizer = use_distributed_optimizer
self.rank = torch.distributed.get_rank()
self.weight_saver = get_weight_saver(self.arch)
def get_rng_state(self, use_dist_ckpt: bool = False, data_parallel_random_init: bool = False):
""" collect rng state across data parallel ranks """
rng_state = {
'random_rng_state': random.getstate(),
'np_rng_state': np.random.get_state(),
'torch_rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state(),
'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states()
}
rng_state_list = None
if torch.distributed.is_initialized() and \
mpu.get_data_parallel_world_size() > 1 and data_parallel_random_init:
rng_state_list = \
[None for i in range(mpu.get_data_parallel_world_size())]
torch.distributed.all_gather_object(rng_state_list, rng_state, group=mpu.get_data_parallel_group())
else:
rng_state_list = [rng_state]
if use_dist_ckpt:
pp_rank = mpu.get_pipeline_model_parallel_rank()
pp_size = mpu.get_pipeline_model_parallel_world_size()
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
cp_rank = mpu.get_context_parallel_rank()
cp_size = mpu.get_context_parallel_world_size()
rng_state_list = ShardedObject('rng_state',
rng_state_list, (pp_size, tp_size, cp_size), (pp_rank, tp_rank, cp_rank),
replica_id=mpu.get_data_parallel_rank(with_context_parallel=True))
return rng_state_list
def get_checkpoint_name(self,
checkpoints_path,
pipeline_parallel=None,
tensor_rank=None,
pipeline_rank=None,
cp_rank=None,
expert_parallel=None,
expert_rank=None,
return_base_dir=True,
basename="model.pt"):
"""Determine the directory name for this rank's checkpoint."""
# Use both the tensor and pipeline MP rank.
if pipeline_parallel is None:
pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1)
if tensor_rank is None:
tensor_rank = mpu.get_tensor_model_parallel_rank()
if pipeline_rank is None:
pipeline_rank = mpu.get_pipeline_model_parallel_rank()
if cp_rank is None:
cp_rank = mpu.get_context_parallel_rank()
if expert_parallel is None:
expert_parallel = (mpu.get_expert_model_parallel_world_size() > 1)
if expert_rank is None:
expert_rank = mpu.get_expert_model_parallel_rank()
# Use both the tensor and pipeline MP rank. If using the distributed
# optimizer, then the optimizer's path must additionally include the
# data parallel rank.
# due to the fact that models are identical across cp ranks, cp rank is not used in the checkpoint path
if not pipeline_parallel:
common_path = os.path.join(checkpoints_path, f'mp_rank_{tensor_rank:02d}')
else:
common_path = os.path.join(checkpoints_path, f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}')
if expert_parallel:
common_path = common_path + f'_{expert_rank:03d}'
os.makedirs(common_path, exist_ok=True)
if return_base_dir:
return common_path
return os.path.join(common_path, basename)
def load_optimizer(self, ckpt_path):
# TODO: Check Optimizer format and distributed optimizer
optimizer_path = get_optimizer_checkpoint_path(ckpt_path)
print(f"Loading optimizer from {optimizer_path}")
self.optimizer.load_parameter_state(optimizer_path)
def load_rng_states(self, ckpt_path, data_parallel_random_init=False, use_dist_ckpt=False):
rng_state_path = get_rng_states_checkpoint_path(ckpt_path, only_rank0_save=False)
print(f"Loading rng states from {rng_state_path}")
rng_state = torch.load(rng_state_path, weights_only=False)
# access rng_state for data parallel rank
if not use_dist_ckpt:
if data_parallel_random_init:
rng_state = rng_state[mpu.get_data_parallel_rank()]
else:
rng_state = rng_state[0]
random.setstate(rng_state['random_rng_state'])
np.random.set_state(rng_state['np_rng_state'])
torch.set_rng_state(rng_state['torch_rng_state'])
torch.cuda.set_rng_state(rng_state['cuda_rng_state'])
# Check for empty states array
if not rng_state['rng_tracker_states']:
raise KeyError
tensor_parallel.get_cuda_rng_tracker().set_states(rng_state['rng_tracker_states'])
def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load=False):
if local_path is None:
return
if 'model' in self.checkpoint_contents:
model_path = get_model_checkpoint_path(local_path)
ckpt_name = self.get_checkpoint_name(model_path, return_base_dir=False)
state_dicts = torch.load(os.path.join(ckpt_name), weights_only=False)
assert len(state_dicts) == len(
self.model), f'state_dicts length: {len(state_dicts)} mismatch with model length: {len(self.model)}'
for vpp_rank, (state_dict, model) in enumerate(zip(state_dicts, self.model)):
model.load_state_dict(state_dict)
print(f'Loaded sharded model checkpoint from {model_path}')
if 'optimizer' in self.checkpoint_contents:
self.load_optimizer(local_path)
if 'extra' in self.checkpoint_contents:
self.load_rng_states(local_path)
if del_local_after_load:
try:
os.remove(local_path) if is_non_local(local_path) else None
except Exception as e:
print(
f'[rank-{self.rank}]: remove local resume ckpt file after loading failed, exception {e} will be ignored'
)
def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep=None):
# record the previous global step
self.previous_global_step = global_step
# remove previous local_path
if max_ckpt_to_keep and isinstance(max_ckpt_to_keep, int) and max_ckpt_to_keep > 0 and len(
self.previous_saved_paths) >= max_ckpt_to_keep:
keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1
self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start])
self.previous_saved_paths = self.previous_saved_paths[keep_start:]
local_path = self.local_mkdir(local_path)
# Save Model
if 'model' in self.checkpoint_contents and mpu.get_data_parallel_rank() == 0:
state_dicts = []
for vpp_rank, model in enumerate(self.model):
state_dict = model.state_dict()
state_dicts.append(state_dict)
print(f'Saving sharded model checkpoint to {local_path}')
model_ckpt_path = get_model_checkpoint_path(local_path)
hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path)
ckpt_name = self.get_checkpoint_name(model_ckpt_path, return_base_dir=False)
torch.save(state_dicts, os.path.join(ckpt_name))
self.processing_class.save_pretrained(hf_model_ckpt_path) # tokenizer will be saved to hf_model_ckpt_path
print(f'Saved checkpoint to {model_ckpt_path}')
if hdfs_path is not None:
print(f'Uploading checkpoint to {hdfs_path}')
from verl.utils import hdfs_io
hdfs_io.makedirs(hdfs_path, exist_ok=True)
hdfs_io.copy(src=model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True)
if 'hf_model' in self.checkpoint_contents:
# wait for everyone to dump to local
state_dict = self.weight_saver(self.model,
self.hf_config,
dtype=self.param_dtype,
is_value_model=self.is_value_model,
tie_word_embeddings=self.share_embeddings_and_output_weights)
torch.distributed.barrier()
print(f'self.param_dtype: {self.param_dtype}')
for key in state_dict.keys():
print(f'state_dict[key].dtype: {key} {state_dict[key].dtype}')
torch.distributed.barrier()
if self.rank == 0:
hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path)
from accelerate import init_empty_weights
import warnings
with init_empty_weights(), warnings.catch_warnings():
warnings.simplefilter("ignore")
if 'mistral7b-rm' in self.config.model.path:
from transformers import MistralForSequenceClassification
model = MistralForSequenceClassification.from_pretrained(
self.config.model.path) # use score head instead of lm_head
state_dict['score.weight'] = state_dict['score.weight']
else:
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(self.config.model.path, torch_dtype="auto")
model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict)
if hdfs_path is not None:
print(f'Uploading checkpoint to {hdfs_path}')
from verl.utils import hdfs_io
hdfs_io.makedirs(hdfs_path, exist_ok=True)
hdfs_io.copy(src=hf_model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True)
# Save Optimizer
if 'optimizer' in self.checkpoint_contents:
torch.distributed.barrier()
optimizer_path = get_optimizer_checkpoint_path(local_path)
self.optimizer.save_parameter_state(optimizer_path)
if self.rank == 0:
print(f"saving optimizer state to {optimizer_path}")
# Save RNG States
if 'extra' in self.checkpoint_contents:
torch.distributed.barrier()
rng_state_path = get_rng_states_checkpoint_path(local_path, only_rank0_save=False)
rng_state = self.get_rng_state()
torch.save(rng_state, rng_state_path)
print(f"Rank {self.rank} saving rng states to {rng_state_path}")
self.previous_saved_paths.append(local_path)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from omegaconf import DictConfig
def update_dict_with_config(dictionary: Dict, config: DictConfig):
for key in dictionary:
if hasattr(config, key):
dictionary[key] = getattr(config, key)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment