Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f3d1f95b
Commit
f3d1f95b
authored
Apr 15, 2026
by
王敏
Browse files
[Feat]添加CPLB功能,支持PCP模式下负载均衡
parent
20254503
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
57 additions
and
16 deletions
+57
-16
vllm/model_executor/layers/mla.py
vllm/model_executor/layers/mla.py
+11
-0
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+20
-7
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+25
-8
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+1
-1
No files found.
vllm/model_executor/layers/mla.py
View file @
f3d1f95b
...
...
@@ -201,6 +201,12 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
k_pe
.
contiguous
(),
0
)
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
if
envs
.
VLLM_MLA_CPLB
and
gather_indexes_tensor
is
not
None
:
# Reorder kv after pcp allgather.
kv_c_normed
=
torch
.
index_select
(
kv_c_normed
,
0
,
gather_indexes_tensor
)
k_pe
=
torch
.
index_select
(
k_pe
,
0
,
gather_indexes_tensor
)
attn_out
=
self
.
mla_attn
(
q
,
kv_c_normed
,
...
...
@@ -244,6 +250,11 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
k_pe
=
tensor_model_parallel_all_gather
(
k_pe
.
contiguous
(),
0
)
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
if
envs
.
VLLM_MLA_CPLB
and
gather_indexes_tensor
is
not
None
:
# Reorder kv after pcp allgather.
kv_c
=
torch
.
index_select
(
kv_c
,
0
,
gather_indexes_tensor
)
k_pe
=
torch
.
index_select
(
k_pe
,
0
,
gather_indexes_tensor
)
attn_out
=
self
.
mla_attn
(
q
[...,
self
.
qk_nope_head_dim
:],
...
...
vllm/model_executor/models/deepseek_mtp.py
View file @
f3d1f95b
...
...
@@ -200,6 +200,8 @@ class DeepSeekMultiTokenPredictor(nn.Module):
current_step_idx
=
spec_step_idx
%
self
.
num_mtp_layers
enable_mla_cp
=
get_forward_context
().
enable_mla_cp
#envs.VLLM_MLA_CP # and not get_forward_context().draft_model
if
enable_mla_cp
:
scatter_indexes_tensor
=
get_forward_context
().
scatter_indexes_tensor
if
scatter_indexes_tensor
is
None
:
inputs_embeds_per_rank
=
torch
.
chunk
(
inputs_embeds
,
chunks
=
self
.
tp_size
,
dim
=
0
)
inputs_embeds
=
inputs_embeds_per_rank
[
self
.
tp_rank
].
contiguous
()
...
...
@@ -209,6 +211,14 @@ class DeepSeekMultiTokenPredictor(nn.Module):
if
positions
is
not
None
:
positions_per_rank
=
torch
.
chunk
(
positions
,
chunks
=
self
.
tp_size
,
dim
=
0
)
positions
=
positions_per_rank
[
self
.
tp_rank
].
contiguous
()
else
:
#scatter_indexes_tensor = scatter_indexes_tensor[scatter_indexes_tensor != -1]
scatter_indexes_tensor
=
torch
.
where
(
scatter_indexes_tensor
==
-
1
,
0
,
scatter_indexes_tensor
)
inputs_embeds
=
torch
.
index_select
(
inputs_embeds
,
0
,
scatter_indexes_tensor
)
previous_hidden_states
=
torch
.
index_select
(
previous_hidden_states
,
0
,
scatter_indexes_tensor
)
if
positions
is
not
None
:
positions
=
torch
.
index_select
(
positions
,
0
,
scatter_indexes_tensor
)
hidden_states
=
self
.
layers
[
str
(
self
.
mtp_start_layer_idx
+
current_step_idx
)](
input_ids
,
...
...
@@ -220,6 +230,9 @@ class DeepSeekMultiTokenPredictor(nn.Module):
if
enable_mla_cp
:
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
.
contiguous
(),
dim
=
0
)
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
if
gather_indexes_tensor
is
not
None
:
hidden_states
=
torch
.
index_select
(
hidden_states
,
0
,
gather_indexes_tensor
)
return
hidden_states
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
f3d1f95b
...
...
@@ -855,6 +855,9 @@ class Indexer(nn.Module):
k
=
tensor_model_parallel_all_gather
(
k
.
contiguous
(),
0
)
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
if
envs
.
VLLM_MLA_CPLB
and
gather_indexes_tensor
is
not
None
:
k
=
torch
.
index_select
(
k
,
0
,
gather_indexes_tensor
)
# we only quant q here since k quant is fused with cache insertion
if
not
current_platform
.
is_rocm
()
or
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
:
...
...
@@ -1397,6 +1400,8 @@ class DeepseekV2Model(nn.Module):
enable_mla_cp
=
get_forward_context
().
enable_mla_cp
#envs.VLLM_MLA_CP # and not get_forward_context().draft_model
if
enable_mla_cp
:
scatter_indexes_tensor
=
get_forward_context
().
scatter_indexes_tensor
if
scatter_indexes_tensor
is
None
:
hidden_states_per_rank
=
torch
.
chunk
(
hidden_states
,
chunks
=
self
.
tp_size
,
dim
=
0
)
hidden_states
=
hidden_states_per_rank
[
self
.
tp_rank
].
contiguous
()
...
...
@@ -1407,6 +1412,15 @@ class DeepseekV2Model(nn.Module):
if
positions
is
not
None
:
positions_per_rank
=
torch
.
chunk
(
positions
,
chunks
=
self
.
tp_size
,
dim
=
0
)
positions
=
positions_per_rank
[
self
.
tp_rank
].
contiguous
()
else
:
scatter_indexes_tensor
=
torch
.
where
(
scatter_indexes_tensor
==
-
1
,
0
,
scatter_indexes_tensor
)
hidden_states
=
torch
.
index_select
(
hidden_states
,
0
,
scatter_indexes_tensor
)
if
residual
is
not
None
:
residual
=
torch
.
index_select
(
residual
,
0
,
scatter_indexes_tensor
)
if
positions
is
not
None
:
positions
=
torch
.
index_select
(
positions
,
0
,
scatter_indexes_tensor
)
# Compute llama 4 scaling once per forward pass if enabled
llama_4_scaling_config
=
getattr
(
self
.
config
,
"llama_4_scaling"
,
None
)
...
...
@@ -1439,6 +1453,9 @@ class DeepseekV2Model(nn.Module):
if
enable_mla_cp
:
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
.
contiguous
(),
dim
=
0
)
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
if
gather_indexes_tensor
is
not
None
:
hidden_states
=
torch
.
index_select
(
hidden_states
,
0
,
gather_indexes_tensor
)
return
hidden_states
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
f3d1f95b
...
...
@@ -5159,7 +5159,7 @@ class GPUModelRunner(
batch_descriptor
=
batch_desc
,
ubatch_slices
=
ubatch_slices_padded
,
slot_mapping
=
slot_mappings
,
enable_mla_cp
=
envs
.
VLLM_MLA_CP
and
num_tokens_unpadded
>
self
.
mla_cp_threshould
enable_mla_cp
=
envs
.
VLLM_MLA_CP
and
num_tokens_unpadded
>
self
.
mla_cp_threshould
,
),
):
outputs
=
self
.
model
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment