Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
EasyR1
Commits
c132cbcb
Commit
c132cbcb
authored
Apr 02, 2025
by
chenych
Browse files
0402 update
parent
f92481f0
Changes
72
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1368 additions
and
444 deletions
+1368
-444
verl/utils/checkpoint/fsdp_checkpoint_manager.py
verl/utils/checkpoint/fsdp_checkpoint_manager.py
+49
-58
verl/utils/dataset.py
verl/utils/dataset.py
+177
-0
verl/utils/flops_counter.py
verl/utils/flops_counter.py
+24
-13
verl/utils/fsdp_utils.py
verl/utils/fsdp_utils.py
+16
-3
verl/utils/logger/__init__.py
verl/utils/logger/__init__.py
+6
-0
verl/utils/logger/gen_logger.py
verl/utils/logger/gen_logger.py
+99
-0
verl/utils/logger/logger.py
verl/utils/logger/logger.py
+154
-0
verl/utils/model_utils.py
verl/utils/model_utils.py
+26
-11
verl/utils/py_functional.py
verl/utils/py_functional.py
+70
-4
verl/utils/reward_score/math.py
verl/utils/reward_score/math.py
+33
-6
verl/utils/reward_score/r1v.py
verl/utils/reward_score/r1v.py
+29
-11
verl/utils/seqlen_balancing.py
verl/utils/seqlen_balancing.py
+264
-0
verl/utils/tokenizer.py
verl/utils/tokenizer.py
+9
-19
verl/utils/torch_dtypes.py
verl/utils/torch_dtypes.py
+5
-5
verl/utils/torch_functional.py
verl/utils/torch_functional.py
+221
-192
verl/utils/ulysses.py
verl/utils/ulysses.py
+1
-1
verl/workers/actor/base.py
verl/workers/actor/base.py
+2
-2
verl/workers/actor/config.py
verl/workers/actor/config.py
+19
-15
verl/workers/actor/dp_actor.py
verl/workers/actor/dp_actor.py
+157
-99
verl/workers/config.py
verl/workers/config.py
+7
-5
No files found.
verl/utils/checkpoint/fsdp_checkpoint_manager.py
View file @
c132cbcb
...
...
@@ -14,12 +14,13 @@
import
os
import
warnings
from
typing
import
Optional
,
Union
import
torch
import
torch.distributed
import
torch.distributed
as
dist
from
torch.distributed.fsdp
import
FullyShardedDataParallel
as
FSDP
from
torch.distributed.fsdp
import
ShardedOptimStateDictConfig
,
ShardedStateDictConfig
,
StateDictType
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
transformers
import
PreTrainedModel
,
PreTrainedTokenizer
,
ProcessorMixin
from
.checkpoint_manager
import
BaseCheckpointManager
...
...
@@ -44,65 +45,56 @@ class FSDPCheckpointManager(BaseCheckpointManager):
model
:
FSDP
,
optimizer
:
torch
.
optim
.
Optimizer
,
lr_scheduler
:
torch
.
optim
.
lr_scheduler
.
LRScheduler
,
tokenizer
:
PreTrainedTokenizer
,
processor
:
ProcessorMixin
,
*
args
,
**
kwargs
,
processing_class
:
Union
[
PreTrainedTokenizer
,
ProcessorMixin
],
):
super
().
__init__
(
model
,
optimizer
,
lr_scheduler
,
tokenizer
,
processor
)
super
().
__init__
(
model
,
optimizer
,
lr_scheduler
,
processing_class
)
def
load_checkpoint
(
self
,
path
=
None
,
*
args
,
**
kwargs
):
def
load_checkpoint
(
self
,
path
:
Optional
[
str
]
=
None
):
if
path
is
None
:
return
# every rank download its own checkpoint
local_model_path
=
os
.
path
.
join
(
path
,
f
"model_world_size_
{
self
.
world_size
}
_rank_
{
self
.
rank
}
.pt"
)
local_optim_path
=
os
.
path
.
join
(
path
,
f
"optim_world_size_
{
self
.
world_size
}
_rank_
{
self
.
rank
}
.pt"
)
local_extra_state_path
=
os
.
path
.
join
(
path
,
f
"extra_state_world_size_
{
self
.
world_size
}
_rank_
{
self
.
rank
}
.pt"
)
print
(
f
"[rank-
{
self
.
rank
}
]: Loading from
{
local_model_path
}
and
{
local_optim_path
}
and
{
local_extra_state_path
}
"
)
model_state_dict
=
torch
.
load
(
local_model_path
)
optimizer_state_dict
=
torch
.
load
(
local_optim_path
)
extra_state_dict
=
torch
.
load
(
local_extra_state_path
)
model_path
=
os
.
path
.
join
(
path
,
f
"model_world_size_
{
self
.
world_size
}
_rank_
{
self
.
rank
}
.pt"
)
optim_path
=
os
.
path
.
join
(
path
,
f
"optim_world_size_
{
self
.
world_size
}
_rank_
{
self
.
rank
}
.pt"
)
extra_state_path
=
os
.
path
.
join
(
path
,
f
"extra_state_world_size_
{
self
.
world_size
}
_rank_
{
self
.
rank
}
.pt"
)
print
(
f
"[rank-
{
self
.
rank
}
]: Loading from
{
model_path
}
and
{
optim_path
}
and
{
extra_state_path
}
."
)
model_state_dict
=
torch
.
load
(
model_path
,
weights_only
=
False
)
optimizer_state_dict
=
torch
.
load
(
optim_path
,
weights_only
=
False
)
extra_state_dict
=
torch
.
load
(
extra_state_path
,
weights_only
=
False
)
lr_scheduler_state_dict
=
extra_state_dict
[
"lr_scheduler"
]
state_dict_cfg
=
ShardedStateDictConfig
(
offload_to_cpu
=
True
)
optim_cfg
=
ShardedOptimStateDictConfig
(
offload_to_cpu
=
True
)
with
FSDP
.
state_dict_type
(
self
.
model
,
StateDictType
.
SHARDED_STATE_DICT
,
state_dict_cfg
,
optim_cfg
):
self
.
model
.
load_state_dict
(
model_state_dict
)
if
self
.
optimizer
is
not
None
:
self
.
optimizer
.
load_state_dict
(
optimizer_state_dict
)
# recover random state
if
"rng"
in
extra_state_dict
:
# 'rng' may not exist for backward compatibility
self
.
load_rng_state
(
extra_state_dict
[
"rng"
])
state_dict_config
=
ShardedStateDictConfig
(
offload_to_cpu
=
True
)
optim_config
=
ShardedOptimStateDictConfig
(
offload_to_cpu
=
True
)
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"ignore"
)
with
FSDP
.
state_dict_type
(
self
.
model
,
StateDictType
.
SHARDED_STATE_DICT
,
state_dict_config
,
optim_config
):
self
.
model
.
load_state_dict
(
model_state_dict
)
if
self
.
optimizer
is
not
None
:
self
.
optimizer
.
load_state_dict
(
optimizer_state_dict
)
if
self
.
lr_scheduler
is
not
None
:
self
.
lr_scheduler
.
load_state_dict
(
lr_scheduler_state_dict
)
def
save_checkpoint
(
self
,
local_path
:
str
,
global_step
:
int
,
re
m
ove
_previous_ckpt
=
False
,
*
args
,
**
kwargs
):
# record the previous global step
self
.
previous_global_step
=
global_step
#
re
c
ove
r random state
if
"rng"
in
extra_state_dict
:
self
.
load_rng_state
(
extra_state_dict
[
"rng"
])
# remove previous local_path
# TODO: shall we remove previous ckpt every save?
if
remove_previous_ckpt
:
self
.
remove_previous_save_local_path
()
local_path
=
self
.
local_mkdir
(
local_path
)
torch
.
distributed
.
barrier
()
def
save_checkpoint
(
self
,
path
:
str
):
path
=
self
.
local_mkdir
(
path
)
dist
.
barrier
()
# every rank will save its own model and optim shard
state_dict_c
f
g
=
ShardedStateDictConfig
(
offload_to_cpu
=
True
)
optim_c
f
g
=
ShardedOptimStateDictConfig
(
offload_to_cpu
=
True
)
state_dict_c
onfi
g
=
ShardedStateDictConfig
(
offload_to_cpu
=
True
)
optim_c
onfi
g
=
ShardedOptimStateDictConfig
(
offload_to_cpu
=
True
)
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"ignore"
)
with
FSDP
.
state_dict_type
(
self
.
model
,
StateDictType
.
SHARDED_STATE_DICT
,
state_dict_c
f
g
,
optim_c
f
g
):
with
FSDP
.
state_dict_type
(
self
.
model
,
StateDictType
.
SHARDED_STATE_DICT
,
state_dict_c
onfi
g
,
optim_c
onfi
g
):
model_state_dict
=
self
.
model
.
state_dict
()
if
self
.
optimizer
is
not
None
:
optimizer_state_dict
=
self
.
optimizer
.
state_dict
()
else
:
optimizer_state_dict
=
None
if
self
.
lr_scheduler
is
not
None
:
lr_scheduler_state_dict
=
self
.
lr_scheduler
.
state_dict
()
else
:
...
...
@@ -112,29 +104,28 @@ class FSDPCheckpointManager(BaseCheckpointManager):
"lr_scheduler"
:
lr_scheduler_state_dict
,
"rng"
:
self
.
get_rng_state
(),
}
model_path
=
os
.
path
.
join
(
local_
path
,
f
"model_world_size_
{
self
.
world_size
}
_rank_
{
self
.
rank
}
.pt"
)
optim_path
=
os
.
path
.
join
(
local_
path
,
f
"optim_world_size_
{
self
.
world_size
}
_rank_
{
self
.
rank
}
.pt"
)
extra_path
=
os
.
path
.
join
(
local_
path
,
f
"extra_state_world_size_
{
self
.
world_size
}
_rank_
{
self
.
rank
}
.pt"
)
model_path
=
os
.
path
.
join
(
path
,
f
"model_world_size_
{
self
.
world_size
}
_rank_
{
self
.
rank
}
.pt"
)
optim_path
=
os
.
path
.
join
(
path
,
f
"optim_world_size_
{
self
.
world_size
}
_rank_
{
self
.
rank
}
.pt"
)
extra_path
=
os
.
path
.
join
(
path
,
f
"extra_state_world_size_
{
self
.
world_size
}
_rank_
{
self
.
rank
}
.pt"
)
print
(
f
"[rank-
{
self
.
rank
}
]: Saving model to
{
os
.
path
.
abspath
(
model_path
)
}
"
)
print
(
f
"[rank-
{
self
.
rank
}
]: Saving checkpoint to
{
os
.
path
.
abspath
(
model_path
)
}
"
)
print
(
f
"[rank-
{
self
.
rank
}
]: Saving extra_state to
{
os
.
path
.
abspath
(
extra_path
)
}
"
)
print
(
f
"[rank-
{
self
.
rank
}
]: Saving model to
{
os
.
path
.
abspath
(
model_path
)
}
.
"
)
print
(
f
"[rank-
{
self
.
rank
}
]: Saving checkpoint to
{
os
.
path
.
abspath
(
model_path
)
}
.
"
)
print
(
f
"[rank-
{
self
.
rank
}
]: Saving extra_state to
{
os
.
path
.
abspath
(
extra_path
)
}
.
"
)
torch
.
save
(
model_state_dict
,
model_path
)
torch
.
save
(
optimizer_state_dict
,
optim_path
)
# TODO: address optimizer is None
if
self
.
optimizer
is
not
None
:
torch
.
save
(
optimizer_state_dict
,
optim_path
)
torch
.
save
(
extra_state_dict
,
extra_path
)
# wait for everyone to dump to local
torch
.
distributed
.
barrier
()
dist
.
barrier
()
if
self
.
rank
==
0
:
hf_local_path
=
os
.
path
.
join
(
local_path
,
"huggingface"
)
os
.
makedirs
(
hf_local_path
,
exist_ok
=
True
)
self
.
model
.
_fsdp_wrapped_module
.
config
.
save_pretrained
(
hf_local_path
)
if
self
.
processor
:
self
.
processor
.
save_pretrained
(
hf_local_path
)
else
:
self
.
tokenizer
.
save_pretrained
(
hf_local_path
)
torch
.
distributed
.
barrier
()
self
.
previous_save_local_path
=
local_path
hf_path
=
os
.
path
.
join
(
path
,
"huggingface"
)
os
.
makedirs
(
hf_path
,
exist_ok
=
True
)
assert
isinstance
(
self
.
model
.
_fsdp_wrapped_module
,
PreTrainedModel
)
self
.
model
.
_fsdp_wrapped_module
.
config
.
save_pretrained
(
hf_path
)
self
.
model
.
_fsdp_wrapped_module
.
generation_config
.
save_pretrained
(
hf_path
)
self
.
processing_class
.
save_pretrained
(
hf_path
)
dist
.
barrier
()
verl/utils/dataset.py
0 → 100644
View file @
c132cbcb
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
os
from
collections
import
defaultdict
from
io
import
BytesIO
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
import
numpy
as
np
import
torch
from
datasets
import
load_dataset
from
PIL
import
Image
from
PIL.Image
import
Image
as
ImageObject
from
torch.utils.data
import
Dataset
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
..models.transformers.qwen2_vl
import
get_rope_index
from
.
import
torch_functional
as
VF
def
collate_fn
(
features
:
List
[
Dict
[
str
,
Any
]])
->
Dict
[
str
,
Any
]:
tensors
=
defaultdict
(
list
)
non_tensors
=
defaultdict
(
list
)
for
feature
in
features
:
for
key
,
value
in
feature
.
items
():
if
isinstance
(
value
,
torch
.
Tensor
):
tensors
[
key
].
append
(
value
)
else
:
non_tensors
[
key
].
append
(
value
)
for
key
,
value
in
tensors
.
items
():
tensors
[
key
]
=
torch
.
stack
(
value
,
dim
=
0
)
for
key
,
value
in
non_tensors
.
items
():
non_tensors
[
key
]
=
np
.
array
(
value
,
dtype
=
object
)
return
{
**
tensors
,
**
non_tensors
}
class
ImageProcessMixin
:
max_pixels
:
int
min_pixels
:
int
def
process_image
(
self
,
image
:
Union
[
Dict
[
str
,
Any
],
ImageObject
])
->
ImageObject
:
if
isinstance
(
image
,
dict
):
image
=
Image
.
open
(
BytesIO
(
image
[
"bytes"
]))
elif
isinstance
(
image
,
bytes
):
image
=
Image
.
open
(
BytesIO
(
image
))
if
(
image
.
width
*
image
.
height
)
>
self
.
max_pixels
:
resize_factor
=
math
.
sqrt
(
self
.
max_pixels
/
(
image
.
width
*
image
.
height
))
width
,
height
=
int
(
image
.
width
*
resize_factor
),
int
(
image
.
height
*
resize_factor
)
image
=
image
.
resize
((
width
,
height
))
if
(
image
.
width
*
image
.
height
)
<
self
.
min_pixels
:
resize_factor
=
math
.
sqrt
(
self
.
min_pixels
/
(
image
.
width
*
image
.
height
))
width
,
height
=
int
(
image
.
width
*
resize_factor
),
int
(
image
.
height
*
resize_factor
)
image
=
image
.
resize
((
width
,
height
))
if
image
.
mode
!=
"RGB"
:
image
=
image
.
convert
(
"RGB"
)
return
image
class
RLHFDataset
(
Dataset
,
ImageProcessMixin
):
"""
We assume the dataset contains a column that contains prompts and other information
"""
def
__init__
(
self
,
data_path
:
str
,
tokenizer
:
PreTrainedTokenizer
,
processor
:
Optional
[
ProcessorMixin
],
prompt_key
:
str
=
"prompt"
,
answer_key
:
str
=
"answer"
,
image_key
:
str
=
"images"
,
max_prompt_length
:
int
=
1024
,
truncation
:
str
=
"error"
,
system_prompt
:
str
=
None
,
max_pixels
:
int
=
None
,
min_pixels
:
int
=
None
,
):
self
.
tokenizer
=
tokenizer
self
.
processor
=
processor
self
.
prompt_key
=
prompt_key
self
.
answer_key
=
answer_key
self
.
image_key
=
image_key
self
.
max_prompt_length
=
max_prompt_length
self
.
truncation
=
truncation
self
.
system_prompt
=
system_prompt
self
.
max_pixels
=
max_pixels
self
.
min_pixels
=
min_pixels
if
"@"
in
data_path
:
data_path
,
data_split
=
data_path
.
split
(
"@"
)
else
:
data_split
=
"train"
if
os
.
path
.
isdir
(
data_path
):
self
.
dataset
=
load_dataset
(
"parquet"
,
data_dir
=
data_path
,
split
=
"train"
)
elif
os
.
path
.
isfile
(
data_path
):
self
.
dataset
=
load_dataset
(
"parquet"
,
data_files
=
data_path
,
split
=
"train"
)
else
:
# remote dataset
self
.
dataset
=
load_dataset
(
data_path
,
split
=
data_split
)
def
__len__
(
self
):
return
len
(
self
.
dataset
)
def
__getitem__
(
self
,
index
):
row_dict
:
dict
=
self
.
dataset
[
index
]
prompt_str
:
str
=
row_dict
[
self
.
prompt_key
]
if
self
.
system_prompt
:
prompt_str
=
" "
.
join
((
self
.
system_prompt
.
strip
(),
prompt_str
))
if
self
.
image_key
in
row_dict
:
# https://huggingface.co/docs/transformers/en/tasks/image_text_to_text
content_list
=
[]
for
i
,
content
in
enumerate
(
prompt_str
.
split
(
"<image>"
)):
if
i
!=
0
:
content_list
.
append
({
"type"
:
"image"
})
if
content
:
content_list
.
append
({
"type"
:
"text"
,
"text"
:
content
})
messages
=
[{
"role"
:
"user"
,
"content"
:
content_list
}]
prompt
=
self
.
processor
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
,
tokenize
=
False
)
images
=
[
self
.
process_image
(
image
)
for
image
in
row_dict
.
pop
(
self
.
image_key
)]
model_inputs
=
self
.
processor
(
images
,
[
prompt
],
add_special_tokens
=
False
,
return_tensors
=
"pt"
)
input_ids
=
model_inputs
.
pop
(
"input_ids"
)[
0
]
attention_mask
=
model_inputs
.
pop
(
"attention_mask"
)[
0
]
row_dict
[
"multi_modal_data"
]
=
{
"image"
:
images
}
row_dict
[
"multi_modal_inputs"
]
=
dict
(
model_inputs
)
# qwen2vl mrope
position_ids
=
get_rope_index
(
self
.
processor
,
input_ids
=
input_ids
,
image_grid_thw
=
model_inputs
[
"image_grid_thw"
],
attention_mask
=
attention_mask
,
)
# (3, seq_length)
else
:
messages
=
[{
"role"
:
"user"
,
"content"
:
prompt_str
}]
prompt
=
self
.
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
,
tokenize
=
False
)
model_inputs
=
self
.
tokenizer
([
prompt
],
add_special_tokens
=
False
,
return_tensors
=
"pt"
)
input_ids
=
model_inputs
.
pop
(
"input_ids"
)[
0
]
attention_mask
=
model_inputs
.
pop
(
"attention_mask"
)[
0
]
position_ids
=
torch
.
clip
(
attention_mask
.
cumsum
(
dim
=
0
)
-
1
,
min
=
0
,
max
=
None
)
# (seq_length,)
input_ids
,
attention_mask
,
position_ids
=
VF
.
postprocess_data
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
max_length
=
self
.
max_prompt_length
,
pad_token_id
=
self
.
tokenizer
.
pad_token_id
,
left_pad
=
True
,
truncation
=
self
.
truncation
,
)
row_dict
[
"input_ids"
]
=
input_ids
row_dict
[
"attention_mask"
]
=
attention_mask
row_dict
[
"position_ids"
]
=
position_ids
row_dict
[
"raw_prompt_ids"
]
=
self
.
tokenizer
.
encode
(
prompt
,
add_special_tokens
=
False
)
row_dict
[
"ground_truth"
]
=
row_dict
.
pop
(
self
.
answer_key
)
return
row_dict
verl/utils/flops_counter.py
View file @
c132cbcb
...
...
@@ -12,22 +12,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
List
,
Tuple
import
torch
from
transformers
import
LlamaConfig
,
PretrainedConfig
,
Qwen2Config
VALID_CONFIG_TYPE
=
(
Qwen2Config
,
LlamaConfig
)
if
TYPE_CHECKING
:
from
transformers.models.llama.configuration_llama
import
LlamaConfig
VALID_MODLE_TYPE
=
{
"llama"
,
"qwen2"
,
"qwen2_vl"
,
"qwen2_5_vl"
}
def
get_device_flops
(
unit
=
"T"
):
def
unit_convert
(
number
,
level
):
def
get_device_flops
(
unit
:
str
=
"T"
)
->
float
:
def
unit_convert
(
number
:
float
,
level
:
str
):
units
=
[
"B"
,
"K"
,
"M"
,
"G"
,
"T"
,
"P"
]
if
number
<=
0
:
return
number
ptr
=
0
while
ptr
<
len
(
units
)
and
units
[
ptr
]
!=
level
:
number
/=
1000
ptr
+=
1
return
number
device_name
=
torch
.
cuda
.
get_device_name
()
...
...
@@ -55,21 +62,24 @@ class FlopsCounter:
Example:
flops_counter = FlopsCounter(config)
flops_achieved, flops_promised = flops_counter.estimate_flops(tokens_list, delta_time)
"""
def
__init__
(
self
,
config
:
Pretrained
Config
):
if
not
isinstance
(
config
,
VALID_
CONFIG
_TYPE
)
:
print
(
f
"Only support
config type of
{
VALID_
CONFIG
_TYPE
}
, but got
{
type
(
config
)
}
. MFU will always be zero."
)
def
__init__
(
self
,
config
:
"Llama
Config
"
):
if
config
.
model_type
not
in
VALID_
MODLE
_TYPE
:
print
(
f
"Only support
{
VALID_
MODLE
_TYPE
}
, but got
{
config
.
model_type
}
. MFU will always be zero."
)
self
.
estimate_func
=
{
"qwen2"
:
self
.
_estimate_qwen2_flops
,
"llama"
:
self
.
_estimate_qwen2_flops
}
self
.
estimate_func
=
{
"llama"
:
self
.
_estimate_llama_flops
,
"qwen2"
:
self
.
_estimate_llama_flops
,
"qwen2_vl"
:
self
.
_estimate_llama_flops
,
"qwen2_5_vl"
:
self
.
_estimate_llama_flops
,
}
self
.
config
=
config
def
_estimate_unknown_flops
(
self
,
tokens_sum
,
batch_seqlens
,
delta_time
)
:
def
_estimate_unknown_flops
(
self
,
tokens_sum
:
int
,
batch_seqlens
:
List
[
int
]
,
delta_time
:
float
)
->
float
:
return
0
def
_estimate_qwen2_flops
(
self
,
tokens_sum
,
batch_seqlens
,
delta_time
):
assert
isinstance
(
self
.
config
,
(
Qwen2Config
,
LlamaConfig
))
def
_estimate_llama_flops
(
self
,
tokens_sum
:
int
,
batch_seqlens
:
List
[
int
],
delta_time
:
float
)
->
float
:
hidden_size
=
self
.
config
.
hidden_size
vocab_size
=
self
.
config
.
vocab_size
num_hidden_layers
=
self
.
config
.
num_hidden_layers
...
...
@@ -96,6 +106,7 @@ class FlopsCounter:
seqlen_square_sum
=
0
for
seqlen
in
batch_seqlens
:
seqlen_square_sum
+=
seqlen
*
seqlen
attn_qkv_flops
=
12
*
seqlen_square_sum
*
head_dim
*
num_attention_heads
*
num_hidden_layers
# all_layer & all_token fwd & bwd flops
...
...
@@ -103,7 +114,7 @@ class FlopsCounter:
flops_achieved
=
flops_all_token
*
(
1.0
/
delta_time
)
/
1e12
return
flops_achieved
def
estimate_flops
(
self
,
batch_seqlens
,
delta_time
)
:
def
estimate_flops
(
self
,
batch_seqlens
:
List
[
int
]
,
delta_time
:
float
)
->
Tuple
[
float
,
float
]
:
"""
Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken.
...
...
verl/utils/fsdp_utils.py
View file @
c132cbcb
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
gc
from
collections
import
defaultdict
from
functools
import
partial
from
typing
import
Callable
,
Union
...
...
@@ -73,6 +74,7 @@ def offload_fsdp_model(model: FSDP, empty_cache: bool = True):
for
handle
in
model
.
_all_handles
:
if
handle
.
_offload_params
:
continue
flat_param
=
handle
.
flat_param
assert
(
flat_param
.
data
.
data_ptr
()
==
flat_param
.
_local_shard
.
data_ptr
()
...
...
@@ -89,7 +91,7 @@ def offload_fsdp_model(model: FSDP, empty_cache: bool = True):
@
torch
.
no_grad
()
def
load_fsdp_model
(
model
:
FSDP
):
def
load_fsdp_model
(
model
:
FSDP
,
empty_cache
:
bool
=
True
):
# lazy init FSDP model
_lazy_init
(
model
,
model
)
assert
model
.
_is_root
,
"Only support root model loading to GPU"
...
...
@@ -102,11 +104,15 @@ def load_fsdp_model(model: FSDP):
# the following still keeps id(._local_shard) != id(.data)
flat_param
.
_local_shard
=
flat_param
.
data
if
empty_cache
:
gc
.
collect
()
@
torch
.
no_grad
()
def
offload_fsdp_optimizer
(
optimizer
:
Optimizer
):
def
offload_fsdp_optimizer
(
optimizer
:
Optimizer
,
empty_cache
:
bool
=
True
):
if
not
optimizer
.
state
:
return
for
param_group
in
optimizer
.
param_groups
:
for
param
in
param_group
[
"params"
]:
state
=
optimizer
.
state
[
param
]
...
...
@@ -114,14 +120,21 @@ def offload_fsdp_optimizer(optimizer: Optimizer):
if
isinstance
(
value
,
torch
.
Tensor
):
state
[
key
]
=
value
.
to
(
"cpu"
,
non_blocking
=
True
)
if
empty_cache
:
torch
.
cuda
.
empty_cache
()
@
torch
.
no_grad
()
def
load_fsdp_optimizer
(
optimizer
:
Optimizer
):
def
load_fsdp_optimizer
(
optimizer
:
Optimizer
,
empty_cache
:
bool
=
True
):
if
not
optimizer
.
state
:
return
for
param_group
in
optimizer
.
param_groups
:
for
param
in
param_group
[
"params"
]:
state
=
optimizer
.
state
[
param
]
for
key
,
value
in
state
.
items
():
if
isinstance
(
value
,
torch
.
Tensor
):
state
[
key
]
=
value
.
to
(
"cuda"
,
non_blocking
=
True
)
if
empty_cache
:
gc
.
collect
()
verl/utils/logger/__init__.py
View file @
c132cbcb
...
...
@@ -11,3 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.logger
import
Tracker
__all__
=
[
"Tracker"
]
verl/utils/logger/gen_logger.py
0 → 100644
View file @
c132cbcb
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
List
,
Tuple
from
..py_functional
import
is_package_available
if
is_package_available
(
"wandb"
):
import
wandb
# type: ignore
if
is_package_available
(
"swanlab"
):
import
swanlab
# type: ignore
@
dataclass
class
GenerationLogger
(
ABC
):
@
abstractmethod
def
log
(
self
,
samples
:
List
[
Tuple
[
str
,
str
,
float
]],
step
:
int
)
->
None
:
...
@
dataclass
class
ConsoleGenerationLogger
(
GenerationLogger
):
def
log
(
self
,
samples
:
List
[
Tuple
[
str
,
str
,
float
]],
step
:
int
)
->
None
:
for
inp
,
out
,
score
in
samples
:
print
(
f
"[prompt]
{
inp
}
\n
[output]
{
out
}
\n
[score]
{
score
}
\n
"
)
@
dataclass
class
WandbGenerationLogger
(
GenerationLogger
):
def
log
(
self
,
samples
:
List
[
Tuple
[
str
,
str
,
float
]],
step
:
int
)
->
None
:
# Create column names for all samples
columns
=
[
"step"
]
+
sum
(
[[
f
"input_
{
i
+
1
}
"
,
f
"output_
{
i
+
1
}
"
,
f
"score_
{
i
+
1
}
"
]
for
i
in
range
(
len
(
samples
))],
[]
)
if
not
hasattr
(
self
,
"validation_table"
):
# Initialize the table on first call
self
.
validation_table
=
wandb
.
Table
(
columns
=
columns
)
# Create a new table with same columns and existing data
# Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737
new_table
=
wandb
.
Table
(
columns
=
columns
,
data
=
self
.
validation_table
.
data
)
# Add new row with all data
row_data
=
[
step
]
for
sample
in
samples
:
row_data
.
extend
(
sample
)
new_table
.
add_data
(
*
row_data
)
wandb
.
log
({
"val/generations"
:
new_table
},
step
=
step
)
self
.
validation_table
=
new_table
@
dataclass
class
SwanlabGenerationLogger
(
GenerationLogger
):
def
log
(
self
,
samples
:
List
[
Tuple
[
str
,
str
,
float
]],
step
:
int
)
->
None
:
swanlab_text_list
=
[]
for
i
,
sample
in
enumerate
(
samples
):
row_text
=
f
"input:
{
sample
[
0
]
}
\n\n
---
\n\n
output:
{
sample
[
1
]
}
\n\n
---
\n\n
score:
{
sample
[
2
]
}
"
swanlab_text_list
.
append
(
swanlab
.
Text
(
row_text
,
caption
=
f
"sample
{
i
+
1
}
"
))
swanlab
.
log
({
"val/generations"
:
swanlab_text_list
},
step
=
step
)
GEN_LOGGERS
=
{
"console"
:
ConsoleGenerationLogger
,
"wandb"
:
WandbGenerationLogger
,
"swanlab"
:
SwanlabGenerationLogger
,
}
@
dataclass
class
AggregateGenerationsLogger
:
def
__init__
(
self
,
loggers
:
List
[
str
]):
self
.
loggers
:
List
[
GenerationLogger
]
=
[]
for
logger
in
loggers
:
if
logger
in
GEN_LOGGERS
:
self
.
loggers
.
append
(
GEN_LOGGERS
[
logger
]())
def
log
(
self
,
samples
:
List
[
Tuple
[
str
,
str
,
float
]],
step
:
int
)
->
None
:
for
logger
in
self
.
loggers
:
logger
.
log
(
samples
,
step
)
verl/utils/logger/logger.py
0 → 100644
View file @
c132cbcb
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A unified tracking interface that supports logging data to different backend
"""
import
os
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
torch.utils.tensorboard
import
SummaryWriter
from
..py_functional
import
convert_dict_to_str
,
flatten_dict
,
is_package_available
,
unflatten_dict
from
.gen_logger
import
AggregateGenerationsLogger
if
is_package_available
(
"mlflow"
):
import
mlflow
# type: ignore
if
is_package_available
(
"wandb"
):
import
wandb
# type: ignore
if
is_package_available
(
"swanlab"
):
import
swanlab
# type: ignore
class
Logger
(
ABC
):
@
abstractmethod
def
__init__
(
self
,
config
:
Dict
[
str
,
Any
])
->
None
:
...
@
abstractmethod
def
log
(
self
,
data
:
Dict
[
str
,
Any
],
step
:
int
)
->
None
:
...
def
finish
(
self
)
->
None
:
pass
class
ConsoleLogger
(
Logger
):
def
__init__
(
self
,
config
:
Dict
[
str
,
Any
])
->
None
:
print
(
"Config
\n
"
+
convert_dict_to_str
(
config
))
def
log
(
self
,
data
:
Dict
[
str
,
Any
],
step
:
int
)
->
None
:
print
(
f
"Step
{
step
}
\n
"
+
convert_dict_to_str
(
unflatten_dict
(
data
)))
class
MlflowLogger
(
Logger
):
def
__init__
(
self
,
config
:
Dict
[
str
,
Any
])
->
None
:
mlflow
.
start_run
(
run_name
=
config
[
"trainer"
][
"experiment_name"
])
mlflow
.
log_params
(
flatten_dict
(
config
))
def
log
(
self
,
data
:
Dict
[
str
,
Any
],
step
:
int
)
->
None
:
mlflow
.
log_metrics
(
metrics
=
data
,
step
=
step
)
class
TensorBoardLogger
(
Logger
):
def
__init__
(
self
,
config
:
Dict
[
str
,
Any
])
->
None
:
tensorboard_dir
=
os
.
getenv
(
"TENSORBOARD_DIR"
,
"tensorboard_log"
)
os
.
makedirs
(
tensorboard_dir
,
exist_ok
=
True
)
print
(
f
"Saving tensorboard log to
{
tensorboard_dir
}
."
)
self
.
writer
=
SummaryWriter
(
tensorboard_dir
)
self
.
writer
.
add_hparams
(
flatten_dict
(
config
))
def
log
(
self
,
data
:
Dict
[
str
,
Any
],
step
:
int
)
->
None
:
for
key
,
value
in
data
.
items
():
self
.
writer
.
add_scalar
(
key
,
value
,
step
)
def
finish
(
self
):
self
.
writer
.
close
()
class
WandbLogger
(
Logger
):
def
__init__
(
self
,
config
:
Dict
[
str
,
Any
])
->
None
:
wandb
.
init
(
project
=
config
[
"trainer"
][
"project_name"
],
name
=
config
[
"trainer"
][
"experiment_name"
],
config
=
config
,
)
def
log
(
self
,
data
:
Dict
[
str
,
Any
],
step
:
int
)
->
None
:
wandb
.
log
(
data
=
data
,
step
=
step
)
def
finish
(
self
)
->
None
:
wandb
.
finish
()
class
SwanlabLogger
(
Logger
):
def
__init__
(
self
,
config
:
Dict
[
str
,
Any
])
->
None
:
swanlab_key
=
os
.
getenv
(
"SWANLAB_API_KEY"
)
swanlab_dir
=
os
.
getenv
(
"SWANLAB_DIR"
,
"swanlab_log"
)
swanlab_mode
=
os
.
getenv
(
"SWANLAB_MODE"
,
"cloud"
)
if
swanlab_key
:
swanlab
.
login
(
swanlab_key
)
swanlab
.
init
(
project
=
config
[
"trainer"
][
"project_name"
],
experiment_name
=
config
[
"trainer"
][
"experiment_name"
],
config
=
{
"UPPERFRAMEWORK"
:
"EasyR1"
,
"FRAMEWORK"
:
"veRL"
,
**
config
},
logdir
=
swanlab_dir
,
mode
=
swanlab_mode
,
)
def
log
(
self
,
data
:
Dict
[
str
,
Any
],
step
:
int
)
->
None
:
swanlab
.
log
(
data
=
data
,
step
=
step
)
def
finish
(
self
)
->
None
:
swanlab
.
finish
()
LOGGERS
=
{
"wandb"
:
WandbLogger
,
"mlflow"
:
MlflowLogger
,
"tensorboard"
:
TensorBoardLogger
,
"console"
:
ConsoleLogger
,
"swanlab"
:
SwanlabLogger
,
}
class
Tracker
:
def
__init__
(
self
,
loggers
:
Union
[
str
,
List
[
str
]]
=
"console"
,
config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
):
if
isinstance
(
loggers
,
str
):
loggers
=
[
loggers
]
self
.
loggers
:
List
[
Logger
]
=
[]
for
logger
in
loggers
:
if
logger
not
in
LOGGERS
:
raise
ValueError
(
f
"
{
logger
}
is not supported."
)
self
.
loggers
.
append
(
LOGGERS
[
logger
](
config
))
self
.
gen_logger
=
AggregateGenerationsLogger
(
loggers
)
def
log
(
self
,
data
:
Dict
[
str
,
Any
],
step
:
int
)
->
None
:
for
logger
in
self
.
loggers
:
logger
.
log
(
data
=
data
,
step
=
step
)
def
log_generation
(
self
,
samples
:
List
[
Tuple
[
str
,
str
,
float
]],
step
:
int
)
->
None
:
self
.
gen_logger
.
log
(
samples
,
step
)
def
__del__
(
self
):
for
logger
in
self
.
loggers
:
logger
.
finish
()
verl/utils/model_utils.py
View file @
c132cbcb
...
...
@@ -15,11 +15,28 @@
Utilities to create common models
"""
from
functools
import
lru_cache
from
typing
import
Optional
,
Tuple
import
torch
import
torch.distributed
as
dist
from
torch
import
nn
def
get_model_size
(
model
:
nn
.
Module
,
scale
=
"auto"
):
@
lru_cache
def
is_rank0
()
->
int
:
return
(
not
dist
.
is_initialized
())
or
(
dist
.
get_rank
()
==
0
)
def
print_gpu_memory_usage
(
prefix
:
str
=
"GPU memory usage"
)
->
None
:
"""Report the current GPU VRAM usage."""
if
is_rank0
():
free_mem
,
total_mem
=
torch
.
cuda
.
mem_get_info
()
print
(
f
"
{
prefix
}
:
{
(
total_mem
-
free_mem
)
/
(
1024
**
3
):.
2
f
}
GB /
{
total_mem
/
(
1024
**
3
):.
2
f
}
GB."
)
def
_get_model_size
(
model
:
nn
.
Module
,
scale
:
str
=
"auto"
)
->
Tuple
[
float
,
str
]:
"""Compute the model size."""
n_params
=
sum
(
p
.
numel
()
for
p
in
model
.
parameters
())
if
scale
==
"auto"
:
...
...
@@ -41,18 +58,16 @@ def get_model_size(model: nn.Module, scale="auto"):
elif
scale
==
""
:
pass
else
:
raise
NotImplementedError
(
f
"Unknown scale
{
scale
}
"
)
raise
NotImplementedError
(
f
"Unknown scale
{
scale
}
.
"
)
return
n_params
,
scale
def
print_model_size
(
model
:
nn
.
Module
,
name
:
str
=
None
):
n_params
,
scale
=
get_model_size
(
model
,
scale
=
"auto"
)
if
name
is
None
:
name
=
model
.
__class__
.
__name__
print
(
f
"
{
name
}
contains
{
n_params
:.
2
f
}{
scale
}
parameters"
)
def
print_model_size
(
model
:
nn
.
Module
,
name
:
Optional
[
str
]
=
None
)
->
None
:
"""Print the model size."""
if
is_rank0
():
n_params
,
scale
=
_get_model_size
(
model
,
scale
=
"auto"
)
if
name
is
None
:
name
=
model
.
__class__
.
__name__
def
compute_position_id_with_mask
(
mask
):
return
torch
.
clip
(
torch
.
cumsum
(
mask
,
dim
=-
1
)
-
1
,
min
=
0
,
max
=
None
)
print
(
f
"
{
name
}
contains
{
n_params
:.
2
f
}{
scale
}
parameters."
)
verl/utils/py_functional.py
View file @
c132cbcb
...
...
@@ -15,23 +15,89 @@
Contain small python utility functions
"""
from
typing
import
Any
,
Dict
,
List
import
importlib.util
import
re
from
functools
import
lru_cache
from
typing
import
Any
,
Dict
,
List
,
Union
import
numpy
as
np
import
yaml
from
yaml
import
Dumper
def
is_sci_notation
(
number
:
float
)
->
bool
:
pattern
=
re
.
compile
(
r
"^[+-]?\d+(\.\d*)?[eE][+-]?\d+$"
)
return
bool
(
pattern
.
match
(
str
(
number
)))
def
float_representer
(
dumper
:
Dumper
,
number
:
Union
[
float
,
np
.
float32
,
np
.
float64
]):
if
is_sci_notation
(
number
):
value
=
str
(
number
)
if
"."
not
in
value
and
"e"
in
value
:
value
=
value
.
replace
(
"e"
,
".0e"
,
1
)
else
:
value
=
str
(
round
(
number
,
3
))
return
dumper
.
represent_scalar
(
"tag:yaml.org,2002:float"
,
value
)
yaml
.
add_representer
(
float
,
float_representer
)
yaml
.
add_representer
(
np
.
float32
,
float_representer
)
yaml
.
add_representer
(
np
.
float64
,
float_representer
)
@
lru_cache
def
is_package_available
(
name
:
str
)
->
bool
:
return
importlib
.
util
.
find_spec
(
name
)
is
not
None
def
union_two_dict
(
dict1
:
Dict
[
str
,
Any
],
dict2
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
"""Union two dict. Will throw an error if there is an item not the same object with the same key."""
for
key
,
value
in
dict2
.
item
s
():
for
key
in
dict2
.
key
s
():
if
key
in
dict1
:
assert
dict1
[
key
]
!
=
value
,
f
"
{
key
}
in
meta_
dict1 and
meta_
dict2 are not the same object"
assert
dict1
[
key
]
=
=
dict2
[
key
]
,
f
"
{
key
}
in dict1 and dict2 are not the same object"
dict1
[
key
]
=
value
dict1
[
key
]
=
dict2
[
key
]
return
dict1
def
append_to_dict
(
data
:
Dict
[
str
,
List
[
Any
]],
new_data
:
Dict
[
str
,
Any
])
->
None
:
"""Append dict to a dict of list."""
for
key
,
val
in
new_data
.
items
():
if
key
not
in
data
:
data
[
key
]
=
[]
data
[
key
].
append
(
val
)
def
unflatten_dict
(
data
:
Dict
[
str
,
Any
],
sep
:
str
=
"/"
)
->
Dict
[
str
,
Any
]:
unflattened
=
{}
for
key
,
value
in
data
.
items
():
pieces
=
key
.
split
(
sep
)
pointer
=
unflattened
for
piece
in
pieces
[:
-
1
]:
if
piece
not
in
pointer
:
pointer
[
piece
]
=
{}
pointer
=
pointer
[
piece
]
pointer
[
pieces
[
-
1
]]
=
value
return
unflattened
def
flatten_dict
(
data
:
Dict
[
str
,
Any
],
parent_key
:
str
=
""
,
sep
:
str
=
"/"
)
->
Dict
[
str
,
Any
]:
flattened
=
{}
for
key
,
value
in
data
.
items
():
new_key
=
parent_key
+
sep
+
key
if
parent_key
else
key
if
isinstance
(
value
,
dict
):
flattened
.
update
(
flatten_dict
(
value
,
new_key
,
sep
=
sep
))
else
:
flattened
[
new_key
]
=
value
return
flattened
def
convert_dict_to_str
(
data
:
Dict
[
str
,
Any
])
->
str
:
return
yaml
.
dump
(
data
,
indent
=
2
)
verl/utils/reward_score/math.py
View file @
c132cbcb
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
re
from
typing
import
Dict
from
mathruler.grader
import
extract_boxed_content
,
grade_answer
def
math_compute_score
(
predict_str
:
str
,
ground_truth
:
str
)
->
float
:
def
math_format_reward
(
predict_str
:
str
)
->
float
:
pattern
=
re
.
compile
(
r
"<think>.*</think>.*\\boxed\{.*\}.*"
,
re
.
DOTALL
)
format_match
=
re
.
fullmatch
(
pattern
,
predict_str
)
return
1.0
if
format_match
else
0.0
def
math_acc_reward
(
predict_str
:
str
,
ground_truth
:
str
)
->
float
:
answer
=
extract_boxed_content
(
predict_str
)
if
answer
==
"None"
:
return
0.0
# no answer
return
1.0
if
grade_answer
(
answer
,
ground_truth
)
else
0.0
if
grade_answer
(
answer
,
ground_truth
):
return
1.0
# correct answer
return
0.1
# wrong answer
def
math_compute_score
(
predict_str
:
str
,
ground_truth
:
str
)
->
Dict
[
str
,
float
]:
format
=
math_format_reward
(
predict_str
)
accuracy
=
math_acc_reward
(
predict_str
,
ground_truth
)
return
{
"overall"
:
0.9
*
accuracy
+
0.1
*
format
,
"format"
:
format
,
"accuracy"
:
accuracy
,
}
verl/utils/reward_score/r1v.py
View file @
c132cbcb
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
re
from
typing
import
Dict
from
mathruler.grader
import
grade_answer
def
r1v_format_reward
(
predict_str
:
str
)
->
float
:
pattern
=
r
"<think>.*?</think>\s*<answer>.*?</answer>"
match
=
re
.
fullmatch
(
pattern
,
predict_str
,
re
.
DOTALL
)
return
1.0
if
match
else
0.0
pattern
=
re
.
compile
(
r
"<think>.*?</think>\s*<answer>.*?</answer>"
,
re
.
DOTALL
)
format_
match
=
re
.
fullmatch
(
pattern
,
predict_str
)
return
1.0
if
format_
match
else
0.0
def
r1v_accuracy_reward
(
predict_str
:
str
,
ground_truth
:
str
)
->
float
:
try
:
ground_truth
=
ground_truth
.
strip
()
content_match
=
re
.
search
(
r
"<answer>(.*?)</answer>"
,
predict_str
)
pred
_answer
=
content_match
.
group
(
1
).
strip
()
if
content_match
else
predict_str
.
strip
()
if
grade_answer
(
pred
_answer
,
ground_truth
):
given
_answer
=
content_match
.
group
(
1
).
strip
()
if
content_match
else
predict_str
.
strip
()
if
grade_answer
(
given
_answer
,
ground_truth
):
return
1.0
except
Exception
:
pass
return
0.0
def
r1v_compute_score
(
predict_str
:
str
,
ground_truth
:
str
)
->
float
:
acc_reward
=
r1v_accuracy_reward
(
predict_str
,
ground_truth
)
format_reward
=
r1v_format_reward
(
predict_str
)
reward
=
acc_reward
+
format_reward
reward
/=
2
return
reward
def
r1v_compute_score
(
predict_str
:
str
,
ground_truth
:
str
)
->
Dict
[
str
,
float
]:
format
=
r1v_format_reward
(
predict_str
)
accuracy
=
r1v_accuracy_reward
(
predict_str
,
ground_truth
)
return
{
"overall"
:
0.5
*
accuracy
+
0.5
*
format
,
"format"
:
format
,
"accuracy"
:
accuracy
,
}
verl/utils/seqlen_balancing.py
0 → 100644
View file @
c132cbcb
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
import
heapq
from
typing
import
List
,
Tuple
import
torch
from
tensordict
import
TensorDict
from
torch
import
distributed
as
dist
class
Set
:
def
__init__
(
self
)
->
None
:
self
.
sum
=
0
self
.
items
=
[]
def
add
(
self
,
idx
:
int
,
val
:
int
):
self
.
items
.
append
((
idx
,
val
))
self
.
sum
+=
val
def
merge
(
self
,
other
):
for
idx
,
val
in
other
.
items
:
self
.
items
.
append
((
idx
,
val
))
self
.
sum
+=
val
def
__lt__
(
self
,
other
):
if
self
.
sum
!=
other
.
sum
:
return
self
.
sum
<
other
.
sum
if
len
(
self
.
items
)
!=
len
(
other
.
items
):
return
len
(
self
.
items
)
<
len
(
other
.
items
)
return
self
.
items
<
other
.
items
class
State
:
def
__init__
(
self
,
items
:
List
[
Tuple
[
int
,
int
]],
k
:
int
)
->
None
:
self
.
k
=
k
# sets should always be decreasing order
self
.
sets
=
[
Set
()
for
_
in
range
(
k
)]
assert
len
(
items
)
in
[
1
,
k
],
f
"
{
len
(
items
)
}
not in [1,
{
k
}
]"
for
i
,
(
idx
,
seqlen
)
in
enumerate
(
items
):
self
.
sets
[
i
].
add
(
idx
=
idx
,
val
=
seqlen
)
self
.
sets
=
sorted
(
self
.
sets
,
reverse
=
True
)
def
get_partitions
(
self
):
partitions
=
[]
for
i
in
range
(
len
(
self
.
sets
)):
cur_partition
=
[]
for
idx
,
_
in
self
.
sets
[
i
].
items
:
cur_partition
.
append
(
idx
)
partitions
.
append
(
cur_partition
)
return
partitions
def
merge
(
self
,
other
):
for
i
in
range
(
self
.
k
):
self
.
sets
[
i
].
merge
(
other
.
sets
[
self
.
k
-
1
-
i
])
self
.
sets
=
sorted
(
self
.
sets
,
reverse
=
True
)
@
property
def
spread
(
self
)
->
int
:
return
self
.
sets
[
0
].
sum
-
self
.
sets
[
-
1
].
sum
def
__lt__
(
self
,
other
):
# least heap, let the state with largest spread to be popped first,
# if the spread is the same, let the state who has the largest set
# to be popped first.
if
self
.
spread
!=
other
.
spread
:
return
self
.
spread
>
other
.
spread
return
self
.
sets
[
0
]
>
other
.
sets
[
0
]
def
__repr__
(
self
)
->
str
:
repr_str
=
"["
for
i
in
range
(
self
.
k
):
if
i
>
0
:
repr_str
+=
","
repr_str
+=
"{"
for
j
,
(
_
,
seqlen
)
in
enumerate
(
self
.
sets
[
i
].
items
):
if
j
>
0
:
repr_str
+=
","
repr_str
+=
str
(
seqlen
)
repr_str
+=
"}"
repr_str
+=
"]"
return
repr_str
def
karmarkar_karp
(
seqlen_list
:
List
[
int
],
k_partitions
:
int
,
equal_size
:
bool
):
# see: https://en.wikipedia.org/wiki/Largest_differencing_method
sorted_seqlen_list
=
sorted
([(
seqlen
,
i
)
for
i
,
seqlen
in
enumerate
(
seqlen_list
)])
states_pq
:
List
[
State
]
=
[]
if
equal_size
:
assert
len
(
seqlen_list
)
%
k_partitions
==
0
,
f
"
{
len
(
seqlen_list
)
}
%
{
k_partitions
}
!= 0"
for
offset
in
range
(
0
,
len
(
sorted_seqlen_list
),
k_partitions
):
items
=
[]
for
i
in
range
(
k_partitions
):
seqlen
,
idx
=
sorted_seqlen_list
[
offset
+
i
]
items
.
append
((
idx
,
seqlen
))
heapq
.
heappush
(
states_pq
,
State
(
items
=
items
,
k
=
k_partitions
))
else
:
for
seqlen
,
idx
in
sorted_seqlen_list
:
heapq
.
heappush
(
states_pq
,
State
(
items
=
[(
idx
,
seqlen
)],
k
=
k_partitions
))
while
len
(
states_pq
)
>
1
:
state0
=
heapq
.
heappop
(
states_pq
)
state1
=
heapq
.
heappop
(
states_pq
)
# merge states
state0
.
merge
(
state1
)
heapq
.
heappush
(
states_pq
,
state0
)
final_state
=
states_pq
[
0
]
partitions
=
final_state
.
get_partitions
()
if
equal_size
:
for
i
,
partition
in
enumerate
(
partitions
):
assert
len
(
partition
)
*
k_partitions
==
len
(
seqlen_list
),
(
f
"
{
len
(
partition
)
}
*
{
k_partitions
}
!=
{
len
(
seqlen_list
)
}
"
)
return
partitions
def
greedy_partition
(
seqlen_list
:
List
[
int
],
k_partitions
:
int
,
equal_size
:
bool
):
bias
=
sum
(
seqlen_list
)
+
1
if
equal_size
else
0
sorted_seqlen
=
[(
seqlen
+
bias
,
i
)
for
i
,
seqlen
in
enumerate
(
seqlen_list
)]
partitions
=
[[]
for
_
in
range
(
k_partitions
)]
partition_sums
=
[
0
for
_
in
range
(
k_partitions
)]
for
seqlen
,
i
in
sorted_seqlen
:
min_idx
=
None
for
j
in
range
(
k_partitions
):
if
min_idx
is
None
or
partition_sums
[
j
]
<
partition_sums
[
min_idx
]:
min_idx
=
j
partitions
[
min_idx
].
append
(
i
)
partition_sums
[
min_idx
]
+=
seqlen
if
equal_size
:
for
i
,
partition
in
enumerate
(
partitions
):
assert
len
(
partition
)
*
k_partitions
==
len
(
seqlen_list
),
(
f
"
{
len
(
partition
)
}
*
{
k_partitions
}
!=
{
len
(
seqlen_list
)
}
"
)
return
partitions
def
get_seqlen_balanced_partitions
(
seqlen_list
:
List
[
int
],
k_partitions
:
int
,
equal_size
:
bool
):
"""get order of seq lengths to make partitions balanced, this is
used in balacing sum of seqlength across dp ranks and microbatches
Parameters:
seqlen_list (List[int]):
seq lengths of each items
k_partitions (int):
resulting number of partitions
equal_size (bool):
if True, number of items in each partitions must be equal.
if False, only consider balancing the sum, each partition can have
variable number of items
Returns:
partitions (List[List[int]]):
return k_partitions list containing the index of items.
"""
assert
len
(
seqlen_list
)
>=
k_partitions
,
f
"number of items:[
{
len
(
seqlen_list
)
}
] < k_partitions:[
{
k_partitions
}
]"
def
_check_and_sort_partitions
(
partitions
):
assert
len
(
partitions
)
==
k_partitions
,
f
"
{
len
(
partitions
)
}
!=
{
k_partitions
}
"
seen_idx
=
set
()
sorted_partitions
=
[
None
]
*
k_partitions
for
i
,
partition
in
enumerate
(
partitions
):
assert
len
(
partition
)
>
0
,
f
"the
{
i
}
-th partition is empty"
for
idx
in
partition
:
seen_idx
.
add
(
idx
)
sorted_partitions
[
i
]
=
sorted
(
partition
)
assert
seen_idx
==
set
(
range
(
len
(
seqlen_list
)))
return
sorted_partitions
partitions
=
karmarkar_karp
(
seqlen_list
=
seqlen_list
,
k_partitions
=
k_partitions
,
equal_size
=
equal_size
)
return
_check_and_sort_partitions
(
partitions
)
def
log_seqlen_unbalance
(
seqlen_list
:
List
[
int
],
partitions
:
List
[
List
[
int
]],
prefix
):
# add some metrics of seqlen sum on dp ranks
k_partition
=
len
(
partitions
)
# assert len(seqlen_list) % k_partition == 0
batch_size
=
len
(
seqlen_list
)
//
k_partition
min_sum_seqlen
=
None
max_sum_seqlen
=
None
total_sum_seqlen
=
0
for
offset
in
range
(
0
,
len
(
seqlen_list
),
batch_size
):
cur_sum_seqlen
=
sum
(
seqlen_list
[
offset
:
offset
+
batch_size
])
if
min_sum_seqlen
is
None
or
cur_sum_seqlen
<
min_sum_seqlen
:
min_sum_seqlen
=
cur_sum_seqlen
if
max_sum_seqlen
is
None
or
cur_sum_seqlen
>
max_sum_seqlen
:
max_sum_seqlen
=
cur_sum_seqlen
total_sum_seqlen
+=
cur_sum_seqlen
balanced_sum_seqlen_list
=
[]
for
partition
in
partitions
:
cur_sum_seqlen_balanced
=
sum
([
seqlen_list
[
i
]
for
i
in
partition
])
balanced_sum_seqlen_list
.
append
(
cur_sum_seqlen_balanced
)
# print("balanced_sum_seqlen_list: ", balanced_sum_seqlen_list)
min_sum_seqlen_balanced
=
min
(
balanced_sum_seqlen_list
)
max_sum_seqlen_balanced
=
max
(
balanced_sum_seqlen_list
)
return
{
f
"
{
prefix
}
/min"
:
min_sum_seqlen
,
f
"
{
prefix
}
/max"
:
max_sum_seqlen
,
f
"
{
prefix
}
/minmax_diff"
:
max_sum_seqlen
-
min_sum_seqlen
,
f
"
{
prefix
}
/balanced_min"
:
min_sum_seqlen_balanced
,
f
"
{
prefix
}
/balanced_max"
:
max_sum_seqlen_balanced
,
f
"
{
prefix
}
/mean"
:
total_sum_seqlen
/
len
(
partitions
),
}
def
ceildiv
(
a
,
b
):
return
-
(
a
//
-
b
)
def
rearrange_micro_batches
(
batch
:
TensorDict
,
max_token_len
,
dp_group
=
None
):
"""Split the batch into a list of micro_batches, where the max_token_len is smaller than max_token_len
and the number of valid tokens in each micro batch is well balanced.
"""
# this is per local micro_bsz
max_seq_len
=
batch
[
"attention_mask"
].
shape
[
-
1
]
assert
max_token_len
>=
max_seq_len
,
(
f
"max_token_len must be greater than the sequence length. Got
{
max_token_len
=
}
and
{
max_seq_len
=
}
"
)
seq_len_effective
:
torch
.
Tensor
=
batch
[
"attention_mask"
].
sum
(
dim
=
1
)
total_seqlen
=
seq_len_effective
.
sum
().
item
()
num_micro_batches
=
ceildiv
(
total_seqlen
,
max_token_len
)
if
dist
.
is_initialized
():
num_micro_batches
=
torch
.
tensor
([
num_micro_batches
],
device
=
"cuda"
)
dist
.
all_reduce
(
num_micro_batches
,
op
=
dist
.
ReduceOp
.
MAX
,
group
=
dp_group
)
num_micro_batches
=
num_micro_batches
.
cpu
().
item
()
seq_len_effective
=
seq_len_effective
.
tolist
()
assert
num_micro_batches
<=
len
(
seq_len_effective
)
micro_bsz_idx
=
get_seqlen_balanced_partitions
(
seq_len_effective
,
num_micro_batches
,
equal_size
=
False
)
micro_batches
=
[]
for
partition
in
micro_bsz_idx
:
curr_micro_batch
=
[]
for
idx
in
partition
:
curr_micro_batch
.
append
(
batch
[
idx
:
idx
+
1
])
curr_micro_batch
=
torch
.
cat
(
curr_micro_batch
)
micro_batches
.
append
(
curr_micro_batch
)
return
micro_batches
,
micro_bsz_idx
def
get_reverse_idx
(
idx_map
):
reverse_idx_map
=
copy
.
deepcopy
(
idx_map
)
for
i
,
idx
in
enumerate
(
idx_map
):
reverse_idx_map
[
idx
]
=
i
return
reverse_idx_map
verl/utils/tokenizer.py
View file @
c132cbcb
...
...
@@ -15,38 +15,28 @@
from
typing
import
Optional
from
transformers
import
AutoConfig
,
AutoProcessor
,
AutoTokenizer
,
PreTrainedTokenizer
,
ProcessorMixin
from
transformers
import
AutoProcessor
,
AutoTokenizer
,
PreTrainedTokenizer
,
ProcessorMixin
def
get_tokenizer
(
model_path
,
correct_pad_token
=
True
,
correct_gemma
=
True
,
**
kwargs
)
->
PreTrainedTokenizer
:
"""Create a huggingface pretrained tokenizer.
Args:
name (str): The name of the tokenizer.
correct_pad_token (bool): Whether to correct the pad token id.
correct_gemma (bool): Whether to correct the gemma tokenizer.
**kwargs: The keyword arguments for the tokenizer.
Returns:
transformers.PreTrainedTokenizer: The pretrained tokenizer.
"""
config
=
AutoConfig
.
from_pretrained
(
model_path
)
def
get_tokenizer
(
model_path
:
str
,
**
kwargs
)
->
PreTrainedTokenizer
:
"""Create a huggingface pretrained tokenizer."""
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
**
kwargs
)
if
correct_gemma
and
getattr
(
config
,
"model_type"
,
None
)
in
[
"gemma"
,
"gemma2"
]
:
# the EOS token in gemma2 is ambiguious, which may worsen RL performance.
if
tokenizer
.
bos_token
==
"<bos>"
and
tokenizer
.
eos_token
==
"<eos>"
:
# the EOS token in gemma2
& gemma3
is ambiguious, which may worsen RL performance.
# https://huggingface.co/google/gemma-2-2b-it/commit/17a01657f5c87135bcdd0ec7abb4b2dece04408a
print
(
"Found gemma model. Set eos_token and eos_token_id to <end_of_turn> and 107."
)
tokenizer
.
eos_token
=
"<end_of_turn>"
if
correct_pad_token
:
if
tokenizer
.
pad_token_id
is
None
:
print
(
"Pad token is None. Set it to eos_token."
)
tokenizer
.
pad_token
=
tokenizer
.
eos_token
return
tokenizer
def
get_processor
(
model_path
,
**
kwargs
)
->
Optional
[
ProcessorMixin
]:
def
get_processor
(
model_path
:
str
,
**
kwargs
)
->
Optional
[
ProcessorMixin
]:
"""Create a huggingface pretrained processor."""
try
:
processor
=
AutoProcessor
.
from_pretrained
(
model_path
,
**
kwargs
)
except
Exception
:
...
...
verl/utils/torch_dtypes.py
View file @
c132cbcb
...
...
@@ -48,7 +48,7 @@ class PrecisionType:
return
precision
in
BFLOAT_LIST
@
staticmethod
def
to_dtype
(
precision
):
def
to_dtype
(
precision
)
->
torch
.
dtype
:
if
precision
in
HALF_LIST
:
return
torch
.
float16
elif
precision
in
FLOAT_LIST
:
...
...
@@ -59,12 +59,12 @@ class PrecisionType:
raise
RuntimeError
(
f
"unexpected precision:
{
precision
}
"
)
@
staticmethod
def
to_str
(
precision
)
:
def
to_str
(
precision
:
torch
.
dtype
)
->
str
:
if
precision
==
torch
.
float16
:
return
"f
p
16"
return
"f
loat
16"
elif
precision
==
torch
.
float32
:
return
"f
p
32"
return
"f
loat
32"
elif
precision
==
torch
.
bfloat16
:
return
"bf16"
return
"bf
loat
16"
else
:
raise
RuntimeError
(
f
"unexpected precision:
{
precision
}
"
)
verl/utils/torch_functional.py
View file @
c132cbcb
...
...
@@ -15,15 +15,12 @@
Contain small torch utilities
"""
import
math
from
typing
import
List
,
Literal
,
Union
from
typing
import
List
,
Literal
,
Optional
,
Tuple
,
Union
import
torch
import
torch.distributed
import
torch.nn.functional
as
F
from
torch.optim
import
Optimizer
from
torch.optim.lr_scheduler
import
LambdaLR
from
transformers
import
PreTrainedTokenizer
try
:
...
...
@@ -34,113 +31,85 @@ except ImportError:
FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE
=
False
def
logprobs_from_logits
(
logits
,
labels
):
"""
See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591
"""
if
FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE
:
batch_dim
=
logits
.
shape
[:
-
1
]
last_dim
=
logits
.
shape
[
-
1
]
logits
=
logits
.
reshape
(
-
1
,
last_dim
)
labels
=
labels
.
reshape
(
-
1
)
output
=
logprobs_from_logits_flash_attn
(
logits
,
labels
)
output
=
output
.
view
(
*
batch_dim
)
else
:
output
=
logprobs_from_logits_v2
(
logits
,
labels
)
return
output
@
torch
.
compiler
.
disable
()
def
log_probs_from_logits_flash_attn
(
logits
:
torch
.
Tensor
,
labels
:
torch
.
Tensor
)
->
torch
.
Tensor
:
output
=
cross_entropy_loss
(
logits
,
labels
,
inplace_backward
=
True
)
if
not
isinstance
(
output
,
tuple
):
raise
ValueError
(
"please make sure flash-attn>=2.4.3 where cross_entropy_loss returns Tuple[losses, z_losses]."
)
def
logprobs_from_logits_flash_attn
(
logits
,
labels
):
output
=
cross_entropy_loss
(
logits
,
labels
)
assert
isinstance
(
output
,
tuple
),
(
"please make sure flash-attn>=2.4.3 where cross_entropy_loss returns Tuple[losses, z_losses]."
)
return
-
output
[
0
]
def
logprobs_from_logits_v2
(
logits
:
torch
.
FloatTensor
,
labels
):
"""
A memory efficient implementation of logprobs_from_logits
"""
if
logits
.
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
logits_labels
=
torch
.
gather
(
logits
,
dim
=-
1
,
index
=
labels
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
# loop to reduce peak mem consumption
logsumexp_values
=
torch
.
stack
([
torch
.
logsumexp
(
l
,
dim
=-
1
)
for
l
in
logits
])
logprobs_labels
=
logits_labels
-
logsumexp_values
# log_softmax(x_i) = x_i - logsumexp(x)
else
:
# logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach
logprobs_labels
=
[]
for
row_logits
,
row_labels
in
zip
(
logits
,
labels
):
# loop to reduce peak mem consumption
row_logprobs
=
F
.
log_softmax
(
row_logits
,
dim
=-
1
)
row_logprobs_labels
=
row_logprobs
.
gather
(
dim
=-
1
,
index
=
row_labels
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
logprobs_labels
.
append
(
row_logprobs_labels
)
def
log_probs_from_logits
(
logits
:
torch
.
Tensor
,
labels
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Compute log probs on the label ids given logits.
logprobs_labels
=
torch
.
stack
(
logprobs_labels
)
return
logprobs_labels
We may use torch compile to speed up computing.
Args:
logits (torch.Tensor): logits of the model, shape (batch_size, seqlen, vocab_size)
labels (torch.Tensor): labels of the model, shape (batch_size, seqlen)
def
clip_by_value
(
x
,
tensor_min
,
tensor_max
):
"""
Tensor extenstion to torch.clamp
https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713
Returns:
torch.Tensor: log probs of the labels, shape (batch_size, seqlen)
"""
clipped
=
torch
.
max
(
torch
.
min
(
x
,
tensor_max
),
tensor_min
)
return
clipped
batch_dim
=
logits
.
shape
[:
-
1
]
vocab_dim
=
logits
.
shape
[
-
1
]
logits
=
logits
.
contiguous
().
view
(
-
1
,
vocab_dim
)
labels
=
labels
.
contiguous
().
view
(
-
1
)
if
FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE
:
output
=
log_probs_from_logits_flash_attn
(
logits
,
labels
)
else
:
# fall back to torch kernel, upcast logits to fp32
output
=
F
.
cross_entropy
(
logits
.
float
(),
labels
,
reduction
=
"none"
)
def
entropy_from_logits
(
logits
:
torch
.
Tensor
):
"""Calculate entropy from logits."""
pd
=
torch
.
nn
.
functional
.
softmax
(
logits
,
dim
=-
1
)
entropy
=
torch
.
logsumexp
(
logits
,
dim
=-
1
)
-
torch
.
sum
(
pd
*
logits
,
dim
=-
1
)
return
entropy
return
output
.
view
(
*
batch_dim
)
def
masked_mean
(
values
,
mask
,
axis
=
None
)
->
torch
.
Tensor
:
def
masked_mean
(
values
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
dim
:
int
=
None
,
eps
:
float
=
1e-8
)
->
torch
.
Tensor
:
"""Compute mean of tensor with a masked values."""
return
(
values
*
mask
).
sum
(
axis
=
axis
)
/
mask
.
sum
(
axis
=
axi
s
)
return
(
values
*
mask
).
sum
(
dim
=
dim
)
/
(
mask
.
sum
(
dim
=
dim
)
+
ep
s
)
def
masked_var
(
values
,
mask
,
unbiased
=
True
)
:
def
masked_var
(
values
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
unbiased
:
bool
=
True
)
->
torch
.
Tensor
:
"""Compute variance of tensor with masked values."""
mean
=
masked_mean
(
values
,
mask
)
centered_values
=
values
-
mean
variance
=
masked_mean
(
centered_values
**
2
,
mask
)
if
unbiased
:
mask_sum
=
mask
.
sum
()
if
mask_sum
==
0
:
raise
ValueError
(
"At least one element in the mask has to be 1."
)
# note that if mask_sum == 1, then there is a division by zero issue
# to avoid it you just need to use a larger minibatch_size
if
mask_sum
==
1
:
raise
ValueError
(
"The sum of the mask is one, which can cause a division by zero."
)
if
mask_sum
<=
1
:
print
(
"The sum of the mask is less than one, which can cause a division by zero."
)
return
variance
bessel_correction
=
mask_sum
/
(
mask_sum
-
1
)
variance
=
variance
*
bessel_correction
return
variance
def
masked_whiten
(
values
,
mask
,
shift_mean
=
True
)
:
def
masked_whiten
(
values
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
eps
:
float
=
1e-8
)
->
torch
.
Tensor
:
"""Whiten values with masked values."""
mean
,
var
=
masked_mean
(
values
,
mask
),
masked_var
(
values
,
mask
)
whitened
=
(
values
-
mean
)
*
torch
.
rsqrt
(
var
+
1e-8
)
if
not
shift_mean
:
whitened
+=
mean
return
whitened
return
(
values
-
mean
)
*
torch
.
rsqrt
(
var
+
eps
)
def
get_eos_mask
(
response_ids
:
torch
.
Tensor
,
eos_token
:
Union
[
int
,
List
[
int
]]
=
2
,
dtype
=
torch
.
int64
):
"""
end of sentence token can be int or list: 1 or [1, 2]
e.g. eos_token=1
response_ids: [0, 0, 2, 42, 3, 5, 1, 0, 0]
eos_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0]
def
get_eos_mask
(
response_ids
:
torch
.
Tensor
,
eos_token_id
:
Union
[
int
,
List
[
int
]]
=
2
,
dtype
:
torch
.
dtype
=
torch
.
long
):
"""Get the mask for the response ids, the mask will be 0 after the first eos token.
eos_token_id can be int or list: 1 or [1, 2].
```
e.g. eos_token = 1
response_ids: [0, 0, 2, 4, 3, 5, 1, 0, 0]
eos_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0]
```
"""
if
isinstance
(
eos_token
,
int
):
eos_token
=
[
eos_token
]
if
isinstance
(
eos_token
_id
,
int
):
eos_token
_id
=
[
eos_token
_id
]
eos_mask
=
torch
.
zeros_like
(
response_ids
,
dtype
=
torch
.
bool
)
for
token
in
eos_token
:
eos_mask
|=
response_ids
.
eq
(
token
)
for
token
_id
in
eos_token
_id
:
eos_mask
|=
response_ids
.
eq
(
token
_id
)
eos_mask
=
eos_mask
.
long
()
eos_mask
=
(
torch
.
cumsum
(
eos_mask
,
dim
=
1
)
-
eos_mask
).
bool
()
...
...
@@ -148,151 +117,211 @@ def get_eos_mask(response_ids: torch.Tensor, eos_token: Union[int, List[int]] =
return
eos_mask
def
pad_2d_list_to_length
(
response
,
pad_token_id
,
max_length
=
None
)
->
torch
.
Tensor
:
"""
pad a 2D list (e.g. responses, logprobs) to a 2D t
ensor
.
"""
response_length
=
max
(
len
(
sub_list
)
for
sub_list
in
response
)
if
max_length
is
not
None
and
max_length
>
response_length
:
def
pad_2d_list_to_length
(
response
:
List
[
List
[
int
]],
pad_token_id
:
int
,
max_length
:
Optional
[
int
]
=
None
)
->
torch
.
T
ensor
:
"""
Pad a 2D list (e.g. responses, log_probs) to a 2D tensor."""
max_
response_length
=
max
(
len
(
sub_list
)
for
sub_list
in
response
)
if
max_length
is
not
None
and
max_length
>
max_
response_length
:
target_length
=
max_length
else
:
target_length
=
response_length
target_length
=
max_response_length
padded_response
=
[
tuple
(
sub_list
)
+
(
pad_token_id
,)
*
(
target_length
-
len
(
sub_list
))
for
sub_list
in
response
]
tensor
=
torch
.
tensor
(
padded_response
)
return
tensor
def
pad_sequence_to_length
(
tensors
,
max_seq_len
,
pad_token_id
,
left_pad
=
False
):
"""
pad a 2D tensors (e.g. responses, logprobs) in the last dim to max_seq_length.
input shape: [bs, seq_length]
output shape: [bs, max_seq_length]
(0, max_seq_len - tensors.shape[-1]) means right pad to max_seq_length and no left pad
"""
if
tensors
.
shape
[
-
1
]
>=
max_seq_len
:
return
tensors
def
pad_sequence_to_length
(
tensor
:
torch
.
Tensor
,
max_seq_len
:
int
,
pad_token_id
:
int
,
left_pad
:
bool
=
False
)
->
torch
.
Tensor
:
"""Pad a nD tensors in the last dim to max_seq_len."""
if
tensor
.
size
(
-
1
)
>=
max_seq_len
:
return
tensor
pad_tuple
=
(
max_seq_len
-
tensors
.
shape
[
-
1
],
0
)
if
left_pad
else
(
0
,
max_seq_len
-
tensors
.
shape
[
-
1
])
return
F
.
pad
(
tensors
,
pad_tuple
,
"constant"
,
pad_token_id
)
pad_shape
=
list
(
tensor
.
shape
)
pad_shape
[
-
1
]
=
max_seq_len
-
tensor
.
size
(
-
1
)
pad_tensor
=
torch
.
full
(
pad_shape
,
fill_value
=
pad_token_id
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
return
torch
.
cat
((
pad_tensor
,
tensor
),
dim
=-
1
)
if
left_pad
else
torch
.
cat
((
tensor
,
pad_tensor
),
dim
=-
1
)
def
tokenize_and_postprocess_data
(
prompt
:
str
,
tokenizer
:
PreTrainedTokenizer
,
def
postprocess_data
(
input_ids
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
max_length
:
int
,
pad_token_id
:
int
,
left_pad
:
bool
=
True
,
truncation
:
Literal
[
"left"
,
"right"
,
"error"
]
=
"error"
,
):
"""
input_data is the output from tokenizer.
"""
"""Pad or truncate data."""
assert
truncation
in
[
"left"
,
"right"
,
"error"
]
input_data
=
tokenizer
(
prompt
,
return_tensors
=
"pt"
,
add_special_tokens
=
False
)
input_ids
=
input_data
[
"input_ids"
][
0
]
attention_mask
=
input_data
[
"attention_mask"
][
0
]
sequence_length
=
len
(
input_ids
)
if
sequence_length
<
max_length
:
seq_length
=
len
(
input_ids
)
if
seq_length
<
max_length
:
input_ids
=
pad_sequence_to_length
(
input_ids
,
max_seq_len
=
max_length
,
pad_token_id
=
pad_token_id
,
left_pad
=
left_pad
)
attention_mask
=
pad_sequence_to_length
(
attention_mask
,
max_seq_len
=
max_length
,
pad_token_id
=
0
,
left_pad
=
left_pad
)
elif
sequence_length
>
max_length
:
if
truncation
==
"left"
:
# actually, left truncation may not be reasonable
input_ids
=
input_ids
[
-
max_length
:]
attention_mask
=
attention_mask
[
-
max_length
:]
position_ids
=
pad_sequence_to_length
(
position_ids
,
max_seq_len
=
max_length
,
pad_token_id
=
0
,
left_pad
=
left_pad
)
elif
seq_length
>
max_length
:
if
truncation
==
"left"
:
# actually, left truncation may not be reasonable
input_ids
=
input_ids
[...,
-
max_length
:]
attention_mask
=
attention_mask
[...,
-
max_length
:]
position_ids
=
position_ids
[...,
-
max_length
:]
elif
truncation
==
"right"
:
input_ids
=
input_ids
[:
max_length
]
attention_mask
=
attention_mask
[:
max_length
]
input_ids
=
input_ids
[...,
:
max_length
]
attention_mask
=
attention_mask
[...,
:
max_length
]
position_ids
=
position_ids
[...,
:
max_length
]
elif
truncation
==
"error"
:
raise
NotImplementedError
(
f
"
{
seq
uence
_length
=
}
is larger than
{
max_length
=
}
"
)
raise
NotImplementedError
(
f
"
{
seq_length
}
is larger than
{
max_length
}
.
"
)
else
:
raise
NotImplementedError
(
f
"Unknown truncation method
{
truncation
}
"
)
return
input_ids
,
attention_mask
raise
NotImplementedError
(
f
"Unknown truncation method
{
truncation
}
."
)
def
remove_pad_token
(
input_ids
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
):
"""Remove the pad token.
Args:
input_ids shape: [bs, seq_length]
attention_mask shape: [bs, seq_length]
Returns:
no_padding_batch(List[List[int]]): contains the rmpad token ids per query.
"""
no_padding_batch
=
[]
for
ids
,
mask
in
zip
(
input_ids
,
attention_mask
):
no_padding_batch
.
append
((
ids
[
len
(
ids
)
-
mask
.
sum
()
:]).
cpu
().
numpy
().
tolist
())
return
no_padding_batch
def
get_cosine_schedule_with_warmup
(
optimizer
:
Optimizer
,
num_warmup_steps
:
int
,
num_training_steps
:
int
,
min_lr_ratio
:
float
=
0.0
,
num_cycles
:
float
=
0.5
,
last_epoch
:
int
=
-
1
,
):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
initial lr set in the optimizer.
Args:
optimizer (:class:`~torch.optim.Optimizer`):
The optimizer for which to schedule the learning rate.
num_warmup_steps (:obj:`int`):
The number of steps for the warmup phase.
num_training_steps (:obj:`int`):
The total number of training steps.
min_lr_ratio (:obj:`float`, `optional`, defaults to 0.0):
The minimum lr ratio w.r.t the maximum.
num_cycles (:obj:`float`, `optional`, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
following a half-cosine).
last_epoch (:obj:`int`, `optional`, defaults to -1):
The index of the last epoch when resuming training.
Return:
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
assert
min_lr_ratio
>=
0
and
min_lr_ratio
<=
1.0
coef
=
(
1
-
min_lr_ratio
)
*
0.5
intercept
=
(
1
+
min_lr_ratio
)
*
0.5
def
lr_lambda
(
current_step
):
if
current_step
<
num_warmup_steps
:
return
float
(
current_step
)
/
float
(
max
(
1
,
num_warmup_steps
))
progress
=
float
(
current_step
-
num_warmup_steps
)
/
float
(
max
(
1
,
num_training_steps
-
num_warmup_steps
))
x
=
math
.
cos
(
math
.
pi
*
float
(
num_cycles
)
*
2.0
*
progress
)
return
max
(
0.0
,
x
*
coef
+
intercept
)
return
LambdaLR
(
optimizer
,
lr_lambda
,
last_epoch
)
return
input_ids
,
attention_mask
,
position_ids
def
get_constant_schedule_with_warmup
(
optimizer
:
Optimizer
,
optimizer
:
torch
.
optim
.
Optimizer
,
num_warmup_steps
:
int
,
last_epoch
:
int
=
-
1
,
):
def
lr_lambda
(
current_step
):
return
min
(
1
,
float
(
current_step
)
/
float
(
max
(
1
,
num_warmup_steps
)))
)
->
torch
.
optim
.
lr_scheduler
.
LRScheduler
:
"""Get the lr scheduler for constant lr."""
def
lr_lambda
(
current_step
:
int
)
->
float
:
return
min
(
1.0
,
float
(
current_step
)
/
float
(
max
(
1
,
num_warmup_steps
)))
return
LambdaLR
(
optimizer
,
lr_lambda
,
last_epoch
)
def
get_unpad_data
(
attention_mask
):
seqlens_in_batch
=
attention_mask
.
sum
(
dim
=-
1
,
dtype
=
torch
.
int32
)
indices
=
torch
.
nonzero
(
attention_mask
.
flatten
(),
as_tuple
=
False
).
flatten
()
max_seqlen_in_batch
=
seqlens_in_batch
.
max
().
item
()
cu_seqlens
=
F
.
pad
(
torch
.
cumsum
(
seqlens_in_batch
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
))
return
(
indices
,
cu_seqlens
,
max_seqlen_in_batch
,
)
# https://github.com/meta-llama/llama-cookbook/blob/v0.0.5/src/llama_cookbook/policies/anyprecision_optimizer.py
class
AnyPrecisionAdamW
(
torch
.
optim
.
Optimizer
):
def
__init__
(
self
,
params
:
List
[
torch
.
Tensor
],
lr
:
float
=
1e-3
,
betas
:
Tuple
[
float
,
float
]
=
(
0.9
,
0.999
),
eps
:
float
=
1e-8
,
weight_decay
:
float
=
0.0
,
use_kahan_summation
:
bool
=
True
,
momentum_dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
variance_dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
compensation_buffer_dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
):
"""
Args:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
# Any Precision specific
use_kahan_summation = creates auxiliary buffer to ensure high precision
model param updates (default: False)
momentum_dtype = dtype for momentum (default: bfloat16)
variance_dtype = dtype for uncentered variance (default: bfloat16)
compensation_buffer_dtype = dtype for Kahan summation buffer (default: bfloat16)
# Usage
This optimizer implements optimizer states, and Kahan summation
for high precision updates, all in user controlled dtypes.
Defaults are variance in BF16, Momentum in FP32.
This can be run in FSDP mixed precision, amp, or full precision,
depending on what training pipeline you wish to work with.
Setting to use_kahan_summation = False, and changing momentum and
variance dtypes to FP32, reverts this to a standard AdamW optimizer.
"""
defaults
=
{
"lr"
:
lr
,
"betas"
:
betas
,
"eps"
:
eps
,
"weight_decay"
:
weight_decay
,
"use_kahan_summation"
:
use_kahan_summation
,
"momentum_dtype"
:
momentum_dtype
,
"variance_dtype"
:
variance_dtype
,
"compensation_buffer_dtype"
:
compensation_buffer_dtype
,
}
super
().
__init__
(
params
,
defaults
)
@
torch
.
no_grad
()
def
step
(
self
,
closure
=
None
):
"""
Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model and returns the loss.
"""
if
closure
is
not
None
:
with
torch
.
enable_grad
():
closure
()
for
group
in
self
.
param_groups
:
beta1
,
beta2
=
group
[
"betas"
]
lr
=
group
[
"lr"
]
weight_decay
=
group
[
"weight_decay"
]
eps
=
group
[
"eps"
]
use_kahan_summation
=
group
[
"use_kahan_summation"
]
momentum_dtype
=
group
[
"momentum_dtype"
]
variance_dtype
=
group
[
"variance_dtype"
]
compensation_buffer_dtype
=
group
[
"compensation_buffer_dtype"
]
for
p
in
group
[
"params"
]:
if
p
.
grad
is
None
:
continue
if
p
.
grad
.
is_sparse
:
raise
RuntimeError
(
"AnyPrecisionAdamW does not support sparse gradients."
)
state
=
self
.
state
[
p
]
# State initialization
if
len
(
state
)
==
0
:
state
[
"step"
]
=
torch
.
tensor
(
0.0
)
# momentum - EMA of gradient values
state
[
"exp_avg"
]
=
torch
.
zeros_like
(
p
,
dtype
=
momentum_dtype
)
# variance uncentered - EMA of squared gradient values
state
[
"exp_avg_sq"
]
=
torch
.
zeros_like
(
p
,
dtype
=
variance_dtype
)
# optional Kahan summation - accumulated error tracker
if
use_kahan_summation
:
state
[
"compensation"
]
=
torch
.
zeros_like
(
p
,
dtype
=
compensation_buffer_dtype
)
# Main processing
# update the steps for each param group update
state
[
"step"
]
+=
1
step
=
state
[
"step"
]
exp_avg
=
state
[
"exp_avg"
]
exp_avg_sq
=
state
[
"exp_avg_sq"
]
grad
=
p
.
grad
if
weight_decay
:
# weight decay, AdamW style
p
.
data
.
mul_
(
1
-
lr
*
weight_decay
)
exp_avg
.
mul_
(
beta1
).
add_
(
grad
,
alpha
=
1
-
beta1
)
# update momentum
exp_avg_sq
.
mul_
(
beta2
).
addcmul_
(
grad
,
grad
,
value
=
1
-
beta2
)
# update uncentered variance
bias_correction1
=
1
-
beta1
**
step
# adjust using bias1
step_size
=
lr
/
bias_correction1
denom_correction
=
(
1
-
beta2
**
step
)
**
0.5
# adjust using bias2 and avoids math import
centered_variance
=
(
exp_avg_sq
.
sqrt
()
/
denom_correction
).
add_
(
eps
,
alpha
=
1
)
if
use_kahan_summation
:
# lr update to compensation
compensation
=
state
[
"compensation"
]
compensation
.
addcdiv_
(
exp_avg
,
centered_variance
,
value
=-
step_size
)
# update weights with compensation (Kahan summation)
# save error back to compensation for next iteration
temp_buffer
=
p
.
detach
().
clone
()
p
.
data
.
add_
(
compensation
)
compensation
.
add_
(
temp_buffer
.
sub_
(
p
.
data
))
else
:
# usual AdamW updates
p
.
data
.
addcdiv_
(
exp_avg
,
centered_variance
,
value
=-
step_size
)
verl/utils/ulysses.py
View file @
c132cbcb
...
...
@@ -238,7 +238,7 @@ class Gather(torch.autograd.Function):
)
def
gather_outpus_and_unpad
(
def
gather_outpu
t
s_and_unpad
(
x
:
Tensor
,
gather_dim
:
int
,
unpad_dim
:
int
=
None
,
...
...
verl/workers/actor/base.py
View file @
c132cbcb
...
...
@@ -20,8 +20,8 @@ from typing import Any, Dict
import
torch
from
ver
l
import
DataProto
from
verl.workers.actor
.config
import
ActorConfig
from
...protoco
l
import
DataProto
from
.config
import
ActorConfig
__all__
=
[
"BasePPOActor"
]
...
...
verl/workers/actor/config.py
View file @
c132cbcb
...
...
@@ -26,6 +26,7 @@ class ModelConfig:
override_config
:
Dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
enable_gradient_checkpointing
:
bool
=
True
trust_remote_code
:
bool
=
True
freeze_vision_tower
:
bool
=
False
def
post_init
(
self
):
if
self
.
tokenizer_path
is
None
:
...
...
@@ -37,7 +38,8 @@ class OptimConfig:
lr
:
float
=
1e-6
betas
:
Tuple
[
float
,
float
]
=
(
0.9
,
0.999
)
weight_decay
:
float
=
1e-2
lr_warmup_steps_ratio
:
float
=
0.0
strategy
:
str
=
"adamw"
lr_warmup_ratio
:
float
=
0.0
min_lr_ratio
:
Optional
[
float
]
=
None
warmup_style
:
str
=
"constant"
"""auto keys"""
...
...
@@ -47,9 +49,11 @@ class OptimConfig:
@
dataclass
class
FSDPConfig
:
enable_full_shard
:
bool
=
True
param_offload
:
bool
=
False
optimizer_offload
:
bool
=
False
enable_cpu_offload
:
bool
=
False
enable_rank0_init
:
bool
=
False
use_orig_params
:
bool
=
False
torch_dtype
:
Optional
[
str
]
=
None
fsdp_size
:
int
=
-
1
mp_param_dtype
:
str
=
"bf16"
mp_reduce_dtype
:
str
=
"fp32"
mp_buffer_dtype
:
str
=
"fp32"
...
...
@@ -57,41 +61,41 @@ class FSDPConfig:
@
dataclass
class
OffloadConfig
:
param_
offload
:
bool
=
False
optimizer
_offload
:
bool
=
False
offload
_params
:
bool
=
False
o
ffload_o
ptimizer
:
bool
=
False
@
dataclass
class
ActorConfig
:
strategy
:
str
=
"fsdp"
global_batch_size
:
int
=
256
micro_batch_size_per_device_for_update
:
int
=
field
(
default
=-
1
,
init
=
False
)
micro_batch_size_per_device_for_experience
:
int
=
field
(
default
=-
1
,
init
=
False
)
micro_batch_size_per_device_for_update
:
int
=
4
micro_batch_size_per_device_for_experience
:
int
=
16
max_grad_norm
:
float
=
1.0
clip_ratio
:
float
=
0.2
entropy_coeff
:
float
=
1e-3
use_kl_loss
:
bool
=
True
kl_loss_coef
:
float
=
1e-3
kl_loss_type
:
str
=
"low_var_kl"
ppo_epochs
:
int
=
1
padding_free
:
bool
=
False
ulysses_sequence_parallel_size
:
int
=
1
use_torch_compile
:
bool
=
True
model
:
ModelConfig
=
field
(
default_factory
=
ModelConfig
)
optim
:
OptimConfig
=
field
(
default_factory
=
OptimConfig
)
fsdp
:
FSDPConfig
=
field
(
default_factory
=
FSDPConfig
)
offload
:
OffloadConfig
=
field
(
default_factory
=
OffloadConfig
)
"""auto keys"""
global_batch_size_per_device
:
int
=
field
(
default
=-
1
,
init
=
False
)
def
post_init
(
self
):
if
self
.
ppo_epochs
!=
1
:
raise
NotImplementedError
disable_kl
:
bool
=
field
(
default
=
False
,
init
=
False
)
use_kl_loss
:
bool
=
field
(
default
=
False
,
init
=
False
)
kl_penalty
:
str
=
field
(
default
=
"kl"
,
init
=
False
)
kl_coef
:
float
=
field
(
default
=
0.0
,
init
=
False
)
@
dataclass
class
RefConfig
:
strategy
:
str
=
"fsdp"
fsdp
:
FSDPConfig
=
field
(
default_factory
=
FSDPConfig
)
offload
:
OffloadConfig
=
field
(
default_factory
=
OffloadConfig
)
"""auto keys"""
micro_batch_size_per_device_for_experience
:
int
=
field
(
default
=-
1
,
init
=
False
)
padding_free
:
bool
=
field
(
default
=
False
,
init
=
False
)
ulysses_sequence_parallel_size
:
int
=
field
(
default
=
1
,
init
=
False
)
use_torch_compile
:
bool
=
field
(
default
=
True
,
init
=
False
)
verl/workers/actor/dp_actor.py
View file @
c132cbcb
...
...
@@ -17,20 +17,26 @@ Implement Actor
import
os
from
collections
import
defaultdict
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
Optional
import
torch
from
ray.experimental.tqdm_ray
import
tqdm
from
torch
import
nn
from
torch.distributed.fsdp
import
FullyShardedDataParallel
as
FSDP
from
tqdm
import
tqdm
import
verl.utils.torch_functional
as
verl_F
from
verl
import
DataProto
from
verl.trainer
import
core_algos
from
verl.utils.py_functional
import
append_to_dict
from
verl.utils.torch_functional
import
logprobs_from_logits
,
masked_mean
from
verl.workers.actor.base
import
BasePPOActor
from
verl.workers.actor.config
import
ActorConfig
from
...protocol
import
DataProto
from
...trainer
import
core_algos
from
...utils
import
torch_functional
as
VF
from
...utils.py_functional
import
append_to_dict
from
...utils.ulysses
import
gather_outputs_and_unpad
,
ulysses_pad_and_slice_inputs
from
.base
import
BasePPOActor
from
.config
import
ActorConfig
try
:
from
flash_attn.bert_padding
import
index_first_axis
,
pad_input
,
rearrange
,
unpad_input
except
ImportError
:
pass
__all__
=
[
"DataParallelPPOActor"
]
...
...
@@ -50,17 +56,18 @@ class DataParallelPPOActor(BasePPOActor):
self
.
rank
=
int
(
os
.
getenv
(
"RANK"
,
"0"
))
self
.
actor_module
=
actor_module
self
.
actor_optimizer
=
actor_optimizer
self
.
compute_entropy_from_logits
=
torch
.
compile
(
verl_F
.
entropy_from_logits
,
dynamic
=
True
)
if
config
.
use_torch_compile
:
self
.
log_probs_from_logits
=
torch
.
compile
(
VF
.
log_probs_from_logits
,
dynamic
=
True
)
else
:
self
.
log_probs_from_logits
=
VF
.
log_probs_from_logits
def
_forward_micro_batch
(
self
,
micro_batch
:
Dict
[
str
,
torch
.
Tensor
],
temperature
:
float
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
_forward_micro_batch
(
self
,
micro_batch
:
Dict
[
str
,
torch
.
Tensor
],
temperature
:
float
)
->
torch
.
Tensor
:
"""
Returns:
entropy: # (bs, response_len)
log_probs: # (bs, response_len)
"""
input_ids
=
micro_batch
[
"input_ids"
]
batch_size
,
seqlen
=
input_ids
.
shape
attention_mask
=
micro_batch
[
"attention_mask"
]
position_ids
=
micro_batch
[
"position_ids"
]
responses
=
micro_batch
[
"responses"
]
...
...
@@ -68,29 +75,82 @@ class DataParallelPPOActor(BasePPOActor):
if
position_ids
.
dim
()
==
3
:
# qwen2vl mrope
position_ids
=
position_ids
.
transpose
(
0
,
1
)
# (bsz, 3, seqlen) -> (3, bsz, seqlen)
vision_inputs
=
{}
if
"pixel_values"
in
micro_batch
:
vision_inputs
[
"pixel_values"
]
=
torch
.
cat
(
micro_batch
[
"pixel_values"
],
dim
=
0
)
vision_inputs
[
"image_grid_thw"
]
=
torch
.
cat
(
micro_batch
[
"image_grid_thw"
],
dim
=
0
)
multi_modal_inputs
=
{}
if
"multi_modal_inputs"
in
micro_batch
:
for
key
in
micro_batch
[
"multi_modal_inputs"
][
0
].
keys
():
multi_modal_inputs
[
key
]
=
torch
.
cat
(
[
inputs
[
key
]
for
inputs
in
micro_batch
[
"multi_modal_inputs"
]],
dim
=
0
)
if
self
.
config
.
padding_free
:
# TODO (yaowei): preprocess data for padding_free and ulysses
raise
NotImplementedError
input_ids_rmpad
,
indices
,
*
_
=
unpad_input
(
input_ids
.
unsqueeze
(
-
1
),
attention_mask
)
# input_ids_rmpad (total_nnz, ...)
input_ids_rmpad
=
input_ids_rmpad
.
transpose
(
0
,
1
)
# (1, total_nnz)
# unpad the position_ids to align the rotary
if
position_ids
.
dim
()
==
3
:
position_ids_rmpad
=
(
index_first_axis
(
rearrange
(
position_ids
,
"c b s ... -> (b s) c ..."
),
indices
)
.
transpose
(
0
,
1
)
.
unsqueeze
(
1
)
)
# (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
else
:
position_ids_rmpad
=
index_first_axis
(
rearrange
(
position_ids
.
unsqueeze
(
-
1
),
"b s ... -> (b s) ..."
),
indices
).
transpose
(
0
,
1
)
# for compute the log_prob
input_ids_rmpad_rolled
=
torch
.
roll
(
input_ids_rmpad
,
shifts
=-
1
,
dims
=
1
)
# (1, total_nnz)
# pad and slice the inputs if sp > 1
if
self
.
config
.
ulysses_sequence_parallel_size
>
1
:
input_ids_rmpad
,
position_ids_rmpad
,
pad_size
=
ulysses_pad_and_slice_inputs
(
input_ids_rmpad
,
position_ids_rmpad
,
sp_size
=
self
.
config
.
ulysses_sequence_parallel_size
)
input_ids_rmpad_rolled
,
_
,
_
=
ulysses_pad_and_slice_inputs
(
input_ids_rmpad_rolled
,
None
,
self
.
config
.
ulysses_sequence_parallel_size
)
input_ids_rmpad_rolled
=
input_ids_rmpad_rolled
.
squeeze
(
0
)
# ((total_nnz / sp) + pad)
# only pass input_ids and position_ids to enable flash_attn_varlen
output
=
self
.
actor_module
(
input_ids
=
input_ids_rmpad
,
attention_mask
=
None
,
position_ids
=
position_ids_rmpad
,
**
multi_modal_inputs
,
use_cache
=
False
,
)
# prevent model thinks we are generating
logits_rmpad
=
output
.
logits
.
squeeze
(
0
)
# (total_nnz, vocab_size)
logits_rmpad
.
div_
(
temperature
)
# ((total_nnz / sp) + pad)
log_probs
=
self
.
log_probs_from_logits
(
logits
=
logits_rmpad
,
labels
=
input_ids_rmpad_rolled
)
# gather log_prob if sp > 1
if
self
.
config
.
ulysses_sequence_parallel_size
>
1
:
# gather and unpad for the ulysses sp
log_probs
=
gather_outputs_and_unpad
(
log_probs
,
gather_dim
=
0
,
unpad_dim
=
0
,
padding_size
=
pad_size
)
# pad back to (bsz, seqlen)
full_log_probs
=
pad_input
(
hidden_states
=
log_probs
.
unsqueeze
(
-
1
),
indices
=
indices
,
batch
=
batch_size
,
seqlen
=
seqlen
)
log_probs
=
full_log_probs
.
squeeze
(
-
1
)[:,
-
response_length
-
1
:
-
1
]
# (bsz, response_length)
else
:
output
=
self
.
actor_module
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
**
vision
_inputs
,
**
multi_modal
_inputs
,
use_cache
=
False
,
)
logits
:
torch
.
Tensor
=
output
.
logits
logits
.
div_
(
temperature
)
logits
=
logits
[:,
-
response_length
-
1
:
-
1
,
:]
# (bsz, response_length, vocab_size)
log_probs
=
logprobs_from_logits
(
logits
,
responses
)
# (bsz, response_length)
entropy
=
verl_F
.
entropy_from_logits
(
logits
)
# (bsz, response_length)
log_probs
=
self
.
log_probs_from_logits
(
logits
,
responses
)
# (bsz, response_length)
return
entropy
,
log_probs
return
log_probs
def
_optimizer_step
(
self
)
->
torch
.
Tensor
:
if
isinstance
(
self
.
actor_module
,
FSDP
):
...
...
@@ -98,7 +158,12 @@ class DataParallelPPOActor(BasePPOActor):
else
:
grad_norm
=
nn
.
utils
.
clip_grad_norm_
(
self
.
actor_module
.
parameters
(),
max_norm
=
self
.
config
.
max_grad_norm
)
self
.
actor_optimizer
.
step
()
if
not
torch
.
isfinite
(
grad_norm
):
print
(
"Gradient norm is not finite. Skip update."
)
else
:
self
.
actor_optimizer
.
step
()
self
.
actor_optimizer
.
zero_grad
()
return
grad_norm
@
torch
.
no_grad
()
...
...
@@ -124,19 +189,21 @@ class DataParallelPPOActor(BasePPOActor):
temperature
=
data
.
meta_info
[
"temperature"
]
select_keys
=
[
"responses"
,
"input_ids"
,
"attention_mask"
,
"position_ids"
]
if
"
pixel_value
s"
in
data
.
non_tensor_batch
.
keys
():
non_tensor_select_keys
=
[
"
pixel_values"
,
"image_grid_thw
"
]
if
"
multi_modal_input
s"
in
data
.
non_tensor_batch
.
keys
():
non_tensor_select_keys
=
[
"
multi_modal_inputs
"
]
else
:
non_tensor_select_keys
=
None
non_tensor_select_keys
=
[]
micro_batches
=
data
.
select
(
select_keys
,
non_tensor_select_keys
).
split
(
self
.
config
.
micro_batch_size_per_device_for_experience
)
log_probs_lst
=
[]
for
micro_batch
in
tqdm
(
micro_batches
,
desc
=
"Compute log probs"
,
disable
=
(
self
.
rank
!=
0
)):
micro_batch
.
to
(
"cuda"
)
if
self
.
rank
==
0
:
micro_batches
=
tqdm
(
micro_batches
,
desc
=
"Compute log probs"
,
position
=
2
)
for
micro_batch
in
micro_batches
:
model_inputs
=
{
**
micro_batch
.
batch
,
**
micro_batch
.
non_tensor_batch
}
_
,
log_probs
=
self
.
_forward_micro_batch
(
model_inputs
,
temperature
=
temperature
)
log_probs
=
self
.
_forward_micro_batch
(
model_inputs
,
temperature
=
temperature
)
log_probs_lst
.
append
(
log_probs
)
log_probs
=
torch
.
concat
(
log_probs_lst
,
dim
=
0
)
...
...
@@ -147,83 +214,74 @@ class DataParallelPPOActor(BasePPOActor):
temperature
=
data
.
meta_info
[
"temperature"
]
# temperature must be in the data.meta_info to avoid slient error
select_keys
=
[
"responses"
,
"input_ids"
,
"attention_mask"
,
"position_ids"
,
"old_log_probs"
,
"advantages"
]
if
self
.
config
.
use_kl_loss
:
select_keys
.
append
(
"ref_log_prob"
)
if
self
.
config
.
use_kl_loss
and
not
self
.
config
.
disable_kl
:
select_keys
.
append
(
"ref_log_prob
s
"
)
if
"
pixel_value
s"
in
data
.
non_tensor_batch
.
keys
():
non_tensor_select_keys
=
[
"
pixel_values"
,
"image_grid_thw
"
]
if
"
multi_modal_input
s"
in
data
.
non_tensor_batch
.
keys
():
non_tensor_select_keys
=
[
"
multi_modal_inputs
"
]
else
:
non_tensor_select_keys
=
None
non_tensor_select_keys
=
[]
# TODO (yaowei): support ppo epochs
# Split to make minibatch iterator for updating the actor
# See PPO paper for details. https://arxiv.org/abs/1707.06347
mini_batches
=
data
.
select
(
select_keys
,
non_tensor_select_keys
).
split
(
self
.
config
.
global_batch_size_per_device
)
metrics
=
defaultdict
(
list
)
n
=
len
(
mini_batches
)
for
i
,
mini_batch
in
enumerate
(
mini_batches
):
gradient_accumulation
=
(
self
.
config
.
global_batch_size_per_device
//
self
.
config
.
micro_batch_size_per_device_for_update
)
micro_batches
=
mini_batch
.
split
(
self
.
config
.
micro_batch_size_per_device_for_update
)
self
.
actor_optimizer
.
zero_grad
()
for
micro_batch
in
tqdm
(
micro_batches
,
desc
=
f
"Update policy [
{
i
+
1
}
/
{
n
}
]"
,
disable
=
(
self
.
rank
!=
0
)):
micro_batch
.
to
(
"cuda"
)
model_inputs
=
{
**
micro_batch
.
batch
,
**
micro_batch
.
non_tensor_batch
}
responses
=
model_inputs
[
"responses"
]
response_length
=
responses
.
size
(
1
)
attention_mask
=
model_inputs
[
"attention_mask"
]
response_mask
=
attention_mask
[:,
-
response_length
:]
old_log_prob
=
model_inputs
[
"old_log_probs"
]
advantages
=
model_inputs
[
"advantages"
]
clip_ratio
=
self
.
config
.
clip_ratio
entropy_coeff
=
self
.
config
.
entropy_coeff
# all return: (bsz, response_length)
entropy
,
log_prob
=
self
.
_forward_micro_batch
(
model_inputs
,
temperature
=
temperature
)
pg_loss
,
pg_clipfrac
,
ppo_kl
=
core_algos
.
compute_policy_loss
(
old_log_prob
=
old_log_prob
,
log_prob
=
log_prob
,
advantages
=
advantages
,
eos_mask
=
response_mask
,
cliprange
=
clip_ratio
,
for
_
in
range
(
self
.
config
.
ppo_epochs
):
if
self
.
rank
==
0
:
mini_batches
=
tqdm
(
mini_batches
,
desc
=
"Train mini-batches"
,
position
=
2
)
for
mini_batch
in
mini_batches
:
gradient_accumulation
=
(
self
.
config
.
global_batch_size_per_device
//
self
.
config
.
micro_batch_size_per_device_for_update
)
# compute entropy loss from entropy
entropy_loss
=
verl_F
.
masked_mean
(
entropy
,
response_mask
)
# compute policy loss
policy_loss
=
pg_loss
-
entropy_loss
*
entropy_coeff
if
self
.
config
.
use_kl_loss
:
ref_log_prob
=
model_inputs
[
"ref_log_prob"
]
# compute kl loss
kld
=
core_algos
.
kl_penalty
(
logprob
=
log_prob
,
ref_logprob
=
ref_log_prob
,
kl_penalty
=
self
.
config
.
kl_loss_type
,
micro_batches
=
mini_batch
.
split
(
self
.
config
.
micro_batch_size_per_device_for_update
)
if
self
.
rank
==
0
:
micro_batches
=
tqdm
(
micro_batches
,
desc
=
"Update policy"
,
position
=
3
)
for
micro_batch
in
micro_batches
:
model_inputs
=
{
**
micro_batch
.
batch
,
**
micro_batch
.
non_tensor_batch
}
responses
=
model_inputs
[
"responses"
]
response_length
=
responses
.
size
(
1
)
attention_mask
=
model_inputs
[
"attention_mask"
]
response_mask
=
attention_mask
[:,
-
response_length
:]
old_log_probs
=
model_inputs
[
"old_log_probs"
]
advantages
=
model_inputs
[
"advantages"
]
# all return: (bsz, response_length)
log_probs
=
self
.
_forward_micro_batch
(
model_inputs
,
temperature
=
temperature
)
pg_loss
,
pg_clipfrac
,
ppo_kl
=
core_algos
.
compute_policy_loss
(
old_log_probs
=
old_log_probs
,
log_probs
=
log_probs
,
advantages
=
advantages
,
eos_mask
=
response_mask
,
cliprange
=
self
.
config
.
clip_ratio
,
)
kl_loss
=
masked_mean
(
kld
,
response_mask
)
policy_loss
=
policy_loss
+
kl_loss
*
self
.
config
.
kl_loss_coef
metrics
[
"actor/kl_loss"
]
=
kl_loss
.
detach
().
item
()
metrics
[
"actor/kl_coef"
]
=
self
.
config
.
kl_loss_coef
loss
=
policy_loss
/
gradient_accumulation
loss
.
backward
()
batch_metrics
=
{
"actor/entropy_loss"
:
entropy_loss
.
detach
().
item
(),
"actor/pg_loss"
:
pg_loss
.
detach
().
item
(),
"actor/pg_clipfrac"
:
pg_clipfrac
.
detach
().
item
(),
"actor/ppo_kl"
:
ppo_kl
.
detach
().
item
(),
}
append_to_dict
(
metrics
,
batch_metrics
)
if
"ref_log_probs"
in
model_inputs
:
ref_log_probs
=
model_inputs
[
"ref_log_probs"
]
# compute kl loss
kld
=
core_algos
.
kl_penalty
(
log_probs
=
log_probs
,
ref_log_probs
=
ref_log_probs
,
kl_penalty
=
self
.
config
.
kl_penalty
,
)
kl_loss
=
VF
.
masked_mean
(
kld
,
response_mask
)
pg_loss
=
pg_loss
+
kl_loss
*
self
.
config
.
kl_coef
metrics
[
"actor/kl_loss"
]
=
kl_loss
.
detach
().
item
()
metrics
[
"actor/kl_coef"
]
=
self
.
config
.
kl_coef
loss
=
pg_loss
/
gradient_accumulation
loss
.
backward
()
batch_metrics
=
{
"actor/pg_loss"
:
pg_loss
.
detach
().
item
(),
"actor/pg_clipfrac"
:
pg_clipfrac
.
detach
().
item
(),
"actor/ppo_kl"
:
ppo_kl
.
detach
().
item
(),
}
append_to_dict
(
metrics
,
batch_metrics
)
grad_norm
=
self
.
_optimizer_step
()
append_to_dict
(
metrics
,
{
"actor/grad_norm"
:
grad_norm
.
detach
().
item
()})
grad_norm
=
self
.
_optimizer_step
()
append_to_dict
(
metrics
,
{
"actor/grad_norm"
:
grad_norm
.
detach
().
item
()})
self
.
actor_optimizer
.
zero_grad
()
return
metrics
verl/workers/config.py
View file @
c132cbcb
...
...
@@ -17,10 +17,10 @@ ActorRolloutRef config
from
dataclasses
import
dataclass
,
field
from
verl.workers
.actor
import
ActorConfig
,
FSDPConfig
,
ModelConfig
,
OptimConfig
,
RefConfig
from
verl.workers
.critic
import
CriticConfig
from
verl.workers
.reward
import
RewardConfig
from
verl.workers
.rollout
import
RolloutConfig
from
.actor
import
ActorConfig
,
FSDPConfig
,
ModelConfig
,
OptimConfig
,
RefConfig
from
.critic
import
CriticConfig
from
.reward
import
RewardConfig
from
.rollout
import
RolloutConfig
__all__
=
[
...
...
@@ -46,5 +46,7 @@ class WorkerConfig:
rollout
:
RolloutConfig
=
field
(
default_factory
=
RolloutConfig
)
def
post_init
(
self
):
self
.
ref
.
padding_free
=
self
.
actor
.
padding_free
self
.
ref
.
micro_batch_size_per_device_for_experience
=
self
.
actor
.
micro_batch_size_per_device_for_experience
self
.
ref
.
padding_free
=
self
.
actor
.
padding_free
self
.
ref
.
ulysses_sequence_parallel_size
=
self
.
actor
.
ulysses_sequence_parallel_size
self
.
ref
.
use_torch_compile
=
self
.
actor
.
use_torch_compile
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment