Commit 6065b946 authored by chenych's avatar chenych
Browse files

Update 0427

parent 2369eb2b
...@@ -43,8 +43,6 @@ EasyR1基于 **[HybirdEngine](https://arxiv.org/abs/2409.19256)** 和最新发 ...@@ -43,8 +43,6 @@ EasyR1基于 **[HybirdEngine](https://arxiv.org/abs/2409.19256)** 和最新发
> >
> 训练需要使用到wandb,环境安装结束后,需要先登录wandb。 > 训练需要使用到wandb,环境安装结束后,需要先登录wandb。
## 教程: 只需三步,在 [Geometry3K](https://huggingface.co/datasets/hiyouga/geometry3k) 数据集上基于GRPO算法训练Qwen2.5-VL。 ## 教程: 只需三步,在 [Geometry3K](https://huggingface.co/datasets/hiyouga/geometry3k) 数据集上基于GRPO算法训练Qwen2.5-VL。
![image](assets/qwen2_5_vl_7b_geo.png) ![image](assets/qwen2_5_vl_7b_geo.png)
...@@ -91,7 +89,7 @@ python: 3.10 ...@@ -91,7 +89,7 @@ python: 3.10
torch: 2.4.1 torch: 2.4.1
deepspeed: 0.14.2+das.opt2.dtk2504 deepspeed: 0.14.2+das.opt2.dtk2504
flash-attn: 2.6.1+das.opt4.dtk2504 flash-attn: 2.6.1+das.opt4.dtk2504
vllm: 0.7.2 vllm: 0.8.3
``` ```
`Tips:以上dtk驱动、python、torch等DCU相关工具版本需要严格一一对应` `Tips:以上dtk驱动、python、torch等DCU相关工具版本需要严格一一对应`
...@@ -106,13 +104,13 @@ pip install -e . ...@@ -106,13 +104,13 @@ pip install -e .
### GRPO 训练 ### GRPO 训练
```bash ```bash
bash examples/qwen2_5_7b_math_grpo.sh bash examples/qwen2_5_vl_7b_geo3k_grpo.sh
``` ```
### 基于Hugging Face Format融合Checkpoint ### 基于Hugging Face Format融合Checkpoint
```bash ```bash
python3 scripts/model_merger.py --local_dir path_to_your_last_actor_checkpoint python3 scripts/model_merger.py --local_dir path_to_your_actor_checkpoint
``` ```
> [!NOTE] > [!NOTE]
...@@ -131,7 +129,8 @@ python3 scripts/model_merger.py --local_dir path_to_your_last_actor_checkpoint ...@@ -131,7 +129,8 @@ python3 scripts/model_merger.py --local_dir path_to_your_last_actor_checkpoint
## 其他基线 ## 其他基线
- [CLEVR-70k-Counting](examples/run_qwen2_5_vl_2b_clevr.sh):训练 Qwen2.5-VL-3B-Instruct 模型计数问题。 - [CLEVR-70k-Counting](examples/baselines/qwen2_5_vl_3b_clevr.sh):训练 Qwen2.5-VL-3B-Instruct 模型计数问题。
- [GeoQA-8k](examples/baselines/qwen2_5_vl_3b_geoqa8k.sh): 训练Qwen2.5-VL-3B-Instruct 的 GeoQA 问题.
### 已知问题 ### 已知问题
......
...@@ -33,19 +33,16 @@ EasyR1 is efficient and scalable due to the design of **[HybirdEngine](https://a ...@@ -33,19 +33,16 @@ EasyR1 is efficient and scalable due to the design of **[HybirdEngine](https://a
### Software Requirements ### Software Requirements
- Python 3.9+ - Python 3.9+
- transformers>=4.49.0 - transformers>=4.51.0
- flash-attn>=2.4.3 - flash-attn>=2.4.3
- vllm>=0.7.3 - vllm>=0.8.3
We provide a [Dockerfile](./Dockerfile) to easily build environments. We provide a [Dockerfile](./Dockerfile) to easily build environments.
We recommend using the [pre-built docker image](https://hub.docker.com/r/hiyouga/verl) in EasyR1. We recommend using the [pre-built docker image](https://hub.docker.com/r/hiyouga/verl) in EasyR1.
```bash ```bash
# stable docker pull hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.4-flashinfer0.2.2-cxx11abi0
docker pull hiyouga/verl:ngc-th2.5.1-cu120-vllm0.7.4-hotfix
# nightly
docker pull hiyouga/verl:ngc-th2.6.0-cu120-vllm0.8.2
``` ```
### Hardware Requirements ### Hardware Requirements
...@@ -96,20 +93,42 @@ python3 scripts/model_merger.py --local_dir checkpoints/easy_r1/exp_name/global_ ...@@ -96,20 +93,42 @@ python3 scripts/model_merger.py --local_dir checkpoints/easy_r1/exp_name/global_
Please refer to the example datasets to prepare your own dataset. Please refer to the example datasets to prepare your own dataset.
- Text dataset: https://huggingface.co/datasets/hiyouga/math12k - Text dataset: https://huggingface.co/datasets/hiyouga/math12k
- Vision-text dataset: https://huggingface.co/datasets/hiyouga/geometry3k - Image-text dataset: https://huggingface.co/datasets/hiyouga/geometry3k
- Multi-image-text dataset: https://huggingface.co/datasets/hiyouga/journeybench-multi-image-vqa
> [!TIP]
> EasyR1 already supports multi-image dataset.
## How to Understand GRPO in EasyR1 ## How to Understand GRPO in EasyR1
![image](assets/easyr1_grpo.png) ![image](assets/easyr1_grpo.png)
- To learn about the GRPO algorithm, you can refer to [Hugging Face's blog](https://huggingface.co/docs/trl/v0.15.2/en/grpo_trainer). - To learn about the GRPO algorithm, you can refer to [Hugging Face's blog](https://huggingface.co/docs/trl/v0.16.1/en/grpo_trainer).
## How to Run 70B+ Model in Multi-node Environment ## How to Run 70B+ Model in Multi-node Environment
Please see the **[veRL's official doc](https://verl.readthedocs.io/en/latest/start/multinode.html)** for multi-node training and Ray debugger. 1. Start the Ray head node.
```bash
ray start --head --port=6379 --dashboard-host=0.0.0.0
```
2. Start the Ray worker node and connect to the head node.
```bash
ray start --address=<head_node_ip>:6379
```
3. Check the Ray resource pool.
```bash
ray status
```
4. Run training script on the Ray head node only.
```bash
bash examples/qwen2_5_vl_7b_geo3k_grpo.sh
```
See the **[veRL's official doc](https://verl.readthedocs.io/en/latest/start/multinode.html)** for more details about multi-node training and Ray debugger.
## Other Baselines ## Other Baselines
...@@ -117,14 +136,20 @@ We also reproduced the following two baselines of the [R1-V](https://github.com/ ...@@ -117,14 +136,20 @@ We also reproduced the following two baselines of the [R1-V](https://github.com/
- [CLEVR-70k-Counting](examples/baselines/qwen2_5_vl_3b_clevr.sh): Train the Qwen2.5-VL-3B-Instruct model on counting problem. - [CLEVR-70k-Counting](examples/baselines/qwen2_5_vl_3b_clevr.sh): Train the Qwen2.5-VL-3B-Instruct model on counting problem.
- [GeoQA-8k](examples/baselines/qwen2_5_vl_3b_geoqa8k.sh): Train the Qwen2.5-VL-3B-Instruct model on GeoQA problem. - [GeoQA-8k](examples/baselines/qwen2_5_vl_3b_geoqa8k.sh): Train the Qwen2.5-VL-3B-Instruct model on GeoQA problem.
## Performance Baselines
See [baselines.md](assets/baselines.md).
## Awesome Work using EasyR1 ## Awesome Work using EasyR1
- **MMR1**: Advancing the Frontiers of Multimodal Reasoning. [![[code]](https://img.shields.io/github/stars/LengSicong/MMR1)](https://github.com/LengSicong/MMR1) - **MMR1**: Advancing the Frontiers of Multimodal Reasoning. [![[code]](https://img.shields.io/github/stars/LengSicong/MMR1)](https://github.com/LengSicong/MMR1)
- **Vision-R1**: Incentivizing Reasoning Capability in Multimodal Large Language Models. [![[code]](https://img.shields.io/github/stars/Osilly/Vision-R1)](https://github.com/Osilly/Vision-R1) [![[arxiv]](https://img.shields.io/badge/arxiv-2503.06749-blue)](https://arxiv.org/abs/2503.06749) - **Vision-R1**: Incentivizing Reasoning Capability in Multimodal Large Language Models. [![[code]](https://img.shields.io/github/stars/Osilly/Vision-R1)](https://github.com/Osilly/Vision-R1) [![[arxiv]](https://img.shields.io/badge/arxiv-2503.06749-blue)](https://arxiv.org/abs/2503.06749)
- **Seg-Zero**: Reasoning-Chain Guided Segmentation via Cognitive Reinforcement. [![[code]](https://img.shields.io/github/stars/dvlab-research/Seg-Zero)](https://github.com/dvlab-research/Seg-Zero) [![[arxiv]](https://img.shields.io/badge/arxiv-2503.06520-blue)](https://arxiv.org/abs/2503.06520) - **Seg-Zero**: Reasoning-Chain Guided Segmentation via Cognitive Reinforcement. [![[code]](https://img.shields.io/github/stars/dvlab-research/Seg-Zero)](https://github.com/dvlab-research/Seg-Zero) [![[arxiv]](https://img.shields.io/badge/arxiv-2503.06520-blue)](https://arxiv.org/abs/2503.06520)
- **MetaSpatial**: Reinforcing 3D Spatial Reasoning in VLMs for the Metaverse. [![[code]](https://img.shields.io/github/stars/PzySeere/MetaSpatial)](https://github.com/PzySeere/MetaSpatial) [![[arxiv]](https://img.shields.io/badge/arxiv-2503.18470-blue)](https://arxiv.org/abs/2503.18470) - **MetaSpatial**: Reinforcing 3D Spatial Reasoning in VLMs for the Metaverse. [![[code]](https://img.shields.io/github/stars/PzySeere/MetaSpatial)](https://github.com/PzySeere/MetaSpatial) [![[arxiv]](https://img.shields.io/badge/arxiv-2503.18470-blue)](https://arxiv.org/abs/2503.18470)
- **Temporal-R1**: Envolving Temporal Reasoning Capability into LMMs via Temporal Consistent Reward - **Temporal-R1**: Envolving Temporal Reasoning Capability into LMMs via Temporal Consistent Reward. [![[code]](https://img.shields.io/github/stars/appletea233/Temporal-R1)](https://github.com/appletea233/Temporal-R1)
[![[code]](https://img.shields.io/github/stars/appletea233/Temporal-R1)](https://github.com/appletea233/Temporal-R1) - **NoisyRollout**: Reinforcing Visual Reasoning with Data Augmentation. [![[code]](https://img.shields.io/github/stars/John-AI-Lab/NoisyRollout)](https://github.com/John-AI-Lab/NoisyRollout) [![[arxiv]](https://img.shields.io/badge/arxiv-2504.13055-blue)](https://arxiv.org/pdf/2504.13055)
- **GUI-R1**: A Generalist R1-Style Vision-Language Action Model For GUI Agents. [![[code]](https://img.shields.io/github/stars/ritzz-ai/GUI-R1)](https://github.com/ritzz-ai/GUI-R1) [![[arxiv]](https://img.shields.io/badge/arxiv-2504.10458-blue)](https://arxiv.org/abs/2504.10458)
## TODO ## TODO
- Support LoRA (high priority). - Support LoRA (high priority).
...@@ -146,9 +171,17 @@ These features are temporarily disabled for now, we plan to fix them one-by-one ...@@ -146,9 +171,17 @@ These features are temporarily disabled for now, we plan to fix them one-by-one
## FAQs ## FAQs
> ValueError: Image features and image tokens do not match: tokens: 8192, features 9800
Increase the `data.max_prompt_length` or reduce the `data.max_pixels`.
> RuntimeError: CUDA Error: out of memory at /workspace/csrc/cumem_allocator.cpp:62 > RuntimeError: CUDA Error: out of memory at /workspace/csrc/cumem_allocator.cpp:62
Reduce the `worker.rollout.gpu_memory_utilization`. Reduce the `worker.rollout.gpu_memory_utilization` and enable `worker.actor.offload.offload_params`.
> RuntimeError: 0 active drivers ([]). There should only be one.
Uninstall `deepspeed` from the current python environment.
## Citation ## Citation
......
# Baselines
Environment: [hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0](https://hub.docker.com/layers/hiyouga/verl/ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0/images/sha256-335ed6cd1fe73090e458409cfa4394d6abf4cd0503ca44dbafdc28ff72e5ed20)
EasyR1 version: [v0.3.0](https://github.com/hiyouga/EasyR1/tree/v0.3.0)
Welcome to contribute new data points!
## Algorithm Baselines
### [Qwen2.5-Instruct](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct) on [Math12k](https://huggingface.co/datasets/hiyouga/math12k)
| Size | Algorithm | Bits | LR | KL | Test Score |
| ---- | ----------- | ---- | ---- | ---- | ---------- |
| 7B | GRPO | AMP | 1e-6 | 1e-2 | 0.73->0.79 |
### [Qwen2.5-VL-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) on [Geometry3k](https://huggingface.co/datasets/hiyouga/geometry3k)
| Size | Algorithm | Bits | LR | KL | Test Score |
| ---- | ----------- | ---- | ---- | ---- | ---------- |
| 7B | GRPO | AMP | 1e-6 | 1e-2 | 0.39->0.52 |
| 7B | GRPO | BF16 | 1e-6 | 1e-2 | 0.39->0.52 |
| 7B | GRPO | AMP | 1e-6 | 1e-3 | 0.39->0.52 |
| 7B | RLOO | AMP | 1e-6 | 1e-2 | 0.39->0.53 |
| 3B | GRPO | AMP | 1e-6 | 1e-2 | 0.27->0.44 |
| 32B | GRPO | BF16 | 1e-6 | 1e-2 | 0.46->0.61 |
> [!NOTE]
> The hyper-parameters not listed are all the same as the default values.
## Performance Baselines
### [Qwen2.5-VL-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) on [Geometry3k](https://huggingface.co/datasets/hiyouga/geometry3k)
| Size | GPU Type | Bits | Batch Size | vLLM Util | vLLM TP | Peak Mem | Peak VRAM | Throughput | Sec per step | Actor MFU |
| ---- | ------------- | ---- | ---------- | --------- | ------- | -------- | --------- | ---------- | ------------ | --------- |
| 3B | 8 * H100 80GB | AMP | 4 / 16 | 0.6 | 2 | 120GB | 35GB | 1200 | 180s | 6.3% |
| 7B | 8 * H100 80GB | AMP | 4 / 16 | 0.6 | 2 | 140GB | 60GB | 1200 | 180s | 13.6% |
| 7B | 8 * H100 80GB | AMP | 10 / 20 | 0.6 | 2 | 150GB | 75GB | 1400 | 170s | 19.2% |
| 7B | 8 * L20 48GB | AMP | 4 / 16 | 0.6 | 2 | 150GB | 44GB | 410 | 580s | 26.5% |
| 7B | 8 * H100 80GB | BF16 | 4 / 16 | 0.6 | 2 | 150GB | 50GB | 1280 | 190s | 13.9% |
| 32B | 8 * H100 80GB | BF16 | 1 / 8 | 0.6 | 8 | 240GB | 68GB | 360 | 860s | 11.2% |
- Batch Size: micro_batch_size_per_device_for_update / micro_batch_size_per_device_for_experience
- vLLM Util: rollout.gpu_memory_utilization
- vLLM TP: rollout.tensor_parallel_size
- Peak Mem: Peak CPU memory usage
- Peak VRAM: Peak GPU memory usage
- Throughput: Number of tokens per second per GPU by one training step
- Sec per step: Average time per step in seconds
> [!NOTE]
> The hyper-parameters not listed are all the same as the default values.
assets/wechat.jpg

111 KB | W: | H:

assets/wechat.jpg

113 KB | W: | H:

assets/wechat.jpg
assets/wechat.jpg
assets/wechat.jpg
assets/wechat.jpg
  • 2-up
  • Swipe
  • Onion skin
...@@ -9,6 +9,6 @@ python3 -m verl.trainer.main \ ...@@ -9,6 +9,6 @@ python3 -m verl.trainer.main \
data.format_prompt=./examples/format_prompt/r1v_format.jinja \ data.format_prompt=./examples/format_prompt/r1v_format.jinja \
worker.actor.model.model_path=${MODEL_PATH} \ worker.actor.model.model_path=${MODEL_PATH} \
worker.rollout.tensor_parallel_size=1 \ worker.rollout.tensor_parallel_size=1 \
worker.reward.score_function=./examples/score_function/r1v.py:compute_score \ worker.reward.reward_function=./examples/reward_function/r1v.py:compute_score \
trainer.experiment_name=qwen2_5_vl_3b_clevr \ trainer.experiment_name=qwen2_5_vl_3b_clevr \
trainer.n_gpus_per_node=2 trainer.n_gpus_per_node=2
...@@ -9,6 +9,6 @@ python3 -m verl.trainer.main \ ...@@ -9,6 +9,6 @@ python3 -m verl.trainer.main \
data.format_prompt=./examples/format_prompt/r1v_format.jinja \ data.format_prompt=./examples/format_prompt/r1v_format.jinja \
worker.actor.model.model_path=${MODEL_PATH} \ worker.actor.model.model_path=${MODEL_PATH} \
worker.rollout.tensor_parallel_size=1 \ worker.rollout.tensor_parallel_size=1 \
worker.reward.score_function=./examples/score_function/r1v.py:compute_score \ worker.reward.reward_function=./examples/reward_function/r1v.py:compute_score \
trainer.experiment_name=qwen2_5_vl_3b_geoqa8k \ trainer.experiment_name=qwen2_5_vl_3b_geoqa8k \
trainer.n_gpus_per_node=8 trainer.n_gpus_per_node=8
...@@ -7,7 +7,7 @@ data: ...@@ -7,7 +7,7 @@ data:
max_prompt_length: 2048 max_prompt_length: 2048
max_response_length: 2048 max_response_length: 2048
rollout_batch_size: 512 rollout_batch_size: 512
val_batch_size: -1 val_batch_size: 1024
format_prompt: ./examples/format_prompt/math_format.jinja format_prompt: ./examples/format_prompt/math_format.jinja
shuffle: true shuffle: true
seed: 1 seed: 1
...@@ -71,7 +71,7 @@ worker: ...@@ -71,7 +71,7 @@ worker:
reward: reward:
reward_type: function reward_type: function
score_function: ./examples/score_function/math.py:compute_score reward_function: ./examples/reward_function/math.py:compute_score
trainer: trainer:
total_episodes: 15 total_episodes: 15
......
...@@ -18,9 +18,17 @@ import re ...@@ -18,9 +18,17 @@ import re
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import numpy as np
import torch import torch
from torch.distributed._tensor import DTensor, Placement, Shard from torch.distributed._tensor import DTensor, Placement, Shard
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForTokenClassification,
AutoModelForVision2Seq,
PretrainedConfig,
PreTrainedModel,
)
def merge_by_placement(tensors: List[torch.Tensor], placement: Placement): def merge_by_placement(tensors: List[torch.Tensor], placement: Placement):
...@@ -34,14 +42,23 @@ def merge_by_placement(tensors: List[torch.Tensor], placement: Placement): ...@@ -34,14 +42,23 @@ def merge_by_placement(tensors: List[torch.Tensor], placement: Placement):
raise ValueError(f"Unsupported placement: {placement}") raise ValueError(f"Unsupported placement: {placement}")
def upload_model_to_huggingface(local_path: str, remote_path: str):
# Push to hugging face
from huggingface_hub import HfApi
api = HfApi()
api.create_repo(repo_id=remote_path, private=False, exist_ok=True)
api.upload_folder(repo_id=remote_path, folder_path=local_path, repo_type="model")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--local_dir", required=True, type=str, help="The path for your saved model") parser.add_argument("--local_dir", required=True, type=str, help="The path for your saved model")
parser.add_argument("--hf_upload_path", default=False, type=str, help="The path of the huggingface repo to upload") parser.add_argument("--hf_upload_path", default=False, type=str, help="The path of the huggingface repo to upload")
args = parser.parse_args() args = parser.parse_args()
local_dir: str = args.local_dir
assert not args.local_dir.endswith("huggingface"), "The local_dir should not end with huggingface" assert not local_dir.endswith("huggingface"), "The local_dir should not end with huggingface."
local_dir = args.local_dir
# copy rank zero to find the shape of (dp, fsdp) # copy rank zero to find the shape of (dp, fsdp)
rank = 0 rank = 0
...@@ -51,22 +68,26 @@ if __name__ == "__main__": ...@@ -51,22 +68,26 @@ if __name__ == "__main__":
if match: if match:
world_size = match.group(1) world_size = match.group(1)
break break
assert world_size, "No model file with the proper format"
state_dict = torch.load( assert world_size, "No model file with the proper format."
os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt"), map_location="cpu"
) rank0_weight_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt")
state_dict = torch.load(rank0_weight_path, map_location="cpu", weights_only=False)
pivot_key = sorted(state_dict.keys())[0] pivot_key = sorted(state_dict.keys())[0]
weight = state_dict[pivot_key] weight = state_dict[pivot_key]
assert isinstance(weight, torch.distributed._tensor.DTensor) if isinstance(weight, DTensor):
# get sharding info # get sharding info
device_mesh = weight.device_mesh device_mesh = weight.device_mesh
mesh = device_mesh.mesh mesh = device_mesh.mesh
mesh_dim_names = device_mesh.mesh_dim_names mesh_dim_names = device_mesh.mesh_dim_names
else:
# for non-DTensor
mesh = np.array([int(world_size)], dtype=np.int64)
mesh_dim_names = ("fsdp",)
print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}") print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}")
assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}" assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}."
if "tp" in mesh_dim_names: if "tp" in mesh_dim_names:
# fsdp * tp # fsdp * tp
...@@ -77,13 +98,12 @@ if __name__ == "__main__": ...@@ -77,13 +98,12 @@ if __name__ == "__main__":
total_shards = mesh.shape[-1] total_shards = mesh.shape[-1]
mesh_shape = (mesh.shape[-1],) mesh_shape = (mesh.shape[-1],)
print(f"Processing model shards with {total_shards} {mesh_shape} in total") print(f"Processing {total_shards} model shards in total.")
model_state_dict_lst = [] model_state_dict_lst = []
model_state_dict_lst.append(state_dict) model_state_dict_lst.append(state_dict)
model_state_dict_lst.extend([""] * (total_shards - 1)) model_state_dict_lst.extend([""] * (total_shards - 1))
def process_one_shard(rank): def process_one_shard(rank, model_state_dict_lst):
model_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt") model_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt")
state_dict = torch.load(model_path, map_location="cpu", weights_only=False) state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
model_state_dict_lst[rank] = state_dict model_state_dict_lst[rank] = state_dict
...@@ -91,8 +111,9 @@ if __name__ == "__main__": ...@@ -91,8 +111,9 @@ if __name__ == "__main__":
with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor:
for rank in range(1, total_shards): for rank in range(1, total_shards):
executor.submit(process_one_shard, rank) executor.submit(process_one_shard, rank, model_state_dict_lst)
state_dict = {}
state_dict: Dict[str, List[torch.Tensor]] = {}
param_placements: Dict[str, List[Placement]] = {} param_placements: Dict[str, List[Placement]] = {}
keys = set(model_state_dict_lst[0].keys()) keys = set(model_state_dict_lst[0].keys())
for key in keys: for key in keys:
...@@ -101,8 +122,8 @@ if __name__ == "__main__": ...@@ -101,8 +122,8 @@ if __name__ == "__main__":
try: try:
tensor = model_state_dict.pop(key) tensor = model_state_dict.pop(key)
except Exception: except Exception:
print("-" * 30) print(f"Cannot find key {key} in rank {rank}.")
print(model_state_dict)
if isinstance(tensor, DTensor): if isinstance(tensor, DTensor):
state_dict[key].append(tensor._local_tensor.bfloat16()) state_dict[key].append(tensor._local_tensor.bfloat16())
placements = tuple(tensor.placements) placements = tuple(tensor.placements)
...@@ -115,7 +136,7 @@ if __name__ == "__main__": ...@@ -115,7 +136,7 @@ if __name__ == "__main__":
else: else:
assert param_placements[key] == placements assert param_placements[key] == placements
else: else:
state_dict[key] = tensor.bfloat16() state_dict[key].append(tensor.bfloat16())
del model_state_dict_lst del model_state_dict_lst
...@@ -123,43 +144,44 @@ if __name__ == "__main__": ...@@ -123,43 +144,44 @@ if __name__ == "__main__":
if not isinstance(state_dict[key], list): if not isinstance(state_dict[key], list):
print(f"No need to merge key {key}") print(f"No need to merge key {key}")
continue continue
# merge shards
placements: Tuple[Shard] = param_placements[key] if key in param_placements:
if len(mesh_shape) == 1: # merge shards
# 1-D list, FSDP without TP placements: Tuple[Shard] = param_placements[key]
assert len(placements) == 1 if len(mesh_shape) == 1:
shards = state_dict[key] # 1-D list, FSDP without TP
state_dict[key] = merge_by_placement(shards, placements[0]) assert len(placements) == 1
shards = state_dict[key]
state_dict[key] = merge_by_placement(shards, placements[0])
else:
# 2-D list, FSDP + TP
raise NotImplementedError("FSDP + TP is not supported yet.")
else: else:
# 2-D list, FSDP + TP state_dict[key] = torch.cat(state_dict[key], dim=0)
raise NotImplementedError("FSDP + TP is not supported yet")
print("Writing to local disk") print("Merge completed.")
hf_path = os.path.join(local_dir, "huggingface") hf_path = os.path.join(local_dir, "huggingface")
config = AutoConfig.from_pretrained(hf_path) config: PretrainedConfig = AutoConfig.from_pretrained(hf_path)
architectures: List[str] = getattr(config, "architectures", ["Unknown"])
if "ForTokenClassification" in config.architectures[0]:
auto_model = AutoModelForTokenClassification if "ForTokenClassification" in architectures[0]:
elif "ForCausalLM" in config.architectures[0]: AutoClass = AutoModelForTokenClassification
auto_model = AutoModelForCausalLM elif "ForCausalLM" in architectures[0]:
elif "ForConditionalGeneration" in config.architectures[0]: AutoClass = AutoModelForCausalLM
auto_model = AutoModelForVision2Seq elif "ForConditionalGeneration" in architectures[0]:
AutoClass = AutoModelForVision2Seq
else: else:
raise NotImplementedError(f"Unknown architecture {config.architectures}") raise NotImplementedError(f"Unknown architecture {architectures}.")
with torch.device("meta"): with torch.device("meta"):
model = auto_model.from_config(config, torch_dtype=torch.bfloat16) model: PreTrainedModel = AutoClass.from_config(config, torch_dtype=torch.bfloat16)
assert isinstance(model, PreTrainedModel)
model.to_empty(device="cpu") model.to_empty(device="cpu")
print(f"Saving model to {hf_path}") print(f"Saving model to {hf_path}...")
model.save_pretrained(hf_path, state_dict=state_dict) model.save_pretrained(hf_path, state_dict=state_dict)
del state_dict del state_dict, model
del model
if args.hf_upload_path:
# Push to hugging face
from huggingface_hub import HfApi
api = HfApi() if args.hf_upload_path:
api.create_repo(repo_id=args.hf_upload_path, private=False, exist_ok=True) upload_model_to_huggingface(hf_path, args.hf_upload_path)
api.upload_folder(folder_path=hf_path, repo_id=args.hf_upload_path, repo_type="model")
...@@ -51,7 +51,7 @@ class DataConfig: ...@@ -51,7 +51,7 @@ class DataConfig:
def post_init(self): def post_init(self):
if self.format_prompt is not None: if self.format_prompt is not None:
if os.path.exists(self.format_prompt): if os.path.exists(self.format_prompt): # ray job uses absolute path
self.format_prompt = os.path.abspath(self.format_prompt) self.format_prompt = os.path.abspath(self.format_prompt)
else: else:
self.format_prompt = None self.format_prompt = None
...@@ -94,7 +94,7 @@ class TrainerConfig: ...@@ -94,7 +94,7 @@ class TrainerConfig:
if self.save_checkpoint_path is None: if self.save_checkpoint_path is None:
self.save_checkpoint_path = os.path.join("checkpoints", self.project_name, self.experiment_name) self.save_checkpoint_path = os.path.join("checkpoints", self.project_name, self.experiment_name)
self.save_checkpoint_path = os.path.abspath(self.save_checkpoint_path) self.save_checkpoint_path = os.path.abspath(self.save_checkpoint_path) # ray job uses absolute path
if self.load_checkpoint_path is not None: if self.load_checkpoint_path is not None:
self.load_checkpoint_path = os.path.abspath(self.load_checkpoint_path) self.load_checkpoint_path = os.path.abspath(self.load_checkpoint_path)
......
...@@ -65,12 +65,11 @@ class Runner: ...@@ -65,12 +65,11 @@ class Runner:
} }
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
reward_fn = FunctionRewardManager(config=config.worker.reward, tokenizer=tokenizer) RemoteRewardManager = ray.remote(FunctionRewardManager).options(num_cpus=config.worker.reward.num_cpus)
val_reward_fn = FunctionRewardManager(config=config.worker.reward, tokenizer=tokenizer) reward_fn = RemoteRewardManager.remote(config.worker.reward, tokenizer)
val_reward_fn = RemoteRewardManager.remote(config.worker.reward, tokenizer)
train_dataloader, val_dataloader = create_dataloader( train_dataloader, val_dataloader = create_dataloader(config.data, tokenizer, processor)
config=config.data, tokenizer=tokenizer, processor=processor
)
trainer = RayPPOTrainer( trainer = RayPPOTrainer(
config=config, config=config,
......
...@@ -19,16 +19,14 @@ This trainer supports model-agonistic model initialization with huggingface ...@@ -19,16 +19,14 @@ This trainer supports model-agonistic model initialization with huggingface
import os import os
import uuid import uuid
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum, IntEnum, auto from enum import Enum, IntEnum, auto
from typing import Any, Callable, Dict, List, Optional, Tuple, Type from typing import Any, Dict, List, Optional, Type
import numpy as np import numpy as np
import ray import ray
import torch import torch
from codetiming import Timer
from ray.experimental.tqdm_ray import tqdm from ray.experimental.tqdm_ray import tqdm
from torchdata.stateful_dataloader import StatefulDataLoader from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
...@@ -40,9 +38,10 @@ from ..single_controller.ray.base import create_colocated_worker_cls ...@@ -40,9 +38,10 @@ from ..single_controller.ray.base import create_colocated_worker_cls
from ..utils import torch_functional as VF from ..utils import torch_functional as VF
from ..utils.checkpoint import CHECKPOINT_TRACKER, remove_obsolete_ckpt from ..utils.checkpoint import CHECKPOINT_TRACKER, remove_obsolete_ckpt
from ..utils.logger import Tracker from ..utils.logger import Tracker
from ..utils.py_functional import convert_dict_to_str from ..utils.py_functional import convert_dict_to_str, timer
from ..utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance from ..utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
from ..workers.fsdp_workers import FSDPWorker from ..workers.fsdp_workers import FSDPWorker
from ..workers.reward import FunctionRewardManager
from . import core_algos from . import core_algos
from .config import PPOConfig from .config import PPOConfig
from .metrics import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics, reduce_metrics from .metrics import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics, reduce_metrics
...@@ -162,14 +161,6 @@ def compute_advantage(data: DataProto, adv_estimator: AdvantageEstimator, gamma: ...@@ -162,14 +161,6 @@ def compute_advantage(data: DataProto, adv_estimator: AdvantageEstimator, gamma:
return data return data
@contextmanager
def _timer(name: str, timing_raw: Dict[str, float]):
with Timer(name=name, logger=None) as timer:
yield
timing_raw[name] = timer.last
class RayPPOTrainer: class RayPPOTrainer:
""" """
Note that this trainer runs on the driver process on a single CPU/GPU node. Note that this trainer runs on the driver process on a single CPU/GPU node.
...@@ -185,8 +176,8 @@ class RayPPOTrainer: ...@@ -185,8 +176,8 @@ class RayPPOTrainer:
role_worker_mapping: dict[Role, Type[Worker]], role_worker_mapping: dict[Role, Type[Worker]],
resource_pool_manager: ResourcePoolManager, resource_pool_manager: ResourcePoolManager,
ray_worker_group_cls: Type[RayWorkerGroup] = RayWorkerGroup, ray_worker_group_cls: Type[RayWorkerGroup] = RayWorkerGroup,
reward_fn: Optional[Callable[[DataProto], Tuple[torch.Tensor, Dict[str, List[float]]]]] = None, reward_fn: Optional[FunctionRewardManager] = None,
val_reward_fn: Optional[Callable[[DataProto], Tuple[torch.Tensor, Dict[str, List[float]]]]] = None, val_reward_fn: Optional[FunctionRewardManager] = None,
): ):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.processor = processor self.processor = processor
...@@ -307,7 +298,6 @@ class RayPPOTrainer: ...@@ -307,7 +298,6 @@ class RayPPOTrainer:
test_gen_batch, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size) test_gen_batch, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size)
test_output_gen_batch = self.actor_rollout_wg.generate_sequences(test_gen_batch) test_output_gen_batch = self.actor_rollout_wg.generate_sequences(test_gen_batch)
test_output_gen_batch = unpad_dataproto(test_output_gen_batch, pad_size=pad_size) test_output_gen_batch = unpad_dataproto(test_output_gen_batch, pad_size=pad_size)
print("validation generation end")
# Store generated outputs # Store generated outputs
output_ids = test_output_gen_batch.batch["responses"] output_ids = test_output_gen_batch.batch["responses"]
...@@ -317,7 +307,7 @@ class RayPPOTrainer: ...@@ -317,7 +307,7 @@ class RayPPOTrainer:
test_batch = test_batch.union(test_output_gen_batch) test_batch = test_batch.union(test_output_gen_batch)
# evaluate using reward_function # evaluate using reward_function
reward_tensor, reward_metrics = self.val_reward_fn(test_batch) reward_tensor, reward_metrics = ray.get(self.val_reward_fn.compute_reward.remote(test_batch))
# Store scores # Store scores
scores = reward_tensor.sum(-1).cpu().tolist() scores = reward_tensor.sum(-1).cpu().tolist()
...@@ -504,20 +494,20 @@ class RayPPOTrainer: ...@@ -504,20 +494,20 @@ class RayPPOTrainer:
non_tensor_batch_keys=["raw_prompt_ids"], non_tensor_batch_keys=["raw_prompt_ids"],
) )
with _timer("step", timing_raw): with timer("step", timing_raw):
# generate a batch # generate a batch
with _timer("gen", timing_raw): # wg: worker group with timer("gen", timing_raw): # wg: worker group
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
if self.config.algorithm.adv_estimator == "remax": if self.config.algorithm.adv_estimator == "remax":
with _timer("gen_max", timing_raw): with timer("gen_max", timing_raw):
gen_baseline_batch = deepcopy(gen_batch) gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["temperature"] = 0 gen_baseline_batch.meta_info["temperature"] = 0
gen_baseline_batch.meta_info["n"] = 1 gen_baseline_batch.meta_info["n"] = 1
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
batch = batch.union(gen_baseline_output) batch = batch.union(gen_baseline_output)
reward_baseline_tensor, _ = self.reward_fn(batch) reward_baseline_tensor, _ = ray.get(self.reward_fn.compute_reward.remote(batch))
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
...@@ -532,19 +522,6 @@ class RayPPOTrainer: ...@@ -532,19 +522,6 @@ class RayPPOTrainer:
batch = batch.union(gen_batch_output) batch = batch.union(gen_batch_output)
batch.non_tensor_batch.pop("multi_modal_data", None) batch.non_tensor_batch.pop("multi_modal_data", None)
# compute reward
with _timer("reward", timing_raw):
if self.use_reward_model:
raise NotImplementedError("Reward model is not supported yet.")
# we combine with rule-based rm
reward_tensor, reward_metrics = self.reward_fn(batch)
batch.batch["token_level_scores"] = reward_tensor
reward_metrics = {
f"reward/{key}": value for key, value in reduce_metrics(reward_metrics).items()
}
metrics.update(reward_metrics)
# balance the number of valid tokens on each dp rank. # balance the number of valid tokens on each dp rank.
# Note that this breaks the order of data inside the batch. # Note that this breaks the order of data inside the batch.
# Please take care when you implement group based adv computation such as GRPO and rloo # Please take care when you implement group based adv computation such as GRPO and rloo
...@@ -553,30 +530,38 @@ class RayPPOTrainer: ...@@ -553,30 +530,38 @@ class RayPPOTrainer:
# compute global_valid tokens # compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
# compute reward
with timer("reward", timing_raw):
reward_ref = self.reward_fn.compute_reward.remote(batch)
# recompute old_log_probs # recompute old_log_probs
with _timer("old", timing_raw): with timer("old", timing_raw):
old_log_probs = self.actor_rollout_wg.compute_log_probs(batch) old_log_probs = self.actor_rollout_wg.compute_log_probs(batch)
batch = batch.union(old_log_probs) batch = batch.union(old_log_probs)
# compute ref_log_probs # compute ref_log_probs
if self.use_reference_policy: if self.use_reference_policy:
with _timer("ref", timing_raw): with timer("ref", timing_raw):
ref_log_probs = self.ref_policy_wg.compute_ref_log_probs(batch) ref_log_probs = self.ref_policy_wg.compute_ref_log_probs(batch)
batch = batch.union(ref_log_probs) batch = batch.union(ref_log_probs)
# compute values # compute values
if self.use_critic: if self.use_critic:
with _timer("values", timing_raw): with timer("values", timing_raw):
values = self.critic_wg.compute_values(batch) values = self.critic_wg.compute_values(batch)
batch = batch.union(values) batch = batch.union(values)
with _timer("adv", timing_raw): with timer("adv", timing_raw):
# get token level scores
reward_tensor, reward_metrics = ray.get(reward_ref)
batch.batch["token_level_scores"] = reward_tensor
reward_metrics = {f"reward/{k}": v for k, v in reduce_metrics(reward_metrics).items()}
metrics.update(reward_metrics)
# apply kl penalty if available # apply kl penalty if available
if not self.config.algorithm.use_kl_loss and self.use_reference_policy: if not self.config.algorithm.use_kl_loss and self.use_reference_policy:
# apply kl penalty to reward # apply kl penalty to reward
batch, kl_metrics = apply_kl_penalty( batch, kl_metrics = apply_kl_penalty(batch, self.kl_ctrl, self.config.algorithm.kl_penalty)
batch, kl_ctrl=self.kl_ctrl, kl_penalty=self.config.algorithm.kl_penalty
)
metrics.update(kl_metrics) metrics.update(kl_metrics)
else: else:
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
...@@ -591,7 +576,7 @@ class RayPPOTrainer: ...@@ -591,7 +576,7 @@ class RayPPOTrainer:
# update critic # update critic
if self.use_critic: if self.use_critic:
with _timer("update_critic", timing_raw): with timer("update_critic", timing_raw):
critic_output = self.critic_wg.update_critic(batch) critic_output = self.critic_wg.update_critic(batch)
critic_metrics = reduce_metrics(critic_output.non_tensor_batch) critic_metrics = reduce_metrics(critic_output.non_tensor_batch)
...@@ -599,7 +584,7 @@ class RayPPOTrainer: ...@@ -599,7 +584,7 @@ class RayPPOTrainer:
# update actor # update actor
if self.config.trainer.critic_warmup <= self.global_step: if self.config.trainer.critic_warmup <= self.global_step:
with _timer("update_actor", timing_raw): with timer("update_actor", timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch) actor_output = self.actor_rollout_wg.update_actor(batch)
actor_metrics = reduce_metrics(actor_output.non_tensor_batch) actor_metrics = reduce_metrics(actor_output.non_tensor_batch)
...@@ -611,13 +596,13 @@ class RayPPOTrainer: ...@@ -611,13 +596,13 @@ class RayPPOTrainer:
and self.config.trainer.val_freq > 0 and self.config.trainer.val_freq > 0
and self.global_step % self.config.trainer.val_freq == 0 and self.global_step % self.config.trainer.val_freq == 0
): ):
with _timer("validation", timing_raw): with timer("validation", timing_raw):
val_metrics = self._validate() val_metrics = self._validate()
metrics.update(val_metrics) metrics.update(val_metrics)
if self.config.trainer.save_freq > 0 and self.global_step % self.config.trainer.save_freq == 0: if self.config.trainer.save_freq > 0 and self.global_step % self.config.trainer.save_freq == 0:
with _timer("save_checkpoint", timing_raw): with timer("save_checkpoint", timing_raw):
self._save_checkpoint() self._save_checkpoint()
# collect metrics # collect metrics
......
...@@ -17,11 +17,13 @@ Contain small python utility functions ...@@ -17,11 +17,13 @@ Contain small python utility functions
import importlib.util import importlib.util
import re import re
from contextlib import contextmanager
from functools import lru_cache from functools import lru_cache
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
import numpy as np import numpy as np
import yaml import yaml
from codetiming import Timer
from yaml import Dumper from yaml import Dumper
...@@ -101,3 +103,11 @@ def flatten_dict(data: Dict[str, Any], parent_key: str = "", sep: str = "/") -> ...@@ -101,3 +103,11 @@ def flatten_dict(data: Dict[str, Any], parent_key: str = "", sep: str = "/") ->
def convert_dict_to_str(data: Dict[str, Any]) -> str: def convert_dict_to_str(data: Dict[str, Any]) -> str:
return yaml.dump(data, indent=2) return yaml.dump(data, indent=2)
@contextmanager
def timer(name: str, timing_raw: Dict[str, float]):
with Timer(name=name, logger=None) as timer:
yield
timing_raw[name] = timer.last
...@@ -33,7 +33,7 @@ class ModelConfig: ...@@ -33,7 +33,7 @@ class ModelConfig:
if self.tokenizer_path is None: if self.tokenizer_path is None:
self.tokenizer_path = self.model_path self.tokenizer_path = self.model_path
if self.model_path is not None and os.path.exists(self.model_path): if self.model_path is not None and os.path.exists(self.model_path): # ray job uses absolute path
self.model_path = os.path.abspath(self.model_path) self.model_path = os.path.abspath(self.model_path)
if self.tokenizer_path is not None and os.path.exists(self.tokenizer_path): if self.tokenizer_path is not None and os.path.exists(self.tokenizer_path):
......
...@@ -262,10 +262,10 @@ class FSDPWorker(Worker): ...@@ -262,10 +262,10 @@ class FSDPWorker(Worker):
else: else:
sync_module_states = False sync_module_states = False
param_init_fn = None param_init_fn = None
## TODO: 模型指定到卡
rank = torch.cuda.set_device(self.rank)
model = model.to(rank)
# rank = torch.cuda.set_device(self.rank)
# model = model.to(rank)
print(f"!!! local_rank={self.rank}, torch.cuda.current_device()={torch.cuda.current_device()}")
self.fsdp_module = FSDP( self.fsdp_module = FSDP(
model, model,
sharding_strategy=sharding_strategy, sharding_strategy=sharding_strategy,
...@@ -284,7 +284,7 @@ class FSDPWorker(Worker): ...@@ -284,7 +284,7 @@ class FSDPWorker(Worker):
if self._is_actor or self._is_critic: if self._is_actor or self._is_critic:
if optim_config.strategy == "adamw": if optim_config.strategy == "adamw":
self.optimizer = torch.optim.AdamW( self.optimizer = torch.optim.AdamW(
self.fsdp_module.parameters(), filter(lambda p: p.requires_grad, self.fsdp_module.parameters()),
lr=optim_config.lr, lr=optim_config.lr,
betas=optim_config.betas, betas=optim_config.betas,
weight_decay=optim_config.weight_decay, weight_decay=optim_config.weight_decay,
...@@ -292,7 +292,7 @@ class FSDPWorker(Worker): ...@@ -292,7 +292,7 @@ class FSDPWorker(Worker):
) )
elif optim_config.strategy == "adamw_bf16": elif optim_config.strategy == "adamw_bf16":
self.optimizer = AnyPrecisionAdamW( self.optimizer = AnyPrecisionAdamW(
self.fsdp_module.parameters(), filter(lambda p: p.requires_grad, self.fsdp_module.parameters()),
lr=optim_config.lr, lr=optim_config.lr,
betas=optim_config.betas, betas=optim_config.betas,
weight_decay=optim_config.weight_decay, weight_decay=optim_config.weight_decay,
......
...@@ -23,20 +23,21 @@ from typing import Optional ...@@ -23,20 +23,21 @@ from typing import Optional
@dataclass @dataclass
class RewardConfig: class RewardConfig:
reward_type: str = "function" reward_type: str = "function"
score_function: Optional[str] = None reward_function: Optional[str] = None
score_function_kwargs: dict = field(default_factory=dict) reward_function_kwargs: dict = field(default_factory=dict)
skip_special_tokens: bool = True skip_special_tokens: bool = True
num_cpus: int = 1
"""auto keys""" """auto keys"""
score_function_name: Optional[str] = field(default=None, init=False) reward_function_name: Optional[str] = field(default=None, init=False)
def post_init(self): def post_init(self):
if self.score_function is not None: if self.reward_function is not None: # support custom reward function, e.g., ./math.py:main
if ":" not in self.score_function: if ":" not in self.reward_function:
self.score_function_name = "main" self.reward_function_name = "main"
else: else:
self.score_function, self.score_function_name = self.score_function.split(":", maxsplit=1) self.reward_function, self.reward_function_name = self.reward_function.rsplit(":", maxsplit=1)
if os.path.exists(self.score_function): if os.path.exists(self.reward_function): # ray job uses absolute path
self.score_function = os.path.abspath(self.score_function) self.reward_function = os.path.abspath(self.reward_function)
else: else:
self.score_function = None self.reward_function = None
...@@ -16,7 +16,6 @@ import importlib.util ...@@ -16,7 +16,6 @@ import importlib.util
import os import os
import sys import sys
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass
from functools import partial from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, TypedDict from typing import Callable, Dict, List, Optional, Tuple, TypedDict
...@@ -33,38 +32,37 @@ class RewardScore(TypedDict): ...@@ -33,38 +32,37 @@ class RewardScore(TypedDict):
accuracy: Optional[float] accuracy: Optional[float]
ScoreFunction = Callable[[str, str], RewardScore] RewardFunction = Callable[[str, str], RewardScore]
@dataclass
class FunctionRewardManager: class FunctionRewardManager:
config: RewardConfig """Reward manager for rule-based reward."""
tokenizer: PreTrainedTokenizer
def __post_init__(self): def __init__(self, config: RewardConfig, tokenizer: PreTrainedTokenizer):
"""Load score function.""" if config.reward_function is None:
if self.config.score_function is None: raise ValueError("Reward function is not provided.")
raise ValueError("Score function is not provided.")
if not os.path.exists(self.config.score_function): if not os.path.exists(config.reward_function):
raise FileNotFoundError(f"Score function file {self.config.score_function} not found.") raise FileNotFoundError(f"Reward function file {config.reward_function} not found.")
spec = importlib.util.spec_from_file_location("custom_score_fn", self.config.score_function) spec = importlib.util.spec_from_file_location("custom_reward_fn", config.reward_function)
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
try: try:
sys.modules["custom_score_fn"] = module sys.modules["custom_reward_fn"] = module
spec.loader.exec_module(module) spec.loader.exec_module(module)
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to load score function: {e}") raise RuntimeError(f"Failed to load reward function: {e}")
if not hasattr(module, self.config.score_function_name): if not hasattr(module, config.reward_function_name):
raise AttributeError(f"Module {module} does not have function {self.config.score_function_name}.") raise AttributeError(f"Module {module} does not have function {config.reward_function_name}.")
score_fn: ScoreFunction = getattr(module, self.config.score_function_name) reward_fn: RewardFunction = getattr(module, config.reward_function_name)
print(f"Using score function `{self.config.score_function_name}` from `{self.config.score_function}`.") print(f"Using reward function `{config.reward_function_name}` from `{config.reward_function}`.")
self.score_fn = partial(score_fn, **self.config.score_function_kwargs) self.reward_fn = partial(reward_fn, **config.reward_function_kwargs)
self.config = config
self.tokenizer = tokenizer
def __call__(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]: def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
reward_metrics = defaultdict(list) reward_metrics = defaultdict(list)
for i in range(len(data)): for i in range(len(data)):
...@@ -79,7 +77,7 @@ class FunctionRewardManager: ...@@ -79,7 +77,7 @@ class FunctionRewardManager:
) )
ground_truth = data_item.non_tensor_batch["ground_truth"] ground_truth = data_item.non_tensor_batch["ground_truth"]
score = self.score_fn(response_str, ground_truth) score = self.reward_fn(response_str, ground_truth)
reward_tensor[i, valid_response_length - 1] = score["overall"] reward_tensor[i, valid_response_length - 1] = score["overall"]
for key, value in score.items(): for key, value in score.items():
reward_metrics[key].append(value) reward_metrics[key].append(value)
......
...@@ -69,7 +69,7 @@ class vLLMRollout(BaseRollout): ...@@ -69,7 +69,7 @@ class vLLMRollout(BaseRollout):
self.inference_engine = LLM( self.inference_engine = LLM(
model=model_path, model=model_path,
skip_tokenizer_init=False, skip_tokenizer_init=False,
trust_remote_code=config.trust_remote_code, trust_remote_code=True,
load_format="dummy", load_format="dummy",
dtype=PrecisionType.to_str(PrecisionType.to_dtype(config.dtype)), dtype=PrecisionType.to_str(PrecisionType.to_dtype(config.dtype)),
seed=config.seed, seed=config.seed,
...@@ -79,13 +79,12 @@ class vLLMRollout(BaseRollout): ...@@ -79,13 +79,12 @@ class vLLMRollout(BaseRollout):
gpu_memory_utilization=config.gpu_memory_utilization, gpu_memory_utilization=config.gpu_memory_utilization,
max_num_batched_tokens=config.max_num_batched_tokens, max_num_batched_tokens=config.max_num_batched_tokens,
disable_log_stats=config.disable_log_stats, disable_log_stats=config.disable_log_stats,
enforce_eager=config.enforce_eager, enforce_eager=True,
disable_custom_all_reduce=True, disable_custom_all_reduce=True,
limit_mm_per_prompt={"image": config.limit_images} if config.limit_images > 0 else None, limit_mm_per_prompt={"image": config.limit_images} if config.limit_images > 0 else None,
disable_mm_preprocessor_cache=True, disable_mm_preprocessor_cache=True,
enable_chunked_prefill=config.enable_chunked_prefill, enable_chunked_prefill=config.enable_chunked_prefill,
enable_sleep_mode=False, # nv True rocm False enable_sleep_mode=False,
# swap_space=20,
) )
# Offload vllm model to reduce peak memory usage # Offload vllm model to reduce peak memory usage
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment