Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3c0e74be
Commit
3c0e74be
authored
Apr 16, 2026
by
wujl5
Committed by
wangmin6
Apr 16, 2026
Browse files
[fix] MLP传入量化参数,MLP allGather通讯优化
parent
c637d1aa
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
46 additions
and
7 deletions
+46
-7
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+46
-7
No files found.
vllm/model_executor/models/deepseek_v2.py
View file @
3c0e74be
...
@@ -183,6 +183,44 @@ class DeepseekAttention(nn.Module):
...
@@ -183,6 +183,44 @@ class DeepseekAttention(nn.Module):
return
output
return
output
def
eff_2d_iqis_all_gather
(
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
tp_size
:
int
|
None
=
None
,
tp_rank
:
int
|
None
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
iqis
is
not
None
iq_tensor
,
is_tensor
=
iqis
assert
isinstance
(
iq_tensor
,
torch
.
Tensor
)
assert
isinstance
(
is_tensor
,
torch
.
Tensor
)
assert
iq_tensor
.
dtype
==
torch
.
int8
,
f
"iq_tensor dtype is
{
iq_tensor
.
dtype
}
"
assert
is_tensor
.
dtype
==
torch
.
float32
,
f
"is_tensor dtype is
{
is_tensor
.
dtype
}
"
assert
iq_tensor
.
dim
()
==
2
assert
is_tensor
.
dim
()
==
2
m_local
,
n
=
iq_tensor
.
shape
assert
is_tensor
.
shape
[
0
]
==
m_local
,
f
"
{
is_tensor
.
shape
[
0
]
}
!=
{
iq_tensor
.
shape
[
0
]
}
"
assert
is_tensor
.
shape
[
1
]
==
1
,
f
"is_tensor dim 1 =
{
is_tensor
.
shape
[
1
]
}
"
iq_int8_2d
=
iq_tensor
.
view
(
torch
.
int8
)
is_int8_2d
=
is_tensor
.
view
(
torch
.
int8
)
combined_2d
=
torch
.
cat
([
iq_int8_2d
,
is_int8_2d
],
dim
=
1
)
# [m_local, n + 4]
if
not
combined_2d
.
is_contiguous
():
combined_2d
=
combined_2d
.
contiguous
()
combined_gathered
=
tensor_model_parallel_all_gather
(
combined_2d
,
dim
=
0
)
split_idx
=
n
iq_gathered_int8
=
combined_gathered
[:,
:
split_idx
].
contiguous
()
is_gathered_int8
=
combined_gathered
[:,
split_idx
:].
contiguous
()
iq_gathered
=
iq_gathered_int8
.
view
(
torch
.
int8
)
assert
iq_gathered
.
shape
[
0
]
==
m_local
*
tp_size
,
f
"iq_gathered dim0=
{
iq_gathered
.
shape
[
0
]
}
, expected
{
m_local
*
tp_size
}
"
# is_gathered_int8 should be [m_local*tp_size, 4]
assert
is_gathered_int8
.
shape
[
0
]
==
m_local
*
tp_size
,
f
"is_gathered_int8 dim0=
{
is_gathered_int8
.
shape
[
0
]
}
, expected
{
m_local
*
tp_size
}
"
assert
is_gathered_int8
.
shape
[
1
]
==
4
,
f
"is_gathered_int8 dim1=
{
is_gathered_int8
.
shape
[
1
]
}
"
is_gathered
=
is_gathered_int8
.
view
(
torch
.
float32
)
return
(
iq_gathered
,
is_gathered
)
class
DeepseekV2MLP
(
nn
.
Module
):
class
DeepseekV2MLP
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -232,13 +270,14 @@ class DeepseekV2MLP(nn.Module):
...
@@ -232,13 +270,14 @@ class DeepseekV2MLP(nn.Module):
enable_mla_cp
=
get_forward_context
().
enable_mla_cp
#envs.VLLM_MLA_CP# and not get_forward_context().draft_model
enable_mla_cp
=
get_forward_context
().
enable_mla_cp
#envs.VLLM_MLA_CP# and not get_forward_context().draft_model
if
enable_mla_cp
:
if
enable_mla_cp
:
if
iqis
is
not
None
and
iqis
[
0
]
is
not
None
and
iqis
[
1
]
is
not
None
:
if
iqis
is
not
None
and
iqis
[
0
]
is
not
None
and
iqis
[
1
]
is
not
None
:
if
False
:
i_q_gahter
=
tensor_model_parallel_all_gather
(
iqis
[
0
].
contiguous
(),
0
)
i_q_gahter
=
tensor_model_parallel_all_gather
(
iqis
[
0
].
contiguous
(),
0
)
i_s_gather
=
tensor_model_parallel_all_gather
(
iqis
[
1
].
contiguous
(),
0
)
i_s_gather
=
tensor_model_parallel_all_gather
(
iqis
[
1
].
contiguous
(),
0
)
iqis
=
(
i_q_gahter
,
i_s_gather
)
iqis
=
(
i_q_gahter
,
i_s_gather
)
else
:
else
:
x
=
tensor_model_parallel_
all_gather
(
iqis
=
eff_2d_iqis_all_gather
(
iqis
,
tp_size
=
self
.
tp_size
,
tp_rank
=
get_
tensor_model_parallel_
rank
())
x
.
contiguous
(),
0
else
:
)
x
=
tensor_model_parallel_all_gather
(
x
.
contiguous
(),
0
)
if
envs
.
USE_FUSED_RMS_QUANT
:
if
envs
.
USE_FUSED_RMS_QUANT
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
,
iqis
=
iqis
)
gate_up
,
_
=
self
.
gate_up_proj
(
x
,
iqis
=
iqis
)
...
@@ -1233,7 +1272,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1233,7 +1272,7 @@ class DeepseekV2DecoderLayer(nn.Module):
update_input
=
update_hs
update_input
=
update_hs
)
)
new_resi
=
residual
new_resi
=
residual
if
skip_moe_large_batch_size
:
if
skip_moe_large_batch_size
and
isinstance
(
self
.
mlp
,
DeepseekV2MoE
)
:
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
else
:
else
:
hidden_states
=
self
.
mlp
(
hidden_states
,
iqis
=
(
_i_q
,
_i_s
))
hidden_states
=
self
.
mlp
(
hidden_states
,
iqis
=
(
_i_q
,
_i_s
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment