Commit 496acb03 authored by chenych's avatar chenych
Browse files

v0.3.0

parent d8de2ca8
......@@ -18,16 +18,16 @@ from typing import Dict
from mathruler.grader import grade_answer
def format_reward(predict: str) -> float:
def format_reward(predict_str: str) -> float:
pattern = re.compile(r"<think>.*?</think>\s*<answer>.*?</answer>", re.DOTALL)
format_match = re.fullmatch(pattern, predict)
format_match = re.fullmatch(pattern, predict_str)
return 1.0 if format_match else 0.0
def accuracy_reward(predict: str, ground_truth: str) -> float:
def accuracy_reward(predict_str: str, ground_truth: str) -> float:
try:
content_match = re.search(r"<answer>(.*?)</answer>", predict)
given_answer = content_match.group(1).strip() if content_match else predict.strip()
content_match = re.search(r"<answer>(.*?)</answer>", predict_str)
given_answer = content_match.group(1).strip() if content_match else predict_str.strip()
if grade_answer(given_answer, ground_truth.strip()):
return 1.0
......@@ -37,9 +37,9 @@ def accuracy_reward(predict: str, ground_truth: str) -> float:
return 0.0
def compute_score(predict: str, ground_truth: str, format_weight: float = 0.5) -> Dict[str, float]:
format_score = format_reward(predict)
accuracy_score = accuracy_reward(predict, ground_truth)
def compute_score(predict_str: str, ground_truth: str, format_weight: float = 0.5) -> Dict[str, float]:
format_score = format_reward(predict_str)
accuracy_score = accuracy_reward(predict_str, ground_truth)
return {
"overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
"format": format_score,
......
......@@ -13,8 +13,8 @@ pyarrow>=15.0.0
pylatexenc
qwen-vl-utils
ray[default]
tensordict
torchdata
transformers>=4.51.0
vllm>=0.7.3
wandb
orjson
tensorboard
\ No newline at end of file
......@@ -18,17 +18,9 @@ import re
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Tuple
import numpy as np
import torch
from torch.distributed._tensor import DTensor, Placement, Shard
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForTokenClassification,
AutoModelForVision2Seq,
PretrainedConfig,
PreTrainedModel,
)
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq
def merge_by_placement(tensors: List[torch.Tensor], placement: Placement):
......@@ -42,23 +34,14 @@ def merge_by_placement(tensors: List[torch.Tensor], placement: Placement):
raise ValueError(f"Unsupported placement: {placement}")
def upload_model_to_huggingface(local_path: str, remote_path: str):
# Push to hugging face
from huggingface_hub import HfApi
api = HfApi()
api.create_repo(repo_id=remote_path, private=False, exist_ok=True)
api.upload_folder(repo_id=remote_path, folder_path=local_path, repo_type="model")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local_dir", required=True, type=str, help="The path for your saved model")
parser.add_argument("--hf_upload_path", default=False, type=str, help="The path of the huggingface repo to upload")
args = parser.parse_args()
local_dir: str = args.local_dir
assert not local_dir.endswith("huggingface"), "The local_dir should not end with huggingface."
assert not args.local_dir.endswith("huggingface"), "The local_dir should not end with huggingface"
local_dir = args.local_dir
# copy rank zero to find the shape of (dp, fsdp)
rank = 0
......@@ -68,26 +51,22 @@ if __name__ == "__main__":
if match:
world_size = match.group(1)
break
assert world_size, "No model file with the proper format"
assert world_size, "No model file with the proper format."
rank0_weight_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt")
state_dict = torch.load(rank0_weight_path, map_location="cpu", weights_only=False)
state_dict = torch.load(
os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt"), map_location="cpu"
)
pivot_key = sorted(state_dict.keys())[0]
weight = state_dict[pivot_key]
if isinstance(weight, DTensor):
# get sharding info
device_mesh = weight.device_mesh
mesh = device_mesh.mesh
mesh_dim_names = device_mesh.mesh_dim_names
else:
# for non-DTensor
mesh = np.array([int(world_size)], dtype=np.int64)
mesh_dim_names = ("fsdp",)
assert isinstance(weight, torch.distributed._tensor.DTensor)
# get sharding info
device_mesh = weight.device_mesh
mesh = device_mesh.mesh
mesh_dim_names = device_mesh.mesh_dim_names
print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}")
assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}."
assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}"
if "tp" in mesh_dim_names:
# fsdp * tp
......@@ -98,12 +77,13 @@ if __name__ == "__main__":
total_shards = mesh.shape[-1]
mesh_shape = (mesh.shape[-1],)
print(f"Processing {total_shards} model shards in total.")
print(f"Processing model shards with {total_shards} {mesh_shape} in total")
model_state_dict_lst = []
model_state_dict_lst.append(state_dict)
model_state_dict_lst.extend([""] * (total_shards - 1))
def process_one_shard(rank, model_state_dict_lst):
def process_one_shard(rank):
model_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt")
state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
model_state_dict_lst[rank] = state_dict
......@@ -111,9 +91,8 @@ if __name__ == "__main__":
with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor:
for rank in range(1, total_shards):
executor.submit(process_one_shard, rank, model_state_dict_lst)
state_dict: Dict[str, List[torch.Tensor]] = {}
executor.submit(process_one_shard, rank)
state_dict = {}
param_placements: Dict[str, List[Placement]] = {}
keys = set(model_state_dict_lst[0].keys())
for key in keys:
......@@ -122,8 +101,8 @@ if __name__ == "__main__":
try:
tensor = model_state_dict.pop(key)
except Exception:
print(f"Cannot find key {key} in rank {rank}.")
print("-" * 30)
print(model_state_dict)
if isinstance(tensor, DTensor):
state_dict[key].append(tensor._local_tensor.bfloat16())
placements = tuple(tensor.placements)
......@@ -136,7 +115,7 @@ if __name__ == "__main__":
else:
assert param_placements[key] == placements
else:
state_dict[key].append(tensor.bfloat16())
state_dict[key] = tensor.bfloat16()
del model_state_dict_lst
......@@ -144,44 +123,43 @@ if __name__ == "__main__":
if not isinstance(state_dict[key], list):
print(f"No need to merge key {key}")
continue
if key in param_placements:
# merge shards
placements: Tuple[Shard] = param_placements[key]
if len(mesh_shape) == 1:
# 1-D list, FSDP without TP
assert len(placements) == 1
shards = state_dict[key]
state_dict[key] = merge_by_placement(shards, placements[0])
else:
# 2-D list, FSDP + TP
raise NotImplementedError("FSDP + TP is not supported yet.")
# merge shards
placements: Tuple[Shard] = param_placements[key]
if len(mesh_shape) == 1:
# 1-D list, FSDP without TP
assert len(placements) == 1
shards = state_dict[key]
state_dict[key] = merge_by_placement(shards, placements[0])
else:
state_dict[key] = torch.cat(state_dict[key], dim=0)
# 2-D list, FSDP + TP
raise NotImplementedError("FSDP + TP is not supported yet")
print("Merge completed.")
print("Writing to local disk")
hf_path = os.path.join(local_dir, "huggingface")
config: PretrainedConfig = AutoConfig.from_pretrained(hf_path)
architectures: List[str] = getattr(config, "architectures", ["Unknown"])
if "ForTokenClassification" in architectures[0]:
AutoClass = AutoModelForTokenClassification
elif "ForCausalLM" in architectures[0]:
AutoClass = AutoModelForCausalLM
elif "ForConditionalGeneration" in architectures[0]:
AutoClass = AutoModelForVision2Seq
config = AutoConfig.from_pretrained(hf_path)
if "ForTokenClassification" in config.architectures[0]:
auto_model = AutoModelForTokenClassification
elif "ForCausalLM" in config.architectures[0]:
auto_model = AutoModelForCausalLM
elif "ForConditionalGeneration" in config.architectures[0]:
auto_model = AutoModelForVision2Seq
else:
raise NotImplementedError(f"Unknown architecture {architectures}.")
raise NotImplementedError(f"Unknown architecture {config.architectures}")
with torch.device("meta"):
model: PreTrainedModel = AutoClass.from_config(config, torch_dtype=torch.bfloat16)
model = auto_model.from_config(config, torch_dtype=torch.bfloat16)
assert isinstance(model, PreTrainedModel)
model.to_empty(device="cpu")
print(f"Saving model to {hf_path}...")
print(f"Saving model to {hf_path}")
model.save_pretrained(hf_path, state_dict=state_dict)
del state_dict, model
del state_dict
del model
if args.hf_upload_path:
upload_model_to_huggingface(hf_path, args.hf_upload_path)
# Push to hugging face
from huggingface_hub import HfApi
api = HfApi()
api.create_repo(repo_id=args.hf_upload_path, private=False, exist_ok=True)
api.upload_folder(folder_path=hf_path, repo_id=args.hf_upload_path, repo_type="model")
......@@ -20,7 +20,7 @@ from .transformers.qwen2_vl import qwen2_vl_attn_forward
def apply_ulysses_patch(model_type: str) -> None:
if model_type in ("llama", "gemma", "gemma2", "mistral", "qwen2", "qwen3", "qwen3_moe"):
if model_type in ("llama", "gemma", "gemma2", "mistral", "qwen2"):
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward
elif model_type in ("qwen2_vl", "qwen2_5_vl"):
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2
......
......@@ -129,7 +129,7 @@ class Worker(WorkerHelper):
self._rank = rank
self._world_size = world_size
# DCU support
## DCU support
os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv("HIP_VISIBLE_DEVICES")
os.environ["LOCAL_RANK"] = os.getenv("RAY_LOCAL_RANK")
cuda_visible_devices = os.getenv("LOCAL_RANK", "0")
......
......@@ -43,7 +43,6 @@ class DataConfig:
rollout_batch_size: int = 512
val_batch_size: int = -1
format_prompt: Optional[str] = None
override_chat_template: Optional[str] = None
shuffle: bool = True
seed: int = 1
max_pixels: int = 4194304
......@@ -52,7 +51,7 @@ class DataConfig:
def post_init(self):
if self.format_prompt is not None:
if os.path.exists(self.format_prompt): # ray job uses absolute path
if os.path.exists(self.format_prompt):
self.format_prompt = os.path.abspath(self.format_prompt)
else:
self.format_prompt = None
......@@ -74,7 +73,7 @@ class AlgorithmConfig:
@dataclass
class TrainerConfig:
total_epochs: int = 10
total_episodes: int = 10
max_steps: Optional[int] = None
project_name: str = "easy_r1"
experiment_name: str = "demo"
......@@ -95,7 +94,7 @@ class TrainerConfig:
if self.save_checkpoint_path is None:
self.save_checkpoint_path = os.path.join("checkpoints", self.project_name, self.experiment_name)
self.save_checkpoint_path = os.path.abspath(self.save_checkpoint_path) # ray job uses absolute path
self.save_checkpoint_path = os.path.abspath(self.save_checkpoint_path)
if self.load_checkpoint_path is not None:
self.load_checkpoint_path = os.path.abspath(self.load_checkpoint_path)
......
......@@ -20,7 +20,7 @@ from omegaconf import OmegaConf
from ..single_controller.ray import RayWorkerGroup
from ..utils.tokenizer import get_processor, get_tokenizer
from ..workers.fsdp_workers import FSDPWorker
from ..workers.reward import BatchFunctionRewardManager, SequentialFunctionRewardManager
from ..workers.reward import FunctionRewardManager
from .config import PPOConfig
from .data_loader import create_dataloader
from .ray_trainer import RayPPOTrainer, ResourcePoolManager, Role
......@@ -38,13 +38,11 @@ class Runner:
# instantiate tokenizer
tokenizer = get_tokenizer(
config.worker.actor.model.model_path,
override_chat_template=config.data.override_chat_template,
trust_remote_code=config.worker.actor.model.trust_remote_code,
use_fast=True,
)
processor = get_processor(
config.worker.actor.model.model_path,
override_chat_template=config.data.override_chat_template,
trust_remote_code=config.worker.actor.model.trust_remote_code,
use_fast=True,
)
......@@ -67,18 +65,12 @@ class Runner:
}
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
if config.worker.reward.reward_type == "sequential":
RewardManager = SequentialFunctionRewardManager
elif config.worker.reward.reward_type == "batch":
RewardManager = BatchFunctionRewardManager
else:
raise NotImplementedError(f"Unknown reward type {config.worker.reward.reward_type}.")
RemoteRewardManager = ray.remote(RewardManager).options(num_cpus=config.worker.reward.num_cpus)
reward_fn = RemoteRewardManager.remote(config.worker.reward, tokenizer)
val_reward_fn = RemoteRewardManager.remote(config.worker.reward, tokenizer)
reward_fn = FunctionRewardManager(config=config.worker.reward, tokenizer=tokenizer)
val_reward_fn = FunctionRewardManager(config=config.worker.reward, tokenizer=tokenizer)
train_dataloader, val_dataloader = create_dataloader(config.data, tokenizer, processor)
train_dataloader, val_dataloader = create_dataloader(
config=config.data, tokenizer=tokenizer, processor=processor
)
trainer = RayPPOTrainer(
config=config,
......@@ -119,13 +111,14 @@ def main():
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:False",
}
}
# this is for local ray cluster
ray.init(runtime_env=runtime_env) # this is for local ray cluster
if torch.version.hip is not None:
ray.init(num_gpus=torch.cuda.device_count(),
ray.init(num_gpus=torch.cuda.device_count(), ## for dcu devices
ignore_reinit_error=True,
runtime_env=runtime_env)
else:
ray.init(runtime_env=runtime_env)
runner = Runner.remote()
ray.get(runner.run.remote(ppo_config))
......
......@@ -19,14 +19,16 @@ This trainer supports model-agonistic model initialization with huggingface
import os
import uuid
from collections import defaultdict
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass, field
from enum import Enum, IntEnum, auto
from typing import Any, Dict, List, Optional, Type
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
import numpy as np
import ray
import torch
from codetiming import Timer
from ray.experimental.tqdm_ray import tqdm
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import PreTrainedTokenizer, ProcessorMixin
......@@ -38,10 +40,9 @@ from ..single_controller.ray.base import create_colocated_worker_cls
from ..utils import torch_functional as VF
from ..utils.checkpoint import CHECKPOINT_TRACKER, remove_obsolete_ckpt
from ..utils.logger import Tracker
from ..utils.py_functional import convert_dict_to_str, timer
from ..utils.py_functional import convert_dict_to_str
from ..utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
from ..workers.fsdp_workers import FSDPWorker
from ..workers.reward import FunctionRewardManager
from . import core_algos
from .config import PPOConfig
from .metrics import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics, reduce_metrics
......@@ -161,6 +162,14 @@ def compute_advantage(data: DataProto, adv_estimator: AdvantageEstimator, gamma:
return data
@contextmanager
def _timer(name: str, timing_raw: Dict[str, float]):
with Timer(name=name, logger=None) as timer:
yield
timing_raw[name] = timer.last
class RayPPOTrainer:
"""
Note that this trainer runs on the driver process on a single CPU/GPU node.
......@@ -176,8 +185,8 @@ class RayPPOTrainer:
role_worker_mapping: dict[Role, Type[Worker]],
resource_pool_manager: ResourcePoolManager,
ray_worker_group_cls: Type[RayWorkerGroup] = RayWorkerGroup,
reward_fn: Optional[FunctionRewardManager] = None,
val_reward_fn: Optional[FunctionRewardManager] = None,
reward_fn: Optional[Callable[[DataProto], Tuple[torch.Tensor, Dict[str, List[float]]]]] = None,
val_reward_fn: Optional[Callable[[DataProto], Tuple[torch.Tensor, Dict[str, List[float]]]]] = None,
):
self.tokenizer = tokenizer
self.processor = processor
......@@ -247,7 +256,7 @@ class RayPPOTrainer:
if config.trainer.max_steps is not None:
self.training_steps = config.trainer.max_steps
else:
self.training_steps = len(train_dataloader) * config.trainer.total_epochs
self.training_steps = len(train_dataloader) * config.trainer.total_episodes
config.worker.actor.optim.training_steps = self.training_steps
config.worker.critic.optim.training_steps = self.training_steps
......@@ -298,6 +307,7 @@ class RayPPOTrainer:
test_gen_batch, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size)
test_output_gen_batch = self.actor_rollout_wg.generate_sequences(test_gen_batch)
test_output_gen_batch = unpad_dataproto(test_output_gen_batch, pad_size=pad_size)
print("validation generation end")
# Store generated outputs
output_ids = test_output_gen_batch.batch["responses"]
......@@ -307,7 +317,7 @@ class RayPPOTrainer:
test_batch = test_batch.union(test_output_gen_batch)
# evaluate using reward_function
reward_tensor, reward_metrics = ray.get(self.val_reward_fn.compute_reward.remote(test_batch))
reward_tensor, reward_metrics = self.val_reward_fn(test_batch)
# Store scores
scores = reward_tensor.sum(-1).cpu().tolist()
......@@ -473,7 +483,7 @@ class RayPPOTrainer:
if self.config.trainer.val_only:
return
for _ in tqdm(range(self.config.trainer.total_epochs), desc="Epoch", position=0):
for _ in tqdm(range(self.config.trainer.total_episodes), desc="Episode", position=0):
for batch_dict in tqdm(self.train_dataloader, desc="Running step", position=1):
self.global_step += 1
if self.global_step > self.training_steps:
......@@ -494,20 +504,20 @@ class RayPPOTrainer:
non_tensor_batch_keys=["raw_prompt_ids"],
)
with timer("step", timing_raw):
with _timer("step", timing_raw):
# generate a batch
with timer("gen", timing_raw): # wg: worker group
with _timer("gen", timing_raw): # wg: worker group
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
if self.config.algorithm.adv_estimator == "remax":
with timer("gen_max", timing_raw):
with _timer("gen_max", timing_raw):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["temperature"] = 0
gen_baseline_batch.meta_info["n"] = 1
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
batch = batch.union(gen_baseline_output)
reward_baseline_tensor, _ = ray.get(self.reward_fn.compute_reward.remote(batch))
reward_baseline_tensor, _ = self.reward_fn(batch)
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
......@@ -522,6 +532,19 @@ class RayPPOTrainer:
batch = batch.union(gen_batch_output)
batch.non_tensor_batch.pop("multi_modal_data", None)
# compute reward
with _timer("reward", timing_raw):
if self.use_reward_model:
raise NotImplementedError("Reward model is not supported yet.")
# we combine with rule-based rm
reward_tensor, reward_metrics = self.reward_fn(batch)
batch.batch["token_level_scores"] = reward_tensor
reward_metrics = {
f"reward/{key}": value for key, value in reduce_metrics(reward_metrics).items()
}
metrics.update(reward_metrics)
# balance the number of valid tokens on each dp rank.
# Note that this breaks the order of data inside the batch.
# Please take care when you implement group based adv computation such as GRPO and rloo
......@@ -530,38 +553,30 @@ class RayPPOTrainer:
# compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
# compute reward
with timer("reward", timing_raw):
reward_ref = self.reward_fn.compute_reward.remote(batch)
# recompute old_log_probs
with timer("old", timing_raw):
with _timer("old", timing_raw):
old_log_probs = self.actor_rollout_wg.compute_log_probs(batch)
batch = batch.union(old_log_probs)
# compute ref_log_probs
if self.use_reference_policy:
with timer("ref", timing_raw):
with _timer("ref", timing_raw):
ref_log_probs = self.ref_policy_wg.compute_ref_log_probs(batch)
batch = batch.union(ref_log_probs)
# compute values
if self.use_critic:
with timer("values", timing_raw):
with _timer("values", timing_raw):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
with timer("adv", timing_raw):
# get token level scores
reward_tensor, reward_metrics = ray.get(reward_ref)
batch.batch["token_level_scores"] = reward_tensor
reward_metrics = {f"reward/{k}": v for k, v in reduce_metrics(reward_metrics).items()}
metrics.update(reward_metrics)
with _timer("adv", timing_raw):
# apply kl penalty if available
if not self.config.algorithm.use_kl_loss and self.use_reference_policy:
# apply kl penalty to reward
batch, kl_metrics = apply_kl_penalty(batch, self.kl_ctrl, self.config.algorithm.kl_penalty)
batch, kl_metrics = apply_kl_penalty(
batch, kl_ctrl=self.kl_ctrl, kl_penalty=self.config.algorithm.kl_penalty
)
metrics.update(kl_metrics)
else:
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
......@@ -576,7 +591,7 @@ class RayPPOTrainer:
# update critic
if self.use_critic:
with timer("update_critic", timing_raw):
with _timer("update_critic", timing_raw):
critic_output = self.critic_wg.update_critic(batch)
critic_metrics = reduce_metrics(critic_output.non_tensor_batch)
......@@ -584,7 +599,7 @@ class RayPPOTrainer:
# update actor
if self.config.trainer.critic_warmup <= self.global_step:
with timer("update_actor", timing_raw):
with _timer("update_actor", timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_metrics = reduce_metrics(actor_output.non_tensor_batch)
......@@ -596,13 +611,13 @@ class RayPPOTrainer:
and self.config.trainer.val_freq > 0
and self.global_step % self.config.trainer.val_freq == 0
):
with timer("validation", timing_raw):
with _timer("validation", timing_raw):
val_metrics = self._validate()
metrics.update(val_metrics)
if self.config.trainer.save_freq > 0 and self.global_step % self.config.trainer.save_freq == 0:
with timer("save_checkpoint", timing_raw):
with _timer("save_checkpoint", timing_raw):
self._save_checkpoint()
# collect metrics
......
......@@ -55,13 +55,11 @@ class FSDPCheckpointManager(BaseCheckpointManager):
# every rank download its own checkpoint
model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt")
extra_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
print(f"[rank-{self.rank}]: Loading model from {os.path.abspath(model_path)}.")
print(f"[rank-{self.rank}]: Loading optimizer from {os.path.abspath(optim_path)}.")
print(f"[rank-{self.rank}]: Loading extra_state from {os.path.abspath(extra_path)}.")
extra_state_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
print(f"[rank-{self.rank}]: Loading from {model_path} and {optim_path} and {extra_state_path}.")
model_state_dict = torch.load(model_path, weights_only=False)
optim_state_dict = torch.load(optim_path, weights_only=False)
extra_state_dict = torch.load(extra_path, weights_only=False)
extra_state_dict = torch.load(extra_state_path, weights_only=False)
state_dict_options = StateDictOptions(cpu_offload=True)
set_state_dict(
......@@ -93,7 +91,7 @@ class FSDPCheckpointManager(BaseCheckpointManager):
extra_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}.")
print(f"[rank-{self.rank}]: Saving optimizer to {os.path.abspath(optim_path)}.")
print(f"[rank-{self.rank}]: Saving checkpoint to {os.path.abspath(model_path)}.")
print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}.")
torch.save(model_state_dict, model_path)
torch.save(optim_state_dict, optim_path)
......
......@@ -21,7 +21,7 @@ if TYPE_CHECKING:
from transformers.models.llama.configuration_llama import LlamaConfig
VALID_MODLE_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl", "qwen3"}
VALID_MODLE_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl"}
def get_device_flops(unit: str = "T") -> float:
......
......@@ -17,13 +17,11 @@ Contain small python utility functions
import importlib.util
import re
from contextlib import contextmanager
from functools import lru_cache
from typing import Any, Dict, List, Union
import numpy as np
import yaml
from codetiming import Timer
from yaml import Dumper
......@@ -103,11 +101,3 @@ def flatten_dict(data: Dict[str, Any], parent_key: str = "", sep: str = "/") ->
def convert_dict_to_str(data: Dict[str, Any]) -> str:
return yaml.dump(data, indent=2)
@contextmanager
def timer(name: str, timing_raw: Dict[str, float]):
with Timer(name=name, logger=None) as timer:
yield
timing_raw[name] = timer.last
......@@ -18,11 +18,9 @@ from typing import Optional
from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizer, ProcessorMixin
def get_tokenizer(model_path: str, override_chat_template: Optional[str] = None, **kwargs) -> PreTrainedTokenizer:
def get_tokenizer(model_path: str, **kwargs) -> PreTrainedTokenizer:
"""Create a huggingface pretrained tokenizer."""
tokenizer = AutoTokenizer.from_pretrained(model_path, **kwargs)
if override_chat_template is not None:
tokenizer.chat_template = override_chat_template
if tokenizer.bos_token == "<bos>" and tokenizer.eos_token == "<eos>":
# the EOS token in gemma2 & gemma3 is ambiguious, which may worsen RL performance.
......@@ -37,11 +35,12 @@ def get_tokenizer(model_path: str, override_chat_template: Optional[str] = None,
return tokenizer
def get_processor(model_path: str, override_chat_template: Optional[str] = None, **kwargs) -> Optional[ProcessorMixin]:
def get_processor(model_path: str, **kwargs) -> Optional[ProcessorMixin]:
"""Create a huggingface pretrained processor."""
processor = AutoProcessor.from_pretrained(model_path, **kwargs)
if override_chat_template is not None:
processor.chat_template = override_chat_template
try:
processor = AutoProcessor.from_pretrained(model_path, **kwargs)
except Exception:
processor = None
# Avoid load tokenizer, see:
# https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/auto/processing_auto.py#L344
......
......@@ -33,7 +33,7 @@ class ModelConfig:
if self.tokenizer_path is None:
self.tokenizer_path = self.model_path
if self.model_path is not None and os.path.exists(self.model_path): # ray job uses absolute path
if self.model_path is not None and os.path.exists(self.model_path):
self.model_path = os.path.abspath(self.model_path)
if self.tokenizer_path is not None and os.path.exists(self.tokenizer_path):
......
......@@ -15,8 +15,6 @@
The main entry point to run the PPO algorithm
"""
import os
from typing import Literal, Optional, Union
import numpy as np
......@@ -73,7 +71,6 @@ class FSDPWorker(Worker):
self.role = role
if not dist.is_initialized():
self.print_rank0("Initializing distributed process group...")
dist.init_process_group(backend="nccl")
# improve numerical stability
......@@ -284,7 +281,7 @@ class FSDPWorker(Worker):
if self._is_actor or self._is_critic:
if optim_config.strategy == "adamw":
self.optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, self.fsdp_module.parameters()),
self.fsdp_module.parameters(),
lr=optim_config.lr,
betas=optim_config.betas,
weight_decay=optim_config.weight_decay,
......@@ -292,7 +289,7 @@ class FSDPWorker(Worker):
)
elif optim_config.strategy == "adamw_bf16":
self.optimizer = AnyPrecisionAdamW(
filter(lambda p: p.requires_grad, self.fsdp_module.parameters()),
self.fsdp_module.parameters(),
lr=optim_config.lr,
betas=optim_config.betas,
weight_decay=optim_config.weight_decay,
......
......@@ -13,7 +13,7 @@
# limitations under the License.
from .config import RewardConfig
from .function import BatchFunctionRewardManager, FunctionRewardManager, SequentialFunctionRewardManager
from .function import FunctionRewardManager
__all__ = ["BatchFunctionRewardManager", "FunctionRewardManager", "RewardConfig", "SequentialFunctionRewardManager"]
__all__ = ["FunctionRewardManager", "RewardConfig"]
......@@ -22,22 +22,21 @@ from typing import Optional
@dataclass
class RewardConfig:
reward_type: str = "batch"
reward_function: Optional[str] = None
reward_function_kwargs: dict = field(default_factory=dict)
reward_type: str = "function"
score_function: Optional[str] = None
score_function_kwargs: dict = field(default_factory=dict)
skip_special_tokens: bool = True
num_cpus: int = 1
"""auto keys"""
reward_function_name: Optional[str] = field(default=None, init=False)
score_function_name: Optional[str] = field(default=None, init=False)
def post_init(self):
if self.reward_function is not None: # support custom reward function, e.g., ./math.py:main
if ":" not in self.reward_function:
self.reward_function_name = "main"
if self.score_function is not None:
if ":" not in self.score_function:
self.score_function_name = "main"
else:
self.reward_function, self.reward_function_name = self.reward_function.rsplit(":", maxsplit=1)
self.score_function, self.score_function_name = self.score_function.split(":", maxsplit=1)
if os.path.exists(self.reward_function): # ray job uses absolute path
self.reward_function = os.path.abspath(self.reward_function)
if os.path.exists(self.score_function):
self.score_function = os.path.abspath(self.score_function)
else:
self.reward_function = None
self.score_function = None
......@@ -15,8 +15,8 @@
import importlib.util
import os
import sys
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, TypedDict
......@@ -33,86 +33,54 @@ class RewardScore(TypedDict):
accuracy: Optional[float]
SequentialRewardFunction = Callable[[str, str], RewardScore]
ScoreFunction = Callable[[str, str], RewardScore]
BatchRewardFunction = Callable[[List[str], List[str]], List[RewardScore]]
@dataclass
class FunctionRewardManager:
config: RewardConfig
tokenizer: PreTrainedTokenizer
class FunctionRewardManager(ABC):
"""Reward manager for rule-based reward."""
def __post_init__(self):
"""Load score function."""
if self.config.score_function is None:
raise ValueError("Score function is not provided.")
def __init__(self, config: RewardConfig, tokenizer: PreTrainedTokenizer):
if config.reward_function is None:
raise ValueError("Reward function is not provided.")
if not os.path.exists(self.config.score_function):
raise FileNotFoundError(f"Score function file {self.config.score_function} not found.")
if not os.path.exists(config.reward_function):
raise FileNotFoundError(f"Reward function file {config.reward_function} not found.")
spec = importlib.util.spec_from_file_location("custom_reward_fn", config.reward_function)
spec = importlib.util.spec_from_file_location("custom_score_fn", self.config.score_function)
module = importlib.util.module_from_spec(spec)
try:
sys.modules["custom_reward_fn"] = module
sys.modules["custom_score_fn"] = module
spec.loader.exec_module(module)
except Exception as e:
raise RuntimeError(f"Failed to load reward function: {e}")
if not hasattr(module, config.reward_function_name):
raise AttributeError(f"Module {module} does not have function {config.reward_function_name}.")
reward_fn = getattr(module, config.reward_function_name)
print(f"Using reward function `{config.reward_function_name}` from `{config.reward_function}`.")
self.reward_fn = partial(reward_fn, **config.reward_function_kwargs)
self.config = config
self.tokenizer = tokenizer
raise RuntimeError(f"Failed to load score function: {e}")
@abstractmethod
def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
"""Compute reward for a batch of data."""
...
if not hasattr(module, self.config.score_function_name):
raise AttributeError(f"Module {module} does not have function {self.config.score_function_name}.")
score_fn: ScoreFunction = getattr(module, self.config.score_function_name)
print(f"Using score function `{self.config.score_function_name}` from `{self.config.score_function}`.")
self.score_fn = partial(score_fn, **self.config.score_function_kwargs)
class SequentialFunctionRewardManager(FunctionRewardManager):
reward_fn: SequentialRewardFunction
def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
def __call__(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
reward_metrics = defaultdict(list)
response_ids = data.batch["responses"]
response_length = data.batch["response_mask"].sum(dim=-1)
for i in range(len(data)):
valid_response_ids = response_ids[i][: response_length[i]]
data_item = data[i] # DataProtoItem
response_ids = data_item.batch["responses"]
response_mask = data_item.batch["response_mask"]
valid_response_length = response_mask.sum()
valid_response_ids = response_ids[:valid_response_length]
response_str = self.tokenizer.decode(
valid_response_ids, skip_special_tokens=self.config.skip_special_tokens
)
ground_truth = data.non_tensor_batch["ground_truth"][i]
ground_truth = data_item.non_tensor_batch["ground_truth"]
score = self.reward_fn(response_str, ground_truth)
reward_tensor[i, response_length[i] - 1] = score["overall"]
for key, value in score.items():
reward_metrics[key].append(value)
return reward_tensor, reward_metrics
class BatchFunctionRewardManager(FunctionRewardManager):
reward_fn: BatchRewardFunction
def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
response_str, ground_truth = [], []
response_ids = data.batch["responses"]
response_length = data.batch["response_mask"].sum(dim=-1)
for i in range(len(data)):
valid_response_ids = response_ids[i][: response_length[i]]
response_str.append(
self.tokenizer.decode(valid_response_ids, skip_special_tokens=self.config.skip_special_tokens)
)
ground_truth.append(data.non_tensor_batch["ground_truth"][i])
scores = self.reward_fn(response_str, ground_truth)
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
reward_metrics = defaultdict(list)
for i, score in enumerate(scores):
reward_tensor[i, response_length[i] - 1] = score["overall"]
score = self.score_fn(response_str, ground_truth)
reward_tensor[i, valid_response_length - 1] = score["overall"]
for key, value in score.items():
reward_metrics[key].append(value)
......
......@@ -69,7 +69,7 @@ class vLLMRollout(BaseRollout):
self.inference_engine = LLM(
model=model_path,
skip_tokenizer_init=False,
trust_remote_code=True,
trust_remote_code=config.trust_remote_code,
load_format="dummy",
dtype=PrecisionType.to_str(PrecisionType.to_dtype(config.dtype)),
seed=config.seed,
......@@ -85,11 +85,11 @@ class vLLMRollout(BaseRollout):
disable_mm_preprocessor_cache=True,
enable_chunked_prefill=config.enable_chunked_prefill,
enable_sleep_mode=False, # only support GPUs
cpu_offload_gb=64,
cpu_offload_gb=64
)
# Offload vllm model to reduce peak memory usage
# self.inference_engine.sleep(level=1)
self.inference_engine.sleep(level=1)
sampling_kwargs = {
"max_tokens": config.response_length,
......
......@@ -101,7 +101,6 @@ class FSDPVLLMShardingManager(BaseShardingManager):
print_gpu_memory_usage("Before vllm offload in sharding manager")
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
# self.inference_engine.sleep(level=1)
free_bytes_after_sleep = torch.cuda.mem_get_info()[0]
self.freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
print_gpu_memory_usage("After vllm offload in sharding manager")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment