Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b58514dd
Commit
b58514dd
authored
Apr 18, 2026
by
王敏
Browse files
[perf]1.优化pcp代码 2.优化ep低延迟模式调度,消除空泡
parent
c462f3a0
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
149 additions
and
187 deletions
+149
-187
vllm/config/parallel.py
vllm/config/parallel.py
+7
-0
vllm/config/vllm.py
vllm/config/vllm.py
+7
-2
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+14
-14
vllm/envs.py
vllm/envs.py
+7
-9
vllm/forward_context.py
vllm/forward_context.py
+12
-6
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
.../model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+7
-20
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+1
-1
vllm/model_executor/layers/fused_moe/modular_kernel.py
vllm/model_executor/layers/fused_moe/modular_kernel.py
+13
-0
vllm/model_executor/layers/mla.py
vllm/model_executor/layers/mla.py
+7
-6
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+3
-4
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+27
-92
vllm/v1/attention/backend.py
vllm/v1/attention/backend.py
+1
-2
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+11
-6
vllm/v1/worker/dp_utils.py
vllm/v1/worker/dp_utils.py
+2
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+30
-24
No files found.
vllm/config/parallel.py
View file @
b58514dd
...
...
@@ -299,6 +299,13 @@ class ParallelConfig:
should only be set by API server scale-out.
"""
enable_lightly_cp
:
bool
=
False
"""Use lightly context parallel."""
enable_lightly_cplb
:
bool
=
False
"""Use lightly context parallel load balancing."""
@
field_validator
(
"disable_nccl_for_dp_synchronization"
,
mode
=
"wrap"
)
@
classmethod
def
_skip_none_validation
(
cls
,
value
:
Any
,
handler
:
Callable
)
->
Any
:
...
...
vllm/config/vllm.py
View file @
b58514dd
...
...
@@ -1061,6 +1061,12 @@ class VllmConfig:
# Handle the KV connector configs
self
.
_post_init_kv_transfer_config
()
if
self
.
parallel_config
.
enable_lightly_cp
and
not
self
.
model_config
.
enforce_eager
:
raise
ValueError
(
"Lightly context parallel currently only supports the eager mode!!!"
)
def
update_sizes_for_sequence_parallelism
(
self
,
possible_sizes
:
list
)
->
list
:
# remove the sizes that not multiple of tp_size when
# enable sequence parallelism
...
...
@@ -1186,8 +1192,7 @@ class VllmConfig:
if
(
self
.
parallel_config
.
tensor_parallel_size
>
1
and
(
self
.
compilation_config
.
pass_config
.
enable_sp
)
#or envs.VLLM_MLA_CP)
and
self
.
compilation_config
.
pass_config
.
enable_sp
):
cudagraph_capture_sizes
=
self
.
update_sizes_for_sequence_parallelism
(
cudagraph_capture_sizes
...
...
vllm/engine/arg_utils.py
View file @
b58514dd
...
...
@@ -582,6 +582,9 @@ class EngineArgs:
kv_offloading_backend
:
KVOffloadingBackend
=
CacheConfig
.
kv_offloading_backend
tokens_only
:
bool
=
False
enable_lightly_cp
:
bool
=
ParallelConfig
.
enable_lightly_cp
enable_lightly_cplb
:
bool
=
ParallelConfig
.
enable_lightly_cplb
def
__post_init__
(
self
):
# support `EngineArgs(compilation_config={...})`
# without having to manually construct a
...
...
@@ -899,6 +902,15 @@ class EngineArgs:
"--worker-extension-cls"
,
**
parallel_kwargs
[
"worker_extension_cls"
]
)
parallel_group
.
add_argument
(
"--enable-lightly-cp"
,
**
parallel_kwargs
[
"enable_lightly_cp"
],
)
parallel_group
.
add_argument
(
"--enable-lightly-cplb"
,
**
parallel_kwargs
[
"enable_lightly_cplb"
],
)
# KV cache arguments
cache_kwargs
=
get_kwargs
(
CacheConfig
)
cache_group
=
parser
.
add_argument_group
(
...
...
@@ -1500,20 +1512,6 @@ class EngineArgs:
data_parallel_external_lb
=
(
self
.
data_parallel_external_lb
or
self
.
data_parallel_rank
is
not
None
)
if
(
envs
.
VLLM_MLA_CP
and
self
.
max_num_batched_tokens
is
not
None
and
self
.
max_num_batched_tokens
<
self
.
tensor_parallel_size
**
3
):
raise
ValueError
(
"max_num_batched_tokens should be larger than "
"tensor_parallel_size ** 3 when enabled VLLM_MLA_CP"
)
logger
.
info
(
"[MLACP] VLLM_MLA_CP is %s"
,
envs
.
VLLM_MLA_CP
)
logger
.
info
(
"[MLACP] VLLM_MLA_CPLB is %s"
,
envs
.
VLLM_MLA_CPLB
)
# Local DP rank = 1, use pure-external LB.
if
data_parallel_external_lb
:
assert
self
.
data_parallel_rank
is
not
None
,
(
...
...
@@ -1644,6 +1642,8 @@ class EngineArgs:
cp_kv_cache_interleave_size
=
self
.
cp_kv_cache_interleave_size
,
_api_process_count
=
self
.
_api_process_count
,
_api_process_rank
=
self
.
_api_process_rank
,
enable_lightly_cp
=
self
.
enable_lightly_cp
,
enable_lightly_cplb
=
self
.
enable_lightly_cplb
,
)
speculative_config
=
self
.
create_speculative_config
(
...
...
vllm/envs.py
View file @
b58514dd
...
...
@@ -324,8 +324,9 @@ if TYPE_CHECKING:
USE_LIGHTOP_TOPK
:
bool
=
False
USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX
:
bool
=
False
VLLM_DISABLE_DSA
:
bool
=
False
VLLM_MLA_CP
:
bool
=
False
VLLM_MLA_CPLB
:
bool
=
False
VLLM_LIGHTLY_CP_THRESHOULD
:
int
=
2048
def
get_default_cache_root
():
return
os
.
getenv
(
"XDG_CACHE_HOME"
,
...
...
@@ -2012,13 +2013,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_DISABLE_DSA"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_DISABLE_DSA"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# If set to 1/True, enable mla context parallel
"VLLM_MLA_CP"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_MLA_CP"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
"VLLM_MLA_CPLB"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_MLA_CPLB"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# MLA_CP open threshold
"VLLM_LIGHTLY_CP_THRESHOULD"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_LIGHTLY_CP_THRESHOULD"
,
"2048"
)),
}
# --8<-- [end:env-vars-definition]
...
...
vllm/forward_context.py
View file @
b58514dd
...
...
@@ -242,7 +242,8 @@ class ForwardContext:
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
enable_mla_cp
:
bool
=
False
enable_lightly_cp
:
bool
=
False
enable_lightly_cplb
:
bool
=
False
def
__post_init__
(
self
):
assert
self
.
cudagraph_runtime_mode
.
valid_runtime_modes
(),
(
...
...
@@ -279,7 +280,8 @@ def create_forward_context(
skip_compiled
:
bool
=
False
,
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
enable_mla_cp
:
bool
=
False
enable_lightly_cp
:
bool
=
False
,
enable_lightly_cplb
:
bool
=
False
):
if
vllm_config
.
compilation_config
.
fast_moe_cold_start
:
if
vllm_config
.
speculative_config
is
None
:
...
...
@@ -307,7 +309,8 @@ def create_forward_context(
skip_compiled
=
skip_compiled
,
scatter_indexes_tensor
=
scatter_indexes_tensor
,
gather_indexes_tensor
=
gather_indexes_tensor
,
enable_mla_cp
=
enable_mla_cp
,
enable_lightly_cp
=
enable_lightly_cp
,
enable_lightly_cplb
=
enable_lightly_cplb
,
additional_kwargs
=
additional_kwargs
or
{},
)
...
...
@@ -341,7 +344,8 @@ def set_forward_context(
skip_compiled
:
bool
=
False
,
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
enable_mla_cp
:
bool
=
False
,
enable_lightly_cp
:
bool
=
False
,
enable_lightly_cplb
:
bool
=
False
,
):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
...
...
@@ -353,7 +357,8 @@ def set_forward_context(
forward_start_time
=
time
.
perf_counter
()
dp_metadata
:
DPMetadata
|
None
=
None
if
vllm_config
.
parallel_config
.
data_parallel_size
>
1
and
(
if
vllm_config
.
parallel_config
.
data_parallel_size
>
1
and
\
envs
.
VLLM_ALL2ALL_BACKEND
!=
"deepep_low_latency"
and
(
attn_metadata
is
not
None
or
num_tokens
is
not
None
):
# If num_tokens_across_dp hasn't already been initialized, then
...
...
@@ -404,7 +409,8 @@ def set_forward_context(
skip_compiled
,
scatter_indexes_tensor
,
gather_indexes_tensor
,
enable_mla_cp
enable_lightly_cp
,
enable_lightly_cplb
)
try
:
...
...
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
View file @
b58514dd
...
...
@@ -205,20 +205,6 @@ def moe_grouped_gemm(
return
output
def
native_w8a8_perChannel_batch_matmul
(
q_a1_all
,
weight13
,
qa1_scale_all
,
w13_scale
,
output_dtype
):
A
=
q_a1_all
.
to
(
torch
.
float32
)
B
=
weight13
.
to
(
torch
.
float32
)
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
],
"Dimension mismatch"
C
=
torch
.
bmm
(
A
,
B
.
transpose
(
1
,
2
))
# [E, M, K]
C
=
qa1_scale_all
*
C
*
w13_scale
.
transpose
(
1
,
2
)
# Broadcast per-column scale
C
=
C
.
to
(
output_dtype
)
return
C
def
scales_shape_stride_dtype
(
E
:
int
,
T
:
int
,
G
:
int
,
quant_scale_fmt
:
DeepGemmQuantScaleFMT
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
...
...
@@ -589,7 +575,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2
:
torch
.
Tensor
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
apply_router_weight_on_input
:
bool
,
use_nn_moe
:
bool
|
None
=
False
,
**
_
):
assert
expert_tokens_meta
is
not
None
...
...
@@ -612,11 +597,13 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace1
=
_resize_cache
(
workspace13
,
(
E
,
max_num_tokens
,
N
))
expected_m
=
self
.
estimate_expected_m
(
global_num_experts
=
global_num_experts
,
max_tokens_per_expert
=
max_num_tokens
,
topk
=
topk_ids
.
size
(
-
1
),
)
# expected_m = self.estimate_expected_m(
# global_num_experts=global_num_experts,
# max_tokens_per_expert=max_num_tokens,
# topk=topk_ids.size(-1),
# )
expected_m
=
self
.
get_expected_m
()
if
self
.
quant_config
.
use_fp8_w8a16
or
self
.
quant_config
.
use_fp8_w8a8
:
fp8_m_grouped_gemm_nt_masked
(
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
b58514dd
...
...
@@ -854,7 +854,7 @@ class FusedMoE(CustomOp):
def
use_dp_chunking
(
self
)
->
bool
:
return
(
self
.
moe_parallel_config
.
use_pplx_kernels
or
self
.
moe_parallel_config
.
use_deepep_ll_kernels
#
or self.moe_parallel_config.use_deepep_ll_kernels
or
self
.
moe_parallel_config
.
use_mori_kernels
or
(
self
.
dp_size
>
1
and
self
.
use_flashinfer_cutlass_kernels
)
)
and
envs
.
VLLM_ENABLE_MOE_DP_CHUNK
...
...
vllm/model_executor/layers/fused_moe/modular_kernel.py
View file @
b58514dd
...
...
@@ -406,6 +406,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
self
.
quant_config
=
quant_config
self
.
max_num_tokens
=
max_num_tokens
self
.
num_dispatchers
=
num_dispatchers
self
.
expected_m
=
max_num_tokens
@
staticmethod
def
expects_unquantized_inputs
(
...
...
@@ -774,6 +775,12 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
chooses to do weight application.
"""
raise
NotImplementedError
def
set_expected_m
(
self
,
expected_m
):
self
.
expected_m
=
expected_m
def
get_expected_m
(
self
):
return
self
.
expected_m
def
_slice_scales
(
...
...
@@ -1074,6 +1081,12 @@ class FusedMoEModularKernel(torch.nn.Module):
The _prepare method is a wrapper around self.prepare_finalize.prepare
that handles DBO and async.
"""
expected_m
=
(
hidden_states
.
shape
[
0
]
*
self
.
fused_experts
.
num_dispatchers
*
topk_ids
.
shape
[
1
]
+
global_num_experts
)
//
global_num_experts
self
.
fused_experts
.
set_expected_m
(
expected_m
)
if
not
self
.
prepare_finalize
.
supports_async
():
# We shouldn't be running an a2a kernel that doesn't
# support async prepare/finalize
...
...
vllm/model_executor/layers/mla.py
View file @
b58514dd
...
...
@@ -3,7 +3,6 @@
from
dataclasses
import
dataclass
import
torch
from
vllm.attention.layer
import
MLAAttention
from
vllm.config
import
CacheConfig
import
vllm.envs
as
envs
...
...
@@ -115,6 +114,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
self
.
prefix
=
prefix
def
forward
(
self
,
positions
:
torch
.
Tensor
,
...
...
@@ -189,11 +189,12 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
if
llama_4_scaling
is
not
None
:
q
*=
llama_4_scaling
enable_mla_cp
=
get_forward_context
().
enable_mla_cp
#envs.VLLM_MLA_CP # and not get_forward_context().draft_model
enable_lightly_cp
=
get_forward_context
().
enable_lightly_cp
enable_lightly_cplb
=
get_forward_context
().
enable_lightly_cplb
# if not use_fused_rms_rope_concat:
if
not
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
if
enable_
mla
_cp
:
if
enable_
lightly
_cp
:
kv_c_normed
=
tensor_model_parallel_all_gather
(
kv_c_normed
.
contiguous
(),
0
)
...
...
@@ -202,7 +203,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
)
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
if
en
vs
.
VLLM_MLA_CPLB
and
gather_indexes_tensor
is
not
None
:
if
en
able_lightly_cplb
and
gather_indexes_tensor
is
not
None
:
# Reorder kv after pcp allgather.
kv_c_normed
=
torch
.
index_select
(
kv_c_normed
,
0
,
gather_indexes_tensor
)
k_pe
=
torch
.
index_select
(
k_pe
,
0
,
gather_indexes_tensor
)
...
...
@@ -243,7 +244,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
"expose 'cos_sin_cache'."
)
if
enable_
mla
_cp
:
if
enable_
lightly
_cp
:
kv_c
=
tensor_model_parallel_all_gather
(
kv_c
.
contiguous
(),
0
)
...
...
@@ -251,7 +252,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
k_pe
.
contiguous
(),
0
)
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
if
en
vs
.
VLLM_MLA_CPLB
and
gather_indexes_tensor
is
not
None
:
if
en
able_lightly_cplb
and
gather_indexes_tensor
is
not
None
:
# Reorder kv after pcp allgather.
kv_c
=
torch
.
index_select
(
kv_c
,
0
,
gather_indexes_tensor
)
k_pe
=
torch
.
index_select
(
k_pe
,
0
,
gather_indexes_tensor
)
...
...
vllm/model_executor/models/deepseek_mtp.py
View file @
b58514dd
...
...
@@ -198,8 +198,8 @@ class DeepSeekMultiTokenPredictor(nn.Module):
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
current_step_idx
=
spec_step_idx
%
self
.
num_mtp_layers
enable_
mla
_cp
=
get_forward_context
().
enable_
mla_cp
#envs.VLLM_MLA_CP # and not get_forward_context().draft_model
if
enable_
mla
_cp
:
enable_
lightly
_cp
=
get_forward_context
().
enable_
lightly_cp
if
enable_
lightly
_cp
:
scatter_indexes_tensor
=
get_forward_context
().
scatter_indexes_tensor
if
scatter_indexes_tensor
is
None
:
inputs_embeds_per_rank
=
torch
.
chunk
(
inputs_embeds
,
chunks
=
self
.
tp_size
,
dim
=
0
)
...
...
@@ -212,7 +212,6 @@ class DeepSeekMultiTokenPredictor(nn.Module):
positions_per_rank
=
torch
.
chunk
(
positions
,
chunks
=
self
.
tp_size
,
dim
=
0
)
positions
=
positions_per_rank
[
self
.
tp_rank
].
contiguous
()
else
:
#scatter_indexes_tensor = scatter_indexes_tensor[scatter_indexes_tensor != -1]
scatter_indexes_tensor
=
torch
.
where
(
scatter_indexes_tensor
==
-
1
,
0
,
scatter_indexes_tensor
)
inputs_embeds
=
torch
.
index_select
(
inputs_embeds
,
0
,
scatter_indexes_tensor
)
...
...
@@ -228,7 +227,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
current_step_idx
,
)
if
enable_
mla
_cp
:
if
enable_
lightly
_cp
:
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
.
contiguous
(),
dim
=
0
)
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
if
gather_indexes_tensor
is
not
None
:
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
b58514dd
...
...
@@ -183,10 +183,9 @@ class DeepseekAttention(nn.Module):
return
output
def
eff_2d_
iqis_all_gather
(
def
iqis_all_gather
(
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
tp_size
:
int
|
None
=
None
,
tp_rank
:
int
|
None
=
None
tp_size
:
int
|
None
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
iqis
is
not
None
iq_tensor
,
is_tensor
=
iqis
...
...
@@ -221,6 +220,7 @@ def eff_2d_iqis_all_gather(
is_gathered
=
is_gathered_int8
.
view
(
torch
.
float32
)
return
(
iq_gathered
,
is_gathered
)
class
DeepseekV2MLP
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -267,15 +267,10 @@ class DeepseekV2MLP(nn.Module):
x
,
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
):
enable_
mla
_cp
=
get_forward_context
().
enable_
mla_cp
#envs.VLLM_MLA_CP# and not get_forward_context().draft_model
if
enable_
mla
_cp
:
enable_
lightly
_cp
=
get_forward_context
().
enable_
lightly_cp
if
enable_
lightly
_cp
:
if
iqis
is
not
None
and
iqis
[
0
]
is
not
None
and
iqis
[
1
]
is
not
None
:
if
False
:
i_q_gahter
=
tensor_model_parallel_all_gather
(
iqis
[
0
].
contiguous
(),
0
)
i_s_gather
=
tensor_model_parallel_all_gather
(
iqis
[
1
].
contiguous
(),
0
)
iqis
=
(
i_q_gahter
,
i_s_gather
)
else
:
iqis
=
eff_2d_iqis_all_gather
(
iqis
,
tp_size
=
self
.
tp_size
,
tp_rank
=
get_tensor_model_parallel_rank
())
iqis
=
iqis_all_gather
(
iqis
,
tp_size
=
self
.
tp_size
)
else
:
x
=
tensor_model_parallel_all_gather
(
x
.
contiguous
(),
0
)
...
...
@@ -293,72 +288,12 @@ class DeepseekV2MLP(nn.Module):
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
if
enable_
mla
_cp
:
if
enable_
lightly
_cp
:
x
=
tensor_model_parallel_reduce_scatter
(
x
.
contiguous
(),
dim
=
0
)
return
x
elif
self
.
tp_size
>
1
:
x
=
tensor_model_parallel_all_reduce
(
x
)
return
x
class
DeepseekV2SharedMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
QuantizationConfig
|
None
=
None
,
reduce_results
:
bool
=
True
,
is_sequence_parallel
=
False
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
# If is_sequence_parallel, the input and output tensors are sharded
# across the ranks within the tp_group. In this case the weights are
# replicated and no collective ops are needed.
# Otherwise we use standard TP with an allreduce at the end.
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
disable_tp
=
is_sequence_parallel
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
reduce_results
=
reduce_results
,
disable_tp
=
is_sequence_parallel
,
prefix
=
f
"
{
prefix
}
.down_proj"
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
,
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
):
if
envs
.
USE_FUSED_RMS_QUANT
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
,
iqis
=
iqis
)
if
envs
.
USE_FUSED_SILU_MUL_QUANT
:
from
lmslim.quantize.quant_ops
import
lm_fuse_silu_mul_quant
xq
,
xs
=
lm_fuse_silu_mul_quant
(
gate_up
)
x
,
_
=
self
.
down_proj
(
gate_up
,
iqis
=
(
xq
,
xs
))
else
:
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
else
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
DeepseekV2MoE
(
nn
.
Module
):
...
...
@@ -431,7 +366,7 @@ class DeepseekV2MoE(nn.Module):
else
:
intermediate_size
=
config
.
moe_intermediate_size
*
config
.
n_shared_experts
self
.
shared_experts
=
DeepseekV2
Shared
MLP
(
self
.
shared_experts
=
DeepseekV2MLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
...
...
@@ -477,8 +412,8 @@ class DeepseekV2MoE(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
)
->
torch
.
Tensor
:
enable_
mla
_cp
=
get_forward_context
().
enable_
mla_cp
#envs.VLLM_MLA_CP #and not get_forward_context().draft_model
if
enable_
mla
_cp
:
enable_
lightly
_cp
=
get_forward_context
().
enable_
lightly_cp
if
enable_
lightly
_cp
:
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
.
contiguous
(),
0
)
...
...
@@ -553,7 +488,7 @@ class DeepseekV2MoE(nn.Module):
assert
shared_output
is
not
None
final_hidden_states
+=
shared_output
if
enable_
mla
_cp
:
if
enable_
lightly
_cp
:
final_hidden_states
=
tensor_model_parallel_reduce_scatter
(
final_hidden_states
.
contiguous
(),
0
)
...
...
@@ -889,13 +824,14 @@ class Indexer(nn.Module):
# `k_pe` is [num_tokens, 1, rope_dim] (MQA).
k
=
torch
.
cat
([
k_pe
.
squeeze
(
-
2
),
k_nope
],
dim
=-
1
)
enable_
mla
_cp
=
get_forward_context
().
enable_
mla_cp
#envs.VLLM_MLA_CP # and not get_forward_context().draft_model
if
enable_
mla
_cp
:
enable_
lightly
_cp
=
get_forward_context
().
enable_
lightly_cp
if
enable_
lightly
_cp
:
k
=
tensor_model_parallel_all_gather
(
k
.
contiguous
(),
0
)
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
if
envs
.
VLLM_MLA_CPLB
and
gather_indexes_tensor
is
not
None
:
enable_lightly_cplb
=
get_forward_context
().
enable_lightly_cplb
if
enable_lightly_cplb
and
gather_indexes_tensor
is
not
None
:
k
=
torch
.
index_select
(
k
,
0
,
gather_indexes_tensor
)
# we only quant q here since k quant is fused with cache insertion
...
...
@@ -964,8 +900,8 @@ class DeepseekV2MLAAttention(nn.Module):
self
.
num_heads
=
num_heads
tp_size
=
get_tensor_model_parallel_world_size
()
assert
num_heads
%
tp_size
==
0
#
self.num_local_heads = num_heads // tp_size
self
.
num_local_heads
=
num_heads
//
tp_size
if
not
envs
.
VLLM_MLA_CP
else
self
.
num_heads
self
.
num_local_heads
=
num_heads
//
tp_size
if
not
\
vllm_config
.
parallel_config
.
enable_lightly_cp
else
self
.
num_heads
self
.
scaling
=
self
.
qk_head_dim
**-
0.5
self
.
max_position_embeddings
=
max_position_embeddings
...
...
@@ -999,7 +935,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.q_b_proj"
,
disable_tp
=
envs
.
VLLM_MLA_CP
,
disable_tp
=
vllm_config
.
parallel_config
.
enable_lightly_cp
)
else
:
self
.
q_proj
=
ColumnParallelLinear
(
...
...
@@ -1008,7 +944,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.q_proj"
,
disable_tp
=
envs
.
VLLM_MLA_CP
,
disable_tp
=
vllm_config
.
parallel_config
.
enable_lightly_cp
,
)
self
.
kv_a_layernorm
=
RMSNorm
(
self
.
kv_lora_rank
,
eps
=
config
.
rms_norm_eps
)
self
.
kv_b_proj
=
ColumnParallelLinear
(
...
...
@@ -1017,7 +953,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.kv_b_proj"
,
disable_tp
=
envs
.
VLLM_MLA_CP
,
disable_tp
=
vllm_config
.
parallel_config
.
enable_lightly_cp
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
num_heads
*
self
.
v_head_dim
,
...
...
@@ -1025,7 +961,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
disable_tp
=
envs
.
VLLM_MLA_CP
,
disable_tp
=
vllm_config
.
parallel_config
.
enable_lightly_cp
,
)
if
config
.
rope_parameters
[
"rope_type"
]
!=
"default"
:
...
...
@@ -1262,8 +1198,7 @@ class DeepseekV2DecoderLayer(nn.Module):
residual
*=
1.0
/
self
.
routed_scaling_factor
# Fully Connected
enable_mla_cp
=
get_forward_context
().
enable_mla_cp
skip_moe_large_batch_size
=
enable_mla_cp
enable_lightly_cp
=
get_forward_context
().
enable_lightly_cp
update_hs
=
True
if
isinstance
(
self
.
mlp
,
DeepseekV2MoE
)
else
False
assert
self
.
post_attention_layernorm
.
has_weight
is
True
_i_q
,
_i_s
,
residual
=
self
.
post_attention_layernorm
(
x
=
hidden_states
,
...
...
@@ -1272,7 +1207,7 @@ class DeepseekV2DecoderLayer(nn.Module):
update_input
=
update_hs
)
new_resi
=
residual
if
skip_moe_large_batch_size
and
isinstance
(
self
.
mlp
,
DeepseekV2MoE
):
if
enable_lightly_cp
and
isinstance
(
self
.
mlp
,
DeepseekV2MoE
):
hidden_states
=
self
.
mlp
(
hidden_states
)
else
:
hidden_states
=
self
.
mlp
(
hidden_states
,
iqis
=
(
_i_q
,
_i_s
))
...
...
@@ -1437,8 +1372,8 @@ class DeepseekV2Model(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
enable_
mla
_cp
=
get_forward_context
().
enable_
mla_cp
#envs.VLLM_MLA_CP # and not get_forward_context().draft_model
if
enable_
mla
_cp
:
enable_
lightly
_cp
=
get_forward_context
().
enable_
lightly_cp
if
enable_
lightly
_cp
:
scatter_indexes_tensor
=
get_forward_context
().
scatter_indexes_tensor
if
scatter_indexes_tensor
is
None
:
hidden_states_per_rank
=
torch
.
chunk
(
hidden_states
,
chunks
=
self
.
tp_size
,
dim
=
0
)
...
...
@@ -1481,7 +1416,7 @@ class DeepseekV2Model(nn.Module):
)
if
not
get_pp_group
().
is_last_rank
:
if
enable_
mla
_cp
:
if
enable_
lightly
_cp
:
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
.
contiguous
(),
dim
=
0
)
residual
=
tensor_model_parallel_all_gather
(
residual
.
contiguous
(),
dim
=
0
)
return
IntermediateTensors
(
...
...
@@ -1490,7 +1425,7 @@ class DeepseekV2Model(nn.Module):
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
if
enable_
mla
_cp
:
if
enable_
lightly
_cp
:
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
.
contiguous
(),
dim
=
0
)
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
if
gather_indexes_tensor
is
not
None
:
...
...
vllm/v1/attention/backend.py
View file @
b58514dd
...
...
@@ -332,7 +332,6 @@ class CommonAttentionMetadata:
"""Number of requests"""
# TODO(lucas): rename to num_tokens since it may be padded and this is misleading
num_actual_tokens
:
int
"""Total number of tokens in batch"""
max_query_len
:
int
"""Longest query in batch"""
...
...
@@ -348,7 +347,7 @@ class CommonAttentionMetadata:
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
cp_common_metadata
:
CpCommonAttentionMetadata
|
None
=
None
enable_
mla
_cp
:
bool
=
False
enable_
lightly
_cp
:
bool
=
False
causal
:
bool
=
True
...
...
vllm/v1/spec_decode/eagle.py
View file @
b58514dd
...
...
@@ -78,7 +78,8 @@ class SpecDecodeBaseProposer:
self
.
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
self
.
num_speculative_tokens
=
self
.
speculative_config
.
num_speculative_tokens
# The drafter can get longer sequences than the target model.
max_batch_size
=
vllm_config
.
scheduler_config
.
max_num_seqs
if
not
envs
.
VLLM_MLA_CPLB
\
max_batch_size
=
vllm_config
.
scheduler_config
.
max_num_seqs
if
not
\
vllm_config
.
parallel_config
.
enable_lightly_cplb
\
else
vllm_config
.
scheduler_config
.
max_num_seqs
*
2
self
.
max_num_tokens
=
(
vllm_config
.
scheduler_config
.
max_num_batched_tokens
+
max_batch_size
...
...
@@ -224,7 +225,10 @@ class SpecDecodeBaseProposer:
self
.
scatter_indexes_tensor
=
None
self
.
gather_indexes_tensor
=
None
if
envs
.
VLLM_MLA_CP
:
self
.
enable_lightly_cp
=
vllm_config
.
parallel_config
.
enable_lightly_cp
self
.
enable_lightly_cplb
=
self
.
enable_lightly_cp
and
vllm_config
.
parallel_config
.
enable_lightly_cplb
if
self
.
enable_lightly_cp
:
self
.
query_start_loc
=
CpuGpuBuffer
(
max_batch_size
+
1
,
dtype
=
torch
.
int32
,
...
...
@@ -339,8 +343,8 @@ class SpecDecodeBaseProposer:
num_tokens_unpadded
=
num_tokens
,
num_tokens_padded
=
num_tokens
)
enable_
mla_cp
=
envs
.
VLLM_MLA_CP
and
num_tokens
>
self
.
runner
.
mla
_cp_threshould
if
enable_
mla
_cp
:
enable_
lightly_cp
=
self
.
enable_lightly_cp
and
num_tokens
>
self
.
runner
.
lightly
_cp_threshould
if
enable_
lightly
_cp
:
num_tokens_dp_padded
=
self
.
_pad_for_mla_cp
(
num_tokens_dp_padded
)
common_attn_metadata
=
self
.
_prepare_cp_metadata
(
...
...
@@ -436,7 +440,8 @@ class SpecDecodeBaseProposer:
),
scatter_indexes_tensor
=
self
.
scatter_indexes_tensor
,
gather_indexes_tensor
=
self
.
gather_indexes_tensor
,
enable_mla_cp
=
envs
.
VLLM_MLA_CP
and
num_tokens
>
self
.
runner
.
mla_cp_threshould
,
enable_lightly_cp
=
self
.
enable_lightly_cp
and
num_tokens
>
self
.
runner
.
lightly_cp_threshould
,
enable_lightly_cplb
=
self
.
enable_lightly_cplb
):
ret_hidden_states
=
self
.
model
(
**
model_kwargs
)
if
not
self
.
model_returns_tuple
():
...
...
@@ -513,7 +518,7 @@ class SpecDecodeBaseProposer:
if
batch_size_across_dp
is
not
None
:
batch_size_across_dp
[
self
.
dp_rank
]
=
input_batch_size
if
enable_
mla
_cp
:
if
enable_
lightly
_cp
:
common_attn_metadata
=
common_attn_metadata
.
cp_common_metadata
common_attn_metadata
.
num_actual_tokens
=
batch_size
...
...
vllm/v1/worker/dp_utils.py
View file @
b58514dd
...
...
@@ -6,6 +6,7 @@ import numpy as np
import
torch
import
torch.distributed
as
dist
import
vllm.envs
as
envs
from
vllm.config
import
ParallelConfig
from
vllm.distributed.parallel_state
import
get_dp_group
from
vllm.logger
import
init_logger
...
...
@@ -208,7 +209,7 @@ def coordinate_batch_across_dp(
]
"""
if
parallel_config
.
data_parallel_size
==
1
:
if
parallel_config
.
data_parallel_size
==
1
or
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
:
# Early exit.
return
False
,
None
,
cudagraph_mode
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
b58514dd
...
...
@@ -189,6 +189,7 @@ from .utils import (
sanity_check_mm_encoder_outputs
,
)
from
vllm.v1.spec_decode.utils
import
DraftProbs
from
vllm.utils.torch_utils
import
async_tensor_h2d
if
TYPE_CHECKING
:
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
...
...
@@ -382,12 +383,15 @@ class GPUModelRunner(
self
.
dcp_rank
=
0
if
self
.
dcp_world_size
<=
1
else
get_dcp_group
().
rank_in_group
self
.
max_num_tokens
=
scheduler_config
.
max_num_batched_tokens
#self.max_num_reqs = scheduler_config.max_num_seqs
self
.
enable_lightly_cp
=
self
.
parallel_config
.
enable_lightly_cp
self
.
enable_lightly_cplb
=
self
.
enable_lightly_cp
and
self
.
parallel_config
.
enable_lightly_cplb
self
.
max_num_reqs
=
(
scheduler_config
.
max_num_seqs
if
not
envs
.
VLLM_MLA_CPLB
if
not
self
.
enable_lightly_cplb
else
scheduler_config
.
max_num_seqs
*
2
)
self
.
mla
_cp_threshould
=
512
self
.
lightly
_cp_threshould
=
envs
.
VLLM_LIGHTLY_CP_THRESHOULD
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
...
...
@@ -1525,7 +1529,7 @@ class GPUModelRunner(
local_scatter_indexes_tensor
=
None
gather_indexes_tensor
=
None
if
envs
.
VLLM_MLA_CPLB
:
if
self
.
enable_lightly_cp
:
rank_tokens
=
0
rank_pad_tokens
=
0
accu_q_start
=
0
...
...
@@ -1736,7 +1740,7 @@ class GPUModelRunner(
cp_common_metadata
=
cp_common_metadata
,
scatter_indexes_tensor
=
scatter_indexes_tensor
,
gather_indexes_tensor
=
gather_indexes_tensor
,
enable_
mla
_cp
=
True
enable_
lightly
_cp
=
True
)
return
cm_base
...
...
@@ -2040,8 +2044,8 @@ class GPUModelRunner(
if
self
.
model_config
.
enable_return_routed_experts
:
self
.
slot_mapping
=
slot_mapping_gid_0
[:
num_tokens
].
cpu
().
numpy
()
mla_cp_enable
=
envs
.
VLLM_MLA_CP
and
num_tokens
>
self
.
mla
_cp_threshould
if
not
mla_cp_enable
:
enable_lightly_cp
=
self
.
enable_lightly_cp
and
num_tokens
>
self
.
lightly
_cp_threshould
if
not
enable_lightly_cp
:
cm_base
=
CommonAttentionMetadata
(
query_start_loc
=
self
.
query_start_loc
.
gpu
[:
num_reqs_padded
+
1
],
query_start_loc_cpu
=
self
.
query_start_loc
.
cpu
[:
num_reqs_padded
+
1
],
...
...
@@ -2183,19 +2187,19 @@ class GPUModelRunner(
cm
.
block_table_tensor
=
_get_block_table
(
kv_cache_gid
)
cm
.
slot_mapping
=
slot_mappings
[
kv_cache_gid
]
if
cm
.
seq_indexes_list
is
not
None
:
if
enable_lightly_cp
and
cm
.
seq_indexes_list
is
not
None
:
cm
.
block_table_tensor
=
cm
.
block_table_tensor
[
cm
.
seq_indexes_list
]
if
self
.
speculative_config
and
spec_decode_common_attn_metadata
is
None
and
hasattr
(
self
,
"drafter"
):
if
isinstance
(
self
.
drafter
,
EagleProposer
):
if
self
.
drafter
.
attn_layer_names
[
0
]
in
kv_cache_group
.
layer_names
:
if
mla_cp_enable
:
if
enable_lightly_cp
:
spec_decode_common_attn_metadata
=
cm
.
cp_common_metadata
else
:
spec_decode_common_attn_metadata
=
cm
#spec_decode_common_attn_metadata = cm
else
:
if
mla_cp_enable
:
if
enable_lightly_cp
:
spec_decode_common_attn_metadata
=
cm
.
cp_common_metadata
else
:
spec_decode_common_attn_metadata
=
cm
...
...
@@ -2230,7 +2234,7 @@ class GPUModelRunner(
_metadata
.
mm_prefix_range
=
req_doc_ranges
# type: ignore[attr-defined]
if
(
(
not
envs
.
VLLM_MLA_CP
)
(
not
self
.
enable_lightly_cp
)
and
spec_decode_common_attn_metadata
is
not
None
and
(
num_reqs
!=
num_reqs_padded
or
num_tokens
!=
num_tokens_padded
)
):
...
...
@@ -3110,7 +3114,7 @@ class GPUModelRunner(
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
if
envs
.
VLLM_MLA_CP
and
num_scheduled_tokens
>
self
.
mla
_cp_threshould
:
if
self
.
enable_lightly_cp
and
num_scheduled_tokens
>
self
.
lightly
_cp_threshould
:
return
self
.
_pad_for_mla_cp
(
num_scheduled_tokens
)
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
if
self
.
compilation_config
.
pass_config
.
enable_sp
and
tp_size
>
1
:
...
...
@@ -3808,7 +3812,7 @@ class GPUModelRunner(
)
num_tokens_padded
=
batch_desc
.
num_tokens
if
envs
.
VLLM_MLA_CP
and
num_tokens_unpadded
>
self
.
mla
_cp_threshould
:
if
self
.
enable_lightly_cp
and
num_tokens_unpadded
>
self
.
lightly
_cp_threshould
:
num_tokens_padded
=
self
.
_pad_for_mla_cp
(
num_tokens_unpadded
)
num_reqs_padded
=
(
batch_desc
.
num_reqs
if
batch_desc
.
num_reqs
is
not
None
else
num_reqs
...
...
@@ -3927,7 +3931,8 @@ class GPUModelRunner(
skip_compiled
=
has_encoder_input
,
scatter_indexes_tensor
=
scatter_indexes_tensor
,
gather_indexes_tensor
=
gather_indexes_tensor
,
enable_mla_cp
=
envs
.
VLLM_MLA_CP
and
num_tokens_unpadded
>
self
.
mla_cp_threshould
,
enable_lightly_cp
=
self
.
enable_lightly_cp
and
num_tokens_unpadded
>
self
.
lightly_cp_threshould
,
enable_lightly_cplb
=
self
.
enable_lightly_cplb
),
record_function_or_nullcontext
(
"gpu_model_runner: forward"
),
self
.
maybe_get_kv_connector_output
(
...
...
@@ -4421,7 +4426,7 @@ class GPUModelRunner(
)
#total_num_tokens = common_attn_metadata.num_actual_tokens
if
(
envs
.
VLLM_MLA_CP
self
.
enable_lightly_cp
and
common_attn_metadata
.
cp_common_metadata
is
not
None
):
total_num_tokens
=
(
...
...
@@ -4952,9 +4957,6 @@ class GPUModelRunner(
or
cudagraph_runtime_mode
.
valid_runtime_modes
()
)
# if envs.VLLM_MLA_CP:
# num_tokens = max(self.tp_size, num_tokens)
# If cudagraph_mode.decode_mode() == FULL and
# cudagraph_mode.separate_routine(). This means that we are using
# different graphs and/or modes for mixed prefill-decode batches vs.
...
...
@@ -5116,9 +5118,6 @@ class GPUModelRunner(
inputs_embeds
=
self
.
inputs_embeds
.
gpu
[:
num_tokens_padded
]
model_kwargs
=
self
.
_init_model_kwargs
()
else
:
self
.
input_ids
.
gpu
[:
num_tokens_padded
]
=
torch
.
randint
(
0
,
self
.
model_config
.
get_vocab_size
(),
(
num_tokens_padded
,),
dtype
=
torch
.
int32
)
input_ids
=
self
.
input_ids
.
gpu
[:
num_tokens_padded
]
inputs_embeds
=
None
...
...
@@ -5159,7 +5158,8 @@ class GPUModelRunner(
batch_descriptor
=
batch_desc
,
ubatch_slices
=
ubatch_slices_padded
,
slot_mapping
=
slot_mappings
,
enable_mla_cp
=
envs
.
VLLM_MLA_CP
and
num_tokens_unpadded
>
self
.
mla_cp_threshould
,
enable_lightly_cp
=
self
.
enable_lightly_cp
and
num_tokens_unpadded
>
self
.
lightly_cp_threshould
,
enable_lightly_cplb
=
self
.
enable_lightly_cplb
),
):
outputs
=
self
.
model
(
...
...
@@ -5232,9 +5232,15 @@ class GPUModelRunner(
self
.
eplb_step
(
is_dummy
=
True
,
is_profile
=
is_profile
)
logit_indices
=
np
.
cumsum
(
num_scheduled_tokens
)
-
1
logit_indices_device
=
torch
.
from_numpy
(
logit_indices
).
to
(
self
.
device
,
non_blocking
=
True
)
# logit_indices_device = torch.from_numpy(logit_indices).to(
# self.device, non_blocking=True
# )
logit_indices
=
logit_indices
.
tolist
()
logit_indices_device
=
async_tensor_h2d
(
logit_indices
,
dtype
=
torch
.
int32
,
target_device
=
self
.
device
,
pin_memory
=
True
)
return
hidden_states
,
hidden_states
[
logit_indices_device
]
@
torch
.
inference_mode
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment