Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
EasyR1
Commits
f92481f0
Commit
f92481f0
authored
Mar 04, 2025
by
chenych
Browse files
First commit.
parent
7121d0b0
Pipeline
#2435
failed with stages
in 0 seconds
Changes
88
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
824 additions
and
0 deletions
+824
-0
verl/workers/rollout/config.py
verl/workers/rollout/config.py
+45
-0
verl/workers/rollout/vllm_rollout/__init__.py
verl/workers/rollout/vllm_rollout/__init__.py
+19
-0
verl/workers/rollout/vllm_rollout/dtensor_weight_loaders.py
verl/workers/rollout/vllm_rollout/dtensor_weight_loaders.py
+340
-0
verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
+199
-0
verl/workers/sharding_manager/__init__.py
verl/workers/sharding_manager/__init__.py
+21
-0
verl/workers/sharding_manager/base.py
verl/workers/sharding_manager/base.py
+32
-0
verl/workers/sharding_manager/fsdp_ulysses.py
verl/workers/sharding_manager/fsdp_ulysses.py
+66
-0
verl/workers/sharding_manager/fsdp_vllm.py
verl/workers/sharding_manager/fsdp_vllm.py
+102
-0
No files found.
verl/workers/rollout/config.py
0 → 100644
View file @
f92481f0
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Rollout config
"""
from
dataclasses
import
asdict
,
dataclass
,
field
@
dataclass
class
RolloutConfig
:
name
:
str
=
"vllm"
temperature
:
float
=
1.0
top_k
:
int
=
-
1
top_p
:
float
=
1.0
dtype
:
str
=
"bfloat16"
gpu_memory_utilization
:
float
=
0.5
ignore_eos
:
bool
=
False
enforce_eager
:
bool
=
False
free_cache_engine
:
bool
=
False
enable_chunked_prefill
:
bool
=
False
tensor_parallel_size
:
int
=
2
max_num_batched_tokens
:
int
=
8192
max_num_seqs
:
int
=
1024
disable_log_stats
:
bool
=
True
do_sample
:
bool
=
True
n
:
int
=
1
limit_images
:
int
=
0
"""auto keys"""
prompt_length
:
int
=
field
(
default
=-
1
,
init
=
False
)
response_length
:
int
=
field
(
default
=-
1
,
init
=
False
)
def
to_dict
(
self
):
return
asdict
(
self
)
verl/workers/rollout/vllm_rollout/__init__.py
0 → 100644
View file @
f92481f0
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.dtensor_weight_loaders
import
load_dtensor_weights
from
.vllm_rollout_spmd
import
vLLMRollout
__all__
=
[
"vLLMRollout"
,
"load_dtensor_weights"
]
verl/workers/rollout/vllm_rollout/dtensor_weight_loaders.py
0 → 100644
View file @
f92481f0
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader
from
typing
import
Dict
import
torch
import
torch.nn
as
nn
from
torch.distributed._tensor
import
DTensor
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.utils
import
is_pp_missing_parameter
def
gemma_dtensor_weight_loader
(
actor_weights
:
Dict
,
vllm_model
:
nn
.
Module
)
->
nn
.
Module
:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
vllm_model
.
named_parameters
())
for
name
,
loaded_weight
in
actor_weights
.
items
():
for
param_name
,
shard_name
,
shard_id
in
stacked_params_mapping
:
if
shard_name
not
in
name
:
continue
stacked_name
=
name
.
replace
(
shard_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
stacked_name
.
endswith
(
".bias"
)
and
stacked_name
not
in
params_dict
:
continue
local_loaded_weight
=
redistribute_dtensor
(
param_name
=
name
,
loaded_weights
=
loaded_weight
)
param
=
params_dict
[
stacked_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
local_loaded_weight
.
to
(
dtype
=
param
.
dtype
),
shard_id
)
break
else
:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if
"lm_head.weight"
in
name
:
continue
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
local_loaded_weight
=
redistribute_dtensor
(
param_name
=
name
,
loaded_weights
=
loaded_weight
)
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
local_loaded_weight
.
to
(
dtype
=
param
.
dtype
))
def
llama_dtensor_weight_loader
(
actor_weights
:
Dict
,
vllm_model
:
nn
.
Module
)
->
nn
.
Module
:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
vllm_model
.
named_parameters
())
for
name
,
loaded_weight
in
actor_weights
.
items
():
if
"rotary_emb.inv_freq"
in
name
:
continue
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if
vllm_model
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
local_loaded_weight
=
redistribute_dtensor
(
param_name
=
name
,
loaded_weights
=
loaded_weight
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
local_loaded_weight
.
to
(
dtype
=
param
.
dtype
),
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
local_loaded_weight
=
redistribute_dtensor
(
param_name
=
name
,
loaded_weights
=
loaded_weight
)
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
local_loaded_weight
)
def
qwen2_dtensor_weight_loader
(
actor_weights
:
Dict
,
vllm_model
:
nn
.
Module
)
->
nn
.
Module
:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
vllm_model
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
actor_weights
.
items
():
if
"rotary_emb.inv_freq"
in
name
:
continue
if
vllm_model
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
local_loaded_weight
=
redistribute_dtensor
(
param_name
=
name
,
loaded_weights
=
loaded_weight
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
local_loaded_weight
.
to
(
dtype
=
param
.
dtype
),
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
local_loaded_weight
=
redistribute_dtensor
(
param_name
=
name
,
loaded_weights
=
loaded_weight
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
local_loaded_weight
.
to
(
dtype
=
param
.
dtype
))
def
qwen2vl_dtensor_weight_loader
(
actor_weights
:
Dict
[
str
,
torch
.
Tensor
],
vllm_model
:
nn
.
Module
)
->
nn
.
Module
:
stacked_params_mapping
=
[
# (vllm_substr, hf_substr, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
vllm_params
=
dict
(
vllm_model
.
named_parameters
(
remove_duplicate
=
False
))
for
actor_name
,
actor_weight
in
actor_weights
.
items
():
if
"rotary_emb.inv_freq"
in
actor_name
:
continue
if
vllm_model
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
actor_name
:
continue
for
vllm_substr
,
hf_substr
,
shard_id
in
stacked_params_mapping
:
if
hf_substr
not
in
actor_name
:
continue
if
"visual"
in
actor_name
:
continue
vllm_name
=
"language_model."
+
actor_name
.
replace
(
hf_substr
,
vllm_substr
)
if
actor_name
.
endswith
(
".bias"
)
and
actor_name
not
in
vllm_params
:
continue
# skip loading extra bias for GPTQ models
local_actor_weight
=
redistribute_dtensor
(
param_name
=
actor_name
,
loaded_weights
=
actor_weight
)
vllm_param
=
vllm_params
[
vllm_name
]
weight_loader
=
vllm_param
.
weight_loader
weight_loader
(
vllm_param
,
local_actor_weight
.
to
(
dtype
=
vllm_param
.
dtype
),
shard_id
)
break
else
:
if
actor_name
.
endswith
(
".bias"
)
and
actor_name
not
in
vllm_params
:
continue
# skip loading extra bias for GPTQ models
if
"visual"
in
actor_name
:
vllm_name
=
actor_name
else
:
vllm_name
=
"language_model."
+
actor_name
vllm_param
=
vllm_params
[
vllm_name
]
local_actor_weight
=
redistribute_dtensor
(
param_name
=
actor_name
,
loaded_weights
=
actor_weight
)
weight_loader
=
getattr
(
vllm_param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
vllm_param
,
local_actor_weight
.
to
(
dtype
=
vllm_param
.
dtype
))
def
deepseekv2_dtensor_weight_loader
(
actor_weights
:
Dict
,
vllm_model
:
nn
.
Module
)
->
nn
.
Module
:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
vllm_model
.
config
.
n_routed_experts
,
)
params_dict
=
dict
(
vllm_model
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
actor_weights
.
items
():
if
"rotary_emb.inv_freq"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
# Skip non-stacked layers and experts (experts handled below).
if
weight_name
not
in
name
:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if
(
"mlp.experts."
in
name
)
and
name
not
in
params_dict
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
vllm_model
):
continue
param
=
params_dict
[
name
]
local_loaded_weight
=
redistribute_dtensor
(
param_name
=
name
,
loaded_weights
=
loaded_weight
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
local_loaded_weight
.
to
(
dtype
=
param
.
dtype
),
shard_id
)
break
else
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
vllm_model
):
continue
param
=
params_dict
[
name
]
local_loaded_weight
=
redistribute_dtensor
(
param_name
=
name
,
loaded_weights
=
loaded_weight
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
local_loaded_weight
.
to
(
dtype
=
param
.
dtype
),
weight_name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
vllm_model
):
continue
param
=
params_dict
[
name
]
local_loaded_weight
=
redistribute_dtensor
(
param_name
=
name
,
loaded_weights
=
loaded_weight
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
local_loaded_weight
.
to
(
dtype
=
param
.
dtype
))
def
redistribute_dtensor
(
param_name
:
str
,
loaded_weights
:
DTensor
,
parallelize_plan
:
Dict
=
None
):
param_name
=
_process_parameter_names
(
name
=
param_name
)
if
parallelize_plan
is
not
None
:
assert
param_name
in
parallelize_plan
.
keys
(),
(
f
"param name:
{
param_name
}
not in parallelize_plan :
{
parallelize_plan
.
keys
()
}
"
)
placement
=
parallelize_plan
[
param_name
]
local_loaded_weights
=
loaded_weights
.
redistribute
(
device_mesh
=
loaded_weights
.
device_mesh
,
placements
=
placement
).
to_local
()
else
:
local_loaded_weights
=
loaded_weights
.
full_tensor
()
return
local_loaded_weights
def
_process_parameter_names
(
name
):
# Remove '.weight' if it exists at the end of the string
if
name
.
endswith
(
".weight"
):
name
=
name
[:
-
7
]
# Remove 'model.layers.x.' or 'model.' prefix
if
"model.layers"
in
name
:
parts
=
name
.
split
(
"."
)
# Reconstruct the string without 'model.layers.x.'
name
=
"."
.
join
(
parts
[
3
:])
# parts[0] is 'model', parts[1] is 'layers', parts[2] is 'x'
elif
name
.
startswith
(
"model."
):
name
=
name
[
6
:]
# Remove 'model.'
return
name
__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__
=
{
"LlamaForCausalLM"
:
llama_dtensor_weight_loader
,
"LLaMAForCausalLM"
:
llama_dtensor_weight_loader
,
"MistralForCausalLM"
:
llama_dtensor_weight_loader
,
# mistral is the same as llama in vLLM
"InternLMForCausalLM"
:
llama_dtensor_weight_loader
,
"Phi3ForCausalLM"
:
llama_dtensor_weight_loader
,
"GemmaForCausalLM"
:
gemma_dtensor_weight_loader
,
"Gemma2ForCausalLM"
:
gemma_dtensor_weight_loader
,
"Qwen2ForCausalLM"
:
qwen2_dtensor_weight_loader
,
"DeepseekV2ForCausalLM"
:
deepseekv2_dtensor_weight_loader
,
"Qwen2VLForConditionalGeneration"
:
qwen2vl_dtensor_weight_loader
,
"Qwen2_5_VLForConditionalGeneration"
:
qwen2vl_dtensor_weight_loader
,
}
# the actor model is .state_dict()
# Load dtensor weights
def
load_dtensor_weights
(
actor_weights
:
Dict
,
vllm_model
:
nn
.
Module
):
weight_loader
=
_get_model_weight_loader
(
vllm_model
.
__class__
.
__name__
)
weight_loader
(
actor_weights
,
vllm_model
)
# NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu
# after init, and we need this after sync model weights for in first iter.
vllm_model
=
vllm_model
.
cuda
()
def
_get_model_weight_loader
(
arch
:
str
):
if
arch
in
__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__
:
return
__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__
[
arch
]
raise
ValueError
(
f
"Model architectures
{
arch
}
are not supported for now. "
f
"Supported architectures:
{
__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__
.
keys
()
}
"
)
# NOTE(sgm): we use per-parameter weight loader in each vllm sub
def
update_dtensor_weight_loader
():
pass
verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
0 → 100644
View file @
f92481f0
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The vllm_rollout that can be applied in different backend
When working with FSDP:
- Use DTensor weight loader (recommended) or HF weight loader
- Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM
"""
from
contextlib
import
contextmanager
from
typing
import
Any
,
List
,
Union
import
torch
import
torch.distributed
from
tensordict
import
TensorDict
from
transformers
import
PreTrainedTokenizer
from
vllm
import
LLM
,
RequestOutput
,
SamplingParams
from
verl
import
DataProto
from
verl.utils.torch_functional
import
get_eos_mask
,
pad_2d_list_to_length
from
verl.workers.rollout.base
import
BaseRollout
from
verl.workers.rollout.config
import
RolloutConfig
def
_repeat_interleave
(
features
:
Union
[
torch
.
Tensor
,
List
[
Any
]],
repeats
:
int
)
->
Union
[
torch
.
Tensor
,
List
[
Any
]]:
if
isinstance
(
features
,
torch
.
Tensor
):
return
features
.
repeat_interleave
(
repeats
,
dim
=
0
)
else
:
return
[
feature
for
feature
in
features
for
_
in
range
(
repeats
)]
class
vLLMRollout
(
BaseRollout
):
def
__init__
(
self
,
model_path
:
str
,
config
:
RolloutConfig
,
tokenizer
:
PreTrainedTokenizer
):
"""A vLLM rollout. It requires the module is supported by the vllm.
Args:
module: module here follows huggingface APIs
config: DictConfig
tokenizer: the task/model tokenizer
"""
super
().
__init__
()
self
.
config
=
config
self
.
pad_token_id
=
tokenizer
.
pad_token_id
if
config
.
tensor_parallel_size
>
torch
.
distributed
.
get_world_size
():
raise
ValueError
(
"Tensor parallelism size should be less than world size."
)
if
not
config
.
enforce_eager
and
config
.
free_cache_engine
:
raise
ValueError
(
"CUDA graph should be disabled when `free_cache_engine` is True."
)
if
config
.
max_num_batched_tokens
<
config
.
prompt_length
+
config
.
response_length
:
raise
ValueError
(
"max_num_batched_tokens should be greater than prompt_length + response_length."
)
vllm_init_kwargs
=
{}
if
config
.
limit_images
>
0
:
vllm_init_kwargs
=
{
"limit_mm_per_prompt"
:
{
"image"
:
config
.
limit_images
}}
self
.
inference_engine
=
LLM
(
model
=
model_path
,
skip_tokenizer_init
=
False
,
tensor_parallel_size
=
config
.
tensor_parallel_size
,
dtype
=
config
.
dtype
,
gpu_memory_utilization
=
config
.
gpu_memory_utilization
,
enforce_eager
=
config
.
enforce_eager
,
max_model_len
=
config
.
prompt_length
+
config
.
response_length
,
max_num_batched_tokens
=
config
.
max_num_batched_tokens
,
enable_sleep_mode
=
True
,
distributed_executor_backend
=
"external_launcher"
,
disable_custom_all_reduce
=
True
,
disable_log_stats
=
config
.
disable_log_stats
,
enable_chunked_prefill
=
config
.
enable_chunked_prefill
,
**
vllm_init_kwargs
,
)
# Offload vllm model to reduce peak memory usage
self
.
inference_engine
.
sleep
(
level
=
1
)
sampling_kwargs
=
{
"max_tokens"
:
config
.
response_length
,
"detokenize"
:
False
}
default_sampling_params
=
SamplingParams
()
for
key
in
config
.
to_dict
().
keys
():
if
hasattr
(
default_sampling_params
,
key
):
sampling_kwargs
[
key
]
=
getattr
(
config
,
key
)
print
(
f
"Sampling params:
{
sampling_kwargs
}
."
)
self
.
sampling_params
=
SamplingParams
(
**
sampling_kwargs
)
@
contextmanager
def
update_sampling_params
(
self
,
**
kwargs
):
# update sampling params
old_sampling_params_args
=
{}
if
kwargs
:
for
key
,
value
in
kwargs
.
items
():
if
hasattr
(
self
.
sampling_params
,
key
):
old_value
=
getattr
(
self
.
sampling_params
,
key
)
old_sampling_params_args
[
key
]
=
old_value
setattr
(
self
.
sampling_params
,
key
,
value
)
yield
# roll back to previous sampling params
for
key
,
value
in
old_sampling_params_args
.
items
():
setattr
(
self
.
sampling_params
,
key
,
value
)
@
torch
.
no_grad
()
def
generate_sequences
(
self
,
prompts
:
DataProto
,
**
kwargs
)
->
DataProto
:
# left-padded attention_mask
input_ids
:
torch
.
Tensor
=
prompts
.
batch
[
"input_ids"
]
# (bs, prompt_length)
attention_mask
:
torch
.
Tensor
=
prompts
.
batch
[
"attention_mask"
]
position_ids
:
torch
.
Tensor
=
prompts
.
batch
[
"position_ids"
]
eos_token_id
:
int
=
prompts
.
meta_info
[
"eos_token_id"
]
batch_size
=
input_ids
.
size
(
0
)
do_sample
=
prompts
.
meta_info
.
get
(
"do_sample"
,
True
)
if
not
do_sample
:
kwargs
=
{
"n"
:
1
,
"temperature"
:
0.0
,
"top_p"
:
1.0
,
"top_k"
:
-
1
,
"min_p"
:
0.0
,
}
non_tensor_batch
=
prompts
.
non_tensor_batch
if
batch_size
!=
len
(
non_tensor_batch
[
"raw_prompt_ids"
]):
raise
RuntimeError
(
"vllm sharding manager is not work properly."
)
if
"images"
in
non_tensor_batch
:
vllm_inputs
=
[]
for
raw_prompt_ids
,
images
in
zip
(
non_tensor_batch
.
pop
(
"raw_prompt_ids"
),
non_tensor_batch
.
pop
(
"images"
)):
vllm_inputs
.
append
({
"prompt_token_ids"
:
raw_prompt_ids
,
"multi_modal_data"
:
{
"image"
:
images
}})
else
:
vllm_inputs
=
[
{
"prompt_token_ids"
:
raw_prompt_ids
}
for
raw_prompt_ids
in
non_tensor_batch
.
pop
(
"raw_prompt_ids"
)
]
# users can customize different sampling_params at different run
with
self
.
update_sampling_params
(
**
kwargs
):
completions
:
List
[
RequestOutput
]
=
self
.
inference_engine
.
generate
(
prompts
=
vllm_inputs
,
sampling_params
=
self
.
sampling_params
)
response_ids
=
[]
for
completion
in
completions
:
for
output
in
completion
.
outputs
:
response_ids
.
append
(
output
.
token_ids
)
response_ids
=
pad_2d_list_to_length
(
response_ids
,
self
.
pad_token_id
,
max_length
=
self
.
config
.
response_length
).
to
(
input_ids
.
device
)
if
self
.
config
.
n
>
1
and
do_sample
:
batch_size
=
batch_size
*
self
.
config
.
n
input_ids
=
_repeat_interleave
(
input_ids
,
self
.
config
.
n
)
attention_mask
=
_repeat_interleave
(
attention_mask
,
self
.
config
.
n
)
position_ids
=
_repeat_interleave
(
position_ids
,
self
.
config
.
n
)
if
"pixel_values"
in
non_tensor_batch
.
keys
():
non_tensor_batch
[
"pixel_values"
]
=
_repeat_interleave
(
non_tensor_batch
[
"pixel_values"
],
self
.
config
.
n
)
non_tensor_batch
[
"image_grid_thw"
]
=
_repeat_interleave
(
non_tensor_batch
[
"image_grid_thw"
],
self
.
config
.
n
)
sequence_ids
=
torch
.
cat
([
input_ids
,
response_ids
],
dim
=-
1
)
response_length
=
response_ids
.
size
(
1
)
delta_position_id
=
torch
.
arange
(
1
,
response_length
+
1
,
device
=
position_ids
.
device
)
delta_position_id
=
delta_position_id
.
view
(
1
,
-
1
).
expand
(
batch_size
,
-
1
)
if
position_ids
.
dim
()
==
3
:
# qwen2vl mrope
delta_position_id
=
delta_position_id
.
view
(
batch_size
,
1
,
-
1
).
expand
(
batch_size
,
3
,
-
1
)
# prompt: left pad + response: right pad
# attention_mask: [0,0,0,0,1,1,1,1 | 1,1,1,0,0,0,0,0]
# position_ids: [0,0,0,0,0,1,2,3 | 4,5,6,7,8,9,10,11]
response_position_ids
=
position_ids
[...,
-
1
:]
+
delta_position_id
position_ids
=
torch
.
cat
([
position_ids
,
response_position_ids
],
dim
=-
1
)
response_attention_mask
=
get_eos_mask
(
response_ids
=
response_ids
,
eos_token
=
eos_token_id
,
dtype
=
attention_mask
.
dtype
)
attention_mask
=
torch
.
cat
((
attention_mask
,
response_attention_mask
),
dim
=-
1
)
# all the tp ranks should contain the same data here. data in all ranks are valid
batch
=
TensorDict
(
{
"prompts"
:
input_ids
,
"responses"
:
response_ids
,
"input_ids"
:
sequence_ids
,
# here input_ids become the whole sentences
"attention_mask"
:
attention_mask
,
"position_ids"
:
position_ids
,
},
batch_size
=
batch_size
,
)
return
DataProto
(
batch
=
batch
,
non_tensor_batch
=
non_tensor_batch
)
verl/workers/sharding_manager/__init__.py
0 → 100644
View file @
f92481f0
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.base
import
BaseShardingManager
from
.fsdp_ulysses
import
FSDPUlyssesShardingManager
from
.fsdp_vllm
import
FSDPVLLMShardingManager
__all__
=
[
"BaseShardingManager"
,
"FSDPUlyssesShardingManager"
,
"FSDPVLLMShardingManager"
]
verl/workers/sharding_manager/base.py
0 → 100644
View file @
f92481f0
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Sharding manager to implement HybridEngine
"""
from
verl
import
DataProto
class
BaseShardingManager
:
def
__enter__
(
self
):
pass
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
pass
def
preprocess_data
(
self
,
data
:
DataProto
)
->
DataProto
:
return
data
def
postprocess_data
(
self
,
data
:
DataProto
)
->
DataProto
:
return
data
verl/workers/sharding_manager/fsdp_ulysses.py
0 → 100644
View file @
f92481f0
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains a resharding manager that binds weights from FSDP zero3 to XPerfGPT
"""
from
torch.distributed.device_mesh
import
DeviceMesh
from
verl
import
DataProto
from
verl.utils.ulysses
import
get_ulysses_sequence_parallel_group
,
set_ulysses_sequence_parallel_group
from
.base
import
BaseShardingManager
class
FSDPUlyssesShardingManager
(
BaseShardingManager
):
"""
Sharding manager to support data resharding when using FSDP + Ulysses
"""
def
__init__
(
self
,
device_mesh
:
DeviceMesh
):
super
().
__init__
()
self
.
device_mesh
=
device_mesh
def
__enter__
(
self
):
if
self
.
device_mesh
is
not
None
:
self
.
prev_sp_group
=
get_ulysses_sequence_parallel_group
()
set_ulysses_sequence_parallel_group
(
self
.
device_mesh
[
"sp"
].
get_group
())
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
if
self
.
device_mesh
is
not
None
:
set_ulysses_sequence_parallel_group
(
self
.
prev_sp_group
)
def
preprocess_data
(
self
,
data
:
DataProto
)
->
DataProto
:
"""
AllGather data from sp region
This is because the data is first sharded along the FSDP dimension as we utilize the DP_COMPUTE
In Ulysses, we need to make sure the same data is used across a SP group
"""
if
self
.
device_mesh
is
not
None
:
sp_group
=
self
.
device_mesh
[
"sp"
].
get_group
()
data
=
data
.
to
(
"cuda"
)
data
.
all_gather
(
sp_group
)
return
data
def
postprocess_data
(
self
,
data
:
DataProto
)
->
DataProto
:
"""
Split the data to follow FSDP partition
"""
if
self
.
device_mesh
is
not
None
:
sp_size
=
self
.
device_mesh
[
"sp"
].
size
()
sp_rank
=
self
.
device_mesh
[
"sp"
].
get_local_rank
()
data
=
data
.
chunk
(
chunks
=
sp_size
)[
sp_rank
]
return
data
verl/workers/sharding_manager/fsdp_vllm.py
0 → 100644
View file @
f92481f0
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.distributed
as
dist
from
torch.distributed.device_mesh
import
DeviceMesh
from
torch.distributed.fsdp.api
import
ShardedStateDictConfig
,
StateDictType
from
torch.distributed.fsdp.fully_sharded_data_parallel
import
FullyShardedDataParallel
as
FSDP
from
vllm
import
LLM
from
vllm.distributed
import
parallel_state
as
vllm_ps
from
verl
import
DataProto
from
verl.utils.performance
import
log_gpu_memory_usage
from
verl.workers.rollout.vllm_rollout
import
load_dtensor_weights
from
.base
import
BaseShardingManager
class
FSDPVLLMShardingManager
(
BaseShardingManager
):
def
__init__
(
self
,
module
:
FSDP
,
inference_engine
:
LLM
,
device_mesh
:
DeviceMesh
=
None
,
):
self
.
module
=
module
self
.
inference_engine
=
inference_engine
self
.
device_mesh
=
device_mesh
FSDP
.
set_state_dict_type
(
self
.
module
,
state_dict_type
=
StateDictType
.
SHARDED_STATE_DICT
,
state_dict_config
=
ShardedStateDictConfig
(),
)
# Note that torch_random_states may be different on each dp rank
self
.
torch_random_states
=
torch
.
cuda
.
get_rng_state
()
# get a random rng states
if
self
.
device_mesh
is
not
None
:
gen_dp_rank
=
self
.
device_mesh
[
"dp"
].
get_local_rank
()
torch
.
cuda
.
manual_seed
(
gen_dp_rank
+
1000
)
# make sure all tp ranks have the same random states
self
.
gen_random_states
=
torch
.
cuda
.
get_rng_state
()
torch
.
cuda
.
set_rng_state
(
self
.
torch_random_states
)
else
:
self
.
gen_random_states
=
None
def
__enter__
(
self
):
log_gpu_memory_usage
(
"Before state_dict() in sharding manager"
)
actor_weights
=
self
.
module
.
state_dict
()
log_gpu_memory_usage
(
"After state_dict() in sharding manager"
)
self
.
inference_engine
.
wake_up
()
load_dtensor_weights
(
actor_weights
,
self
.
inference_engine
.
llm_engine
.
model_executor
.
driver_worker
.
worker
.
model_runner
.
model
)
log_gpu_memory_usage
(
"After sync model weights in sharding manager"
)
del
actor_weights
torch
.
cuda
.
empty_cache
()
log_gpu_memory_usage
(
"After del state_dict and empty_cache in sharding manager"
)
# important: need to manually set the random states of each tp to be identical.
if
self
.
device_mesh
is
not
None
:
self
.
torch_random_states
=
torch
.
cuda
.
get_rng_state
()
torch
.
cuda
.
set_rng_state
(
self
.
gen_random_states
)
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
log_gpu_memory_usage
(
"Before vllm offload in sharding manager"
)
self
.
inference_engine
.
sleep
(
level
=
1
)
log_gpu_memory_usage
(
"After vllm offload in sharding manager"
)
self
.
module
.
train
()
torch
.
cuda
.
empty_cache
()
# add empty cache after each compute
# restore random states
if
self
.
device_mesh
is
not
None
:
self
.
gen_random_states
=
torch
.
cuda
.
get_rng_state
()
torch
.
cuda
.
set_rng_state
(
self
.
torch_random_states
)
def
preprocess_data
(
self
,
data
:
DataProto
)
->
DataProto
:
tp_group
=
vllm_ps
.
get_tensor_model_parallel_group
().
device_group
data
=
data
.
to
(
"cuda"
)
data
.
all_gather
(
tp_group
)
return
data
def
postprocess_data
(
self
,
data
:
DataProto
)
->
DataProto
:
dp_rank
=
dist
.
get_rank
()
tp_size
=
vllm_ps
.
get_tensor_model_parallel_world_size
()
if
tp_size
>
1
:
data
=
data
.
chunk
(
chunks
=
tp_size
)[
dp_rank
%
tp_size
]
return
data
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment