Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
44d451ff
Commit
44d451ff
authored
Feb 12, 2025
by
王敏
Browse files
优化ep moe
parent
36e35fac
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
75 additions
and
17 deletions
+75
-17
csrc/moe/moe_align_sum_kernels.cu
csrc/moe/moe_align_sum_kernels.cu
+3
-2
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+1
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+0
-3
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+1
-0
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+69
-8
vllm/model_executor/models/deepseek_v3.py
vllm/model_executor/models/deepseek_v3.py
+1
-4
No files found.
csrc/moe/moe_align_sum_kernels.cu
View file @
44d451ff
...
...
@@ -307,8 +307,9 @@ __global__ void ep_moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
* assigned to expert expert_index.
*/
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
if
(
topk_ids
[
i
]
>=
start_expert
&&
topk_ids
[
i
]
<
end_expert
)
{
++
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
+
1
,
topk_ids
[
i
]
-
start_expert
)];
expert_id
=
topk_ids
[
i
];
if
(
expert_id
>=
start_expert
&&
expert_id
<
end_expert
)
{
++
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
+
1
,
expert_id
-
start_expert
)];
}
}
...
...
vllm/engine/arg_utils.py
View file @
44d451ff
...
...
@@ -420,6 +420,7 @@ class EngineArgs:
default
=
EngineArgs
.
tensor_parallel_size
,
help
=
'Number of tensor parallel replicas.'
)
parser
.
add_argument
(
'--moe-ep-size'
,
'-ep'
,
type
=
int
,
default
=
EngineArgs
.
moe_ep_size
,
help
=
'Number of moe expert parallel replicas.'
)
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
44d451ff
...
...
@@ -1345,9 +1345,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
E
))
if
moe_ep_size
==
1
:
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
E
))
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
44d451ff
...
...
@@ -234,6 +234,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
moe_ep_size
:
Optional
[
int
]
=
1
,
start_expert
:
Optional
[
int
]
=
-
1
,
end_expert
:
Optional
[
int
]
=
-
1
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
44d451ff
...
...
@@ -22,6 +22,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only DeepseekV2/DeepseekV3 model."""
import
os
import
re
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
...
...
@@ -56,6 +58,7 @@ from .interfaces import SupportsPP
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
from
vllm
import
_custom_ops
as
ops
class
DeepseekV2MLP
(
nn
.
Module
):
...
...
@@ -100,6 +103,7 @@ class DeepseekV2MoE(nn.Module):
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
moe_ep_size
:
int
=
1
):
super
().
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
...
...
@@ -139,7 +143,8 @@ class DeepseekV2MoE(nn.Module):
topk_group
=
config
.
topk_group
,
prefix
=
f
"
{
prefix
}
.experts"
,
scoring_func
=
config
.
scoring_func
,
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
)
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
,
moe_ep_size
=
moe_ep_size
)
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
(
config
.
moe_intermediate_size
*
...
...
@@ -490,6 +495,7 @@ class DeepseekV2DecoderLayer(nn.Module):
model_config
:
ModelConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
moe_ep_size
:
int
=
1
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
...
...
@@ -529,6 +535,7 @@ class DeepseekV2DecoderLayer(nn.Module):
config
=
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
moe_ep_size
=
moe_ep_size
)
else
:
self
.
mlp
=
DeepseekV2MLP
(
...
...
@@ -577,7 +584,7 @@ class DeepseekV2Model(nn.Module):
fall_back_to_pt_during_load
=
False
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
moe_ep_size
:
int
=
1
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
...
...
@@ -604,6 +611,7 @@ class DeepseekV2Model(nn.Module):
model_config
=
model_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
moe_ep_size
=
moe_ep_size
),
prefix
=
f
"
{
prefix
}
.layers"
)
...
...
@@ -662,8 +670,12 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
moe_ep_size
=
self
.
parallel_config
.
moe_ep_size
self
.
model
=
DeepseekV2Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
prefix
=
maybe_prefix
(
prefix
,
"model"
),
moe_ep_size
=
self
.
moe_ep_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
...
...
@@ -672,6 +684,12 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
...
...
@@ -730,11 +748,19 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
if
self
.
moe_ep_size
==
1
:
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
n_routed_experts
)
else
:
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping_ep
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
n_routed_experts
,
moe_ep_size
=
self
.
moe_ep_size
)
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
...
...
@@ -803,6 +829,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
# Skip loading extra expert weights for ep moe mode
if
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
...
...
@@ -810,6 +840,37 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
if
self
.
use_llama_nn
and
self
.
quant_method
is
None
:
lay_key_words
=
[
"self_attn.q_a_proj.weight"
,
"self_attn.kv_a_proj_with_mqa.weight"
,
"mlp.gate.weight"
,
"mlp.gate_up_proj.weight"
,
"mlp.down_proj"
,
"shared_experts.gate_up_proj"
,
"shared_experts.down_proj"
,
"self_attn.q_proj.weight"
,
"self_attn.q_b_proj.weight"
,
"self_attn.kv_b_proj.weight"
,
"self_attn.o_proj.weight"
,
"lm_head.weight"
]
combined_words
=
"|"
.
join
(
lay_key_words
)
for
layername
in
loaded_params
:
weight
=
params_dict
[
layername
]
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
_weight
=
torch
.
zeros_like
(
weight
.
data
)
ori_shape
=
_weight
.
shape
ops
.
trans_w16_gemm
(
_weight
,
weight
.
data
,
_weight
.
shape
[
0
],
_weight
.
shape
[
1
])
weight
.
data
.
copy_
(
_weight
)
weight
.
data
=
weight
.
data
.
reshape
(
ori_shape
[
1
],
-
1
)
return
loaded_params
...
...
vllm/model_executor/models/deepseek_v3.py
View file @
44d451ff
...
...
@@ -683,9 +683,6 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP):
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_mla
=
False
if
hasattr
(
vllm_config
.
model_config
,
"use_mla"
):
self
.
use_mla
=
vllm_config
.
model_config
.
use_mla
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -753,7 +750,7 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP):
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
n_routed_experts
)
else
:
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
_ep
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment