Commit 6065b946 authored by chenych's avatar chenych
Browse files

Update 0427

parent 2369eb2b
......@@ -43,8 +43,6 @@ EasyR1基于 **[HybirdEngine](https://arxiv.org/abs/2409.19256)** 和最新发
>
> 训练需要使用到wandb,环境安装结束后,需要先登录wandb。
## 教程: 只需三步,在 [Geometry3K](https://huggingface.co/datasets/hiyouga/geometry3k) 数据集上基于GRPO算法训练Qwen2.5-VL。
![image](assets/qwen2_5_vl_7b_geo.png)
......@@ -91,7 +89,7 @@ python: 3.10
torch: 2.4.1
deepspeed: 0.14.2+das.opt2.dtk2504
flash-attn: 2.6.1+das.opt4.dtk2504
vllm: 0.7.2
vllm: 0.8.3
```
`Tips:以上dtk驱动、python、torch等DCU相关工具版本需要严格一一对应`
......@@ -106,13 +104,13 @@ pip install -e .
### GRPO 训练
```bash
bash examples/qwen2_5_7b_math_grpo.sh
bash examples/qwen2_5_vl_7b_geo3k_grpo.sh
```
### 基于Hugging Face Format融合Checkpoint
```bash
python3 scripts/model_merger.py --local_dir path_to_your_last_actor_checkpoint
python3 scripts/model_merger.py --local_dir path_to_your_actor_checkpoint
```
> [!NOTE]
......@@ -131,7 +129,8 @@ python3 scripts/model_merger.py --local_dir path_to_your_last_actor_checkpoint
## 其他基线
- [CLEVR-70k-Counting](examples/run_qwen2_5_vl_2b_clevr.sh):训练 Qwen2.5-VL-3B-Instruct 模型计数问题。
- [CLEVR-70k-Counting](examples/baselines/qwen2_5_vl_3b_clevr.sh):训练 Qwen2.5-VL-3B-Instruct 模型计数问题。
- [GeoQA-8k](examples/baselines/qwen2_5_vl_3b_geoqa8k.sh): 训练Qwen2.5-VL-3B-Instruct 的 GeoQA 问题.
### 已知问题
......
......@@ -33,19 +33,16 @@ EasyR1 is efficient and scalable due to the design of **[HybirdEngine](https://a
### Software Requirements
- Python 3.9+
- transformers>=4.49.0
- transformers>=4.51.0
- flash-attn>=2.4.3
- vllm>=0.7.3
- vllm>=0.8.3
We provide a [Dockerfile](./Dockerfile) to easily build environments.
We recommend using the [pre-built docker image](https://hub.docker.com/r/hiyouga/verl) in EasyR1.
```bash
# stable
docker pull hiyouga/verl:ngc-th2.5.1-cu120-vllm0.7.4-hotfix
# nightly
docker pull hiyouga/verl:ngc-th2.6.0-cu120-vllm0.8.2
docker pull hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.4-flashinfer0.2.2-cxx11abi0
```
### Hardware Requirements
......@@ -96,20 +93,42 @@ python3 scripts/model_merger.py --local_dir checkpoints/easy_r1/exp_name/global_
Please refer to the example datasets to prepare your own dataset.
- Text dataset: https://huggingface.co/datasets/hiyouga/math12k
- Vision-text dataset: https://huggingface.co/datasets/hiyouga/geometry3k
> [!TIP]
> EasyR1 already supports multi-image dataset.
- Image-text dataset: https://huggingface.co/datasets/hiyouga/geometry3k
- Multi-image-text dataset: https://huggingface.co/datasets/hiyouga/journeybench-multi-image-vqa
## How to Understand GRPO in EasyR1
![image](assets/easyr1_grpo.png)
- To learn about the GRPO algorithm, you can refer to [Hugging Face's blog](https://huggingface.co/docs/trl/v0.15.2/en/grpo_trainer).
- To learn about the GRPO algorithm, you can refer to [Hugging Face's blog](https://huggingface.co/docs/trl/v0.16.1/en/grpo_trainer).
## How to Run 70B+ Model in Multi-node Environment
Please see the **[veRL's official doc](https://verl.readthedocs.io/en/latest/start/multinode.html)** for multi-node training and Ray debugger.
1. Start the Ray head node.
```bash
ray start --head --port=6379 --dashboard-host=0.0.0.0
```
2. Start the Ray worker node and connect to the head node.
```bash
ray start --address=<head_node_ip>:6379
```
3. Check the Ray resource pool.
```bash
ray status
```
4. Run training script on the Ray head node only.
```bash
bash examples/qwen2_5_vl_7b_geo3k_grpo.sh
```
See the **[veRL's official doc](https://verl.readthedocs.io/en/latest/start/multinode.html)** for more details about multi-node training and Ray debugger.
## Other Baselines
......@@ -117,14 +136,20 @@ We also reproduced the following two baselines of the [R1-V](https://github.com/
- [CLEVR-70k-Counting](examples/baselines/qwen2_5_vl_3b_clevr.sh): Train the Qwen2.5-VL-3B-Instruct model on counting problem.
- [GeoQA-8k](examples/baselines/qwen2_5_vl_3b_geoqa8k.sh): Train the Qwen2.5-VL-3B-Instruct model on GeoQA problem.
## Performance Baselines
See [baselines.md](assets/baselines.md).
## Awesome Work using EasyR1
- **MMR1**: Advancing the Frontiers of Multimodal Reasoning. [![[code]](https://img.shields.io/github/stars/LengSicong/MMR1)](https://github.com/LengSicong/MMR1)
- **Vision-R1**: Incentivizing Reasoning Capability in Multimodal Large Language Models. [![[code]](https://img.shields.io/github/stars/Osilly/Vision-R1)](https://github.com/Osilly/Vision-R1) [![[arxiv]](https://img.shields.io/badge/arxiv-2503.06749-blue)](https://arxiv.org/abs/2503.06749)
- **Seg-Zero**: Reasoning-Chain Guided Segmentation via Cognitive Reinforcement. [![[code]](https://img.shields.io/github/stars/dvlab-research/Seg-Zero)](https://github.com/dvlab-research/Seg-Zero) [![[arxiv]](https://img.shields.io/badge/arxiv-2503.06520-blue)](https://arxiv.org/abs/2503.06520)
- **MetaSpatial**: Reinforcing 3D Spatial Reasoning in VLMs for the Metaverse. [![[code]](https://img.shields.io/github/stars/PzySeere/MetaSpatial)](https://github.com/PzySeere/MetaSpatial) [![[arxiv]](https://img.shields.io/badge/arxiv-2503.18470-blue)](https://arxiv.org/abs/2503.18470)
- **Temporal-R1**: Envolving Temporal Reasoning Capability into LMMs via Temporal Consistent Reward
[![[code]](https://img.shields.io/github/stars/appletea233/Temporal-R1)](https://github.com/appletea233/Temporal-R1)
- **Temporal-R1**: Envolving Temporal Reasoning Capability into LMMs via Temporal Consistent Reward. [![[code]](https://img.shields.io/github/stars/appletea233/Temporal-R1)](https://github.com/appletea233/Temporal-R1)
- **NoisyRollout**: Reinforcing Visual Reasoning with Data Augmentation. [![[code]](https://img.shields.io/github/stars/John-AI-Lab/NoisyRollout)](https://github.com/John-AI-Lab/NoisyRollout) [![[arxiv]](https://img.shields.io/badge/arxiv-2504.13055-blue)](https://arxiv.org/pdf/2504.13055)
- **GUI-R1**: A Generalist R1-Style Vision-Language Action Model For GUI Agents. [![[code]](https://img.shields.io/github/stars/ritzz-ai/GUI-R1)](https://github.com/ritzz-ai/GUI-R1) [![[arxiv]](https://img.shields.io/badge/arxiv-2504.10458-blue)](https://arxiv.org/abs/2504.10458)
## TODO
- Support LoRA (high priority).
......@@ -146,9 +171,17 @@ These features are temporarily disabled for now, we plan to fix them one-by-one
## FAQs
> ValueError: Image features and image tokens do not match: tokens: 8192, features 9800
Increase the `data.max_prompt_length` or reduce the `data.max_pixels`.
> RuntimeError: CUDA Error: out of memory at /workspace/csrc/cumem_allocator.cpp:62
Reduce the `worker.rollout.gpu_memory_utilization`.
Reduce the `worker.rollout.gpu_memory_utilization` and enable `worker.actor.offload.offload_params`.
> RuntimeError: 0 active drivers ([]). There should only be one.
Uninstall `deepspeed` from the current python environment.
## Citation
......
# Baselines
Environment: [hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0](https://hub.docker.com/layers/hiyouga/verl/ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0/images/sha256-335ed6cd1fe73090e458409cfa4394d6abf4cd0503ca44dbafdc28ff72e5ed20)
EasyR1 version: [v0.3.0](https://github.com/hiyouga/EasyR1/tree/v0.3.0)
Welcome to contribute new data points!
## Algorithm Baselines
### [Qwen2.5-Instruct](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct) on [Math12k](https://huggingface.co/datasets/hiyouga/math12k)
| Size | Algorithm | Bits | LR | KL | Test Score |
| ---- | ----------- | ---- | ---- | ---- | ---------- |
| 7B | GRPO | AMP | 1e-6 | 1e-2 | 0.73->0.79 |
### [Qwen2.5-VL-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) on [Geometry3k](https://huggingface.co/datasets/hiyouga/geometry3k)
| Size | Algorithm | Bits | LR | KL | Test Score |
| ---- | ----------- | ---- | ---- | ---- | ---------- |
| 7B | GRPO | AMP | 1e-6 | 1e-2 | 0.39->0.52 |
| 7B | GRPO | BF16 | 1e-6 | 1e-2 | 0.39->0.52 |
| 7B | GRPO | AMP | 1e-6 | 1e-3 | 0.39->0.52 |
| 7B | RLOO | AMP | 1e-6 | 1e-2 | 0.39->0.53 |
| 3B | GRPO | AMP | 1e-6 | 1e-2 | 0.27->0.44 |
| 32B | GRPO | BF16 | 1e-6 | 1e-2 | 0.46->0.61 |
> [!NOTE]
> The hyper-parameters not listed are all the same as the default values.
## Performance Baselines
### [Qwen2.5-VL-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) on [Geometry3k](https://huggingface.co/datasets/hiyouga/geometry3k)
| Size | GPU Type | Bits | Batch Size | vLLM Util | vLLM TP | Peak Mem | Peak VRAM | Throughput | Sec per step | Actor MFU |
| ---- | ------------- | ---- | ---------- | --------- | ------- | -------- | --------- | ---------- | ------------ | --------- |
| 3B | 8 * H100 80GB | AMP | 4 / 16 | 0.6 | 2 | 120GB | 35GB | 1200 | 180s | 6.3% |
| 7B | 8 * H100 80GB | AMP | 4 / 16 | 0.6 | 2 | 140GB | 60GB | 1200 | 180s | 13.6% |
| 7B | 8 * H100 80GB | AMP | 10 / 20 | 0.6 | 2 | 150GB | 75GB | 1400 | 170s | 19.2% |
| 7B | 8 * L20 48GB | AMP | 4 / 16 | 0.6 | 2 | 150GB | 44GB | 410 | 580s | 26.5% |
| 7B | 8 * H100 80GB | BF16 | 4 / 16 | 0.6 | 2 | 150GB | 50GB | 1280 | 190s | 13.9% |
| 32B | 8 * H100 80GB | BF16 | 1 / 8 | 0.6 | 8 | 240GB | 68GB | 360 | 860s | 11.2% |
- Batch Size: micro_batch_size_per_device_for_update / micro_batch_size_per_device_for_experience
- vLLM Util: rollout.gpu_memory_utilization
- vLLM TP: rollout.tensor_parallel_size
- Peak Mem: Peak CPU memory usage
- Peak VRAM: Peak GPU memory usage
- Throughput: Number of tokens per second per GPU by one training step
- Sec per step: Average time per step in seconds
> [!NOTE]
> The hyper-parameters not listed are all the same as the default values.
assets/wechat.jpg

111 KB | W: | H:

assets/wechat.jpg

113 KB | W: | H:

assets/wechat.jpg
assets/wechat.jpg
assets/wechat.jpg
assets/wechat.jpg
  • 2-up
  • Swipe
  • Onion skin
......@@ -9,6 +9,6 @@ python3 -m verl.trainer.main \
data.format_prompt=./examples/format_prompt/r1v_format.jinja \
worker.actor.model.model_path=${MODEL_PATH} \
worker.rollout.tensor_parallel_size=1 \
worker.reward.score_function=./examples/score_function/r1v.py:compute_score \
worker.reward.reward_function=./examples/reward_function/r1v.py:compute_score \
trainer.experiment_name=qwen2_5_vl_3b_clevr \
trainer.n_gpus_per_node=2
......@@ -9,6 +9,6 @@ python3 -m verl.trainer.main \
data.format_prompt=./examples/format_prompt/r1v_format.jinja \
worker.actor.model.model_path=${MODEL_PATH} \
worker.rollout.tensor_parallel_size=1 \
worker.reward.score_function=./examples/score_function/r1v.py:compute_score \
worker.reward.reward_function=./examples/reward_function/r1v.py:compute_score \
trainer.experiment_name=qwen2_5_vl_3b_geoqa8k \
trainer.n_gpus_per_node=8
......@@ -7,7 +7,7 @@ data:
max_prompt_length: 2048
max_response_length: 2048
rollout_batch_size: 512
val_batch_size: -1
val_batch_size: 1024
format_prompt: ./examples/format_prompt/math_format.jinja
shuffle: true
seed: 1
......@@ -71,7 +71,7 @@ worker:
reward:
reward_type: function
score_function: ./examples/score_function/math.py:compute_score
reward_function: ./examples/reward_function/math.py:compute_score
trainer:
total_episodes: 15
......
......@@ -18,9 +18,17 @@ import re
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Tuple
import numpy as np
import torch
from torch.distributed._tensor import DTensor, Placement, Shard
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForTokenClassification,
AutoModelForVision2Seq,
PretrainedConfig,
PreTrainedModel,
)
def merge_by_placement(tensors: List[torch.Tensor], placement: Placement):
......@@ -34,14 +42,23 @@ def merge_by_placement(tensors: List[torch.Tensor], placement: Placement):
raise ValueError(f"Unsupported placement: {placement}")
def upload_model_to_huggingface(local_path: str, remote_path: str):
# Push to hugging face
from huggingface_hub import HfApi
api = HfApi()
api.create_repo(repo_id=remote_path, private=False, exist_ok=True)
api.upload_folder(repo_id=remote_path, folder_path=local_path, repo_type="model")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local_dir", required=True, type=str, help="The path for your saved model")
parser.add_argument("--hf_upload_path", default=False, type=str, help="The path of the huggingface repo to upload")
args = parser.parse_args()
local_dir: str = args.local_dir
assert not args.local_dir.endswith("huggingface"), "The local_dir should not end with huggingface"
local_dir = args.local_dir
assert not local_dir.endswith("huggingface"), "The local_dir should not end with huggingface."
# copy rank zero to find the shape of (dp, fsdp)
rank = 0
......@@ -51,22 +68,26 @@ if __name__ == "__main__":
if match:
world_size = match.group(1)
break
assert world_size, "No model file with the proper format"
state_dict = torch.load(
os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt"), map_location="cpu"
)
assert world_size, "No model file with the proper format."
rank0_weight_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt")
state_dict = torch.load(rank0_weight_path, map_location="cpu", weights_only=False)
pivot_key = sorted(state_dict.keys())[0]
weight = state_dict[pivot_key]
assert isinstance(weight, torch.distributed._tensor.DTensor)
# get sharding info
device_mesh = weight.device_mesh
mesh = device_mesh.mesh
mesh_dim_names = device_mesh.mesh_dim_names
if isinstance(weight, DTensor):
# get sharding info
device_mesh = weight.device_mesh
mesh = device_mesh.mesh
mesh_dim_names = device_mesh.mesh_dim_names
else:
# for non-DTensor
mesh = np.array([int(world_size)], dtype=np.int64)
mesh_dim_names = ("fsdp",)
print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}")
assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}"
assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}."
if "tp" in mesh_dim_names:
# fsdp * tp
......@@ -77,13 +98,12 @@ if __name__ == "__main__":
total_shards = mesh.shape[-1]
mesh_shape = (mesh.shape[-1],)
print(f"Processing model shards with {total_shards} {mesh_shape} in total")
print(f"Processing {total_shards} model shards in total.")
model_state_dict_lst = []
model_state_dict_lst.append(state_dict)
model_state_dict_lst.extend([""] * (total_shards - 1))
def process_one_shard(rank):
def process_one_shard(rank, model_state_dict_lst):
model_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt")
state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
model_state_dict_lst[rank] = state_dict
......@@ -91,8 +111,9 @@ if __name__ == "__main__":
with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor:
for rank in range(1, total_shards):
executor.submit(process_one_shard, rank)
state_dict = {}
executor.submit(process_one_shard, rank, model_state_dict_lst)
state_dict: Dict[str, List[torch.Tensor]] = {}
param_placements: Dict[str, List[Placement]] = {}
keys = set(model_state_dict_lst[0].keys())
for key in keys:
......@@ -101,8 +122,8 @@ if __name__ == "__main__":
try:
tensor = model_state_dict.pop(key)
except Exception:
print("-" * 30)
print(model_state_dict)
print(f"Cannot find key {key} in rank {rank}.")
if isinstance(tensor, DTensor):
state_dict[key].append(tensor._local_tensor.bfloat16())
placements = tuple(tensor.placements)
......@@ -115,7 +136,7 @@ if __name__ == "__main__":
else:
assert param_placements[key] == placements
else:
state_dict[key] = tensor.bfloat16()
state_dict[key].append(tensor.bfloat16())
del model_state_dict_lst
......@@ -123,43 +144,44 @@ if __name__ == "__main__":
if not isinstance(state_dict[key], list):
print(f"No need to merge key {key}")
continue
# merge shards
placements: Tuple[Shard] = param_placements[key]
if len(mesh_shape) == 1:
# 1-D list, FSDP without TP
assert len(placements) == 1
shards = state_dict[key]
state_dict[key] = merge_by_placement(shards, placements[0])
if key in param_placements:
# merge shards
placements: Tuple[Shard] = param_placements[key]
if len(mesh_shape) == 1:
# 1-D list, FSDP without TP
assert len(placements) == 1
shards = state_dict[key]
state_dict[key] = merge_by_placement(shards, placements[0])
else:
# 2-D list, FSDP + TP
raise NotImplementedError("FSDP + TP is not supported yet.")
else:
# 2-D list, FSDP + TP
raise NotImplementedError("FSDP + TP is not supported yet")
state_dict[key] = torch.cat(state_dict[key], dim=0)
print("Writing to local disk")
print("Merge completed.")
hf_path = os.path.join(local_dir, "huggingface")
config = AutoConfig.from_pretrained(hf_path)
if "ForTokenClassification" in config.architectures[0]:
auto_model = AutoModelForTokenClassification
elif "ForCausalLM" in config.architectures[0]:
auto_model = AutoModelForCausalLM
elif "ForConditionalGeneration" in config.architectures[0]:
auto_model = AutoModelForVision2Seq
config: PretrainedConfig = AutoConfig.from_pretrained(hf_path)
architectures: List[str] = getattr(config, "architectures", ["Unknown"])
if "ForTokenClassification" in architectures[0]:
AutoClass = AutoModelForTokenClassification
elif "ForCausalLM" in architectures[0]:
AutoClass = AutoModelForCausalLM
elif "ForConditionalGeneration" in architectures[0]:
AutoClass = AutoModelForVision2Seq
else:
raise NotImplementedError(f"Unknown architecture {config.architectures}")
raise NotImplementedError(f"Unknown architecture {architectures}.")
with torch.device("meta"):
model = auto_model.from_config(config, torch_dtype=torch.bfloat16)
model: PreTrainedModel = AutoClass.from_config(config, torch_dtype=torch.bfloat16)
assert isinstance(model, PreTrainedModel)
model.to_empty(device="cpu")
print(f"Saving model to {hf_path}")
print(f"Saving model to {hf_path}...")
model.save_pretrained(hf_path, state_dict=state_dict)
del state_dict
del model
if args.hf_upload_path:
# Push to hugging face
from huggingface_hub import HfApi
del state_dict, model
api = HfApi()
api.create_repo(repo_id=args.hf_upload_path, private=False, exist_ok=True)
api.upload_folder(folder_path=hf_path, repo_id=args.hf_upload_path, repo_type="model")
if args.hf_upload_path:
upload_model_to_huggingface(hf_path, args.hf_upload_path)
......@@ -51,7 +51,7 @@ class DataConfig:
def post_init(self):
if self.format_prompt is not None:
if os.path.exists(self.format_prompt):
if os.path.exists(self.format_prompt): # ray job uses absolute path
self.format_prompt = os.path.abspath(self.format_prompt)
else:
self.format_prompt = None
......@@ -94,7 +94,7 @@ class TrainerConfig:
if self.save_checkpoint_path is None:
self.save_checkpoint_path = os.path.join("checkpoints", self.project_name, self.experiment_name)
self.save_checkpoint_path = os.path.abspath(self.save_checkpoint_path)
self.save_checkpoint_path = os.path.abspath(self.save_checkpoint_path) # ray job uses absolute path
if self.load_checkpoint_path is not None:
self.load_checkpoint_path = os.path.abspath(self.load_checkpoint_path)
......
......@@ -65,12 +65,11 @@ class Runner:
}
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
reward_fn = FunctionRewardManager(config=config.worker.reward, tokenizer=tokenizer)
val_reward_fn = FunctionRewardManager(config=config.worker.reward, tokenizer=tokenizer)
RemoteRewardManager = ray.remote(FunctionRewardManager).options(num_cpus=config.worker.reward.num_cpus)
reward_fn = RemoteRewardManager.remote(config.worker.reward, tokenizer)
val_reward_fn = RemoteRewardManager.remote(config.worker.reward, tokenizer)
train_dataloader, val_dataloader = create_dataloader(
config=config.data, tokenizer=tokenizer, processor=processor
)
train_dataloader, val_dataloader = create_dataloader(config.data, tokenizer, processor)
trainer = RayPPOTrainer(
config=config,
......
......@@ -19,16 +19,14 @@ This trainer supports model-agonistic model initialization with huggingface
import os
import uuid
from collections import defaultdict
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass, field
from enum import Enum, IntEnum, auto
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Type
import numpy as np
import ray
import torch
from codetiming import Timer
from ray.experimental.tqdm_ray import tqdm
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import PreTrainedTokenizer, ProcessorMixin
......@@ -40,9 +38,10 @@ from ..single_controller.ray.base import create_colocated_worker_cls
from ..utils import torch_functional as VF
from ..utils.checkpoint import CHECKPOINT_TRACKER, remove_obsolete_ckpt
from ..utils.logger import Tracker
from ..utils.py_functional import convert_dict_to_str
from ..utils.py_functional import convert_dict_to_str, timer
from ..utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
from ..workers.fsdp_workers import FSDPWorker
from ..workers.reward import FunctionRewardManager
from . import core_algos
from .config import PPOConfig
from .metrics import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics, reduce_metrics
......@@ -162,14 +161,6 @@ def compute_advantage(data: DataProto, adv_estimator: AdvantageEstimator, gamma:
return data
@contextmanager
def _timer(name: str, timing_raw: Dict[str, float]):
with Timer(name=name, logger=None) as timer:
yield
timing_raw[name] = timer.last
class RayPPOTrainer:
"""
Note that this trainer runs on the driver process on a single CPU/GPU node.
......@@ -185,8 +176,8 @@ class RayPPOTrainer:
role_worker_mapping: dict[Role, Type[Worker]],
resource_pool_manager: ResourcePoolManager,
ray_worker_group_cls: Type[RayWorkerGroup] = RayWorkerGroup,
reward_fn: Optional[Callable[[DataProto], Tuple[torch.Tensor, Dict[str, List[float]]]]] = None,
val_reward_fn: Optional[Callable[[DataProto], Tuple[torch.Tensor, Dict[str, List[float]]]]] = None,
reward_fn: Optional[FunctionRewardManager] = None,
val_reward_fn: Optional[FunctionRewardManager] = None,
):
self.tokenizer = tokenizer
self.processor = processor
......@@ -307,7 +298,6 @@ class RayPPOTrainer:
test_gen_batch, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size)
test_output_gen_batch = self.actor_rollout_wg.generate_sequences(test_gen_batch)
test_output_gen_batch = unpad_dataproto(test_output_gen_batch, pad_size=pad_size)
print("validation generation end")
# Store generated outputs
output_ids = test_output_gen_batch.batch["responses"]
......@@ -317,7 +307,7 @@ class RayPPOTrainer:
test_batch = test_batch.union(test_output_gen_batch)
# evaluate using reward_function
reward_tensor, reward_metrics = self.val_reward_fn(test_batch)
reward_tensor, reward_metrics = ray.get(self.val_reward_fn.compute_reward.remote(test_batch))
# Store scores
scores = reward_tensor.sum(-1).cpu().tolist()
......@@ -504,20 +494,20 @@ class RayPPOTrainer:
non_tensor_batch_keys=["raw_prompt_ids"],
)
with _timer("step", timing_raw):
with timer("step", timing_raw):
# generate a batch
with _timer("gen", timing_raw): # wg: worker group
with timer("gen", timing_raw): # wg: worker group
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
if self.config.algorithm.adv_estimator == "remax":
with _timer("gen_max", timing_raw):
with timer("gen_max", timing_raw):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["temperature"] = 0
gen_baseline_batch.meta_info["n"] = 1
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
batch = batch.union(gen_baseline_output)
reward_baseline_tensor, _ = self.reward_fn(batch)
reward_baseline_tensor, _ = ray.get(self.reward_fn.compute_reward.remote(batch))
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
......@@ -532,19 +522,6 @@ class RayPPOTrainer:
batch = batch.union(gen_batch_output)
batch.non_tensor_batch.pop("multi_modal_data", None)
# compute reward
with _timer("reward", timing_raw):
if self.use_reward_model:
raise NotImplementedError("Reward model is not supported yet.")
# we combine with rule-based rm
reward_tensor, reward_metrics = self.reward_fn(batch)
batch.batch["token_level_scores"] = reward_tensor
reward_metrics = {
f"reward/{key}": value for key, value in reduce_metrics(reward_metrics).items()
}
metrics.update(reward_metrics)
# balance the number of valid tokens on each dp rank.
# Note that this breaks the order of data inside the batch.
# Please take care when you implement group based adv computation such as GRPO and rloo
......@@ -553,30 +530,38 @@ class RayPPOTrainer:
# compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
# compute reward
with timer("reward", timing_raw):
reward_ref = self.reward_fn.compute_reward.remote(batch)
# recompute old_log_probs
with _timer("old", timing_raw):
with timer("old", timing_raw):
old_log_probs = self.actor_rollout_wg.compute_log_probs(batch)
batch = batch.union(old_log_probs)
# compute ref_log_probs
if self.use_reference_policy:
with _timer("ref", timing_raw):
with timer("ref", timing_raw):
ref_log_probs = self.ref_policy_wg.compute_ref_log_probs(batch)
batch = batch.union(ref_log_probs)
# compute values
if self.use_critic:
with _timer("values", timing_raw):
with timer("values", timing_raw):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
with _timer("adv", timing_raw):
with timer("adv", timing_raw):
# get token level scores
reward_tensor, reward_metrics = ray.get(reward_ref)
batch.batch["token_level_scores"] = reward_tensor
reward_metrics = {f"reward/{k}": v for k, v in reduce_metrics(reward_metrics).items()}
metrics.update(reward_metrics)
# apply kl penalty if available
if not self.config.algorithm.use_kl_loss and self.use_reference_policy:
# apply kl penalty to reward
batch, kl_metrics = apply_kl_penalty(
batch, kl_ctrl=self.kl_ctrl, kl_penalty=self.config.algorithm.kl_penalty
)
batch, kl_metrics = apply_kl_penalty(batch, self.kl_ctrl, self.config.algorithm.kl_penalty)
metrics.update(kl_metrics)
else:
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
......@@ -591,7 +576,7 @@ class RayPPOTrainer:
# update critic
if self.use_critic:
with _timer("update_critic", timing_raw):
with timer("update_critic", timing_raw):
critic_output = self.critic_wg.update_critic(batch)
critic_metrics = reduce_metrics(critic_output.non_tensor_batch)
......@@ -599,7 +584,7 @@ class RayPPOTrainer:
# update actor
if self.config.trainer.critic_warmup <= self.global_step:
with _timer("update_actor", timing_raw):
with timer("update_actor", timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_metrics = reduce_metrics(actor_output.non_tensor_batch)
......@@ -611,13 +596,13 @@ class RayPPOTrainer:
and self.config.trainer.val_freq > 0
and self.global_step % self.config.trainer.val_freq == 0
):
with _timer("validation", timing_raw):
with timer("validation", timing_raw):
val_metrics = self._validate()
metrics.update(val_metrics)
if self.config.trainer.save_freq > 0 and self.global_step % self.config.trainer.save_freq == 0:
with _timer("save_checkpoint", timing_raw):
with timer("save_checkpoint", timing_raw):
self._save_checkpoint()
# collect metrics
......
......@@ -17,11 +17,13 @@ Contain small python utility functions
import importlib.util
import re
from contextlib import contextmanager
from functools import lru_cache
from typing import Any, Dict, List, Union
import numpy as np
import yaml
from codetiming import Timer
from yaml import Dumper
......@@ -101,3 +103,11 @@ def flatten_dict(data: Dict[str, Any], parent_key: str = "", sep: str = "/") ->
def convert_dict_to_str(data: Dict[str, Any]) -> str:
return yaml.dump(data, indent=2)
@contextmanager
def timer(name: str, timing_raw: Dict[str, float]):
with Timer(name=name, logger=None) as timer:
yield
timing_raw[name] = timer.last
......@@ -33,7 +33,7 @@ class ModelConfig:
if self.tokenizer_path is None:
self.tokenizer_path = self.model_path
if self.model_path is not None and os.path.exists(self.model_path):
if self.model_path is not None and os.path.exists(self.model_path): # ray job uses absolute path
self.model_path = os.path.abspath(self.model_path)
if self.tokenizer_path is not None and os.path.exists(self.tokenizer_path):
......
......@@ -262,10 +262,10 @@ class FSDPWorker(Worker):
else:
sync_module_states = False
param_init_fn = None
## TODO: 模型指定到卡
rank = torch.cuda.set_device(self.rank)
model = model.to(rank)
# rank = torch.cuda.set_device(self.rank)
# model = model.to(rank)
print(f"!!! local_rank={self.rank}, torch.cuda.current_device()={torch.cuda.current_device()}")
self.fsdp_module = FSDP(
model,
sharding_strategy=sharding_strategy,
......@@ -284,7 +284,7 @@ class FSDPWorker(Worker):
if self._is_actor or self._is_critic:
if optim_config.strategy == "adamw":
self.optimizer = torch.optim.AdamW(
self.fsdp_module.parameters(),
filter(lambda p: p.requires_grad, self.fsdp_module.parameters()),
lr=optim_config.lr,
betas=optim_config.betas,
weight_decay=optim_config.weight_decay,
......@@ -292,7 +292,7 @@ class FSDPWorker(Worker):
)
elif optim_config.strategy == "adamw_bf16":
self.optimizer = AnyPrecisionAdamW(
self.fsdp_module.parameters(),
filter(lambda p: p.requires_grad, self.fsdp_module.parameters()),
lr=optim_config.lr,
betas=optim_config.betas,
weight_decay=optim_config.weight_decay,
......
......@@ -23,20 +23,21 @@ from typing import Optional
@dataclass
class RewardConfig:
reward_type: str = "function"
score_function: Optional[str] = None
score_function_kwargs: dict = field(default_factory=dict)
reward_function: Optional[str] = None
reward_function_kwargs: dict = field(default_factory=dict)
skip_special_tokens: bool = True
num_cpus: int = 1
"""auto keys"""
score_function_name: Optional[str] = field(default=None, init=False)
reward_function_name: Optional[str] = field(default=None, init=False)
def post_init(self):
if self.score_function is not None:
if ":" not in self.score_function:
self.score_function_name = "main"
if self.reward_function is not None: # support custom reward function, e.g., ./math.py:main
if ":" not in self.reward_function:
self.reward_function_name = "main"
else:
self.score_function, self.score_function_name = self.score_function.split(":", maxsplit=1)
self.reward_function, self.reward_function_name = self.reward_function.rsplit(":", maxsplit=1)
if os.path.exists(self.score_function):
self.score_function = os.path.abspath(self.score_function)
if os.path.exists(self.reward_function): # ray job uses absolute path
self.reward_function = os.path.abspath(self.reward_function)
else:
self.score_function = None
self.reward_function = None
......@@ -16,7 +16,6 @@ import importlib.util
import os
import sys
from collections import defaultdict
from dataclasses import dataclass
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, TypedDict
......@@ -33,38 +32,37 @@ class RewardScore(TypedDict):
accuracy: Optional[float]
ScoreFunction = Callable[[str, str], RewardScore]
RewardFunction = Callable[[str, str], RewardScore]
@dataclass
class FunctionRewardManager:
config: RewardConfig
tokenizer: PreTrainedTokenizer
"""Reward manager for rule-based reward."""
def __post_init__(self):
"""Load score function."""
if self.config.score_function is None:
raise ValueError("Score function is not provided.")
def __init__(self, config: RewardConfig, tokenizer: PreTrainedTokenizer):
if config.reward_function is None:
raise ValueError("Reward function is not provided.")
if not os.path.exists(self.config.score_function):
raise FileNotFoundError(f"Score function file {self.config.score_function} not found.")
if not os.path.exists(config.reward_function):
raise FileNotFoundError(f"Reward function file {config.reward_function} not found.")
spec = importlib.util.spec_from_file_location("custom_score_fn", self.config.score_function)
spec = importlib.util.spec_from_file_location("custom_reward_fn", config.reward_function)
module = importlib.util.module_from_spec(spec)
try:
sys.modules["custom_score_fn"] = module
sys.modules["custom_reward_fn"] = module
spec.loader.exec_module(module)
except Exception as e:
raise RuntimeError(f"Failed to load score function: {e}")
raise RuntimeError(f"Failed to load reward function: {e}")
if not hasattr(module, self.config.score_function_name):
raise AttributeError(f"Module {module} does not have function {self.config.score_function_name}.")
if not hasattr(module, config.reward_function_name):
raise AttributeError(f"Module {module} does not have function {config.reward_function_name}.")
score_fn: ScoreFunction = getattr(module, self.config.score_function_name)
print(f"Using score function `{self.config.score_function_name}` from `{self.config.score_function}`.")
self.score_fn = partial(score_fn, **self.config.score_function_kwargs)
reward_fn: RewardFunction = getattr(module, config.reward_function_name)
print(f"Using reward function `{config.reward_function_name}` from `{config.reward_function}`.")
self.reward_fn = partial(reward_fn, **config.reward_function_kwargs)
self.config = config
self.tokenizer = tokenizer
def __call__(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
reward_metrics = defaultdict(list)
for i in range(len(data)):
......@@ -79,7 +77,7 @@ class FunctionRewardManager:
)
ground_truth = data_item.non_tensor_batch["ground_truth"]
score = self.score_fn(response_str, ground_truth)
score = self.reward_fn(response_str, ground_truth)
reward_tensor[i, valid_response_length - 1] = score["overall"]
for key, value in score.items():
reward_metrics[key].append(value)
......
......@@ -69,7 +69,7 @@ class vLLMRollout(BaseRollout):
self.inference_engine = LLM(
model=model_path,
skip_tokenizer_init=False,
trust_remote_code=config.trust_remote_code,
trust_remote_code=True,
load_format="dummy",
dtype=PrecisionType.to_str(PrecisionType.to_dtype(config.dtype)),
seed=config.seed,
......@@ -79,13 +79,12 @@ class vLLMRollout(BaseRollout):
gpu_memory_utilization=config.gpu_memory_utilization,
max_num_batched_tokens=config.max_num_batched_tokens,
disable_log_stats=config.disable_log_stats,
enforce_eager=config.enforce_eager,
enforce_eager=True,
disable_custom_all_reduce=True,
limit_mm_per_prompt={"image": config.limit_images} if config.limit_images > 0 else None,
disable_mm_preprocessor_cache=True,
enable_chunked_prefill=config.enable_chunked_prefill,
enable_sleep_mode=False, # nv True rocm False
# swap_space=20,
enable_sleep_mode=False,
)
# Offload vllm model to reduce peak memory usage
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment