Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
EasyR1
Commits
ff7fb65e
Commit
ff7fb65e
authored
Apr 09, 2025
by
chenych
Browse files
Update
parent
c132cbcb
Changes
51
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
47 additions
and
391 deletions
+47
-391
verl/workers/actor/config.py
verl/workers/actor/config.py
+3
-1
verl/workers/actor/dp_actor.py
verl/workers/actor/dp_actor.py
+10
-5
verl/workers/critic/dp_critic.py
verl/workers/critic/dp_critic.py
+3
-3
verl/workers/fsdp_workers.py
verl/workers/fsdp_workers.py
+7
-3
verl/workers/reward/config.py
verl/workers/reward/config.py
+2
-1
verl/workers/reward/custom.py
verl/workers/reward/custom.py
+11
-7
verl/workers/rollout/__init__.py
verl/workers/rollout/__init__.py
+2
-1
verl/workers/rollout/config.py
verl/workers/rollout/config.py
+2
-3
verl/workers/rollout/vllm_rollout/__init__.py
verl/workers/rollout/vllm_rollout/__init__.py
+0
-18
verl/workers/rollout/vllm_rollout/dtensor_weight_loaders.py
verl/workers/rollout/vllm_rollout/dtensor_weight_loaders.py
+0
-340
verl/workers/rollout/vllm_rollout_spmd.py
verl/workers/rollout/vllm_rollout_spmd.py
+7
-9
No files found.
verl/workers/actor/config.py
View file @
ff7fb65e
...
...
@@ -72,7 +72,9 @@ class ActorConfig:
micro_batch_size_per_device_for_update
:
int
=
4
micro_batch_size_per_device_for_experience
:
int
=
16
max_grad_norm
:
float
=
1.0
clip_ratio
:
float
=
0.2
clip_ratio_low
:
float
=
0.2
clip_ratio_high
:
float
=
0.3
clip_ratio_dual
:
float
=
3.0
ppo_epochs
:
int
=
1
padding_free
:
bool
=
False
ulysses_sequence_parallel_size
:
int
=
1
...
...
verl/workers/actor/dp_actor.py
View file @
ff7fb65e
...
...
@@ -250,18 +250,21 @@ class DataParallelPPOActor(BasePPOActor):
# all return: (bsz, response_length)
log_probs
=
self
.
_forward_micro_batch
(
model_inputs
,
temperature
=
temperature
)
entropy_loss
=
-
VF
.
masked_mean
(
log_probs
,
response_mask
)
# estimator of entropy loss
pg_loss
,
pg_clipfrac
,
ppo_kl
=
core_algos
.
compute_policy_loss
(
pg_loss
,
pg_clipfrac
_higher
,
pg_clipfrac_lower
,
ppo_kl
=
core_algos
.
compute_policy_loss
(
old_log_probs
=
old_log_probs
,
log_probs
=
log_probs
,
advantages
=
advantages
,
eos_mask
=
response_mask
,
cliprange
=
self
.
config
.
clip_ratio
,
response_mask
=
response_mask
,
clip_ratio_low
=
self
.
config
.
clip_ratio_low
,
clip_ratio_high
=
self
.
config
.
clip_ratio_high
,
clip_ratio_dual
=
self
.
config
.
clip_ratio_dual
,
)
if
"ref_log_probs"
in
model_inputs
:
ref_log_probs
=
model_inputs
[
"ref_log_probs"
]
# compute kl loss
kld
=
core_algos
.
kl_penalty
(
kld
=
core_algos
.
compute_kl
(
log_probs
=
log_probs
,
ref_log_probs
=
ref_log_probs
,
kl_penalty
=
self
.
config
.
kl_penalty
,
...
...
@@ -276,7 +279,9 @@ class DataParallelPPOActor(BasePPOActor):
batch_metrics
=
{
"actor/pg_loss"
:
pg_loss
.
detach
().
item
(),
"actor/pg_clipfrac"
:
pg_clipfrac
.
detach
().
item
(),
"actor/pg_clipfrac_higher"
:
pg_clipfrac_higher
.
detach
().
item
(),
"actor/pg_clipfrac_lower"
:
pg_clipfrac_lower
.
detach
().
item
(),
"actor/entropy_loss"
:
entropy_loss
.
detach
().
item
(),
"actor/ppo_kl"
:
ppo_kl
.
detach
().
item
(),
}
append_to_dict
(
metrics
,
batch_metrics
)
...
...
verl/workers/critic/dp_critic.py
View file @
ff7fb65e
...
...
@@ -199,14 +199,14 @@ class DataParallelPPOCritic(BasePPOCritic):
values
=
model_inputs
[
"values"
]
returns
=
model_inputs
[
"returns"
]
response_length
=
responses
.
size
(
1
)
eos
_mask
=
attention_mask
[:,
-
response_length
-
1
:
-
1
]
# shift left for value computation
action
_mask
=
attention_mask
[:,
-
response_length
-
1
:
-
1
]
# shift left for value computation
vpreds
=
self
.
_forward_micro_batch
(
model_inputs
)
vf_loss
,
vf_clipfrac
=
core_algos
.
compute_value_loss
(
vpreds
=
vpreds
,
returns
=
returns
,
values
=
values
,
eos_mask
=
eos
_mask
,
action_mask
=
action
_mask
,
cliprange_value
=
self
.
config
.
cliprange_value
,
)
loss
=
vf_loss
/
gradient_accumulation
...
...
@@ -215,7 +215,7 @@ class DataParallelPPOCritic(BasePPOCritic):
batch_metrics
=
{
"critic/vf_loss"
:
vf_loss
.
detach
().
item
(),
"critic/vf_clipfrac"
:
vf_clipfrac
.
detach
().
item
(),
"critic/vpred_mean"
:
VF
.
masked_mean
(
vpreds
,
eos
_mask
).
detach
().
item
(),
"critic/vpred_mean"
:
VF
.
masked_mean
(
vpreds
,
action
_mask
).
detach
().
item
(),
}
append_to_dict
(
metrics
,
batch_metrics
)
...
...
verl/workers/fsdp_workers.py
View file @
ff7fb65e
...
...
@@ -57,7 +57,7 @@ from ..utils.torch_functional import AnyPrecisionAdamW, get_constant_schedule_wi
from
.actor
import
DataParallelPPOActor
from
.config
import
ActorConfig
,
CriticConfig
,
FSDPConfig
,
ModelConfig
,
OptimConfig
,
RefConfig
,
WorkerConfig
from
.critic
import
DataParallelPPOCritic
from
.rollout
.vllm_rollout
import
vLLMRollout
from
.rollout
import
vLLMRollout
from
.sharding_manager
import
FSDPVLLMShardingManager
from
.sharding_manager.fsdp_ulysses
import
FSDPUlyssesShardingManager
...
...
@@ -75,6 +75,10 @@ class FSDPWorker(Worker):
if
not
dist
.
is_initialized
():
dist
.
init_process_group
(
backend
=
"nccl"
)
# improve numerical stability
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cuda
.
matmul
.
allow_bf16_reduced_precision_reduction
=
False
self
.
_is_actor
=
self
.
role
in
[
"actor"
,
"actor_rollout"
,
"actor_rollout_ref"
]
self
.
_is_critic
=
self
.
role
==
"critic"
self
.
_is_rollout
=
self
.
role
in
[
"rollout"
,
"actor_rollout"
,
"actor_rollout_ref"
]
...
...
@@ -131,7 +135,7 @@ class FSDPWorker(Worker):
config
.
global_batch_size
//
self
.
device_mesh
.
size
()
*
config
.
ulysses_sequence_parallel_size
)
if
config
.
global_batch_size_per_device
==
0
:
raise
ValueError
(
f
"
{
role
}
global batch size must be larger than num gpus."
)
raise
ValueError
(
f
"
{
role
}
global batch size
* ulysses size
must be larger than num gpus."
)
if
config
.
global_batch_size_per_device
%
config
.
micro_batch_size_per_device_for_update
!=
0
:
raise
ValueError
(
f
"
{
role
}
global batch size per device must be divisible by the micro batch size."
)
...
...
@@ -413,7 +417,7 @@ class FSDPWorker(Worker):
if
self
.
_use_param_offload
:
offload_fsdp_model
(
self
.
fsdp_module
)
if
self
.
_use_optimizer_offload
:
if
self
.
_use_optimizer_offload
:
# avoid OOM in resuming
offload_fsdp_optimizer
(
self
.
optimizer
)
@
register
(
dispatch_mode
=
Dispatch
.
DP_COMPUTE_PROTO
)
...
...
verl/workers/reward/config.py
View file @
ff7fb65e
...
...
@@ -21,4 +21,5 @@ from dataclasses import dataclass
@
dataclass
class
RewardConfig
:
reward_type
:
str
=
"function"
compute_score
:
str
=
"math"
score_function
:
str
=
"math"
skip_special_tokens
:
bool
=
True
verl/workers/reward/custom.py
View file @
ff7fb65e
...
...
@@ -14,13 +14,14 @@
from
collections
import
defaultdict
from
typing
import
Any
,
Callable
,
Dict
,
Tuple
,
TypedDict
from
typing
import
Callable
,
Dict
,
List
,
Tuple
,
TypedDict
import
torch
from
transformers
import
PreTrainedTokenizer
from
...protocol
import
DataProto
from
...utils.reward_score
import
math_compute_score
,
r1v_compute_score
from
.config
import
RewardConfig
class
RewardScore
(
TypedDict
):
...
...
@@ -30,16 +31,17 @@ class RewardScore(TypedDict):
class
CustomRewardManager
:
def
__init__
(
self
,
tokenizer
:
PreTrainedTokenizer
,
compute_score
:
str
):
def
__init__
(
self
,
tokenizer
:
PreTrainedTokenizer
,
config
:
RewardConfig
):
self
.
config
=
config
self
.
tokenizer
=
tokenizer
if
co
mpute_score
==
"math"
:
if
co
nfig
.
score_function
==
"math"
:
self
.
compute_score
:
Callable
[[
str
,
str
],
RewardScore
]
=
math_compute_score
elif
co
mpute_score
==
"r1v"
:
elif
co
nfig
.
score_function
==
"r1v"
:
self
.
compute_score
:
Callable
[[
str
,
str
],
RewardScore
]
=
r1v_compute_score
else
:
raise
NotImplementedError
()
raise
NotImplementedError
(
f
"Unknown score function
{
config
.
score_function
}
."
)
def
__call__
(
self
,
data
:
DataProto
)
->
Tuple
[
torch
.
Tensor
,
Dict
[
str
,
Any
]]:
def
__call__
(
self
,
data
:
DataProto
)
->
Tuple
[
torch
.
Tensor
,
Dict
[
str
,
List
[
float
]
]]:
reward_tensor
=
torch
.
zeros_like
(
data
.
batch
[
"responses"
],
dtype
=
torch
.
float32
)
reward_metrics
=
defaultdict
(
list
)
for
i
in
range
(
len
(
data
)):
...
...
@@ -49,7 +51,9 @@ class CustomRewardManager:
valid_response_length
=
response_mask
.
sum
()
valid_response_ids
=
response_ids
[:
valid_response_length
]
response_str
=
self
.
tokenizer
.
decode
(
valid_response_ids
,
skip_special_tokens
=
True
)
response_str
=
self
.
tokenizer
.
decode
(
valid_response_ids
,
skip_special_tokens
=
self
.
config
.
skip_special_tokens
)
ground_truth
=
data_item
.
non_tensor_batch
[
"ground_truth"
]
score
=
self
.
compute_score
(
response_str
,
ground_truth
)
...
...
verl/workers/rollout/__init__.py
View file @
ff7fb65e
...
...
@@ -14,6 +14,7 @@
from
.config
import
RolloutConfig
from
.vllm_rollout_spmd
import
vLLMRollout
__all__
=
[
"RolloutConfig"
]
__all__
=
[
"RolloutConfig"
,
"vLLMRollout"
]
verl/workers/rollout/config.py
View file @
ff7fb65e
...
...
@@ -28,11 +28,10 @@ class RolloutConfig:
top_k
:
int
=
-
1
limit_images
:
int
=
0
dtype
:
str
=
"bf16"
gpu_memory_utilization
:
float
=
0.
5
gpu_memory_utilization
:
float
=
0.
6
ignore_eos
:
bool
=
False
enforce_eager
:
bool
=
False
free_cache_engine
:
bool
=
False
enable_chunked_prefill
:
bool
=
False
enable_chunked_prefill
:
bool
=
False
# only for v0 engine
tensor_parallel_size
:
int
=
2
max_num_batched_tokens
:
int
=
8192
max_num_seqs
:
int
=
1024
...
...
verl/workers/rollout/vllm_rollout/__init__.py
deleted
100644 → 0
View file @
c132cbcb
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.vllm_rollout_spmd
import
vLLMRollout
__all__
=
[
"vLLMRollout"
]
verl/workers/rollout/vllm_rollout/dtensor_weight_loaders.py
deleted
100644 → 0
View file @
c132cbcb
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader
from
typing
import
Dict
import
torch
import
torch.nn
as
nn
from
torch.distributed._tensor
import
DTensor
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.utils
import
is_pp_missing_parameter
def
gemma_dtensor_weight_loader
(
actor_weights
:
Dict
,
vllm_model
:
nn
.
Module
)
->
nn
.
Module
:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
vllm_model
.
named_parameters
())
for
name
,
loaded_weight
in
actor_weights
.
items
():
for
param_name
,
shard_name
,
shard_id
in
stacked_params_mapping
:
if
shard_name
not
in
name
:
continue
stacked_name
=
name
.
replace
(
shard_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
stacked_name
.
endswith
(
".bias"
)
and
stacked_name
not
in
params_dict
:
continue
local_loaded_weight
=
redistribute_dtensor
(
param_name
=
name
,
loaded_weights
=
loaded_weight
)
param
=
params_dict
[
stacked_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
local_loaded_weight
.
to
(
dtype
=
param
.
dtype
),
shard_id
)
break
else
:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if
"lm_head.weight"
in
name
:
continue
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
local_loaded_weight
=
redistribute_dtensor
(
param_name
=
name
,
loaded_weights
=
loaded_weight
)
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
local_loaded_weight
.
to
(
dtype
=
param
.
dtype
))
def
llama_dtensor_weight_loader
(
actor_weights
:
Dict
,
vllm_model
:
nn
.
Module
)
->
nn
.
Module
:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
vllm_model
.
named_parameters
())
for
name
,
loaded_weight
in
actor_weights
.
items
():
if
"rotary_emb.inv_freq"
in
name
:
continue
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if
vllm_model
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
local_loaded_weight
=
redistribute_dtensor
(
param_name
=
name
,
loaded_weights
=
loaded_weight
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
local_loaded_weight
.
to
(
dtype
=
param
.
dtype
),
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
local_loaded_weight
=
redistribute_dtensor
(
param_name
=
name
,
loaded_weights
=
loaded_weight
)
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
local_loaded_weight
)
def
qwen2_dtensor_weight_loader
(
actor_weights
:
Dict
,
vllm_model
:
nn
.
Module
)
->
nn
.
Module
:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
vllm_model
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
actor_weights
.
items
():
if
"rotary_emb.inv_freq"
in
name
:
continue
if
vllm_model
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
local_loaded_weight
=
redistribute_dtensor
(
param_name
=
name
,
loaded_weights
=
loaded_weight
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
local_loaded_weight
.
to
(
dtype
=
param
.
dtype
),
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
local_loaded_weight
=
redistribute_dtensor
(
param_name
=
name
,
loaded_weights
=
loaded_weight
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
local_loaded_weight
.
to
(
dtype
=
param
.
dtype
))
def
qwen2vl_dtensor_weight_loader
(
actor_weights
:
Dict
[
str
,
torch
.
Tensor
],
vllm_model
:
nn
.
Module
)
->
nn
.
Module
:
stacked_params_mapping
=
[
# (vllm_substr, hf_substr, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
vllm_params
=
dict
(
vllm_model
.
named_parameters
(
remove_duplicate
=
False
))
for
actor_name
,
actor_weight
in
actor_weights
.
items
():
if
"rotary_emb.inv_freq"
in
actor_name
:
continue
if
vllm_model
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
actor_name
:
continue
for
vllm_substr
,
hf_substr
,
shard_id
in
stacked_params_mapping
:
if
hf_substr
not
in
actor_name
:
continue
if
"visual"
in
actor_name
:
continue
vllm_name
=
"language_model."
+
actor_name
.
replace
(
hf_substr
,
vllm_substr
)
if
actor_name
.
endswith
(
".bias"
)
and
actor_name
not
in
vllm_params
:
continue
# skip loading extra bias for GPTQ models
local_actor_weight
=
redistribute_dtensor
(
param_name
=
actor_name
,
loaded_weights
=
actor_weight
)
vllm_param
=
vllm_params
[
vllm_name
]
weight_loader
=
vllm_param
.
weight_loader
weight_loader
(
vllm_param
,
local_actor_weight
.
to
(
dtype
=
vllm_param
.
dtype
),
shard_id
)
break
else
:
if
actor_name
.
endswith
(
".bias"
)
and
actor_name
not
in
vllm_params
:
continue
# skip loading extra bias for GPTQ models
if
"visual"
in
actor_name
:
vllm_name
=
actor_name
else
:
vllm_name
=
"language_model."
+
actor_name
vllm_param
=
vllm_params
[
vllm_name
]
local_actor_weight
=
redistribute_dtensor
(
param_name
=
actor_name
,
loaded_weights
=
actor_weight
)
weight_loader
=
getattr
(
vllm_param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
vllm_param
,
local_actor_weight
.
to
(
dtype
=
vllm_param
.
dtype
))
def
deepseekv2_dtensor_weight_loader
(
actor_weights
:
Dict
,
vllm_model
:
nn
.
Module
)
->
nn
.
Module
:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
vllm_model
.
config
.
n_routed_experts
,
)
params_dict
=
dict
(
vllm_model
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
actor_weights
.
items
():
if
"rotary_emb.inv_freq"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
# Skip non-stacked layers and experts (experts handled below).
if
weight_name
not
in
name
:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if
(
"mlp.experts."
in
name
)
and
name
not
in
params_dict
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
vllm_model
):
continue
param
=
params_dict
[
name
]
local_loaded_weight
=
redistribute_dtensor
(
param_name
=
name
,
loaded_weights
=
loaded_weight
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
local_loaded_weight
.
to
(
dtype
=
param
.
dtype
),
shard_id
)
break
else
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
vllm_model
):
continue
param
=
params_dict
[
name
]
local_loaded_weight
=
redistribute_dtensor
(
param_name
=
name
,
loaded_weights
=
loaded_weight
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
local_loaded_weight
.
to
(
dtype
=
param
.
dtype
),
weight_name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
vllm_model
):
continue
param
=
params_dict
[
name
]
local_loaded_weight
=
redistribute_dtensor
(
param_name
=
name
,
loaded_weights
=
loaded_weight
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
local_loaded_weight
.
to
(
dtype
=
param
.
dtype
))
def
redistribute_dtensor
(
param_name
:
str
,
loaded_weights
:
DTensor
,
parallelize_plan
:
Dict
=
None
):
param_name
=
_process_parameter_names
(
name
=
param_name
)
if
parallelize_plan
is
not
None
:
assert
param_name
in
parallelize_plan
.
keys
(),
(
f
"param name:
{
param_name
}
not in parallelize_plan :
{
parallelize_plan
.
keys
()
}
"
)
placement
=
parallelize_plan
[
param_name
]
local_loaded_weights
=
loaded_weights
.
redistribute
(
device_mesh
=
loaded_weights
.
device_mesh
,
placements
=
placement
).
to_local
()
else
:
local_loaded_weights
=
loaded_weights
.
full_tensor
()
return
local_loaded_weights
def
_process_parameter_names
(
name
):
# Remove '.weight' if it exists at the end of the string
if
name
.
endswith
(
".weight"
):
name
=
name
[:
-
7
]
# Remove 'model.layers.x.' or 'model.' prefix
if
"model.layers"
in
name
:
parts
=
name
.
split
(
"."
)
# Reconstruct the string without 'model.layers.x.'
name
=
"."
.
join
(
parts
[
3
:])
# parts[0] is 'model', parts[1] is 'layers', parts[2] is 'x'
elif
name
.
startswith
(
"model."
):
name
=
name
[
6
:]
# Remove 'model.'
return
name
__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__
=
{
"LlamaForCausalLM"
:
llama_dtensor_weight_loader
,
"LLaMAForCausalLM"
:
llama_dtensor_weight_loader
,
"MistralForCausalLM"
:
llama_dtensor_weight_loader
,
# mistral is the same as llama in vLLM
"InternLMForCausalLM"
:
llama_dtensor_weight_loader
,
"Phi3ForCausalLM"
:
llama_dtensor_weight_loader
,
"GemmaForCausalLM"
:
gemma_dtensor_weight_loader
,
"Gemma2ForCausalLM"
:
gemma_dtensor_weight_loader
,
"Qwen2ForCausalLM"
:
qwen2_dtensor_weight_loader
,
"DeepseekV2ForCausalLM"
:
deepseekv2_dtensor_weight_loader
,
"Qwen2VLForConditionalGeneration"
:
qwen2vl_dtensor_weight_loader
,
"Qwen2_5_VLForConditionalGeneration"
:
qwen2vl_dtensor_weight_loader
,
}
# the actor model is .state_dict()
# Load dtensor weights
def
load_dtensor_weights
(
actor_weights
:
Dict
,
vllm_model
:
nn
.
Module
):
weight_loader
=
_get_model_weight_loader
(
vllm_model
.
__class__
.
__name__
)
weight_loader
(
actor_weights
,
vllm_model
)
# NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu
# after init, and we need this after sync model weights for in first iter.
vllm_model
=
vllm_model
.
cuda
()
def
_get_model_weight_loader
(
arch
:
str
):
if
arch
in
__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__
:
return
__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__
[
arch
]
raise
ValueError
(
f
"Model architectures
{
arch
}
are not supported for now. "
f
"Supported architectures:
{
__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__
.
keys
()
}
"
)
# NOTE(sgm): we use per-parameter weight loader in each vllm sub
def
update_dtensor_weight_loader
():
pass
verl/workers/rollout/vllm_rollout
/vllm_rollout
_spmd.py
→
verl/workers/rollout/vllm_rollout_spmd.py
View file @
ff7fb65e
...
...
@@ -29,11 +29,11 @@ from tensordict import TensorDict
from
transformers
import
PreTrainedTokenizer
from
vllm
import
LLM
,
RequestOutput
,
SamplingParams
from
...
.
protocol
import
DataProto
from
...
.
utils
import
torch_functional
as
VF
from
...
.
utils.torch_dtypes
import
PrecisionType
from
.
.base
import
BaseRollout
from
.
.config
import
RolloutConfig
from
...protocol
import
DataProto
from
...utils
import
torch_functional
as
VF
from
...utils.torch_dtypes
import
PrecisionType
from
.base
import
BaseRollout
from
.config
import
RolloutConfig
def
_repeat_interleave
(
value
:
Union
[
torch
.
Tensor
,
np
.
ndarray
],
repeats
:
int
)
->
Union
[
torch
.
Tensor
,
List
[
Any
]]:
...
...
@@ -59,9 +59,6 @@ class vLLMRollout(BaseRollout):
if
config
.
tensor_parallel_size
>
torch
.
distributed
.
get_world_size
():
raise
ValueError
(
"Tensor parallelism size should be less than world size."
)
if
not
config
.
enforce_eager
and
config
.
free_cache_engine
:
raise
ValueError
(
"CUDA graph should be disabled when `free_cache_engine` is True."
)
if
config
.
max_num_batched_tokens
<
config
.
prompt_length
+
config
.
response_length
:
raise
ValueError
(
"max_num_batched_tokens should be greater than prompt_length + response_length."
)
...
...
@@ -84,6 +81,7 @@ class vLLMRollout(BaseRollout):
disable_mm_preprocessor_cache
=
True
,
disable_log_stats
=
config
.
disable_log_stats
,
enable_chunked_prefill
=
config
.
enable_chunked_prefill
,
seed
=
self
.
rank
//
config
.
tensor_parallel_size
,
# dp rank
**
vllm_init_kwargs
,
)
...
...
@@ -171,7 +169,7 @@ class vLLMRollout(BaseRollout):
# position_ids: [0,0,0,0,0,1,2,3 | 4,5,6,7,8,9,10,11]
response_position_ids
=
position_ids
[...,
-
1
:]
+
delta_position_id
position_ids
=
torch
.
cat
([
position_ids
,
response_position_ids
],
dim
=-
1
)
response_mask
=
VF
.
get_
eos
_mask
(
response_mask
=
VF
.
get_
response
_mask
(
response_ids
=
response_ids
,
eos_token_id
=
eos_token_id
,
dtype
=
attention_mask
.
dtype
)
attention_mask
=
torch
.
cat
((
attention_mask
,
response_mask
),
dim
=-
1
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment