Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d698d6f2
"examples/vscode:/vscode.git/clone" did not exist on "47a1f11bffdd12cd59d90d79ff9867b7b3ac5b69"
Commit
d698d6f2
authored
Nov 01, 2025
by
王敏
Browse files
[feat]整合mori和deepep相关代码
parent
7293a072
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
93 additions
and
655 deletions
+93
-655
vllm.zip
vllm.zip
+0
-0
vllm/distributed/device_communicators/all2all.py
vllm/distributed/device_communicators/all2all.py
+10
-10
vllm/envs.py
vllm/envs.py
+6
-0
vllm/model_executor/layers/fused_moe/ep_moe/token_dispatcher.py
...odel_executor/layers/fused_moe/ep_moe/token_dispatcher.py
+0
-559
vllm/model_executor/layers/fused_moe/mori_moe/ep_moe_utlis.py
.../model_executor/layers/fused_moe/mori_moe/ep_moe_utlis.py
+0
-0
vllm/model_executor/layers/fused_moe/mori_moe/layer.py
vllm/model_executor/layers/fused_moe/mori_moe/layer.py
+20
-48
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
...del_executor/layers/quantization/slimquant_w4a8_marlin.py
+33
-28
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+24
-10
No files found.
vllm.zip
0 → 100644
View file @
d698d6f2
File added
vllm/distributed/device_communicators/all2all.py
View file @
d698d6f2
...
@@ -171,16 +171,16 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
...
@@ -171,16 +171,16 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
num_qps_per_rank
=
None
num_qps_per_rank
=
None
if
self
.
internode
:
if
self
.
internode
:
num_rdma_bytes
=
int
(
1e9
/
2
)
#1024 * 1024 * 1024
num_rdma_bytes
=
int
(
1e9
/
2
)
#1024 * 1024 * 1024
num_qps_per_rank
=
30
#self.num_sms // 2
num_qps_per_rank
=
30
#self.num_sms // 2
import
deep_ep
#
import deep_ep
num_nvl_bytes
,
num_rdma_bytes
=
0
,
0
#
num_nvl_bytes, num_rdma_bytes = 0, 0
hidden_size
=
7168
#
hidden_size = 7168
hidden_bytes
=
hidden_size
*
2
#
hidden_bytes = hidden_size * 2
for
config
in
(
deep_ep
.
Buffer
.
get_dispatch_config
(
self
.
cpu_group
.
size
()),
deep_ep
.
Buffer
.
get_combine_config
(
self
.
cpu_group
.
size
())):
#
for config in (deep_ep.Buffer.get_dispatch_config(self.cpu_group.size()), deep_ep.Buffer.get_combine_config(self.cpu_group.size())):
num_nvl_bytes
=
max
(
config
.
get_nvl_buffer_size_hint
(
hidden_bytes
,
self
.
cpu_group
.
size
()),
num_nvl_bytes
)
#
num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, self.cpu_group.size()), num_nvl_bytes)
num_rdma_bytes
=
max
(
config
.
get_rdma_buffer_size_hint
(
hidden_bytes
,
self
.
cpu_group
.
size
()),
num_rdma_bytes
)
#
num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, self.cpu_group.size()), num_rdma_bytes)
else
:
else
:
num_rdma_bytes
=
0
num_rdma_bytes
=
0
num_qps_per_rank
=
1
num_qps_per_rank
=
1
...
...
vllm/envs.py
View file @
d698d6f2
...
@@ -175,6 +175,7 @@ if TYPE_CHECKING:
...
@@ -175,6 +175,7 @@ if TYPE_CHECKING:
USE_FUSED_SILU_MUL_QUANT
:
bool
=
False
USE_FUSED_SILU_MUL_QUANT
:
bool
=
False
VLLM_P2P_ASYNC
:
bool
=
False
VLLM_P2P_ASYNC
:
bool
=
False
VLLM_P2P_BUF_TOKENS
:
int
=
30000
VLLM_P2P_BUF_TOKENS
:
int
=
30000
VLLM_ENABLE_MOE_GROUP_GEMM
:
bool
=
False
def
get_default_cache_root
():
def
get_default_cache_root
():
return
os
.
getenv
(
return
os
.
getenv
(
...
@@ -1151,6 +1152,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1151,6 +1152,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# pd separation p2p async buf tokens
# pd separation p2p async buf tokens
"VLLM_P2P_BUF_TOKENS"
:
"VLLM_P2P_BUF_TOKENS"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_P2P_BUF_TOKENS"
,
"30000"
)),
lambda
:
int
(
os
.
getenv
(
"VLLM_P2P_BUF_TOKENS"
,
"30000"
)),
# pd separation p2p async buf tokens
"VLLM_ENABLE_MOE_GROUP_GEMM"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_ENABLE_MOE_GROUP_GEMM"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
}
}
# --8<-- [end:env-vars-definition]
# --8<-- [end:env-vars-definition]
...
...
vllm/model_executor/layers/fused_moe/ep_moe/token_dispatcher.py
deleted
100644 → 0
View file @
7293a072
import
os
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch.nn
as
nn
from
vllm.distributed.parallel_state
import
(
get_dp_group
,
get_tp_group
,
get_ep_group
,
get_tensor_model_parallel_rank
)
from
vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis
import
(
EPSharedExperts
,
maybe_move_tensor_to_cpu
,
maybe_move_tensor_to_cpu_block
,
permute
,
sort_chunks_by_idxs
,
unpermute
,
all_to_all
,
EpMoeConfig
)
from
vllm.distributed
import
(
tensor_model_parallel_all_gather
,
tensor_model_parallel_gather
,
expert_parallel_all_gather
,
expert_parallel_gather
)
from
vllm.platforms
import
current_platform
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.utils
import
direct_register_custom_op
from
vllm.config
import
get_current_vllm_config
from
lightop
import
groupgemm_permute
,
groupgemm_unpermute
cuda_dtoh_stream
=
torch
.
cuda
.
Stream
()
cuda_dtoh_sync_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
class
MoETokenDispatcher
(
nn
.
Module
):
"""
MoE Token Dispatcher
"""
def
__init__
(
self
,
config
:
EpMoeConfig
)
->
None
:
"""
Initialize the MoE Token Dispatcher.
"""
super
().
__init__
()
self
.
config
=
config
self
.
tp_size
=
1
self
.
ep_size
=
config
.
ep_size
@
property
def
ep_group
(
self
):
"""Get expert model parallel group."""
return
get_ep_group
()
@
property
def
tp_group
(
self
):
"""Get expert tensor parallel group."""
return
get_tp_group
()
@
property
def
tp_rank
(
self
):
"""Get expert tensor parallel rank."""
return
0
#get_tensor_model_parallel_rank()
@
property
def
tp_ep_group
(
self
):
"""Get expert tensor and model parallel group."""
return
get_ep_group
()
@
abstractmethod
def
token_permutation
(
self
,
tokens
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
):
"""Dispatch tokens to experts.
Args:
tokens (torch.Tensor): Input tokens.
probs (torch.Tensor): The routing probability tensor [num_tokens, num_experts].
routing_map (torch.Tensor): Token to expert mapping tensor.
Returns:
torch.Tensor: Tokens tensor.
"""
raise
NotImplementedError
(
"Dispatch function not implemented."
)
@
abstractmethod
def
token_unpermutation
(
self
,
expert_output
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
=
None
):
"""Restores the expert output to its original ordering.
Args:
expert_output (torch.Tensor): The output tensor from the expert models.
bias (torch.Tensor): The bias tensor.
Returns:
(torch.Tensor, torch.Tensor): Unpermuted activation and optional bias.
"""
raise
NotImplementedError
(
"Restore function not implemented."
)
def
set_shared_experts
(
self
,
shared_experts
):
"""Set shared expert to the dispatcher."""
assert
self
.
config
.
moe_shared_expert_overlap
self
.
shared_experts
=
shared_experts
class
MoEAlltoAllTokenDispatcher
(
MoETokenDispatcher
):
"""
AlltoAll-based token dispatcher.
The workflow of AlltoAll token dispatcher is as follows:
(1) preprocess(): calculate necessary metadata for communication and permute
(2) token_permutation(): permute->A2A(EP)->AG(TP)->sort_chunk(if num_local_experts>1)
(3) token_unpermutation(): sort_chunk(if num_local_experts>1)->RS(TP)->A2A(EP)->unpermute
"""
def
__init__
(
self
,
num_local_experts
:
int
,
local_expert_indices
:
List
[
int
],
config
:
EpMoeConfig
,
layer_name
:
str
=
""
)
->
None
:
"""
Initialize the AlltoAll token dispatcher.
Args:
num_local_experts (int): Number of local experts on the current device.
local_expert_indices (List[int]): Indices of local experts on the current device.
config (TransformerConfig): Configuration for the transformer model.
"""
super
().
__init__
(
config
=
config
)
self
.
num_local_experts
=
num_local_experts
assert
config
.
num_moe_experts
is
not
None
self
.
num_experts
=
config
.
num_moe_experts
assert
self
.
num_local_experts
>
0
,
"Expected at least one expert"
self
.
local_expert_indices
=
local_expert_indices
assert
(
len
(
self
.
local_expert_indices
)
==
self
.
num_local_experts
),
"Invalid local expert indices"
for
i
in
range
(
len
(
self
.
local_expert_indices
)
-
1
):
assert
(
self
.
local_expert_indices
[
i
]
==
self
.
local_expert_indices
[
i
+
1
]
-
1
),
"local_expert_indices must be continous"
self
.
layer_name
=
layer_name
# [ep_size]. Represents the number of tokens sent by the current rank to other
# EP ranks.
self
.
input_splits
=
None
# [ep_size]. Represents the number of tokens received by the current rank from
# other EP ranks.
self
.
output_splits
=
None
# [tp_size]. Represents the number of tokens received by the current rank from
# other TP ranks.
#self.output_splits_tp = None
self
.
permute_idx_device
=
torch
.
device
(
"cuda"
)
if
self
.
config
.
moe_permute_fusion
else
None
input_chunk_idxs
=
torch
.
arange
(
self
.
num_experts
*
self
.
tp_size
,
device
=
self
.
permute_idx_device
)
# [num_local_experts, tp_size * ep_size]. Sort the input chunks by local experts.
self
.
sort_input_by_local_experts
=
input_chunk_idxs
.
reshape
(
-
1
,
self
.
num_local_experts
).
T
.
ravel
()
# [tp_size * ep_size, num_local_experts]. Restore the output chunks by local experts.
self
.
restore_output_by_local_experts
=
input_chunk_idxs
.
reshape
(
self
.
num_local_experts
,
-
1
).
T
.
ravel
()
# A cuda stream synchronization is needed in self.token_permutation() in some cases,
# because there are several non-blocking DtoH data transfers called at
# `self.cuda_dtoh_point`. The synchronization happens at `self.cuda_sync_point`, which is
# decided based on the MoE and parallel settings. Valid points are "before_permutation_1",
# "before_ep_alltoall", "before_permutation_2", "before_finish", and "no_sync".
self
.
cuda_sync_point
=
"no_sync"
self
.
cuda_sync_point_priority
=
{
"before_permutation_1"
:
0
,
"before_ep_alltoall"
:
1
,
"before_permutation_2"
:
2
,
"before_finish"
:
3
,
"no_sync"
:
4
,
}
self
.
cuda_dtoh_point
=
"before_permutation_1"
#self.cuda_dtoh_stream = torch.cuda.Stream()
# Whether to use gather or all-gather to gather the logits.
self
.
use_all_gather
=
current_platform
.
use_all_gather
()
self
.
probs
=
None
# For smuggling this layer into the fused moe custom op
vllm_config
=
get_current_vllm_config
()
compilation_config
=
vllm_config
.
compilation_config
if
layer_name
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
"Duplicate layer name: {}"
.
format
(
layer_name
))
compilation_config
.
static_forward_context
[
layer_name
]
=
self
def
preprocess
(
self
,
routing_map
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Preprocess token routing map for AlltoAll communication and token permutation.
This method computes the number of tokens assigned to each expert based on the routing_map.
It also initializes the necessary data structures for AlltoAll communication, such as input
and output splits, and the mapping between global tokens and local experts. This method
should not call any DtoH data copying due to performance consideration. The necessary DtoH
copies are made on the `self.cuda_dtoh_stream` at `self.cuda_dtoh_point`.
Args:
routing_map (torch.Tensor): The mapping of tokens to experts, with shape
[num_tokens, num_experts].
Returns:
torch.Tensor: Tensor containing the number of tokens assigned to local expert.
"""
# [num_experts], number of tokens assigned to each expert from the current rank's input.
num_local_tokens_per_expert
=
routing_map
.
sum
(
dim
=
0
).
long
()
self
.
num_out_tokens
=
routing_map
.
size
(
0
)
*
self
.
config
.
moe_router_topk
# ===================================================
# Calculate input_splits, output_splits for alltoall/allgather in variable size.
# ===================================================
# [ep_size]. Represents the number of tokens sent by the current rank to other
# EP ranks.
self
.
input_splits
=
num_local_tokens_per_expert
.
reshape
(
self
.
ep_size
,
self
.
num_local_experts
).
sum
(
axis
=
1
)
# Gather the global distribution of tokens across ranks.
# num_global_tokens_per_expert represents the number of tokens sent to each
# expert by all ranks.
# [tp_size, ep_size, num_experts]
if
self
.
use_all_gather
:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
num_global_tokens_per_expert
=
expert_parallel_all_gather
(
num_local_tokens_per_expert
)
\
.
reshape
(
self
.
ep_size
,
self
.
tp_size
,
self
.
num_experts
)
\
.
transpose
(
0
,
1
)
else
:
# None may be returned for rank > 0
num_global_tokens_per_expert
=
expert_parallel_gather
(
num_local_tokens_per_expert
)
\
.
reshape
(
self
.
ep_size
,
self
.
tp_size
,
self
.
num_experts
)
\
.
transpose
(
0
,
1
)
# [tp_size, ep_size, num_experts] -> [tp_size, ep_size, num_local_experts]
num_global_tokens_per_local_expert
=
num_global_tokens_per_expert
[
:,
:,
self
.
local_expert_indices
[
0
]
:
self
.
local_expert_indices
[
-
1
]
+
1
].
contiguous
()
# [tp_size, ep_size, num_local_experts] -> [tp_size, ep_size]
num_global_tokens_per_rank
=
num_global_tokens_per_local_expert
.
sum
(
axis
=
2
)
# [tp_size, ep_size] -> [ep_size]
# self.output_splits represents the number of tokens received by the current rank
# from other EP rank.
self
.
output_splits
=
num_global_tokens_per_rank
[
self
.
tp_rank
]
# [tp_size, ep_size] -> [tp_size]
# self.output_splits_tp represents the number of tokens received by the current
# rank from other TP rank.
#self.output_splits_tp = num_global_tokens_per_rank.sum(axis=1)
# [tp_size, ep_size, num_local_experts] -> [num_local_experts]
num_tokens_per_local_expert
=
num_global_tokens_per_local_expert
.
sum
(
dim
=
(
0
,
1
))
# A synchronization is needed before expert parallel AlltoAll communication
# to get the `input_splits` and `output_splits` CPU values.
#self._maybe_update_cuda_sync_point("before_ep_alltoall")
if
self
.
num_local_experts
>
1
:
# [tp_size * ep_size, num_local_experts]. Represents the number of tokens sent
# to each local expert by all ranks.
self
.
num_global_tokens_per_local_expert
=
num_global_tokens_per_local_expert
.
view
(
-
1
,
self
.
num_local_experts
)
# if not self.config.moe_permute_fusion:
# # A synchronization is needed before permutation 2
# # to get the `num_global_tokens_per_local_expert` CPU value.
# self._maybe_update_cuda_sync_point("before_permutation_2")
# assert (
# self.cuda_sync_point_priority[self.cuda_dtoh_point]
# <= self.cuda_sync_point_priority[self.cuda_sync_point]
# ), "cuda_sync_point must be after cuda_dtoh_point."
return
num_tokens_per_local_expert
def
token_permutation
(
self
,
hidden_states
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
self
.
routing_map
=
routing_map
assert
routing_map
.
dim
()
==
2
,
"Expected 2D tensor for token2expert mask"
assert
routing_map
.
dtype
==
torch
.
bool
,
"Expected bool tensor for mask"
tokens_per_expert
=
self
.
preprocess
(
self
.
routing_map
)
if
self
.
config
.
moe_shared_expert_overlap
and
self
.
shared_experts
is
not
None
:
self
.
shared_experts
.
pre_forward_comm
(
hidden_states
.
view
(
self
.
hidden_shape
))
global_input_tokens
=
torch
.
ops
.
vllm
.
token_permutation_forward
(
tokens_per_expert
,
hidden_states
,
probs
,
routing_map
,
self
.
layer_name
)
return
global_input_tokens
,
tokens_per_expert
def
token_permutation_impl
(
self
,
tokens_per_expert
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Dispatch tokens to local experts using AlltoAll communication.
This method performs the following steps:
1. Preprocess the routing map to get metadata for communication and permutation.
2. Permute input tokens for AlltoAll communication.
3. Perform expert parallel AlltoAll communication.
4. Sort tokens by local expert (if multiple local experts exist).
Args:
hidden_states (torch.Tensor): Input token embeddings.
probs (torch.Tensor): The probabilities of token to experts assignment.
routing_map (torch.Tensor): The mapping of token to experts assignment.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
- Permuted token embeddings for local experts.
- Number of tokens per expert.
"""
# Preprocess: Get the metadata for communication, permutation and computation operations.
# Permutation 1: input to AlltoAll input
tokens_per_expert
=
self
.
_maybe_dtoh_and_synchronize
(
"before_permutation_1"
,
tokens_per_expert
)
self
.
hidden_shape
=
hidden_states
.
shape
if
self
.
config
.
apply_router_weight_on_input
:
self
.
probs
=
probs
assert
probs
.
dim
()
==
2
,
"Expected 2D tensor for probs"
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_shape
[
-
1
])
self
.
hidden_shape_before_permute
=
hidden_states
.
shape
if
False
:
permutated_local_input_tokens
,
self
.
reversed_local_input_permutation_mapping
=
permute
(
hidden_states
,
routing_map
,
num_out_tokens
=
self
.
num_out_tokens
,
fused
=
self
.
config
.
moe_permute_fusion
)
else
:
cuda_permute_result
=
groupgemm_permute
(
hidden_states
,
routing_map
)
permutated_local_input_tokens
,
self
.
reversed_local_input_permutation_mapping
,
\
self
.
expert_m_count
=
cuda_permute_result
# Perform expert parallel AlltoAll communication
# tokens_per_expert = self._maybe_dtoh_and_synchronize(
# "before_ep_alltoall", tokens_per_expert
# )
###test##############
#cuda_dtoh_stream.synchronize()
cuda_dtoh_sync_event
.
synchronize
()
###test##############
global_input_tokens
=
all_to_all
(
self
.
ep_group
.
device_group
,
permutated_local_input_tokens
,
self
.
output_splits
,
self
.
input_splits
)
if
self
.
config
.
moe_shared_expert_overlap
and
self
.
shared_experts
is
not
None
:
self
.
shared_experts
.
linear_fc1_forward_and_act
(
global_input_tokens
)
# Permutation 2: Sort tokens by local expert.
# tokens_per_expert = self._maybe_dtoh_and_synchronize(
# "before_permutation_2", tokens_per_expert
# )
if
self
.
num_local_experts
>
1
:
global_input_tokens
=
sort_chunks_by_idxs
(
global_input_tokens
,
self
.
num_global_tokens_per_local_expert
.
ravel
(),
self
.
sort_input_by_local_experts
,
fused
=
self
.
config
.
moe_permute_fusion
,
)
#tokens_per_expert = self._maybe_dtoh_and_synchronize("before_finish", tokens_per_expert)
return
global_input_tokens
def
token_unpermutation
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
return
torch
.
ops
.
vllm
.
token_unpermutation_forward
(
hidden_states
,
self
.
layer_name
)
def
token_unpermutation_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""
Reverse the token permutation to restore the original order.
This method performs the following steps:
1. Unsort tokens by local expert (if multiple local experts exist).
2. Perform expert parallel AlltoAll communication to restore the original order.
3. Unpermute tokens to restore the original order.
Args:
hidden_states (torch.Tensor): Output from local experts.
bias (torch.Tensor, optional): Bias tensor (not supported).
Returns:
Tuple[torch.Tensor, Optional[torch.Tensor]]:
- Unpermuted token embeddings in the original order.
- None (bias is not supported).
"""
# Unpermutation 2: Unsort tokens by local expert.
if
self
.
num_local_experts
>
1
:
hidden_states
=
sort_chunks_by_idxs
(
hidden_states
,
self
.
num_global_tokens_per_local_expert
.
T
.
ravel
(),
self
.
restore_output_by_local_experts
,
fused
=
self
.
config
.
moe_permute_fusion
,
)
# Perform expert parallel AlltoAll communication
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
permutated_local_input_tokens
=
all_to_all
(
self
.
ep_group
.
device_group
,
hidden_states
,
self
.
input_splits
,
self
.
output_splits
)
if
self
.
config
.
moe_shared_expert_overlap
and
self
.
shared_experts
is
not
None
:
self
.
shared_experts
.
linear_fc2_forward
(
permutated_local_input_tokens
)
self
.
shared_experts
.
post_forward_comm
()
# Unpermutation 1: AlltoAll output to output
if
False
:
output
=
unpermute
(
permutated_local_input_tokens
,
self
.
reversed_local_input_permutation_mapping
,
restore_shape
=
self
.
hidden_shape_before_permute
,
probs
=
self
.
probs
,
routing_map
=
self
.
routing_map
,
fused
=
self
.
config
.
moe_permute_fusion
,
)
else
:
output
=
groupgemm_unpermute
(
permutated_local_input_tokens
,
self
.
reversed_local_input_permutation_mapping
,
list
(
self
.
hidden_shape_before_permute
),
self
.
probs
,
self
.
routing_map
,
self
.
expert_m_count
)
# Reshape the output tensor
output
=
output
.
view
(
self
.
hidden_shape
)
# Add shared experts output
if
self
.
config
.
moe_shared_expert_overlap
and
self
.
shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
.
get_output
()
if
hidden_states
.
dtype
!=
torch
.
float16
:
output
=
output
+
shared_output
else
:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
output
=
output
+
shared_output
\
*
(
1.
/
self
.
config
.
routed_scaling_factor
)
return
output
def
_maybe_update_cuda_sync_point
(
self
,
point
:
str
):
"""
Update the CUDA sync point if the priority of the new point is higher than the current
sync point, which means the new point is reached earlier than the current sync point.
"""
if
(
self
.
cuda_sync_point_priority
[
point
]
<
self
.
cuda_sync_point_priority
[
self
.
cuda_sync_point
]
):
self
.
cuda_sync_point
=
point
def
_maybe_dtoh_and_synchronize
(
self
,
point
:
str
,
tokens_per_expert
:
torch
.
Tensor
=
None
)
->
torch
.
Tensor
:
"""
Move all possible GPU tensors to CPU and make a synchronization at the expected point.
"""
if
point
==
self
.
cuda_dtoh_point
:
# Move all possible GPU tensors to CPU at self.cuda_dtoh_point.
on_side_stream
=
torch
.
cuda
.
current_stream
()
!=
cuda_dtoh_stream
if
on_side_stream
:
cuda_dtoh_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
cuda_dtoh_stream
):
# TODO: use MemcpyBatchAsync instead.
# tokens_per_expert = maybe_move_tensor_to_cpu(
# tokens_per_expert, record_stream=on_side_stream
# )
self
.
input_splits
=
maybe_move_tensor_to_cpu
(
self
.
input_splits
,
as_numpy
=
True
,
record_stream
=
on_side_stream
)
self
.
output_splits
=
maybe_move_tensor_to_cpu
(
self
.
output_splits
,
as_numpy
=
True
,
record_stream
=
on_side_stream
)
# self.output_splits_tp = maybe_move_tensor_to_cpu(
# self.output_splits_tp, as_numpy=True, record_stream=on_side_stream
# )
self
.
num_out_tokens
=
maybe_move_tensor_to_cpu
(
self
.
num_out_tokens
,
record_stream
=
on_side_stream
)
if
self
.
num_local_experts
>
1
and
not
self
.
config
.
moe_permute_fusion
:
self
.
num_global_tokens_per_local_expert
=
maybe_move_tensor_to_cpu
(
self
.
num_global_tokens_per_local_expert
,
record_stream
=
on_side_stream
)
cuda_dtoh_sync_event
.
record
()
# if point == self.cuda_sync_point:
# # Synchronize with the dtoh stream at self.cuda_sync_point.
# cuda_dtoh_stream.synchronize()
return
tokens_per_expert
def
token_permutation_forward
(
tokens_per_expert
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
,
layer_name
:
str
)
->
torch
.
Tensor
:
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
return
self
.
token_permutation_impl
(
tokens_per_expert
,
hidden_states
,
probs
,
routing_map
)
def
token_permutation_forward_fake
(
tokens_per_expert
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
,
layer_name
:
str
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
direct_register_custom_op
(
op_name
=
"token_permutation_forward"
,
op_func
=
token_permutation_forward
,
mutates_args
=
[
"tokens_per_expert"
,
"hidden_states"
,
"probs"
,
"routing_map"
],
fake_impl
=
token_permutation_forward_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
)
def
token_unpermutation_forward
(
hidden_states
:
torch
.
Tensor
,
layer_name
:
str
)
->
torch
.
Tensor
:
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
return
self
.
token_unpermutation_impl
(
hidden_states
)
def
token_unpermutation_forward_fake
(
hidden_states
:
torch
.
Tensor
,
layer_name
:
str
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
direct_register_custom_op
(
op_name
=
"token_unpermutation_forward"
,
op_func
=
token_unpermutation_forward
,
mutates_args
=
[
"hidden_states"
],
fake_impl
=
token_unpermutation_forward_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
)
\ No newline at end of file
vllm/model_executor/layers/fused_moe/
ep
_moe/ep_moe_utlis.py
→
vllm/model_executor/layers/fused_moe/
mori
_moe/ep_moe_utlis.py
View file @
d698d6f2
File moved
vllm/model_executor/layers/fused_moe/
ep
_moe/layer.py
→
vllm/model_executor/layers/fused_moe/
mori
_moe/layer.py
View file @
d698d6f2
import
os
from
typing
import
Callable
,
Optional
import
logging
from
typing
import
Callable
,
List
,
Optional
,
Tuple
from
dataclasses
import
dataclass
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
import
torch
import
torch
import
torch.
nn.functional
as
F
import
torch.
distributed
as
dist
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -18,10 +15,8 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -18,10 +15,8 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoEMethodBase
,
UnquantizedFusedMoEMethod
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoEMethodBase
,
UnquantizedFusedMoEMethod
from
vllm.model_executor.layers.fused_moe.ep_moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
from
vllm.model_executor.layers.fused_moe.mori_moe.ep_moe_utlis
import
EpMoeConfig
from
vllm.model_executor.layers.fused_moe.ep_moe.ep_moe_utlis
import
EpMoeConfig
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
import
torch.distributed
as
dist
try
:
try
:
import
mori
import
mori
...
@@ -35,8 +30,8 @@ logger = init_logger(__name__)
...
@@ -35,8 +30,8 @@ logger = init_logger(__name__)
_MORI_OP
=
None
_MORI_OP
=
None
@
CustomOp
.
register
(
"unquantized_
ep
_moe"
)
@
CustomOp
.
register
(
"unquantized_
mori
_moe"
)
class
Unquantized
EPGroupedGemm
Method
(
UnquantizedFusedMoEMethod
):
class
Unquantized
MoriMoe
Method
(
UnquantizedFusedMoEMethod
):
"""MoE method without quantization."""
"""MoE method without quantization."""
def
__init__
(
self
,
moe
:
FusedMoEConfig
):
def
__init__
(
self
,
moe
:
FusedMoEConfig
):
...
@@ -44,9 +39,9 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
...
@@ -44,9 +39,9 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
self
.
topk_indices_dtype
=
None
self
.
topk_indices_dtype
=
None
self
.
moe
=
moe
self
.
moe
=
moe
self
.
rocm_aiter_moe_enabled
=
False
# is_rocm_aiter_moe_enabled()
self
.
rocm_aiter_moe_enabled
=
False
def
apply_ep
(
def
apply_
mori_
ep
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
@@ -162,7 +157,7 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
...
@@ -162,7 +157,7 @@ class UnquantizedEPGroupedGemmMethod(UnquantizedFusedMoEMethod):
forward_native
=
forward_cuda
forward_native
=
forward_cuda
class
EP
MoE
(
FusedMoE
):
class
Mori
MoE
(
FusedMoE
):
"""
"""
dp+ep MoE Expert Parallel Impl
dp+ep MoE Expert Parallel Impl
...
@@ -194,7 +189,6 @@ class EPMoE(FusedMoE):
...
@@ -194,7 +189,6 @@ class EPMoE(FusedMoE):
enable_eplb
:
bool
=
False
,
enable_eplb
:
bool
=
False
,
num_redundant_experts
:
int
=
0
,
num_redundant_experts
:
int
=
0
,
moe_permute_fusion
:
bool
=
False
,
moe_permute_fusion
:
bool
=
False
,
moe_shared_expert_overlap
:
bool
=
False
):
):
super
().
__init__
(
num_experts
,
top_k
,
hidden_size
,
super
().
__init__
(
num_experts
,
top_k
,
hidden_size
,
intermediate_size
,
params_dtype
,
intermediate_size
,
params_dtype
,
...
@@ -215,7 +209,6 @@ class EPMoE(FusedMoE):
...
@@ -215,7 +209,6 @@ class EPMoE(FusedMoE):
moe_router_topk
=
self
.
top_k
,
moe_router_topk
=
self
.
top_k
,
# TODO: support fusion permute
# TODO: support fusion permute
moe_permute_fusion
=
moe_permute_fusion
,
moe_permute_fusion
=
moe_permute_fusion
,
moe_shared_expert_overlap
=
moe_shared_expert_overlap
,
ep_size
=
self
.
ep_size
,
ep_size
=
self
.
ep_size
,
num_moe_experts
=
self
.
global_num_experts
,
num_moe_experts
=
self
.
global_num_experts
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
...
@@ -228,21 +221,14 @@ class EPMoE(FusedMoE):
...
@@ -228,21 +221,14 @@ class EPMoE(FusedMoE):
self
.
local_expert_indices
=
[
self
.
local_expert_indices
=
[
local_expert_indices_offset
+
i
for
i
in
range
(
self
.
local_num_experts
)
local_expert_indices_offset
+
i
for
i
in
range
(
self
.
local_num_experts
)
]
]
self
.
use_shared_expert
=
False
# self.token_dispatcher = MoEAlltoAllTokenDispatcher(
# self.local_num_experts, self.local_expert_indices,
# config=self.ep_moe_config, layer_name=f"{self.layer_name}.token_dispatcher",
# )
self
.
shared_expert_overlap
=
moe_shared_expert_overlap
self
.
shared_experts
=
None
self
.
shared_experts
=
None
self
.
scales
=
None
self
.
scales
=
None
self
.
use_int8_dispatch
=
True
self
.
use_int8_dispatch
=
True
vllm_config
=
get_current_vllm_config
()
vllm_config
=
get_current_vllm_config
()
self
.
max_num_inp_token_per_rank
=
1024
#vllm_config.scheduler_config.max_num_seqs
self
.
max_num_inp_token_per_rank
=
1024
#vllm_config.scheduler_config.max_num_seqs
self
.
mori_op
=
self
.
get_mori_op
()
self
.
mori_op
=
self
.
get_mori_op
()
def
get_mori_op
(
self
):
def
get_mori_op
(
self
):
...
@@ -252,10 +238,6 @@ class EPMoE(FusedMoE):
...
@@ -252,10 +238,6 @@ class EPMoE(FusedMoE):
assert
world_group
is
not
None
assert
world_group
is
not
None
torch
.
_C
.
_distributed_c10d
.
_register_process_group
(
"mori_ep"
,
get_ep_group
().
device_group
)
torch
.
_C
.
_distributed_c10d
.
_register_process_group
(
"mori_ep"
,
get_ep_group
().
device_group
)
mori
.
shmem
.
shmem_torch_process_group_init
(
"mori_ep"
)
mori
.
shmem
.
shmem_torch_process_group_init
(
"mori_ep"
)
# world_group = torch.distributed.group.WORLD
# assert world_group is not None
# torch._C._distributed_c10d._register_process_group("default", world_group)
# mori.shmem.shmem_torch_process_group_init("default")
vllm_config
=
get_current_vllm_config
()
vllm_config
=
get_current_vllm_config
()
multi_node
=
self
.
ep_size
/
8
>
1
multi_node
=
self
.
ep_size
/
8
>
1
...
@@ -278,7 +260,6 @@ class EPMoE(FusedMoE):
...
@@ -278,7 +260,6 @@ class EPMoE(FusedMoE):
max_token_type_size
=
2
,
max_token_type_size
=
2
,
block_num
=
80
,
block_num
=
80
,
warp_num_per_block
=
4
,
warp_num_per_block
=
4
,
# kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode
kernel_type
=
mori
.
ops
.
EpDispatchCombineKernelType
.
InterNode
if
multi_node
else
\
kernel_type
=
mori
.
ops
.
EpDispatchCombineKernelType
.
InterNode
if
multi_node
else
\
mori
.
ops
.
EpDispatchCombineKernelType
.
IntraNode
mori
.
ops
.
EpDispatchCombineKernelType
.
IntraNode
)
)
...
@@ -290,14 +271,11 @@ class EPMoE(FusedMoE):
...
@@ -290,14 +271,11 @@ class EPMoE(FusedMoE):
if
self
.
shared_experts
is
None
:
if
self
.
shared_experts
is
None
:
self
.
shared_experts
=
shared_experts
self
.
shared_experts
=
shared_experts
# if self.shared_expert_overlap:
# self.token_dispatcher.set_shared_experts(self.shared_experts)
def
create_quant_method
(
self
,
moe
,
quant_config
,
prefix
):
def
create_quant_method
(
self
,
moe
,
quant_config
,
prefix
):
# Note: get_quant_method will look at the layer's local_num_experts
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
# for heuristic purposes, so it must be initialized first.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
None
quant_method
:
Optional
[
QuantizeMethodBase
]
=
None
quant_method
=
(
Unquantized
EPGroupedGemm
Method
(
moe
)
if
quant_config
is
None
quant_method
=
(
Unquantized
MoriMoe
Method
(
moe
)
if
quant_config
is
None
else
quant_config
.
get_quant_method
(
self
,
prefix
))
else
quant_config
.
get_quant_method
(
self
,
prefix
))
assert
quant_method
is
not
None
assert
quant_method
is
not
None
...
@@ -310,7 +288,7 @@ class EPMoE(FusedMoE):
...
@@ -310,7 +288,7 @@ class EPMoE(FusedMoE):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
router_logits
:
torch
.
Tensor
):
return
torch
.
ops
.
vllm
.
ep
_moe_forward
(
hidden_states
,
router_logits
,
return
torch
.
ops
.
vllm
.
mori
_moe_forward
(
hidden_states
,
router_logits
,
self
.
layer_name
)
self
.
layer_name
)
def
get_expert_weights
(
self
)
->
Iterable
[
torch
.
Tensor
]:
def
get_expert_weights
(
self
)
->
Iterable
[
torch
.
Tensor
]:
...
@@ -350,7 +328,7 @@ class EPMoE(FusedMoE):
...
@@ -350,7 +328,7 @@ class EPMoE(FusedMoE):
routed_scaling_factor
=
self
.
routed_scaling_factor
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
use_fused_gate
=
self
.
use_fused_gate
)
use_fused_gate
=
self
.
use_fused_gate
)
if
not
self
.
ep_moe_config
.
moe_shared_expert_overlap
and
self
.
shared_experts
is
not
None
:
if
self
.
shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
shared_output
=
self
.
shared_experts
(
hidden_states
)
if
self
.
use_int8_dispatch
:
if
self
.
use_int8_dispatch
:
...
@@ -377,11 +355,10 @@ class EPMoE(FusedMoE):
...
@@ -377,11 +355,10 @@ class EPMoE(FusedMoE):
hidden_states
,
hidden_states
,
topk_weights
,
topk_weights
,
scales
,
scales
,
topk_ids
,
topk_ids
#layer_idx=int(self.layer_name.split('.')[2])
)
)
expert_output
=
self
.
quant_method
.
apply_ep
(
expert_output
=
self
.
quant_method
.
apply_
mori_
ep
(
layer
=
self
,
layer
=
self
,
x
=
dispatch_output
,
x
=
dispatch_output
,
topk_weights
=
dispatch_weights
,
topk_weights
=
dispatch_weights
,
...
@@ -394,7 +371,6 @@ class EPMoE(FusedMoE):
...
@@ -394,7 +371,6 @@ class EPMoE(FusedMoE):
num_local_tokens
=
dispatch_recv_num_token
,
num_local_tokens
=
dispatch_recv_num_token
,
config_select_bs
=
hidden_states
.
shape
[
0
]
*
self
.
ep_size
/
self
.
dp_size
,
config_select_bs
=
hidden_states
.
shape
[
0
]
*
self
.
ep_size
/
self
.
dp_size
,
scales
=
dispatch_scales
if
self
.
use_int8_dispatch
else
None
scales
=
dispatch_scales
if
self
.
use_int8_dispatch
else
None
# routed_scaling_factor=self.routed_scaling_factor,
)
)
# self.sync()
# self.sync()
...
@@ -404,11 +380,7 @@ class EPMoE(FusedMoE):
...
@@ -404,11 +380,7 @@ class EPMoE(FusedMoE):
# self.sync()
# self.sync()
if
not
self
.
ep_moe_config
.
moe_shared_expert_overlap
and
self
.
shared_experts
is
not
None
:
if
self
.
shared_experts
is
not
None
:
# shared_output = (
# self.maybe_all_reduce_tensor_model_parallel(
# shared_output))
if
hidden_states
.
dtype
!=
torch
.
float16
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
=
final_hidden_states
+
shared_output
final_hidden_states
=
final_hidden_states
+
shared_output
else
:
else
:
...
@@ -420,7 +392,7 @@ class EPMoE(FusedMoE):
...
@@ -420,7 +392,7 @@ class EPMoE(FusedMoE):
return
final_hidden_states
return
final_hidden_states
def
ep
_moe_forward
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
def
mori
_moe_forward
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
layer_name
:
str
)
->
torch
.
Tensor
:
layer_name
:
str
)
->
torch
.
Tensor
:
forward_context
:
ForwardContext
=
get_forward_context
()
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
=
forward_context
.
no_compile_layers
[
layer_name
]
...
@@ -429,16 +401,16 @@ def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
...
@@ -429,16 +401,16 @@ def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
return
self
.
forward_impl
(
hidden_states
,
router_logits
)
return
self
.
forward_impl
(
hidden_states
,
router_logits
)
def
ep
_moe_forward_fake
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
def
mori
_moe_forward_fake
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
layer_name
:
str
)
->
torch
.
Tensor
:
layer_name
:
str
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
return
torch
.
empty_like
(
hidden_states
)
direct_register_custom_op
(
direct_register_custom_op
(
op_name
=
"
ep
_moe_forward"
,
op_name
=
"
mori
_moe_forward"
,
op_func
=
ep
_moe_forward
,
op_func
=
mori
_moe_forward
,
mutates_args
=
[
"hidden_states"
,
"router_logits"
],
mutates_args
=
[
"hidden_states"
,
"router_logits"
],
fake_impl
=
ep
_moe_forward_fake
,
fake_impl
=
mori
_moe_forward_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
dispatch_key
=
current_platform
.
dispatch_key
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,),
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,),
)
)
\ No newline at end of file
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
View file @
d698d6f2
...
@@ -166,6 +166,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
...
@@ -166,6 +166,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
self
.
use_deepep
=
parallel_config
.
enable_expert_parallel
and
\
self
.
use_deepep
=
parallel_config
.
enable_expert_parallel
and
\
(
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_high_throughput"
or
\
(
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_high_throughput"
or
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
)
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
)
self
.
enable_moe_group_gemm
=
parallel_config
.
enable_expert_parallel
and
envs
.
VLLM_ENABLE_MOE_GROUP_GEMM
def
create_weights
(
def
create_weights
(
...
@@ -250,32 +252,36 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
...
@@ -250,32 +252,36 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
):
**
_
):
workspace
,
global_reduce_buffer
=
MarlinMoeWorkspace
(
x
.
device
).
get_buffers
()
if
not
self
.
enable_moe_group_gemm
:
return
fused_experts_impl_w4a8_marlin
(
workspace
,
global_reduce_buffer
=
MarlinMoeWorkspace
(
x
.
device
).
get_buffers
()
x
,
return
fused_experts_impl_w4a8_marlin
(
w1
,
x
,
w2
,
w1
,
topk_ids
=
topk_ids
,
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
workspace
=
workspace
,
topk_weights
=
topk_weights
,
global_reduce_buffer
=
global_reduce_buffer
,
workspace
=
workspace
,
inplace
=
True
,
global_reduce_buffer
=
global_reduce_buffer
,
use_int4_w4a8
=
True
,
inplace
=
True
,
per_channel_quant
=
True
,
use_int4_w4a8
=
True
,
activation
=
activation
,
per_channel_quant
=
True
,
expert_map
=
expert_map
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
expert_map
=
expert_map
,
global_num_experts
=
global_num_experts
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
w1_scale
=
w1_scale
,
global_num_experts
=
global_num_experts
,
w2_scale
=
w2_scale
,
w1_scale
=
w1_scale
,
a1_scale
=
a1_scale
,
w2_scale
=
w2_scale
,
a2_scale
=
a2_scale
,
a1_scale
=
a1_scale
,
use_nn_moe
=
use_nn_moe
,
a2_scale
=
a2_scale
,
shared_output
=
shared_output
,
use_nn_moe
=
use_nn_moe
,
routed_scaling_factor
=
routed_scaling_factor
,
shared_output
=
shared_output
,
)
routed_scaling_factor
=
routed_scaling_factor
,
)
def
apply_ep
(
#dp+ep
else
:
# TODO:
return
None
def
apply_mori_ep
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
@@ -310,12 +316,11 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
...
@@ -310,12 +316,11 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
global_num_experts
=
global_num_experts
,
global_num_experts
=
global_num_experts
,
w1_scale
=
(
layer
.
w13_weight_scale
),
w1_scale
=
(
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_
scale
,
a1_scale
=
scale
s
,
a2_scale
=
layer
.
w2_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
use_nn_moe
,
use_nn_moe
=
use_nn_moe
,
num_local_tokens
=
num_local_tokens
,
num_local_tokens
=
num_local_tokens
,
config_select_bs
=
config_select_bs
,
config_select_bs
=
config_select_bs
,
q_scales
=
scales
)
)
def
apply
(
def
apply
(
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
d698d6f2
...
@@ -43,8 +43,8 @@ from vllm.distributed import (get_ep_group, get_pp_group, get_dp_group,
...
@@ -43,8 +43,8 @@ from vllm.distributed import (get_ep_group, get_pp_group, get_dp_group,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe.
ep
_moe.layer
import
EP
MoE
from
vllm.model_executor.layers.fused_moe.
mori
_moe.layer
import
Mori
MoE
from
vllm.model_executor.layers.fused_moe.
ep
_moe.ep_moe_utlis
import
EPSharedExperts
from
vllm.model_executor.layers.fused_moe.
mori
_moe.ep_moe_utlis
import
EPSharedExperts
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
...
@@ -167,10 +167,10 @@ class DeepseekV2MoE(nn.Module):
...
@@ -167,10 +167,10 @@ class DeepseekV2MoE(nn.Module):
self
.
n_local_physical_experts
)
self
.
n_local_physical_experts
)
dp_size
=
get_dp_group
().
world_size
dp_size
=
get_dp_group
().
world_size
self
.
use_mori_ep
=
envs
.
VLLM_ALL2ALL_BACKEND
==
'mori'
and
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
self
.
use_mori_ep
=
parallel_config
.
enable_expert_parallel
and
dp_size
>
1
and
envs
.
VLLM_ALL2ALL_BACKEND
==
'mori'
self
.
enable_expert_parallel
=
parallel_config
.
enable_expert_parallel
self
.
enable_expert_parallel
=
parallel_config
.
enable_expert_parallel
moe_cls
=
FusedMoE
if
not
self
.
use_mori_ep
else
EP
MoE
moe_cls
=
FusedMoE
if
not
self
.
use_mori_ep
else
Mori
MoE
self
.
experts
=
moe_cls
(
self
.
experts
=
moe_cls
(
num_experts
=
config
.
n_routed_experts
,
num_experts
=
config
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
...
@@ -225,12 +225,12 @@ class DeepseekV2MoE(nn.Module):
...
@@ -225,12 +225,12 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
if
not
self
.
use_mori_ep
:
if
not
self
.
enable_expert_parallel
:
if
envs
.
VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD
:
if
envs
.
VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD
:
final_hidden_states
=
self
.
experts
(
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
router_logits
=
router_logits
,
shared_output
=
shared_output
)
shared_output
=
shared_output
)
else
:
else
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
=
self
.
experts
(
final_hidden_states
=
self
.
experts
(
...
@@ -249,8 +249,22 @@ class DeepseekV2MoE(nn.Module):
...
@@ -249,8 +249,22 @@ class DeepseekV2MoE(nn.Module):
# See DeepseekV2DecoderLayer for more details.
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
final_hidden_states
+
shared_output
\
final_hidden_states
=
final_hidden_states
+
shared_output
\
*
(
1.
/
self
.
routed_scaling_factor
)
*
(
1.
/
self
.
routed_scaling_factor
)
else
:
else
:
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
if
not
self
.
use_mori_ep
:
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
shared_output
is
not
None
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
=
final_hidden_states
+
shared_output
else
:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
final_hidden_states
+
shared_output
\
*
(
1.
/
self
.
routed_scaling_factor
)
else
:
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
router_logits
=
router_logits
)
if
not
self
.
use_mori_ep
:
if
not
self
.
use_mori_ep
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment