Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cd42bf87
Commit
cd42bf87
authored
Dec 27, 2025
by
zhuwenwen
Browse files
Merge remote-tracking branch 'origin/v0.9.2-dev-wm-1218' into v0.9.2-dev
parents
43546076
9925dd0e
Changes
27
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
383 additions
and
141 deletions
+383
-141
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+14
-2
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+232
-83
vllm/model_executor/parameter.py
vllm/model_executor/parameter.py
+9
-0
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+4
-2
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+95
-38
vllm/zero_overhead/v1/eagle.py
vllm/zero_overhead/v1/eagle.py
+4
-2
vllm/zero_overhead/v1/gpu_model_runner.py
vllm/zero_overhead/v1/gpu_model_runner.py
+25
-14
No files found.
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
cd42bf87
...
@@ -39,6 +39,20 @@ def get_w8a8_int8_marlin_weights(
...
@@ -39,6 +39,20 @@ def get_w8a8_int8_marlin_weights(
return
weight
return
weight
def
w8a8_nt_kpack2_marlin_weight
(
w8a8_w
,
# [size_n, size_k// 2 ]
k_tile
=
16
,
n_tile
=
16
,
):
assert
w8a8_w
.
dtype
==
torch
.
int8
,
"w8a8_w 必须是 int8 类型"
size_n
,
size_k
=
w8a8_w
.
shape
assert
size_n
%
k_tile
==
0
and
size_k
%
n_tile
==
0
,
"k_tile / n_tile 必须能整除对应维度"
w8a8_w
=
w8a8_w
.
reshape
((
size_n
//
n_tile
,
n_tile
,
size_k
//
k_tile
,
k_tile
))
w8a8_w
=
w8a8_w
.
permute
((
0
,
2
,
1
,
3
)).
contiguous
()
w8a8_w
=
w8a8_w
.
reshape
((
size_n
//
k_tile
,
size_k
*
k_tile
))
return
w8a8_w
def
sparse_cutlass_supported
()
->
bool
:
def
sparse_cutlass_supported
()
->
bool
:
if
not
current_platform
.
is_cuda
():
if
not
current_platform
.
is_cuda
():
return
False
return
False
...
@@ -455,12 +469,10 @@ def apply_int8_linear(
...
@@ -455,12 +469,10 @@ def apply_int8_linear(
elif
f
"1_
{
n
}
_
{
k
}
"
in
W8A8_TRITONJSON
.
triton_json_dict
:
elif
f
"1_
{
n
}
_
{
k
}
"
in
W8A8_TRITONJSON
.
triton_json_dict
:
if
m
<=
16
:
if
m
<=
16
:
m_
=
m
m_
=
m
elif
m
<=
64
:
elif
m
<=
64
:
m_
=
(
m
//
4
)
*
4
#取值到最近的4的倍数
m_
=
(
m
//
4
)
*
4
#取值到最近的4的倍数
elif
m
<=
160
:
elif
m
<=
160
:
m_
=
(
m
//
8
)
*
8
m_
=
(
m
//
8
)
*
8
elif
m
<
200
:
#256
elif
m
<
200
:
#256
m_
=
160
m_
=
160
elif
m
<
480
:
#512
elif
m
<
480
:
#512
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
cd42bf87
...
@@ -40,9 +40,14 @@ from vllm.compilation.decorators import support_torch_compile
...
@@ -40,9 +40,14 @@ from vllm.compilation.decorators import support_torch_compile
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
VllmConfig
,
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
VllmConfig
,
get_current_vllm_config
)
get_current_vllm_config
)
from
vllm.distributed
import
(
get_ep_group
,
get_pp_group
,
get_dp_group
,
from
vllm.distributed
import
(
get_ep_group
,
get_pp_group
,
get_dp_group
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_gather
,
get_tensor_model_parallel_rank
,
tensor_model_parallel_reduce_scatter
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
,
SharedFusedMoE
from
vllm.model_executor.layers.fused_moe.utils
import
EPSharedExperts
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
...
@@ -175,6 +180,13 @@ class DeepseekV2MoE(nn.Module):
...
@@ -175,6 +180,13 @@ class DeepseekV2MoE(nn.Module):
self
.
physical_expert_end
=
(
self
.
physical_expert_start
+
self
.
physical_expert_end
=
(
self
.
physical_expert_start
+
self
.
n_local_physical_experts
)
self
.
n_local_physical_experts
)
dp_size
=
get_dp_group
().
world_size
self
.
enable_expert_parallel
=
parallel_config
.
enable_expert_parallel
self
.
use_deepep
=
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
and
\
(
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_high_throughput"
or
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
)
if
not
self
.
use_deepep
:
self
.
experts
=
FusedMoE
(
self
.
experts
=
FusedMoE
(
num_experts
=
config
.
n_routed_experts
,
num_experts
=
config
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
...
@@ -205,6 +217,36 @@ class DeepseekV2MoE(nn.Module):
...
@@ -205,6 +217,36 @@ class DeepseekV2MoE(nn.Module):
),
),
prefix
=
f
"
{
prefix
}
.shared_experts"
,
prefix
=
f
"
{
prefix
}
.shared_experts"
,
)
)
else
:
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
(
config
.
moe_intermediate_size
*
config
.
n_shared_experts
)
self
.
shared_experts
=
EPSharedExperts
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
reduce_results
=
False
,
prefix
=
f
"
{
prefix
}
.shared_experts"
,
)
self
.
experts
=
SharedFusedMoE
(
num_experts
=
config
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
topk_group
=
config
.
topk_group
,
prefix
=
f
"
{
prefix
}
.experts"
,
scoring_func
=
config
.
scoring_func
,
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
shared_experts
=
self
.
shared_experts
)
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
from
vllm.two_batch_overlap.two_batch_overlap
import
tbo_all_reduce
self
.
tbo_all_reduce
=
tbo_all_reduce
self
.
tbo_all_reduce
=
tbo_all_reduce
...
@@ -215,9 +257,10 @@ class DeepseekV2MoE(nn.Module):
...
@@ -215,9 +257,10 @@ class DeepseekV2MoE(nn.Module):
xqxs
:
Optional
[
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
xqxs
:
Optional
[
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
)
->
Union
[
torch
.
Tensor
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]]:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]]:
if
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
xqxs
is
not
None
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
xqxs
is
not
None
:
if
self
.
n_shared_experts
is
not
None
:
if
self
.
n_shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
,
xqxs
=
xqxs
)
shared_output
=
self
.
shared_experts
(
hidden_states
,
xqxs
=
xqxs
)
...
@@ -257,8 +300,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -257,8 +300,7 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states
))
final_hidden_states
))
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
else
:
else
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
if
not
self
.
enable_expert_parallel
:
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
i_q
,
i_s
=
None
,
None
i_q
,
i_s
=
None
,
None
if
self
.
n_shared_experts
is
not
None
:
if
self
.
n_shared_experts
is
not
None
:
if
envs
.
USE_FUSED_RMS_QUANT
:
if
envs
.
USE_FUSED_RMS_QUANT
:
...
@@ -268,7 +310,6 @@ class DeepseekV2MoE(nn.Module):
...
@@ -268,7 +310,6 @@ class DeepseekV2MoE(nn.Module):
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
if
envs
.
VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD
:
if
envs
.
VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD
:
final_hidden_states
=
self
.
experts
(
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
...
@@ -296,6 +337,40 @@ class DeepseekV2MoE(nn.Module):
...
@@ -296,6 +337,40 @@ class DeepseekV2MoE(nn.Module):
# See DeepseekV2DecoderLayer for more details.
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
final_hidden_states
+
shared_output
\
final_hidden_states
=
final_hidden_states
+
shared_output
\
*
(
1.
/
self
.
routed_scaling_factor
)
*
(
1.
/
self
.
routed_scaling_factor
)
else
:
router_logits
,
_
=
self
.
gate
(
hidden_states
)
if
self
.
use_deepep
:
shared_output
,
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
shared_output
is
not
None
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
=
final_hidden_states
+
shared_output
else
:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
final_hidden_states
+
shared_output
\
*
(
1.
/
self
.
routed_scaling_factor
)
else
:
if
self
.
n_shared_experts
is
not
None
:
if
envs
.
USE_FUSED_RMS_QUANT
:
shared_output
,
new_resi
=
self
.
shared_experts
(
hidden_states
,
rms_weight
,
residual
,
update_hd
=
True
)
else
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
shared_output
is
not
None
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
=
final_hidden_states
+
shared_output
else
:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
final_hidden_states
+
shared_output
\
*
(
1.
/
self
.
routed_scaling_factor
)
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
if
envs
.
VLLM_ENABLE_TBO
:
if
envs
.
VLLM_ENABLE_TBO
:
...
@@ -336,6 +411,7 @@ class DeepseekV2Attention(nn.Module):
...
@@ -336,6 +411,7 @@ class DeepseekV2Attention(nn.Module):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
reduce_results
:
bool
=
True
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -394,7 +470,8 @@ class DeepseekV2Attention(nn.Module):
...
@@ -394,7 +470,8 @@ class DeepseekV2Attention(nn.Module):
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
)
prefix
=
f
"
{
prefix
}
.o_proj"
,
reduce_results
=
reduce_results
)
if
rope_scaling
:
if
rope_scaling
:
rope_scaling
[
"rope_type"
]
=
'deepseek_yarn'
rope_scaling
[
"rope_type"
]
=
'deepseek_yarn'
...
@@ -488,6 +565,7 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -488,6 +565,7 @@ class DeepseekV2MLAAttention(nn.Module):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
reduce_results
:
bool
=
True
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -565,7 +643,8 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -565,7 +643,8 @@ class DeepseekV2MLAAttention(nn.Module):
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
)
prefix
=
f
"
{
prefix
}
.o_proj"
,
reduce_results
=
reduce_results
)
if
rope_scaling
:
if
rope_scaling
:
rope_scaling
[
"rope_type"
]
=
'deepseek_yarn'
rope_scaling
[
"rope_type"
]
=
'deepseek_yarn'
...
@@ -803,6 +882,44 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -803,6 +882,44 @@ class DeepseekV2DecoderLayer(nn.Module):
# with the layer's index.
# with the layer's index.
layer_idx
=
int
(
prefix
.
split
(
sep
=
'.'
)[
-
1
])
layer_idx
=
int
(
prefix
.
split
(
sep
=
'.'
)[
-
1
])
self
.
layer_idx
=
layer_idx
self
.
layer_idx
=
layer_idx
self
.
dp_size
=
get_dp_group
().
world_size
vllm_config
=
get_current_vllm_config
()
parallel_config
=
vllm_config
.
parallel_config
self
.
use_deepep
=
self
.
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
and
\
(
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_high_throughput"
or
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
config
=
config
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
if
(
config
.
n_routed_experts
is
not
None
and
layer_idx
>=
config
.
first_k_dense_replace
and
layer_idx
%
config
.
moe_layer_freq
==
0
):
self
.
mlp
=
DeepseekV2MoE
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
enable_eplb
=
enable_eplb
,
)
else
:
self
.
mlp
=
DeepseekV2MLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
self
.
is_mtp_layer
=
False
if
self
.
layer_idx
==
config
.
num_hidden_layers
:
self
.
is_mtp_layer
=
True
reduce_results
=
True
if
isinstance
(
self
.
mlp
,
DeepseekV2MoE
)
and
self
.
use_deepep
and
\
self
.
tp_size
>
1
and
not
self
.
is_mtp_layer
:
reduce_results
=
False
if
model_config
.
use_mla
:
if
model_config
.
use_mla
:
attn_cls
=
DeepseekV2MLAAttention
attn_cls
=
DeepseekV2MLAAttention
else
:
else
:
...
@@ -823,25 +940,9 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -823,25 +940,9 @@ class DeepseekV2DecoderLayer(nn.Module):
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
reduce_results
=
reduce_results
)
)
if
(
config
.
n_routed_experts
is
not
None
and
layer_idx
>=
config
.
first_k_dense_replace
and
layer_idx
%
config
.
moe_layer_freq
==
0
):
self
.
mlp
=
DeepseekV2MoE
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
enable_eplb
=
enable_eplb
,
)
else
:
self
.
mlp
=
DeepseekV2MLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
...
@@ -850,6 +951,8 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -850,6 +951,8 @@ class DeepseekV2DecoderLayer(nn.Module):
self
.
use_fused_rms_quant
=
envs
.
USE_FUSED_RMS_QUANT
self
.
use_fused_rms_quant
=
envs
.
USE_FUSED_RMS_QUANT
self
.
use_fused_custom_all_reduce
=
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
self
.
use_fused_custom_all_reduce
=
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
def
forward_fused_rmsquant
(
def
forward_fused_rmsquant
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
@@ -956,11 +1059,27 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -956,11 +1059,27 @@ class DeepseekV2DecoderLayer(nn.Module):
else
:
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
)
if
not
self
.
is_mtp_layer
:
if
isinstance
(
self
.
mlp
,
DeepseekV2MoE
)
and
self
.
use_deepep
and
self
.
tp_size
>
1
and
\
self
.
layer_idx
>
self
.
config
.
first_k_dense_replace
:
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
,
dim
=
0
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
)
)
if
not
self
.
is_mtp_layer
:
if
isinstance
(
self
.
mlp
,
DeepseekV2MoE
)
and
self
.
use_deepep
and
self
.
tp_size
>
1
:
if
self
.
layer_idx
==
self
.
config
.
first_k_dense_replace
:
residual
=
residual
.
tensor_split
(
self
.
tp_size
)[
self
.
tp_rank
]
hidden_states
=
tensor_model_parallel_reduce_scatter
(
hidden_states
,
dim
=
0
)
if
hidden_states
.
dtype
==
torch
.
float16
:
if
hidden_states
.
dtype
==
torch
.
float16
:
# Fix FP16 overflow
# Fix FP16 overflow
# We scale both hidden_states and residual before
# We scale both hidden_states and residual before
...
@@ -974,8 +1093,26 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -974,8 +1093,26 @@ class DeepseekV2DecoderLayer(nn.Module):
# Fully Connected
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
)
if
self
.
is_mtp_layer
:
if
isinstance
(
self
.
mlp
,
DeepseekV2MoE
)
and
self
.
use_deepep
and
self
.
tp_size
>
1
:
ori_bs
=
hidden_states
.
shape
[
0
]
pad_size
=
(
ori_bs
+
self
.
tp_size
-
1
)
//
self
.
tp_size
*
self
.
tp_size
-
ori_bs
if
pad_size
>
0
:
hidden_states
=
torch
.
nn
.
functional
.
pad
(
hidden_states
.
contiguous
(),
[
0
,
0
,
0
,
pad_size
],
value
=
0
).
contiguous
()
new_bs
=
(
ori_bs
+
pad_size
)
//
self
.
tp_size
hidden_states
=
hidden_states
[
self
.
tp_rank
*
new_bs
:
(
self
.
tp_rank
+
1
)
*
new_bs
,
:].
contiguous
()
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
if
self
.
is_mtp_layer
:
if
isinstance
(
self
.
mlp
,
DeepseekV2MoE
)
and
self
.
use_deepep
and
self
.
tp_size
>
1
:
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
,
dim
=
0
)
hidden_states
=
hidden_states
[:
ori_bs
,
:]
if
isinstance
(
self
.
mlp
,
if
isinstance
(
self
.
mlp
,
DeepseekV2MLP
)
and
hidden_states
.
dtype
==
torch
.
float16
:
DeepseekV2MLP
)
and
hidden_states
.
dtype
==
torch
.
float16
:
# Fix FP16 overflow
# Fix FP16 overflow
...
@@ -1052,6 +1189,14 @@ class DeepseekV2Model(nn.Module):
...
@@ -1052,6 +1189,14 @@ class DeepseekV2Model(nn.Module):
make_empty_intermediate_tensors_factory
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
self
.
dp_size
=
get_dp_group
().
world_size
vllm_config
=
get_current_vllm_config
()
parallel_config
=
vllm_config
.
parallel_config
self
.
use_deepep
=
self
.
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
and
\
(
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_high_throughput"
or
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
return
self
.
embed_tokens
(
input_ids
)
...
@@ -1083,6 +1228,10 @@ class DeepseekV2Model(nn.Module):
...
@@ -1083,6 +1228,10 @@ class DeepseekV2Model(nn.Module):
})
})
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
if
self
.
use_deepep
and
self
.
tp_size
>
1
:
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
,
dim
=
0
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/parameter.py
View file @
cd42bf87
...
@@ -96,6 +96,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
...
@@ -96,6 +96,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
def
__init__
(
self
,
output_dim
:
int
,
**
kwargs
):
def
__init__
(
self
,
output_dim
:
int
,
**
kwargs
):
self
.
_output_dim
=
output_dim
self
.
_output_dim
=
output_dim
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
expect_tp_size
=
-
1
@
property
@
property
def
output_dim
(
self
):
def
output_dim
(
self
):
...
@@ -103,6 +105,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
...
@@ -103,6 +105,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
def
load_column_parallel_weight
(
self
,
loaded_weight
:
torch
.
Tensor
):
def
load_column_parallel_weight
(
self
,
loaded_weight
:
torch
.
Tensor
):
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
if
self
.
expect_tp_size
==
1
:
tp_rank
=
0
shard_size
=
self
.
data
.
shape
[
self
.
output_dim
]
shard_size
=
self
.
data
.
shape
[
self
.
output_dim
]
loaded_weight
=
loaded_weight
.
narrow
(
self
.
output_dim
,
loaded_weight
=
loaded_weight
.
narrow
(
self
.
output_dim
,
tp_rank
*
shard_size
,
shard_size
)
tp_rank
*
shard_size
,
shard_size
)
...
@@ -123,6 +127,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
...
@@ -123,6 +127,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
param_data
=
self
.
data
param_data
=
self
.
data
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
if
self
.
expect_tp_size
==
1
:
tp_rank
=
0
param_data
=
param_data
.
narrow
(
self
.
output_dim
,
shard_offset
,
param_data
=
param_data
.
narrow
(
self
.
output_dim
,
shard_offset
,
shard_size
)
shard_size
)
loaded_weight
=
loaded_weight
.
narrow
(
self
.
output_dim
,
loaded_weight
=
loaded_weight
.
narrow
(
self
.
output_dim
,
...
@@ -167,6 +173,7 @@ class RowvLLMParameter(BasevLLMParameter):
...
@@ -167,6 +173,7 @@ class RowvLLMParameter(BasevLLMParameter):
def
__init__
(
self
,
input_dim
:
int
,
**
kwargs
):
def
__init__
(
self
,
input_dim
:
int
,
**
kwargs
):
self
.
_input_dim
=
input_dim
self
.
_input_dim
=
input_dim
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
expect_tp_size
=
-
1
@
property
@
property
def
input_dim
(
self
):
def
input_dim
(
self
):
...
@@ -174,6 +181,8 @@ class RowvLLMParameter(BasevLLMParameter):
...
@@ -174,6 +181,8 @@ class RowvLLMParameter(BasevLLMParameter):
def
load_row_parallel_weight
(
self
,
loaded_weight
:
torch
.
Tensor
):
def
load_row_parallel_weight
(
self
,
loaded_weight
:
torch
.
Tensor
):
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
if
self
.
expect_tp_size
==
1
:
tp_rank
=
0
shard_size
=
self
.
data
.
shape
[
self
.
input_dim
]
shard_size
=
self
.
data
.
shape
[
self
.
input_dim
]
loaded_weight
=
loaded_weight
.
narrow
(
self
.
input_dim
,
loaded_weight
=
loaded_weight
.
narrow
(
self
.
input_dim
,
tp_rank
*
shard_size
,
shard_size
)
tp_rank
*
shard_size
,
shard_size
)
...
...
vllm/v1/spec_decode/eagle.py
View file @
cd42bf87
...
@@ -25,6 +25,7 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata, MLACommonDe
...
@@ -25,6 +25,7 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata, MLACommonDe
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.spec_decode.utils
import
prepare_eagle_input_kernel
from
vllm.v1.spec_decode.utils
import
prepare_eagle_input_kernel
from
vllm.utils
import
round_up
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -186,6 +187,7 @@ class EagleProposer:
...
@@ -186,6 +187,7 @@ class EagleProposer:
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_tokens
)
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_tokens
)
else
:
else
:
num_input_tokens
=
num_tokens
num_input_tokens
=
num_tokens
# copy inputs to buffer for cudagraph
# copy inputs to buffer for cudagraph
self
.
positions
[:
num_tokens
]
=
target_positions
self
.
positions
[:
num_tokens
]
=
target_positions
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
...
@@ -224,8 +226,8 @@ class EagleProposer:
...
@@ -224,8 +226,8 @@ class EagleProposer:
with
set_forward_context
(
per_layer_attn_metadata
,
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens
=
num_input_tokens
):
skip_cuda_graphs
=
not
decoding
):
#
skip_cuda_graphs=not decoding):
ret_hidden_states
=
self
.
model
(
ret_hidden_states
=
self
.
model
(
self
.
input_ids
[:
num_input_tokens
],
self
.
input_ids
[:
num_input_tokens
],
self
.
positions
[:
num_input_tokens
],
self
.
positions
[:
num_input_tokens
],
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
cd42bf87
...
@@ -28,7 +28,8 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
...
@@ -28,7 +28,8 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
from
vllm.distributed.kv_transfer.kv_connector.v1
import
KVConnectorBase_V1
from
vllm.distributed.kv_transfer.kv_connector.v1
import
KVConnectorBase_V1
from
vllm.distributed.parallel_state
import
(
from
vllm.distributed.parallel_state
import
(
get_pp_group
,
get_tp_group
,
graph_capture
,
is_global_first_rank
,
get_pp_group
,
get_tp_group
,
graph_capture
,
is_global_first_rank
,
prepare_communication_buffer_for_model
)
prepare_communication_buffer_for_model
,
get_tensor_model_parallel_world_size
)
from
vllm.forward_context
import
(
DPMetadata
,
get_forward_context
,
from
vllm.forward_context
import
(
DPMetadata
,
get_forward_context
,
set_forward_context
,
set_profilling
)
set_forward_context
,
set_profilling
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -46,7 +47,7 @@ from vllm.sequence import IntermediateTensors
...
@@ -46,7 +47,7 @@ from vllm.sequence import IntermediateTensors
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
GiB_bytes
,
LazyLoader
,
async_tensor_h2d
,
cdiv
,
GiB_bytes
,
LazyLoader
,
async_tensor_h2d
,
cdiv
,
check_use_alibi
,
get_dtype_size
,
check_use_alibi
,
get_dtype_size
,
is_pin_memory_available
,
round_up
)
is_pin_memory_available
,
round_up
,
round_down
)
from
vllm.v1.attention.backends.mamba_attn
import
Mamba2AttentionBackend
from
vllm.v1.attention.backends.mamba_attn
import
Mamba2AttentionBackend
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
)
CommonAttentionMetadata
)
...
@@ -331,6 +332,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -331,6 +332,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self
.
draft_probs
:
Optional
[
DraftProbs
]
=
None
self
.
draft_probs
:
Optional
[
DraftProbs
]
=
None
self
.
ep_sp
=
False
self
.
dp_size
=
self
.
parallel_config
.
data_parallel_size
self
.
tp_size
=
self
.
parallel_config
.
tensor_parallel_size
self
.
enable_expert_parallel
=
self
.
parallel_config
.
enable_expert_parallel
if
self
.
enable_expert_parallel
and
self
.
dp_size
>
1
and
self
.
tp_size
>
1
:
self
.
ep_sp
=
True
def
_may_reorder_batch
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
def
_may_reorder_batch
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
"""
"""
Update the order of requests in the batch based on the attention
Update the order of requests in the batch based on the attention
...
@@ -1267,7 +1275,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -1267,7 +1275,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# TODO(tms) : There are many cases where padding is enabled for
# TODO(tms) : There are many cases where padding is enabled for
# prefills, causing unnecessary and excessive padding of activations.
# prefills, causing unnecessary and excessive padding of activations.
if
dp_size
==
1
or
self
.
vllm_config
.
model_config
.
enforce_eager
:
if
dp_size
==
1
or
self
.
vllm_config
.
model_config
.
enforce_eager
or
envs
.
VLLM_ALL2ALL_BACKEND
!=
'naive'
:
# Early exit.
# Early exit.
return
0
,
None
return
0
,
None
...
@@ -1344,6 +1352,17 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -1344,6 +1352,17 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
spec_decode_metadata
,
spec_decode_metadata
,
num_scheduled_tokens_np
)
=
(
self
.
_prepare_inputs
(
scheduler_output
))
num_scheduled_tokens_np
)
=
(
self
.
_prepare_inputs
(
scheduler_output
))
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if
self
.
ep_sp
:
num_input_tokens
=
round_up
(
num_scheduled_tokens
,
self
.
tp_size
)
if
(
self
.
use_cuda_graph
and
num_input_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_input_tokens
)
else
:
if
(
self
.
use_cuda_graph
if
(
self
.
use_cuda_graph
and
num_scheduled_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
and
num_scheduled_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
# Use piecewise CUDA graphs.
# Use piecewise CUDA graphs.
...
@@ -1792,14 +1811,16 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -1792,14 +1811,16 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
if
not
envs
.
VLLM_REJECT_SAMPLE_OPT
:
if
not
envs
.
VLLM_REJECT_SAMPLE_OPT
:
draft_token_ids
=
draft_result
draft_token_ids
=
draft_result
else
:
else
:
draft_req_ids
=
list
(
scheduler_output
.
num_scheduled_tokens
.
keys
())
draft_token_ids
,
draft_probs
=
draft_result
draft_token_ids
,
draft_probs
=
draft_result
spec_token_ids
=
draft_token_ids
.
tolist
()
if
envs
.
VLLM_REJECT_SAMPLE_OPT
:
draft_req_ids
=
list
(
scheduler_output
.
num_scheduled_tokens
.
keys
())
if
self
.
draft_probs
is
None
:
if
self
.
draft_probs
is
None
:
self
.
draft_probs
=
DraftProbs
(
self
.
draft_probs
=
DraftProbs
(
draft_probs
,
draft_req_ids
)
draft_probs
,
draft_req_ids
)
else
:
else
:
self
.
draft_probs
.
update
(
draft_probs
,
draft_req_ids
)
self
.
draft_probs
.
update
(
draft_probs
,
draft_req_ids
)
spec_token_ids
=
draft_token_ids
.
tolist
()
return
spec_token_ids
return
spec_token_ids
...
@@ -1920,6 +1941,9 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -1920,6 +1941,9 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
time_after_load
-
time_before_load
)
time_after_load
-
time_before_load
)
prepare_communication_buffer_for_model
(
self
.
model
)
prepare_communication_buffer_for_model
(
self
.
model
)
if
hasattr
(
self
,
"drafter"
):
prepare_communication_buffer_for_model
(
self
.
drafter
.
model
)
if
is_mixture_of_experts
(
if
is_mixture_of_experts
(
self
.
model
)
and
self
.
parallel_config
.
enable_eplb
:
self
.
model
)
and
self
.
parallel_config
.
enable_eplb
:
logger
.
info
(
"EPLB is enabled for model %s."
,
logger
.
info
(
"EPLB is enabled for model %s."
,
...
@@ -2092,6 +2116,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -2092,6 +2116,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
is_profile
:
bool
=
False
,
is_profile
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if
self
.
ep_sp
:
if
num_tokens
<
self
.
tp_size
:
num_tokens
=
self
.
tp_size
# Padding for DP
# Padding for DP
num_pad
,
num_tokens_across_dp
=
self
.
get_dp_padding
(
num_tokens
)
num_pad
,
num_tokens_across_dp
=
self
.
get_dp_padding
(
num_tokens
)
num_tokens
+=
num_pad
num_tokens
+=
num_pad
...
@@ -2099,10 +2128,12 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -2099,10 +2128,12 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively
# for dummy run with LoRA so that the num_reqs collectively
# has num_tokens in total.
# has num_tokens in total.
assert
num_tokens
<=
self
.
scheduler_config
.
max_num_batched_tokens
assert
num_tokens
<=
self
.
scheduler_config
.
max_num_batched_tokens
max_num_reqs
=
self
.
scheduler_config
.
max_num_seqs
max_num_reqs
=
self
.
scheduler_config
.
max_num_seqs
num_reqs
=
min
(
num_tokens
,
max_num_reqs
)
num_reqs
=
min
(
num_tokens
,
max_num_reqs
)
min_tokens_per_req
=
num_tokens
//
num_reqs
min_tokens_per_req
=
num_tokens
//
num_reqs
num_actual_tokens
=
num_tokens
if
not
is_profile
and
self
.
speculative_config
is
not
None
\
if
not
is_profile
and
self
.
speculative_config
is
not
None
\
and
self
.
speculative_config
.
num_lookahead_slots
>
0
\
and
self
.
speculative_config
.
num_lookahead_slots
>
0
\
...
@@ -2110,7 +2141,19 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -2110,7 +2141,19 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
min_tokens_per_req
=
(
1
+
self
.
speculative_config
.
num_lookahead_slots
)
min_tokens_per_req
=
(
1
+
self
.
speculative_config
.
num_lookahead_slots
)
num_reqs
=
num_tokens
//
min_tokens_per_req
num_reqs
=
num_tokens
//
min_tokens_per_req
if
self
.
ep_sp
:
num_actual_tokens
=
round_down
(
num_tokens
,
1
+
self
.
speculative_config
.
num_lookahead_slots
)
num_reqs
=
num_actual_tokens
//
min_tokens_per_req
num_scheduled_tokens_list
=
[
min_tokens_per_req
]
*
num_reqs
num_scheduled_tokens_list
=
[
min_tokens_per_req
]
*
num_reqs
if
not
self
.
ep_sp
:
num_scheduled_tokens_list
[
-
1
]
+=
num_tokens
%
num_reqs
else
:
if
self
.
speculative_config
is
not
None
:
num_scheduled_tokens_list
[
-
1
]
+=
num_tokens
%
min_tokens_per_req
else
:
num_scheduled_tokens_list
[
-
1
]
+=
num_tokens
%
num_reqs
num_scheduled_tokens_list
[
-
1
]
+=
num_tokens
%
num_reqs
assert
sum
(
num_scheduled_tokens_list
)
==
num_tokens
assert
sum
(
num_scheduled_tokens_list
)
==
num_tokens
assert
len
(
num_scheduled_tokens_list
)
==
num_reqs
assert
len
(
num_scheduled_tokens_list
)
==
num_reqs
...
@@ -2135,7 +2178,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -2135,7 +2178,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
# seq_lens_tensor=seq_lens_tensor,
# seq_lens_tensor=seq_lens_tensor,
num_reqs
=
num_reqs
,
num_reqs
=
num_reqs
,
num_actual_tokens
=
num_tokens
,
num_actual_tokens
=
num_
actual_
tokens
,
max_query_len
=
num_tokens
,
max_query_len
=
num_tokens
,
num_speculative_tokens
=
num_speculative_tokens
,
num_speculative_tokens
=
num_speculative_tokens
,
)
)
...
@@ -2156,6 +2199,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
...
@@ -2156,6 +2199,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
input_ids
=
None
input_ids
=
None
inputs_embeds
=
self
.
inputs_embeds
[:
num_tokens
]
inputs_embeds
=
self
.
inputs_embeds
[:
num_tokens
]
else
:
else
:
self
.
input_ids
[:
num_tokens
]
=
torch
.
randint
(
0
,
self
.
model_config
.
get_vocab_size
(),
(
num_tokens
,),
dtype
=
torch
.
int32
)
input_ids
=
self
.
input_ids
[:
num_tokens
]
input_ids
=
self
.
input_ids
[:
num_tokens
]
inputs_embeds
=
None
inputs_embeds
=
None
if
self
.
uses_mrope
:
if
self
.
uses_mrope
:
...
@@ -3166,6 +3211,17 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
...
@@ -3166,6 +3211,17 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
spec_decode_metadata
,
spec_decode_metadata
,
num_scheduled_tokens_np
)
=
(
self
.
_prepare_inputs
(
scheduler_output
))
num_scheduled_tokens_np
)
=
(
self
.
_prepare_inputs
(
scheduler_output
))
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if
self
.
ep_sp
:
num_input_tokens
=
round_up
(
num_scheduled_tokens
,
self
.
tp_size
)
if
(
self
.
use_cuda_graph
and
num_input_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_input_tokens
)
else
:
if
(
self
.
use_cuda_graph
if
(
self
.
use_cuda_graph
and
num_scheduled_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
and
num_scheduled_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
# Use piecewise CUDA graphs.
# Use piecewise CUDA graphs.
...
@@ -3608,16 +3664,17 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
...
@@ -3608,16 +3664,17 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
if
not
envs
.
VLLM_REJECT_SAMPLE_OPT
:
if
not
envs
.
VLLM_REJECT_SAMPLE_OPT
:
draft_token_ids
=
draft_result
draft_token_ids
=
draft_result
else
:
else
:
draft_req_ids
=
list
(
scheduler_output
.
num_scheduled_tokens
.
keys
())
draft_token_ids
,
draft_probs
=
draft_result
draft_token_ids
,
draft_probs
=
draft_result
spec_token_ids
=
draft_token_ids
.
tolist
()
if
envs
.
VLLM_REJECT_SAMPLE_OPT
:
draft_req_ids
=
list
(
scheduler_output
.
num_scheduled_tokens
.
keys
())
if
self
.
draft_probs
is
None
:
if
self
.
draft_probs
is
None
:
self
.
draft_probs
=
DraftProbs
(
self
.
draft_probs
=
DraftProbs
(
draft_probs
,
draft_req_ids
)
draft_probs
,
draft_req_ids
)
else
:
else
:
self
.
draft_probs
.
update
(
draft_probs
,
draft_req_ids
)
self
.
draft_probs
.
update
(
draft_probs
,
draft_req_ids
)
spec_token_ids
=
draft_token_ids
.
tolist
()
return
spec_token_ids
return
spec_token_ids
#TODO:稳定后使用GPUModelRunnerMTP替换GPUModelRunner
#TODO:稳定后使用GPUModelRunnerMTP替换GPUModelRunner
if
envs
.
VLLM_USE_ZERO_MTP
:
if
envs
.
VLLM_USE_ZERO_MTP
:
...
...
vllm/zero_overhead/v1/eagle.py
View file @
cd42bf87
...
@@ -10,6 +10,7 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata
...
@@ -10,6 +10,7 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.spec_decode.eagle
import
PADDING_SLOT_ID
,
EagleProposer
from
vllm.v1.spec_decode.eagle
import
PADDING_SLOT_ID
,
EagleProposer
from
vllm.utils
import
round_up
class
V1ZeroEagleProposer
(
EagleProposer
):
class
V1ZeroEagleProposer
(
EagleProposer
):
...
@@ -110,6 +111,7 @@ class V1ZeroEagleProposer(EagleProposer):
...
@@ -110,6 +111,7 @@ class V1ZeroEagleProposer(EagleProposer):
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_tokens
)
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_tokens
)
else
:
else
:
num_input_tokens
=
num_tokens
num_input_tokens
=
num_tokens
# copy inputs to buffer for cudagraph
# copy inputs to buffer for cudagraph
self
.
positions
[:
num_tokens
]
=
target_positions
self
.
positions
[:
num_tokens
]
=
target_positions
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
...
@@ -148,8 +150,8 @@ class V1ZeroEagleProposer(EagleProposer):
...
@@ -148,8 +150,8 @@ class V1ZeroEagleProposer(EagleProposer):
with
set_forward_context
(
per_layer_attn_metadata
,
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens
=
num_input_tokens
,
):
skip_cuda_graphs
=
not
decoding
):
#
skip_cuda_graphs=not decoding):
ret_hidden_states
=
self
.
model
(
ret_hidden_states
=
self
.
model
(
self
.
input_ids
[:
num_input_tokens
],
self
.
input_ids
[:
num_input_tokens
],
self
.
positions
[:
num_input_tokens
],
self
.
positions
[:
num_input_tokens
],
...
...
vllm/zero_overhead/v1/gpu_model_runner.py
View file @
cd42bf87
...
@@ -424,6 +424,17 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -424,6 +424,17 @@ class V1ZeroModelRunner(GPUModelRunner):
spec_decode_metadata
,
spec_decode_metadata
,
num_scheduled_tokens_np
)
=
(
self
.
_prepare_inputs
(
scheduler_output
))
num_scheduled_tokens_np
)
=
(
self
.
_prepare_inputs
(
scheduler_output
))
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if
self
.
ep_sp
:
num_input_tokens
=
round_up
(
num_scheduled_tokens
,
tp_size
)
if
(
self
.
use_cuda_graph
and
num_input_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_input_tokens
)
else
:
if
(
self
.
use_cuda_graph
if
(
self
.
use_cuda_graph
and
num_scheduled_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
and
num_scheduled_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
# Use piecewise CUDA graphs.
# Use piecewise CUDA graphs.
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment