Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
EasyR1
Commits
2369eb2b
Commit
2369eb2b
authored
Apr 21, 2025
by
chenych
Browse files
update
parent
ac9d2b05
Changes
43
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
377 additions
and
299 deletions
+377
-299
verl/single_controller/ray/base.py
verl/single_controller/ray/base.py
+5
-5
verl/trainer/config.py
verl/trainer/config.py
+13
-0
verl/trainer/data_loader.py
verl/trainer/data_loader.py
+87
-0
verl/trainer/main.py
verl/trainer/main.py
+24
-13
verl/trainer/metrics.py
verl/trainer/metrics.py
+2
-2
verl/trainer/ray_trainer.py
verl/trainer/ray_trainer.py
+23
-96
verl/utils/checkpoint/fsdp_checkpoint_manager.py
verl/utils/checkpoint/fsdp_checkpoint_manager.py
+27
-47
verl/utils/dataset.py
verl/utils/dataset.py
+65
-28
verl/utils/logger/logger.py
verl/utils/logger/logger.py
+1
-1
verl/utils/reward_score/__init__.py
verl/utils/reward_score/__init__.py
+0
-20
verl/utils/torch_dtypes.py
verl/utils/torch_dtypes.py
+9
-21
verl/utils/torch_functional.py
verl/utils/torch_functional.py
+41
-30
verl/workers/actor/__init__.py
verl/workers/actor/__init__.py
+0
-4
verl/workers/actor/config.py
verl/workers/actor/config.py
+7
-0
verl/workers/actor/dp_actor.py
verl/workers/actor/dp_actor.py
+2
-6
verl/workers/critic/__init__.py
verl/workers/critic/__init__.py
+2
-4
verl/workers/fsdp_workers.py
verl/workers/fsdp_workers.py
+9
-2
verl/workers/reward/__init__.py
verl/workers/reward/__init__.py
+2
-2
verl/workers/reward/config.py
verl/workers/reward/config.py
+19
-2
verl/workers/reward/function.py
verl/workers/reward/function.py
+39
-16
No files found.
verl/single_controller/ray/base.py
View file @
2369eb2b
...
@@ -98,7 +98,7 @@ class RayResourcePool(ResourcePool):
...
@@ -98,7 +98,7 @@ class RayResourcePool(ResourcePool):
# print(f"pg_name_prefix = {pg_name_prefix}")
# print(f"pg_name_prefix = {pg_name_prefix}")
pg_scheme
=
[
pg_scheme
=
[
[
[
{
"CPU"
:
self
.
max_col
l
ocate_count
,
"GPU"
:
1
}
if
self
.
use_gpu
else
{
"CPU"
:
self
.
max_col
l
ocate_count
}
{
"CPU"
:
self
.
max_colocate_count
,
"GPU"
:
1
}
if
self
.
use_gpu
else
{
"CPU"
:
self
.
max_colocate_count
}
for
_
in
range
(
process_count
)
for
_
in
range
(
process_count
)
]
]
for
process_count
in
self
.
_store
for
process_count
in
self
.
_store
...
@@ -145,8 +145,8 @@ def extract_pg_from_exist(
...
@@ -145,8 +145,8 @@ def extract_pg_from_exist(
def
merge_resource_pool
(
rp1
:
RayResourcePool
,
rp2
:
RayResourcePool
)
->
RayResourcePool
:
def
merge_resource_pool
(
rp1
:
RayResourcePool
,
rp2
:
RayResourcePool
)
->
RayResourcePool
:
assert
rp1
.
use_gpu
==
rp2
.
use_gpu
,
"Both RayResourcePool must either use_gpu or not"
assert
rp1
.
use_gpu
==
rp2
.
use_gpu
,
"Both RayResourcePool must either use_gpu or not"
assert
rp1
.
max_col
l
ocate_count
==
rp2
.
max_col
l
ocate_count
,
(
assert
rp1
.
max_colocate_count
==
rp2
.
max_colocate_count
,
(
"Both RayResourcePool must has the same max_col
l
ocate_count"
"Both RayResourcePool must has the same max_colocate_count"
)
)
assert
rp1
.
n_gpus_per_node
==
rp2
.
n_gpus_per_node
,
"Both RayResourcePool must has the same n_gpus_per_node"
assert
rp1
.
n_gpus_per_node
==
rp2
.
n_gpus_per_node
,
"Both RayResourcePool must has the same n_gpus_per_node"
assert
rp1
.
detached
==
rp2
.
detached
,
"Detached ResourcePool cannot be merged with non-detached ResourcePool"
assert
rp1
.
detached
==
rp2
.
detached
,
"Detached ResourcePool cannot be merged with non-detached ResourcePool"
...
@@ -259,7 +259,7 @@ class RayWorkerGroup(WorkerGroup):
...
@@ -259,7 +259,7 @@ class RayWorkerGroup(WorkerGroup):
world_size
=
resource_pool
.
world_size
world_size
=
resource_pool
.
world_size
self
.
_world_size
=
world_size
self
.
_world_size
=
world_size
# cia.add_kwarg("_world_size", world_size)
# cia.add_kwarg("_world_size", world_size)
num_gpus
=
1
/
resource_pool
.
max_col
l
ocate_count
num_gpus
=
1
/
resource_pool
.
max_colocate_count
rank
=
-
1
rank
=
-
1
local_world_size
=
resource_pool
.
store
[
0
]
local_world_size
=
resource_pool
.
store
[
0
]
...
@@ -300,7 +300,7 @@ class RayWorkerGroup(WorkerGroup):
...
@@ -300,7 +300,7 @@ class RayWorkerGroup(WorkerGroup):
if
rank
==
0
:
if
rank
==
0
:
register_center_actor
=
None
register_center_actor
=
None
for
_
in
range
(
36
0
):
for
_
in
range
(
12
0
):
if
f
"
{
self
.
name_prefix
}
_register_center"
not
in
list_named_actors
():
if
f
"
{
self
.
name_prefix
}
_register_center"
not
in
list_named_actors
():
time
.
sleep
(
1
)
time
.
sleep
(
1
)
else
:
else
:
...
...
verl/trainer/config.py
View file @
2369eb2b
...
@@ -47,6 +47,14 @@ class DataConfig:
...
@@ -47,6 +47,14 @@ class DataConfig:
seed
:
int
=
1
seed
:
int
=
1
max_pixels
:
int
=
4194304
max_pixels
:
int
=
4194304
min_pixels
:
int
=
262144
min_pixels
:
int
=
262144
filter_overlong_prompts
:
bool
=
True
def
post_init
(
self
):
if
self
.
format_prompt
is
not
None
:
if
os
.
path
.
exists
(
self
.
format_prompt
):
self
.
format_prompt
=
os
.
path
.
abspath
(
self
.
format_prompt
)
else
:
self
.
format_prompt
=
None
@
dataclass
@
dataclass
...
@@ -86,6 +94,10 @@ class TrainerConfig:
...
@@ -86,6 +94,10 @@ class TrainerConfig:
if
self
.
save_checkpoint_path
is
None
:
if
self
.
save_checkpoint_path
is
None
:
self
.
save_checkpoint_path
=
os
.
path
.
join
(
"checkpoints"
,
self
.
project_name
,
self
.
experiment_name
)
self
.
save_checkpoint_path
=
os
.
path
.
join
(
"checkpoints"
,
self
.
project_name
,
self
.
experiment_name
)
self
.
save_checkpoint_path
=
os
.
path
.
abspath
(
self
.
save_checkpoint_path
)
if
self
.
load_checkpoint_path
is
not
None
:
self
.
load_checkpoint_path
=
os
.
path
.
abspath
(
self
.
load_checkpoint_path
)
@
dataclass
@
dataclass
class
PPOConfig
:
class
PPOConfig
:
...
@@ -97,6 +109,7 @@ class PPOConfig:
...
@@ -97,6 +109,7 @@ class PPOConfig:
def
post_init
(
self
):
def
post_init
(
self
):
self
.
worker
.
rollout
.
prompt_length
=
self
.
data
.
max_prompt_length
self
.
worker
.
rollout
.
prompt_length
=
self
.
data
.
max_prompt_length
self
.
worker
.
rollout
.
response_length
=
self
.
data
.
max_response_length
self
.
worker
.
rollout
.
response_length
=
self
.
data
.
max_response_length
self
.
worker
.
rollout
.
trust_remote_code
=
self
.
worker
.
actor
.
model
.
trust_remote_code
self
.
worker
.
actor
.
disable_kl
=
self
.
algorithm
.
disable_kl
self
.
worker
.
actor
.
disable_kl
=
self
.
algorithm
.
disable_kl
self
.
worker
.
actor
.
use_kl_loss
=
self
.
algorithm
.
use_kl_loss
self
.
worker
.
actor
.
use_kl_loss
=
self
.
algorithm
.
use_kl_loss
self
.
worker
.
actor
.
kl_penalty
=
self
.
algorithm
.
kl_penalty
self
.
worker
.
actor
.
kl_penalty
=
self
.
algorithm
.
kl_penalty
...
...
verl/trainer/data_loader.py
0 → 100644
View file @
2369eb2b
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Optional
import
torch
from
torch.utils.data
import
RandomSampler
,
SequentialSampler
from
torchdata.stateful_dataloader
import
StatefulDataLoader
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
..utils.dataset
import
RLHFDataset
,
collate_fn
from
.config
import
DataConfig
def
create_dataloader
(
config
:
DataConfig
,
tokenizer
:
PreTrainedTokenizer
,
processor
:
Optional
[
ProcessorMixin
])
->
None
:
train_dataset
=
RLHFDataset
(
data_path
=
config
.
train_files
,
tokenizer
=
tokenizer
,
processor
=
processor
,
prompt_key
=
config
.
prompt_key
,
answer_key
=
config
.
answer_key
,
image_key
=
config
.
image_key
,
max_prompt_length
=
config
.
max_prompt_length
,
truncation
=
"right"
,
format_prompt
=
config
.
format_prompt
,
min_pixels
=
config
.
min_pixels
,
max_pixels
=
config
.
max_pixels
,
filter_overlong_prompts
=
config
.
filter_overlong_prompts
,
)
# use sampler for better ckpt resume
if
config
.
shuffle
:
train_dataloader_generator
=
torch
.
Generator
()
train_dataloader_generator
.
manual_seed
(
config
.
seed
)
sampler
=
RandomSampler
(
data_source
=
train_dataset
,
generator
=
train_dataloader_generator
)
else
:
sampler
=
SequentialSampler
(
data_source
=
train_dataset
)
train_dataloader
=
StatefulDataLoader
(
dataset
=
train_dataset
,
batch_size
=
config
.
rollout_batch_size
,
sampler
=
sampler
,
num_workers
=
8
,
collate_fn
=
collate_fn
,
pin_memory
=
False
,
drop_last
=
True
,
)
val_dataset
=
RLHFDataset
(
data_path
=
config
.
val_files
,
tokenizer
=
tokenizer
,
processor
=
processor
,
prompt_key
=
config
.
prompt_key
,
answer_key
=
config
.
answer_key
,
image_key
=
config
.
image_key
,
max_prompt_length
=
config
.
max_prompt_length
,
truncation
=
"right"
,
format_prompt
=
config
.
format_prompt
,
min_pixels
=
config
.
min_pixels
,
max_pixels
=
config
.
max_pixels
,
filter_overlong_prompts
=
config
.
filter_overlong_prompts
,
)
val_dataloader
=
StatefulDataLoader
(
dataset
=
val_dataset
,
batch_size
=
len
(
val_dataset
)
if
config
.
val_batch_size
==
-
1
else
config
.
val_batch_size
,
shuffle
=
False
,
num_workers
=
8
,
collate_fn
=
collate_fn
,
pin_memory
=
False
,
drop_last
=
False
,
)
assert
len
(
train_dataloader
)
>=
1
assert
len
(
val_dataloader
)
>=
1
print
(
f
"Size of train dataloader:
{
len
(
train_dataloader
)
}
"
)
print
(
f
"Size of val dataloader:
{
len
(
val_dataloader
)
}
"
)
return
train_dataloader
,
val_dataloader
verl/trainer/main.py
View file @
2369eb2b
...
@@ -11,21 +11,18 @@
...
@@ -11,21 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""
import
json
import
json
import
torch
import
torch
import
ray
import
ray
from
omegaconf
import
OmegaConf
from
omegaconf
import
OmegaConf
from
..single_controller.ray
import
RayWorkerGroup
from
..single_controller.ray
import
RayWorkerGroup
from
..utils.tokenizer
import
get_processor
,
get_tokenizer
from
..utils.tokenizer
import
get_processor
,
get_tokenizer
from
..workers.fsdp_workers
import
FSDPWorker
from
..workers.fsdp_workers
import
FSDPWorker
from
..workers.reward
import
Custom
RewardManager
from
..workers.reward
import
Function
RewardManager
from
.config
import
PPOConfig
from
.config
import
PPOConfig
from
.data_loader
import
create_dataloader
from
.ray_trainer
import
RayPPOTrainer
,
ResourcePoolManager
,
Role
from
.ray_trainer
import
RayPPOTrainer
,
ResourcePoolManager
,
Role
...
@@ -36,7 +33,6 @@ class Runner:
...
@@ -36,7 +33,6 @@ class Runner:
def
run
(
self
,
config
:
PPOConfig
):
def
run
(
self
,
config
:
PPOConfig
):
# print config
# print config
config
.
deep_post_init
()
print
(
json
.
dumps
(
config
.
to_dict
(),
indent
=
2
))
print
(
json
.
dumps
(
config
.
to_dict
(),
indent
=
2
))
# instantiate tokenizer
# instantiate tokenizer
...
@@ -69,13 +65,19 @@ class Runner:
...
@@ -69,13 +65,19 @@ class Runner:
}
}
resource_pool_manager
=
ResourcePoolManager
(
resource_pool_spec
=
resource_pool_spec
,
mapping
=
mapping
)
resource_pool_manager
=
ResourcePoolManager
(
resource_pool_spec
=
resource_pool_spec
,
mapping
=
mapping
)
reward_fn
=
CustomRewardManager
(
tokenizer
=
tokenizer
,
config
=
config
.
worker
.
reward
)
reward_fn
=
FunctionRewardManager
(
config
=
config
.
worker
.
reward
,
tokenizer
=
tokenizer
)
val_reward_fn
=
CustomRewardManager
(
tokenizer
=
tokenizer
,
config
=
config
.
worker
.
reward
)
val_reward_fn
=
FunctionRewardManager
(
config
=
config
.
worker
.
reward
,
tokenizer
=
tokenizer
)
train_dataloader
,
val_dataloader
=
create_dataloader
(
config
=
config
.
data
,
tokenizer
=
tokenizer
,
processor
=
processor
)
trainer
=
RayPPOTrainer
(
trainer
=
RayPPOTrainer
(
config
=
config
,
config
=
config
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
processor
=
processor
,
processor
=
processor
,
train_dataloader
=
train_dataloader
,
val_dataloader
=
val_dataloader
,
role_worker_mapping
=
role_worker_mapping
,
role_worker_mapping
=
role_worker_mapping
,
resource_pool_manager
=
resource_pool_manager
,
resource_pool_manager
=
resource_pool_manager
,
ray_worker_group_cls
=
ray_worker_group_cls
,
ray_worker_group_cls
=
ray_worker_group_cls
,
...
@@ -96,17 +98,26 @@ def main():
...
@@ -96,17 +98,26 @@ def main():
default_config
=
OmegaConf
.
merge
(
default_config
,
file_config
)
default_config
=
OmegaConf
.
merge
(
default_config
,
file_config
)
ppo_config
=
OmegaConf
.
merge
(
default_config
,
cli_args
)
ppo_config
=
OmegaConf
.
merge
(
default_config
,
cli_args
)
ppo_config
=
OmegaConf
.
to_object
(
ppo_config
)
ppo_config
:
PPOConfig
=
OmegaConf
.
to_object
(
ppo_config
)
ppo_config
.
deep_post_init
()
if
not
ray
.
is_initialized
():
if
not
ray
.
is_initialized
():
runtime_env
=
{
"env_vars"
:
{
"TOKENIZERS_PARALLELISM"
:
"true"
,
"NCCL_DEBUG"
:
"WARN"
,
"VLLM_LOGGING_LEVEL"
:
"INFO"
,
"TORCH_NCCL_AVOID_RECORD_STREAMS"
:
"1"
,
"PYTORCH_CUDA_ALLOC_CONF"
:
"expandable_segments:False"
,
}
}
# this is for local ray cluster
# this is for local ray cluster
if
torch
.
version
.
hip
is
not
None
:
if
torch
.
version
.
hip
is
not
None
:
ray
.
init
(
num_gpus
=
torch
.
cuda
.
device_count
(),
ray
.
init
(
num_gpus
=
torch
.
cuda
.
device_count
(),
ignore_reinit_error
=
True
,
ignore_reinit_error
=
True
,
runtime_env
=
{
"env_vars"
:
{
"TOKENIZERS_PARALLELISM"
:
"true"
,
"NCCL_DEBUG"
:
"WARN"
}}
)
runtime_env
=
runtime_env
)
else
:
else
:
ray
.
init
(
runtime_env
=
{
"env_vars"
:
{
"TOKENIZERS_PARALLELISM"
:
"true"
,
"NCCL_DEBUG"
:
"WARN"
}})
ray
.
init
(
runtime_env
=
runtime_env
)
runner
=
Runner
.
remote
()
runner
=
Runner
.
remote
()
ray
.
get
(
runner
.
run
.
remote
(
ppo_config
))
ray
.
get
(
runner
.
run
.
remote
(
ppo_config
))
...
...
verl/trainer/metrics.py
View file @
2369eb2b
...
@@ -110,11 +110,11 @@ def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Di
...
@@ -110,11 +110,11 @@ def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Di
}
}
def
compute_throughout_metrics
(
batch
:
DataProto
,
timing_raw
:
Dict
[
str
,
float
],
n_gpus
:
int
)
->
Dict
[
str
,
Any
]:
def
compute_throughout_metrics
(
batch
:
DataProto
,
timing_raw
:
Dict
[
str
,
float
],
n
um
_gpus
:
int
)
->
Dict
[
str
,
Any
]:
total_num_tokens
=
sum
(
batch
.
meta_info
[
"global_token_num"
])
total_num_tokens
=
sum
(
batch
.
meta_info
[
"global_token_num"
])
time
=
timing_raw
[
"step"
]
time
=
timing_raw
[
"step"
]
return
{
return
{
"perf/total_num_tokens"
:
total_num_tokens
,
"perf/total_num_tokens"
:
total_num_tokens
,
"perf/time_per_step"
:
time
,
"perf/time_per_step"
:
time
,
"perf/throughput"
:
total_num_tokens
/
(
time
*
n_gpus
),
"perf/throughput"
:
total_num_tokens
/
(
time
*
n
um
_gpus
),
}
}
verl/trainer/ray_trainer.py
View file @
2369eb2b
...
@@ -30,7 +30,6 @@ import ray
...
@@ -30,7 +30,6 @@ import ray
import
torch
import
torch
from
codetiming
import
Timer
from
codetiming
import
Timer
from
ray.experimental.tqdm_ray
import
tqdm
from
ray.experimental.tqdm_ray
import
tqdm
from
torch.utils.data
import
RandomSampler
,
SequentialSampler
from
torchdata.stateful_dataloader
import
StatefulDataLoader
from
torchdata.stateful_dataloader
import
StatefulDataLoader
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
...
@@ -40,7 +39,6 @@ from ..single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWo
...
@@ -40,7 +39,6 @@ from ..single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWo
from
..single_controller.ray.base
import
create_colocated_worker_cls
from
..single_controller.ray.base
import
create_colocated_worker_cls
from
..utils
import
torch_functional
as
VF
from
..utils
import
torch_functional
as
VF
from
..utils.checkpoint
import
CHECKPOINT_TRACKER
,
remove_obsolete_ckpt
from
..utils.checkpoint
import
CHECKPOINT_TRACKER
,
remove_obsolete_ckpt
from
..utils.dataset
import
RLHFDataset
,
collate_fn
from
..utils.logger
import
Tracker
from
..utils.logger
import
Tracker
from
..utils.py_functional
import
convert_dict_to_str
from
..utils.py_functional
import
convert_dict_to_str
from
..utils.seqlen_balancing
import
get_seqlen_balanced_partitions
,
log_seqlen_unbalance
from
..utils.seqlen_balancing
import
get_seqlen_balanced_partitions
,
log_seqlen_unbalance
...
@@ -102,24 +100,16 @@ class ResourcePoolManager:
...
@@ -102,24 +100,16 @@ class ResourcePoolManager:
"""Get the resource pool of the worker."""
"""Get the resource pool of the worker."""
return
self
.
resource_pool_dict
[
self
.
mapping
[
role
]]
return
self
.
resource_pool_dict
[
self
.
mapping
[
role
]]
def
get_n_gpus
(
self
)
->
int
:
def
get_n
um
_gpus
(
self
)
->
int
:
"""Get the number of gpus in this cluster."""
"""Get the number of gpus in this cluster."""
return
sum
([
n_gpus
for
process_on_nodes
in
self
.
resource_pool_spec
.
values
()
for
n_gpus
in
process_on_nodes
])
return
sum
([
n_gpus
for
process_on_nodes
in
self
.
resource_pool_spec
.
values
()
for
n_gpus
in
process_on_nodes
])
def
_check_resource_available
(
self
):
def
_check_resource_available
(
self
):
"""Check if the resource pool can be satisfied in this ray cluster."""
"""Check if the resource pool can be satisfied in this ray cluster."""
node_available_resources
=
ray
.
state
.
available_resources_per_node
()
gpus_available
=
ray
.
available_resources
().
get
(
"GPU"
,
0
)
node_available_gpus
=
{
node
:
node_info
.
get
(
"GPU"
,
0
)
for
node
,
node_info
in
node_available_resources
.
items
()}
gpus_required
=
self
.
get_num_gpus
()
if
gpus_available
<
gpus_required
:
# check total required gpus can be satisfied
raise
ValueError
(
f
"Total available GPUs
{
gpus_available
}
is less than total desired GPUs
{
gpus_required
}
."
)
total_available_gpus
=
sum
(
node_available_gpus
.
values
())
total_required_gpus
=
sum
(
[
n_gpus
for
process_on_nodes
in
self
.
resource_pool_spec
.
values
()
for
n_gpus
in
process_on_nodes
]
)
if
total_available_gpus
<
total_required_gpus
:
raise
ValueError
(
f
"Total available GPUs
{
total_available_gpus
}
is less than total desired GPUs
{
total_required_gpus
}
."
)
def
apply_kl_penalty
(
data
:
DataProto
,
kl_ctrl
:
core_algos
.
KLController
,
kl_penalty
=
"kl"
):
def
apply_kl_penalty
(
data
:
DataProto
,
kl_ctrl
:
core_algos
.
KLController
,
kl_penalty
=
"kl"
):
...
@@ -128,11 +118,8 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.KLController, kl_penal
...
@@ -128,11 +118,8 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.KLController, kl_penal
response_mask
=
data
.
batch
[
"response_mask"
]
response_mask
=
data
.
batch
[
"response_mask"
]
# compute kl between ref_policy and current policy
# compute kl between ref_policy and current policy
if
"ref_log_probs"
in
data
.
batch
.
keys
():
kld
=
core_algos
.
compute_kl
(
data
.
batch
[
"old_log_probs"
],
data
.
batch
[
"ref_log_probs"
],
kl_penalty
=
kl_penalty
)
kld
=
core_algos
.
compute_kl
(
data
.
batch
[
"old_log_probs"
],
data
.
batch
[
"ref_log_probs"
],
kl_penalty
=
kl_penalty
)
kld
=
kld
*
response_mask
# (batch_size, response_length)
kld
=
kld
*
response_mask
# (batch_size, response_length)
else
:
kld
=
torch
.
zeros_like
(
response_mask
,
dtype
=
torch
.
float32
)
data
.
batch
[
"token_level_rewards"
]
=
token_level_scores
-
kl_ctrl
.
kl_coef
*
kld
data
.
batch
[
"token_level_rewards"
]
=
token_level_scores
-
kl_ctrl
.
kl_coef
*
kld
...
@@ -193,6 +180,8 @@ class RayPPOTrainer:
...
@@ -193,6 +180,8 @@ class RayPPOTrainer:
config
:
PPOConfig
,
config
:
PPOConfig
,
tokenizer
:
PreTrainedTokenizer
,
tokenizer
:
PreTrainedTokenizer
,
processor
:
Optional
[
ProcessorMixin
],
processor
:
Optional
[
ProcessorMixin
],
train_dataloader
:
StatefulDataLoader
,
val_dataloader
:
StatefulDataLoader
,
role_worker_mapping
:
dict
[
Role
,
Type
[
Worker
]],
role_worker_mapping
:
dict
[
Role
,
Type
[
Worker
]],
resource_pool_manager
:
ResourcePoolManager
,
resource_pool_manager
:
ResourcePoolManager
,
ray_worker_group_cls
:
Type
[
RayWorkerGroup
]
=
RayWorkerGroup
,
ray_worker_group_cls
:
Type
[
RayWorkerGroup
]
=
RayWorkerGroup
,
...
@@ -201,6 +190,8 @@ class RayPPOTrainer:
...
@@ -201,6 +190,8 @@ class RayPPOTrainer:
):
):
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
self
.
processor
=
processor
self
.
processor
=
processor
self
.
train_dataloader
=
train_dataloader
self
.
val_dataloader
=
val_dataloader
self
.
config
=
config
self
.
config
=
config
self
.
reward_fn
=
reward_fn
self
.
reward_fn
=
reward_fn
self
.
val_reward_fn
=
val_reward_fn
self
.
val_reward_fn
=
val_reward_fn
...
@@ -262,78 +253,13 @@ class RayPPOTrainer:
...
@@ -262,78 +253,13 @@ class RayPPOTrainer:
):
):
raise
ValueError
(
"GRPO and RLOO algorithm need `config.worker.rollout.n > 1`."
)
raise
ValueError
(
"GRPO and RLOO algorithm need `config.worker.rollout.n > 1`."
)
self
.
_create_dataloader
()
if
config
.
trainer
.
max_steps
is
not
None
:
self
.
training_steps
=
config
.
trainer
.
max_steps
def
_create_dataloader
(
self
)
->
None
:
self
.
train_dataset
=
RLHFDataset
(
data_path
=
self
.
config
.
data
.
train_files
,
tokenizer
=
self
.
tokenizer
,
processor
=
self
.
processor
,
prompt_key
=
self
.
config
.
data
.
prompt_key
,
answer_key
=
self
.
config
.
data
.
answer_key
,
image_key
=
self
.
config
.
data
.
image_key
,
max_prompt_length
=
self
.
config
.
data
.
max_prompt_length
,
truncation
=
"right"
,
format_prompt
=
self
.
config
.
data
.
format_prompt
,
min_pixels
=
self
.
config
.
data
.
min_pixels
,
max_pixels
=
self
.
config
.
data
.
max_pixels
,
)
# use sampler for better ckpt resume
if
self
.
config
.
data
.
shuffle
:
train_dataloader_generator
=
torch
.
Generator
()
train_dataloader_generator
.
manual_seed
(
self
.
config
.
data
.
seed
)
sampler
=
RandomSampler
(
data_source
=
self
.
train_dataset
,
generator
=
train_dataloader_generator
)
else
:
sampler
=
SequentialSampler
(
data_source
=
self
.
train_dataset
)
self
.
train_dataloader
=
StatefulDataLoader
(
dataset
=
self
.
train_dataset
,
batch_size
=
self
.
config
.
data
.
rollout_batch_size
,
sampler
=
sampler
,
num_workers
=
8
,
collate_fn
=
collate_fn
,
pin_memory
=
False
,
drop_last
=
True
,
)
self
.
val_dataset
=
RLHFDataset
(
data_path
=
self
.
config
.
data
.
val_files
,
tokenizer
=
self
.
tokenizer
,
processor
=
self
.
processor
,
prompt_key
=
self
.
config
.
data
.
prompt_key
,
answer_key
=
self
.
config
.
data
.
answer_key
,
image_key
=
self
.
config
.
data
.
image_key
,
max_prompt_length
=
self
.
config
.
data
.
max_prompt_length
,
truncation
=
"right"
,
format_prompt
=
self
.
config
.
data
.
format_prompt
,
min_pixels
=
self
.
config
.
data
.
min_pixels
,
max_pixels
=
self
.
config
.
data
.
max_pixels
,
)
self
.
val_dataloader
=
StatefulDataLoader
(
dataset
=
self
.
val_dataset
,
batch_size
=
len
(
self
.
val_dataset
)
if
self
.
config
.
data
.
val_batch_size
==
-
1
else
self
.
config
.
data
.
val_batch_size
,
shuffle
=
False
,
num_workers
=
8
,
collate_fn
=
collate_fn
,
pin_memory
=
False
,
drop_last
=
False
,
)
assert
len
(
self
.
train_dataloader
)
>=
1
assert
len
(
self
.
val_dataloader
)
>=
1
print
(
f
"Size of train dataloader:
{
len
(
self
.
train_dataloader
)
}
"
)
print
(
f
"Size of val dataloader:
{
len
(
self
.
val_dataloader
)
}
"
)
if
self
.
config
.
trainer
.
max_steps
is
not
None
:
training_steps
=
self
.
config
.
trainer
.
max_steps
else
:
else
:
training_steps
=
len
(
self
.
train_dataloader
)
*
self
.
config
.
trainer
.
total_episodes
self
.
training_steps
=
len
(
train_dataloader
)
*
config
.
trainer
.
total_episodes
self
.
training_steps
=
training_steps
config
.
worker
.
actor
.
optim
.
training_steps
=
self
.
training_steps
self
.
config
.
worker
.
actor
.
optim
.
training_steps
=
training_steps
config
.
worker
.
critic
.
optim
.
training_steps
=
self
.
training_steps
self
.
config
.
worker
.
critic
.
optim
.
training_steps
=
training_steps
print
(
f
"Total training steps:
{
self
.
training_steps
}
"
)
print
(
f
"Total training steps:
{
self
.
training_steps
}
"
)
def
_maybe_log_val_generations
(
def
_maybe_log_val_generations
(
...
@@ -366,10 +292,10 @@ class RayPPOTrainer:
...
@@ -366,10 +292,10 @@ class RayPPOTrainer:
input_texts
=
[
self
.
tokenizer
.
decode
(
ids
,
skip_special_tokens
=
True
)
for
ids
in
input_ids
]
input_texts
=
[
self
.
tokenizer
.
decode
(
ids
,
skip_special_tokens
=
True
)
for
ids
in
input_ids
]
sample_inputs
.
extend
(
input_texts
)
sample_inputs
.
extend
(
input_texts
)
if
"multi_modal_
inputs
"
in
test_batch
.
non_tensor_batch
.
keys
():
if
"multi_modal_
data
"
in
test_batch
.
non_tensor_batch
.
keys
():
test_gen_batch
=
test_batch
.
pop
(
test_gen_batch
=
test_batch
.
pop
(
batch_keys
=
[
"input_ids"
,
"attention_mask"
,
"position_ids"
],
batch_keys
=
[
"input_ids"
,
"attention_mask"
,
"position_ids"
],
non_tensor_batch_keys
=
[
"raw_prompt_ids"
,
"multi_modal_data"
,
"multi_modal_inputs"
],
non_tensor_batch_keys
=
[
"raw_prompt_ids"
,
"multi_modal_data"
],
)
)
else
:
else
:
test_gen_batch
=
test_batch
.
pop
(
test_gen_batch
=
test_batch
.
pop
(
...
@@ -567,10 +493,10 @@ class RayPPOTrainer:
...
@@ -567,10 +493,10 @@ class RayPPOTrainer:
batch
:
DataProto
=
DataProto
.
from_single_dict
(
batch_dict
)
batch
:
DataProto
=
DataProto
.
from_single_dict
(
batch_dict
)
# pop those keys for generation
# pop those keys for generation
if
"multi_modal_
inputs
"
in
batch
.
non_tensor_batch
.
keys
():
if
"multi_modal_
data
"
in
batch
.
non_tensor_batch
.
keys
():
gen_batch
=
batch
.
pop
(
gen_batch
=
batch
.
pop
(
batch_keys
=
[
"input_ids"
,
"attention_mask"
,
"position_ids"
],
batch_keys
=
[
"input_ids"
,
"attention_mask"
,
"position_ids"
],
non_tensor_batch_keys
=
[
"raw_prompt_ids"
,
"multi_modal_data"
,
"multi_modal_inputs"
],
non_tensor_batch_keys
=
[
"raw_prompt_ids"
,
"multi_modal_data"
],
)
)
else
:
else
:
gen_batch
=
batch
.
pop
(
gen_batch
=
batch
.
pop
(
...
@@ -604,6 +530,7 @@ class RayPPOTrainer:
...
@@ -604,6 +530,7 @@ class RayPPOTrainer:
# repeat to align with repeated responses in rollout
# repeat to align with repeated responses in rollout
batch
=
batch
.
repeat
(
repeat_times
=
self
.
config
.
worker
.
rollout
.
n
,
interleave
=
True
)
batch
=
batch
.
repeat
(
repeat_times
=
self
.
config
.
worker
.
rollout
.
n
,
interleave
=
True
)
batch
=
batch
.
union
(
gen_batch_output
)
batch
=
batch
.
union
(
gen_batch_output
)
batch
.
non_tensor_batch
.
pop
(
"multi_modal_data"
,
None
)
# compute reward
# compute reward
with
_timer
(
"reward"
,
timing_raw
):
with
_timer
(
"reward"
,
timing_raw
):
...
@@ -694,10 +621,10 @@ class RayPPOTrainer:
...
@@ -694,10 +621,10 @@ class RayPPOTrainer:
self
.
_save_checkpoint
()
self
.
_save_checkpoint
()
# collect metrics
# collect metrics
n_gpus
=
self
.
resource_pool_manager
.
get_n_gpus
()
n
um
_gpus
=
self
.
resource_pool_manager
.
get_n
um
_gpus
()
metrics
.
update
(
compute_data_metrics
(
batch
=
batch
,
use_critic
=
self
.
use_critic
))
metrics
.
update
(
compute_data_metrics
(
batch
=
batch
,
use_critic
=
self
.
use_critic
))
metrics
.
update
(
compute_timing_metrics
(
batch
=
batch
,
timing_raw
=
timing_raw
))
metrics
.
update
(
compute_timing_metrics
(
batch
=
batch
,
timing_raw
=
timing_raw
))
metrics
.
update
(
compute_throughout_metrics
(
batch
=
batch
,
timing_raw
=
timing_raw
,
n_gpus
=
n_gpus
))
metrics
.
update
(
compute_throughout_metrics
(
batch
=
batch
,
timing_raw
=
timing_raw
,
n
um
_gpus
=
n
um
_gpus
))
self
.
logger
.
log
(
data
=
metrics
,
step
=
self
.
global_step
)
self
.
logger
.
log
(
data
=
metrics
,
step
=
self
.
global_step
)
...
...
verl/utils/checkpoint/fsdp_checkpoint_manager.py
View file @
2369eb2b
...
@@ -13,13 +13,12 @@
...
@@ -13,13 +13,12 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
import
warnings
from
typing
import
Optional
,
Union
from
typing
import
Optional
,
Union
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.distributed.checkpoint.state_dict
import
StateDictOptions
,
get_state_dict
,
set_state_dict
from
torch.distributed.fsdp
import
FullyShardedDataParallel
as
FSDP
from
torch.distributed.fsdp
import
FullyShardedDataParallel
as
FSDP
from
torch.distributed.fsdp
import
ShardedOptimStateDictConfig
,
ShardedStateDictConfig
,
StateDictType
from
transformers
import
PreTrainedModel
,
PreTrainedTokenizer
,
ProcessorMixin
from
transformers
import
PreTrainedModel
,
PreTrainedTokenizer
,
ProcessorMixin
from
.checkpoint_manager
import
BaseCheckpointManager
from
.checkpoint_manager
import
BaseCheckpointManager
...
@@ -59,21 +58,18 @@ class FSDPCheckpointManager(BaseCheckpointManager):
...
@@ -59,21 +58,18 @@ class FSDPCheckpointManager(BaseCheckpointManager):
extra_state_path
=
os
.
path
.
join
(
path
,
f
"extra_state_world_size_
{
self
.
world_size
}
_rank_
{
self
.
rank
}
.pt"
)
extra_state_path
=
os
.
path
.
join
(
path
,
f
"extra_state_world_size_
{
self
.
world_size
}
_rank_
{
self
.
rank
}
.pt"
)
print
(
f
"[rank-
{
self
.
rank
}
]: Loading from
{
model_path
}
and
{
optim_path
}
and
{
extra_state_path
}
."
)
print
(
f
"[rank-
{
self
.
rank
}
]: Loading from
{
model_path
}
and
{
optim_path
}
and
{
extra_state_path
}
."
)
model_state_dict
=
torch
.
load
(
model_path
,
weights_only
=
False
)
model_state_dict
=
torch
.
load
(
model_path
,
weights_only
=
False
)
optim
izer
_state_dict
=
torch
.
load
(
optim_path
,
weights_only
=
False
)
optim_state_dict
=
torch
.
load
(
optim_path
,
weights_only
=
False
)
extra_state_dict
=
torch
.
load
(
extra_state_path
,
weights_only
=
False
)
extra_state_dict
=
torch
.
load
(
extra_state_path
,
weights_only
=
False
)
lr_scheduler_state_dict
=
extra_state_dict
[
"lr_scheduler"
]
state_dict_config
=
ShardedStateDictConfig
(
offload_to_cpu
=
True
)
state_dict_options
=
StateDictOptions
(
cpu_offload
=
True
)
optim_config
=
ShardedOptimStateDictConfig
(
offload_to_cpu
=
True
)
set_state_dict
(
with
warnings
.
catch_warnings
():
model
=
self
.
model
,
warnings
.
simplefilter
(
"ignore"
)
optimizers
=
self
.
optimizer
,
with
FSDP
.
state_dict_type
(
self
.
model
,
StateDictType
.
SHARDED_STATE_DICT
,
state_dict_config
,
optim_config
):
model_state_dict
=
model_state_dict
,
self
.
model
.
load_state_dict
(
model_state_dict
)
optim_state_dict
=
optim_state_dict
,
if
self
.
optimizer
is
not
None
:
options
=
state_dict_options
,
self
.
optimizer
.
load_state_dict
(
optimizer_state_dict
)
)
self
.
lr_scheduler
.
load_state_dict
(
extra_state_dict
[
"lr_scheduler"
])
if
self
.
lr_scheduler
is
not
None
:
self
.
lr_scheduler
.
load_state_dict
(
lr_scheduler_state_dict
)
# recover random state
# recover random state
if
"rng"
in
extra_state_dict
:
if
"rng"
in
extra_state_dict
:
...
@@ -84,24 +80,10 @@ class FSDPCheckpointManager(BaseCheckpointManager):
...
@@ -84,24 +80,10 @@ class FSDPCheckpointManager(BaseCheckpointManager):
dist
.
barrier
()
dist
.
barrier
()
# every rank will save its own model and optim shard
# every rank will save its own model and optim shard
state_dict_config
=
ShardedStateDictConfig
(
offload_to_cpu
=
True
)
state_dict_options
=
StateDictOptions
(
cpu_offload
=
True
)
optim_config
=
ShardedOptimStateDictConfig
(
offload_to_cpu
=
True
)
model_state_dict
,
optim_state_dict
=
get_state_dict
(
self
.
model
,
self
.
optimizer
,
options
=
state_dict_options
)
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"ignore"
)
with
FSDP
.
state_dict_type
(
self
.
model
,
StateDictType
.
SHARDED_STATE_DICT
,
state_dict_config
,
optim_config
):
model_state_dict
=
self
.
model
.
state_dict
()
if
self
.
optimizer
is
not
None
:
optimizer_state_dict
=
self
.
optimizer
.
state_dict
()
else
:
optimizer_state_dict
=
None
if
self
.
lr_scheduler
is
not
None
:
lr_scheduler_state_dict
=
self
.
lr_scheduler
.
state_dict
()
else
:
lr_scheduler_state_dict
=
None
extra_state_dict
=
{
extra_state_dict
=
{
"lr_scheduler"
:
lr_scheduler
_
state_dict
,
"lr_scheduler"
:
self
.
lr_scheduler
.
state_dict
()
,
"rng"
:
self
.
get_rng_state
(),
"rng"
:
self
.
get_rng_state
(),
}
}
model_path
=
os
.
path
.
join
(
path
,
f
"model_world_size_
{
self
.
world_size
}
_rank_
{
self
.
rank
}
.pt"
)
model_path
=
os
.
path
.
join
(
path
,
f
"model_world_size_
{
self
.
world_size
}
_rank_
{
self
.
rank
}
.pt"
)
...
@@ -112,9 +94,7 @@ class FSDPCheckpointManager(BaseCheckpointManager):
...
@@ -112,9 +94,7 @@ class FSDPCheckpointManager(BaseCheckpointManager):
print
(
f
"[rank-
{
self
.
rank
}
]: Saving checkpoint to
{
os
.
path
.
abspath
(
model_path
)
}
."
)
print
(
f
"[rank-
{
self
.
rank
}
]: Saving checkpoint to
{
os
.
path
.
abspath
(
model_path
)
}
."
)
print
(
f
"[rank-
{
self
.
rank
}
]: Saving extra_state to
{
os
.
path
.
abspath
(
extra_path
)
}
."
)
print
(
f
"[rank-
{
self
.
rank
}
]: Saving extra_state to
{
os
.
path
.
abspath
(
extra_path
)
}
."
)
torch
.
save
(
model_state_dict
,
model_path
)
torch
.
save
(
model_state_dict
,
model_path
)
if
self
.
optimizer
is
not
None
:
torch
.
save
(
optim_state_dict
,
optim_path
)
torch
.
save
(
optimizer_state_dict
,
optim_path
)
torch
.
save
(
extra_state_dict
,
extra_path
)
torch
.
save
(
extra_state_dict
,
extra_path
)
# wait for everyone to dump to local
# wait for everyone to dump to local
...
...
verl/utils/dataset.py
View file @
2369eb2b
...
@@ -21,6 +21,7 @@ from typing import Any, Dict, List, Optional, Union
...
@@ -21,6 +21,7 @@ from typing import Any, Dict, List, Optional, Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
datasets
import
load_dataset
from
datasets
import
load_dataset
from
jinja2
import
Template
from
PIL
import
Image
from
PIL
import
Image
from
PIL.Image
import
Image
as
ImageObject
from
PIL.Image
import
Image
as
ImageObject
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
...
@@ -90,9 +91,10 @@ class RLHFDataset(Dataset, ImageProcessMixin):
...
@@ -90,9 +91,10 @@ class RLHFDataset(Dataset, ImageProcessMixin):
image_key
:
str
=
"images"
,
image_key
:
str
=
"images"
,
max_prompt_length
:
int
=
1024
,
max_prompt_length
:
int
=
1024
,
truncation
:
str
=
"error"
,
truncation
:
str
=
"error"
,
format_prompt
:
str
=
None
,
format_prompt
:
Optional
[
str
]
=
None
,
max_pixels
:
int
=
None
,
max_pixels
:
Optional
[
int
]
=
None
,
min_pixels
:
int
=
None
,
min_pixels
:
Optional
[
int
]
=
None
,
filter_overlong_prompts
:
bool
=
True
,
):
):
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
self
.
processor
=
processor
self
.
processor
=
processor
...
@@ -101,9 +103,9 @@ class RLHFDataset(Dataset, ImageProcessMixin):
...
@@ -101,9 +103,9 @@ class RLHFDataset(Dataset, ImageProcessMixin):
self
.
image_key
=
image_key
self
.
image_key
=
image_key
self
.
max_prompt_length
=
max_prompt_length
self
.
max_prompt_length
=
max_prompt_length
self
.
truncation
=
truncation
self
.
truncation
=
truncation
self
.
format_prompt
=
format_prompt
self
.
max_pixels
=
max_pixels
self
.
max_pixels
=
max_pixels
self
.
min_pixels
=
min_pixels
self
.
min_pixels
=
min_pixels
self
.
filter_overlong_prompts
=
filter_overlong_prompts
if
"@"
in
data_path
:
if
"@"
in
data_path
:
data_path
,
data_split
=
data_path
.
split
(
"@"
)
data_path
,
data_split
=
data_path
.
split
(
"@"
)
...
@@ -111,22 +113,29 @@ class RLHFDataset(Dataset, ImageProcessMixin):
...
@@ -111,22 +113,29 @@ class RLHFDataset(Dataset, ImageProcessMixin):
data_split
=
"train"
data_split
=
"train"
if
os
.
path
.
isdir
(
data_path
):
if
os
.
path
.
isdir
(
data_path
):
# when we use dataset builder, we should always refer to the train split
self
.
dataset
=
load_dataset
(
"parquet"
,
data_dir
=
data_path
,
split
=
"train"
)
self
.
dataset
=
load_dataset
(
"parquet"
,
data_dir
=
data_path
,
split
=
"train"
)
elif
os
.
path
.
isfile
(
data_path
):
elif
os
.
path
.
isfile
(
data_path
):
self
.
dataset
=
load_dataset
(
"parquet"
,
data_files
=
data_path
,
split
=
"train"
)
self
.
dataset
=
load_dataset
(
"parquet"
,
data_files
=
data_path
,
split
=
"train"
)
else
:
# remote dataset
else
:
# load remote dataset from huggingface hub
self
.
dataset
=
load_dataset
(
data_path
,
split
=
data_split
)
self
.
dataset
=
load_dataset
(
data_path
,
split
=
data_split
)
def
__len__
(
self
):
self
.
format_prompt
=
None
return
len
(
self
.
dataset
)
if
format_prompt
:
with
open
(
format_prompt
,
encoding
=
"utf-8"
)
as
f
:
self
.
format_prompt
=
f
.
read
()
def
__getitem__
(
self
,
index
):
if
self
.
filter_overlong_prompts
:
row_dict
:
dict
=
self
.
dataset
[
index
]
self
.
dataset
=
self
.
dataset
.
filter
(
self
.
_filter_overlong_prompts
,
desc
=
"Filtering overlong prompts"
)
prompt_str
:
str
=
row_dict
[
self
.
prompt_key
]
def
_build_messages
(
self
,
example
:
Dict
[
str
,
Any
])
->
List
[
Dict
[
str
,
Any
]]:
prompt_str
:
str
=
example
[
self
.
prompt_key
]
if
self
.
format_prompt
:
if
self
.
format_prompt
:
prompt_str
=
prompt_str
+
" "
+
self
.
format_prompt
.
strip
()
format_prompt
=
Template
(
self
.
format_prompt
.
strip
())
prompt_str
=
format_prompt
.
render
(
content
=
prompt_str
)
if
self
.
image_key
in
row_dict
:
if
self
.
image_key
in
example
:
# https://huggingface.co/docs/transformers/en/tasks/image_text_to_text
# https://huggingface.co/docs/transformers/en/tasks/image_text_to_text
content_list
=
[]
content_list
=
[]
for
i
,
content
in
enumerate
(
prompt_str
.
split
(
"<image>"
)):
for
i
,
content
in
enumerate
(
prompt_str
.
split
(
"<image>"
)):
...
@@ -136,28 +145,47 @@ class RLHFDataset(Dataset, ImageProcessMixin):
...
@@ -136,28 +145,47 @@ class RLHFDataset(Dataset, ImageProcessMixin):
if
content
:
if
content
:
content_list
.
append
({
"type"
:
"text"
,
"text"
:
content
})
content_list
.
append
({
"type"
:
"text"
,
"text"
:
content
})
messages
=
[{
"role"
:
"user"
,
"content"
:
content_list
}]
return
[{
"role"
:
"user"
,
"content"
:
content_list
}]
else
:
return
[{
"role"
:
"user"
,
"content"
:
prompt_str
}]
def
_filter_overlong_prompts
(
self
,
example
:
Dict
[
str
,
Any
])
->
bool
:
messages
=
self
.
_build_messages
(
example
)
processing_class
=
self
.
processor
if
self
.
processor
is
not
None
else
self
.
tokenizer
return
(
len
(
processing_class
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
))
<=
self
.
max_prompt_length
)
def
__len__
(
self
):
return
len
(
self
.
dataset
)
def
__getitem__
(
self
,
index
):
example
:
dict
=
self
.
dataset
[
index
]
messages
=
self
.
_build_messages
(
example
)
if
self
.
image_key
in
example
:
prompt
=
self
.
processor
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
,
tokenize
=
False
)
prompt
=
self
.
processor
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
,
tokenize
=
False
)
images
=
[
self
.
process_image
(
image
)
for
image
in
row_dict
.
pop
(
self
.
image_key
)]
images
=
[
self
.
process_image
(
image
)
for
image
in
example
.
pop
(
self
.
image_key
)]
model_inputs
=
self
.
processor
(
images
,
[
prompt
],
add_special_tokens
=
False
,
return_tensors
=
"pt"
)
model_inputs
=
self
.
processor
(
images
,
[
prompt
],
add_special_tokens
=
False
,
return_tensors
=
"pt"
)
input_ids
=
model_inputs
.
pop
(
"input_ids"
)[
0
]
input_ids
=
model_inputs
.
pop
(
"input_ids"
)[
0
]
attention_mask
=
model_inputs
.
pop
(
"attention_mask"
)[
0
]
attention_mask
=
model_inputs
.
pop
(
"attention_mask"
)[
0
]
row_dict
[
"multi_modal_data"
]
=
{
"image"
:
images
}
example
[
"multi_modal_data"
]
=
{
"image"
:
images
}
row_dict
[
"multi_modal_inputs"
]
=
dict
(
model_inputs
)
example
[
"multi_modal_inputs"
]
=
dict
(
model_inputs
)
else
:
prompt
=
self
.
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
,
tokenize
=
False
)
model_inputs
=
self
.
tokenizer
([
prompt
],
add_special_tokens
=
False
,
return_tensors
=
"pt"
)
input_ids
=
model_inputs
.
pop
(
"input_ids"
)[
0
]
attention_mask
=
model_inputs
.
pop
(
"attention_mask"
)[
0
]
if
self
.
processor
is
not
None
and
self
.
processor
.
image_processor
.
__class__
.
__name__
==
"Qwen2VLImageProcessor"
:
# qwen2vl mrope
# qwen2vl mrope
position_ids
=
get_rope_index
(
position_ids
=
get_rope_index
(
self
.
processor
,
self
.
processor
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
image_grid_thw
=
model_inputs
[
"image_grid_thw"
]
,
image_grid_thw
=
model_inputs
.
get
(
"image_grid_thw"
)
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
)
# (3, seq_length)
)
# (3, seq_length)
else
:
else
:
messages
=
[{
"role"
:
"user"
,
"content"
:
prompt_str
}]
prompt
=
self
.
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
,
tokenize
=
False
)
model_inputs
=
self
.
tokenizer
([
prompt
],
add_special_tokens
=
False
,
return_tensors
=
"pt"
)
input_ids
=
model_inputs
.
pop
(
"input_ids"
)[
0
]
attention_mask
=
model_inputs
.
pop
(
"attention_mask"
)[
0
]
position_ids
=
torch
.
clip
(
attention_mask
.
cumsum
(
dim
=
0
)
-
1
,
min
=
0
,
max
=
None
)
# (seq_length,)
position_ids
=
torch
.
clip
(
attention_mask
.
cumsum
(
dim
=
0
)
-
1
,
min
=
0
,
max
=
None
)
# (seq_length,)
input_ids
,
attention_mask
,
position_ids
=
VF
.
postprocess_data
(
input_ids
,
attention_mask
,
position_ids
=
VF
.
postprocess_data
(
...
@@ -169,9 +197,18 @@ class RLHFDataset(Dataset, ImageProcessMixin):
...
@@ -169,9 +197,18 @@ class RLHFDataset(Dataset, ImageProcessMixin):
left_pad
=
True
,
left_pad
=
True
,
truncation
=
self
.
truncation
,
truncation
=
self
.
truncation
,
)
)
row_dict
[
"input_ids"
]
=
input_ids
raw_prompt_ids
=
self
.
tokenizer
.
encode
(
prompt
,
add_special_tokens
=
False
)
row_dict
[
"attention_mask"
]
=
attention_mask
if
len
(
raw_prompt_ids
)
>
self
.
max_prompt_length
:
row_dict
[
"position_ids"
]
=
position_ids
if
self
.
truncation
==
"left"
:
row_dict
[
"raw_prompt_ids"
]
=
self
.
tokenizer
.
encode
(
prompt
,
add_special_tokens
=
False
)
raw_prompt_ids
=
raw_prompt_ids
[
-
self
.
max_prompt_length
:]
row_dict
[
"ground_truth"
]
=
row_dict
.
pop
(
self
.
answer_key
)
elif
self
.
truncation
==
"right"
:
return
row_dict
raw_prompt_ids
=
raw_prompt_ids
[:
self
.
max_prompt_length
]
elif
self
.
truncation
==
"error"
:
raise
RuntimeError
(
f
"Prompt length
{
len
(
raw_prompt_ids
)
}
is longer than
{
self
.
max_prompt_length
}
."
)
example
[
"input_ids"
]
=
input_ids
example
[
"attention_mask"
]
=
attention_mask
example
[
"position_ids"
]
=
position_ids
example
[
"raw_prompt_ids"
]
=
raw_prompt_ids
example
[
"ground_truth"
]
=
example
.
pop
(
self
.
answer_key
)
return
example
verl/utils/logger/logger.py
View file @
2369eb2b
...
@@ -71,7 +71,7 @@ class TensorBoardLogger(Logger):
...
@@ -71,7 +71,7 @@ class TensorBoardLogger(Logger):
os
.
makedirs
(
tensorboard_dir
,
exist_ok
=
True
)
os
.
makedirs
(
tensorboard_dir
,
exist_ok
=
True
)
print
(
f
"Saving tensorboard log to
{
tensorboard_dir
}
."
)
print
(
f
"Saving tensorboard log to
{
tensorboard_dir
}
."
)
self
.
writer
=
SummaryWriter
(
tensorboard_dir
)
self
.
writer
=
SummaryWriter
(
tensorboard_dir
)
self
.
writer
.
add_hparams
(
hparam_dict
=
flatten_dict
(
config
),
metric_dict
=
{})
self
.
writer
.
add_hparams
(
hparam_dict
=
flatten_dict
(
config
),
metric_dict
=
{
"placeholder"
:
0
})
def
log
(
self
,
data
:
Dict
[
str
,
Any
],
step
:
int
)
->
None
:
def
log
(
self
,
data
:
Dict
[
str
,
Any
],
step
:
int
)
->
None
:
for
key
,
value
in
data
.
items
():
for
key
,
value
in
data
.
items
():
...
...
verl/utils/reward_score/__init__.py
deleted
100644 → 0
View file @
ac9d2b05
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.math
import
math_compute_score
from
.r1v
import
r1v_compute_score
__all__
=
[
"math_compute_score"
,
"r1v_compute_score"
]
verl/utils/torch_dtypes.py
View file @
2369eb2b
...
@@ -15,40 +15,28 @@
...
@@ -15,40 +15,28 @@
import
torch
import
torch
HALF_LIST
=
[
16
,
"16"
,
"fp16"
,
"float16"
]
HALF_LIST
=
[
"fp16"
,
"float16"
]
FLOAT_LIST
=
[
32
,
"32"
,
"fp32"
,
"float32"
]
FLOAT_LIST
=
[
"fp32"
,
"float32"
]
BFLOAT_LIST
=
[
"bf16"
,
"bfloat16"
]
BFLOAT_LIST
=
[
"bf16"
,
"bfloat16"
]
class
PrecisionType
:
class
PrecisionType
:
"""Type of precision used.
"""Type of precision used."""
>>> PrecisionType.HALF == 16
True
>>> PrecisionType.HALF in (16, "16")
True
"""
HALF
=
"16"
FLOAT
=
"32"
FULL
=
"64"
BFLOAT
=
"bf16"
MIXED
=
"mixed"
@
staticmethod
@
staticmethod
def
is_fp16
(
precision
)
:
def
is_fp16
(
precision
:
str
)
->
bool
:
return
precision
in
HALF_LIST
return
precision
in
HALF_LIST
@
staticmethod
@
staticmethod
def
is_fp32
(
precision
)
:
def
is_fp32
(
precision
:
str
)
->
bool
:
return
precision
in
FLOAT_LIST
return
precision
in
FLOAT_LIST
@
staticmethod
@
staticmethod
def
is_bf16
(
precision
)
:
def
is_bf16
(
precision
:
str
)
->
bool
:
return
precision
in
BFLOAT_LIST
return
precision
in
BFLOAT_LIST
@
staticmethod
@
staticmethod
def
to_dtype
(
precision
)
->
torch
.
dtype
:
def
to_dtype
(
precision
:
str
)
->
torch
.
dtype
:
if
precision
in
HALF_LIST
:
if
precision
in
HALF_LIST
:
return
torch
.
float16
return
torch
.
float16
elif
precision
in
FLOAT_LIST
:
elif
precision
in
FLOAT_LIST
:
...
@@ -56,7 +44,7 @@ class PrecisionType:
...
@@ -56,7 +44,7 @@ class PrecisionType:
elif
precision
in
BFLOAT_LIST
:
elif
precision
in
BFLOAT_LIST
:
return
torch
.
bfloat16
return
torch
.
bfloat16
else
:
else
:
raise
RuntimeError
(
f
"
u
nexpected precision:
{
precision
}
"
)
raise
RuntimeError
(
f
"
U
nexpected precision:
{
precision
}
"
)
@
staticmethod
@
staticmethod
def
to_str
(
precision
:
torch
.
dtype
)
->
str
:
def
to_str
(
precision
:
torch
.
dtype
)
->
str
:
...
@@ -67,4 +55,4 @@ class PrecisionType:
...
@@ -67,4 +55,4 @@ class PrecisionType:
elif
precision
==
torch
.
bfloat16
:
elif
precision
==
torch
.
bfloat16
:
return
"bfloat16"
return
"bfloat16"
else
:
else
:
raise
RuntimeError
(
f
"
u
nexpected precision:
{
precision
}
"
)
raise
RuntimeError
(
f
"
U
nexpected precision:
{
precision
}
"
)
verl/utils/torch_functional.py
View file @
2369eb2b
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright Meta Platforms, Inc. and affiliates
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -22,6 +23,8 @@ import torch.distributed
...
@@ -22,6 +23,8 @@ import torch.distributed
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.optim.lr_scheduler
import
LambdaLR
from
torch.optim.lr_scheduler
import
LambdaLR
from
.torch_dtypes
import
PrecisionType
try
:
try
:
from
flash_attn.ops.triton.cross_entropy
import
cross_entropy_loss
from
flash_attn.ops.triton.cross_entropy
import
cross_entropy_loss
...
@@ -177,7 +180,7 @@ def postprocess_data(
...
@@ -177,7 +180,7 @@ def postprocess_data(
attention_mask
=
attention_mask
[...,
:
max_length
]
attention_mask
=
attention_mask
[...,
:
max_length
]
position_ids
=
position_ids
[...,
:
max_length
]
position_ids
=
position_ids
[...,
:
max_length
]
elif
truncation
==
"error"
:
elif
truncation
==
"error"
:
raise
NotImplementedError
(
f
"
{
seq_length
}
is l
ar
ger than
{
max_length
}
."
)
raise
RuntimeError
(
f
"Input sequence length
{
seq_length
}
is l
on
ger than
max length
{
max_length
}
."
)
else
:
else
:
raise
NotImplementedError
(
f
"Unknown truncation method
{
truncation
}
."
)
raise
NotImplementedError
(
f
"Unknown truncation method
{
truncation
}
."
)
...
@@ -207,11 +210,18 @@ class AnyPrecisionAdamW(torch.optim.Optimizer):
...
@@ -207,11 +210,18 @@ class AnyPrecisionAdamW(torch.optim.Optimizer):
eps
:
float
=
1e-8
,
eps
:
float
=
1e-8
,
weight_decay
:
float
=
0.0
,
weight_decay
:
float
=
0.0
,
use_kahan_summation
:
bool
=
True
,
use_kahan_summation
:
bool
=
True
,
momentum_dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
momentum_dtype
:
str
=
"
bfloat16
"
,
variance_dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
variance_dtype
:
str
=
"
bfloat16
"
,
compensation_buffer_dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
compensation_buffer_dtype
:
str
=
"
bfloat16
"
,
):
):
"""
"""
AnyPrecisionAdamW: a flexible precision AdamW optimizer
with optional Kahan summation for high precision weight updates.
Allows direct control over momentum, variance and auxiliary compensation buffer dtypes.
Optional Kahan summation is used to offset precision reduction for the weight updates.
This allows full training in BFloat16 (equal or better than FP32 results in many cases)
due to high precision weight updates.
Args:
Args:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
lr (float, optional): learning rate (default: 1e-3)
lr (float, optional): learning rate (default: 1e-3)
...
@@ -270,10 +280,11 @@ class AnyPrecisionAdamW(torch.optim.Optimizer):
...
@@ -270,10 +280,11 @@ class AnyPrecisionAdamW(torch.optim.Optimizer):
eps
=
group
[
"eps"
]
eps
=
group
[
"eps"
]
use_kahan_summation
=
group
[
"use_kahan_summation"
]
use_kahan_summation
=
group
[
"use_kahan_summation"
]
momentum_dtype
=
group
[
"momentum_dtype"
]
momentum_dtype
=
PrecisionType
.
to_dtype
(
group
[
"momentum_dtype"
]
)
variance_dtype
=
group
[
"variance_dtype"
]
variance_dtype
=
PrecisionType
.
to_dtype
(
group
[
"variance_dtype"
]
)
compensation_buffer_dtype
=
group
[
"compensation_buffer_dtype"
]
compensation_buffer_dtype
=
PrecisionType
.
to_dtype
(
group
[
"compensation_buffer_dtype"
]
)
for
p
in
group
[
"params"
]:
for
p
in
group
[
"params"
]:
assert
isinstance
(
p
,
torch
.
Tensor
)
# lint
if
p
.
grad
is
None
:
if
p
.
grad
is
None
:
continue
continue
...
...
verl/workers/actor/__init__.py
View file @
2369eb2b
...
@@ -12,15 +12,11 @@
...
@@ -12,15 +12,11 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
.base
import
BasePPOActor
from
.config
import
ActorConfig
,
FSDPConfig
,
ModelConfig
,
OptimConfig
,
RefConfig
from
.config
import
ActorConfig
,
FSDPConfig
,
ModelConfig
,
OptimConfig
,
RefConfig
from
.dp_actor
import
DataParallelPPOActor
__all__
=
[
__all__
=
[
"ActorConfig"
,
"ActorConfig"
,
"BasePPOActor"
,
"DataParallelPPOActor"
,
"FSDPConfig"
,
"FSDPConfig"
,
"ModelConfig"
,
"ModelConfig"
,
"OptimConfig"
,
"OptimConfig"
,
...
...
verl/workers/actor/config.py
View file @
2369eb2b
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
Actor config
Actor config
"""
"""
import
os
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
...
@@ -32,6 +33,12 @@ class ModelConfig:
...
@@ -32,6 +33,12 @@ class ModelConfig:
if
self
.
tokenizer_path
is
None
:
if
self
.
tokenizer_path
is
None
:
self
.
tokenizer_path
=
self
.
model_path
self
.
tokenizer_path
=
self
.
model_path
if
self
.
model_path
is
not
None
and
os
.
path
.
exists
(
self
.
model_path
):
self
.
model_path
=
os
.
path
.
abspath
(
self
.
model_path
)
if
self
.
tokenizer_path
is
not
None
and
os
.
path
.
exists
(
self
.
tokenizer_path
):
self
.
tokenizer_path
=
os
.
path
.
abspath
(
self
.
tokenizer_path
)
@
dataclass
@
dataclass
class
OptimConfig
:
class
OptimConfig
:
...
...
verl/workers/actor/dp_actor.py
View file @
2369eb2b
...
@@ -20,9 +20,11 @@ from collections import defaultdict
...
@@ -20,9 +20,11 @@ from collections import defaultdict
from
typing
import
Any
,
Dict
,
Optional
from
typing
import
Any
,
Dict
,
Optional
import
torch
import
torch
from
einops
import
rearrange
from
ray.experimental.tqdm_ray
import
tqdm
from
ray.experimental.tqdm_ray
import
tqdm
from
torch
import
nn
from
torch
import
nn
from
torch.distributed.fsdp
import
FullyShardedDataParallel
as
FSDP
from
torch.distributed.fsdp
import
FullyShardedDataParallel
as
FSDP
from
transformers.modeling_flash_attention_utils
import
index_first_axis
,
pad_input
,
unpad_input
from
...protocol
import
DataProto
from
...protocol
import
DataProto
from
...trainer
import
core_algos
from
...trainer
import
core_algos
...
@@ -33,12 +35,6 @@ from .base import BasePPOActor
...
@@ -33,12 +35,6 @@ from .base import BasePPOActor
from
.config
import
ActorConfig
from
.config
import
ActorConfig
try
:
from
flash_attn.bert_padding
import
index_first_axis
,
pad_input
,
rearrange
,
unpad_input
except
ImportError
:
pass
__all__
=
[
"DataParallelPPOActor"
]
__all__
=
[
"DataParallelPPOActor"
]
...
...
verl/workers/critic/__init__.py
View file @
2369eb2b
...
@@ -12,9 +12,7 @@
...
@@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
.base
import
BasePPOCritic
from
.config
import
CriticConfig
from
.config
import
CriticConfig
,
ModelConfig
from
.dp_critic
import
DataParallelPPOCritic
__all__
=
[
"
BasePPOCritic"
,
"CriticConfig"
,
"DataParallelPPOCritic"
,
"Model
Config"
]
__all__
=
[
"
Critic
Config"
]
verl/workers/fsdp_workers.py
View file @
2369eb2b
...
@@ -54,9 +54,7 @@ from ..utils.model_utils import print_gpu_memory_usage, print_model_size
...
@@ -54,9 +54,7 @@ from ..utils.model_utils import print_gpu_memory_usage, print_model_size
from
..utils.tokenizer
import
get_processor
,
get_tokenizer
from
..utils.tokenizer
import
get_processor
,
get_tokenizer
from
..utils.torch_dtypes
import
PrecisionType
from
..utils.torch_dtypes
import
PrecisionType
from
..utils.torch_functional
import
AnyPrecisionAdamW
,
get_constant_schedule_with_warmup
from
..utils.torch_functional
import
AnyPrecisionAdamW
,
get_constant_schedule_with_warmup
from
.actor
import
DataParallelPPOActor
from
.config
import
ActorConfig
,
CriticConfig
,
FSDPConfig
,
ModelConfig
,
OptimConfig
,
RefConfig
,
WorkerConfig
from
.config
import
ActorConfig
,
CriticConfig
,
FSDPConfig
,
ModelConfig
,
OptimConfig
,
RefConfig
,
WorkerConfig
from
.critic
import
DataParallelPPOCritic
from
.rollout
import
vLLMRollout
from
.rollout
import
vLLMRollout
from
.sharding_manager
import
FSDPVLLMShardingManager
from
.sharding_manager
import
FSDPVLLMShardingManager
from
.sharding_manager.fsdp_ulysses
import
FSDPUlyssesShardingManager
from
.sharding_manager.fsdp_ulysses
import
FSDPUlyssesShardingManager
...
@@ -264,6 +262,9 @@ class FSDPWorker(Worker):
...
@@ -264,6 +262,9 @@ class FSDPWorker(Worker):
else
:
else
:
sync_module_states
=
False
sync_module_states
=
False
param_init_fn
=
None
param_init_fn
=
None
## TODO: 模型指定到卡
rank
=
torch
.
cuda
.
set_device
(
self
.
rank
)
model
=
model
.
to
(
rank
)
self
.
fsdp_module
=
FSDP
(
self
.
fsdp_module
=
FSDP
(
model
,
model
,
...
@@ -365,6 +366,8 @@ class FSDPWorker(Worker):
...
@@ -365,6 +366,8 @@ class FSDPWorker(Worker):
print_gpu_memory_usage
(
f
"After offload
{
role
}
optimizer during init"
)
print_gpu_memory_usage
(
f
"After offload
{
role
}
optimizer during init"
)
if
self
.
_is_actor
:
if
self
.
_is_actor
:
from
.actor.dp_actor
import
DataParallelPPOActor
# lazy import
self
.
actor
=
DataParallelPPOActor
(
self
.
actor
=
DataParallelPPOActor
(
config
=
self
.
config
.
actor
,
config
=
self
.
config
.
actor
,
actor_module
=
self
.
fsdp_module
,
actor_module
=
self
.
fsdp_module
,
...
@@ -372,6 +375,8 @@ class FSDPWorker(Worker):
...
@@ -372,6 +375,8 @@ class FSDPWorker(Worker):
)
)
if
self
.
_is_critic
:
if
self
.
_is_critic
:
from
.critic.dp_critic
import
DataParallelPPOCritic
# lazy import
self
.
critic
=
DataParallelPPOCritic
(
self
.
critic
=
DataParallelPPOCritic
(
config
=
self
.
config
,
config
=
self
.
config
,
critic_module
=
self
.
fsdp_module
,
critic_module
=
self
.
fsdp_module
,
...
@@ -382,6 +387,8 @@ class FSDPWorker(Worker):
...
@@ -382,6 +387,8 @@ class FSDPWorker(Worker):
self
.
_build_rollout
()
self
.
_build_rollout
()
if
self
.
_is_ref
:
if
self
.
_is_ref
:
from
.actor.dp_actor
import
DataParallelPPOActor
# lazy import
self
.
ref_policy
=
DataParallelPPOActor
(
self
.
ref_policy
=
DataParallelPPOActor
(
config
=
self
.
config
.
ref
,
config
=
self
.
config
.
ref
,
actor_module
=
self
.
fsdp_module
,
actor_module
=
self
.
fsdp_module
,
...
...
verl/workers/reward/__init__.py
View file @
2369eb2b
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
from
.config
import
RewardConfig
from
.config
import
RewardConfig
from
.
custom
import
Custom
RewardManager
from
.
function
import
Function
RewardManager
__all__
=
[
"
Custom
RewardManager"
,
"RewardConfig"
]
__all__
=
[
"
Function
RewardManager"
,
"RewardConfig"
]
verl/workers/reward/config.py
View file @
2369eb2b
...
@@ -15,11 +15,28 @@
...
@@ -15,11 +15,28 @@
Reward config
Reward config
"""
"""
from
dataclasses
import
dataclass
import
os
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
@
dataclass
@
dataclass
class
RewardConfig
:
class
RewardConfig
:
reward_type
:
str
=
"function"
reward_type
:
str
=
"function"
score_function
:
str
=
"math"
score_function
:
Optional
[
str
]
=
None
score_function_kwargs
:
dict
=
field
(
default_factory
=
dict
)
skip_special_tokens
:
bool
=
True
skip_special_tokens
:
bool
=
True
"""auto keys"""
score_function_name
:
Optional
[
str
]
=
field
(
default
=
None
,
init
=
False
)
def
post_init
(
self
):
if
self
.
score_function
is
not
None
:
if
":"
not
in
self
.
score_function
:
self
.
score_function_name
=
"main"
else
:
self
.
score_function
,
self
.
score_function_name
=
self
.
score_function
.
split
(
":"
,
maxsplit
=
1
)
if
os
.
path
.
exists
(
self
.
score_function
):
self
.
score_function
=
os
.
path
.
abspath
(
self
.
score_function
)
else
:
self
.
score_function
=
None
verl/workers/reward/
custom
.py
→
verl/workers/reward/
function
.py
View file @
2369eb2b
...
@@ -12,34 +12,57 @@
...
@@ -12,34 +12,57 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
importlib.util
import
os
import
sys
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
Callable
,
Dict
,
List
,
Tuple
,
TypedDict
from
dataclasses
import
dataclass
from
functools
import
partial
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
TypedDict
import
torch
import
torch
from
transformers
import
PreTrainedTokenizer
from
transformers
import
PreTrainedTokenizer
from
...protocol
import
DataProto
from
...protocol
import
DataProto
from
...utils.reward_score
import
math_compute_score
,
r1v_compute_score
from
.config
import
RewardConfig
from
.config
import
RewardConfig
class
RewardScore
(
TypedDict
):
class
RewardScore
(
TypedDict
):
overall
:
float
overall
:
float
format
:
float
format
:
Optional
[
float
]
accuracy
:
float
accuracy
:
Optional
[
float
]
ScoreFunction
=
Callable
[[
str
,
str
],
RewardScore
]
@
dataclass
class
FunctionRewardManager
:
config
:
RewardConfig
tokenizer
:
PreTrainedTokenizer
def
__post_init__
(
self
):
"""Load score function."""
if
self
.
config
.
score_function
is
None
:
raise
ValueError
(
"Score function is not provided."
)
if
not
os
.
path
.
exists
(
self
.
config
.
score_function
):
raise
FileNotFoundError
(
f
"Score function file
{
self
.
config
.
score_function
}
not found."
)
spec
=
importlib
.
util
.
spec_from_file_location
(
"custom_score_fn"
,
self
.
config
.
score_function
)
module
=
importlib
.
util
.
module_from_spec
(
spec
)
try
:
sys
.
modules
[
"custom_score_fn"
]
=
module
spec
.
loader
.
exec_module
(
module
)
except
Exception
as
e
:
raise
RuntimeError
(
f
"Failed to load score function:
{
e
}
"
)
if
not
hasattr
(
module
,
self
.
config
.
score_function_name
):
raise
AttributeError
(
f
"Module
{
module
}
does not have function
{
self
.
config
.
score_function_name
}
."
)
class
CustomRewardManager
:
score_fn
:
ScoreFunction
=
getattr
(
module
,
self
.
config
.
score_function_name
)
def
__init__
(
self
,
tokenizer
:
PreTrainedTokenizer
,
config
:
RewardConfig
):
print
(
f
"Using score function `
{
self
.
config
.
score_function_name
}
` from `
{
self
.
config
.
score_function
}
`."
)
self
.
config
=
config
self
.
score_fn
=
partial
(
score_fn
,
**
self
.
config
.
score_function_kwargs
)
self
.
tokenizer
=
tokenizer
if
config
.
score_function
==
"math"
:
self
.
compute_score
:
Callable
[[
str
,
str
],
RewardScore
]
=
math_compute_score
elif
config
.
score_function
==
"r1v"
:
self
.
compute_score
:
Callable
[[
str
,
str
],
RewardScore
]
=
r1v_compute_score
else
:
raise
NotImplementedError
(
f
"Unknown score function
{
config
.
score_function
}
."
)
def
__call__
(
self
,
data
:
DataProto
)
->
Tuple
[
torch
.
Tensor
,
Dict
[
str
,
List
[
float
]]]:
def
__call__
(
self
,
data
:
DataProto
)
->
Tuple
[
torch
.
Tensor
,
Dict
[
str
,
List
[
float
]]]:
reward_tensor
=
torch
.
zeros_like
(
data
.
batch
[
"responses"
],
dtype
=
torch
.
float32
)
reward_tensor
=
torch
.
zeros_like
(
data
.
batch
[
"responses"
],
dtype
=
torch
.
float32
)
...
@@ -56,7 +79,7 @@ class CustomRewardManager:
...
@@ -56,7 +79,7 @@ class CustomRewardManager:
)
)
ground_truth
=
data_item
.
non_tensor_batch
[
"ground_truth"
]
ground_truth
=
data_item
.
non_tensor_batch
[
"ground_truth"
]
score
=
self
.
compute_
score
(
response_str
,
ground_truth
)
score
=
self
.
score
_fn
(
response_str
,
ground_truth
)
reward_tensor
[
i
,
valid_response_length
-
1
]
=
score
[
"overall"
]
reward_tensor
[
i
,
valid_response_length
-
1
]
=
score
[
"overall"
]
for
key
,
value
in
score
.
items
():
for
key
,
value
in
score
.
items
():
reward_metrics
[
key
].
append
(
value
)
reward_metrics
[
key
].
append
(
value
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment