Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5dcc5cb8
Commit
5dcc5cb8
authored
Oct 09, 2025
by
王敏
Browse files
优化mori ep
parent
e0ba23b5
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
179 additions
and
186 deletions
+179
-186
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+1
-1
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
+178
-185
No files found.
vllm/distributed/parallel_state.py
View file @
5dcc5cb8
...
@@ -948,7 +948,7 @@ def init_distributed_environment(
...
@@ -948,7 +948,7 @@ def init_distributed_environment(
"Fallback Gloo backend is not available."
)
"Fallback Gloo backend is not available."
)
backend
=
"gloo"
backend
=
"gloo"
# this backend is used for WORLD
# this backend is used for WORLD
parallel_config
=
config
.
parallel_config
data_parallel_size
=
parallel_config
.
data_parallel_size
data_parallel_size
=
parallel_config
.
data_parallel_size
use_mori_ep
=
envs
.
VLLM_USE_MORI_EP
and
data_parallel_size
>
1
and
parallel_config
.
enable_expert_parallel
use_mori_ep
=
envs
.
VLLM_USE_MORI_EP
and
data_parallel_size
>
1
and
parallel_config
.
enable_expert_parallel
if
use_mori_ep
:
if
use_mori_ep
:
...
...
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
View file @
5dcc5cb8
...
@@ -10,30 +10,31 @@ import torch.nn.functional as F
...
@@ -10,30 +10,31 @@ import torch.nn.functional as F
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.config
import
get_current_vllm_config
from
vllm.config
import
get_current_vllm_config
from
vllm.distributed.parallel_state
import
get_ep_group
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.distributed.parallel_state
import
get_ep_group
,
get_node_count
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEConfig
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEConfig
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.distributed
import
expert_parallel_all_gather
,
expert_parallel_all_reduce
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoEMethodBase
,
UnquantizedFusedMoEMethod
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoEMethodBase
,
UnquantizedFusedMoEMethod
from
vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
from
vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
from
vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis
import
EpMoeConfig
from
vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis
import
EpMoeConfig
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
import
mori
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
lmslim.layers.gemm.int8_utils
import
(
try
:
per_token_group_quant_int8
,
import
mori
from
lmslim.layers.gemm.int8_utils
import
(
per_token_quant_int8
)
per_token_quant_int8
)
except
ImportError
:
is_mori_available
=
False
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_MORI_OP
=
None
_MORI_OP
=
None
@
CustomOp
.
register
(
"unquantized_ep_moe"
)
@
CustomOp
.
register
(
"unquantized_ep_moe"
)
class
UnquantizedEPGroupedGemmMethod
(
UnquantizedFusedMoEMethod
):
class
UnquantizedEPGroupedGemmMethod
(
UnquantizedFusedMoEMethod
):
"""MoE method without quantization."""
"""MoE method without quantization."""
...
@@ -166,6 +167,7 @@ class EPMoE(FusedMoE):
...
@@ -166,6 +167,7 @@ class EPMoE(FusedMoE):
dp+ep MoE Expert Parallel Impl
dp+ep MoE Expert Parallel Impl
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
num_experts
:
int
,
# Global number of experts
num_experts
:
int
,
# Global number of experts
...
@@ -241,26 +243,27 @@ class EPMoE(FusedMoE):
...
@@ -241,26 +243,27 @@ class EPMoE(FusedMoE):
self
.
scales
=
None
self
.
scales
=
None
self
.
use_int8_dispatch
=
True
self
.
use_int8_dispatch
=
True
vllm_config
=
get_current_vllm_config
()
self
.
max_num_inp_token_per_rank
=
vllm_config
.
scheduler_config
.
max_num_seqs
self
.
mori_op
=
self
.
get_mori_op
()
self
.
mori_op
=
self
.
get_mori_op
()
self
.
first
=
True
self
.
first
=
True
def
get_mori_op
(
self
):
def
get_mori_op
(
self
):
global
_MORI_OP
global
_MORI_OP
if
_MORI_OP
is
None
:
if
_MORI_OP
is
None
:
# world_group = torch.distributed.group.WORLD
# assert world_group is not None
#torch._C._distributed_c10d._register_process_group("mori_ep", get_ep_group().device_group)
#mori.shmem.shmem_torch_process_group_init("mori_ep")
world_group
=
torch
.
distributed
.
group
.
WORLD
world_group
=
torch
.
distributed
.
group
.
WORLD
assert
world_group
is
not
None
assert
world_group
is
not
None
torch
.
_C
.
_distributed_c10d
.
_register_process_group
(
"default"
,
world_group
)
torch
.
_C
.
_distributed_c10d
.
_register_process_group
(
"mori_ep"
,
get_ep_group
().
device_group
)
mori
.
shmem
.
shmem_torch_process_group_init
(
"default"
)
mori
.
shmem
.
shmem_torch_process_group_init
(
"mori_ep"
)
# world_group = torch.distributed.group.WORLD
# assert world_group is not None
# torch._C._distributed_c10d._register_process_group("default", world_group)
# mori.shmem.shmem_torch_process_group_init("default")
vllm_config
=
get_current_vllm_config
()
vllm_config
=
get_current_vllm_config
()
multi_node
=
self
.
ep_size
/
8
>
1
multi_node
=
self
.
ep_size
/
8
>
1
mori_data_type
=
vllm_config
.
model_config
.
dtype
mori_data_type
=
vllm_config
.
model_config
.
dtype
mori_scale_type_size
=
vllm_config
.
model_config
.
dtype
.
itemsize
mori_scale_type_size
=
vllm_config
.
model_config
.
dtype
.
itemsize
if
self
.
use_int8_dispatch
:
if
self
.
use_int8_dispatch
:
mori_scale_type_size
=
4
mori_scale_type_size
=
4
...
@@ -272,12 +275,13 @@ class EPMoE(FusedMoE):
...
@@ -272,12 +275,13 @@ class EPMoE(FusedMoE):
hidden_dim
=
self
.
hidden_size
,
hidden_dim
=
self
.
hidden_size
,
scale_dim
=
1
if
self
.
use_int8_dispatch
else
0
,
scale_dim
=
1
if
self
.
use_int8_dispatch
else
0
,
scale_type_size
=
mori_scale_type_size
,
scale_type_size
=
mori_scale_type_size
,
max_num_inp_token_per_rank
=
2048
,
max_num_inp_token_per_rank
=
self
.
max_num_inp_token_per_rank
,
num_experts_per_rank
=
self
.
local_num_experts
,
num_experts_per_rank
=
self
.
local_num_experts
,
num_experts_per_token
=
self
.
top_k
,
num_experts_per_token
=
self
.
top_k
,
max_token_type_size
=
2
,
max_token_type_size
=
2
,
block_num
=
80
,
block_num
=
80
,
warp_num_per_block
=
16
,
warp_num_per_block
=
16
,
# kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode
kernel_type
=
mori
.
ops
.
EpDispatchCombineKernelType
.
InterNode
if
multi_node
else
\
kernel_type
=
mori
.
ops
.
EpDispatchCombineKernelType
.
InterNode
if
multi_node
else
\
mori
.
ops
.
EpDispatchCombineKernelType
.
IntraNode
mori
.
ops
.
EpDispatchCombineKernelType
.
IntraNode
)
)
...
@@ -304,7 +308,7 @@ class EPMoE(FusedMoE):
...
@@ -304,7 +308,7 @@ class EPMoE(FusedMoE):
return
quant_method
return
quant_method
def
sync
(
self
):
def
sync
(
self
):
#torch.cuda.synchronize()
#
torch.cuda.synchronize()
dist
.
barrier
()
dist
.
barrier
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
...
@@ -331,7 +335,6 @@ class EPMoE(FusedMoE):
...
@@ -331,7 +335,6 @@ class EPMoE(FusedMoE):
if
name
not
in
NON_EXPERT_WEIGHTS
if
name
not
in
NON_EXPERT_WEIGHTS
]
]
def
forward_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
def
forward_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
router_logits
:
torch
.
Tensor
):
...
@@ -365,8 +368,7 @@ class EPMoE(FusedMoE):
...
@@ -365,8 +368,7 @@ class EPMoE(FusedMoE):
)
)
scales
=
self
.
scales
scales
=
self
.
scales
# self.sync()
#self.sync()
(
(
dispatch_output
,
dispatch_output
,
...
@@ -380,60 +382,50 @@ class EPMoE(FusedMoE):
...
@@ -380,60 +382,50 @@ class EPMoE(FusedMoE):
scales
,
scales
,
topk_ids
,
topk_ids
,
)
)
#self.sync()
# self.sync()
expect_m
=
hidden_states
.
shape
[
0
]
*
self
.
ep_size
dispatch_output_clip
=
dispatch_output
[:
expect_m
]
dispatch_weights_clip
=
dispatch_weights
[:
expect_m
]
dispatch_indices_clip
=
dispatch_indices
[:
expect_m
]
dispatch_scales_clip
=
dispatch_scales
[:
expect_m
]
expert_output
=
self
.
quant_method
.
apply_ep
(
layer
=
self
,
x
=
dispatch_output_clip
,
topk_weights
=
dispatch_weights_clip
,
topk_ids
=
dispatch_indices_clip
,
global_num_experts
=
self
.
global_num_experts
,
expert_map
=
self
.
expert_map
,
activation
=
self
.
activation
,
apply_router_weight_on_input
=
self
.
apply_router_weight_on_input
,
use_nn_moe
=
self
.
use_nn_moe
,
num_local_tokens
=
dispatch_recv_num_token
,
config_select_bs
=
hidden_states
.
shape
[
0
],
scales
=
dispatch_scales_clip
if
self
.
use_int8_dispatch
else
None
#routed_scaling_factor=self.routed_scaling_factor,
)
# if self.first and hidden_states.shape[0] == 2:
# self.first = False
# import numpy as np
# np.save(f'/work/vllm_profile/ep{self.ep_rank}_topk_ids.npy', dispatch_indices_clip.cpu().numpy())
# print("##################config_select_bs:{} topk_ids shape:{} num_local_tokens:{}".format(hidden_states.shape[0],
# topk_ids.shape,
# dispatch_recv_num_token))
# expect_m = topk_ids.shape[0] * self.ep_size
# dispatch_output_clip = dispatch_output[:expect_m]
# dispatch_weights_clip = dispatch_weights[:expect_m]
# dispatch_indices_clip = dispatch_indices[:expect_m]
# dispatch_scales_clip = dispatch_scales[:expect_m]
# expert_output = self.quant_method.apply_ep(
# expert_output = self.quant_method.apply_ep(
# layer=self,
# layer=self,
# x=dispatch_output,
# x=dispatch_output
_clip
,
# topk_weights=dispatch_weights,
# topk_weights=dispatch_weights
_clip
,
# topk_ids=dispatch_indices,
# topk_ids=dispatch_indices
_clip
,
# global_num_experts=self.global_num_experts,
# global_num_experts=self.global_num_experts,
# expert_map=self.expert_map,
# expert_map=self.expert_map,
# activation=self.activation,
# activation=self.activation,
# apply_router_weight_on_input=self.apply_router_weight_on_input,
# apply_router_weight_on_input=self.apply_router_weight_on_input,
# use_nn_moe=self.use_nn_moe,
# use_nn_moe=self.use_nn_moe,
# num_local_tokens=dispatch_recv_num_token,
# num_local_tokens=dispatch_recv_num_token,
# config_select_bs=hidden_states.shape[0]
*2
,
# config_select_bs=hidden_states.shape[0],
# scales=dispatch_scales if self.use_int8_dispatch else None
# scales=dispatch_scales
_clip
if self.use_int8_dispatch else None
# #routed_scaling_factor=self.routed_scaling_factor,
# #routed_scaling_factor=self.routed_scaling_factor,
# )
# )
#self.sync()
expert_output
=
self
.
quant_method
.
apply_ep
(
layer
=
self
,
x
=
dispatch_output
,
topk_weights
=
dispatch_weights
,
topk_ids
=
dispatch_indices
,
global_num_experts
=
self
.
global_num_experts
,
expert_map
=
self
.
expert_map
,
activation
=
self
.
activation
,
apply_router_weight_on_input
=
self
.
apply_router_weight_on_input
,
use_nn_moe
=
self
.
use_nn_moe
,
num_local_tokens
=
dispatch_recv_num_token
,
config_select_bs
=
hidden_states
.
shape
[
0
],
scales
=
dispatch_scales
if
self
.
use_int8_dispatch
else
None
# routed_scaling_factor=self.routed_scaling_factor,
)
# self.sync()
combine_output
,
_
=
self
.
mori_op
.
combine
(
expert_output
,
dispatch_weights
,
topk_ids
)
combine_output
,
_
=
self
.
mori_op
.
combine
(
expert_output
,
dispatch_weights
,
topk_ids
)
final_hidden_states
=
combine_output
[:
hidden_states
.
shape
[
0
],
:]
final_hidden_states
=
combine_output
[:
hidden_states
.
shape
[
0
],
:]
#self.sync()
#
self.sync()
if
not
self
.
ep_moe_config
.
moe_shared_expert_overlap
and
self
.
shared_experts
is
not
None
:
if
not
self
.
ep_moe_config
.
moe_shared_expert_overlap
and
self
.
shared_experts
is
not
None
:
# if shared_expert_overlap is True, the expert calculation happens in
# if shared_expert_overlap is True, the expert calculation happens in
...
@@ -452,6 +444,7 @@ class EPMoE(FusedMoE):
...
@@ -452,6 +444,7 @@ class EPMoE(FusedMoE):
return
final_hidden_states
return
final_hidden_states
def
ep_moe_forward
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
def
ep_moe_forward
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
layer_name
:
str
)
->
torch
.
Tensor
:
layer_name
:
str
)
->
torch
.
Tensor
:
forward_context
:
ForwardContext
=
get_forward_context
()
forward_context
:
ForwardContext
=
get_forward_context
()
...
@@ -472,5 +465,5 @@ direct_register_custom_op(
...
@@ -472,5 +465,5 @@ direct_register_custom_op(
mutates_args
=
[
"hidden_states"
,
"router_logits"
],
mutates_args
=
[
"hidden_states"
,
"router_logits"
],
fake_impl
=
ep_moe_forward_fake
,
fake_impl
=
ep_moe_forward_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
dispatch_key
=
current_platform
.
dispatch_key
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,),
)
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment