Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
EasyR1
Commits
2369eb2b
Commit
2369eb2b
authored
Apr 21, 2025
by
chenych
Browse files
update
parent
ac9d2b05
Changes
43
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
377 additions
and
299 deletions
+377
-299
verl/single_controller/ray/base.py
verl/single_controller/ray/base.py
+5
-5
verl/trainer/config.py
verl/trainer/config.py
+13
-0
verl/trainer/data_loader.py
verl/trainer/data_loader.py
+87
-0
verl/trainer/main.py
verl/trainer/main.py
+24
-13
verl/trainer/metrics.py
verl/trainer/metrics.py
+2
-2
verl/trainer/ray_trainer.py
verl/trainer/ray_trainer.py
+23
-96
verl/utils/checkpoint/fsdp_checkpoint_manager.py
verl/utils/checkpoint/fsdp_checkpoint_manager.py
+27
-47
verl/utils/dataset.py
verl/utils/dataset.py
+65
-28
verl/utils/logger/logger.py
verl/utils/logger/logger.py
+1
-1
verl/utils/reward_score/__init__.py
verl/utils/reward_score/__init__.py
+0
-20
verl/utils/torch_dtypes.py
verl/utils/torch_dtypes.py
+9
-21
verl/utils/torch_functional.py
verl/utils/torch_functional.py
+41
-30
verl/workers/actor/__init__.py
verl/workers/actor/__init__.py
+0
-4
verl/workers/actor/config.py
verl/workers/actor/config.py
+7
-0
verl/workers/actor/dp_actor.py
verl/workers/actor/dp_actor.py
+2
-6
verl/workers/critic/__init__.py
verl/workers/critic/__init__.py
+2
-4
verl/workers/fsdp_workers.py
verl/workers/fsdp_workers.py
+9
-2
verl/workers/reward/__init__.py
verl/workers/reward/__init__.py
+2
-2
verl/workers/reward/config.py
verl/workers/reward/config.py
+19
-2
verl/workers/reward/function.py
verl/workers/reward/function.py
+39
-16
No files found.
verl/single_controller/ray/base.py
View file @
2369eb2b
...
...
@@ -98,7 +98,7 @@ class RayResourcePool(ResourcePool):
# print(f"pg_name_prefix = {pg_name_prefix}")
pg_scheme
=
[
[
{
"CPU"
:
self
.
max_col
l
ocate_count
,
"GPU"
:
1
}
if
self
.
use_gpu
else
{
"CPU"
:
self
.
max_col
l
ocate_count
}
{
"CPU"
:
self
.
max_colocate_count
,
"GPU"
:
1
}
if
self
.
use_gpu
else
{
"CPU"
:
self
.
max_colocate_count
}
for
_
in
range
(
process_count
)
]
for
process_count
in
self
.
_store
...
...
@@ -145,8 +145,8 @@ def extract_pg_from_exist(
def
merge_resource_pool
(
rp1
:
RayResourcePool
,
rp2
:
RayResourcePool
)
->
RayResourcePool
:
assert
rp1
.
use_gpu
==
rp2
.
use_gpu
,
"Both RayResourcePool must either use_gpu or not"
assert
rp1
.
max_col
l
ocate_count
==
rp2
.
max_col
l
ocate_count
,
(
"Both RayResourcePool must has the same max_col
l
ocate_count"
assert
rp1
.
max_colocate_count
==
rp2
.
max_colocate_count
,
(
"Both RayResourcePool must has the same max_colocate_count"
)
assert
rp1
.
n_gpus_per_node
==
rp2
.
n_gpus_per_node
,
"Both RayResourcePool must has the same n_gpus_per_node"
assert
rp1
.
detached
==
rp2
.
detached
,
"Detached ResourcePool cannot be merged with non-detached ResourcePool"
...
...
@@ -259,7 +259,7 @@ class RayWorkerGroup(WorkerGroup):
world_size
=
resource_pool
.
world_size
self
.
_world_size
=
world_size
# cia.add_kwarg("_world_size", world_size)
num_gpus
=
1
/
resource_pool
.
max_col
l
ocate_count
num_gpus
=
1
/
resource_pool
.
max_colocate_count
rank
=
-
1
local_world_size
=
resource_pool
.
store
[
0
]
...
...
@@ -300,7 +300,7 @@ class RayWorkerGroup(WorkerGroup):
if
rank
==
0
:
register_center_actor
=
None
for
_
in
range
(
36
0
):
for
_
in
range
(
12
0
):
if
f
"
{
self
.
name_prefix
}
_register_center"
not
in
list_named_actors
():
time
.
sleep
(
1
)
else
:
...
...
verl/trainer/config.py
View file @
2369eb2b
...
...
@@ -47,6 +47,14 @@ class DataConfig:
seed
:
int
=
1
max_pixels
:
int
=
4194304
min_pixels
:
int
=
262144
filter_overlong_prompts
:
bool
=
True
def
post_init
(
self
):
if
self
.
format_prompt
is
not
None
:
if
os
.
path
.
exists
(
self
.
format_prompt
):
self
.
format_prompt
=
os
.
path
.
abspath
(
self
.
format_prompt
)
else
:
self
.
format_prompt
=
None
@
dataclass
...
...
@@ -86,6 +94,10 @@ class TrainerConfig:
if
self
.
save_checkpoint_path
is
None
:
self
.
save_checkpoint_path
=
os
.
path
.
join
(
"checkpoints"
,
self
.
project_name
,
self
.
experiment_name
)
self
.
save_checkpoint_path
=
os
.
path
.
abspath
(
self
.
save_checkpoint_path
)
if
self
.
load_checkpoint_path
is
not
None
:
self
.
load_checkpoint_path
=
os
.
path
.
abspath
(
self
.
load_checkpoint_path
)
@
dataclass
class
PPOConfig
:
...
...
@@ -97,6 +109,7 @@ class PPOConfig:
def
post_init
(
self
):
self
.
worker
.
rollout
.
prompt_length
=
self
.
data
.
max_prompt_length
self
.
worker
.
rollout
.
response_length
=
self
.
data
.
max_response_length
self
.
worker
.
rollout
.
trust_remote_code
=
self
.
worker
.
actor
.
model
.
trust_remote_code
self
.
worker
.
actor
.
disable_kl
=
self
.
algorithm
.
disable_kl
self
.
worker
.
actor
.
use_kl_loss
=
self
.
algorithm
.
use_kl_loss
self
.
worker
.
actor
.
kl_penalty
=
self
.
algorithm
.
kl_penalty
...
...
verl/trainer/data_loader.py
0 → 100644
View file @
2369eb2b
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Optional
import
torch
from
torch.utils.data
import
RandomSampler
,
SequentialSampler
from
torchdata.stateful_dataloader
import
StatefulDataLoader
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
..utils.dataset
import
RLHFDataset
,
collate_fn
from
.config
import
DataConfig
def
create_dataloader
(
config
:
DataConfig
,
tokenizer
:
PreTrainedTokenizer
,
processor
:
Optional
[
ProcessorMixin
])
->
None
:
train_dataset
=
RLHFDataset
(
data_path
=
config
.
train_files
,
tokenizer
=
tokenizer
,
processor
=
processor
,
prompt_key
=
config
.
prompt_key
,
answer_key
=
config
.
answer_key
,
image_key
=
config
.
image_key
,
max_prompt_length
=
config
.
max_prompt_length
,
truncation
=
"right"
,
format_prompt
=
config
.
format_prompt
,
min_pixels
=
config
.
min_pixels
,
max_pixels
=
config
.
max_pixels
,
filter_overlong_prompts
=
config
.
filter_overlong_prompts
,
)
# use sampler for better ckpt resume
if
config
.
shuffle
:
train_dataloader_generator
=
torch
.
Generator
()
train_dataloader_generator
.
manual_seed
(
config
.
seed
)
sampler
=
RandomSampler
(
data_source
=
train_dataset
,
generator
=
train_dataloader_generator
)
else
:
sampler
=
SequentialSampler
(
data_source
=
train_dataset
)
train_dataloader
=
StatefulDataLoader
(
dataset
=
train_dataset
,
batch_size
=
config
.
rollout_batch_size
,
sampler
=
sampler
,
num_workers
=
8
,
collate_fn
=
collate_fn
,
pin_memory
=
False
,
drop_last
=
True
,
)
val_dataset
=
RLHFDataset
(
data_path
=
config
.
val_files
,
tokenizer
=
tokenizer
,
processor
=
processor
,
prompt_key
=
config
.
prompt_key
,
answer_key
=
config
.
answer_key
,
image_key
=
config
.
image_key
,
max_prompt_length
=
config
.
max_prompt_length
,
truncation
=
"right"
,
format_prompt
=
config
.
format_prompt
,
min_pixels
=
config
.
min_pixels
,
max_pixels
=
config
.
max_pixels
,
filter_overlong_prompts
=
config
.
filter_overlong_prompts
,
)
val_dataloader
=
StatefulDataLoader
(
dataset
=
val_dataset
,
batch_size
=
len
(
val_dataset
)
if
config
.
val_batch_size
==
-
1
else
config
.
val_batch_size
,
shuffle
=
False
,
num_workers
=
8
,
collate_fn
=
collate_fn
,
pin_memory
=
False
,
drop_last
=
False
,
)
assert
len
(
train_dataloader
)
>=
1
assert
len
(
val_dataloader
)
>=
1
print
(
f
"Size of train dataloader:
{
len
(
train_dataloader
)
}
"
)
print
(
f
"Size of val dataloader:
{
len
(
val_dataloader
)
}
"
)
return
train_dataloader
,
val_dataloader
verl/trainer/main.py
View file @
2369eb2b
...
...
@@ -11,21 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
import
json
import
torch
import
ray
from
omegaconf
import
OmegaConf
from
..single_controller.ray
import
RayWorkerGroup
from
..utils.tokenizer
import
get_processor
,
get_tokenizer
from
..workers.fsdp_workers
import
FSDPWorker
from
..workers.reward
import
Custom
RewardManager
from
..workers.reward
import
Function
RewardManager
from
.config
import
PPOConfig
from
.data_loader
import
create_dataloader
from
.ray_trainer
import
RayPPOTrainer
,
ResourcePoolManager
,
Role
...
...
@@ -36,7 +33,6 @@ class Runner:
def
run
(
self
,
config
:
PPOConfig
):
# print config
config
.
deep_post_init
()
print
(
json
.
dumps
(
config
.
to_dict
(),
indent
=
2
))
# instantiate tokenizer
...
...
@@ -69,13 +65,19 @@ class Runner:
}
resource_pool_manager
=
ResourcePoolManager
(
resource_pool_spec
=
resource_pool_spec
,
mapping
=
mapping
)
reward_fn
=
CustomRewardManager
(
tokenizer
=
tokenizer
,
config
=
config
.
worker
.
reward
)
val_reward_fn
=
CustomRewardManager
(
tokenizer
=
tokenizer
,
config
=
config
.
worker
.
reward
)
reward_fn
=
FunctionRewardManager
(
config
=
config
.
worker
.
reward
,
tokenizer
=
tokenizer
)
val_reward_fn
=
FunctionRewardManager
(
config
=
config
.
worker
.
reward
,
tokenizer
=
tokenizer
)
train_dataloader
,
val_dataloader
=
create_dataloader
(
config
=
config
.
data
,
tokenizer
=
tokenizer
,
processor
=
processor
)
trainer
=
RayPPOTrainer
(
config
=
config
,
tokenizer
=
tokenizer
,
processor
=
processor
,
train_dataloader
=
train_dataloader
,
val_dataloader
=
val_dataloader
,
role_worker_mapping
=
role_worker_mapping
,
resource_pool_manager
=
resource_pool_manager
,
ray_worker_group_cls
=
ray_worker_group_cls
,
...
...
@@ -96,17 +98,26 @@ def main():
default_config
=
OmegaConf
.
merge
(
default_config
,
file_config
)
ppo_config
=
OmegaConf
.
merge
(
default_config
,
cli_args
)
ppo_config
=
OmegaConf
.
to_object
(
ppo_config
)
ppo_config
:
PPOConfig
=
OmegaConf
.
to_object
(
ppo_config
)
ppo_config
.
deep_post_init
()
if
not
ray
.
is_initialized
():
runtime_env
=
{
"env_vars"
:
{
"TOKENIZERS_PARALLELISM"
:
"true"
,
"NCCL_DEBUG"
:
"WARN"
,
"VLLM_LOGGING_LEVEL"
:
"INFO"
,
"TORCH_NCCL_AVOID_RECORD_STREAMS"
:
"1"
,
"PYTORCH_CUDA_ALLOC_CONF"
:
"expandable_segments:False"
,
}
}
# this is for local ray cluster
if
torch
.
version
.
hip
is
not
None
:
ray
.
init
(
num_gpus
=
torch
.
cuda
.
device_count
(),
ignore_reinit_error
=
True
,
runtime_env
=
{
"env_vars"
:
{
"TOKENIZERS_PARALLELISM"
:
"true"
,
"NCCL_DEBUG"
:
"WARN"
}}
)
runtime_env
=
runtime_env
)
else
:
ray
.
init
(
runtime_env
=
{
"env_vars"
:
{
"TOKENIZERS_PARALLELISM"
:
"true"
,
"NCCL_DEBUG"
:
"WARN"
}})
ray
.
init
(
runtime_env
=
runtime_env
)
runner
=
Runner
.
remote
()
ray
.
get
(
runner
.
run
.
remote
(
ppo_config
))
...
...
verl/trainer/metrics.py
View file @
2369eb2b
...
...
@@ -110,11 +110,11 @@ def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Di
}
def
compute_throughout_metrics
(
batch
:
DataProto
,
timing_raw
:
Dict
[
str
,
float
],
n_gpus
:
int
)
->
Dict
[
str
,
Any
]:
def
compute_throughout_metrics
(
batch
:
DataProto
,
timing_raw
:
Dict
[
str
,
float
],
n
um
_gpus
:
int
)
->
Dict
[
str
,
Any
]:
total_num_tokens
=
sum
(
batch
.
meta_info
[
"global_token_num"
])
time
=
timing_raw
[
"step"
]
return
{
"perf/total_num_tokens"
:
total_num_tokens
,
"perf/time_per_step"
:
time
,
"perf/throughput"
:
total_num_tokens
/
(
time
*
n_gpus
),
"perf/throughput"
:
total_num_tokens
/
(
time
*
n
um
_gpus
),
}
verl/trainer/ray_trainer.py
View file @
2369eb2b
...
...
@@ -30,7 +30,6 @@ import ray
import
torch
from
codetiming
import
Timer
from
ray.experimental.tqdm_ray
import
tqdm
from
torch.utils.data
import
RandomSampler
,
SequentialSampler
from
torchdata.stateful_dataloader
import
StatefulDataLoader
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
...
...
@@ -40,7 +39,6 @@ from ..single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWo
from
..single_controller.ray.base
import
create_colocated_worker_cls
from
..utils
import
torch_functional
as
VF
from
..utils.checkpoint
import
CHECKPOINT_TRACKER
,
remove_obsolete_ckpt
from
..utils.dataset
import
RLHFDataset
,
collate_fn
from
..utils.logger
import
Tracker
from
..utils.py_functional
import
convert_dict_to_str
from
..utils.seqlen_balancing
import
get_seqlen_balanced_partitions
,
log_seqlen_unbalance
...
...
@@ -102,24 +100,16 @@ class ResourcePoolManager:
"""Get the resource pool of the worker."""
return
self
.
resource_pool_dict
[
self
.
mapping
[
role
]]
def
get_n_gpus
(
self
)
->
int
:
def
get_n
um
_gpus
(
self
)
->
int
:
"""Get the number of gpus in this cluster."""
return
sum
([
n_gpus
for
process_on_nodes
in
self
.
resource_pool_spec
.
values
()
for
n_gpus
in
process_on_nodes
])
def
_check_resource_available
(
self
):
"""Check if the resource pool can be satisfied in this ray cluster."""
node_available_resources
=
ray
.
state
.
available_resources_per_node
()
node_available_gpus
=
{
node
:
node_info
.
get
(
"GPU"
,
0
)
for
node
,
node_info
in
node_available_resources
.
items
()}
# check total required gpus can be satisfied
total_available_gpus
=
sum
(
node_available_gpus
.
values
())
total_required_gpus
=
sum
(
[
n_gpus
for
process_on_nodes
in
self
.
resource_pool_spec
.
values
()
for
n_gpus
in
process_on_nodes
]
)
if
total_available_gpus
<
total_required_gpus
:
raise
ValueError
(
f
"Total available GPUs
{
total_available_gpus
}
is less than total desired GPUs
{
total_required_gpus
}
."
)
gpus_available
=
ray
.
available_resources
().
get
(
"GPU"
,
0
)
gpus_required
=
self
.
get_num_gpus
()
if
gpus_available
<
gpus_required
:
raise
ValueError
(
f
"Total available GPUs
{
gpus_available
}
is less than total desired GPUs
{
gpus_required
}
."
)
def
apply_kl_penalty
(
data
:
DataProto
,
kl_ctrl
:
core_algos
.
KLController
,
kl_penalty
=
"kl"
):
...
...
@@ -128,11 +118,8 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.KLController, kl_penal
response_mask
=
data
.
batch
[
"response_mask"
]
# compute kl between ref_policy and current policy
if
"ref_log_probs"
in
data
.
batch
.
keys
():
kld
=
core_algos
.
compute_kl
(
data
.
batch
[
"old_log_probs"
],
data
.
batch
[
"ref_log_probs"
],
kl_penalty
=
kl_penalty
)
kld
=
kld
*
response_mask
# (batch_size, response_length)
else
:
kld
=
torch
.
zeros_like
(
response_mask
,
dtype
=
torch
.
float32
)
data
.
batch
[
"token_level_rewards"
]
=
token_level_scores
-
kl_ctrl
.
kl_coef
*
kld
...
...
@@ -193,6 +180,8 @@ class RayPPOTrainer:
config
:
PPOConfig
,
tokenizer
:
PreTrainedTokenizer
,
processor
:
Optional
[
ProcessorMixin
],
train_dataloader
:
StatefulDataLoader
,
val_dataloader
:
StatefulDataLoader
,
role_worker_mapping
:
dict
[
Role
,
Type
[
Worker
]],
resource_pool_manager
:
ResourcePoolManager
,
ray_worker_group_cls
:
Type
[
RayWorkerGroup
]
=
RayWorkerGroup
,
...
...
@@ -201,6 +190,8 @@ class RayPPOTrainer:
):
self
.
tokenizer
=
tokenizer
self
.
processor
=
processor
self
.
train_dataloader
=
train_dataloader
self
.
val_dataloader
=
val_dataloader
self
.
config
=
config
self
.
reward_fn
=
reward_fn
self
.
val_reward_fn
=
val_reward_fn
...
...
@@ -262,78 +253,13 @@ class RayPPOTrainer:
):
raise
ValueError
(
"GRPO and RLOO algorithm need `config.worker.rollout.n > 1`."
)
self
.
_create_dataloader
()
def
_create_dataloader
(
self
)
->
None
:
self
.
train_dataset
=
RLHFDataset
(
data_path
=
self
.
config
.
data
.
train_files
,
tokenizer
=
self
.
tokenizer
,
processor
=
self
.
processor
,
prompt_key
=
self
.
config
.
data
.
prompt_key
,
answer_key
=
self
.
config
.
data
.
answer_key
,
image_key
=
self
.
config
.
data
.
image_key
,
max_prompt_length
=
self
.
config
.
data
.
max_prompt_length
,
truncation
=
"right"
,
format_prompt
=
self
.
config
.
data
.
format_prompt
,
min_pixels
=
self
.
config
.
data
.
min_pixels
,
max_pixels
=
self
.
config
.
data
.
max_pixels
,
)
# use sampler for better ckpt resume
if
self
.
config
.
data
.
shuffle
:
train_dataloader_generator
=
torch
.
Generator
()
train_dataloader_generator
.
manual_seed
(
self
.
config
.
data
.
seed
)
sampler
=
RandomSampler
(
data_source
=
self
.
train_dataset
,
generator
=
train_dataloader_generator
)
else
:
sampler
=
SequentialSampler
(
data_source
=
self
.
train_dataset
)
self
.
train_dataloader
=
StatefulDataLoader
(
dataset
=
self
.
train_dataset
,
batch_size
=
self
.
config
.
data
.
rollout_batch_size
,
sampler
=
sampler
,
num_workers
=
8
,
collate_fn
=
collate_fn
,
pin_memory
=
False
,
drop_last
=
True
,
)
self
.
val_dataset
=
RLHFDataset
(
data_path
=
self
.
config
.
data
.
val_files
,
tokenizer
=
self
.
tokenizer
,
processor
=
self
.
processor
,
prompt_key
=
self
.
config
.
data
.
prompt_key
,
answer_key
=
self
.
config
.
data
.
answer_key
,
image_key
=
self
.
config
.
data
.
image_key
,
max_prompt_length
=
self
.
config
.
data
.
max_prompt_length
,
truncation
=
"right"
,
format_prompt
=
self
.
config
.
data
.
format_prompt
,
min_pixels
=
self
.
config
.
data
.
min_pixels
,
max_pixels
=
self
.
config
.
data
.
max_pixels
,
)
self
.
val_dataloader
=
StatefulDataLoader
(
dataset
=
self
.
val_dataset
,
batch_size
=
len
(
self
.
val_dataset
)
if
self
.
config
.
data
.
val_batch_size
==
-
1
else
self
.
config
.
data
.
val_batch_size
,
shuffle
=
False
,
num_workers
=
8
,
collate_fn
=
collate_fn
,
pin_memory
=
False
,
drop_last
=
False
,
)
assert
len
(
self
.
train_dataloader
)
>=
1
assert
len
(
self
.
val_dataloader
)
>=
1
print
(
f
"Size of train dataloader:
{
len
(
self
.
train_dataloader
)
}
"
)
print
(
f
"Size of val dataloader:
{
len
(
self
.
val_dataloader
)
}
"
)
if
self
.
config
.
trainer
.
max_steps
is
not
None
:
training_steps
=
self
.
config
.
trainer
.
max_steps
if
config
.
trainer
.
max_steps
is
not
None
:
self
.
training_steps
=
config
.
trainer
.
max_steps
else
:
training_steps
=
len
(
self
.
train_dataloader
)
*
self
.
config
.
trainer
.
total_episodes
self
.
training_steps
=
len
(
train_dataloader
)
*
config
.
trainer
.
total_episodes
self
.
training_steps
=
training_steps
self
.
config
.
worker
.
actor
.
optim
.
training_steps
=
training_steps
self
.
config
.
worker
.
critic
.
optim
.
training_steps
=
training_steps
config
.
worker
.
actor
.
optim
.
training_steps
=
self
.
training_steps
config
.
worker
.
critic
.
optim
.
training_steps
=
self
.
training_steps
print
(
f
"Total training steps:
{
self
.
training_steps
}
"
)
def
_maybe_log_val_generations
(
...
...
@@ -366,10 +292,10 @@ class RayPPOTrainer:
input_texts
=
[
self
.
tokenizer
.
decode
(
ids
,
skip_special_tokens
=
True
)
for
ids
in
input_ids
]
sample_inputs
.
extend
(
input_texts
)
if
"multi_modal_
inputs
"
in
test_batch
.
non_tensor_batch
.
keys
():
if
"multi_modal_
data
"
in
test_batch
.
non_tensor_batch
.
keys
():
test_gen_batch
=
test_batch
.
pop
(
batch_keys
=
[
"input_ids"
,
"attention_mask"
,
"position_ids"
],
non_tensor_batch_keys
=
[
"raw_prompt_ids"
,
"multi_modal_data"
,
"multi_modal_inputs"
],
non_tensor_batch_keys
=
[
"raw_prompt_ids"
,
"multi_modal_data"
],
)
else
:
test_gen_batch
=
test_batch
.
pop
(
...
...
@@ -567,10 +493,10 @@ class RayPPOTrainer:
batch
:
DataProto
=
DataProto
.
from_single_dict
(
batch_dict
)
# pop those keys for generation
if
"multi_modal_
inputs
"
in
batch
.
non_tensor_batch
.
keys
():
if
"multi_modal_
data
"
in
batch
.
non_tensor_batch
.
keys
():
gen_batch
=
batch
.
pop
(
batch_keys
=
[
"input_ids"
,
"attention_mask"
,
"position_ids"
],
non_tensor_batch_keys
=
[
"raw_prompt_ids"
,
"multi_modal_data"
,
"multi_modal_inputs"
],
non_tensor_batch_keys
=
[
"raw_prompt_ids"
,
"multi_modal_data"
],
)
else
:
gen_batch
=
batch
.
pop
(
...
...
@@ -604,6 +530,7 @@ class RayPPOTrainer:
# repeat to align with repeated responses in rollout
batch
=
batch
.
repeat
(
repeat_times
=
self
.
config
.
worker
.
rollout
.
n
,
interleave
=
True
)
batch
=
batch
.
union
(
gen_batch_output
)
batch
.
non_tensor_batch
.
pop
(
"multi_modal_data"
,
None
)
# compute reward
with
_timer
(
"reward"
,
timing_raw
):
...
...
@@ -694,10 +621,10 @@ class RayPPOTrainer:
self
.
_save_checkpoint
()
# collect metrics
n_gpus
=
self
.
resource_pool_manager
.
get_n_gpus
()
n
um
_gpus
=
self
.
resource_pool_manager
.
get_n
um
_gpus
()
metrics
.
update
(
compute_data_metrics
(
batch
=
batch
,
use_critic
=
self
.
use_critic
))
metrics
.
update
(
compute_timing_metrics
(
batch
=
batch
,
timing_raw
=
timing_raw
))
metrics
.
update
(
compute_throughout_metrics
(
batch
=
batch
,
timing_raw
=
timing_raw
,
n_gpus
=
n_gpus
))
metrics
.
update
(
compute_throughout_metrics
(
batch
=
batch
,
timing_raw
=
timing_raw
,
n
um
_gpus
=
n
um
_gpus
))
self
.
logger
.
log
(
data
=
metrics
,
step
=
self
.
global_step
)
...
...
verl/utils/checkpoint/fsdp_checkpoint_manager.py
View file @
2369eb2b
...
...
@@ -13,13 +13,12 @@
# limitations under the License.
import
os
import
warnings
from
typing
import
Optional
,
Union
import
torch
import
torch.distributed
as
dist
from
torch.distributed.checkpoint.state_dict
import
StateDictOptions
,
get_state_dict
,
set_state_dict
from
torch.distributed.fsdp
import
FullyShardedDataParallel
as
FSDP
from
torch.distributed.fsdp
import
ShardedOptimStateDictConfig
,
ShardedStateDictConfig
,
StateDictType
from
transformers
import
PreTrainedModel
,
PreTrainedTokenizer
,
ProcessorMixin
from
.checkpoint_manager
import
BaseCheckpointManager
...
...
@@ -59,21 +58,18 @@ class FSDPCheckpointManager(BaseCheckpointManager):
extra_state_path
=
os
.
path
.
join
(
path
,
f
"extra_state_world_size_
{
self
.
world_size
}
_rank_
{
self
.
rank
}
.pt"
)
print
(
f
"[rank-
{
self
.
rank
}
]: Loading from
{
model_path
}
and
{
optim_path
}
and
{
extra_state_path
}
."
)
model_state_dict
=
torch
.
load
(
model_path
,
weights_only
=
False
)
optim
izer
_state_dict
=
torch
.
load
(
optim_path
,
weights_only
=
False
)
optim_state_dict
=
torch
.
load
(
optim_path
,
weights_only
=
False
)
extra_state_dict
=
torch
.
load
(
extra_state_path
,
weights_only
=
False
)
lr_scheduler_state_dict
=
extra_state_dict
[
"lr_scheduler"
]
state_dict_config
=
ShardedStateDictConfig
(
offload_to_cpu
=
True
)
optim_config
=
ShardedOptimStateDictConfig
(
offload_to_cpu
=
True
)
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"ignore"
)
with
FSDP
.
state_dict_type
(
self
.
model
,
StateDictType
.
SHARDED_STATE_DICT
,
state_dict_config
,
optim_config
):
self
.
model
.
load_state_dict
(
model_state_dict
)
if
self
.
optimizer
is
not
None
:
self
.
optimizer
.
load_state_dict
(
optimizer_state_dict
)
if
self
.
lr_scheduler
is
not
None
:
self
.
lr_scheduler
.
load_state_dict
(
lr_scheduler_state_dict
)
state_dict_options
=
StateDictOptions
(
cpu_offload
=
True
)
set_state_dict
(
model
=
self
.
model
,
optimizers
=
self
.
optimizer
,
model_state_dict
=
model_state_dict
,
optim_state_dict
=
optim_state_dict
,
options
=
state_dict_options
,
)
self
.
lr_scheduler
.
load_state_dict
(
extra_state_dict
[
"lr_scheduler"
])
# recover random state
if
"rng"
in
extra_state_dict
:
...
...
@@ -84,24 +80,10 @@ class FSDPCheckpointManager(BaseCheckpointManager):
dist
.
barrier
()
# every rank will save its own model and optim shard
state_dict_config
=
ShardedStateDictConfig
(
offload_to_cpu
=
True
)
optim_config
=
ShardedOptimStateDictConfig
(
offload_to_cpu
=
True
)
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"ignore"
)
with
FSDP
.
state_dict_type
(
self
.
model
,
StateDictType
.
SHARDED_STATE_DICT
,
state_dict_config
,
optim_config
):
model_state_dict
=
self
.
model
.
state_dict
()
if
self
.
optimizer
is
not
None
:
optimizer_state_dict
=
self
.
optimizer
.
state_dict
()
else
:
optimizer_state_dict
=
None
if
self
.
lr_scheduler
is
not
None
:
lr_scheduler_state_dict
=
self
.
lr_scheduler
.
state_dict
()
else
:
lr_scheduler_state_dict
=
None
state_dict_options
=
StateDictOptions
(
cpu_offload
=
True
)
model_state_dict
,
optim_state_dict
=
get_state_dict
(
self
.
model
,
self
.
optimizer
,
options
=
state_dict_options
)
extra_state_dict
=
{
"lr_scheduler"
:
lr_scheduler
_
state_dict
,
"lr_scheduler"
:
self
.
lr_scheduler
.
state_dict
()
,
"rng"
:
self
.
get_rng_state
(),
}
model_path
=
os
.
path
.
join
(
path
,
f
"model_world_size_
{
self
.
world_size
}
_rank_
{
self
.
rank
}
.pt"
)
...
...
@@ -112,9 +94,7 @@ class FSDPCheckpointManager(BaseCheckpointManager):
print
(
f
"[rank-
{
self
.
rank
}
]: Saving checkpoint to
{
os
.
path
.
abspath
(
model_path
)
}
."
)
print
(
f
"[rank-
{
self
.
rank
}
]: Saving extra_state to
{
os
.
path
.
abspath
(
extra_path
)
}
."
)
torch
.
save
(
model_state_dict
,
model_path
)
if
self
.
optimizer
is
not
None
:
torch
.
save
(
optimizer_state_dict
,
optim_path
)
torch
.
save
(
optim_state_dict
,
optim_path
)
torch
.
save
(
extra_state_dict
,
extra_path
)
# wait for everyone to dump to local
...
...
verl/utils/dataset.py
View file @
2369eb2b
...
...
@@ -21,6 +21,7 @@ from typing import Any, Dict, List, Optional, Union
import
numpy
as
np
import
torch
from
datasets
import
load_dataset
from
jinja2
import
Template
from
PIL
import
Image
from
PIL.Image
import
Image
as
ImageObject
from
torch.utils.data
import
Dataset
...
...
@@ -90,9 +91,10 @@ class RLHFDataset(Dataset, ImageProcessMixin):
image_key
:
str
=
"images"
,
max_prompt_length
:
int
=
1024
,
truncation
:
str
=
"error"
,
format_prompt
:
str
=
None
,
max_pixels
:
int
=
None
,
min_pixels
:
int
=
None
,
format_prompt
:
Optional
[
str
]
=
None
,
max_pixels
:
Optional
[
int
]
=
None
,
min_pixels
:
Optional
[
int
]
=
None
,
filter_overlong_prompts
:
bool
=
True
,
):
self
.
tokenizer
=
tokenizer
self
.
processor
=
processor
...
...
@@ -101,9 +103,9 @@ class RLHFDataset(Dataset, ImageProcessMixin):
self
.
image_key
=
image_key
self
.
max_prompt_length
=
max_prompt_length
self
.
truncation
=
truncation
self
.
format_prompt
=
format_prompt
self
.
max_pixels
=
max_pixels
self
.
min_pixels
=
min_pixels
self
.
filter_overlong_prompts
=
filter_overlong_prompts
if
"@"
in
data_path
:
data_path
,
data_split
=
data_path
.
split
(
"@"
)
...
...
@@ -111,22 +113,29 @@ class RLHFDataset(Dataset, ImageProcessMixin):
data_split
=
"train"
if
os
.
path
.
isdir
(
data_path
):
# when we use dataset builder, we should always refer to the train split
self
.
dataset
=
load_dataset
(
"parquet"
,
data_dir
=
data_path
,
split
=
"train"
)
elif
os
.
path
.
isfile
(
data_path
):
self
.
dataset
=
load_dataset
(
"parquet"
,
data_files
=
data_path
,
split
=
"train"
)
else
:
# remote dataset
else
:
# load remote dataset from huggingface hub
self
.
dataset
=
load_dataset
(
data_path
,
split
=
data_split
)
def
__len__
(
self
):
return
len
(
self
.
dataset
)
self
.
format_prompt
=
None
if
format_prompt
:
with
open
(
format_prompt
,
encoding
=
"utf-8"
)
as
f
:
self
.
format_prompt
=
f
.
read
()
def
__getitem__
(
self
,
index
):
row_dict
:
dict
=
self
.
dataset
[
index
]
prompt_str
:
str
=
row_dict
[
self
.
prompt_key
]
if
self
.
filter_overlong_prompts
:
self
.
dataset
=
self
.
dataset
.
filter
(
self
.
_filter_overlong_prompts
,
desc
=
"Filtering overlong prompts"
)
def
_build_messages
(
self
,
example
:
Dict
[
str
,
Any
])
->
List
[
Dict
[
str
,
Any
]]:
prompt_str
:
str
=
example
[
self
.
prompt_key
]
if
self
.
format_prompt
:
prompt_str
=
prompt_str
+
" "
+
self
.
format_prompt
.
strip
()
format_prompt
=
Template
(
self
.
format_prompt
.
strip
())
prompt_str
=
format_prompt
.
render
(
content
=
prompt_str
)
if
self
.
image_key
in
row_dict
:
if
self
.
image_key
in
example
:
# https://huggingface.co/docs/transformers/en/tasks/image_text_to_text
content_list
=
[]
for
i
,
content
in
enumerate
(
prompt_str
.
split
(
"<image>"
)):
...
...
@@ -136,28 +145,47 @@ class RLHFDataset(Dataset, ImageProcessMixin):
if
content
:
content_list
.
append
({
"type"
:
"text"
,
"text"
:
content
})
messages
=
[{
"role"
:
"user"
,
"content"
:
content_list
}]
return
[{
"role"
:
"user"
,
"content"
:
content_list
}]
else
:
return
[{
"role"
:
"user"
,
"content"
:
prompt_str
}]
def
_filter_overlong_prompts
(
self
,
example
:
Dict
[
str
,
Any
])
->
bool
:
messages
=
self
.
_build_messages
(
example
)
processing_class
=
self
.
processor
if
self
.
processor
is
not
None
else
self
.
tokenizer
return
(
len
(
processing_class
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
))
<=
self
.
max_prompt_length
)
def
__len__
(
self
):
return
len
(
self
.
dataset
)
def
__getitem__
(
self
,
index
):
example
:
dict
=
self
.
dataset
[
index
]
messages
=
self
.
_build_messages
(
example
)
if
self
.
image_key
in
example
:
prompt
=
self
.
processor
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
,
tokenize
=
False
)
images
=
[
self
.
process_image
(
image
)
for
image
in
row_dict
.
pop
(
self
.
image_key
)]
images
=
[
self
.
process_image
(
image
)
for
image
in
example
.
pop
(
self
.
image_key
)]
model_inputs
=
self
.
processor
(
images
,
[
prompt
],
add_special_tokens
=
False
,
return_tensors
=
"pt"
)
input_ids
=
model_inputs
.
pop
(
"input_ids"
)[
0
]
attention_mask
=
model_inputs
.
pop
(
"attention_mask"
)[
0
]
row_dict
[
"multi_modal_data"
]
=
{
"image"
:
images
}
row_dict
[
"multi_modal_inputs"
]
=
dict
(
model_inputs
)
example
[
"multi_modal_data"
]
=
{
"image"
:
images
}
example
[
"multi_modal_inputs"
]
=
dict
(
model_inputs
)
else
:
prompt
=
self
.
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
,
tokenize
=
False
)
model_inputs
=
self
.
tokenizer
([
prompt
],
add_special_tokens
=
False
,
return_tensors
=
"pt"
)
input_ids
=
model_inputs
.
pop
(
"input_ids"
)[
0
]
attention_mask
=
model_inputs
.
pop
(
"attention_mask"
)[
0
]
if
self
.
processor
is
not
None
and
self
.
processor
.
image_processor
.
__class__
.
__name__
==
"Qwen2VLImageProcessor"
:
# qwen2vl mrope
position_ids
=
get_rope_index
(
self
.
processor
,
input_ids
=
input_ids
,
image_grid_thw
=
model_inputs
[
"image_grid_thw"
]
,
image_grid_thw
=
model_inputs
.
get
(
"image_grid_thw"
)
,
attention_mask
=
attention_mask
,
)
# (3, seq_length)
else
:
messages
=
[{
"role"
:
"user"
,
"content"
:
prompt_str
}]
prompt
=
self
.
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
,
tokenize
=
False
)
model_inputs
=
self
.
tokenizer
([
prompt
],
add_special_tokens
=
False
,
return_tensors
=
"pt"
)
input_ids
=
model_inputs
.
pop
(
"input_ids"
)[
0
]
attention_mask
=
model_inputs
.
pop
(
"attention_mask"
)[
0
]
position_ids
=
torch
.
clip
(
attention_mask
.
cumsum
(
dim
=
0
)
-
1
,
min
=
0
,
max
=
None
)
# (seq_length,)
input_ids
,
attention_mask
,
position_ids
=
VF
.
postprocess_data
(
...
...
@@ -169,9 +197,18 @@ class RLHFDataset(Dataset, ImageProcessMixin):
left_pad
=
True
,
truncation
=
self
.
truncation
,
)
row_dict
[
"input_ids"
]
=
input_ids
row_dict
[
"attention_mask"
]
=
attention_mask
row_dict
[
"position_ids"
]
=
position_ids
row_dict
[
"raw_prompt_ids"
]
=
self
.
tokenizer
.
encode
(
prompt
,
add_special_tokens
=
False
)
row_dict
[
"ground_truth"
]
=
row_dict
.
pop
(
self
.
answer_key
)
return
row_dict
raw_prompt_ids
=
self
.
tokenizer
.
encode
(
prompt
,
add_special_tokens
=
False
)
if
len
(
raw_prompt_ids
)
>
self
.
max_prompt_length
:
if
self
.
truncation
==
"left"
:
raw_prompt_ids
=
raw_prompt_ids
[
-
self
.
max_prompt_length
:]
elif
self
.
truncation
==
"right"
:
raw_prompt_ids
=
raw_prompt_ids
[:
self
.
max_prompt_length
]
elif
self
.
truncation
==
"error"
:
raise
RuntimeError
(
f
"Prompt length
{
len
(
raw_prompt_ids
)
}
is longer than
{
self
.
max_prompt_length
}
."
)
example
[
"input_ids"
]
=
input_ids
example
[
"attention_mask"
]
=
attention_mask
example
[
"position_ids"
]
=
position_ids
example
[
"raw_prompt_ids"
]
=
raw_prompt_ids
example
[
"ground_truth"
]
=
example
.
pop
(
self
.
answer_key
)
return
example
verl/utils/logger/logger.py
View file @
2369eb2b
...
...
@@ -71,7 +71,7 @@ class TensorBoardLogger(Logger):
os
.
makedirs
(
tensorboard_dir
,
exist_ok
=
True
)
print
(
f
"Saving tensorboard log to
{
tensorboard_dir
}
."
)
self
.
writer
=
SummaryWriter
(
tensorboard_dir
)
self
.
writer
.
add_hparams
(
hparam_dict
=
flatten_dict
(
config
),
metric_dict
=
{})
self
.
writer
.
add_hparams
(
hparam_dict
=
flatten_dict
(
config
),
metric_dict
=
{
"placeholder"
:
0
})
def
log
(
self
,
data
:
Dict
[
str
,
Any
],
step
:
int
)
->
None
:
for
key
,
value
in
data
.
items
():
...
...
verl/utils/reward_score/__init__.py
deleted
100644 → 0
View file @
ac9d2b05
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.math
import
math_compute_score
from
.r1v
import
r1v_compute_score
__all__
=
[
"math_compute_score"
,
"r1v_compute_score"
]
verl/utils/torch_dtypes.py
View file @
2369eb2b
...
...
@@ -15,40 +15,28 @@
import
torch
HALF_LIST
=
[
16
,
"16"
,
"fp16"
,
"float16"
]
FLOAT_LIST
=
[
32
,
"32"
,
"fp32"
,
"float32"
]
HALF_LIST
=
[
"fp16"
,
"float16"
]
FLOAT_LIST
=
[
"fp32"
,
"float32"
]
BFLOAT_LIST
=
[
"bf16"
,
"bfloat16"
]
class
PrecisionType
:
"""Type of precision used.
>>> PrecisionType.HALF == 16
True
>>> PrecisionType.HALF in (16, "16")
True
"""
HALF
=
"16"
FLOAT
=
"32"
FULL
=
"64"
BFLOAT
=
"bf16"
MIXED
=
"mixed"
"""Type of precision used."""
@
staticmethod
def
is_fp16
(
precision
)
:
def
is_fp16
(
precision
:
str
)
->
bool
:
return
precision
in
HALF_LIST
@
staticmethod
def
is_fp32
(
precision
)
:
def
is_fp32
(
precision
:
str
)
->
bool
:
return
precision
in
FLOAT_LIST
@
staticmethod
def
is_bf16
(
precision
)
:
def
is_bf16
(
precision
:
str
)
->
bool
:
return
precision
in
BFLOAT_LIST
@
staticmethod
def
to_dtype
(
precision
)
->
torch
.
dtype
:
def
to_dtype
(
precision
:
str
)
->
torch
.
dtype
:
if
precision
in
HALF_LIST
:
return
torch
.
float16
elif
precision
in
FLOAT_LIST
:
...
...
@@ -56,7 +44,7 @@ class PrecisionType:
elif
precision
in
BFLOAT_LIST
:
return
torch
.
bfloat16
else
:
raise
RuntimeError
(
f
"
u
nexpected precision:
{
precision
}
"
)
raise
RuntimeError
(
f
"
U
nexpected precision:
{
precision
}
"
)
@
staticmethod
def
to_str
(
precision
:
torch
.
dtype
)
->
str
:
...
...
@@ -67,4 +55,4 @@ class PrecisionType:
elif
precision
==
torch
.
bfloat16
:
return
"bfloat16"
else
:
raise
RuntimeError
(
f
"
u
nexpected precision:
{
precision
}
"
)
raise
RuntimeError
(
f
"
U
nexpected precision:
{
precision
}
"
)
verl/utils/torch_functional.py
View file @
2369eb2b
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright Meta Platforms, Inc. and affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -22,6 +23,8 @@ import torch.distributed
import
torch.nn.functional
as
F
from
torch.optim.lr_scheduler
import
LambdaLR
from
.torch_dtypes
import
PrecisionType
try
:
from
flash_attn.ops.triton.cross_entropy
import
cross_entropy_loss
...
...
@@ -177,7 +180,7 @@ def postprocess_data(
attention_mask
=
attention_mask
[...,
:
max_length
]
position_ids
=
position_ids
[...,
:
max_length
]
elif
truncation
==
"error"
:
raise
NotImplementedError
(
f
"
{
seq_length
}
is l
ar
ger than
{
max_length
}
."
)
raise
RuntimeError
(
f
"Input sequence length
{
seq_length
}
is l
on
ger than
max length
{
max_length
}
."
)
else
:
raise
NotImplementedError
(
f
"Unknown truncation method
{
truncation
}
."
)
...
...
@@ -207,11 +210,18 @@ class AnyPrecisionAdamW(torch.optim.Optimizer):
eps
:
float
=
1e-8
,
weight_decay
:
float
=
0.0
,
use_kahan_summation
:
bool
=
True
,
momentum_dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
variance_dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
compensation_buffer_dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
momentum_dtype
:
str
=
"
bfloat16
"
,
variance_dtype
:
str
=
"
bfloat16
"
,
compensation_buffer_dtype
:
str
=
"
bfloat16
"
,
):
"""
AnyPrecisionAdamW: a flexible precision AdamW optimizer
with optional Kahan summation for high precision weight updates.
Allows direct control over momentum, variance and auxiliary compensation buffer dtypes.
Optional Kahan summation is used to offset precision reduction for the weight updates.
This allows full training in BFloat16 (equal or better than FP32 results in many cases)
due to high precision weight updates.
Args:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
lr (float, optional): learning rate (default: 1e-3)
...
...
@@ -270,10 +280,11 @@ class AnyPrecisionAdamW(torch.optim.Optimizer):
eps
=
group
[
"eps"
]
use_kahan_summation
=
group
[
"use_kahan_summation"
]
momentum_dtype
=
group
[
"momentum_dtype"
]
variance_dtype
=
group
[
"variance_dtype"
]
compensation_buffer_dtype
=
group
[
"compensation_buffer_dtype"
]
momentum_dtype
=
PrecisionType
.
to_dtype
(
group
[
"momentum_dtype"
]
)
variance_dtype
=
PrecisionType
.
to_dtype
(
group
[
"variance_dtype"
]
)
compensation_buffer_dtype
=
PrecisionType
.
to_dtype
(
group
[
"compensation_buffer_dtype"
]
)
for
p
in
group
[
"params"
]:
assert
isinstance
(
p
,
torch
.
Tensor
)
# lint
if
p
.
grad
is
None
:
continue
...
...
verl/workers/actor/__init__.py
View file @
2369eb2b
...
...
@@ -12,15 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
.base
import
BasePPOActor
from
.config
import
ActorConfig
,
FSDPConfig
,
ModelConfig
,
OptimConfig
,
RefConfig
from
.dp_actor
import
DataParallelPPOActor
__all__
=
[
"ActorConfig"
,
"BasePPOActor"
,
"DataParallelPPOActor"
,
"FSDPConfig"
,
"ModelConfig"
,
"OptimConfig"
,
...
...
verl/workers/actor/config.py
View file @
2369eb2b
...
...
@@ -15,6 +15,7 @@
Actor config
"""
import
os
from
dataclasses
import
dataclass
,
field
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
...
...
@@ -32,6 +33,12 @@ class ModelConfig:
if
self
.
tokenizer_path
is
None
:
self
.
tokenizer_path
=
self
.
model_path
if
self
.
model_path
is
not
None
and
os
.
path
.
exists
(
self
.
model_path
):
self
.
model_path
=
os
.
path
.
abspath
(
self
.
model_path
)
if
self
.
tokenizer_path
is
not
None
and
os
.
path
.
exists
(
self
.
tokenizer_path
):
self
.
tokenizer_path
=
os
.
path
.
abspath
(
self
.
tokenizer_path
)
@
dataclass
class
OptimConfig
:
...
...
verl/workers/actor/dp_actor.py
View file @
2369eb2b
...
...
@@ -20,9 +20,11 @@ from collections import defaultdict
from
typing
import
Any
,
Dict
,
Optional
import
torch
from
einops
import
rearrange
from
ray.experimental.tqdm_ray
import
tqdm
from
torch
import
nn
from
torch.distributed.fsdp
import
FullyShardedDataParallel
as
FSDP
from
transformers.modeling_flash_attention_utils
import
index_first_axis
,
pad_input
,
unpad_input
from
...protocol
import
DataProto
from
...trainer
import
core_algos
...
...
@@ -33,12 +35,6 @@ from .base import BasePPOActor
from
.config
import
ActorConfig
try
:
from
flash_attn.bert_padding
import
index_first_axis
,
pad_input
,
rearrange
,
unpad_input
except
ImportError
:
pass
__all__
=
[
"DataParallelPPOActor"
]
...
...
verl/workers/critic/__init__.py
View file @
2369eb2b
...
...
@@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
.base
import
BasePPOCritic
from
.config
import
CriticConfig
,
ModelConfig
from
.dp_critic
import
DataParallelPPOCritic
from
.config
import
CriticConfig
__all__
=
[
"
BasePPOCritic"
,
"CriticConfig"
,
"DataParallelPPOCritic"
,
"Model
Config"
]
__all__
=
[
"
Critic
Config"
]
verl/workers/fsdp_workers.py
View file @
2369eb2b
...
...
@@ -54,9 +54,7 @@ from ..utils.model_utils import print_gpu_memory_usage, print_model_size
from
..utils.tokenizer
import
get_processor
,
get_tokenizer
from
..utils.torch_dtypes
import
PrecisionType
from
..utils.torch_functional
import
AnyPrecisionAdamW
,
get_constant_schedule_with_warmup
from
.actor
import
DataParallelPPOActor
from
.config
import
ActorConfig
,
CriticConfig
,
FSDPConfig
,
ModelConfig
,
OptimConfig
,
RefConfig
,
WorkerConfig
from
.critic
import
DataParallelPPOCritic
from
.rollout
import
vLLMRollout
from
.sharding_manager
import
FSDPVLLMShardingManager
from
.sharding_manager.fsdp_ulysses
import
FSDPUlyssesShardingManager
...
...
@@ -264,6 +262,9 @@ class FSDPWorker(Worker):
else
:
sync_module_states
=
False
param_init_fn
=
None
## TODO: 模型指定到卡
rank
=
torch
.
cuda
.
set_device
(
self
.
rank
)
model
=
model
.
to
(
rank
)
self
.
fsdp_module
=
FSDP
(
model
,
...
...
@@ -365,6 +366,8 @@ class FSDPWorker(Worker):
print_gpu_memory_usage
(
f
"After offload
{
role
}
optimizer during init"
)
if
self
.
_is_actor
:
from
.actor.dp_actor
import
DataParallelPPOActor
# lazy import
self
.
actor
=
DataParallelPPOActor
(
config
=
self
.
config
.
actor
,
actor_module
=
self
.
fsdp_module
,
...
...
@@ -372,6 +375,8 @@ class FSDPWorker(Worker):
)
if
self
.
_is_critic
:
from
.critic.dp_critic
import
DataParallelPPOCritic
# lazy import
self
.
critic
=
DataParallelPPOCritic
(
config
=
self
.
config
,
critic_module
=
self
.
fsdp_module
,
...
...
@@ -382,6 +387,8 @@ class FSDPWorker(Worker):
self
.
_build_rollout
()
if
self
.
_is_ref
:
from
.actor.dp_actor
import
DataParallelPPOActor
# lazy import
self
.
ref_policy
=
DataParallelPPOActor
(
config
=
self
.
config
.
ref
,
actor_module
=
self
.
fsdp_module
,
...
...
verl/workers/reward/__init__.py
View file @
2369eb2b
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
from
.config
import
RewardConfig
from
.
custom
import
Custom
RewardManager
from
.
function
import
Function
RewardManager
__all__
=
[
"
Custom
RewardManager"
,
"RewardConfig"
]
__all__
=
[
"
Function
RewardManager"
,
"RewardConfig"
]
verl/workers/reward/config.py
View file @
2369eb2b
...
...
@@ -15,11 +15,28 @@
Reward config
"""
from
dataclasses
import
dataclass
import
os
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
@
dataclass
class
RewardConfig
:
reward_type
:
str
=
"function"
score_function
:
str
=
"math"
score_function
:
Optional
[
str
]
=
None
score_function_kwargs
:
dict
=
field
(
default_factory
=
dict
)
skip_special_tokens
:
bool
=
True
"""auto keys"""
score_function_name
:
Optional
[
str
]
=
field
(
default
=
None
,
init
=
False
)
def
post_init
(
self
):
if
self
.
score_function
is
not
None
:
if
":"
not
in
self
.
score_function
:
self
.
score_function_name
=
"main"
else
:
self
.
score_function
,
self
.
score_function_name
=
self
.
score_function
.
split
(
":"
,
maxsplit
=
1
)
if
os
.
path
.
exists
(
self
.
score_function
):
self
.
score_function
=
os
.
path
.
abspath
(
self
.
score_function
)
else
:
self
.
score_function
=
None
verl/workers/reward/
custom
.py
→
verl/workers/reward/
function
.py
View file @
2369eb2b
...
...
@@ -12,34 +12,57 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
importlib.util
import
os
import
sys
from
collections
import
defaultdict
from
typing
import
Callable
,
Dict
,
List
,
Tuple
,
TypedDict
from
dataclasses
import
dataclass
from
functools
import
partial
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
TypedDict
import
torch
from
transformers
import
PreTrainedTokenizer
from
...protocol
import
DataProto
from
...utils.reward_score
import
math_compute_score
,
r1v_compute_score
from
.config
import
RewardConfig
class
RewardScore
(
TypedDict
):
overall
:
float
format
:
float
accuracy
:
float
format
:
Optional
[
float
]
accuracy
:
Optional
[
float
]
ScoreFunction
=
Callable
[[
str
,
str
],
RewardScore
]
@
dataclass
class
FunctionRewardManager
:
config
:
RewardConfig
tokenizer
:
PreTrainedTokenizer
def
__post_init__
(
self
):
"""Load score function."""
if
self
.
config
.
score_function
is
None
:
raise
ValueError
(
"Score function is not provided."
)
if
not
os
.
path
.
exists
(
self
.
config
.
score_function
):
raise
FileNotFoundError
(
f
"Score function file
{
self
.
config
.
score_function
}
not found."
)
spec
=
importlib
.
util
.
spec_from_file_location
(
"custom_score_fn"
,
self
.
config
.
score_function
)
module
=
importlib
.
util
.
module_from_spec
(
spec
)
try
:
sys
.
modules
[
"custom_score_fn"
]
=
module
spec
.
loader
.
exec_module
(
module
)
except
Exception
as
e
:
raise
RuntimeError
(
f
"Failed to load score function:
{
e
}
"
)
if
not
hasattr
(
module
,
self
.
config
.
score_function_name
):
raise
AttributeError
(
f
"Module
{
module
}
does not have function
{
self
.
config
.
score_function_name
}
."
)
class
CustomRewardManager
:
def
__init__
(
self
,
tokenizer
:
PreTrainedTokenizer
,
config
:
RewardConfig
):
self
.
config
=
config
self
.
tokenizer
=
tokenizer
if
config
.
score_function
==
"math"
:
self
.
compute_score
:
Callable
[[
str
,
str
],
RewardScore
]
=
math_compute_score
elif
config
.
score_function
==
"r1v"
:
self
.
compute_score
:
Callable
[[
str
,
str
],
RewardScore
]
=
r1v_compute_score
else
:
raise
NotImplementedError
(
f
"Unknown score function
{
config
.
score_function
}
."
)
score_fn
:
ScoreFunction
=
getattr
(
module
,
self
.
config
.
score_function_name
)
print
(
f
"Using score function `
{
self
.
config
.
score_function_name
}
` from `
{
self
.
config
.
score_function
}
`."
)
self
.
score_fn
=
partial
(
score_fn
,
**
self
.
config
.
score_function_kwargs
)
def
__call__
(
self
,
data
:
DataProto
)
->
Tuple
[
torch
.
Tensor
,
Dict
[
str
,
List
[
float
]]]:
reward_tensor
=
torch
.
zeros_like
(
data
.
batch
[
"responses"
],
dtype
=
torch
.
float32
)
...
...
@@ -56,7 +79,7 @@ class CustomRewardManager:
)
ground_truth
=
data_item
.
non_tensor_batch
[
"ground_truth"
]
score
=
self
.
compute_
score
(
response_str
,
ground_truth
)
score
=
self
.
score
_fn
(
response_str
,
ground_truth
)
reward_tensor
[
i
,
valid_response_length
-
1
]
=
score
[
"overall"
]
for
key
,
value
in
score
.
items
():
reward_metrics
[
key
].
append
(
value
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment