Commit f92481f0 authored by chenych's avatar chenych
Browse files

First commit.

parent 7121d0b0
Pipeline #2435 failed with stages
in 0 seconds
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
PPO config
"""
import os
from dataclasses import asdict, dataclass, field, fields, is_dataclass
from typing import Optional, Tuple
from verl.workers.config import WorkerConfig
def recursive_post_init(dataclass_obj):
if hasattr(dataclass_obj, "post_init"):
dataclass_obj.post_init()
for attr in fields(dataclass_obj):
if is_dataclass(getattr(dataclass_obj, attr.name)):
recursive_post_init(getattr(dataclass_obj, attr.name))
@dataclass
class DataConfig:
train_files: str = ""
val_files: str = ""
prompt_key: str = "prompt"
max_prompt_length: int = 512
max_response_length: int = 512
rollout_batch_size: int = 512
return_raw_input_ids: bool = False
return_raw_prompt: bool = False
system_prompt: str = r"Please reason step by step, and put your final answer within \boxed{}."
shuffle: bool = True
seed: int = 1
max_pixels: int = 4194304
min_pixels: int = 262144
@dataclass
class AlgorithmConfig:
gamma: float = 1.0
lam: float = 1.0
adv_estimator: str = "gae"
kl_penalty: str = "kl"
kl_type: str = "fixed"
kl_coef: float = 1e-3
kl_horizon: float = 0.0
kl_target: float = 0.0
@dataclass
class TrainerConfig:
total_episodes: int = 10
max_steps: Optional[int] = None
project_name: str = "easy_r1"
experiment_name: str = "demo"
logger: Tuple[str] = ("console", "wandb")
val_generations_to_log_to_wandb: int = 0
nnodes: int = 1
n_gpus_per_node: int = 8
save_freq: int = -1
load_checkpoint_path: Optional[str] = None
val_before_train: bool = True
val_only: bool = False
test_freq: int = -1
critic_warmup: int = 0
remove_previous_ckpt: bool = False
del_local_ckpt_after_load: bool = False
save_checkpoint_path: Optional[str] = None
def post_init(self):
if self.save_checkpoint_path is None:
self.save_checkpoint_path = os.path.join("checkpoints", self.project_name, self.experiment_name)
@dataclass
class PPOConfig:
data: DataConfig = field(default_factory=DataConfig)
worker: WorkerConfig = field(default_factory=WorkerConfig)
algorithm: AlgorithmConfig = field(default_factory=AlgorithmConfig)
trainer: TrainerConfig = field(default_factory=TrainerConfig)
def post_init(self):
self.worker.rollout.prompt_length = self.data.max_prompt_length
self.worker.rollout.response_length = self.data.max_response_length
def deep_post_init(self):
recursive_post_init(self)
def to_dict(self):
return asdict(self)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Core functions to implement PPO algorithms.
The function implemented in this file should be used by trainer with different distributed strategies to
implement PPO
"""
from collections import defaultdict
from typing import TYPE_CHECKING, Tuple
import numpy as np
import torch
import verl.utils.torch_functional as verl_F
if TYPE_CHECKING:
from verl.trainer.config import AlgorithmConfig
class AdaptiveKLController:
"""
Adaptive KL controller described in the paper:
https://arxiv.org/pdf/1909.08593.pdf
"""
def __init__(self, init_kl_coef: float, target_kl: float, horizon: float):
self.value = init_kl_coef
self.target = target_kl
self.horizon = horizon
def update(self, current_kl, n_steps):
target = self.target
proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
mult = 1 + proportional_error * n_steps / self.horizon
self.value *= mult
class FixedKLController:
"""Fixed KL controller."""
def __init__(self, kl_coef: float):
self.value = kl_coef
def update(self, current_kl, n_steps):
pass
def get_kl_controller(algorithm_config: "AlgorithmConfig"):
if algorithm_config.kl_type == "fixed":
kl_ctrl = FixedKLController(kl_coef=algorithm_config.kl_coef)
elif algorithm_config.kl_type == "adaptive":
assert algorithm_config.kl_horizon > 0, f"horizon must be larger than 0. Got {algorithm_config.kl_horizon}."
kl_ctrl = AdaptiveKLController(
init_kl_coef=algorithm_config.kl_coef,
target_kl=algorithm_config.kl_target,
horizon=algorithm_config.kl_horizon,
)
else:
raise ValueError("Unknown kl_ctrl type")
return kl_ctrl
def compute_gae_advantage_return(
token_level_rewards: torch.Tensor,
values: torch.Tensor,
eos_mask: torch.Tensor,
gamma: torch.Tensor,
lam: torch.Tensor,
):
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
values: `(torch.Tensor)`
shape: (bs, response_length)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length). [EOS] mask. The token after [EOS] have mask zero.
gamma: `(float)`
discounted factor used in RL
lam: `(float)`
lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
with torch.no_grad():
lastgaelam = 0
advantages_reversed = []
gen_len = token_level_rewards.shape[-1]
for t in reversed(range(gen_len)):
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
lastgaelam = delta + gamma * lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
returns = advantages + values
advantages = verl_F.masked_whiten(advantages, eos_mask)
return advantages, returns
# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar.
def compute_grpo_outcome_advantage(
token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, index: torch.Tensor, epsilon: float = 1e-6
):
"""
Compute advantage for GRPO, operating only on Outcome reward
(with only one scalar reward for each response).
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
response_length = token_level_rewards.shape[-1]
scores = token_level_rewards.sum(dim=-1)
id2score = defaultdict(list)
id2mean = {}
id2std = {}
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
id2score[index[i]].append(scores[i])
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
id2std[idx] = torch.tensor(1.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask
return scores, scores
def compute_reinforce_plus_plus_outcome_advantage(
token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, gamma: torch.Tensor
):
"""
Compute advantage for REINFORCE++.
This implementation is based on the paper: https://arxiv.org/abs/2501.03262
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
with torch.no_grad():
returns = torch.zeros_like(token_level_rewards)
running_return = 0
for t in reversed(range(token_level_rewards.shape[1])):
running_return = token_level_rewards[:, t] + gamma * running_return
returns[:, t] = running_return
# Reset after EOS
running_return = running_return * eos_mask[:, t]
advantages = verl_F.masked_whiten(returns, eos_mask)
advantages = advantages * eos_mask
return advantages, returns
def compute_remax_outcome_advantage(
token_level_rewards: torch.Tensor, reward_baselines: torch.Tensor, eos_mask: torch.Tensor
):
"""
Compute advantage for ReMax, operating only on Outcome reward
This implementation is based on the paper: https://arxiv.org/abs/2310.10505
(with only one scalar reward for each response).
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
reward_baselines: `(torch.Tensor)`
shape: (bs,)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
response_length = token_level_rewards.shape[-1]
# scores = token_level_rewards.sum(dim=-1)
with torch.no_grad():
returns = (token_level_rewards * eos_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
advantages = returns - reward_baselines.unsqueeze(-1).tile([1, response_length]) * eos_mask
return advantages, returns
def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
kl = old_log_prob - ref_log_prob
return token_level_scores - kl * kl_ratio
def compute_policy_loss(
old_log_prob, log_prob, advantages, eos_mask, cliprange
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122
Args:
old_log_prob: `(torch.Tensor)`
shape: (bs, response_length)
log_prob: `(torch.Tensor)`
shape: (bs, response_length)
advantages: `(torch.Tensor)`
shape: (bs, response_length)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length)
cliprange: (float)
The clip range used in PPO. See https://arxiv.org/abs/1707.06347
Returns:
pg_loss: `a scalar torch.Tensor`
policy gradient loss computed via PPO
pg_clipfrac: (float)
a float number indicating the fraction of policy gradient loss being clipped
"""
negative_approx_kl = log_prob - old_log_prob
ratio = torch.exp(negative_approx_kl)
ppo_kl = verl_F.masked_mean(-negative_approx_kl, eos_mask)
pg_losses = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)
pg_loss = verl_F.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask)
pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses).float(), eos_mask)
return pg_loss, pg_clipfrac, ppo_kl
def compute_entropy_loss(logits, eos_mask):
"""Compute Categorical entropy loss
Args:
logits: `(torch.Tensor)`
shape: (bs, response_length, vocab_size)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length)
Returns:
entropy: a scalar torch.Tensor
"""
# compute entropy
entropy = verl_F.entropy_from_logits(logits) # (bs, response_len)
entropy_loss = verl_F.masked_mean(entropy, mask=eos_mask)
return entropy_loss
def compute_value_loss(vpreds, returns, values, eos_mask, cliprange_value):
"""Compute the value loss. Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151
Args:
vpreds (`torch.FloatTensor`):
Predicted values of the value head, shape (`batch_size`, `response_length`)
values (`torch.FloatTensor`):
Old values of value head, shape (`batch_size`, `response_length`)
returns: (`torch.FloatTensor`):
Ground truth returns, shape (`batch_size`, `response_length`)
Returns:
vf_loss: a scalar (`torch.FloatTensor`):
value function loss
vf_clipfrac: a float
The ratio of vf being clipped
"""
vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)
vf_losses1 = (vpreds - returns) ** 2
vf_losses2 = (vpredclipped - returns) ** 2
vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), eos_mask)
vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), eos_mask)
return vf_loss, vf_clipfrac
def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.Tensor:
"""Compute KL divergence given logprob and ref_logprob.
Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104
Args:
logprob:
ref_logprob:
Returns:
"""
if kl_penalty == "kl":
return logprob - ref_logprob
if kl_penalty == "abs":
return (logprob - ref_logprob).abs()
if kl_penalty == "mse":
return 0.5 * (logprob - ref_logprob).square()
# J. Schulman. Approximating kl divergence, 2020.
# # URL http://joschu.net/blog/kl-approx.html.
if kl_penalty == "low_var_kl":
kl = ref_logprob - logprob
ratio = torch.exp(kl)
kld = (ratio - kl - 1).contiguous()
return torch.clamp(kld, min=-10, max=10)
if kl_penalty == "full":
# so, here logprob and ref_logprob should contain the logits for every token in vocabulary
raise NotImplementedError
raise NotImplementedError
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
import json
import ray
from omegaconf import OmegaConf
from verl.single_controller.ray import RayWorkerGroup
from verl.trainer.config import PPOConfig
from verl.trainer.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role
from verl.utils import get_processor, get_tokenizer
from verl.workers.fsdp_workers import FSDPWorker
from verl.workers.reward import CustomRewardManager
def main():
cli_args = OmegaConf.from_cli()
file_config = OmegaConf.load(cli_args.config)
del cli_args.config
default_config = OmegaConf.structured(PPOConfig())
ppo_config = OmegaConf.merge(default_config, file_config, cli_args)
ppo_config = OmegaConf.to_object(ppo_config)
if not ray.is_initialized():
# this is for local ray cluster
ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}})
ray.get(main_task.remote(ppo_config))
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
def main_task(config: PPOConfig):
config.deep_post_init()
print(json.dumps(config.to_dict(), indent=2))
# instantiate tokenizer
tokenizer = get_tokenizer(config.worker.actor.model.model_path)
processor = get_processor(config.worker.actor.model.model_path, use_fast=True)
# define worker classes
ray_worker_group_cls = RayWorkerGroup
role_worker_mapping = {
Role.ActorRollout: ray.remote(FSDPWorker),
Role.Critic: ray.remote(FSDPWorker),
Role.RefPolicy: ray.remote(FSDPWorker),
}
global_pool_id = "global_pool"
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
Role.Critic: global_pool_id,
Role.RefPolicy: global_pool_id,
}
reward_fn = CustomRewardManager(
tokenizer=tokenizer, num_examine=1, compute_score=config.worker.reward.compute_score
)
val_reward_fn = CustomRewardManager(
tokenizer=tokenizer, num_examine=1, compute_score=config.worker.reward.compute_score
)
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
trainer = RayPPOTrainer(
config=config,
tokenizer=tokenizer,
processor=processor,
role_worker_mapping=role_worker_mapping,
resource_pool_manager=resource_pool_manager,
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
)
trainer.init_workers()
trainer.fit()
if __name__ == "__main__":
main()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
FSDP PPO Trainer with Ray-based single controller.
This trainer supports model-agonistic model initialization with huggingface
"""
import os
import uuid
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass, field
from enum import Enum
from pprint import pprint
from typing import Any, Dict, Optional, Type
import numpy as np
import torch
from codetiming import Timer
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import PreTrainedTokenizer, ProcessorMixin
from verl import DataProto
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
from verl.single_controller.base import Worker
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from verl.single_controller.ray.base import create_colocated_worker_cls
from verl.trainer import core_algos
from verl.trainer.config import PPOConfig
from verl.utils.rl_dataset import RLHFDataset, collate_fn
from verl.utils.torch_functional import masked_mean
from verl.utils.tracking import Tracking
from verl.workers.fsdp_workers import FSDPWorker
WorkerType = Type[Worker]
class Role(Enum):
"""
To create more roles dynamically, you can subclass Role and add new members
"""
Actor = 0
Rollout = 1
ActorRollout = 2
Critic = 3
RefPolicy = 4
RewardModel = 5
ActorRolloutRef = 6
@dataclass
class ResourcePoolManager:
"""
Define a resource pool specification. Resource pool will be initialized first.
Mapping
"""
resource_pool_spec: dict[str, list[int]]
mapping: dict[Role, str]
resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict)
def create_resource_pool(self):
for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():
# max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool
# For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one.
# For Megatron backend, we recommend using max_colocate_count>1 that can utilize different WorkerGroup for differnt models
resource_pool = RayResourcePool(
process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name
)
self.resource_pool_dict[resource_pool_name] = resource_pool
def get_resource_pool(self, role: Role) -> RayResourcePool:
"""Get the resource pool of the worker_cls"""
return self.resource_pool_dict[self.mapping[role]]
def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"):
responses = data.batch["responses"]
response_length = responses.size(1)
token_level_scores = data.batch["token_level_scores"]
batch_size = data.batch.batch_size[0]
attention_mask = data.batch["attention_mask"]
response_mask = attention_mask[:, -response_length:]
# compute kl between ref_policy and current policy
if "ref_log_prob" in data.batch.keys():
kld = core_algos.kl_penalty(
data.batch["old_log_probs"], data.batch["ref_log_prob"], kl_penalty=kl_penalty
) # (batch_size, response_length)
kld = kld * response_mask
beta = kl_ctrl.value
else:
beta = 0
kld = torch.zeros_like(response_mask, dtype=torch.float32)
token_level_rewards = token_level_scores - beta * kld
current_kl = masked_mean(kld, mask=response_mask, axis=-1) # average over sequence
current_kl = torch.mean(current_kl, dim=0).item()
# according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837
kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)
data.batch["token_level_rewards"] = token_level_rewards
metrics = {"critic/kl": current_kl, "critic/kl_coeff": beta}
return data, metrics
def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1):
# prepare response group
# TODO: add other ways to estimate advantages
if adv_estimator == "gae":
values = data.batch["values"]
responses = data.batch["responses"]
response_length = responses.size(-1)
attention_mask = data.batch["attention_mask"]
response_mask = attention_mask[:, -response_length:]
token_level_rewards = data.batch["token_level_rewards"]
advantages, returns = core_algos.compute_gae_advantage_return(
token_level_rewards=token_level_rewards, values=values, eos_mask=response_mask, gamma=gamma, lam=lam
)
data.batch["advantages"] = advantages
data.batch["returns"] = returns
elif adv_estimator == "grpo":
token_level_rewards = data.batch["token_level_rewards"]
index = data.non_tensor_batch["uid"]
responses = data.batch["responses"]
response_length = responses.size(-1)
attention_mask = data.batch["attention_mask"]
response_mask = attention_mask[:, -response_length:]
advantages, returns = core_algos.compute_grpo_outcome_advantage(
token_level_rewards=token_level_rewards, eos_mask=response_mask, index=index
)
data.batch["advantages"] = advantages
data.batch["returns"] = returns
elif adv_estimator == "reinforce_plus_plus":
token_level_rewards = data.batch["token_level_rewards"]
responses = data.batch["responses"]
response_length = responses.size(-1)
attention_mask = data.batch["attention_mask"]
response_mask = attention_mask[:, -response_length:]
advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage(
token_level_rewards=token_level_rewards, eos_mask=response_mask, gamma=gamma
)
data.batch["advantages"] = advantages
data.batch["returns"] = returns
elif adv_estimator == "remax":
token_level_rewards = data.batch["token_level_rewards"]
index = data.non_tensor_batch["uid"]
responses = data.batch["responses"]
response_length = responses.size(-1)
attention_mask = data.batch["attention_mask"]
response_mask = attention_mask[:, -response_length:]
reward_baselines = data.batch["reward_baselines"]
advantages, returns = core_algos.compute_remax_outcome_advantage(
token_level_rewards=token_level_rewards, reward_baselines=reward_baselines, eos_mask=response_mask
)
data.batch["advantages"] = advantages
data.batch["returns"] = returns
else:
raise NotImplementedError
return data
def reduce_metrics(metrics: Dict[str, Any]):
for key, val in metrics.items():
metrics[key] = np.mean(val)
return metrics
def _compute_response_info(batch: DataProto):
response_length = batch.batch["responses"].shape[-1]
prompt_mask = batch.batch["attention_mask"][:, :-response_length]
response_mask = batch.batch["attention_mask"][:, -response_length:]
prompt_length = prompt_mask.sum(-1).float()
response_length = response_mask.sum(-1).float() # (batch_size,)
return dict(
response_mask=response_mask,
prompt_length=prompt_length,
response_length=response_length,
)
def compute_data_metrics(batch: DataProto, use_critic: bool = True):
# TODO: add response length
sequence_score = batch.batch["token_level_scores"].sum(-1)
sequence_reward = batch.batch["token_level_rewards"].sum(-1)
advantages = batch.batch["advantages"]
returns = batch.batch["returns"]
max_response_length = batch.batch["responses"].shape[-1]
prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool()
response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool()
max_prompt_length = prompt_mask.size(-1)
response_info = _compute_response_info(batch)
prompt_length = response_info["prompt_length"]
response_length = response_info["response_length"]
valid_adv = torch.masked_select(advantages, response_mask)
valid_returns = torch.masked_select(returns, response_mask)
if use_critic:
values = batch.batch["values"]
valid_values = torch.masked_select(values, response_mask)
return_diff_var = torch.var(valid_returns - valid_values)
return_var = torch.var(valid_returns)
metrics = {
# score
"critic/score/mean": torch.mean(sequence_score).detach().item(),
"critic/score/max": torch.max(sequence_score).detach().item(),
"critic/score/min": torch.min(sequence_score).detach().item(),
# reward
"critic/rewards/mean": torch.mean(sequence_reward).detach().item(),
"critic/rewards/max": torch.max(sequence_reward).detach().item(),
"critic/rewards/min": torch.min(sequence_reward).detach().item(),
# adv
"critic/advantages/mean": torch.mean(valid_adv).detach().item(),
"critic/advantages/max": torch.max(valid_adv).detach().item(),
"critic/advantages/min": torch.min(valid_adv).detach().item(),
# returns
"critic/returns/mean": torch.mean(valid_returns).detach().item(),
"critic/returns/max": torch.max(valid_returns).detach().item(),
"critic/returns/min": torch.min(valid_returns).detach().item(),
**(
{
# values
"critic/values/mean": torch.mean(valid_values).detach().item(),
"critic/values/max": torch.max(valid_values).detach().item(),
"critic/values/min": torch.min(valid_values).detach().item(),
# vf explained var
"critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
}
if use_critic
else {}
),
# response length
"response_length/mean": torch.mean(response_length).detach().item(),
"response_length/max": torch.max(response_length).detach().item(),
"response_length/min": torch.min(response_length).detach().item(),
"response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float())
.detach()
.item(),
# prompt length
"prompt_length/mean": torch.mean(prompt_length).detach().item(),
"prompt_length/max": torch.max(prompt_length).detach().item(),
"prompt_length/min": torch.min(prompt_length).detach().item(),
"prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
}
return metrics
def compute_timing_metrics(batch, timing_raw):
response_info = _compute_response_info(batch)
num_prompt_tokens = torch.sum(response_info["prompt_length"]).item()
num_response_tokens = torch.sum(response_info["response_length"]).item()
num_overall_tokens = num_prompt_tokens + num_response_tokens
num_tokens_of_section = {
"gen": num_response_tokens,
**{name: num_overall_tokens for name in ["ref", "values", "adv", "update_critic", "update_actor"]},
}
return {
**{f"timing_s/{name}": value for name, value in timing_raw.items()},
**{
f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name]
for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys())
},
}
@contextmanager
def _timer(name: str, timing_raw: Dict[str, float]):
with Timer(name=name, logger=None) as timer:
yield
timing_raw[name] = timer.last
class RayPPOTrainer:
"""
Note that this trainer runs on the driver process on a single CPU/GPU node.
"""
# TODO: support each role have individual ray_worker_group_cls,
# i.e., support different backend of different role
def __init__(
self,
config: PPOConfig,
tokenizer: PreTrainedTokenizer,
processor: Optional[ProcessorMixin],
role_worker_mapping: dict[Role, WorkerType],
resource_pool_manager: ResourcePoolManager,
ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
reward_fn=None,
val_reward_fn=None,
):
self.tokenizer = tokenizer
self.processor = processor
self.config = config
self.reward_fn = reward_fn
self.val_reward_fn = val_reward_fn
self.hybrid_engine = config.worker.hybrid_engine
assert self.hybrid_engine, "Currently, only support hybrid engine"
if self.hybrid_engine:
assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()}"
self.role_worker_mapping = role_worker_mapping
self.resource_pool_manager = resource_pool_manager
self.use_reference_policy = Role.RefPolicy in role_worker_mapping
self.use_reward_model = Role.RewardModel in role_worker_mapping
self.ray_worker_group_cls = ray_worker_group_cls
# define KL control
if self.use_reference_policy:
self.kl_ctrl = core_algos.get_kl_controller(config.algorithm)
else:
self.kl_ctrl = core_algos.FixedKLController(kl_coef=0.0)
if self.config.algorithm.adv_estimator == "gae":
self.use_critic = True
elif self.config.algorithm.adv_estimator == "grpo":
self.use_critic = False
elif self.config.algorithm.adv_estimator == "reinforce_plus_plus":
self.use_critic = False
elif self.config.algorithm.adv_estimator == "remax":
self.use_critic = False
else:
raise NotImplementedError
self._create_dataloader()
def _create_dataloader(self):
self.train_dataset = RLHFDataset(
data_path=self.config.data.train_files,
tokenizer=self.tokenizer,
processor=self.processor,
prompt_key=self.config.data.prompt_key,
max_prompt_length=self.config.data.max_prompt_length,
truncation="right",
system_prompt=self.config.data.system_prompt,
min_pixels=self.config.data.min_pixels,
max_pixels=self.config.data.max_pixels,
)
# use sampler for better ckpt resume
if self.config.data.shuffle:
train_dataloader_generator = torch.Generator()
train_dataloader_generator.manual_seed(self.config.data.seed)
sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator)
else:
sampler = SequentialSampler(data_source=self.train_dataset)
self.train_dataloader = DataLoader(
dataset=self.train_dataset,
batch_size=self.config.data.rollout_batch_size,
num_workers=8,
drop_last=True,
collate_fn=collate_fn,
sampler=sampler,
)
self.val_dataset = RLHFDataset(
data_path=self.config.data.val_files,
tokenizer=self.tokenizer,
processor=self.processor,
prompt_key=self.config.data.prompt_key,
max_prompt_length=self.config.data.max_prompt_length,
truncation="right",
system_prompt=self.config.data.system_prompt,
min_pixels=self.config.data.min_pixels,
max_pixels=self.config.data.max_pixels,
)
self.val_dataloader = DataLoader(
dataset=self.val_dataset,
batch_size=len(self.val_dataset),
num_workers=8,
shuffle=False,
drop_last=False,
collate_fn=collate_fn,
)
assert len(self.train_dataloader) >= 1
assert len(self.val_dataloader) >= 1
print(f"Size of train dataloader: {len(self.train_dataloader)}")
print(f"Size of val dataloader: {len(self.val_dataloader)}")
if self.config.trainer.max_steps is not None:
training_steps = self.config.trainer.max_steps
else:
training_steps = len(self.train_dataloader) * self.config.trainer.total_episodes
self.training_steps = training_steps
self.config.worker.actor.optim.training_steps = training_steps
self.config.worker.critic.optim.training_steps = training_steps
print(f"Total training steps: {self.training_steps}")
def _maybe_log_val_generations_to_wandb(self, inputs, outputs, scores):
"""Log a table of validation samples to wandb"""
generations_to_log = self.config.trainer.val_generations_to_log_to_wandb
if generations_to_log == 0:
return
if generations_to_log > 0 and "wandb" not in self.config.trainer.logger:
print("WARNING: `val_generations_to_log_to_wandb` is set, but no wandb logger is found.")
return
import wandb
# Create tuples of (input, output, score) and sort by input text
samples = list(zip(inputs, outputs, scores))
samples.sort(key=lambda x: x[0]) # Sort by input text
# Use fixed random seed for deterministic shuffling
rng = np.random.RandomState(42)
rng.shuffle(samples)
# Take first N samples after shuffling
samples = samples[:generations_to_log]
# Create column names for all samples
columns = ["step"] + sum(
[[f"input_{i + 1}", f"output_{i + 1}", f"score_{i + 1}"] for i in range(len(samples))], []
)
if not hasattr(self, "validation_table"):
# Initialize the table on first call
self.validation_table = wandb.Table(columns=columns)
# Create a new table with same columns and existing data
# Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737
new_table = wandb.Table(columns=columns, data=self.validation_table.data)
# Add new row with all data
row_data = []
row_data.append(self.global_steps)
for sample in samples:
row_data.extend(sample)
new_table.add_data(*row_data)
# Update reference and log
wandb.log({"val/generations": new_table}, step=self.global_steps)
self.validation_table = new_table
def _validate(self):
reward_tensor_lst = []
data_source_lst = []
# Lists to collect samples for the table
sample_inputs = []
sample_outputs = []
sample_scores = []
for test_data in self.val_dataloader:
test_batch = DataProto.from_single_dict(test_data)
# Store original inputs
input_ids = test_batch.batch["input_ids"]
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
sample_inputs.extend(input_texts)
if "pixel_values" in test_batch.non_tensor_batch.keys():
test_gen_batch = test_batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["pixel_values", "image_grid_thw", "raw_prompt_ids", "images"],
)
else:
test_gen_batch = test_batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["raw_prompt_ids"],
)
test_gen_batch.meta_info = {"do_sample": False}
# pad to be divisible by dp_size
test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(
test_gen_batch, self.actor_rollout_wg.world_size
)
test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)
# unpad
test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)
print("validation generation end")
# Store generated outputs
output_ids = test_output_gen_batch.batch["responses"]
output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
sample_outputs.extend(output_texts)
test_batch = test_batch.union(test_output_gen_batch)
# evaluate using reward_function
reward_tensor = self.val_reward_fn(test_batch)
# Store scores
scores = reward_tensor.sum(-1).cpu().tolist()
sample_scores.extend(scores)
reward_tensor_lst.append(reward_tensor)
data_source_lst.append(
test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0])
)
self._maybe_log_val_generations_to_wandb(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)
reward_tensor = torch.cat(reward_tensor_lst, dim=0).sum(-1).cpu() # (batch_size,)
data_sources = np.concatenate(data_source_lst, axis=0)
# evaluate test_score based on data source
data_source_reward = {}
for i in range(reward_tensor.shape[0]):
data_source = data_sources[i]
if data_source not in data_source_reward:
data_source_reward[data_source] = []
data_source_reward[data_source].append(reward_tensor[i].item())
metric_dict = {}
for data_source, rewards in data_source_reward.items():
metric_dict[f"val/test_score/{data_source}"] = np.mean(rewards)
return metric_dict
def init_workers(self):
"""Init resource pool and worker group"""
self.resource_pool_manager.create_resource_pool()
self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}
# create actor and rollout
if self.hybrid_engine:
resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)
actor_rollout_cls = RayClassWithInitArgs(
cls=self.role_worker_mapping[Role.ActorRollout], config=self.config.worker, role="actor_rollout"
)
self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls
else:
raise NotImplementedError
# create critic
if self.use_critic:
resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
critic_cls = RayClassWithInitArgs(
cls=self.role_worker_mapping[Role.Critic], config=self.config.worker, role="critic"
)
self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls
# create reference policy if needed
if self.use_reference_policy:
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
ref_policy_cls = RayClassWithInitArgs(
self.role_worker_mapping[Role.RefPolicy], config=self.config.worker, role="ref"
)
self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls
# create a reward model if reward_fn is None
if self.use_reward_model:
# we create a RM here
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
rm_cls = RayClassWithInitArgs(
cls=self.role_worker_mapping[Role.RewardModel], config=self.config.worker, role="reward"
)
self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls
# initialize WorkerGroup
# NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
# you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups.
# See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
all_wg = {}
self.wg_dicts = []
for resource_pool, class_dict in self.resource_pool_to_cls.items():
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls)
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
all_wg.update(spawn_wg)
# keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699
self.wg_dicts.append(wg_dict)
if self.use_critic:
self.critic_wg: FSDPWorker = all_wg["critic"]
self.critic_wg.init_model()
if self.use_reference_policy:
self.ref_policy_wg: FSDPWorker = all_wg["ref"]
self.ref_policy_wg.init_model()
if self.use_reward_model:
self.rm_wg: FSDPWorker = all_wg["rm"]
self.rm_wg.init_model()
# we should create rollout at the end so that vllm can have a better estimation of kv cache memory
self.actor_rollout_wg: FSDPWorker = all_wg["actor_rollout"]
self.actor_rollout_wg.init_model()
def _save_checkpoint(self):
# path: {save_checkpoint_path}/global_step_{global_steps}/actor
local_global_step_folder = os.path.join(
self.config.trainer.save_checkpoint_path, f"global_step_{self.global_steps}"
)
actor_local_path = os.path.join(local_global_step_folder, "actor")
self.actor_rollout_wg.save_checkpoint(
actor_local_path,
self.global_steps,
remove_previous_ckpt=self.config.trainer.remove_previous_ckpt,
)
if self.use_critic:
critic_local_path = os.path.join(local_global_step_folder, "critic")
self.critic_wg.save_checkpoint(
critic_local_path,
self.global_steps,
remove_previous_ckpt=self.config.trainer.remove_previous_ckpt,
)
local_latest_checkpointed_iteration = os.path.join(
self.config.trainer.save_checkpoint_path, "latest_checkpointed_iteration.txt"
)
with open(local_latest_checkpointed_iteration, "w") as f:
f.write(str(self.global_steps))
def _load_checkpoint(self):
if self.config.trainer.load_checkpoint_path is None:
return
print(f"Load from checkpoint: {self.config.trainer.load_checkpoint_path}")
actor_path = os.path.join(self.config.trainer.load_checkpoint_path, "actor")
critic_path = os.path.join(self.config.trainer.load_checkpoint_path, "critic")
self.actor_rollout_wg.load_checkpoint(
actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load
)
if self.use_critic:
self.critic_wg.load_checkpoint(
critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load
)
def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
logger = Tracking(
project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=self.config.to_dict(),
)
self.global_steps = 0
# load checkpoint before doing anything
self._load_checkpoint()
# perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.val_before_train:
val_metrics = self._validate()
pprint(f"Initial validation metrics: {val_metrics}")
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.val_only:
return
for _ in range(self.config.trainer.total_episodes):
for batch_dict in self.train_dataloader:
self.global_steps += 1
if self.global_steps >= self.training_steps:
break
metrics = {}
timing_raw = {}
batch: DataProto = DataProto.from_single_dict(batch_dict)
# pop those keys for generation
if "pixel_values" in batch.non_tensor_batch.keys():
gen_batch = batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["pixel_values", "image_grid_thw", "raw_prompt_ids", "images"],
)
else:
gen_batch = batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["raw_prompt_ids"],
)
with _timer("step", timing_raw):
# generate a batch
with _timer("gen", timing_raw): # wg: worker group
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
if self.config.algorithm.adv_estimator == "remax":
with _timer("gen_max", timing_raw):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["do_sample"] = False
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
batch = batch.union(gen_baseline_output)
reward_baseline_tensor = self.reward_fn(batch)
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
batch.batch["reward_baselines"] = reward_baseline_tensor
del gen_baseline_batch, gen_baseline_output
batch.non_tensor_batch["uid"] = np.array(
[str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object
)
# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.worker.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)
# balance the number of valid tokens on each dp rank.
# Note that this breaks the order of data inside the batch.
# Please take care when you implement group based adv computation such as GRPO and rloo
# self._balance_batch(batch, metrics=metrics) # TODO: re-enable balance batch
# compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
# recompute old_log_probs
with _timer("old_log_prob", timing_raw):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
batch = batch.union(old_log_prob)
if self.use_reference_policy:
# compute reference log_prob
with _timer("ref", timing_raw):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
# compute values
if self.use_critic:
with _timer("values", timing_raw):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
with _timer("adv", timing_raw):
# compute scores. Support both model and function-based.
# We first compute the scores using reward model. Then, we call reward_fn to combine
# the results from reward model and rule-based results.
if self.use_reward_model:
raise NotImplementedError
# we combine with rule-based rm
reward_tensor = self.reward_fn(batch)
batch.batch["token_level_scores"] = reward_tensor
# compute rewards. apply_kl_penalty if available
if not self.config.worker.actor.use_kl_loss: # not grpo
batch, kl_metrics = apply_kl_penalty(
batch, kl_ctrl=self.kl_ctrl, kl_penalty=self.config.algorithm.kl_penalty
)
metrics.update(kl_metrics)
else:
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
# compute advantages, executed on the driver process
batch = compute_advantage(
batch,
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.worker.rollout.n,
)
# update critic
if self.use_critic:
with _timer("update_critic", timing_raw):
critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
metrics.update(critic_output_metrics)
# implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
with _timer("update_actor", timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)
# validate
if (
self.val_reward_fn is not None
and self.config.trainer.test_freq > 0
and self.global_steps % self.config.trainer.test_freq == 0
):
with _timer("testing", timing_raw):
val_metrics: dict = self._validate()
metrics.update(val_metrics)
if self.config.trainer.save_freq > 0 and self.global_steps % self.config.trainer.save_freq == 0:
with _timer("save_checkpoint", timing_raw):
self._save_checkpoint()
# collect metrics
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)
# perform validation after training
if self.val_reward_fn is not None:
val_metrics = self._validate()
pprint(f"Final validation metrics: {val_metrics}")
logger.log(data=val_metrics, step=self.global_steps)
self._save_checkpoint()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .tokenizer import get_processor, get_tokenizer
__all__ = ["get_processor", "get_tokenizer"]
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import shutil
import tempfile
import numpy as np
import torch
import torch.distributed
from filelock import FileLock
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import PreTrainedTokenizer, ProcessorMixin
class BaseCheckpointManager:
"""
A checkpoint manager that saves and loads
- model
- optimizer
- lr_scheduler
- extra_states
in a SPMD way.
We save
- sharded model states and optimizer states
- full lr_scheduler states
- huggingface tokenizer and config for ckpt merge
"""
def __init__(
self,
model: FSDP,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin
):
self.previous_global_step = None
self.previous_save_local_path = None
self.model = model
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.tokenizer = tokenizer
self.processor = processor
assert isinstance(self.model, FSDP)
self.rank = torch.distributed.get_rank()
self.world_size = torch.distributed.get_world_size()
def load_checkpoint(self, *args, **kwargs):
raise NotImplementedError
def save_checkpoint(self, *args, **kwargs):
raise NotImplementedError
def remove_previous_save_local_path(self):
if not self.previous_save_local_path:
return
abs_path = os.path.abspath(self.previous_save_local_path)
print(f"Checkpoint manager remove previous save local path: {abs_path}")
if not os.path.exists(abs_path):
return
# remove previous local_path
shutil.rmtree(abs_path, ignore_errors=True)
@staticmethod
def local_mkdir(path):
if not os.path.isabs(path):
working_dir = os.getcwd()
path = os.path.join(working_dir, path)
# Using hash value of path as lock file name to avoid long file name
lock_filename = f"ckpt_{hash(path) & 0xFFFFFFFF:08x}.lock"
lock_path = os.path.join(tempfile.gettempdir(), lock_filename)
try:
with FileLock(lock_path, timeout=60): # Add timeout
# make a new dir
os.makedirs(path, exist_ok=True)
except Exception as e:
print(f"Warning: Failed to acquire lock for {path}: {e}")
# Even if the lock is not acquired, try to create the directory
os.makedirs(path, exist_ok=True)
return path
@staticmethod
def get_rng_state():
rng_state = {
"cpu": torch.get_rng_state(),
"cuda": torch.cuda.get_rng_state(),
"numpy": np.random.get_state(),
"random": random.getstate(),
}
return rng_state
@staticmethod
def load_rng_state(rng_state):
torch.set_rng_state(rng_state["cpu"])
torch.cuda.set_rng_state(rng_state["cuda"])
np.random.set_state(rng_state["numpy"])
random.setstate(rng_state["random"])
def find_latest_ckpt_path(path, directory_format="global_step_{}"):
if path is None:
return None
tracker_file = get_checkpoint_tracker_filename(path)
if not os.path.exists(tracker_file):
print("Checkpoint tracker file does not exist: %s", tracker_file)
return None
with open(tracker_file, "rb") as f:
iteration = int(f.read().decode())
ckpt_path = os.path.join(path, directory_format.format(iteration))
if not os.path.exists(ckpt_path):
print("Checkpoint does not exist: %s", ckpt_path)
return None
print("Found checkpoint: %s", ckpt_path)
return ckpt_path
def get_checkpoint_tracker_filename(root_path: str):
"""
Tracker file rescords the latest chckpoint during training to restart from.
"""
return os.path.join(root_path, "latest_checkpointed_iteration.txt")
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
import torch
import torch.distributed
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType
from transformers import PreTrainedTokenizer, ProcessorMixin
from .checkpoint_manager import BaseCheckpointManager
class FSDPCheckpointManager(BaseCheckpointManager):
"""
A checkpoint manager that saves and loads
- model
- optimizer
- lr_scheduler
- extra_states
in a SPMD way.
We save
- sharded model states and optimizer states
- full lr_scheduler states
- huggingface tokenizer and config for ckpt merge
"""
def __init__(
self,
model: FSDP,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin,
*args,
**kwargs,
):
super().__init__(model, optimizer, lr_scheduler, tokenizer, processor)
def load_checkpoint(self, path=None, *args, **kwargs):
if path is None:
return
# every rank download its own checkpoint
local_model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
local_optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt")
local_extra_state_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
print(
f"[rank-{self.rank}]: Loading from {local_model_path} and {local_optim_path} and {local_extra_state_path}"
)
model_state_dict = torch.load(local_model_path)
optimizer_state_dict = torch.load(local_optim_path)
extra_state_dict = torch.load(local_extra_state_path)
lr_scheduler_state_dict = extra_state_dict["lr_scheduler"]
state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True)
optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True)
with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg):
self.model.load_state_dict(model_state_dict)
if self.optimizer is not None:
self.optimizer.load_state_dict(optimizer_state_dict)
# recover random state
if "rng" in extra_state_dict:
# 'rng' may not exist for backward compatibility
self.load_rng_state(extra_state_dict["rng"])
if self.lr_scheduler is not None:
self.lr_scheduler.load_state_dict(lr_scheduler_state_dict)
def save_checkpoint(self, local_path: str, global_step: int, remove_previous_ckpt=False, *args, **kwargs):
# record the previous global step
self.previous_global_step = global_step
# remove previous local_path
# TODO: shall we remove previous ckpt every save?
if remove_previous_ckpt:
self.remove_previous_save_local_path()
local_path = self.local_mkdir(local_path)
torch.distributed.barrier()
# every rank will save its own model and optim shard
state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True)
optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg):
model_state_dict = self.model.state_dict()
if self.optimizer is not None:
optimizer_state_dict = self.optimizer.state_dict()
else:
optimizer_state_dict = None
if self.lr_scheduler is not None:
lr_scheduler_state_dict = self.lr_scheduler.state_dict()
else:
lr_scheduler_state_dict = None
extra_state_dict = {
"lr_scheduler": lr_scheduler_state_dict,
"rng": self.get_rng_state(),
}
model_path = os.path.join(local_path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
optim_path = os.path.join(local_path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt")
extra_path = os.path.join(local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}")
print(f"[rank-{self.rank}]: Saving checkpoint to {os.path.abspath(model_path)}")
print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}")
torch.save(model_state_dict, model_path)
torch.save(optimizer_state_dict, optim_path) # TODO: address optimizer is None
torch.save(extra_state_dict, extra_path)
# wait for everyone to dump to local
torch.distributed.barrier()
if self.rank == 0:
hf_local_path = os.path.join(local_path, "huggingface")
os.makedirs(hf_local_path, exist_ok=True)
self.model._fsdp_wrapped_module.config.save_pretrained(hf_local_path)
if self.processor:
self.processor.save_pretrained(hf_local_path)
else:
self.tokenizer.save_pretrained(hf_local_path)
torch.distributed.barrier()
self.previous_save_local_path = local_path
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from transformers import LlamaConfig, PretrainedConfig, Qwen2Config
VALID_CONFIG_TYPE = (Qwen2Config, LlamaConfig)
def get_device_flops(unit="T"):
def unit_convert(number, level):
units = ["B", "K", "M", "G", "T", "P"]
if number <= 0:
return number
ptr = 0
while ptr < len(units) and units[ptr] != level:
number /= 1000
ptr += 1
return number
device_name = torch.cuda.get_device_name()
flops = float("inf") # INF flops for unkown gpu type
if "H100" in device_name or "H800" in device_name:
flops = 989e12
elif "A100" in device_name or "A800" in device_name:
flops = 312e12
elif "L40" in device_name:
flops = 181.05e12
elif "L20" in device_name:
flops = 119.5e12
elif "H20" in device_name:
flops = 148e12
elif "910B" in device_name:
flops = 354e12
flops_unit = unit_convert(flops, unit)
return flops_unit
class FlopsCounter:
"""
Used to count mfu during training loop
Example:
flops_counter = FlopsCounter(config)
flops_achieved, flops_promised = flops_counter.estimate_flops(tokens_list, delta_time)
"""
def __init__(self, config: PretrainedConfig):
if not isinstance(config, VALID_CONFIG_TYPE):
print(f"Only support config type of {VALID_CONFIG_TYPE}, but got {type(config)}. MFU will always be zero.")
self.estimate_func = {"qwen2": self._estimate_qwen2_flops, "llama": self._estimate_qwen2_flops}
self.config = config
def _estimate_unknown_flops(self, tokens_sum, batch_seqlens, delta_time):
return 0
def _estimate_qwen2_flops(self, tokens_sum, batch_seqlens, delta_time):
assert isinstance(self.config, (Qwen2Config, LlamaConfig))
hidden_size = self.config.hidden_size
vocab_size = self.config.vocab_size
num_hidden_layers = self.config.num_hidden_layers
num_key_value_heads = self.config.num_key_value_heads
num_attention_heads = self.config.num_attention_heads
intermediate_size = self.config.intermediate_size
head_dim = hidden_size // num_attention_heads
q_size = num_attention_heads * head_dim
k_size = num_key_value_heads * head_dim
v_size = num_key_value_heads * head_dim
# non-attn per layer parm
# Qwen2/LLama use SwiGelu, gate, having up and down linear layer in mlp
mlp_N = hidden_size * intermediate_size * 3
attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim)
emd_and_lm_head_N = vocab_size * hidden_size * 2
# non-attn all_layer parm
dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N
# non-attn all_layer & all_token fwd & bwd flops
dense_N_flops = 6 * dense_N * tokens_sum
# attn all_layer & all_token fwd & bwd flops
seqlen_square_sum = 0
for seqlen in batch_seqlens:
seqlen_square_sum += seqlen * seqlen
attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
# all_layer & all_token fwd & bwd flops
flops_all_token = dense_N_flops + attn_qkv_flops
flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
return flops_achieved
def estimate_flops(self, batch_seqlens, delta_time):
"""
Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken.
Args:
batch_seqlens (List[int]): A list where each element represents the number of valid tokens in the current batch.
delta_time (float): The time taken to process the batch, in seconds.
Returns:
estimated_flops (float): The estimated FLOPS based on the input tokens and time.
promised_flops (float): The expected FLOPS of the current device.
"""
tokens_sum = sum(batch_seqlens)
func = self.estimate_func.get(self.config.model_type, self._estimate_unknown_flops)
estimated_flops = func(tokens_sum, batch_seqlens, delta_time)
promised_flops = get_device_flops()
return estimated_flops, promised_flops
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from functools import partial
from typing import Callable, Union
import torch
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._runtime_utils import _lazy_init
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.optim import Optimizer
from transformers import PreTrainedModel
from transformers.trainer_pt_utils import get_module_class_from_name
def get_init_fn(model: nn.Module, device: Union[str, torch.device]) -> Callable[[nn.Module], None]:
param_occurrence = defaultdict(int)
for _, param in model.named_parameters(remove_duplicate=False):
param_occurrence[param] += 1
duplicated_params = {param for param in param_occurrence.keys() if param_occurrence[param] > 1}
materialized_params = {}
def init_fn(module: nn.Module):
for name, param in module.named_parameters(recurse=False):
if param in duplicated_params:
module._parameters[name] = materialized_params.setdefault(
param, nn.Parameter(torch.empty_like(param.data, device=device), requires_grad=param.requires_grad)
)
else:
module._parameters[name] = nn.Parameter(
torch.empty_like(param.data, device=device), requires_grad=param.requires_grad
)
return init_fn
def get_fsdp_wrap_policy(model: PreTrainedModel):
"""Get FSDP wrap policy for the model.
Args:
module: The module to get wrap policy for
"""
transformer_cls_to_wrap = set()
for module in model._no_split_modules:
transformer_cls = get_module_class_from_name(model, module)
if transformer_cls is None:
raise Exception(f"Cannot find {module} in pretrained model.")
else:
transformer_cls_to_wrap.add(transformer_cls)
return partial(transformer_auto_wrap_policy, transformer_layer_cls=transformer_cls_to_wrap)
@torch.no_grad()
def offload_fsdp_model(model: FSDP, empty_cache: bool = True):
# lazy init FSDP model
_lazy_init(model, model)
assert model._is_root, "Only support root model offloading to CPU"
for handle in model._all_handles:
if handle._offload_params:
continue
flat_param = handle.flat_param
assert (
flat_param.data.data_ptr() == flat_param._local_shard.data_ptr()
and id(flat_param.data) != id(flat_param._local_shard)
and flat_param.data.size() == flat_param._local_shard.size()
)
handle.flat_param_to("cpu", non_blocking=True)
# the following still keeps id(._local_shard) != id(.data)
flat_param._local_shard = flat_param.data
assert id(flat_param._local_shard) != id(flat_param.data)
if empty_cache:
torch.cuda.empty_cache()
@torch.no_grad()
def load_fsdp_model(model: FSDP):
# lazy init FSDP model
_lazy_init(model, model)
assert model._is_root, "Only support root model loading to GPU"
for handle in model._all_handles:
if handle._offload_params:
continue
flat_param = handle.flat_param
handle.flat_param_to("cuda", non_blocking=True)
# the following still keeps id(._local_shard) != id(.data)
flat_param._local_shard = flat_param.data
@torch.no_grad()
def offload_fsdp_optimizer(optimizer: Optimizer):
if not optimizer.state:
return
for param_group in optimizer.param_groups:
for param in param_group["params"]:
state = optimizer.state[param]
for key, value in state.items():
if isinstance(value, torch.Tensor):
state[key] = value.to("cpu", non_blocking=True)
@torch.no_grad()
def load_fsdp_optimizer(optimizer: Optimizer):
if not optimizer.state:
return
for param_group in optimizer.param_groups:
for param in param_group["params"]:
state = optimizer.state[param]
for key, value in state.items():
if isinstance(value, torch.Tensor):
state[key] = value.to("cuda", non_blocking=True)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A Ray logger will receive logging info from different processes.
"""
import numbers
from typing import Dict
def concat_dict_to_str(dict: Dict, step):
output = [f"step {step}:"]
for k, v in dict.items():
if isinstance(v, numbers.Number):
output.append(f"{k}:{v:.3f}")
output_str = " - ".join(output)
return output_str
class LocalLogger:
def __init__(self, remote_logger=None, enable_wandb=False, print_to_console=False):
self.print_to_console = print_to_console
if print_to_console:
print("Using LocalLogger is deprecated. The constructor API will change.")
def flush(self):
pass
def log(self, data, step):
if self.print_to_console:
print(concat_dict_to_str(data, step=step), flush=True)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities to create common models
"""
import torch
from torch import nn
def get_model_size(model: nn.Module, scale="auto"):
n_params = sum(p.numel() for p in model.parameters())
if scale == "auto":
if n_params > 1e9:
scale = "B"
elif n_params > 1e6:
scale = "M"
elif n_params > 1e3:
scale = "K"
else:
scale = ""
if scale == "B":
n_params = n_params / 1e9
elif scale == "M":
n_params = n_params / 1e6
elif scale == "K":
n_params = n_params / 1e3
elif scale == "":
pass
else:
raise NotImplementedError(f"Unknown scale {scale}")
return n_params, scale
def print_model_size(model: nn.Module, name: str = None):
n_params, scale = get_model_size(model, scale="auto")
if name is None:
name = model.__class__.__name__
print(f"{name} contains {n_params:.2f}{scale} parameters")
def compute_position_id_with_mask(mask):
return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed as dist
def log_gpu_memory_usage(head: str, rank: int = 0):
if (not dist.is_initialized()) or (rank is None) or (dist.get_rank() == rank):
memory_allocated = torch.cuda.memory_allocated() / 1024**3
memory_reserved = torch.cuda.memory_reserved() / 1024**3
print(f"{head}, memory allocated (GB): {memory_allocated}, memory reserved (GB): {memory_reserved}.")
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contain small python utility functions
"""
from typing import Any, Dict, List
def union_two_dict(dict1: Dict[str, Any], dict2: Dict[str, Any]) -> Dict[str, Any]:
"""Union two dict. Will throw an error if there is an item not the same object with the same key."""
for key, value in dict2.items():
if key in dict1:
assert dict1[key] != value, f"{key} in meta_dict1 and meta_dict2 are not the same object"
dict1[key] = value
return dict1
def append_to_dict(data: Dict[str, List[Any]], new_data: Dict[str, Any]) -> None:
for key, val in new_data.items():
if key not in data:
data[key] = []
data[key].append(val)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .math import math_compute_score
from .r1v import r1v_compute_score
__all__ = ["math_compute_score", "r1v_compute_score"]
from mathruler.grader import extract_boxed_content, grade_answer
def math_compute_score(predict_str: str, ground_truth: str) -> float:
answer = extract_boxed_content(predict_str)
if answer == "None":
return 0.0 # no answer
if grade_answer(answer, ground_truth):
return 1.0 # correct answer
return 0.1 # wrong answer
import re
from mathruler.grader import grade_answer
def r1v_format_reward(predict_str: str) -> float:
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
match = re.fullmatch(pattern, predict_str, re.DOTALL)
return 1.0 if match else 0.0
def r1v_accuracy_reward(predict_str: str, ground_truth: str) -> float:
try:
ground_truth = ground_truth.strip()
content_match = re.search(r"<answer>(.*?)</answer>", predict_str)
pred_answer = content_match.group(1).strip() if content_match else predict_str.strip()
if grade_answer(pred_answer, ground_truth):
return 1.0
except Exception:
pass
return 0.0
def r1v_compute_score(predict_str: str, ground_truth: str) -> float:
acc_reward = r1v_accuracy_reward(predict_str, ground_truth)
format_reward = r1v_format_reward(predict_str)
reward = acc_reward + format_reward
reward /= 2
return reward
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections import defaultdict
from typing import Any, Dict, List, Optional
import torch
from datasets import load_dataset
from PIL import Image
from PIL.Image import Image as ImageObject
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, ProcessorMixin
import verl.utils.torch_functional as verl_F
from verl.models.transformers.qwen2_5_vl import get_rope_index
def collate_fn(features: List[Dict[str, Any]]) -> Dict[str, Any]:
tensors = defaultdict(list)
non_tensors = defaultdict(list)
for feature in features:
for key, value in feature.items():
if isinstance(value, torch.Tensor):
tensors[key].append(value)
else:
non_tensors[key].append(value)
for key, value in tensors.items():
if key not in ["pixel_values", "image_grid_thw"]:
tensors[key] = torch.stack(value, dim=0)
return {**tensors, **non_tensors}
def process_image(image: ImageObject, max_pixels: int, min_pixels: int) -> ImageObject:
if (image.width * image.height) > max_pixels:
resize_factor = math.sqrt(max_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height), resample=Image.Resampling.NEAREST)
if (image.width * image.height) < min_pixels:
resize_factor = math.sqrt(min_pixels / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height), resample=Image.Resampling.NEAREST)
if image.mode != "RGB":
image = image.convert("RGB")
return image
class RLHFDataset(Dataset):
"""
We assume the dataset contains a column that contains prompts and other information
"""
def __init__(
self,
data_path: str,
tokenizer: PreTrainedTokenizer,
processor: Optional[ProcessorMixin],
prompt_key="prompt",
max_prompt_length=1024,
truncation="error",
system_prompt=None,
max_pixels=None,
min_pixels=None,
):
self.tokenizer = tokenizer
self.processor = processor
self.prompt_key = prompt_key
self.max_prompt_length = max_prompt_length
self.truncation = truncation
self.system_prompt = system_prompt
self.max_pixels = max_pixels
self.min_pixels = min_pixels
if "@" in data_path:
data_path, data_split = data_path.split("@")
else:
data_split = "train"
self.dataset = load_dataset(data_path, split=data_split)
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
"""
Note that we also return the raw_input_ids so that it can be combined with other chat template
"""
row_dict = self.dataset[index]
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": row_dict[self.prompt_key]},
]
prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
if "images" in row_dict: # expand image token
raw_prompt = prompt.replace("<image>", "<|vision_start|><|image_pad|><|vision_end|>")
row_dict["images"] = [
process_image(image, self.max_pixels, self.min_pixels) for image in row_dict["images"]
]
image_inputs = self.processor.image_processor(row_dict["images"], return_tensors="pt")
image_grid_thw = image_inputs["image_grid_thw"]
row_dict.update(image_inputs)
if image_grid_thw is not None:
merge_length = self.processor.image_processor.merge_size**2
index = 0
while "<image>" in prompt:
prompt = prompt.replace(
"<image>",
"<|vision_start|>"
+ "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length)
+ "<|vision_end|>",
1,
)
index += 1
prompt = prompt.replace("<|placeholder|>", self.processor.image_token)
else:
raw_prompt = prompt
input_ids, attention_mask = verl_F.tokenize_and_postprocess_data(
prompt=prompt,
tokenizer=self.tokenizer,
max_length=self.max_prompt_length,
pad_token_id=self.tokenizer.pad_token_id,
left_pad=True,
truncation=self.truncation,
)
if "images" in row_dict:
position_ids = get_rope_index(
self.processor,
input_ids=input_ids,
image_grid_thw=image_grid_thw,
attention_mask=attention_mask,
) # (3, seq_len)
else:
position_ids = torch.clip(attention_mask.cumsum(dim=0) - 1, min=0, max=None) # (seqlen,)
row_dict["input_ids"] = input_ids
row_dict["attention_mask"] = attention_mask
row_dict["position_ids"] = position_ids
row_dict["raw_prompt_ids"] = self.tokenizer.encode(raw_prompt, add_special_tokens=False)
return row_dict
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for tokenization."""
from typing import Optional
from transformers import AutoConfig, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, ProcessorMixin
def get_tokenizer(model_path, correct_pad_token=True, correct_gemma=True, **kwargs) -> PreTrainedTokenizer:
"""Create a huggingface pretrained tokenizer.
Args:
name (str): The name of the tokenizer.
correct_pad_token (bool): Whether to correct the pad token id.
correct_gemma (bool): Whether to correct the gemma tokenizer.
**kwargs: The keyword arguments for the tokenizer.
Returns:
transformers.PreTrainedTokenizer: The pretrained tokenizer.
"""
config = AutoConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, **kwargs)
if correct_gemma and getattr(config, "model_type", None) in ["gemma", "gemma2"]:
# the EOS token in gemma2 is ambiguious, which may worsen RL performance.
# https://huggingface.co/google/gemma-2-2b-it/commit/17a01657f5c87135bcdd0ec7abb4b2dece04408a
print("Found gemma model. Set eos_token and eos_token_id to <end_of_turn> and 107.")
tokenizer.eos_token = "<end_of_turn>"
if correct_pad_token:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
def get_processor(model_path, **kwargs) -> Optional[ProcessorMixin]:
try:
processor = AutoProcessor.from_pretrained(model_path, **kwargs)
except Exception:
processor = None
# Avoid load tokenizer, see:
# https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/auto/processing_auto.py#L344
if processor is not None and "Processor" not in processor.__class__.__name__:
processor = None
return processor
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment