Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fe70dcb2
Commit
fe70dcb2
authored
Sep 07, 2025
by
王敏
Browse files
临时添加cudagraph代码,目前还有问题
parent
121db653
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
264 additions
and
92 deletions
+264
-92
vllm/config.py
vllm/config.py
+3
-0
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+1
-0
vllm/model_executor/layers/fused_moe/ep_moe/ep_moe_utlis.py
vllm/model_executor/layers/fused_moe/ep_moe/ep_moe_utlis.py
+18
-0
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
+85
-9
vllm/model_executor/layers/fused_moe/ep_moe/token_dispatcher.py
...odel_executor/layers/fused_moe/ep_moe/token_dispatcher.py
+157
-83
No files found.
vllm/config.py
View file @
fe70dcb2
...
...
@@ -4320,6 +4320,9 @@ class CompilationConfig:
self
.
splitting_ops
=
[]
if
self
.
full_cuda_graph
else
[
"vllm.unified_attention"
,
"vllm.unified_attention_with_output"
,
"vllm.token_permutation_forward"
,
"vllm.token_unpermutation_forward"
,
"vllm.ep_moe_forward"
,
]
...
...
vllm/distributed/parallel_state.py
View file @
fe70dcb2
...
...
@@ -948,6 +948,7 @@ def init_distributed_environment(
"Fallback Gloo backend is not available."
)
backend
=
"gloo"
# this backend is used for WORLD
backend
=
"cpu:gloo,cuda:nccl"
torch
.
distributed
.
init_process_group
(
backend
=
backend
,
init_method
=
distributed_init_method
,
...
...
vllm/model_executor/layers/fused_moe/ep_moe/ep_moe_utlis.py
View file @
fe70dcb2
...
...
@@ -204,6 +204,24 @@ def maybe_move_tensor_to_cpu(tensor, as_numpy=False, record_stream=False):
return
tensor
def
maybe_move_tensor_to_cpu_block
(
tensor
,
as_numpy
=
False
,
record_stream
=
False
):
"""Move a tensor to CPU if it is on GPU.
Args:
tensor (torch.Tensor or None): The tensor to move to CPU.
as_numpy (bool): Whether to convert the tensor to a numpy array.
record_stream (bool): Whether to record the stream of the tensor, to prevent memory leak
when the DtoH data transfer is on a side stream.
"""
if
torch
.
is_tensor
(
tensor
)
and
tensor
.
is_cuda
:
cpu_tensor
=
tensor
.
to
(
torch
.
device
(
"cpu"
))
if
as_numpy
:
cpu_tensor
=
cpu_tensor
.
numpy
()
if
record_stream
:
tensor
.
record_stream
(
torch
.
cuda
.
current_stream
())
tensor
=
cpu_tensor
return
tensor
def
sort_chunks_by_idxs
(
input
:
torch
.
Tensor
,
split_sizes
:
torch
.
Tensor
,
sorted_idxs
:
torch
.
Tensor
,
fused
:
bool
=
False
):
...
...
vllm/model_executor/layers/fused_moe/ep_moe/layer.py
View file @
fe70dcb2
...
...
@@ -8,8 +8,10 @@ import torch.nn.functional as F
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.config
import
get_current_vllm_config
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.distributed.parallel_state
import
get_ep_group
,
get_node_count
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEConfig
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
...
...
@@ -19,10 +21,13 @@ from vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher import MoEAllt
from
vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis
import
EpMoeConfig
from
vllm.utils
import
direct_register_custom_op
from
lightop
import
groupgemm
#import mori
import
torch.distributed
as
dist
logger
=
init_logger
(
__name__
)
_MORI_OP
=
None
@
CustomOp
.
register
(
"unquantized_ep_moe"
)
class
UnquantizedEPGroupedGemmMethod
(
UnquantizedFusedMoEMethod
):
...
...
@@ -36,7 +41,7 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self
.
rocm_aiter_moe_enabled
=
False
# is_rocm_aiter_moe_enabled()
self
.
zero_token_count
=
None
def
apply
(
def
apply
_ep
(
self
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
...
...
@@ -218,7 +223,8 @@ class EPMoE(FusedMoE):
self
.
use_shared_expert
=
False
self
.
token_dispatcher
=
MoEAlltoAllTokenDispatcher
(
self
.
local_num_experts
,
self
.
local_expert_indices
,
config
=
self
.
ep_moe_config
self
.
local_num_experts
,
self
.
local_expert_indices
,
config
=
self
.
ep_moe_config
,
#layer_name=f"{self.layer_name}.token_dispatcher",
)
self
.
shared_expert_overlap
=
moe_shared_expert_overlap
...
...
@@ -226,6 +232,36 @@ class EPMoE(FusedMoE):
self
.
dpsk_fp16_quick
=
os
.
environ
.
get
(
'DPSK_FP16_QUICK'
)
==
'1'
if
False
:
self
.
mori_op
=
self
.
get_mori_op
()
def
get_mori_op
(
self
):
global
_MORI_OP
if
_MORI_OP
is
None
:
# world_group = torch.distributed.group.WORLD
# assert world_group is not None
torch
.
_C
.
_distributed_c10d
.
_register_process_group
(
"mori_ep"
,
get_ep_group
().
device_group
)
mori
.
shmem
.
shmem_torch_process_group_init
(
"mori_ep"
)
vllm_config
=
get_current_vllm_config
()
config
=
mori
.
ops
.
EpDispatchCombineConfig
(
data_type
=
vllm_config
.
model_config
.
dtype
,
rank
=
self
.
ep_rank
,
world_size
=
self
.
ep_size
,
hidden_dim
=
self
.
hidden_size
,
scale_dim
=
0
,
scale_type_size
=
vllm_config
.
model_config
.
dtype
.
itemsize
,
max_num_inp_token_per_rank
=
10000
,
num_experts_per_rank
=
self
.
local_num_experts
,
num_experts_per_token
=
self
.
top_k
,
max_token_type_size
=
4
,
# block_num=40,
# warp_num_per_block=8,
kernel_type
=
mori
.
ops
.
EpDispatchCombineKernelType
.
InterNode
)
_MORI_OP
=
mori
.
ops
.
EpDispatchCombineOp
(
config
)
return
_MORI_OP
def
set_shared_experts
(
self
,
shared_experts
:
torch
.
nn
.
Module
):
if
self
.
shared_experts
is
None
:
self
.
shared_experts
=
shared_experts
...
...
@@ -243,6 +279,10 @@ class EPMoE(FusedMoE):
assert
isinstance
(
quant_method
,
FusedMoEMethodBase
)
return
quant_method
def
sync
(
self
):
torch
.
cuda
.
synchronize
()
dist
.
barrier
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
return
torch
.
ops
.
vllm
.
ep_moe_forward
(
hidden_states
,
router_logits
,
...
...
@@ -267,21 +307,57 @@ class EPMoE(FusedMoE):
if
not
self
.
ep_moe_config
.
moe_shared_expert_overlap
and
self
.
shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
if
True
:
probs
=
None
if
self
.
apply_router_weight_on_input
:
probs
=
torch
.
zeros_like
(
router_logits
,
dtype
=
topk_weights
.
dtype
).
scatter
(
1
,
topk_ids
,
topk_weights
)
routing_map
=
torch
.
zeros_like
(
router_logits
).
int
().
scatter
(
1
,
topk_ids
,
1
).
bool
()
(
dispatched_input
,
tokens_per_expert
)
=
self
.
token_dispatcher
.
token_permutation
(
hidden_states
,
probs
,
routing_map
)
else
:
topk_ids
=
topk_ids
.
to
(
torch
.
int32
)
scales
=
torch
.
rand
(
hidden_states
.
shape
[
0
],
0
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
,
)
(
dispatched_input
,
dispatch_weights
,
dispatch_scales
,
dispatch_indices
,
dispatch_recv_num_token
,
)
=
self
.
mori_op
.
dispatch
(
hidden_states
,
topk_weights
,
scales
,
topk_ids
,
)
tokens_per_expert
=
dispatch_recv_num_token
self
.
sync
()
print
(
"######################dispatched_input shape:{} dispatch_weights.shape:{} dispatch_indices shape:{}"
.
format
(
dispatched_input
.
shape
,
dispatch_weights
.
shape
,
dispatch_indices
.
shape
))
print
(
"####################dispatch_recv_num_token:"
,
dispatch_recv_num_token
.
tolist
())
#print("####################dispatch_weights:", dispatch_weights.tolist())
#print("####################dispatch_indices:", dispatch_indices.tolist())
# Matrix multiply.
expert_output
=
self
.
quant_method
.
apply
(
expert_output
=
self
.
quant_method
.
apply
_ep
(
layer
=
self
,
hidden_states
=
dispatched_input
,
tokens_per_expert
=
tokens_per_expert
)
if
True
:
final_hidden_states
=
self
.
token_dispatcher
.
token_unpermutation
(
expert_output
)
else
:
combine_output
,
_
=
self
.
mori_op
.
combine
(
expert_output
,
dispatch_weights
,
topk_ids
)
final_hidden_states
=
combine_output
[:
hidden_states
.
shape
[
0
],
:]
if
not
self
.
ep_moe_config
.
moe_shared_expert_overlap
and
self
.
shared_experts
is
not
None
:
# if shared_expert_overlap is True, the expert calculation happens in
# the token_dispatcher to overlap communications and computations
...
...
vllm/model_executor/layers/fused_moe/ep_moe/token_dispatcher.py
View file @
fe70dcb2
...
...
@@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch.nn
as
nn
from
vllm.distributed.parallel_state
import
(
get_dp_group
,
get_tp_group
,
...
...
@@ -11,6 +12,7 @@ from vllm.distributed.parallel_state import (get_dp_group,
get_tensor_model_parallel_rank
)
from
vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis
import
(
EPSharedExperts
,
maybe_move_tensor_to_cpu
,
maybe_move_tensor_to_cpu_block
,
permute
,
sort_chunks_by_idxs
,
unpermute
,
...
...
@@ -21,12 +23,16 @@ from vllm.distributed import (tensor_model_parallel_all_gather,
expert_parallel_all_gather
,
expert_parallel_gather
)
from
vllm.platforms
import
current_platform
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.utils
import
direct_register_custom_op
from
vllm.config
import
get_current_vllm_config
cuda_dtoh_stream
=
torch
.
cuda
.
Stream
()
cuda_dtoh_sync_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
class
MoETokenDispatcher
:
class
MoETokenDispatcher
(
nn
.
Module
):
"""
MoE Token Dispatcher
"""
...
...
@@ -35,6 +41,7 @@ class MoETokenDispatcher:
"""
Initialize the MoE Token Dispatcher.
"""
super
().
__init__
()
self
.
config
=
config
self
.
tp_size
=
1
...
...
@@ -106,7 +113,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
"""
def
__init__
(
self
,
num_local_experts
:
int
,
local_expert_indices
:
List
[
int
],
config
:
EpMoeConfig
self
,
num_local_experts
:
int
,
local_expert_indices
:
List
[
int
],
config
:
EpMoeConfig
,
layer_name
:
str
=
""
)
->
None
:
"""
Initialize the AlltoAll token dispatcher.
...
...
@@ -130,6 +137,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self
.
local_expert_indices
[
i
]
==
self
.
local_expert_indices
[
i
+
1
]
-
1
),
"local_expert_indices must be continous"
self
.
layer_name
=
layer_name
# [ep_size]. Represents the number of tokens sent by the current rank to other
# EP ranks.
self
.
input_splits
=
None
...
...
@@ -174,6 +182,13 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self
.
probs
=
None
self
.
dpsk_fp16_quick
=
os
.
environ
.
get
(
'DPSK_FP16_QUICK'
)
==
'1'
# For smuggling this layer into the fused moe custom op
vllm_config
=
get_current_vllm_config
()
compilation_config
=
vllm_config
.
compilation_config
if
layer_name
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
"Duplicate layer name: {}"
.
format
(
layer_name
))
compilation_config
.
static_forward_context
[
layer_name
]
=
self
def
preprocess
(
self
,
routing_map
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Preprocess token routing map for AlltoAll communication and token permutation.
...
...
@@ -196,7 +211,6 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self
.
num_out_tokens
=
routing_map
.
size
(
0
)
*
self
.
config
.
moe_router_topk
if
self
.
ep_size
>
1
or
self
.
tp_size
>
1
:
# ===================================================
# Calculate input_splits, output_splits for alltoall/allgather in variable size.
# ===================================================
...
...
@@ -240,16 +254,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
# A synchronization is needed before expert parallel AlltoAll communication
# to get the `input_splits` and `output_splits` CPU values.
self
.
_maybe_update_cuda_sync_point
(
"before_ep_alltoall"
)
else
:
num_global_tokens_per_local_expert
=
num_local_tokens_per_expert
.
reshape
(
self
.
num_experts
)
num_tokens_per_local_expert
=
num_local_tokens_per_expert
# A synchronization is needed before the returns
# to get the `num_tokens_per_local_expert` CPU value.
self
.
_maybe_update_cuda_sync_point
(
"before_finish"
)
#self._maybe_update_cuda_sync_point("before_ep_alltoall")
if
self
.
num_local_experts
>
1
:
# [tp_size * ep_size, num_local_experts]. Represents the number of tokens sent
...
...
@@ -257,21 +262,40 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
self
.
num_global_tokens_per_local_expert
=
num_global_tokens_per_local_expert
.
view
(
-
1
,
self
.
num_local_experts
)
if
not
self
.
config
.
moe_permute_fusion
:
# A synchronization is needed before permutation 2
# to get the `num_global_tokens_per_local_expert` CPU value.
self
.
_maybe_update_cuda_sync_point
(
"before_permutation_2"
)
assert
(
self
.
cuda_sync_point_priority
[
self
.
cuda_dtoh_point
]
<=
self
.
cuda_sync_point_priority
[
self
.
cuda_sync_point
]
),
"cuda_sync_point must be after cuda_dtoh_point."
#
if not self.config.moe_permute_fusion:
#
# A synchronization is needed before permutation 2
#
# to get the `num_global_tokens_per_local_expert` CPU value.
#
self._maybe_update_cuda_sync_point("before_permutation_2")
#
assert (
#
self.cuda_sync_point_priority[self.cuda_dtoh_point]
#
<= self.cuda_sync_point_priority[self.cuda_sync_point]
#
), "cuda_sync_point must be after cuda_dtoh_point."
return
num_tokens_per_local_expert
def
token_permutation
(
self
,
hidden_states
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
self
.
routing_map
=
routing_map
assert
routing_map
.
dim
()
==
2
,
"Expected 2D tensor for token2expert mask"
assert
routing_map
.
dtype
==
torch
.
bool
,
"Expected bool tensor for mask"
tokens_per_expert
=
self
.
preprocess
(
self
.
routing_map
)
if
self
.
config
.
moe_shared_expert_overlap
and
self
.
shared_experts
is
not
None
:
self
.
shared_experts
.
pre_forward_comm
(
hidden_states
.
view
(
self
.
hidden_shape
))
global_input_tokens
=
torch
.
ops
.
vllm
.
token_permutation_forward
(
tokens_per_expert
,
hidden_states
,
probs
,
routing_map
,
self
.
layer_name
)
return
global_input_tokens
,
tokens_per_expert
def
token_permutation_impl
(
self
,
tokens_per_expert
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Dispatch tokens to local experts using AlltoAll communication.
...
...
@@ -293,23 +317,16 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
- Number of tokens per expert.
"""
# Preprocess: Get the metadata for communication, permutation and computation operations.
# Permutation 1: input to AlltoAll input
tokens_per_expert
=
self
.
_maybe_dtoh_and_synchronize
(
"before_permutation_1"
,
tokens_per_expert
)
self
.
hidden_shape
=
hidden_states
.
shape
if
self
.
config
.
apply_router_weight_on_input
:
self
.
probs
=
probs
self
.
routing_map
=
routing_map
assert
probs
.
dim
()
==
2
,
"Expected 2D tensor for probs"
assert
routing_map
.
dim
()
==
2
,
"Expected 2D tensor for token2expert mask"
assert
routing_map
.
dtype
==
torch
.
bool
,
"Expected bool tensor for mask"
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_shape
[
-
1
])
tokens_per_expert
=
self
.
preprocess
(
self
.
routing_map
)
if
self
.
config
.
moe_shared_expert_overlap
and
self
.
shared_experts
is
not
None
:
self
.
shared_experts
.
pre_forward_comm
(
hidden_states
.
view
(
self
.
hidden_shape
))
# Permutation 1: input to AlltoAll input
tokens_per_expert
=
self
.
_maybe_dtoh_and_synchronize
(
"before_permutation_1"
,
tokens_per_expert
)
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_shape
[
-
1
])
self
.
hidden_shape_before_permute
=
hidden_states
.
shape
permutated_local_input_tokens
,
self
.
reversed_local_input_permutation_mapping
=
permute
(
...
...
@@ -350,11 +367,16 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
#tokens_per_expert = self._maybe_dtoh_and_synchronize("before_finish", tokens_per_expert)
return
global_input_tokens
,
tokens_per_expert
return
global_input_tokens
def
token_unpermutation
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
)
->
torch
.
Tensor
:
return
torch
.
ops
.
vllm
.
token_unpermutation_forward
(
hidden_states
,
self
.
layer_name
)
def
token_unpermutation_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""
Reverse the token permutation to restore the original order.
...
...
@@ -468,3 +490,55 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
# cuda_dtoh_stream.synchronize()
return
tokens_per_expert
def
token_permutation_forward
(
tokens_per_expert
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
,
layer_name
:
str
)
->
torch
.
Tensor
:
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
return
self
.
token_permutation_impl
(
tokens_per_expert
,
hidden_states
,
probs
,
routing_map
)
def
token_permutation_forward_fake
(
tokens_per_expert
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
,
layer_name
:
str
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
direct_register_custom_op
(
op_name
=
"token_permutation_forward"
,
op_func
=
token_permutation_forward
,
mutates_args
=
[
"tokens_per_expert"
,
"hidden_states"
,
"probs"
,
"routing_map"
],
fake_impl
=
token_permutation_forward_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
)
def
token_unpermutation_forward
(
hidden_states
:
torch
.
Tensor
,
layer_name
:
str
)
->
torch
.
Tensor
:
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
return
self
.
token_unpermutation_impl
(
hidden_states
)
def
token_unpermutation_forward_fake
(
hidden_states
:
torch
.
Tensor
,
layer_name
:
str
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
direct_register_custom_op
(
op_name
=
"token_unpermutation_forward"
,
op_func
=
token_unpermutation_forward
,
mutates_args
=
[
"hidden_states"
],
fake_impl
=
token_unpermutation_forward_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment