Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
EasyR1
Commits
c132cbcb
Commit
c132cbcb
authored
Apr 02, 2025
by
chenych
Browse files
0402 update
parent
f92481f0
Changes
72
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
470 additions
and
340 deletions
+470
-340
verl/workers/critic/base.py
verl/workers/critic/base.py
+2
-2
verl/workers/critic/config.py
verl/workers/critic/config.py
+4
-3
verl/workers/critic/dp_critic.py
verl/workers/critic/dp_critic.py
+121
-64
verl/workers/fsdp_workers.py
verl/workers/fsdp_workers.py
+215
-153
verl/workers/reward/custom.py
verl/workers/reward/custom.py
+23
-30
verl/workers/rollout/base.py
verl/workers/rollout/base.py
+1
-1
verl/workers/rollout/config.py
verl/workers/rollout/config.py
+6
-5
verl/workers/rollout/vllm_rollout/__init__.py
verl/workers/rollout/vllm_rollout/__init__.py
+1
-2
verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
+41
-49
verl/workers/sharding_manager/base.py
verl/workers/sharding_manager/base.py
+1
-1
verl/workers/sharding_manager/fsdp_ulysses.py
verl/workers/sharding_manager/fsdp_ulysses.py
+4
-5
verl/workers/sharding_manager/fsdp_vllm.py
verl/workers/sharding_manager/fsdp_vllm.py
+51
-25
No files found.
verl/workers/critic/base.py
View file @
c132cbcb
...
...
@@ -20,8 +20,8 @@ from typing import Any, Dict
import
torch
from
ver
l
import
DataProto
from
verl.workers.critic
.config
import
CriticConfig
from
...protoco
l
import
DataProto
from
.config
import
CriticConfig
__all__
=
[
"BasePPOCritic"
]
...
...
verl/workers/critic/config.py
View file @
c132cbcb
...
...
@@ -17,17 +17,18 @@ Critic config
from
dataclasses
import
dataclass
,
field
from
verl.workers
.actor.config
import
FSDPConfig
,
ModelConfig
,
OffloadConfig
,
OptimConfig
from
.
.actor.config
import
FSDPConfig
,
ModelConfig
,
OffloadConfig
,
OptimConfig
@
dataclass
class
CriticConfig
:
strategy
:
str
=
"fsdp"
global_batch_size
:
int
=
256
micro_batch_size_per_device_for_update
:
int
=
field
(
default
=-
1
,
init
=
False
)
micro_batch_size_per_device_for_experience
:
int
=
field
(
default
=-
1
,
init
=
False
)
micro_batch_size_per_device_for_update
:
int
=
4
micro_batch_size_per_device_for_experience
:
int
=
16
max_grad_norm
:
float
=
1.0
cliprange_value
:
float
=
0.5
ppo_epochs
:
int
=
1
padding_free
:
bool
=
False
ulysses_sequence_parallel_size
:
int
=
1
model
:
ModelConfig
=
field
(
default_factory
=
ModelConfig
)
...
...
verl/workers/critic/dp_critic.py
View file @
c132cbcb
...
...
@@ -20,17 +20,23 @@ from collections import defaultdict
from
typing
import
Any
,
Dict
import
torch
import
torch.distributed
from
ray.experimental.tqdm_ray
import
tqdm
from
torch
import
nn
from
torch.distributed.fsdp
import
FullyShardedDataParallel
as
FSDP
from
tqdm
import
tqdm
from
verl
import
DataProto
from
verl.trainer
import
core_algos
from
verl.utils.py_functional
import
append_to_dict
from
verl.utils.torch_functional
import
masked_mean
from
verl.workers.critic.base
import
BasePPOCritic
from
verl.workers.critic.config
import
CriticConfig
from
...protocol
import
DataProto
from
...trainer
import
core_algos
from
...utils
import
torch_functional
as
VF
from
...utils.py_functional
import
append_to_dict
from
...utils.ulysses
import
gather_outputs_and_unpad
,
ulysses_pad_and_slice_inputs
from
.base
import
BasePPOCritic
from
.config
import
CriticConfig
try
:
from
flash_attn.bert_padding
import
index_first_axis
,
pad_input
,
rearrange
,
unpad_input
except
ImportError
:
pass
__all__
=
[
"DataParallelPPOCritic"
]
...
...
@@ -45,6 +51,7 @@ class DataParallelPPOCritic(BasePPOCritic):
def
_forward_micro_batch
(
self
,
micro_batch
:
Dict
[
str
,
torch
.
Tensor
])
->
torch
.
Tensor
:
input_ids
=
micro_batch
[
"input_ids"
]
batch_size
,
seqlen
=
input_ids
.
shape
attention_mask
=
micro_batch
[
"attention_mask"
]
position_ids
=
micro_batch
[
"position_ids"
]
responses
=
micro_batch
[
"responses"
]
...
...
@@ -52,20 +59,61 @@ class DataParallelPPOCritic(BasePPOCritic):
if
position_ids
.
dim
()
==
3
:
# qwen2vl mrope
position_ids
=
position_ids
.
transpose
(
0
,
1
)
# (bsz, 3, seqlen) -> (3, bsz, seqlen)
vision_inputs
=
{}
if
"pixel_values"
in
micro_batch
:
vision_inputs
[
"pixel_values"
]
=
torch
.
cat
(
micro_batch
[
"pixel_values"
],
dim
=
0
)
vision_inputs
[
"image_grid_thw"
]
=
torch
.
cat
(
micro_batch
[
"image_grid_thw"
],
dim
=
0
)
multi_modal_inputs
=
{}
if
"multi_modal_inputs"
in
micro_batch
:
for
key
in
micro_batch
[
"multi_modal_inputs"
][
0
].
keys
():
multi_modal_inputs
[
key
]
=
torch
.
cat
(
[
inputs
[
key
]
for
inputs
in
micro_batch
[
"multi_modal_inputs"
]],
dim
=
0
)
if
self
.
config
.
padding_free
:
# TODO (yaowei): preprocess data for padding_free and ulysses
raise
NotImplementedError
input_ids_rmpad
,
indices
,
*
_
=
unpad_input
(
input_ids
.
unsqueeze
(
-
1
),
attention_mask
)
# input_ids_rmpad (total_nnz, ...)
input_ids_rmpad
=
input_ids_rmpad
.
transpose
(
0
,
1
)
# (1, total_nnz)
# unpad the position_ids to align the rotary
if
position_ids
.
dim
()
==
3
:
position_ids_rmpad
=
(
index_first_axis
(
rearrange
(
position_ids
,
"c b s ... -> (b s) c ..."
),
indices
)
.
transpose
(
0
,
1
)
.
unsqueeze
(
1
)
)
# (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
else
:
position_ids_rmpad
=
index_first_axis
(
rearrange
(
position_ids
.
unsqueeze
(
-
1
),
"b s ... -> (b s) ..."
),
indices
).
transpose
(
0
,
1
)
# pad and slice the inputs if sp > 1
if
self
.
config
.
ulysses_sequence_parallel_size
>
1
:
input_ids_rmpad
,
position_ids_rmpad
,
pad_size
=
ulysses_pad_and_slice_inputs
(
input_ids_rmpad
,
position_ids_rmpad
,
sp_size
=
self
.
config
.
ulysses_sequence_parallel_size
)
# only pass input_ids and position_ids to enable flash_attn_varlen
output
=
self
.
critic_module
(
input_ids
=
input_ids_rmpad
,
attention_mask
=
None
,
position_ids
=
position_ids_rmpad
,
**
multi_modal_inputs
,
use_cache
=
False
,
)
# prevent model thinks we are generating
values_rmpad
=
output
.
logits
values_rmpad
=
values_rmpad
.
squeeze
(
0
)
# (total_nnz)
# gather output if sp > 1
if
self
.
config
.
ulysses_sequence_parallel_size
>
1
:
values_rmpad
=
gather_outputs_and_unpad
(
values_rmpad
,
gather_dim
=
0
,
unpad_dim
=
0
,
padding_size
=
pad_size
)
# pad it back
values
=
pad_input
(
values_rmpad
,
indices
=
indices
,
batch
=
batch_size
,
seqlen
=
seqlen
).
squeeze
(
-
1
)
values
=
values
[:,
-
response_length
-
1
:
-
1
]
else
:
output
=
self
.
critic_module
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
**
vision
_inputs
,
**
multi_modal
_inputs
,
use_cache
=
False
,
)
values
:
torch
.
Tensor
=
output
.
logits
...
...
@@ -81,7 +129,12 @@ class DataParallelPPOCritic(BasePPOCritic):
self
.
critic_module
.
parameters
(),
max_norm
=
self
.
config
.
max_grad_norm
)
self
.
critic_optimizer
.
step
()
if
not
torch
.
isfinite
(
grad_norm
):
print
(
"Gradient norm is not finite. Skip update."
)
else
:
self
.
critic_optimizer
.
step
()
self
.
critic_optimizer
.
zero_grad
()
return
grad_norm
@
torch
.
no_grad
()
...
...
@@ -89,18 +142,21 @@ class DataParallelPPOCritic(BasePPOCritic):
self
.
critic_module
.
eval
()
select_keys
=
[
"responses"
,
"input_ids"
,
"attention_mask"
,
"position_ids"
]
if
"
pixel_value
s"
in
data
.
non_tensor_batch
.
keys
():
non_tensor_select_keys
=
[
"
pixel_values"
,
"image_grid_thw
"
]
if
"
multi_modal_input
s"
in
data
.
non_tensor_batch
.
keys
():
non_tensor_select_keys
=
[
"
multi_modal_inputs
"
]
else
:
non_tensor_select_keys
=
None
non_tensor_select_keys
=
[]
micro_batches
=
data
.
select
(
select_keys
,
non_tensor_select_keys
).
split
(
self
.
config
.
micro_batch_size_per_device_for_experience
)
values_lst
=
[]
for
micro_batch
in
tqdm
(
micro_batches
,
"Compute values"
,
disable
=
(
self
.
rank
!=
0
)):
micro_batch
.
to
(
"cuda"
)
values
=
self
.
_forward_micro_batch
(
micro_batch
)
if
self
.
rank
==
0
:
micro_batches
=
tqdm
(
micro_batches
,
desc
=
"Compute values"
,
position
=
2
)
for
micro_batch
in
micro_batches
:
model_inputs
=
{
**
micro_batch
.
batch
,
**
micro_batch
.
non_tensor_batch
}
values
=
self
.
_forward_micro_batch
(
model_inputs
)
values_lst
.
append
(
values
)
values
=
torch
.
concat
(
values_lst
,
dim
=
0
)
...
...
@@ -114,55 +170,56 @@ class DataParallelPPOCritic(BasePPOCritic):
self
.
critic_module
.
train
()
select_keys
=
[
"input_ids"
,
"responses"
,
"attention_mask"
,
"position_ids"
,
"values"
,
"returns"
]
if
"
pixel_value
s"
in
data
.
non_tensor_batch
.
keys
():
non_tensor_select_keys
=
[
"
pixel_values"
,
"image_grid_thw
"
]
if
"
multi_modal_input
s"
in
data
.
non_tensor_batch
.
keys
():
non_tensor_select_keys
=
[
"
multi_modal_inputs
"
]
else
:
non_tensor_select_keys
=
None
non_tensor_select_keys
=
[]
# TODO (yaowei): support ppo epochs
# Split to make minibatch iterator for updating the actor
# See PPO paper for details. https://arxiv.org/abs/1707.06347
mini_batches
=
data
.
select
(
select_keys
,
non_tensor_select_keys
).
split
(
self
.
config
.
global_batch_size_per_device
)
metrics
=
defaultdict
(
list
)
n
=
len
(
mini_batches
)
for
i
,
mini_batch
in
enumerate
(
mini_batches
):
gradient_accumulation
=
(
self
.
config
.
global_batch_size_per_device
//
self
.
config
.
micro_batch_size_per_device_for_update
)
micro_batches
=
mini_batch
.
split
(
self
.
config
.
micro_batch_size_per_device_for_update
)
self
.
critic_optimizer
.
zero_grad
()
for
micro_batch
in
tqdm
(
micro_batches
,
desc
=
f
"Update critic [
{
i
+
1
}
/
{
n
}
]"
,
disable
=
(
self
.
rank
!=
0
)):
micro_batch
.
to
(
"cuda"
)
model_inputs
=
{
**
micro_batch
.
batch
,
**
micro_batch
.
non_tensor_batch
}
responses
=
model_inputs
[
"responses"
]
attention_mask
=
model_inputs
[
"attention_mask"
]
values
=
model_inputs
[
"values"
]
returns
=
model_inputs
[
"returns"
]
response_length
=
responses
.
size
(
1
)
eos_mask
=
attention_mask
[:,
-
response_length
-
1
:
-
1
]
vpreds
=
self
.
_forward_micro_batch
(
data
)
vf_loss
,
vf_clipfrac
=
core_algos
.
compute_value_loss
(
vpreds
=
vpreds
,
values
=
values
,
returns
=
returns
,
eos_mask
=
eos_mask
,
cliprange_value
=
self
.
config
.
cliprange_value
,
)
loss
=
vf_loss
/
gradient_accumulation
loss
.
backward
()
batch_metrics
=
{
"critic/vf_loss"
:
vf_loss
.
detach
().
item
(),
"critic/vf_clipfrac"
:
vf_clipfrac
.
detach
().
item
(),
"critic/vpred_mean"
:
masked_mean
(
vpreds
,
eos_mask
).
detach
().
item
(),
}
append_to_dict
(
metrics
,
batch_metrics
)
for
_
in
range
(
self
.
config
.
ppo_epochs
):
if
self
.
rank
==
0
:
mini_batches
=
tqdm
(
mini_batches
,
desc
=
"Train mini-batches"
,
position
=
2
)
grad_norm
=
self
.
_optimizer_step
()
append_to_dict
(
metrics
,
{
"critic/grad_norm"
:
grad_norm
.
detach
().
item
()})
for
mini_batch
in
mini_batches
:
gradient_accumulation
=
(
self
.
config
.
global_batch_size_per_device
//
self
.
config
.
micro_batch_size_per_device_for_update
)
micro_batches
=
mini_batch
.
split
(
self
.
config
.
micro_batch_size_per_device_for_update
)
if
self
.
rank
==
0
:
micro_batches
=
tqdm
(
micro_batches
,
desc
=
"Update critic"
,
position
=
3
)
for
micro_batch
in
micro_batches
:
model_inputs
=
{
**
micro_batch
.
batch
,
**
micro_batch
.
non_tensor_batch
}
responses
=
model_inputs
[
"responses"
]
attention_mask
=
model_inputs
[
"attention_mask"
]
values
=
model_inputs
[
"values"
]
returns
=
model_inputs
[
"returns"
]
response_length
=
responses
.
size
(
1
)
eos_mask
=
attention_mask
[:,
-
response_length
-
1
:
-
1
]
# shift left for value computation
vpreds
=
self
.
_forward_micro_batch
(
model_inputs
)
vf_loss
,
vf_clipfrac
=
core_algos
.
compute_value_loss
(
vpreds
=
vpreds
,
returns
=
returns
,
values
=
values
,
eos_mask
=
eos_mask
,
cliprange_value
=
self
.
config
.
cliprange_value
,
)
loss
=
vf_loss
/
gradient_accumulation
loss
.
backward
()
batch_metrics
=
{
"critic/vf_loss"
:
vf_loss
.
detach
().
item
(),
"critic/vf_clipfrac"
:
vf_clipfrac
.
detach
().
item
(),
"critic/vpred_mean"
:
VF
.
masked_mean
(
vpreds
,
eos_mask
).
detach
().
item
(),
}
append_to_dict
(
metrics
,
batch_metrics
)
grad_norm
=
self
.
_optimizer_step
()
append_to_dict
(
metrics
,
{
"critic/grad_norm"
:
grad_norm
.
detach
().
item
()})
self
.
critic_optimizer
.
zero_grad
()
return
metrics
verl/workers/fsdp_workers.py
View file @
c132cbcb
...
...
@@ -15,8 +15,10 @@
The main entry point to run the PPO algorithm
"""
from
typing
import
Literal
from
typing
import
Literal
,
Optional
,
Union
import
numpy
as
np
import
psutil
import
torch
import
torch.distributed
as
dist
from
accelerate
import
init_empty_weights
...
...
@@ -34,13 +36,13 @@ from transformers import (
)
from
transformers.modeling_utils
import
no_init_weights
from
verl
import
DataProto
from
verl.single_controller.base
import
Worker
from
verl
.single_controller.base
.decorator
import
Dispatch
,
regist
er
from
verl.utils
import
get_tokenizer
,
get_processo
r
from
verl
.utils.checkpoint.fsdp_checkpoint_manager
import
FSDPCheckpointManager
from
verl
.utils.flops_counter
import
FlopsCounter
from
verl
.utils.fsdp_utils
import
(
from
..models.monkey_patch
import
apply_ulysses_patch
from
..protocol
import
DataProto
from
.
.single_controller.base
import
Work
er
from
..single_controller.base.decorator
import
Dispatch
,
registe
r
from
.
.utils.checkpoint.fsdp_checkpoint_manager
import
FSDPCheckpointManager
from
.
.utils.flops_counter
import
FlopsCounter
from
.
.utils.fsdp_utils
import
(
get_fsdp_wrap_policy
,
get_init_fn
,
load_fsdp_model
,
...
...
@@ -48,16 +50,16 @@ from verl.utils.fsdp_utils import (
offload_fsdp_model
,
offload_fsdp_optimizer
,
)
from
verl
.utils.model_utils
import
print_model_size
from
verl
.utils.
performance
import
log_gpu_memory_usage
from
verl
.utils.torch_dtypes
import
PrecisionType
from
verl
.utils.torch_functional
import
get_constant_schedule_with_warmup
from
verl.workers
.actor
import
DataParallelPPOActor
from
verl.workers
.config
import
FSDPConfig
,
ModelConfig
,
OptimConfig
,
WorkerConfig
from
verl.workers
.critic
import
DataParallelPPOCritic
from
verl.workers
.rollout.vllm_rollout
import
vLLMRollout
from
verl.workers
.sharding_manager
import
FSDPVLLMShardingManager
from
verl.workers
.sharding_manager.fsdp_ulysses
import
FSDPUlyssesShardingManager
from
.
.utils.model_utils
import
print_gpu_memory_usage
,
print_model_size
from
.
.utils.
tokenizer
import
get_processor
,
get_tokenizer
from
.
.utils.torch_dtypes
import
PrecisionType
from
.
.utils.torch_functional
import
AnyPrecisionAdamW
,
get_constant_schedule_with_warmup
from
.actor
import
DataParallelPPOActor
from
.config
import
ActorConfig
,
CriticConfig
,
FSDPConfig
,
ModelConfig
,
OptimConfig
,
RefConfig
,
WorkerConfig
from
.critic
import
DataParallelPPOCritic
from
.rollout.vllm_rollout
import
vLLMRollout
from
.sharding_manager
import
FSDPVLLMShardingManager
from
.sharding_manager.fsdp_ulysses
import
FSDPUlyssesShardingManager
class
FSDPWorker
(
Worker
):
...
...
@@ -68,77 +70,95 @@ class FSDPWorker(Worker):
):
super
().
__init__
()
self
.
config
=
config
self
.
role
=
role
if
not
dist
.
is_initialized
():
dist
.
init_process_group
(
backend
=
"nccl"
)
# build device mesh for FSDP
# TODO: support FSDP hybrid shard for larger model
self
.
_is_actor
=
self
.
role
in
[
"actor"
,
"actor_rollout"
,
"actor_rollout_ref"
]
self
.
_is_critic
=
self
.
role
==
"critic"
self
.
_is_rollout
=
self
.
role
in
[
"rollout"
,
"actor_rollout"
,
"actor_rollout_ref"
]
self
.
_is_ref
=
self
.
role
in
[
"ref"
,
"actor_rollout_ref"
]
self
.
_use_param_offload
=
False
self
.
_use_optimizer_offload
=
False
if
self
.
_is_actor
:
self
.
_use_param_offload
=
self
.
config
.
actor
.
offload
.
offload_params
self
.
_use_optimizer_offload
=
self
.
config
.
actor
.
offload
.
offload_optimizer
self
.
_init_config
(
self
.
config
.
actor
,
"actor"
)
elif
self
.
_is_critic
:
self
.
_use_param_offload
=
self
.
config
.
critic
.
offload
.
offload_params
self
.
_use_optimizer_offload
=
self
.
config
.
critic
.
offload
.
offload_optimizer
self
.
_init_config
(
self
.
config
.
critic
,
"critic"
)
elif
self
.
_is_ref
:
# NOTE: it seems that manual offload is slower than FSDP offload
self
.
_use_param_offload
=
self
.
config
.
ref
.
offload
.
offload_params
self
.
_init_config
(
self
.
config
.
ref
,
"ref"
)
def
_init_config
(
self
,
config
:
Union
[
ActorConfig
,
CriticConfig
,
RefConfig
],
role
:
Literal
[
"actor"
,
"critic"
,
"ref"
]
):
world_size
=
dist
.
get_world_size
()
self
.
device_mesh
=
init_device_mesh
(
"cuda"
,
mesh_shape
=
(
world_size
,),
mesh_dim_names
=
[
"fsdp"
])
fsdp_size
=
config
.
fsdp
.
fsdp_size
if
fsdp_size
<=
0
or
fsdp_size
>=
world_size
:
self
.
device_mesh
=
init_device_mesh
(
"cuda"
,
mesh_shape
=
(
world_size
,),
mesh_dim_names
=
(
"fsdp"
,))
else
:
# hsdp
self
.
device_mesh
=
init_device_mesh
(
"cuda"
,
mesh_shape
=
(
world_size
//
fsdp_size
,
fsdp_size
),
mesh_dim_names
=
(
"ddp"
,
"fsdp"
)
)
# build device mesh for Ulysses Sequence Parallel
self
.
ulysses_sequence_parallel_size
=
self
.
config
.
actor
.
ulysses_sequence_parallel_size
if
self
.
ulysses_sequence_parallel_size
>
1
:
if
config
.
ulysses_sequence_parallel_size
>
1
:
self
.
ulysses_device_mesh
=
init_device_mesh
(
"cuda"
,
mesh_shape
=
(
world_size
//
self
.
ulysses_sequence_parallel_size
,
self
.
ulysses_sequence_parallel_size
),
mesh_dim_names
=
[
"dp"
,
"sp"
],
mesh_shape
=
(
world_size
//
config
.
ulysses_sequence_parallel_size
,
config
.
ulysses_sequence_parallel_size
,
),
mesh_dim_names
=
(
"dp"
,
"sp"
),
)
else
:
self
.
ulysses_device_mesh
=
None
self
.
ulysses_sharding_manager
=
FSDPUlyssesShardingManager
(
self
.
ulysses_device_mesh
)
self
.
role
=
role
self
.
_is_actor
=
self
.
role
in
[
"actor"
,
"actor_rollout"
,
"actor_rollout_ref"
]
self
.
_is_critic
=
self
.
role
==
"critic"
self
.
_is_rollout
=
self
.
role
in
[
"rollout"
,
"actor_rollout"
,
"actor_rollout_ref"
]
self
.
_is_ref
=
self
.
role
in
[
"ref"
,
"actor_rollout_ref"
]
if
not
hasattr
(
config
,
"global_batch_size"
):
# ref model
return
self
.
_use_param_offload
=
False
self
.
_use_optimizer_offload
=
False
if
self
.
_is_actor
:
self
.
_use_param_offload
=
self
.
config
.
actor
.
offload
.
param_offload
self
.
_use_optimizer_offload
=
self
.
config
.
actor
.
offload
.
optimizer_offload
elif
self
.
_is_critic
:
self
.
_use_param_offload
=
self
.
config
.
critic
.
offload
.
param_offload
self
.
_use_optimizer_offload
=
self
.
config
.
critic
.
offload
.
optimizer_offload
elif
self
.
_is_ref
:
# NOTE: it seems that manual offload is slowly than FSDP offload
self
.
_use_param_offload
=
self
.
config
.
ref
.
offload
.
param_offload
if
self
.
config
.
rollout
.
n
>
1
:
config
.
global_batch_size
*=
self
.
config
.
rollout
.
n
self
.
print_rank0
(
f
"
{
role
}
will use global batch size
{
config
.
global_batch_size
}
."
)
# normalize config
if
self
.
_is_actor
:
self
.
config
.
actor
.
global_batch_size
*=
self
.
config
.
rollout
.
n
self
.
config
.
actor
.
global_batch_size_per_device
=
(
self
.
config
.
actor
.
global_batch_size
//
self
.
device_mesh
.
shape
[
0
]
*
self
.
ulysses_sequence_parallel_size
)
assert
(
self
.
config
.
actor
.
global_batch_size_per_device
%
self
.
config
.
actor
.
micro_batch_size_per_device_for_update
==
0
)
elif
self
.
_is_critic
:
self
.
config
.
critic
.
global_batch_size
*=
self
.
config
.
rollout
.
n
self
.
config
.
critic
.
global_batch_size_per_device
=
(
self
.
config
.
critic
.
global_batch_size
//
self
.
device_mesh
.
shape
[
0
]
*
self
.
ulysses_sequence_parallel_size
)
assert
(
self
.
config
.
critic
.
global_batch_size_per_device
%
self
.
config
.
critic
.
micro_batch_size_per_device_for_update
==
0
)
config
.
global_batch_size_per_device
=
(
config
.
global_batch_size
//
self
.
device_mesh
.
size
()
*
config
.
ulysses_sequence_parallel_size
)
if
config
.
global_batch_size_per_device
==
0
:
raise
ValueError
(
f
"
{
role
}
global batch size must be larger than num gpus."
)
if
config
.
global_batch_size_per_device
%
config
.
micro_batch_size_per_device_for_update
!=
0
:
raise
ValueError
(
f
"
{
role
}
global batch size per device must be divisible by the micro batch size."
)
if
(
config
.
fsdp
.
enable_cpu_offload
and
config
.
global_batch_size_per_device
!=
config
.
micro_batch_size_per_device_for_update
):
raise
ValueError
(
f
"
{
role
}
cannot use FSDP's CPU offload when gradient accumulation is enabled."
)
def
_build_model_optimizer
(
self
,
model_config
:
ModelConfig
,
fsdp_config
:
FSDPConfig
,
optim_config
:
OptimConfig
,
optim_config
:
Optional
[
OptimConfig
]
,
padding_free
:
bool
=
False
,
)
->
None
:
self
.
tokenizer
=
get_tokenizer
(
model_config
.
tokenizer_path
,
trust_remote_code
=
model_config
.
trust_remote_code
)
self
.
processor
=
get_processor
(
model_config
.
tokenizer_path
)
self
.
tokenizer
=
get_tokenizer
(
model_config
.
tokenizer_path
,
trust_remote_code
=
model_config
.
trust_remote_code
,
use_fast
=
True
,
)
self
.
processor
=
get_processor
(
model_config
.
tokenizer_path
,
trust_remote_code
=
model_config
.
trust_remote_code
,
use_fast
=
True
,
)
self
.
model_config
=
AutoConfig
.
from_pretrained
(
model_config
.
model_path
,
trust_remote_code
=
model_config
.
trust_remote_code
,
...
...
@@ -156,7 +176,8 @@ class FSDPWorker(Worker):
self
.
print_rank0
(
f
"Model config:
{
self
.
model_config
}
"
)
if
padding_free
:
raise
NotImplementedError
(
"Padding free is not implemented yet."
)
apply_ulysses_patch
(
self
.
model_config
.
model_type
)
self
.
print_rank0
(
"Ulysses patch applied!"
)
if
fsdp_config
.
torch_dtype
is
None
:
torch_dtype
=
torch
.
float32
if
self
.
_is_actor
or
self
.
_is_critic
else
torch
.
bfloat16
...
...
@@ -170,13 +191,13 @@ class FSDPWorker(Worker):
else
:
auto_class
=
AutoModelForCausalLM
if
self
.
rank
==
0
:
if
(
not
fsdp_config
.
enable_rank0_init
)
or
self
.
device_mesh
.
get_local_rank
(
"fsdp"
)
==
0
:
model
=
auto_class
.
from_pretrained
(
model_config
.
model_path
,
config
=
self
.
model_config
,
torch_dtype
=
torch_dtype
,
attn_implementation
=
"flash_attention_2"
,
device_map
=
"cpu"
,
device_map
=
"cpu"
if
fsdp_config
.
enable_rank0_init
else
"cuda"
,
low_cpu_mem_usage
=
True
,
trust_remote_code
=
model_config
.
trust_remote_code
,
)
...
...
@@ -195,29 +216,50 @@ class FSDPWorker(Worker):
if
model_config
.
enable_gradient_checkpointing
:
model
.
gradient_checkpointing_enable
(
gradient_checkpointing_kwargs
=
{
"use_reentrant"
:
False
})
dist
.
barrier
()
if
self
.
rank
==
0
:
print_model_size
(
model
)
if
not
(
self
.
_is_actor
or
self
.
_is_critic
):
model
.
requires_grad_
(
False
)
if
model_config
.
freeze_vision_tower
:
if
hasattr
(
model
,
"visual"
):
model
.
visual
.
requires_grad_
(
False
)
fsdp_config
.
use_orig_params
=
True
self
.
print_rank0
(
"Vision tower is set to not trainable."
)
else
:
self
.
print_rank0
(
"No vision tower found."
)
log_gpu_memory_usage
(
"After init from huggingface model"
)
dist
.
barrier
()
print_model_size
(
model
)
print_gpu_memory_usage
(
"After huggingface model init"
)
mixed_precision
=
MixedPrecision
(
param_dtype
=
PrecisionType
.
to_dtype
(
fsdp_config
.
mp_param_dtype
),
reduce_dtype
=
PrecisionType
.
to_dtype
(
fsdp_config
.
mp_reduce_dtype
),
buffer_dtype
=
PrecisionType
.
to_dtype
(
fsdp_config
.
mp_buffer_dtype
),
)
auto_wrap_policy
=
get_fsdp_wrap_policy
(
model
)
if
fsdp_config
.
enable_full_shard
:
sharding_strategy
=
ShardingStrategy
.
FULL_SHARD
self
.
print_rank0
(
f
"FSDP wrap policy:
{
auto_wrap_policy
}
."
)
if
self
.
device_mesh
.
ndim
==
2
:
if
fsdp_config
.
enable_full_shard
:
sharding_strategy
=
ShardingStrategy
.
HYBRID_SHARD
else
:
sharding_strategy
=
ShardingStrategy
.
_HYBRID_SHARD_ZERO2
else
:
sharding_strategy
=
ShardingStrategy
.
SHARD_GRAD_OP
if
fsdp_config
.
enable_full_shard
:
sharding_strategy
=
ShardingStrategy
.
FULL_SHARD
else
:
sharding_strategy
=
ShardingStrategy
.
SHARD_GRAD_OP
if
fsdp_config
.
param_offload
or
fsdp_config
.
optimizer
_offload
:
cpu_offload
=
CPUOffload
(
offload_params
=
fsdp_config
.
param_offload
)
if
fsdp_config
.
enable_cpu
_offload
:
cpu_offload
=
CPUOffload
(
offload_params
=
True
)
else
:
cpu_offload
=
None
if
self
.
rank
==
0
:
print
(
f
"FSDP wrap policy:
{
auto_wrap_policy
}
."
)
if
fsdp_config
.
enable_rank0_init
:
sync_module_states
=
True
param_init_fn
=
get_init_fn
(
model
,
device
=
"cuda"
)
if
self
.
rank
!=
0
else
None
else
:
sync_module_states
=
False
param_init_fn
=
None
self
.
fsdp_module
=
FSDP
(
model
,
...
...
@@ -225,53 +267,60 @@ class FSDPWorker(Worker):
cpu_offload
=
cpu_offload
,
auto_wrap_policy
=
auto_wrap_policy
,
mixed_precision
=
mixed_precision
,
param_init_fn
=
get
_init_fn
(
model
,
device
=
"cuda"
)
if
self
.
rank
!=
0
else
None
,
param_init_fn
=
param
_init_fn
,
device_id
=
torch
.
cuda
.
current_device
(),
sync_module_states
=
True
,
sync_module_states
=
sync_module_states
,
forward_prefetch
=
False
,
use_orig_params
=
False
,
use_orig_params
=
fsdp_config
.
use_orig_params
,
device_mesh
=
self
.
device_mesh
,
)
log
_gpu_memory_usage
(
"After
Actor FSDP
init"
)
print
_gpu_memory_usage
(
"After
FSDP module
init"
)
if
self
.
_is_actor
or
self
.
_is_critic
:
self
.
optimizer
=
torch
.
optim
.
AdamW
(
self
.
fsdp_module
.
parameters
(),
lr
=
optim_config
.
lr
,
betas
=
optim_config
.
betas
,
weight_decay
=
optim_config
.
weight_decay
,
)
num_warmup_steps
=
int
(
optim_config
.
lr_warmup_steps_ratio
*
optim_config
.
training_steps
)
if
optim_config
.
strategy
==
"adamw"
:
self
.
optimizer
=
torch
.
optim
.
AdamW
(
self
.
fsdp_module
.
parameters
(),
lr
=
optim_config
.
lr
,
betas
=
optim_config
.
betas
,
weight_decay
=
optim_config
.
weight_decay
,
fused
=
True
,
)
elif
optim_config
.
strategy
==
"adamw_bf16"
:
self
.
optimizer
=
AnyPrecisionAdamW
(
self
.
fsdp_module
.
parameters
(),
lr
=
optim_config
.
lr
,
betas
=
optim_config
.
betas
,
weight_decay
=
optim_config
.
weight_decay
,
)
else
:
raise
NotImplementedError
(
f
"Optimizer
{
optim_config
.
strategy
}
not supported."
)
num_warmup_steps
=
int
(
optim_config
.
lr_warmup_ratio
*
optim_config
.
training_steps
)
self
.
lr_scheduler
=
get_constant_schedule_with_warmup
(
optimizer
=
self
.
optimizer
,
num_warmup_steps
=
num_warmup_steps
)
print_gpu_memory_usage
(
"After optimizer init"
)
else
:
self
.
optimizer
,
self
.
lr_scheduler
=
None
,
None
log_gpu_memory_usage
(
"After actor optimizer init"
)
def
_build_rollout
(
self
)
->
None
:
# TODO(sgm): support FSDP hybrid shard for larger model
tp_size
=
self
.
config
.
rollout
.
tensor_parallel_size
dp_size
=
self
.
world_size
//
tp_size
assert
self
.
world_size
%
tp_size
==
0
,
(
f
"rollout world
_
size:
{
self
.
world_size
}
is not divisible by tp
_
size:
{
tp_size
}
"
f
"rollout world
size:
{
self
.
world_size
}
is not divisible by tp
size:
{
tp_size
}
"
)
rollout_device_mesh
=
init_device_mesh
(
"cuda"
,
mesh_shape
=
(
dp_size
,
tp_size
),
mesh_dim_names
=
[
"dp"
,
"tp"
])
log_gpu_memory_usage
(
"Before building vllm rollout"
)
rollout_device_mesh
=
init_device_mesh
(
"cuda"
,
mesh_shape
=
(
dp_size
,
tp_size
),
mesh_dim_names
=
(
"dp"
,
"tp"
))
self
.
rollout
=
vLLMRollout
(
model_path
=
self
.
config
.
actor
.
model
.
model_path
,
config
=
self
.
config
.
rollout
,
tokenizer
=
self
.
tokenizer
,
)
log_gpu_memory_usage
(
"After building vllm rollout"
)
self
.
rollout_sharding_manager
=
FSDPVLLMShardingManager
(
module
=
self
.
fsdp_module
,
inference_engine
=
self
.
rollout
.
inference_engine
,
device_mesh
=
rollout_device_mesh
,
)
log
_gpu_memory_usage
(
"After
building sharding manager
"
)
print
_gpu_memory_usage
(
"After
vllm init
"
)
@
register
(
dispatch_mode
=
Dispatch
.
ONE_TO_ALL
)
def
init_model
(
self
):
...
...
@@ -280,11 +329,21 @@ class FSDPWorker(Worker):
fsdp_config
=
self
.
config
.
critic
.
fsdp
optim_config
=
self
.
config
.
critic
.
optim
padding_free
=
self
.
config
.
critic
.
padding_free
else
:
role
=
"critic"
elif
self
.
_is_actor
:
model_config
=
self
.
config
.
actor
.
model
fsdp_config
=
self
.
config
.
actor
.
fsdp
optim_config
=
self
.
config
.
actor
.
optim
padding_free
=
self
.
config
.
actor
.
padding_free
role
=
"actor"
elif
self
.
_is_ref
:
model_config
=
self
.
config
.
actor
.
model
fsdp_config
=
self
.
config
.
ref
.
fsdp
optim_config
=
None
padding_free
=
self
.
config
.
ref
.
padding_free
role
=
"ref"
else
:
raise
ValueError
(
f
"Unknown role
{
role
}
."
)
if
self
.
_is_actor
or
self
.
_is_critic
or
self
.
_is_ref
:
self
.
_build_model_optimizer
(
...
...
@@ -293,11 +352,13 @@ class FSDPWorker(Worker):
optim_config
=
optim_config
,
padding_free
=
padding_free
,
)
# get the original unwrapped module
self
.
unwrapped_model
=
self
.
fsdp_module
.
_fsdp_wrapped_module
if
self
.
_use_optimizer_offload
and
not
self
.
_is_critic
:
if
self
.
_use_param_offload
:
offload_fsdp_model
(
self
.
fsdp_module
)
print_gpu_memory_usage
(
f
"After offload
{
role
}
model during init"
)
if
self
.
_use_optimizer_offload
:
offload_fsdp_optimizer
(
optimizer
=
self
.
optimizer
)
log
_gpu_memory_usage
(
"After offload
actor
optimizer during init"
)
print
_gpu_memory_usage
(
f
"After offload
{
role
}
optimizer during init"
)
if
self
.
_is_actor
:
self
.
actor
=
DataParallelPPOActor
(
...
...
@@ -317,7 +378,10 @@ class FSDPWorker(Worker):
self
.
_build_rollout
()
if
self
.
_is_ref
:
self
.
ref_policy
=
DataParallelPPOActor
(
config
=
self
.
config
.
ref
,
actor_module
=
self
.
fsdp_module
)
self
.
ref_policy
=
DataParallelPPOActor
(
config
=
self
.
config
.
ref
,
actor_module
=
self
.
fsdp_module
,
)
if
self
.
_is_actor
or
self
.
_is_critic
:
self
.
flops_counter
=
FlopsCounter
(
self
.
model_config
)
...
...
@@ -325,42 +389,37 @@ class FSDPWorker(Worker):
model
=
self
.
fsdp_module
,
optimizer
=
self
.
optimizer
,
lr_scheduler
=
self
.
lr_scheduler
,
tokenizer
=
self
.
tokenizer
,
processor
=
self
.
processor
processing_class
=
self
.
processor
if
self
.
processor
is
not
None
else
self
.
tokenizer
,
)
torch
.
cuda
.
empty_cache
()
@
register
(
dispatch_mode
=
Dispatch
.
ONE_TO_ALL
)
def
save_checkpoint
(
self
,
path
:
str
,
global_step
:
int
=
0
,
remove_previous_ckpt
:
bool
=
False
):
def
save_checkpoint
(
self
,
path
:
str
):
assert
self
.
_is_actor
or
self
.
_is_critic
if
self
.
_use_param_offload
:
load_fsdp_model
(
self
.
fsdp_module
)
self
.
checkpoint_manager
.
save_checkpoint
(
local_path
=
path
,
global_step
=
global_step
,
remove_previous_ckpt
=
remove_previous_ckpt
,
)
self
.
checkpoint_manager
.
save_checkpoint
(
path
)
dist
.
barrier
()
if
self
.
_use_param_offload
:
offload_fsdp_model
(
self
.
fsdp_module
)
@
register
(
dispatch_mode
=
Dispatch
.
ONE_TO_ALL
)
def
load_checkpoint
(
self
,
path
:
str
,
del_local_after_load
:
bool
=
True
):
def
load_checkpoint
(
self
,
path
:
str
):
if
self
.
_use_param_offload
:
load_fsdp_model
(
self
.
fsdp_module
)
self
.
checkpoint_manager
.
load_checkpoint
(
path
=
path
,
del_local_after_load
=
del_local_after_load
)
self
.
checkpoint_manager
.
load_checkpoint
(
path
)
dist
.
barrier
()
if
self
.
_use_param_offload
:
offload_fsdp_model
(
self
.
fsdp_module
)
"""ActorRolloutRefWorker"""
if
self
.
_use_optimizer_offload
:
offload_fsdp_optimizer
(
self
.
optimizer
)
@
register
(
dispatch_mode
=
Dispatch
.
DP_COMPUTE_PROTO
)
def
update_actor
(
self
,
data
:
DataProto
):
assert
self
.
_is_actor
data
=
data
.
to
(
torch
.
cuda
.
current_device
())
if
self
.
_use_param_offload
:
load_fsdp_model
(
self
.
fsdp_module
)
...
...
@@ -368,7 +427,6 @@ class FSDPWorker(Worker):
if
self
.
_use_optimizer_offload
:
load_fsdp_optimizer
(
optimizer
=
self
.
optimizer
)
log_gpu_memory_usage
(
"Before update policy"
)
with
self
.
ulysses_sharding_manager
:
data
=
self
.
ulysses_sharding_manager
.
preprocess_data
(
data
=
data
)
with
Timer
(
name
=
"update_policy"
,
logger
=
None
)
as
timer
:
...
...
@@ -377,17 +435,27 @@ class FSDPWorker(Worker):
delta_time
=
timer
.
last
global_num_tokens
=
data
.
meta_info
[
"global_token_num"
]
estimated_flops
,
promised_flops
=
self
.
flops_counter
.
estimate_flops
(
global_num_tokens
,
delta_time
)
metrics
[
"mfu/actor"
]
=
estimated_flops
*
self
.
config
.
actor
.
ppo_epochs
/
promised_flops
/
self
.
world_size
metrics
[
"perf/mfu_actor"
]
=
(
estimated_flops
*
self
.
config
.
actor
.
ppo_epochs
/
(
promised_flops
*
self
.
world_size
)
)
metrics
[
"perf/max_memory_allocated_gb"
]
=
(
torch
.
cuda
.
max_memory_allocated
()
-
self
.
rollout_sharding_manager
.
freed_bytes
)
/
(
1024
**
3
)
metrics
[
"perf/max_memory_reserved_gb"
]
=
(
torch
.
cuda
.
max_memory_reserved
()
-
self
.
rollout_sharding_manager
.
freed_bytes
)
/
(
1024
**
3
)
metrics
[
"perf/cpu_memory_used_gb"
]
=
psutil
.
virtual_memory
().
used
/
(
1024
**
3
)
self
.
lr_scheduler
.
step
()
lr
=
self
.
lr_scheduler
.
get_last_lr
()[
0
]
metrics
[
"actor/lr"
]
=
lr
log_gpu_memory_usage
(
"After update policy"
)
# TODO: here, we should return all metrics
output
=
DataProto
(
meta_info
=
{
"metrics"
:
metrics
})
output
=
self
.
ulysses_sharding_manager
.
postprocess_data
(
data
=
output
)
output
=
output
.
to
(
"cpu"
)
# Metrics should be in non_tensor_batch instead of meta_info, as DataProto not concat meta_info.
output
=
DataProto
(
non_tensor_batch
=
{
key
:
np
.
array
([
value
]
if
np
.
isscalar
(
value
)
else
value
)
for
key
,
value
in
metrics
.
items
()
}
)
if
self
.
_use_param_offload
:
offload_fsdp_model
(
self
.
fsdp_module
)
...
...
@@ -395,7 +463,7 @@ class FSDPWorker(Worker):
if
self
.
_use_optimizer_offload
:
offload_fsdp_optimizer
(
optimizer
=
self
.
optimizer
)
torch
.
cuda
.
empty_cache
(
)
output
=
output
.
to
(
"cpu"
)
return
output
@
register
(
dispatch_mode
=
Dispatch
.
DP_COMPUTE_PROTO
)
...
...
@@ -422,22 +490,17 @@ class FSDPWorker(Worker):
if
self
.
_use_optimizer_offload
:
offload_fsdp_optimizer
(
optimizer
=
self
.
optimizer
)
log_gpu_memory_usage
(
"After entering rollout sharding manager"
)
prompts
=
self
.
rollout_sharding_manager
.
preprocess_data
(
prompts
)
output
=
self
.
rollout
.
generate_sequences
(
prompts
=
prompts
)
log_gpu_memory_usage
(
"After rollout generation"
)
output
=
self
.
rollout_sharding_manager
.
postprocess_data
(
output
)
output
=
output
.
to
(
"cpu"
)
torch
.
cuda
.
empty_cache
()
# clear kv cache
log_gpu_memory_usage
(
"After recompute log prob"
)
return
output
@
register
(
dispatch_mode
=
Dispatch
.
DP_COMPUTE_PROTO
)
def
compute_log_prob
(
self
,
data
:
DataProto
):
def
compute_log_prob
s
(
self
,
data
:
DataProto
):
assert
self
.
_is_actor
data
=
data
.
to
(
torch
.
cuda
.
current_device
())
if
self
.
_use_param_offload
:
load_fsdp_model
(
self
.
fsdp_module
)
...
...
@@ -452,8 +515,6 @@ class FSDPWorker(Worker):
)
output
=
self
.
ulysses_sharding_manager
.
postprocess_data
(
output
)
output
=
output
.
to
(
"cpu"
)
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
if
self
.
world_size
>
1
:
...
...
@@ -462,13 +523,13 @@ class FSDPWorker(Worker):
if
self
.
_use_param_offload
:
offload_fsdp_model
(
self
.
fsdp_module
)
torch
.
cuda
.
empty_cache
()
log_gpu_memory_usage
(
"After compute_log_prob"
)
output
=
output
.
to
(
"cpu"
)
return
output
@
register
(
dispatch_mode
=
Dispatch
.
DP_COMPUTE_PROTO
)
def
compute_ref_log_prob
(
self
,
data
:
DataProto
):
def
compute_ref_log_prob
s
(
self
,
data
:
DataProto
):
assert
self
.
_is_ref
data
=
data
.
to
(
torch
.
cuda
.
current_device
())
if
self
.
_use_param_offload
:
load_fsdp_model
(
self
.
fsdp_module
)
...
...
@@ -476,11 +537,9 @@ class FSDPWorker(Worker):
with
self
.
ulysses_sharding_manager
:
data
=
self
.
ulysses_sharding_manager
.
preprocess_data
(
data
)
output
=
self
.
ref_policy
.
compute_log_prob
(
data
=
data
)
output
=
DataProto
.
from_dict
(
tensors
=
{
"ref_log_prob"
:
output
})
output
=
DataProto
.
from_dict
(
tensors
=
{
"ref_log_prob
s
"
:
output
})
output
=
self
.
ulysses_sharding_manager
.
postprocess_data
(
output
)
output
=
output
.
to
(
"cpu"
)
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
if
self
.
world_size
>
1
:
...
...
@@ -489,15 +548,13 @@ class FSDPWorker(Worker):
if
self
.
_use_param_offload
:
offload_fsdp_model
(
self
.
fsdp_module
)
torch
.
cuda
.
empty_cache
()
log_gpu_memory_usage
(
"After compute_ref_log_prob"
)
output
=
output
.
to
(
"cpu"
)
return
output
"""CriticWorker"""
@
register
(
dispatch_mode
=
Dispatch
.
DP_COMPUTE_PROTO
)
def
compute_values
(
self
,
data
:
DataProto
):
assert
self
.
_is_critic
data
=
data
.
to
(
torch
.
cuda
.
current_device
())
if
self
.
_use_param_offload
:
load_fsdp_model
(
self
.
fsdp_module
)
...
...
@@ -507,15 +564,15 @@ class FSDPWorker(Worker):
output
=
DataProto
.
from_dict
(
tensors
=
{
"values"
:
values
})
output
=
self
.
ulysses_sharding_manager
.
postprocess_data
(
data
=
output
)
output
=
output
.
to
(
"cpu"
)
if
self
.
_use_param_offload
:
offload_fsdp_model
(
self
.
fsdp_module
)
torch
.
cuda
.
empty_cache
(
)
output
=
output
.
to
(
"cpu"
)
return
output
@
register
(
dispatch_mode
=
Dispatch
.
DP_COMPUTE_PROTO
)
def
update_critic
(
self
,
data
:
DataProto
):
data
=
data
.
to
(
torch
.
cuda
.
current_device
())
if
self
.
_use_param_offload
:
load_fsdp_model
(
self
.
fsdp_module
)
...
...
@@ -530,21 +587,26 @@ class FSDPWorker(Worker):
delta_time
=
timer
.
last
global_num_tokens
=
data
.
meta_info
[
"global_token_num"
]
estimated_flops
,
promised_flops
=
self
.
flops_counter
.
estimate_flops
(
global_num_tokens
,
delta_time
)
metrics
[
"mfu/critic"
]
=
estimated_flops
*
self
.
config
.
actor
.
ppo_epochs
/
promised_flops
/
self
.
world_size
metrics
[
"perf/mfu_critic"
]
=
(
estimated_flops
*
self
.
config
.
actor
.
ppo_epochs
/
(
promised_flops
*
self
.
world_size
)
)
self
.
lr_scheduler
.
step
()
lr
=
self
.
lr_scheduler
.
get_last_lr
()[
0
]
metrics
[
"critic/lr"
]
=
lr
output
=
DataProto
(
batch
=
None
,
meta_info
=
{
"metrics"
:
metrics
})
output
=
self
.
ulysses_sharding_manager
.
postprocess_data
(
data
=
output
)
# Metrics should be in non_tensor_batch instead of meta_info, as DataProto not concat meta_info.
output
=
DataProto
(
non_tensor_batch
=
{
metric
:
np
.
array
([
value
]
if
np
.
isscalar
(
value
)
else
value
)
for
metric
,
value
in
metrics
.
items
()
}
)
output
=
output
.
to
(
"cpu"
)
if
self
.
_use_param_offload
:
offload_fsdp_model
(
self
.
fsdp_module
)
if
self
.
_use_optimizer_offload
:
offload_fsdp_optimizer
(
optimizer
=
self
.
optimizer
)
torch
.
cuda
.
empty_cache
(
)
output
=
output
.
to
(
"cpu"
)
return
output
verl/workers/reward/custom.py
View file @
c132cbcb
...
...
@@ -13,55 +13,48 @@
# limitations under the License.
from
collections
import
defaultdict
from
typing
import
Any
,
Callable
,
Dict
,
Tuple
,
TypedDict
import
torch
from
transformers
import
PreTrainedTokenizer
from
verl
import
DataProto
from
verl.utils.reward_score
import
math_compute_score
,
r1v_compute_score
from
...protocol
import
DataProto
from
...utils.reward_score
import
math_compute_score
,
r1v_compute_score
class
RewardScore
(
TypedDict
):
overall
:
float
format
:
float
accuracy
:
float
class
CustomRewardManager
:
def
__init__
(
self
,
tokenizer
:
PreTrainedTokenizer
,
num_examine
:
int
,
compute_score
:
str
):
def
__init__
(
self
,
tokenizer
:
PreTrainedTokenizer
,
compute_score
:
str
):
self
.
tokenizer
=
tokenizer
self
.
num_examine
=
num_examine
if
compute_score
==
"math"
:
self
.
compute_score
=
math_compute_score
self
.
compute_score
:
Callable
[[
str
,
str
],
RewardScore
]
=
math_compute_score
elif
compute_score
==
"r1v"
:
self
.
compute_score
=
r1v_compute_score
self
.
compute_score
:
Callable
[[
str
,
str
],
RewardScore
]
=
r1v_compute_score
else
:
raise
NotImplementedError
()
def
__call__
(
self
,
data
:
DataProto
)
->
torch
.
Tensor
:
def
__call__
(
self
,
data
:
DataProto
)
->
Tuple
[
torch
.
Tensor
,
Dict
[
str
,
Any
]]
:
reward_tensor
=
torch
.
zeros_like
(
data
.
batch
[
"responses"
],
dtype
=
torch
.
float32
)
already_print
=
0
reward_metrics
=
defaultdict
(
list
)
for
i
in
range
(
len
(
data
)):
data_item
=
data
[
i
]
# DataProtoItem
prompt_ids
=
data_item
.
batch
[
"prompts"
]
prompt_length
=
prompt_ids
.
shape
[
-
1
]
valid_prompt_length
=
data_item
.
batch
[
"attention_mask"
][:
prompt_length
].
sum
()
valid_prompt_ids
=
prompt_ids
[
-
valid_prompt_length
:]
response_ids
=
data_item
.
batch
[
"responses"
]
valid_response_length
=
data_item
.
batch
[
"attention_mask"
][
prompt_length
:].
sum
()
response_mask
=
data_item
.
batch
[
"response_mask"
]
valid_response_length
=
response_mask
.
sum
()
valid_response_ids
=
response_ids
[:
valid_response_length
]
# decode
prompt_str
=
self
.
tokenizer
.
decode
(
valid_prompt_ids
,
skip_special_tokens
=
True
)
response_str
=
self
.
tokenizer
.
decode
(
valid_response_ids
,
skip_special_tokens
=
True
)
ground_truth
=
data_item
.
non_tensor_batch
[
"answer"
]
ground_truth
=
data_item
.
non_tensor_batch
[
"ground_truth"
]
score
=
self
.
compute_score
(
response_str
,
ground_truth
)
reward_tensor
[
i
,
valid_response_length
-
1
]
=
score
if
already_print
<
self
.
num_examine
:
already_print
+=
1
print
(
"[prompt]"
,
prompt_str
)
print
(
"[response]"
,
response_str
)
print
(
"[ground_truth]"
,
ground_truth
)
print
(
"[score]"
,
score
)
reward_tensor
[
i
,
valid_response_length
-
1
]
=
score
[
"overall"
]
for
key
,
value
in
score
.
items
():
reward_metrics
[
key
].
append
(
value
)
return
reward_tensor
return
reward_tensor
,
reward_metrics
verl/workers/rollout/base.py
View file @
c132cbcb
...
...
@@ -14,7 +14,7 @@
from
abc
import
ABC
,
abstractmethod
from
ver
l
import
DataProto
from
...protoco
l
import
DataProto
__all__
=
[
"BaseRollout"
]
...
...
verl/workers/rollout/config.py
View file @
c132cbcb
...
...
@@ -16,15 +16,18 @@ Rollout config
"""
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
Any
,
Dict
@
dataclass
class
RolloutConfig
:
name
:
str
=
"vllm"
n
:
int
=
1
temperature
:
float
=
1.0
top_k
:
int
=
-
1
top_p
:
float
=
1.0
dtype
:
str
=
"bfloat16"
top_k
:
int
=
-
1
limit_images
:
int
=
0
dtype
:
str
=
"bf16"
gpu_memory_utilization
:
float
=
0.5
ignore_eos
:
bool
=
False
enforce_eager
:
bool
=
False
...
...
@@ -34,9 +37,7 @@ class RolloutConfig:
max_num_batched_tokens
:
int
=
8192
max_num_seqs
:
int
=
1024
disable_log_stats
:
bool
=
True
do_sample
:
bool
=
True
n
:
int
=
1
limit_images
:
int
=
0
val_override_config
:
Dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
"""auto keys"""
prompt_length
:
int
=
field
(
default
=-
1
,
init
=
False
)
response_length
:
int
=
field
(
default
=-
1
,
init
=
False
)
...
...
verl/workers/rollout/vllm_rollout/__init__.py
View file @
c132cbcb
...
...
@@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
.dtensor_weight_loaders
import
load_dtensor_weights
from
.vllm_rollout_spmd
import
vLLMRollout
__all__
=
[
"vLLMRollout"
,
"load_dtensor_weights"
]
__all__
=
[
"vLLMRollout"
]
verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
View file @
c132cbcb
...
...
@@ -18,26 +18,29 @@ When working with FSDP:
- Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM
"""
import
os
from
contextlib
import
contextmanager
from
typing
import
Any
,
List
,
Union
import
numpy
as
np
import
torch
import
torch.distributed
from
tensordict
import
TensorDict
from
transformers
import
PreTrainedTokenizer
from
vllm
import
LLM
,
RequestOutput
,
SamplingParams
from
verl
import
DataProto
from
verl.utils.torch_functional
import
get_eos_mask
,
pad_2d_list_to_length
from
verl.workers.rollout.base
import
BaseRollout
from
verl.workers.rollout.config
import
RolloutConfig
from
....protocol
import
DataProto
from
....utils
import
torch_functional
as
VF
from
....utils.torch_dtypes
import
PrecisionType
from
..base
import
BaseRollout
from
..config
import
RolloutConfig
def
_repeat_interleave
(
features
:
Union
[
torch
.
Tensor
,
List
[
Any
]
],
repeats
:
int
)
->
Union
[
torch
.
Tensor
,
List
[
Any
]]:
if
isinstance
(
features
,
torch
.
Tensor
):
return
features
.
repeat_interleave
(
repeats
,
dim
=
0
)
def
_repeat_interleave
(
value
:
Union
[
torch
.
Tensor
,
np
.
ndarray
],
repeats
:
int
)
->
Union
[
torch
.
Tensor
,
List
[
Any
]]:
if
isinstance
(
value
,
torch
.
Tensor
):
return
value
.
repeat_interleave
(
repeats
,
dim
=
0
)
else
:
return
[
feature
for
feature
in
features
for
_
in
range
(
repeats
)]
return
np
.
repeat
(
value
,
repeats
,
axis
=
0
)
class
vLLMRollout
(
BaseRollout
):
...
...
@@ -50,6 +53,7 @@ class vLLMRollout(BaseRollout):
tokenizer: the task/model tokenizer
"""
super
().
__init__
()
self
.
rank
=
int
(
os
.
getenv
(
"RANK"
,
"0"
))
self
.
config
=
config
self
.
pad_token_id
=
tokenizer
.
pad_token_id
if
config
.
tensor_parallel_size
>
torch
.
distributed
.
get_world_size
():
...
...
@@ -69,7 +73,7 @@ class vLLMRollout(BaseRollout):
model
=
model_path
,
skip_tokenizer_init
=
False
,
tensor_parallel_size
=
config
.
tensor_parallel_size
,
dtype
=
config
.
dtype
,
dtype
=
PrecisionType
.
to_str
(
PrecisionType
.
to_dtype
(
config
.
dtype
))
,
gpu_memory_utilization
=
config
.
gpu_memory_utilization
,
enforce_eager
=
config
.
enforce_eager
,
max_model_len
=
config
.
prompt_length
+
config
.
response_length
,
...
...
@@ -77,6 +81,7 @@ class vLLMRollout(BaseRollout):
enable_sleep_mode
=
True
,
distributed_executor_backend
=
"external_launcher"
,
disable_custom_all_reduce
=
True
,
disable_mm_preprocessor_cache
=
True
,
disable_log_stats
=
config
.
disable_log_stats
,
enable_chunked_prefill
=
config
.
enable_chunked_prefill
,
**
vllm_init_kwargs
,
...
...
@@ -111,7 +116,7 @@ class vLLMRollout(BaseRollout):
setattr
(
self
.
sampling_params
,
key
,
value
)
@
torch
.
no_grad
()
def
generate_sequences
(
self
,
prompts
:
DataProto
,
**
kwargs
)
->
DataProto
:
def
generate_sequences
(
self
,
prompts
:
DataProto
)
->
DataProto
:
# left-padded attention_mask
input_ids
:
torch
.
Tensor
=
prompts
.
batch
[
"input_ids"
]
# (bs, prompt_length)
attention_mask
:
torch
.
Tensor
=
prompts
.
batch
[
"attention_mask"
]
...
...
@@ -119,54 +124,40 @@ class vLLMRollout(BaseRollout):
eos_token_id
:
int
=
prompts
.
meta_info
[
"eos_token_id"
]
batch_size
=
input_ids
.
size
(
0
)
do_sample
=
prompts
.
meta_info
.
get
(
"do_sample"
,
True
)
if
not
do_sample
:
kwargs
=
{
"n"
:
1
,
"temperature"
:
0.0
,
"top_p"
:
1.0
,
"top_k"
:
-
1
,
"min_p"
:
0.0
,
}
non_tensor_batch
=
prompts
.
non_tensor_batch
if
batch_size
!=
len
(
non_tensor_batch
[
"raw_prompt_ids"
]):
raise
RuntimeError
(
"vllm sharding manager is not work properly."
)
if
"
images
"
in
non_tensor_batch
:
if
"
multi_modal_data
"
in
non_tensor_batch
:
vllm_inputs
=
[]
for
raw_prompt_ids
,
images
in
zip
(
non_tensor_batch
.
pop
(
"raw_prompt_ids"
),
non_tensor_batch
.
pop
(
"images"
)):
vllm_inputs
.
append
({
"prompt_token_ids"
:
raw_prompt_ids
,
"multi_modal_data"
:
{
"image"
:
images
}})
for
raw_prompt_ids
,
multi_modal_data
in
zip
(
non_tensor_batch
.
pop
(
"raw_prompt_ids"
),
non_tensor_batch
.
pop
(
"multi_modal_data"
)
):
vllm_inputs
.
append
({
"prompt_token_ids"
:
list
(
raw_prompt_ids
),
"multi_modal_data"
:
multi_modal_data
})
else
:
vllm_inputs
=
[
{
"prompt_token_ids"
:
raw_prompt_ids
}
for
raw_prompt_ids
in
non_tensor_batch
.
pop
(
"raw_prompt_ids"
)
{
"prompt_token_ids"
:
list
(
raw_prompt_ids
)
}
for
raw_prompt_ids
in
non_tensor_batch
.
pop
(
"raw_prompt_ids"
)
]
# users can customize different sampling_params at different run
with
self
.
update_sampling_params
(
**
kwargs
):
with
self
.
update_sampling_params
(
**
prompts
.
meta_info
):
completions
:
List
[
RequestOutput
]
=
self
.
inference_engine
.
generate
(
prompts
=
vllm_inputs
,
sampling_params
=
self
.
sampling_params
prompts
=
vllm_inputs
,
sampling_params
=
self
.
sampling_params
,
use_tqdm
=
(
self
.
rank
==
0
)
)
response_ids
=
[]
for
completion
in
completions
:
for
output
in
completion
.
outputs
:
response_ids
.
append
(
output
.
token_ids
)
response_ids
=
pad_2d_list_to_length
(
response_ids
,
self
.
pad_token_id
,
max_length
=
self
.
config
.
response_length
).
to
(
input_ids
.
device
)
if
self
.
config
.
n
>
1
and
do_sample
:
batch_size
=
batch_size
*
self
.
config
.
n
input_ids
=
_repeat_interleave
(
input_ids
,
self
.
config
.
n
)
attention_mask
=
_repeat_interleave
(
attention_mask
,
self
.
config
.
n
)
position_ids
=
_repeat_interleave
(
position_ids
,
self
.
config
.
n
)
if
"pixel_values"
in
non_tensor_batch
.
keys
():
non_tensor_batch
[
"pixel_values"
]
=
_repeat_interleave
(
non_tensor_batch
[
"pixel_values"
],
self
.
config
.
n
)
non_tensor_batch
[
"image_grid_thw"
]
=
_repeat_interleave
(
non_tensor_batch
[
"image_grid_thw"
],
self
.
config
.
n
)
response_ids
=
[
output
.
token_ids
for
completion
in
completions
for
output
in
completion
.
outputs
]
response_ids
=
VF
.
pad_2d_list_to_length
(
response_ids
,
self
.
pad_token_id
,
max_length
=
self
.
config
.
response_length
).
to
(
input_ids
.
device
)
if
self
.
sampling_params
.
n
>
1
:
batch_size
=
batch_size
*
self
.
sampling_params
.
n
input_ids
=
_repeat_interleave
(
input_ids
,
self
.
sampling_params
.
n
)
attention_mask
=
_repeat_interleave
(
attention_mask
,
self
.
sampling_params
.
n
)
position_ids
=
_repeat_interleave
(
position_ids
,
self
.
sampling_params
.
n
)
if
"multi_modal_inputs"
in
non_tensor_batch
.
keys
():
non_tensor_batch
[
"multi_modal_inputs"
]
=
_repeat_interleave
(
non_tensor_batch
[
"multi_modal_inputs"
],
self
.
sampling_params
.
n
)
sequence_ids
=
torch
.
cat
([
input_ids
,
response_ids
],
dim
=-
1
)
response_length
=
response_ids
.
size
(
1
)
...
...
@@ -180,10 +171,10 @@ class vLLMRollout(BaseRollout):
# position_ids: [0,0,0,0,0,1,2,3 | 4,5,6,7,8,9,10,11]
response_position_ids
=
position_ids
[...,
-
1
:]
+
delta_position_id
position_ids
=
torch
.
cat
([
position_ids
,
response_position_ids
],
dim
=-
1
)
response_
attention_
mask
=
get_eos_mask
(
response_ids
=
response_ids
,
eos_token
=
eos_token_id
,
dtype
=
attention_mask
.
dtype
response_mask
=
VF
.
get_eos_mask
(
response_ids
=
response_ids
,
eos_token
_id
=
eos_token_id
,
dtype
=
attention_mask
.
dtype
)
attention_mask
=
torch
.
cat
((
attention_mask
,
response_
attention_
mask
),
dim
=-
1
)
attention_mask
=
torch
.
cat
((
attention_mask
,
response_mask
),
dim
=-
1
)
# all the tp ranks should contain the same data here. data in all ranks are valid
batch
=
TensorDict
(
...
...
@@ -192,6 +183,7 @@ class vLLMRollout(BaseRollout):
"responses"
:
response_ids
,
"input_ids"
:
sequence_ids
,
# here input_ids become the whole sentences
"attention_mask"
:
attention_mask
,
"response_mask"
:
response_mask
,
"position_ids"
:
position_ids
,
},
batch_size
=
batch_size
,
...
...
verl/workers/sharding_manager/base.py
View file @
c132cbcb
...
...
@@ -15,7 +15,7 @@
Sharding manager to implement HybridEngine
"""
from
ver
l
import
DataProto
from
...protoco
l
import
DataProto
class
BaseShardingManager
:
...
...
verl/workers/sharding_manager/fsdp_ulysses.py
View file @
c132cbcb
...
...
@@ -17,9 +17,8 @@ Contains a resharding manager that binds weights from FSDP zero3 to XPerfGPT
from
torch.distributed.device_mesh
import
DeviceMesh
from
verl
import
DataProto
from
verl.utils.ulysses
import
get_ulysses_sequence_parallel_group
,
set_ulysses_sequence_parallel_group
from
...protocol
import
DataProto
,
all_gather_data_proto
from
...utils.ulysses
import
get_ulysses_sequence_parallel_group
,
set_ulysses_sequence_parallel_group
from
.base
import
BaseShardingManager
...
...
@@ -48,9 +47,9 @@ class FSDPUlyssesShardingManager(BaseShardingManager):
In Ulysses, we need to make sure the same data is used across a SP group
"""
if
self
.
device_mesh
is
not
None
:
sp_size
=
self
.
device_mesh
[
"sp"
].
size
()
sp_group
=
self
.
device_mesh
[
"sp"
].
get_group
()
data
=
data
.
to
(
"cuda"
)
data
.
all_gather
(
sp_group
)
all_gather_data_proto
(
data
,
size
=
sp_size
,
group
=
sp_group
)
return
data
...
...
verl/workers/sharding_manager/fsdp_vllm.py
View file @
c132cbcb
...
...
@@ -12,19 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
warnings
from
typing
import
Dict
,
Iterable
,
Tuple
,
Union
import
torch
import
torch.distributed
as
dist
from
torch.distributed._tensor
import
DTensor
from
torch.distributed.device_mesh
import
DeviceMesh
from
torch.distributed.fsdp.api
import
ShardedStateDictConfig
,
StateDictType
from
torch.distributed.fsdp.fully_sharded_data_parallel
import
FullyShardedDataParallel
as
FSDP
from
vllm
import
LLM
from
vllm.distributed
import
parallel_state
as
vllm_ps
from
verl
import
DataProto
from
verl.utils.performance
import
log_gpu_memory_usage
from
verl.workers.rollout.vllm_rollout
import
load_dtensor_weights
from
...protocol
import
DataProto
,
all_gather_data_proto
from
...utils.model_utils
import
print_gpu_memory_usage
from
.base
import
BaseShardingManager
...
...
@@ -38,11 +39,22 @@ class FSDPVLLMShardingManager(BaseShardingManager):
self
.
module
=
module
self
.
inference_engine
=
inference_engine
self
.
device_mesh
=
device_mesh
FSDP
.
set_state_dict_type
(
self
.
module
,
state_dict_type
=
StateDictType
.
SHARDED_STATE_DICT
,
state_dict_config
=
ShardedStateDictConfig
(),
)
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"ignore"
)
FSDP
.
set_state_dict_type
(
self
.
module
,
state_dict_type
=
StateDictType
.
SHARDED_STATE_DICT
,
state_dict_config
=
ShardedStateDictConfig
(),
)
self
.
world_size
=
dist
.
get_world_size
()
self
.
tp_size
=
vllm_ps
.
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
vllm_ps
.
get_tensor_model_parallel_rank
()
self
.
tp_group
=
vllm_ps
.
get_tensor_model_parallel_group
().
device_group
# Record freed bytes to estimate memory usage correctly
# https://github.com/vllm-project/vllm/pull/11743#issuecomment-2754338119
self
.
freed_bytes
=
0
# Note that torch_random_states may be different on each dp rank
self
.
torch_random_states
=
torch
.
cuda
.
get_rng_state
()
...
...
@@ -55,29 +67,45 @@ class FSDPVLLMShardingManager(BaseShardingManager):
else
:
self
.
gen_random_states
=
None
def
_make_weight_iterator
(
self
,
actor_weights
:
Dict
[
str
,
Union
[
torch
.
Tensor
,
DTensor
]]
)
->
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]:
for
name
,
tensor
in
actor_weights
.
items
():
yield
name
,
tensor
.
full_tensor
()
if
self
.
world_size
!=
1
else
tensor
def
__enter__
(
self
):
log_gpu_memory_usage
(
"Before state_dict() in sharding manager"
)
# NOTE: Basically, we only need `torch.cuda.empty_cache()` before vllm wake_up and
# after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator.
# Out of vllm scope, we should avoid empty cache to let pytorch using caching memory
# to speed up memory allocations.
#
# pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management
# vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103
torch
.
cuda
.
empty_cache
()
print_gpu_memory_usage
(
"Before state_dict() in sharding manager"
)
actor_weights
=
self
.
module
.
state_dict
()
log
_gpu_memory_usage
(
"After state_dict() in sharding manager"
)
print
_gpu_memory_usage
(
"After state_dict() in sharding manager"
)
self
.
inference_engine
.
wake_up
()
load_dtensor_weights
(
actor_weights
,
self
.
inference_engine
.
llm_engine
.
model_executor
.
driver_worker
.
worker
.
model_runner
.
model
)
log_gpu_memory_usage
(
"After sync model weights in sharding manager"
)
model
=
self
.
inference_engine
.
llm_engine
.
model_executor
.
driver_worker
.
worker
.
model_runner
.
model
model
.
load_weights
(
self
.
_make_weight_iterator
(
actor_weights
))
print_gpu_memory_usage
(
"After sync model weights in sharding manager"
)
del
actor_weights
torch
.
cuda
.
empty_cache
()
log
_gpu_memory_usage
(
"After del state_dict and empty_cache in sharding manager"
)
print
_gpu_memory_usage
(
"After del state_dict and empty_cache in sharding manager"
)
# important: need to manually set the random states of each tp to be identical.
if
self
.
device_mesh
is
not
None
:
self
.
torch_random_states
=
torch
.
cuda
.
get_rng_state
()
torch
.
cuda
.
set_rng_state
(
self
.
gen_random_states
)
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
log_gpu_memory_usage
(
"Before vllm offload in sharding manager"
)
print_gpu_memory_usage
(
"Before vllm offload in sharding manager"
)
free_bytes_before_sleep
=
torch
.
cuda
.
mem_get_info
()[
0
]
self
.
inference_engine
.
sleep
(
level
=
1
)
log_gpu_memory_usage
(
"After vllm offload in sharding manager"
)
free_bytes_after_sleep
=
torch
.
cuda
.
mem_get_info
()[
0
]
self
.
freed_bytes
=
free_bytes_after_sleep
-
free_bytes_before_sleep
print_gpu_memory_usage
(
"After vllm offload in sharding manager"
)
self
.
module
.
train
()
torch
.
cuda
.
empty_cache
()
# add empty cache after each compute
...
...
@@ -88,15 +116,13 @@ class FSDPVLLMShardingManager(BaseShardingManager):
torch
.
cuda
.
set_rng_state
(
self
.
torch_random_states
)
def
preprocess_data
(
self
,
data
:
DataProto
)
->
DataProto
:
tp_group
=
vllm_ps
.
get_tensor_model_parallel_group
().
device_group
data
=
data
.
to
(
"cuda"
)
data
.
all_gather
(
tp_group
)
"""All gather across tp group to make each rank has identical input."""
all_gather_data_proto
(
data
,
size
=
self
.
tp_size
,
group
=
self
.
tp_group
)
return
data
def
postprocess_data
(
self
,
data
:
DataProto
)
->
DataProto
:
dp_rank
=
dist
.
get_rank
()
tp_size
=
vllm_ps
.
get_tensor_model_parallel_world_size
()
if
tp_size
>
1
:
data
=
data
.
chunk
(
chunks
=
tp_size
)[
dp_rank
%
tp_size
]
"""Get chunk data of this tp rank since we do all gather in preprocess."""
if
self
.
tp_size
>
1
:
data
=
data
.
chunk
(
chunks
=
self
.
tp_size
)[
self
.
tp_rank
]
return
data
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment