Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
EasyR1
Commits
20247eb8
"docs/git@developer.sourcefind.cn:change/sglang.git" did not exist on "d8a136a11332117229aca73abbc5edcf2b9ebd76"
Commit
20247eb8
authored
May 06, 2025
by
chenych
Browse files
Update 0506
parent
6065b946
Changes
26
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
132 additions
and
44 deletions
+132
-44
README.md
README.md
+23
-1
examples/baselines/qwen2_5_vl_3b_clevr.sh
examples/baselines/qwen2_5_vl_3b_clevr.sh
+5
-0
examples/baselines/qwen2_5_vl_3b_geoqa8k.sh
examples/baselines/qwen2_5_vl_3b_geoqa8k.sh
+5
-0
examples/config.yaml
examples/config.yaml
+6
-4
examples/qwen2_5_7b_math_grpo.sh
examples/qwen2_5_7b_math_grpo.sh
+5
-5
examples/qwen2_5_vl_32b_geo3k_grpo.sh
examples/qwen2_5_vl_32b_geo3k_grpo.sh
+4
-0
examples/qwen2_5_vl_3b_geo3k_grpo.sh
examples/qwen2_5_vl_3b_geo3k_grpo.sh
+4
-0
examples/qwen2_5_vl_7b_geo3k_grpo.sh
examples/qwen2_5_vl_7b_geo3k_grpo.sh
+4
-0
examples/qwen2_5_vl_7b_geo3k_reinforce.sh
examples/qwen2_5_vl_7b_geo3k_reinforce.sh
+4
-0
examples/qwen2_5_vl_7b_geo3k_swanlab.sh
examples/qwen2_5_vl_7b_geo3k_swanlab.sh
+4
-0
examples/qwen2_5_vl_7b_multi_image.sh
examples/qwen2_5_vl_7b_multi_image.sh
+3
-0
examples/qwen3_4b_math_grpo.sh
examples/qwen3_4b_math_grpo.sh
+13
-0
examples/reward_function/math.py
examples/reward_function/math.py
+20
-14
examples/reward_function/r1v.py
examples/reward_function/r1v.py
+8
-8
examples/runtime_env.yaml
examples/runtime_env.yaml
+2
-1
verl/models/monkey_patch.py
verl/models/monkey_patch.py
+1
-1
verl/trainer/config.py
verl/trainer/config.py
+2
-1
verl/trainer/main.py
verl/trainer/main.py
+11
-3
verl/trainer/ray_trainer.py
verl/trainer/ray_trainer.py
+2
-2
verl/utils/checkpoint/fsdp_checkpoint_manager.py
verl/utils/checkpoint/fsdp_checkpoint_manager.py
+6
-4
No files found.
README.md
View file @
20247eb8
...
@@ -101,6 +101,13 @@ pip install "tensordict<0.6"
...
@@ -101,6 +101,13 @@ pip install "tensordict<0.6"
pip
install
-e
.
pip
install
-e
.
```
```
### 数据集
可根据下面的样例数据去构造自己的数据集
-
Text dataset: https://huggingface.co/datasets/hiyouga/math12k
-
Image-text dataset: https://huggingface.co/datasets/hiyouga/geometry3k
-
Multi-image-text dataset: https://huggingface.co/datasets/hiyouga/journeybench-multi-image-vqa
### GRPO 训练
### GRPO 训练
```
bash
```
bash
...
@@ -114,7 +121,8 @@ python3 scripts/model_merger.py --local_dir path_to_your_actor_checkpoint
...
@@ -114,7 +121,8 @@ python3 scripts/model_merger.py --local_dir path_to_your_actor_checkpoint
```
```
> [!NOTE]
> [!NOTE]
> 如果无法连接到Hugging Face,请先安装`pip install -U huggingface_hub hf_transfer`,再在启动前增加 `export HF_ENDPOINT=https://hf-mirror.com`命令
>
> 如果您想使用SwanLab日志记录器,请考虑使用 `bash examples/qwen2_5_vl_7b_geo3k_swanlab.sh`.
> 如果您想使用SwanLab日志记录器,请考虑使用 `bash examples/qwen2_5_vl_7b_geo3k_swanlab.sh`.
## 自定义数据集
## 自定义数据集
...
@@ -137,3 +145,17 @@ python3 scripts/model_merger.py --local_dir path_to_your_actor_checkpoint
...
@@ -137,3 +145,17 @@ python3 scripts/model_merger.py --local_dir path_to_your_actor_checkpoint
这些功能目前暂时禁用,我们计划在未来的更新中逐一修复。
这些功能目前暂时禁用,我们计划在未来的更新中逐一修复。
-
视觉语言模型目前不兼容 padding-free 训练和 DeepSpeed Ulysses并行方法。
-
视觉语言模型目前不兼容 padding-free 训练和 DeepSpeed Ulysses并行方法。
### 常见问题及解决办法
> ValueError: Image features and image tokens do not match: tokens: 8192, features 9800
增大
`data.max_prompt_length`
的数值或者减小
`data.max_pixels`
的数值.
> RuntimeError: CUDA Error: out of memory at /workspace/csrc/cumem_allocator.cpp:62
减小
`worker.rollout.gpu_memory_utilization`
的数值并且确认开启
`worker.actor.offload.offload_params`
.
> RuntimeError: 0 active drivers ([]). There should only be one.
在当前python环境下卸载
`deepspeed`
.
examples/baselines/qwen2_5_vl_3b_clevr.sh
View file @
20247eb8
#!/bin/bash
set
-x
set
-x
export
PYTHONUNBUFFERED
=
1
MODEL_PATH
=
Qwen/Qwen2.5-VL-3B-Instruct
# replace it with your local file path
MODEL_PATH
=
Qwen/Qwen2.5-VL-3B-Instruct
# replace it with your local file path
python3
-m
verl.trainer.main
\
python3
-m
verl.trainer.main
\
...
@@ -9,6 +13,7 @@ python3 -m verl.trainer.main \
...
@@ -9,6 +13,7 @@ python3 -m verl.trainer.main \
data.format_prompt
=
./examples/format_prompt/r1v_format.jinja
\
data.format_prompt
=
./examples/format_prompt/r1v_format.jinja
\
worker.actor.model.model_path
=
${
MODEL_PATH
}
\
worker.actor.model.model_path
=
${
MODEL_PATH
}
\
worker.rollout.tensor_parallel_size
=
1
\
worker.rollout.tensor_parallel_size
=
1
\
worker.reward.reward_type
=
sequential
\
worker.reward.reward_function
=
./examples/reward_function/r1v.py:compute_score
\
worker.reward.reward_function
=
./examples/reward_function/r1v.py:compute_score
\
trainer.experiment_name
=
qwen2_5_vl_3b_clevr
\
trainer.experiment_name
=
qwen2_5_vl_3b_clevr
\
trainer.n_gpus_per_node
=
2
trainer.n_gpus_per_node
=
2
examples/baselines/qwen2_5_vl_3b_geoqa8k.sh
View file @
20247eb8
#!/bin/bash
set
-x
set
-x
export
PYTHONUNBUFFERED
=
1
MODEL_PATH
=
Qwen/Qwen2.5-VL-3B-Instruct
# replace it with your local file path
MODEL_PATH
=
Qwen/Qwen2.5-VL-3B-Instruct
# replace it with your local file path
python3
-m
verl.trainer.main
\
python3
-m
verl.trainer.main
\
...
@@ -9,6 +13,7 @@ python3 -m verl.trainer.main \
...
@@ -9,6 +13,7 @@ python3 -m verl.trainer.main \
data.format_prompt
=
./examples/format_prompt/r1v_format.jinja
\
data.format_prompt
=
./examples/format_prompt/r1v_format.jinja
\
worker.actor.model.model_path
=
${
MODEL_PATH
}
\
worker.actor.model.model_path
=
${
MODEL_PATH
}
\
worker.rollout.tensor_parallel_size
=
1
\
worker.rollout.tensor_parallel_size
=
1
\
worker.reward.reward_type
=
sequential
\
worker.reward.reward_function
=
./examples/reward_function/r1v.py:compute_score
\
worker.reward.reward_function
=
./examples/reward_function/r1v.py:compute_score
\
trainer.experiment_name
=
qwen2_5_vl_3b_geoqa8k
\
trainer.experiment_name
=
qwen2_5_vl_3b_geoqa8k
\
trainer.n_gpus_per_node
=
8
trainer.n_gpus_per_node
=
8
examples/config.yaml
View file @
20247eb8
...
@@ -9,6 +9,7 @@ data:
...
@@ -9,6 +9,7 @@ data:
rollout_batch_size
:
512
rollout_batch_size
:
512
val_batch_size
:
1024
val_batch_size
:
1024
format_prompt
:
./examples/format_prompt/math_format.jinja
format_prompt
:
./examples/format_prompt/math_format.jinja
override_chat_template
:
null
shuffle
:
true
shuffle
:
true
seed
:
1
seed
:
1
max_pixels
:
4194304
max_pixels
:
4194304
...
@@ -70,16 +71,17 @@ worker:
...
@@ -70,16 +71,17 @@ worker:
offload_params
:
false
offload_params
:
false
reward
:
reward
:
reward_type
:
function
reward_type
:
batch
reward_function
:
./examples/reward_function/math.py:compute_score
reward_function
:
./examples/reward_function/math.py:compute_score
trainer
:
trainer
:
total_ep
isode
s
:
15
total_ep
och
s
:
15
logger
:
[
"
console"
,
"
wandb"
]
max_steps
:
null
project_name
:
easy_r1
project_name
:
easy_r1
experiment_name
:
qwen2_5_7b_math_grpo
experiment_name
:
qwen2_5_7b_math_grpo
n_gpus_per_node
:
8
logger
:
[
"
console"
,
"
wandb"
]
nnodes
:
1
nnodes
:
1
n_gpus_per_node
:
8
val_freq
:
5
# -1 to disable
val_freq
:
5
# -1 to disable
val_before_train
:
true
val_before_train
:
true
val_only
:
false
val_only
:
false
...
...
examples/qwen2_5_7b_math_grpo.sh
View file @
20247eb8
#!/bin/bash
set
-x
set
-x
export
PYTHONUNBUFFERED
=
1
MODEL_PATH
=
Qwen/Qwen2.5-7B-Instruct
# replace it with your local file path
MODEL_PATH
=
Qwen/Qwen2.5-7B-Instruct
# replace it with your local file path
python3
-m
verl.trainer.main
\
python3
-m
verl.trainer.main
\
config
=
examples/config.yaml
\
config
=
examples/config.yaml
\
data.train_files
=
hiyouga/math12k@train
\
worker.actor.model.model_path
=
${
MODEL_PATH
}
data.val_files
=
hiyouga/math12k@test
\
worker.actor.model.model_path
=
${
MODEL_PATH
}
\
trainer.experiment_name
=
qwen2_5_7b_math_grpo
\
trainer.n_gpus_per_node
=
8
examples/qwen2_5_vl_32b_geo3k_grpo.sh
View file @
20247eb8
#!/bin/bash
set
-x
set
-x
export
PYTHONUNBUFFERED
=
1
MODEL_PATH
=
Qwen/Qwen2.5-VL-32B-Instruct
# replace it with your local file path
MODEL_PATH
=
Qwen/Qwen2.5-VL-32B-Instruct
# replace it with your local file path
python3
-m
verl.trainer.main
\
python3
-m
verl.trainer.main
\
...
...
examples/qwen2_5_vl_3b_geo3k_grpo.sh
View file @
20247eb8
#!/bin/bash
set
-x
set
-x
export
PYTHONUNBUFFERED
=
1
MODEL_PATH
=
Qwen/Qwen2.5-VL-3B-Instruct
# replace it with your local file path
MODEL_PATH
=
Qwen/Qwen2.5-VL-3B-Instruct
# replace it with your local file path
python3
-m
verl.trainer.main
\
python3
-m
verl.trainer.main
\
...
...
examples/qwen2_5_vl_7b_geo3k_grpo.sh
View file @
20247eb8
#!/bin/bash
set
-x
set
-x
export
PYTHONUNBUFFERED
=
1
MODEL_PATH
=
Qwen/Qwen2.5-VL-7B-Instruct
# replace it with your local file path
MODEL_PATH
=
Qwen/Qwen2.5-VL-7B-Instruct
# replace it with your local file path
python3
-m
verl.trainer.main
\
python3
-m
verl.trainer.main
\
...
...
examples/qwen2_5_vl_7b_geo3k_reinforce.sh
View file @
20247eb8
#!/bin/bash
set
-x
set
-x
export
PYTHONUNBUFFERED
=
1
MODEL_PATH
=
Qwen/Qwen2.5-VL-7B-Instruct
# replace it with your local file path
MODEL_PATH
=
Qwen/Qwen2.5-VL-7B-Instruct
# replace it with your local file path
python3
-m
verl.trainer.main
\
python3
-m
verl.trainer.main
\
...
...
examples/qwen2_5_vl_7b_geo3k_swanlab.sh
View file @
20247eb8
#!/bin/bash
set
-x
set
-x
export
PYTHONUNBUFFERED
=
1
MODEL_PATH
=
Qwen/Qwen2.5-VL-7B-Instruct
# replace it with your local file path
MODEL_PATH
=
Qwen/Qwen2.5-VL-7B-Instruct
# replace it with your local file path
python3
-m
verl.trainer.main
\
python3
-m
verl.trainer.main
\
...
...
examples/qwen2_5_vl_7b_multi_image.sh
View file @
20247eb8
#!/bin/bash
# REMINDER: this script uses test data split and should ONLY be used for debugging. DO NOT use for training.
# REMINDER: this script uses test data split and should ONLY be used for debugging. DO NOT use for training.
set
-x
set
-x
export
PYTHONUNBUFFERED
=
1
MODEL_PATH
=
Qwen/Qwen2.5-VL-7B-Instruct
# replace it with your local file path
MODEL_PATH
=
Qwen/Qwen2.5-VL-7B-Instruct
# replace it with your local file path
python3
-m
verl.trainer.main
\
python3
-m
verl.trainer.main
\
...
...
examples/qwen3_4b_math_grpo.sh
0 → 100644
View file @
20247eb8
#!/bin/bash
set
-x
export
PYTHONUNBUFFERED
=
1
MODEL_PATH
=
Qwen/Qwen3-4B
# replace it with your local file path
python3
-m
verl.trainer.main
\
config
=
examples/config.yaml
\
data.max_response_length
=
4096
\
worker.actor.model.model_path
=
${
MODEL_PATH
}
\
trainer.experiment_name
=
qwen3_4b_math_grpo
examples/reward_function/math.py
View file @
20247eb8
...
@@ -13,28 +13,34 @@
...
@@ -13,28 +13,34 @@
# limitations under the License.
# limitations under the License.
import
re
import
re
from
typing
import
Dict
from
typing
import
Dict
,
List
from
mathruler.grader
import
extract_boxed_content
,
grade_answer
from
mathruler.grader
import
extract_boxed_content
,
grade_answer
def
format_reward
(
predict
_str
:
str
)
->
float
:
def
format_reward
(
predict
:
str
)
->
float
:
pattern
=
re
.
compile
(
r
"<think>.*</think>.*\\boxed\{.*\}.*"
,
re
.
DOTALL
)
pattern
=
re
.
compile
(
r
"<think>.*</think>.*\\boxed\{.*\}.*"
,
re
.
DOTALL
)
format_match
=
re
.
fullmatch
(
pattern
,
predict
_str
)
format_match
=
re
.
fullmatch
(
pattern
,
predict
)
return
1.0
if
format_match
else
0.0
return
1.0
if
format_match
else
0.0
def
accuracy_reward
(
predict
_str
:
str
,
ground_truth
:
str
)
->
float
:
def
accuracy_reward
(
predict
:
str
,
ground_truth
:
str
)
->
float
:
answer
=
extract_boxed_content
(
predict
_str
)
answer
=
extract_boxed_content
(
predict
)
return
1.0
if
grade_answer
(
answer
,
ground_truth
)
else
0.0
return
1.0
if
grade_answer
(
answer
,
ground_truth
)
else
0.0
def
compute_score
(
predict_str
:
str
,
ground_truth
:
str
,
format_weight
:
float
=
0.1
)
->
Dict
[
str
,
float
]:
def
compute_score
(
predicts
:
List
[
str
],
ground_truths
:
List
[
str
],
format_weight
:
float
=
0.1
)
->
List
[
Dict
[
str
,
float
]]:
predict_str
=
re
.
sub
(
r
"\s*(<|>|/)\s*"
,
r
"\1"
,
predict_str
)
# handle qwen2.5vl-32b format
scores
=
[]
format_score
=
format_reward
(
predict_str
)
for
predict
,
ground_truth
in
zip
(
predicts
,
ground_truths
):
accuracy_score
=
accuracy_reward
(
predict_str
,
ground_truth
)
predict
=
re
.
sub
(
r
"\s*(<|>|/)\s*"
,
r
"\1"
,
predict
)
# handle qwen2.5vl-32b format
return
{
format_score
=
format_reward
(
predict
)
accuracy_score
=
accuracy_reward
(
predict
,
ground_truth
)
scores
.
append
(
{
"overall"
:
(
1
-
format_weight
)
*
accuracy_score
+
format_weight
*
format_score
,
"overall"
:
(
1
-
format_weight
)
*
accuracy_score
+
format_weight
*
format_score
,
"format"
:
format_score
,
"format"
:
format_score
,
"accuracy"
:
accuracy_score
,
"accuracy"
:
accuracy_score
,
}
}
)
return
scores
examples/reward_function/r1v.py
View file @
20247eb8
...
@@ -18,16 +18,16 @@ from typing import Dict
...
@@ -18,16 +18,16 @@ from typing import Dict
from
mathruler.grader
import
grade_answer
from
mathruler.grader
import
grade_answer
def
format_reward
(
predict
_str
:
str
)
->
float
:
def
format_reward
(
predict
:
str
)
->
float
:
pattern
=
re
.
compile
(
r
"<think>.*?</think>\s*<answer>.*?</answer>"
,
re
.
DOTALL
)
pattern
=
re
.
compile
(
r
"<think>.*?</think>\s*<answer>.*?</answer>"
,
re
.
DOTALL
)
format_match
=
re
.
fullmatch
(
pattern
,
predict
_str
)
format_match
=
re
.
fullmatch
(
pattern
,
predict
)
return
1.0
if
format_match
else
0.0
return
1.0
if
format_match
else
0.0
def
accuracy_reward
(
predict
_str
:
str
,
ground_truth
:
str
)
->
float
:
def
accuracy_reward
(
predict
:
str
,
ground_truth
:
str
)
->
float
:
try
:
try
:
content_match
=
re
.
search
(
r
"<answer>(.*?)</answer>"
,
predict
_str
)
content_match
=
re
.
search
(
r
"<answer>(.*?)</answer>"
,
predict
)
given_answer
=
content_match
.
group
(
1
).
strip
()
if
content_match
else
predict
_str
.
strip
()
given_answer
=
content_match
.
group
(
1
).
strip
()
if
content_match
else
predict
.
strip
()
if
grade_answer
(
given_answer
,
ground_truth
.
strip
()):
if
grade_answer
(
given_answer
,
ground_truth
.
strip
()):
return
1.0
return
1.0
...
@@ -37,9 +37,9 @@ def accuracy_reward(predict_str: str, ground_truth: str) -> float:
...
@@ -37,9 +37,9 @@ def accuracy_reward(predict_str: str, ground_truth: str) -> float:
return
0.0
return
0.0
def
compute_score
(
predict
_str
:
str
,
ground_truth
:
str
,
format_weight
:
float
=
0.5
)
->
Dict
[
str
,
float
]:
def
compute_score
(
predict
:
str
,
ground_truth
:
str
,
format_weight
:
float
=
0.5
)
->
Dict
[
str
,
float
]:
format_score
=
format_reward
(
predict
_str
)
format_score
=
format_reward
(
predict
)
accuracy_score
=
accuracy_reward
(
predict
_str
,
ground_truth
)
accuracy_score
=
accuracy_reward
(
predict
,
ground_truth
)
return
{
return
{
"overall"
:
(
1
-
format_weight
)
*
accuracy_score
+
format_weight
*
format_score
,
"overall"
:
(
1
-
format_weight
)
*
accuracy_score
+
format_weight
*
format_score
,
"format"
:
format_score
,
"format"
:
format_score
,
...
...
examples/runtime_env.yaml
View file @
20247eb8
...
@@ -3,6 +3,7 @@ excludes: ["/.git/"]
...
@@ -3,6 +3,7 @@ excludes: ["/.git/"]
env_vars
:
env_vars
:
TOKENIZERS_PARALLELISM
:
"
true"
TOKENIZERS_PARALLELISM
:
"
true"
NCCL_DEBUG
:
"
WARN"
NCCL_DEBUG
:
"
WARN"
VLLM_LOGGING_LEVEL
:
"
INFO
"
VLLM_LOGGING_LEVEL
:
"
WARN
"
TORCH_NCCL_AVOID_RECORD_STREAMS
:
"
1"
TORCH_NCCL_AVOID_RECORD_STREAMS
:
"
1"
PYTORCH_CUDA_ALLOC_CONF
:
"
expandable_segments:False"
PYTORCH_CUDA_ALLOC_CONF
:
"
expandable_segments:False"
PYTHONUNBUFFERED
:
"
1"
verl/models/monkey_patch.py
View file @
20247eb8
...
@@ -20,7 +20,7 @@ from .transformers.qwen2_vl import qwen2_vl_attn_forward
...
@@ -20,7 +20,7 @@ from .transformers.qwen2_vl import qwen2_vl_attn_forward
def
apply_ulysses_patch
(
model_type
:
str
)
->
None
:
def
apply_ulysses_patch
(
model_type
:
str
)
->
None
:
if
model_type
in
(
"llama"
,
"gemma"
,
"gemma2"
,
"mistral"
,
"qwen2"
):
if
model_type
in
(
"llama"
,
"gemma"
,
"gemma2"
,
"mistral"
,
"qwen2"
,
"qwen3"
,
"qwen3_moe"
):
ALL_ATTENTION_FUNCTIONS
[
"flash_attention_2"
]
=
flash_attention_forward
ALL_ATTENTION_FUNCTIONS
[
"flash_attention_2"
]
=
flash_attention_forward
elif
model_type
in
(
"qwen2_vl"
,
"qwen2_5_vl"
):
elif
model_type
in
(
"qwen2_vl"
,
"qwen2_5_vl"
):
from
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl
import
Qwen2_5_VLFlashAttention2
from
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl
import
Qwen2_5_VLFlashAttention2
...
...
verl/trainer/config.py
View file @
20247eb8
...
@@ -43,6 +43,7 @@ class DataConfig:
...
@@ -43,6 +43,7 @@ class DataConfig:
rollout_batch_size
:
int
=
512
rollout_batch_size
:
int
=
512
val_batch_size
:
int
=
-
1
val_batch_size
:
int
=
-
1
format_prompt
:
Optional
[
str
]
=
None
format_prompt
:
Optional
[
str
]
=
None
override_chat_template
:
Optional
[
str
]
=
None
shuffle
:
bool
=
True
shuffle
:
bool
=
True
seed
:
int
=
1
seed
:
int
=
1
max_pixels
:
int
=
4194304
max_pixels
:
int
=
4194304
...
@@ -73,7 +74,7 @@ class AlgorithmConfig:
...
@@ -73,7 +74,7 @@ class AlgorithmConfig:
@
dataclass
@
dataclass
class
TrainerConfig
:
class
TrainerConfig
:
total_ep
isode
s
:
int
=
10
total_ep
och
s
:
int
=
10
max_steps
:
Optional
[
int
]
=
None
max_steps
:
Optional
[
int
]
=
None
project_name
:
str
=
"easy_r1"
project_name
:
str
=
"easy_r1"
experiment_name
:
str
=
"demo"
experiment_name
:
str
=
"demo"
...
...
verl/trainer/main.py
View file @
20247eb8
...
@@ -20,7 +20,7 @@ from omegaconf import OmegaConf
...
@@ -20,7 +20,7 @@ from omegaconf import OmegaConf
from
..single_controller.ray
import
RayWorkerGroup
from
..single_controller.ray
import
RayWorkerGroup
from
..utils.tokenizer
import
get_processor
,
get_tokenizer
from
..utils.tokenizer
import
get_processor
,
get_tokenizer
from
..workers.fsdp_workers
import
FSDPWorker
from
..workers.fsdp_workers
import
FSDPWorker
from
..workers.reward
import
FunctionRewardManager
from
..workers.reward
import
BatchFunctionRewardManager
,
Sequential
FunctionRewardManager
from
.config
import
PPOConfig
from
.config
import
PPOConfig
from
.data_loader
import
create_dataloader
from
.data_loader
import
create_dataloader
from
.ray_trainer
import
RayPPOTrainer
,
ResourcePoolManager
,
Role
from
.ray_trainer
import
RayPPOTrainer
,
ResourcePoolManager
,
Role
...
@@ -38,11 +38,13 @@ class Runner:
...
@@ -38,11 +38,13 @@ class Runner:
# instantiate tokenizer
# instantiate tokenizer
tokenizer
=
get_tokenizer
(
tokenizer
=
get_tokenizer
(
config
.
worker
.
actor
.
model
.
model_path
,
config
.
worker
.
actor
.
model
.
model_path
,
override_chat_template
=
config
.
data
.
override_chat_template
,
trust_remote_code
=
config
.
worker
.
actor
.
model
.
trust_remote_code
,
trust_remote_code
=
config
.
worker
.
actor
.
model
.
trust_remote_code
,
use_fast
=
True
,
use_fast
=
True
,
)
)
processor
=
get_processor
(
processor
=
get_processor
(
config
.
worker
.
actor
.
model
.
model_path
,
config
.
worker
.
actor
.
model
.
model_path
,
override_chat_template
=
config
.
data
.
override_chat_template
,
trust_remote_code
=
config
.
worker
.
actor
.
model
.
trust_remote_code
,
trust_remote_code
=
config
.
worker
.
actor
.
model
.
trust_remote_code
,
use_fast
=
True
,
use_fast
=
True
,
)
)
...
@@ -65,7 +67,14 @@ class Runner:
...
@@ -65,7 +67,14 @@ class Runner:
}
}
resource_pool_manager
=
ResourcePoolManager
(
resource_pool_spec
=
resource_pool_spec
,
mapping
=
mapping
)
resource_pool_manager
=
ResourcePoolManager
(
resource_pool_spec
=
resource_pool_spec
,
mapping
=
mapping
)
RemoteRewardManager
=
ray
.
remote
(
FunctionRewardManager
).
options
(
num_cpus
=
config
.
worker
.
reward
.
num_cpus
)
if
config
.
worker
.
reward
.
reward_type
==
"sequential"
:
RewardManager
=
SequentialFunctionRewardManager
elif
config
.
worker
.
reward
.
reward_type
==
"batch"
:
RewardManager
=
BatchFunctionRewardManager
else
:
raise
NotImplementedError
(
f
"Unknown reward type
{
config
.
worker
.
reward
.
reward_type
}
."
)
RemoteRewardManager
=
ray
.
remote
(
RewardManager
).
options
(
num_cpus
=
config
.
worker
.
reward
.
num_cpus
)
reward_fn
=
RemoteRewardManager
.
remote
(
config
.
worker
.
reward
,
tokenizer
)
reward_fn
=
RemoteRewardManager
.
remote
(
config
.
worker
.
reward
,
tokenizer
)
val_reward_fn
=
RemoteRewardManager
.
remote
(
config
.
worker
.
reward
,
tokenizer
)
val_reward_fn
=
RemoteRewardManager
.
remote
(
config
.
worker
.
reward
,
tokenizer
)
...
@@ -117,7 +126,6 @@ def main():
...
@@ -117,7 +126,6 @@ def main():
runtime_env
=
runtime_env
)
runtime_env
=
runtime_env
)
else
:
else
:
ray
.
init
(
runtime_env
=
runtime_env
)
ray
.
init
(
runtime_env
=
runtime_env
)
runner
=
Runner
.
remote
()
runner
=
Runner
.
remote
()
ray
.
get
(
runner
.
run
.
remote
(
ppo_config
))
ray
.
get
(
runner
.
run
.
remote
(
ppo_config
))
...
...
verl/trainer/ray_trainer.py
View file @
20247eb8
...
@@ -247,7 +247,7 @@ class RayPPOTrainer:
...
@@ -247,7 +247,7 @@ class RayPPOTrainer:
if
config
.
trainer
.
max_steps
is
not
None
:
if
config
.
trainer
.
max_steps
is
not
None
:
self
.
training_steps
=
config
.
trainer
.
max_steps
self
.
training_steps
=
config
.
trainer
.
max_steps
else
:
else
:
self
.
training_steps
=
len
(
train_dataloader
)
*
config
.
trainer
.
total_ep
isode
s
self
.
training_steps
=
len
(
train_dataloader
)
*
config
.
trainer
.
total_ep
och
s
config
.
worker
.
actor
.
optim
.
training_steps
=
self
.
training_steps
config
.
worker
.
actor
.
optim
.
training_steps
=
self
.
training_steps
config
.
worker
.
critic
.
optim
.
training_steps
=
self
.
training_steps
config
.
worker
.
critic
.
optim
.
training_steps
=
self
.
training_steps
...
@@ -473,7 +473,7 @@ class RayPPOTrainer:
...
@@ -473,7 +473,7 @@ class RayPPOTrainer:
if
self
.
config
.
trainer
.
val_only
:
if
self
.
config
.
trainer
.
val_only
:
return
return
for
_
in
tqdm
(
range
(
self
.
config
.
trainer
.
total_ep
isode
s
),
desc
=
"Ep
isode
"
,
position
=
0
):
for
_
in
tqdm
(
range
(
self
.
config
.
trainer
.
total_ep
och
s
),
desc
=
"Ep
och
"
,
position
=
0
):
for
batch_dict
in
tqdm
(
self
.
train_dataloader
,
desc
=
"Running step"
,
position
=
1
):
for
batch_dict
in
tqdm
(
self
.
train_dataloader
,
desc
=
"Running step"
,
position
=
1
):
self
.
global_step
+=
1
self
.
global_step
+=
1
if
self
.
global_step
>
self
.
training_steps
:
if
self
.
global_step
>
self
.
training_steps
:
...
...
verl/utils/checkpoint/fsdp_checkpoint_manager.py
View file @
20247eb8
...
@@ -55,11 +55,13 @@ class FSDPCheckpointManager(BaseCheckpointManager):
...
@@ -55,11 +55,13 @@ class FSDPCheckpointManager(BaseCheckpointManager):
# every rank download its own checkpoint
# every rank download its own checkpoint
model_path
=
os
.
path
.
join
(
path
,
f
"model_world_size_
{
self
.
world_size
}
_rank_
{
self
.
rank
}
.pt"
)
model_path
=
os
.
path
.
join
(
path
,
f
"model_world_size_
{
self
.
world_size
}
_rank_
{
self
.
rank
}
.pt"
)
optim_path
=
os
.
path
.
join
(
path
,
f
"optim_world_size_
{
self
.
world_size
}
_rank_
{
self
.
rank
}
.pt"
)
optim_path
=
os
.
path
.
join
(
path
,
f
"optim_world_size_
{
self
.
world_size
}
_rank_
{
self
.
rank
}
.pt"
)
extra_state_path
=
os
.
path
.
join
(
path
,
f
"extra_state_world_size_
{
self
.
world_size
}
_rank_
{
self
.
rank
}
.pt"
)
extra_path
=
os
.
path
.
join
(
path
,
f
"extra_state_world_size_
{
self
.
world_size
}
_rank_
{
self
.
rank
}
.pt"
)
print
(
f
"[rank-
{
self
.
rank
}
]: Loading from
{
model_path
}
and
{
optim_path
}
and
{
extra_state_path
}
."
)
print
(
f
"[rank-
{
self
.
rank
}
]: Loading model from
{
os
.
path
.
abspath
(
model_path
)
}
."
)
print
(
f
"[rank-
{
self
.
rank
}
]: Loading optimizer from
{
os
.
path
.
abspath
(
optim_path
)
}
."
)
print
(
f
"[rank-
{
self
.
rank
}
]: Loading extra_state from
{
os
.
path
.
abspath
(
extra_path
)
}
."
)
model_state_dict
=
torch
.
load
(
model_path
,
weights_only
=
False
)
model_state_dict
=
torch
.
load
(
model_path
,
weights_only
=
False
)
optim_state_dict
=
torch
.
load
(
optim_path
,
weights_only
=
False
)
optim_state_dict
=
torch
.
load
(
optim_path
,
weights_only
=
False
)
extra_state_dict
=
torch
.
load
(
extra_
state_
path
,
weights_only
=
False
)
extra_state_dict
=
torch
.
load
(
extra_path
,
weights_only
=
False
)
state_dict_options
=
StateDictOptions
(
cpu_offload
=
True
)
state_dict_options
=
StateDictOptions
(
cpu_offload
=
True
)
set_state_dict
(
set_state_dict
(
...
@@ -91,7 +93,7 @@ class FSDPCheckpointManager(BaseCheckpointManager):
...
@@ -91,7 +93,7 @@ class FSDPCheckpointManager(BaseCheckpointManager):
extra_path
=
os
.
path
.
join
(
path
,
f
"extra_state_world_size_
{
self
.
world_size
}
_rank_
{
self
.
rank
}
.pt"
)
extra_path
=
os
.
path
.
join
(
path
,
f
"extra_state_world_size_
{
self
.
world_size
}
_rank_
{
self
.
rank
}
.pt"
)
print
(
f
"[rank-
{
self
.
rank
}
]: Saving model to
{
os
.
path
.
abspath
(
model_path
)
}
."
)
print
(
f
"[rank-
{
self
.
rank
}
]: Saving model to
{
os
.
path
.
abspath
(
model_path
)
}
."
)
print
(
f
"[rank-
{
self
.
rank
}
]: Saving
checkpoint
to
{
os
.
path
.
abspath
(
model
_path
)
}
."
)
print
(
f
"[rank-
{
self
.
rank
}
]: Saving
optimizer
to
{
os
.
path
.
abspath
(
optim
_path
)
}
."
)
print
(
f
"[rank-
{
self
.
rank
}
]: Saving extra_state to
{
os
.
path
.
abspath
(
extra_path
)
}
."
)
print
(
f
"[rank-
{
self
.
rank
}
]: Saving extra_state to
{
os
.
path
.
abspath
(
extra_path
)
}
."
)
torch
.
save
(
model_state_dict
,
model_path
)
torch
.
save
(
model_state_dict
,
model_path
)
torch
.
save
(
optim_state_dict
,
optim_path
)
torch
.
save
(
optim_state_dict
,
optim_path
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment