Commit 496acb03 authored by chenych's avatar chenych
Browse files

v0.3.0

parent d8de2ca8
...@@ -18,16 +18,16 @@ from typing import Dict ...@@ -18,16 +18,16 @@ from typing import Dict
from mathruler.grader import grade_answer from mathruler.grader import grade_answer
def format_reward(predict: str) -> float: def format_reward(predict_str: str) -> float:
pattern = re.compile(r"<think>.*?</think>\s*<answer>.*?</answer>", re.DOTALL) pattern = re.compile(r"<think>.*?</think>\s*<answer>.*?</answer>", re.DOTALL)
format_match = re.fullmatch(pattern, predict) format_match = re.fullmatch(pattern, predict_str)
return 1.0 if format_match else 0.0 return 1.0 if format_match else 0.0
def accuracy_reward(predict: str, ground_truth: str) -> float: def accuracy_reward(predict_str: str, ground_truth: str) -> float:
try: try:
content_match = re.search(r"<answer>(.*?)</answer>", predict) content_match = re.search(r"<answer>(.*?)</answer>", predict_str)
given_answer = content_match.group(1).strip() if content_match else predict.strip() given_answer = content_match.group(1).strip() if content_match else predict_str.strip()
if grade_answer(given_answer, ground_truth.strip()): if grade_answer(given_answer, ground_truth.strip()):
return 1.0 return 1.0
...@@ -37,9 +37,9 @@ def accuracy_reward(predict: str, ground_truth: str) -> float: ...@@ -37,9 +37,9 @@ def accuracy_reward(predict: str, ground_truth: str) -> float:
return 0.0 return 0.0
def compute_score(predict: str, ground_truth: str, format_weight: float = 0.5) -> Dict[str, float]: def compute_score(predict_str: str, ground_truth: str, format_weight: float = 0.5) -> Dict[str, float]:
format_score = format_reward(predict) format_score = format_reward(predict_str)
accuracy_score = accuracy_reward(predict, ground_truth) accuracy_score = accuracy_reward(predict_str, ground_truth)
return { return {
"overall": (1 - format_weight) * accuracy_score + format_weight * format_score, "overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
"format": format_score, "format": format_score,
......
...@@ -13,8 +13,8 @@ pyarrow>=15.0.0 ...@@ -13,8 +13,8 @@ pyarrow>=15.0.0
pylatexenc pylatexenc
qwen-vl-utils qwen-vl-utils
ray[default] ray[default]
tensordict
torchdata torchdata
transformers>=4.51.0 transformers>=4.51.0
vllm>=0.7.3
wandb wandb
orjson
tensorboard
\ No newline at end of file
...@@ -18,17 +18,9 @@ import re ...@@ -18,17 +18,9 @@ import re
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import numpy as np
import torch import torch
from torch.distributed._tensor import DTensor, Placement, Shard from torch.distributed._tensor import DTensor, Placement, Shard
from transformers import ( from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq
AutoConfig,
AutoModelForCausalLM,
AutoModelForTokenClassification,
AutoModelForVision2Seq,
PretrainedConfig,
PreTrainedModel,
)
def merge_by_placement(tensors: List[torch.Tensor], placement: Placement): def merge_by_placement(tensors: List[torch.Tensor], placement: Placement):
...@@ -42,23 +34,14 @@ def merge_by_placement(tensors: List[torch.Tensor], placement: Placement): ...@@ -42,23 +34,14 @@ def merge_by_placement(tensors: List[torch.Tensor], placement: Placement):
raise ValueError(f"Unsupported placement: {placement}") raise ValueError(f"Unsupported placement: {placement}")
def upload_model_to_huggingface(local_path: str, remote_path: str):
# Push to hugging face
from huggingface_hub import HfApi
api = HfApi()
api.create_repo(repo_id=remote_path, private=False, exist_ok=True)
api.upload_folder(repo_id=remote_path, folder_path=local_path, repo_type="model")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--local_dir", required=True, type=str, help="The path for your saved model") parser.add_argument("--local_dir", required=True, type=str, help="The path for your saved model")
parser.add_argument("--hf_upload_path", default=False, type=str, help="The path of the huggingface repo to upload") parser.add_argument("--hf_upload_path", default=False, type=str, help="The path of the huggingface repo to upload")
args = parser.parse_args() args = parser.parse_args()
local_dir: str = args.local_dir
assert not local_dir.endswith("huggingface"), "The local_dir should not end with huggingface." assert not args.local_dir.endswith("huggingface"), "The local_dir should not end with huggingface"
local_dir = args.local_dir
# copy rank zero to find the shape of (dp, fsdp) # copy rank zero to find the shape of (dp, fsdp)
rank = 0 rank = 0
...@@ -68,26 +51,22 @@ if __name__ == "__main__": ...@@ -68,26 +51,22 @@ if __name__ == "__main__":
if match: if match:
world_size = match.group(1) world_size = match.group(1)
break break
assert world_size, "No model file with the proper format"
assert world_size, "No model file with the proper format." state_dict = torch.load(
os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt"), map_location="cpu"
rank0_weight_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt") )
state_dict = torch.load(rank0_weight_path, map_location="cpu", weights_only=False)
pivot_key = sorted(state_dict.keys())[0] pivot_key = sorted(state_dict.keys())[0]
weight = state_dict[pivot_key] weight = state_dict[pivot_key]
if isinstance(weight, DTensor): assert isinstance(weight, torch.distributed._tensor.DTensor)
# get sharding info # get sharding info
device_mesh = weight.device_mesh device_mesh = weight.device_mesh
mesh = device_mesh.mesh mesh = device_mesh.mesh
mesh_dim_names = device_mesh.mesh_dim_names mesh_dim_names = device_mesh.mesh_dim_names
else:
# for non-DTensor
mesh = np.array([int(world_size)], dtype=np.int64)
mesh_dim_names = ("fsdp",)
print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}") print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}")
assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}." assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}"
if "tp" in mesh_dim_names: if "tp" in mesh_dim_names:
# fsdp * tp # fsdp * tp
...@@ -98,12 +77,13 @@ if __name__ == "__main__": ...@@ -98,12 +77,13 @@ if __name__ == "__main__":
total_shards = mesh.shape[-1] total_shards = mesh.shape[-1]
mesh_shape = (mesh.shape[-1],) mesh_shape = (mesh.shape[-1],)
print(f"Processing {total_shards} model shards in total.") print(f"Processing model shards with {total_shards} {mesh_shape} in total")
model_state_dict_lst = [] model_state_dict_lst = []
model_state_dict_lst.append(state_dict) model_state_dict_lst.append(state_dict)
model_state_dict_lst.extend([""] * (total_shards - 1)) model_state_dict_lst.extend([""] * (total_shards - 1))
def process_one_shard(rank, model_state_dict_lst): def process_one_shard(rank):
model_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt") model_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt")
state_dict = torch.load(model_path, map_location="cpu", weights_only=False) state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
model_state_dict_lst[rank] = state_dict model_state_dict_lst[rank] = state_dict
...@@ -111,9 +91,8 @@ if __name__ == "__main__": ...@@ -111,9 +91,8 @@ if __name__ == "__main__":
with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor:
for rank in range(1, total_shards): for rank in range(1, total_shards):
executor.submit(process_one_shard, rank, model_state_dict_lst) executor.submit(process_one_shard, rank)
state_dict = {}
state_dict: Dict[str, List[torch.Tensor]] = {}
param_placements: Dict[str, List[Placement]] = {} param_placements: Dict[str, List[Placement]] = {}
keys = set(model_state_dict_lst[0].keys()) keys = set(model_state_dict_lst[0].keys())
for key in keys: for key in keys:
...@@ -122,8 +101,8 @@ if __name__ == "__main__": ...@@ -122,8 +101,8 @@ if __name__ == "__main__":
try: try:
tensor = model_state_dict.pop(key) tensor = model_state_dict.pop(key)
except Exception: except Exception:
print(f"Cannot find key {key} in rank {rank}.") print("-" * 30)
print(model_state_dict)
if isinstance(tensor, DTensor): if isinstance(tensor, DTensor):
state_dict[key].append(tensor._local_tensor.bfloat16()) state_dict[key].append(tensor._local_tensor.bfloat16())
placements = tuple(tensor.placements) placements = tuple(tensor.placements)
...@@ -136,7 +115,7 @@ if __name__ == "__main__": ...@@ -136,7 +115,7 @@ if __name__ == "__main__":
else: else:
assert param_placements[key] == placements assert param_placements[key] == placements
else: else:
state_dict[key].append(tensor.bfloat16()) state_dict[key] = tensor.bfloat16()
del model_state_dict_lst del model_state_dict_lst
...@@ -144,44 +123,43 @@ if __name__ == "__main__": ...@@ -144,44 +123,43 @@ if __name__ == "__main__":
if not isinstance(state_dict[key], list): if not isinstance(state_dict[key], list):
print(f"No need to merge key {key}") print(f"No need to merge key {key}")
continue continue
# merge shards
if key in param_placements: placements: Tuple[Shard] = param_placements[key]
# merge shards if len(mesh_shape) == 1:
placements: Tuple[Shard] = param_placements[key] # 1-D list, FSDP without TP
if len(mesh_shape) == 1: assert len(placements) == 1
# 1-D list, FSDP without TP shards = state_dict[key]
assert len(placements) == 1 state_dict[key] = merge_by_placement(shards, placements[0])
shards = state_dict[key]
state_dict[key] = merge_by_placement(shards, placements[0])
else:
# 2-D list, FSDP + TP
raise NotImplementedError("FSDP + TP is not supported yet.")
else: else:
state_dict[key] = torch.cat(state_dict[key], dim=0) # 2-D list, FSDP + TP
raise NotImplementedError("FSDP + TP is not supported yet")
print("Merge completed.") print("Writing to local disk")
hf_path = os.path.join(local_dir, "huggingface") hf_path = os.path.join(local_dir, "huggingface")
config: PretrainedConfig = AutoConfig.from_pretrained(hf_path) config = AutoConfig.from_pretrained(hf_path)
architectures: List[str] = getattr(config, "architectures", ["Unknown"])
if "ForTokenClassification" in config.architectures[0]:
if "ForTokenClassification" in architectures[0]: auto_model = AutoModelForTokenClassification
AutoClass = AutoModelForTokenClassification elif "ForCausalLM" in config.architectures[0]:
elif "ForCausalLM" in architectures[0]: auto_model = AutoModelForCausalLM
AutoClass = AutoModelForCausalLM elif "ForConditionalGeneration" in config.architectures[0]:
elif "ForConditionalGeneration" in architectures[0]: auto_model = AutoModelForVision2Seq
AutoClass = AutoModelForVision2Seq
else: else:
raise NotImplementedError(f"Unknown architecture {architectures}.") raise NotImplementedError(f"Unknown architecture {config.architectures}")
with torch.device("meta"): with torch.device("meta"):
model: PreTrainedModel = AutoClass.from_config(config, torch_dtype=torch.bfloat16) model = auto_model.from_config(config, torch_dtype=torch.bfloat16)
assert isinstance(model, PreTrainedModel)
model.to_empty(device="cpu") model.to_empty(device="cpu")
print(f"Saving model to {hf_path}...") print(f"Saving model to {hf_path}")
model.save_pretrained(hf_path, state_dict=state_dict) model.save_pretrained(hf_path, state_dict=state_dict)
del state_dict, model del state_dict
del model
if args.hf_upload_path: if args.hf_upload_path:
upload_model_to_huggingface(hf_path, args.hf_upload_path) # Push to hugging face
from huggingface_hub import HfApi
api = HfApi()
api.create_repo(repo_id=args.hf_upload_path, private=False, exist_ok=True)
api.upload_folder(folder_path=hf_path, repo_id=args.hf_upload_path, repo_type="model")
...@@ -20,7 +20,7 @@ from .transformers.qwen2_vl import qwen2_vl_attn_forward ...@@ -20,7 +20,7 @@ from .transformers.qwen2_vl import qwen2_vl_attn_forward
def apply_ulysses_patch(model_type: str) -> None: def apply_ulysses_patch(model_type: str) -> None:
if model_type in ("llama", "gemma", "gemma2", "mistral", "qwen2", "qwen3", "qwen3_moe"): if model_type in ("llama", "gemma", "gemma2", "mistral", "qwen2"):
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward
elif model_type in ("qwen2_vl", "qwen2_5_vl"): elif model_type in ("qwen2_vl", "qwen2_5_vl"):
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2 from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2
......
...@@ -129,7 +129,7 @@ class Worker(WorkerHelper): ...@@ -129,7 +129,7 @@ class Worker(WorkerHelper):
self._rank = rank self._rank = rank
self._world_size = world_size self._world_size = world_size
# DCU support ## DCU support
os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv("HIP_VISIBLE_DEVICES") os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv("HIP_VISIBLE_DEVICES")
os.environ["LOCAL_RANK"] = os.getenv("RAY_LOCAL_RANK") os.environ["LOCAL_RANK"] = os.getenv("RAY_LOCAL_RANK")
cuda_visible_devices = os.getenv("LOCAL_RANK", "0") cuda_visible_devices = os.getenv("LOCAL_RANK", "0")
......
...@@ -43,7 +43,6 @@ class DataConfig: ...@@ -43,7 +43,6 @@ class DataConfig:
rollout_batch_size: int = 512 rollout_batch_size: int = 512
val_batch_size: int = -1 val_batch_size: int = -1
format_prompt: Optional[str] = None format_prompt: Optional[str] = None
override_chat_template: Optional[str] = None
shuffle: bool = True shuffle: bool = True
seed: int = 1 seed: int = 1
max_pixels: int = 4194304 max_pixels: int = 4194304
...@@ -52,7 +51,7 @@ class DataConfig: ...@@ -52,7 +51,7 @@ class DataConfig:
def post_init(self): def post_init(self):
if self.format_prompt is not None: if self.format_prompt is not None:
if os.path.exists(self.format_prompt): # ray job uses absolute path if os.path.exists(self.format_prompt):
self.format_prompt = os.path.abspath(self.format_prompt) self.format_prompt = os.path.abspath(self.format_prompt)
else: else:
self.format_prompt = None self.format_prompt = None
...@@ -74,7 +73,7 @@ class AlgorithmConfig: ...@@ -74,7 +73,7 @@ class AlgorithmConfig:
@dataclass @dataclass
class TrainerConfig: class TrainerConfig:
total_epochs: int = 10 total_episodes: int = 10
max_steps: Optional[int] = None max_steps: Optional[int] = None
project_name: str = "easy_r1" project_name: str = "easy_r1"
experiment_name: str = "demo" experiment_name: str = "demo"
...@@ -95,7 +94,7 @@ class TrainerConfig: ...@@ -95,7 +94,7 @@ class TrainerConfig:
if self.save_checkpoint_path is None: if self.save_checkpoint_path is None:
self.save_checkpoint_path = os.path.join("checkpoints", self.project_name, self.experiment_name) self.save_checkpoint_path = os.path.join("checkpoints", self.project_name, self.experiment_name)
self.save_checkpoint_path = os.path.abspath(self.save_checkpoint_path) # ray job uses absolute path self.save_checkpoint_path = os.path.abspath(self.save_checkpoint_path)
if self.load_checkpoint_path is not None: if self.load_checkpoint_path is not None:
self.load_checkpoint_path = os.path.abspath(self.load_checkpoint_path) self.load_checkpoint_path = os.path.abspath(self.load_checkpoint_path)
......
...@@ -20,7 +20,7 @@ from omegaconf import OmegaConf ...@@ -20,7 +20,7 @@ from omegaconf import OmegaConf
from ..single_controller.ray import RayWorkerGroup from ..single_controller.ray import RayWorkerGroup
from ..utils.tokenizer import get_processor, get_tokenizer from ..utils.tokenizer import get_processor, get_tokenizer
from ..workers.fsdp_workers import FSDPWorker from ..workers.fsdp_workers import FSDPWorker
from ..workers.reward import BatchFunctionRewardManager, SequentialFunctionRewardManager from ..workers.reward import FunctionRewardManager
from .config import PPOConfig from .config import PPOConfig
from .data_loader import create_dataloader from .data_loader import create_dataloader
from .ray_trainer import RayPPOTrainer, ResourcePoolManager, Role from .ray_trainer import RayPPOTrainer, ResourcePoolManager, Role
...@@ -38,13 +38,11 @@ class Runner: ...@@ -38,13 +38,11 @@ class Runner:
# instantiate tokenizer # instantiate tokenizer
tokenizer = get_tokenizer( tokenizer = get_tokenizer(
config.worker.actor.model.model_path, config.worker.actor.model.model_path,
override_chat_template=config.data.override_chat_template,
trust_remote_code=config.worker.actor.model.trust_remote_code, trust_remote_code=config.worker.actor.model.trust_remote_code,
use_fast=True, use_fast=True,
) )
processor = get_processor( processor = get_processor(
config.worker.actor.model.model_path, config.worker.actor.model.model_path,
override_chat_template=config.data.override_chat_template,
trust_remote_code=config.worker.actor.model.trust_remote_code, trust_remote_code=config.worker.actor.model.trust_remote_code,
use_fast=True, use_fast=True,
) )
...@@ -67,18 +65,12 @@ class Runner: ...@@ -67,18 +65,12 @@ class Runner:
} }
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
if config.worker.reward.reward_type == "sequential": reward_fn = FunctionRewardManager(config=config.worker.reward, tokenizer=tokenizer)
RewardManager = SequentialFunctionRewardManager val_reward_fn = FunctionRewardManager(config=config.worker.reward, tokenizer=tokenizer)
elif config.worker.reward.reward_type == "batch":
RewardManager = BatchFunctionRewardManager
else:
raise NotImplementedError(f"Unknown reward type {config.worker.reward.reward_type}.")
RemoteRewardManager = ray.remote(RewardManager).options(num_cpus=config.worker.reward.num_cpus)
reward_fn = RemoteRewardManager.remote(config.worker.reward, tokenizer)
val_reward_fn = RemoteRewardManager.remote(config.worker.reward, tokenizer)
train_dataloader, val_dataloader = create_dataloader(config.data, tokenizer, processor) train_dataloader, val_dataloader = create_dataloader(
config=config.data, tokenizer=tokenizer, processor=processor
)
trainer = RayPPOTrainer( trainer = RayPPOTrainer(
config=config, config=config,
...@@ -119,13 +111,14 @@ def main(): ...@@ -119,13 +111,14 @@ def main():
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:False", "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:False",
} }
} }
# this is for local ray cluster ray.init(runtime_env=runtime_env) # this is for local ray cluster
if torch.version.hip is not None: if torch.version.hip is not None:
ray.init(num_gpus=torch.cuda.device_count(), ray.init(num_gpus=torch.cuda.device_count(), ## for dcu devices
ignore_reinit_error=True, ignore_reinit_error=True,
runtime_env=runtime_env) runtime_env=runtime_env)
else: else:
ray.init(runtime_env=runtime_env) ray.init(runtime_env=runtime_env)
runner = Runner.remote() runner = Runner.remote()
ray.get(runner.run.remote(ppo_config)) ray.get(runner.run.remote(ppo_config))
......
...@@ -19,14 +19,16 @@ This trainer supports model-agonistic model initialization with huggingface ...@@ -19,14 +19,16 @@ This trainer supports model-agonistic model initialization with huggingface
import os import os
import uuid import uuid
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum, IntEnum, auto from enum import Enum, IntEnum, auto
from typing import Any, Dict, List, Optional, Type from typing import Any, Callable, Dict, List, Optional, Tuple, Type
import numpy as np import numpy as np
import ray import ray
import torch import torch
from codetiming import Timer
from ray.experimental.tqdm_ray import tqdm from ray.experimental.tqdm_ray import tqdm
from torchdata.stateful_dataloader import StatefulDataLoader from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
...@@ -38,10 +40,9 @@ from ..single_controller.ray.base import create_colocated_worker_cls ...@@ -38,10 +40,9 @@ from ..single_controller.ray.base import create_colocated_worker_cls
from ..utils import torch_functional as VF from ..utils import torch_functional as VF
from ..utils.checkpoint import CHECKPOINT_TRACKER, remove_obsolete_ckpt from ..utils.checkpoint import CHECKPOINT_TRACKER, remove_obsolete_ckpt
from ..utils.logger import Tracker from ..utils.logger import Tracker
from ..utils.py_functional import convert_dict_to_str, timer from ..utils.py_functional import convert_dict_to_str
from ..utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance from ..utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
from ..workers.fsdp_workers import FSDPWorker from ..workers.fsdp_workers import FSDPWorker
from ..workers.reward import FunctionRewardManager
from . import core_algos from . import core_algos
from .config import PPOConfig from .config import PPOConfig
from .metrics import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics, reduce_metrics from .metrics import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics, reduce_metrics
...@@ -161,6 +162,14 @@ def compute_advantage(data: DataProto, adv_estimator: AdvantageEstimator, gamma: ...@@ -161,6 +162,14 @@ def compute_advantage(data: DataProto, adv_estimator: AdvantageEstimator, gamma:
return data return data
@contextmanager
def _timer(name: str, timing_raw: Dict[str, float]):
with Timer(name=name, logger=None) as timer:
yield
timing_raw[name] = timer.last
class RayPPOTrainer: class RayPPOTrainer:
""" """
Note that this trainer runs on the driver process on a single CPU/GPU node. Note that this trainer runs on the driver process on a single CPU/GPU node.
...@@ -176,8 +185,8 @@ class RayPPOTrainer: ...@@ -176,8 +185,8 @@ class RayPPOTrainer:
role_worker_mapping: dict[Role, Type[Worker]], role_worker_mapping: dict[Role, Type[Worker]],
resource_pool_manager: ResourcePoolManager, resource_pool_manager: ResourcePoolManager,
ray_worker_group_cls: Type[RayWorkerGroup] = RayWorkerGroup, ray_worker_group_cls: Type[RayWorkerGroup] = RayWorkerGroup,
reward_fn: Optional[FunctionRewardManager] = None, reward_fn: Optional[Callable[[DataProto], Tuple[torch.Tensor, Dict[str, List[float]]]]] = None,
val_reward_fn: Optional[FunctionRewardManager] = None, val_reward_fn: Optional[Callable[[DataProto], Tuple[torch.Tensor, Dict[str, List[float]]]]] = None,
): ):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.processor = processor self.processor = processor
...@@ -247,7 +256,7 @@ class RayPPOTrainer: ...@@ -247,7 +256,7 @@ class RayPPOTrainer:
if config.trainer.max_steps is not None: if config.trainer.max_steps is not None:
self.training_steps = config.trainer.max_steps self.training_steps = config.trainer.max_steps
else: else:
self.training_steps = len(train_dataloader) * config.trainer.total_epochs self.training_steps = len(train_dataloader) * config.trainer.total_episodes
config.worker.actor.optim.training_steps = self.training_steps config.worker.actor.optim.training_steps = self.training_steps
config.worker.critic.optim.training_steps = self.training_steps config.worker.critic.optim.training_steps = self.training_steps
...@@ -298,6 +307,7 @@ class RayPPOTrainer: ...@@ -298,6 +307,7 @@ class RayPPOTrainer:
test_gen_batch, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size) test_gen_batch, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size)
test_output_gen_batch = self.actor_rollout_wg.generate_sequences(test_gen_batch) test_output_gen_batch = self.actor_rollout_wg.generate_sequences(test_gen_batch)
test_output_gen_batch = unpad_dataproto(test_output_gen_batch, pad_size=pad_size) test_output_gen_batch = unpad_dataproto(test_output_gen_batch, pad_size=pad_size)
print("validation generation end")
# Store generated outputs # Store generated outputs
output_ids = test_output_gen_batch.batch["responses"] output_ids = test_output_gen_batch.batch["responses"]
...@@ -307,7 +317,7 @@ class RayPPOTrainer: ...@@ -307,7 +317,7 @@ class RayPPOTrainer:
test_batch = test_batch.union(test_output_gen_batch) test_batch = test_batch.union(test_output_gen_batch)
# evaluate using reward_function # evaluate using reward_function
reward_tensor, reward_metrics = ray.get(self.val_reward_fn.compute_reward.remote(test_batch)) reward_tensor, reward_metrics = self.val_reward_fn(test_batch)
# Store scores # Store scores
scores = reward_tensor.sum(-1).cpu().tolist() scores = reward_tensor.sum(-1).cpu().tolist()
...@@ -473,7 +483,7 @@ class RayPPOTrainer: ...@@ -473,7 +483,7 @@ class RayPPOTrainer:
if self.config.trainer.val_only: if self.config.trainer.val_only:
return return
for _ in tqdm(range(self.config.trainer.total_epochs), desc="Epoch", position=0): for _ in tqdm(range(self.config.trainer.total_episodes), desc="Episode", position=0):
for batch_dict in tqdm(self.train_dataloader, desc="Running step", position=1): for batch_dict in tqdm(self.train_dataloader, desc="Running step", position=1):
self.global_step += 1 self.global_step += 1
if self.global_step > self.training_steps: if self.global_step > self.training_steps:
...@@ -494,20 +504,20 @@ class RayPPOTrainer: ...@@ -494,20 +504,20 @@ class RayPPOTrainer:
non_tensor_batch_keys=["raw_prompt_ids"], non_tensor_batch_keys=["raw_prompt_ids"],
) )
with timer("step", timing_raw): with _timer("step", timing_raw):
# generate a batch # generate a batch
with timer("gen", timing_raw): # wg: worker group with _timer("gen", timing_raw): # wg: worker group
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
if self.config.algorithm.adv_estimator == "remax": if self.config.algorithm.adv_estimator == "remax":
with timer("gen_max", timing_raw): with _timer("gen_max", timing_raw):
gen_baseline_batch = deepcopy(gen_batch) gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["temperature"] = 0 gen_baseline_batch.meta_info["temperature"] = 0
gen_baseline_batch.meta_info["n"] = 1 gen_baseline_batch.meta_info["n"] = 1
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
batch = batch.union(gen_baseline_output) batch = batch.union(gen_baseline_output)
reward_baseline_tensor, _ = ray.get(self.reward_fn.compute_reward.remote(batch)) reward_baseline_tensor, _ = self.reward_fn(batch)
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
...@@ -522,6 +532,19 @@ class RayPPOTrainer: ...@@ -522,6 +532,19 @@ class RayPPOTrainer:
batch = batch.union(gen_batch_output) batch = batch.union(gen_batch_output)
batch.non_tensor_batch.pop("multi_modal_data", None) batch.non_tensor_batch.pop("multi_modal_data", None)
# compute reward
with _timer("reward", timing_raw):
if self.use_reward_model:
raise NotImplementedError("Reward model is not supported yet.")
# we combine with rule-based rm
reward_tensor, reward_metrics = self.reward_fn(batch)
batch.batch["token_level_scores"] = reward_tensor
reward_metrics = {
f"reward/{key}": value for key, value in reduce_metrics(reward_metrics).items()
}
metrics.update(reward_metrics)
# balance the number of valid tokens on each dp rank. # balance the number of valid tokens on each dp rank.
# Note that this breaks the order of data inside the batch. # Note that this breaks the order of data inside the batch.
# Please take care when you implement group based adv computation such as GRPO and rloo # Please take care when you implement group based adv computation such as GRPO and rloo
...@@ -530,38 +553,30 @@ class RayPPOTrainer: ...@@ -530,38 +553,30 @@ class RayPPOTrainer:
# compute global_valid tokens # compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
# compute reward
with timer("reward", timing_raw):
reward_ref = self.reward_fn.compute_reward.remote(batch)
# recompute old_log_probs # recompute old_log_probs
with timer("old", timing_raw): with _timer("old", timing_raw):
old_log_probs = self.actor_rollout_wg.compute_log_probs(batch) old_log_probs = self.actor_rollout_wg.compute_log_probs(batch)
batch = batch.union(old_log_probs) batch = batch.union(old_log_probs)
# compute ref_log_probs # compute ref_log_probs
if self.use_reference_policy: if self.use_reference_policy:
with timer("ref", timing_raw): with _timer("ref", timing_raw):
ref_log_probs = self.ref_policy_wg.compute_ref_log_probs(batch) ref_log_probs = self.ref_policy_wg.compute_ref_log_probs(batch)
batch = batch.union(ref_log_probs) batch = batch.union(ref_log_probs)
# compute values # compute values
if self.use_critic: if self.use_critic:
with timer("values", timing_raw): with _timer("values", timing_raw):
values = self.critic_wg.compute_values(batch) values = self.critic_wg.compute_values(batch)
batch = batch.union(values) batch = batch.union(values)
with timer("adv", timing_raw): with _timer("adv", timing_raw):
# get token level scores
reward_tensor, reward_metrics = ray.get(reward_ref)
batch.batch["token_level_scores"] = reward_tensor
reward_metrics = {f"reward/{k}": v for k, v in reduce_metrics(reward_metrics).items()}
metrics.update(reward_metrics)
# apply kl penalty if available # apply kl penalty if available
if not self.config.algorithm.use_kl_loss and self.use_reference_policy: if not self.config.algorithm.use_kl_loss and self.use_reference_policy:
# apply kl penalty to reward # apply kl penalty to reward
batch, kl_metrics = apply_kl_penalty(batch, self.kl_ctrl, self.config.algorithm.kl_penalty) batch, kl_metrics = apply_kl_penalty(
batch, kl_ctrl=self.kl_ctrl, kl_penalty=self.config.algorithm.kl_penalty
)
metrics.update(kl_metrics) metrics.update(kl_metrics)
else: else:
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
...@@ -576,7 +591,7 @@ class RayPPOTrainer: ...@@ -576,7 +591,7 @@ class RayPPOTrainer:
# update critic # update critic
if self.use_critic: if self.use_critic:
with timer("update_critic", timing_raw): with _timer("update_critic", timing_raw):
critic_output = self.critic_wg.update_critic(batch) critic_output = self.critic_wg.update_critic(batch)
critic_metrics = reduce_metrics(critic_output.non_tensor_batch) critic_metrics = reduce_metrics(critic_output.non_tensor_batch)
...@@ -584,7 +599,7 @@ class RayPPOTrainer: ...@@ -584,7 +599,7 @@ class RayPPOTrainer:
# update actor # update actor
if self.config.trainer.critic_warmup <= self.global_step: if self.config.trainer.critic_warmup <= self.global_step:
with timer("update_actor", timing_raw): with _timer("update_actor", timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch) actor_output = self.actor_rollout_wg.update_actor(batch)
actor_metrics = reduce_metrics(actor_output.non_tensor_batch) actor_metrics = reduce_metrics(actor_output.non_tensor_batch)
...@@ -596,13 +611,13 @@ class RayPPOTrainer: ...@@ -596,13 +611,13 @@ class RayPPOTrainer:
and self.config.trainer.val_freq > 0 and self.config.trainer.val_freq > 0
and self.global_step % self.config.trainer.val_freq == 0 and self.global_step % self.config.trainer.val_freq == 0
): ):
with timer("validation", timing_raw): with _timer("validation", timing_raw):
val_metrics = self._validate() val_metrics = self._validate()
metrics.update(val_metrics) metrics.update(val_metrics)
if self.config.trainer.save_freq > 0 and self.global_step % self.config.trainer.save_freq == 0: if self.config.trainer.save_freq > 0 and self.global_step % self.config.trainer.save_freq == 0:
with timer("save_checkpoint", timing_raw): with _timer("save_checkpoint", timing_raw):
self._save_checkpoint() self._save_checkpoint()
# collect metrics # collect metrics
......
...@@ -55,13 +55,11 @@ class FSDPCheckpointManager(BaseCheckpointManager): ...@@ -55,13 +55,11 @@ class FSDPCheckpointManager(BaseCheckpointManager):
# every rank download its own checkpoint # every rank download its own checkpoint
model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt")
extra_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt") extra_state_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
print(f"[rank-{self.rank}]: Loading model from {os.path.abspath(model_path)}.") print(f"[rank-{self.rank}]: Loading from {model_path} and {optim_path} and {extra_state_path}.")
print(f"[rank-{self.rank}]: Loading optimizer from {os.path.abspath(optim_path)}.")
print(f"[rank-{self.rank}]: Loading extra_state from {os.path.abspath(extra_path)}.")
model_state_dict = torch.load(model_path, weights_only=False) model_state_dict = torch.load(model_path, weights_only=False)
optim_state_dict = torch.load(optim_path, weights_only=False) optim_state_dict = torch.load(optim_path, weights_only=False)
extra_state_dict = torch.load(extra_path, weights_only=False) extra_state_dict = torch.load(extra_state_path, weights_only=False)
state_dict_options = StateDictOptions(cpu_offload=True) state_dict_options = StateDictOptions(cpu_offload=True)
set_state_dict( set_state_dict(
...@@ -93,7 +91,7 @@ class FSDPCheckpointManager(BaseCheckpointManager): ...@@ -93,7 +91,7 @@ class FSDPCheckpointManager(BaseCheckpointManager):
extra_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt") extra_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}.") print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}.")
print(f"[rank-{self.rank}]: Saving optimizer to {os.path.abspath(optim_path)}.") print(f"[rank-{self.rank}]: Saving checkpoint to {os.path.abspath(model_path)}.")
print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}.") print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}.")
torch.save(model_state_dict, model_path) torch.save(model_state_dict, model_path)
torch.save(optim_state_dict, optim_path) torch.save(optim_state_dict, optim_path)
......
...@@ -21,7 +21,7 @@ if TYPE_CHECKING: ...@@ -21,7 +21,7 @@ if TYPE_CHECKING:
from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.configuration_llama import LlamaConfig
VALID_MODLE_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl", "qwen3"} VALID_MODLE_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl"}
def get_device_flops(unit: str = "T") -> float: def get_device_flops(unit: str = "T") -> float:
......
...@@ -17,13 +17,11 @@ Contain small python utility functions ...@@ -17,13 +17,11 @@ Contain small python utility functions
import importlib.util import importlib.util
import re import re
from contextlib import contextmanager
from functools import lru_cache from functools import lru_cache
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
import numpy as np import numpy as np
import yaml import yaml
from codetiming import Timer
from yaml import Dumper from yaml import Dumper
...@@ -103,11 +101,3 @@ def flatten_dict(data: Dict[str, Any], parent_key: str = "", sep: str = "/") -> ...@@ -103,11 +101,3 @@ def flatten_dict(data: Dict[str, Any], parent_key: str = "", sep: str = "/") ->
def convert_dict_to_str(data: Dict[str, Any]) -> str: def convert_dict_to_str(data: Dict[str, Any]) -> str:
return yaml.dump(data, indent=2) return yaml.dump(data, indent=2)
@contextmanager
def timer(name: str, timing_raw: Dict[str, float]):
with Timer(name=name, logger=None) as timer:
yield
timing_raw[name] = timer.last
...@@ -18,11 +18,9 @@ from typing import Optional ...@@ -18,11 +18,9 @@ from typing import Optional
from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizer, ProcessorMixin from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizer, ProcessorMixin
def get_tokenizer(model_path: str, override_chat_template: Optional[str] = None, **kwargs) -> PreTrainedTokenizer: def get_tokenizer(model_path: str, **kwargs) -> PreTrainedTokenizer:
"""Create a huggingface pretrained tokenizer.""" """Create a huggingface pretrained tokenizer."""
tokenizer = AutoTokenizer.from_pretrained(model_path, **kwargs) tokenizer = AutoTokenizer.from_pretrained(model_path, **kwargs)
if override_chat_template is not None:
tokenizer.chat_template = override_chat_template
if tokenizer.bos_token == "<bos>" and tokenizer.eos_token == "<eos>": if tokenizer.bos_token == "<bos>" and tokenizer.eos_token == "<eos>":
# the EOS token in gemma2 & gemma3 is ambiguious, which may worsen RL performance. # the EOS token in gemma2 & gemma3 is ambiguious, which may worsen RL performance.
...@@ -37,11 +35,12 @@ def get_tokenizer(model_path: str, override_chat_template: Optional[str] = None, ...@@ -37,11 +35,12 @@ def get_tokenizer(model_path: str, override_chat_template: Optional[str] = None,
return tokenizer return tokenizer
def get_processor(model_path: str, override_chat_template: Optional[str] = None, **kwargs) -> Optional[ProcessorMixin]: def get_processor(model_path: str, **kwargs) -> Optional[ProcessorMixin]:
"""Create a huggingface pretrained processor.""" """Create a huggingface pretrained processor."""
processor = AutoProcessor.from_pretrained(model_path, **kwargs) try:
if override_chat_template is not None: processor = AutoProcessor.from_pretrained(model_path, **kwargs)
processor.chat_template = override_chat_template except Exception:
processor = None
# Avoid load tokenizer, see: # Avoid load tokenizer, see:
# https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/auto/processing_auto.py#L344 # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/auto/processing_auto.py#L344
......
...@@ -33,7 +33,7 @@ class ModelConfig: ...@@ -33,7 +33,7 @@ class ModelConfig:
if self.tokenizer_path is None: if self.tokenizer_path is None:
self.tokenizer_path = self.model_path self.tokenizer_path = self.model_path
if self.model_path is not None and os.path.exists(self.model_path): # ray job uses absolute path if self.model_path is not None and os.path.exists(self.model_path):
self.model_path = os.path.abspath(self.model_path) self.model_path = os.path.abspath(self.model_path)
if self.tokenizer_path is not None and os.path.exists(self.tokenizer_path): if self.tokenizer_path is not None and os.path.exists(self.tokenizer_path):
......
...@@ -15,8 +15,6 @@ ...@@ -15,8 +15,6 @@
The main entry point to run the PPO algorithm The main entry point to run the PPO algorithm
""" """
import os
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
import numpy as np import numpy as np
...@@ -73,7 +71,6 @@ class FSDPWorker(Worker): ...@@ -73,7 +71,6 @@ class FSDPWorker(Worker):
self.role = role self.role = role
if not dist.is_initialized(): if not dist.is_initialized():
self.print_rank0("Initializing distributed process group...")
dist.init_process_group(backend="nccl") dist.init_process_group(backend="nccl")
# improve numerical stability # improve numerical stability
...@@ -284,7 +281,7 @@ class FSDPWorker(Worker): ...@@ -284,7 +281,7 @@ class FSDPWorker(Worker):
if self._is_actor or self._is_critic: if self._is_actor or self._is_critic:
if optim_config.strategy == "adamw": if optim_config.strategy == "adamw":
self.optimizer = torch.optim.AdamW( self.optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, self.fsdp_module.parameters()), self.fsdp_module.parameters(),
lr=optim_config.lr, lr=optim_config.lr,
betas=optim_config.betas, betas=optim_config.betas,
weight_decay=optim_config.weight_decay, weight_decay=optim_config.weight_decay,
...@@ -292,7 +289,7 @@ class FSDPWorker(Worker): ...@@ -292,7 +289,7 @@ class FSDPWorker(Worker):
) )
elif optim_config.strategy == "adamw_bf16": elif optim_config.strategy == "adamw_bf16":
self.optimizer = AnyPrecisionAdamW( self.optimizer = AnyPrecisionAdamW(
filter(lambda p: p.requires_grad, self.fsdp_module.parameters()), self.fsdp_module.parameters(),
lr=optim_config.lr, lr=optim_config.lr,
betas=optim_config.betas, betas=optim_config.betas,
weight_decay=optim_config.weight_decay, weight_decay=optim_config.weight_decay,
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from .config import RewardConfig from .config import RewardConfig
from .function import BatchFunctionRewardManager, FunctionRewardManager, SequentialFunctionRewardManager from .function import FunctionRewardManager
__all__ = ["BatchFunctionRewardManager", "FunctionRewardManager", "RewardConfig", "SequentialFunctionRewardManager"] __all__ = ["FunctionRewardManager", "RewardConfig"]
...@@ -22,22 +22,21 @@ from typing import Optional ...@@ -22,22 +22,21 @@ from typing import Optional
@dataclass @dataclass
class RewardConfig: class RewardConfig:
reward_type: str = "batch" reward_type: str = "function"
reward_function: Optional[str] = None score_function: Optional[str] = None
reward_function_kwargs: dict = field(default_factory=dict) score_function_kwargs: dict = field(default_factory=dict)
skip_special_tokens: bool = True skip_special_tokens: bool = True
num_cpus: int = 1
"""auto keys""" """auto keys"""
reward_function_name: Optional[str] = field(default=None, init=False) score_function_name: Optional[str] = field(default=None, init=False)
def post_init(self): def post_init(self):
if self.reward_function is not None: # support custom reward function, e.g., ./math.py:main if self.score_function is not None:
if ":" not in self.reward_function: if ":" not in self.score_function:
self.reward_function_name = "main" self.score_function_name = "main"
else: else:
self.reward_function, self.reward_function_name = self.reward_function.rsplit(":", maxsplit=1) self.score_function, self.score_function_name = self.score_function.split(":", maxsplit=1)
if os.path.exists(self.reward_function): # ray job uses absolute path if os.path.exists(self.score_function):
self.reward_function = os.path.abspath(self.reward_function) self.score_function = os.path.abspath(self.score_function)
else: else:
self.reward_function = None self.score_function = None
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
import importlib.util import importlib.util
import os import os
import sys import sys
from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass
from functools import partial from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, TypedDict from typing import Callable, Dict, List, Optional, Tuple, TypedDict
...@@ -33,86 +33,54 @@ class RewardScore(TypedDict): ...@@ -33,86 +33,54 @@ class RewardScore(TypedDict):
accuracy: Optional[float] accuracy: Optional[float]
SequentialRewardFunction = Callable[[str, str], RewardScore] ScoreFunction = Callable[[str, str], RewardScore]
BatchRewardFunction = Callable[[List[str], List[str]], List[RewardScore]]
@dataclass
class FunctionRewardManager:
config: RewardConfig
tokenizer: PreTrainedTokenizer
class FunctionRewardManager(ABC): def __post_init__(self):
"""Reward manager for rule-based reward.""" """Load score function."""
if self.config.score_function is None:
raise ValueError("Score function is not provided.")
def __init__(self, config: RewardConfig, tokenizer: PreTrainedTokenizer): if not os.path.exists(self.config.score_function):
if config.reward_function is None: raise FileNotFoundError(f"Score function file {self.config.score_function} not found.")
raise ValueError("Reward function is not provided.")
if not os.path.exists(config.reward_function): spec = importlib.util.spec_from_file_location("custom_score_fn", self.config.score_function)
raise FileNotFoundError(f"Reward function file {config.reward_function} not found.")
spec = importlib.util.spec_from_file_location("custom_reward_fn", config.reward_function)
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
try: try:
sys.modules["custom_reward_fn"] = module sys.modules["custom_score_fn"] = module
spec.loader.exec_module(module) spec.loader.exec_module(module)
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to load reward function: {e}") raise RuntimeError(f"Failed to load score function: {e}")
if not hasattr(module, config.reward_function_name):
raise AttributeError(f"Module {module} does not have function {config.reward_function_name}.")
reward_fn = getattr(module, config.reward_function_name)
print(f"Using reward function `{config.reward_function_name}` from `{config.reward_function}`.")
self.reward_fn = partial(reward_fn, **config.reward_function_kwargs)
self.config = config
self.tokenizer = tokenizer
@abstractmethod if not hasattr(module, self.config.score_function_name):
def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]: raise AttributeError(f"Module {module} does not have function {self.config.score_function_name}.")
"""Compute reward for a batch of data."""
...
score_fn: ScoreFunction = getattr(module, self.config.score_function_name)
print(f"Using score function `{self.config.score_function_name}` from `{self.config.score_function}`.")
self.score_fn = partial(score_fn, **self.config.score_function_kwargs)
class SequentialFunctionRewardManager(FunctionRewardManager): def __call__(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
reward_fn: SequentialRewardFunction
def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
reward_metrics = defaultdict(list) reward_metrics = defaultdict(list)
response_ids = data.batch["responses"]
response_length = data.batch["response_mask"].sum(dim=-1)
for i in range(len(data)): for i in range(len(data)):
valid_response_ids = response_ids[i][: response_length[i]] data_item = data[i] # DataProtoItem
response_ids = data_item.batch["responses"]
response_mask = data_item.batch["response_mask"]
valid_response_length = response_mask.sum()
valid_response_ids = response_ids[:valid_response_length]
response_str = self.tokenizer.decode( response_str = self.tokenizer.decode(
valid_response_ids, skip_special_tokens=self.config.skip_special_tokens valid_response_ids, skip_special_tokens=self.config.skip_special_tokens
) )
ground_truth = data.non_tensor_batch["ground_truth"][i] ground_truth = data_item.non_tensor_batch["ground_truth"]
score = self.reward_fn(response_str, ground_truth) score = self.score_fn(response_str, ground_truth)
reward_tensor[i, response_length[i] - 1] = score["overall"] reward_tensor[i, valid_response_length - 1] = score["overall"]
for key, value in score.items():
reward_metrics[key].append(value)
return reward_tensor, reward_metrics
class BatchFunctionRewardManager(FunctionRewardManager):
reward_fn: BatchRewardFunction
def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
response_str, ground_truth = [], []
response_ids = data.batch["responses"]
response_length = data.batch["response_mask"].sum(dim=-1)
for i in range(len(data)):
valid_response_ids = response_ids[i][: response_length[i]]
response_str.append(
self.tokenizer.decode(valid_response_ids, skip_special_tokens=self.config.skip_special_tokens)
)
ground_truth.append(data.non_tensor_batch["ground_truth"][i])
scores = self.reward_fn(response_str, ground_truth)
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
reward_metrics = defaultdict(list)
for i, score in enumerate(scores):
reward_tensor[i, response_length[i] - 1] = score["overall"]
for key, value in score.items(): for key, value in score.items():
reward_metrics[key].append(value) reward_metrics[key].append(value)
......
...@@ -69,7 +69,7 @@ class vLLMRollout(BaseRollout): ...@@ -69,7 +69,7 @@ class vLLMRollout(BaseRollout):
self.inference_engine = LLM( self.inference_engine = LLM(
model=model_path, model=model_path,
skip_tokenizer_init=False, skip_tokenizer_init=False,
trust_remote_code=True, trust_remote_code=config.trust_remote_code,
load_format="dummy", load_format="dummy",
dtype=PrecisionType.to_str(PrecisionType.to_dtype(config.dtype)), dtype=PrecisionType.to_str(PrecisionType.to_dtype(config.dtype)),
seed=config.seed, seed=config.seed,
...@@ -85,11 +85,11 @@ class vLLMRollout(BaseRollout): ...@@ -85,11 +85,11 @@ class vLLMRollout(BaseRollout):
disable_mm_preprocessor_cache=True, disable_mm_preprocessor_cache=True,
enable_chunked_prefill=config.enable_chunked_prefill, enable_chunked_prefill=config.enable_chunked_prefill,
enable_sleep_mode=False, # only support GPUs enable_sleep_mode=False, # only support GPUs
cpu_offload_gb=64, cpu_offload_gb=64
) )
# Offload vllm model to reduce peak memory usage # Offload vllm model to reduce peak memory usage
# self.inference_engine.sleep(level=1) self.inference_engine.sleep(level=1)
sampling_kwargs = { sampling_kwargs = {
"max_tokens": config.response_length, "max_tokens": config.response_length,
......
...@@ -101,7 +101,6 @@ class FSDPVLLMShardingManager(BaseShardingManager): ...@@ -101,7 +101,6 @@ class FSDPVLLMShardingManager(BaseShardingManager):
print_gpu_memory_usage("Before vllm offload in sharding manager") print_gpu_memory_usage("Before vllm offload in sharding manager")
free_bytes_before_sleep = torch.cuda.mem_get_info()[0] free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
# self.inference_engine.sleep(level=1) # self.inference_engine.sleep(level=1)
free_bytes_after_sleep = torch.cuda.mem_get_info()[0] free_bytes_after_sleep = torch.cuda.mem_get_info()[0]
self.freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep self.freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
print_gpu_memory_usage("After vllm offload in sharding manager") print_gpu_memory_usage("After vllm offload in sharding manager")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment