Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3833018c
Commit
3833018c
authored
Dec 15, 2025
by
王敏
Browse files
[feat]1.支持高吞吐模式ep_scatter+deepgemm contiguous+ep_gather方案;2.支持高吞吐模式下ETP,例如dp4 tp4
parent
94c4ca4d
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1120 additions
and
167 deletions
+1120
-167
vllm/envs.py
vllm/envs.py
+4
-0
vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
+209
-76
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
+68
-21
vllm/model_executor/layers/fused_moe/modular_kernel.py
vllm/model_executor/layers/fused_moe/modular_kernel.py
+124
-47
vllm/model_executor/layers/fused_moe/triton_group_gemm_moe.py
.../model_executor/layers/fused_moe/triton_group_gemm_moe.py
+2
-0
vllm/model_executor/layers/fused_moe/utils.py
vllm/model_executor/layers/fused_moe/utils.py
+572
-2
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
...ation/compressed_tensors/compressed_tensors_moe_marlin.py
+125
-11
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
...del_executor/layers/quantization/slimquant_w4a8_marlin.py
+5
-1
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+11
-9
No files found.
vllm/envs.py
View file @
3833018c
...
...
@@ -180,6 +180,7 @@ if TYPE_CHECKING:
VLLM_USE_PD_SPLIT
:
bool
=
False
VLLM_USE_PP_BALANCE
:
bool
=
False
VLLM_USE_ZERO_MTP
:
bool
=
False
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
:
bool
=
True
def
get_default_cache_root
():
return
os
.
getenv
(
...
...
@@ -1181,6 +1182,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_ZERO_MTP"
:
lambda
:
(
os
.
getenv
(
'VLLM_USE_ZERO_MTP'
,
'1'
).
lower
()
in
(
"true"
,
"1"
)),
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM"
:
lambda
:
(
os
.
getenv
(
'VLLM_ENABLE_DEEPEP_HT_DEEPGEMM'
,
'1'
).
lower
()
in
(
"true"
,
"1"
)),
}
# --8<-- [end:env-vars-definition]
...
...
vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
View file @
3833018c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
from
collections.abc
import
Callable
import
deep_ep
import
torch
...
...
@@ -58,39 +59,49 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return
None
return
deep_ep
.
Buffer
.
get_combine_config
(
self
.
dp_size
)
def
sync
(
self
):
# torch.cuda.synchronize()
dist
.
barrier
()
def
_do_dispatch
(
self
,
tokens
:
torch
.
Tensor
,
token_scales
:
Optional
[
torch
.
Tensor
],
def
_do_dispatch
(
self
,
tokens
:
torch
.
Tensor
,
token_scales
:
torch
.
Tensor
|
None
,
rank_topk_ids
:
torch
.
Tensor
,
rank_topk_weights
:
torch
.
Tensor
,
num_experts
:
int
):
rank_topk_weights
:
torch
.
Tensor
,
num_experts
:
int
,
quant_config
:
FusedMoEQuantConfig
,
)
->
Callable
:
has_scales
=
token_scales
is
not
None
(
num_tokens_per_rank
,
num_tokens_per_rdma_rank
,
expert_num_tokens
,
is_token_in_rank
,
event
)
=
self
.
buffer
.
get_dispatch_layout
(
(
num_tokens_per_rank
,
num_tokens_per_rdma_rank
,
dispatch_expert_num_tokens
,
is_token_in_rank
,
event
,
)
=
self
.
buffer
.
get_dispatch_layout
(
topk_idx
=
rank_topk_ids
,
num_experts
=
num_experts
,
previous_event
=
None
,
async_finish
=
False
,
allocate_on_comm_stream
=
False
)
allocate_on_comm_stream
=
False
,
)
token_data
=
tokens
if
has_scales
:
token_data
=
(
tokens
,
token_scales
)
(
token_data
,
expert_topk_ids
,
expert_topk_weights
,
expert_num_tokens_per_expert_list
,
self
.
handle
,
event
token_data
,
expert_topk_ids
,
expert_topk_weights
,
expert_num_tokens_per_expert_list
,
self
.
handle
,
event
,
)
=
self
.
buffer
.
dispatch
(
x
=
token_data
,
handle
=
None
,
num_tokens_per_rank
=
num_tokens_per_rank
,
num_tokens_per_rdma_rank
=
num_tokens_per_rdma_rank
,
is_token_in_rank
=
is_token_in_rank
,
num_tokens_per_expert
=
expert_num_tokens
,
num_tokens_per_expert
=
dispatch_
expert_num_tokens
,
topk_idx
=
rank_topk_ids
,
topk_weights
=
rank_topk_weights
,
# expert_alignment rounds the number of tokens per expert
...
...
@@ -98,8 +109,36 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_alignment
=
1
,
config
=
self
.
_get_dispatch_config
(),
previous_event
=
None
,
async_finish
=
False
,
allocate_on_comm_stream
=
False
)
async_finish
=
True
,
allocate_on_comm_stream
=
False
,
)
return
lambda
:
self
.
_receiver
(
event
,
has_scales
,
token_data
,
expert_topk_ids
,
num_experts
,
expert_num_tokens_per_expert_list
,
expert_topk_weights
,
token_scales
,
quant_config
,
)
def
_receiver
(
self
,
event
:
deep_ep
.
EventOverlap
,
has_scales
:
bool
,
token_data
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
torch
.
Tensor
,
expert_topk_ids
:
torch
.
Tensor
|
None
,
num_experts
:
int
,
expert_num_tokens_per_expert_list
:
list
[
int
],
expert_topk_weights
:
torch
.
Tensor
|
None
,
a1_scale
:
torch
.
Tensor
|
None
,
quant_config
:
FusedMoEQuantConfig
,
)
->
mk
.
PrepareResultType
:
if
event
.
event
is
not
None
:
event
.
current_stream_wait
()
if
has_scales
:
expert_x
,
expert_x_scale
=
token_data
...
...
@@ -117,15 +156,45 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# DeepEP's topk_ids output refers to the local experts directly. Offset
# the topk_ids to move it back to the global experts space so it aligns
# with existing vLLM interfaces.
assert
expert_topk_ids
is
not
None
expert_topk_ids
=
torch
.
where
(
expert_topk_ids
==
-
1
,
num_experts
-
1
if
self
.
rank_expert_offset
==
0
else
0
,
expert_topk_ids
+
self
.
rank_expert_offset
)
expert_topk_ids
+
self
.
rank_expert_offset
,
)
return
(
expert_x
,
expert_x_scale
,
expert_num_tokens
,
expert_topk_ids
,
expert_topk_weights
)
# Makes a GPU-CPU copy.
# TODO (varun): Maybe it is better to re-compute the expert_num_tokens
# on GPU.
expert_tokens_meta
=
mk
.
ExpertTokensMetadata
.
make_from_list
(
expert_num_tokens_per_expert_list
,
device
=
expert_x
.
device
)
def
prepare
(
# Dispatch and Quant
# DeepEP kernels only support dispatching block-quantized
# activation scales.
# Dispatch in bfloat16 and quantize afterwards
if
not
quant_config
.
per_act_token_quant
:
# Quantize after dispatch.
expert_x_scale
=
None
if
expert_x
.
numel
()
!=
0
:
expert_x
,
expert_x_scale
=
moe_kernel_quantize_input
(
expert_x
,
a1_scale
,
quant_dtype
=
quant_config
.
quant_dtype
,
per_act_token_quant
=
False
,
block_shape
=
quant_config
.
block_shape
,
)
return
(
expert_x
,
expert_x_scale
,
expert_tokens_meta
,
expert_topk_ids
,
expert_topk_weights
,
)
def
prepare_async
(
self
,
a1
:
torch
.
Tensor
,
a1_scale
:
Optional
[
torch
.
Tensor
],
...
...
@@ -136,14 +205,13 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map
:
Optional
[
torch
.
Tensor
],
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
)
->
mk
.
ReceiverType
:
if
apply_router_weight_on_input
:
topk
=
topk_ids
.
size
(
1
)
# TODO: this only works for topK=1, will need to update for topK>1
assert
topk
==
1
,
(
"apply_router_weight_on_input is only implemented for topk=1"
)
"apply_router_weight_on_input is only implemented for topk=1"
)
a1
=
a1
*
topk_weights
.
to
(
a1
.
dtype
)
if
quant_config
.
per_act_token_quant
:
...
...
@@ -156,35 +224,43 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
)
if
a1q_scale
is
not
None
and
a1q_scale
.
numel
()
==
1
:
a1q_scale
=
a1q_scale
.
view
(
1
,
1
)
(
expert_x
,
expert_x_scale
,
expert_num_tokens
,
expert_topk_ids
,
expert_topk_weights
)
=
self
.
_do_dispatch
(
else
:
a1q
=
a1
a1q_scale
=
None
return
self
.
_do_dispatch
(
tokens
=
a1q
,
token_scales
=
a1q_scale
,
rank_topk_ids
=
topk_ids
,
rank_topk_weights
=
topk_weights
,
num_experts
=
num_experts
)
else
:
# DeepEP kernels only support dispatching per-token-quant
# quantization. dispatch in bfloat16.
(
expert_x
,
_
,
expert_num_tokens
,
expert_topk_ids
,
expert_topk_weights
)
=
self
.
_do_dispatch
(
tokens
=
a1
,
token_scales
=
None
,
rank_topk_ids
=
topk_ids
,
rank_topk_weights
=
topk_weights
,
num_experts
=
num_experts
)
# quantize now
expert_x_scale
=
None
if
expert_x
.
numel
()
!=
0
:
expert_x
,
expert_x_scale
=
moe_kernel_quantize_input
(
expert_x
,
a1_scale
,
quant_dtype
=
quant_config
.
quant_dtype
,
per_act_token_quant
=
False
,
block_shape
=
quant_config
.
block_shape
)
num_experts
=
num_experts
,
quant_config
=
quant_config
,
)
return
(
expert_x
,
expert_x_scale
,
expert_num_tokens
,
expert_topk_ids
,
expert_topk_weights
)
def
prepare
(
self
,
a1
:
torch
.
Tensor
,
a1_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
)
->
mk
.
PrepareResultType
:
receiver
=
self
.
prepare_async
(
a1
,
a1_scale
,
a2_scale
,
topk_weights
,
topk_ids
,
num_experts
,
expert_map
,
apply_router_weight_on_input
,
quant_config
,
)
return
receiver
()
def
_apply_weights_and_reduce
(
self
,
num_tokens
:
int
,
fused_expert_output
:
torch
.
Tensor
,
...
...
@@ -210,31 +286,88 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return
out
def
finalize
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
def
_finalize
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
apply_weights_and_reduce
:
bool
=
True
)
->
None
:
do_async
:
bool
,
apply_weights_and_reduce
:
bool
=
True
,
)
->
Callable
|
None
:
assert
self
.
handle
is
not
None
# fused_expert_output can have 0 tokens - This happens when none of the
# tokens from the all2all reach this EP rank.
if
fused_expert_output
.
numel
()
!=
0
and
apply_weights_and_reduce
:
fused_expert_output
=
self
.
_apply_weights_and_reduce
(
num_tokens
=
topk_ids
.
size
(
0
),
fused_expert_output
=
fused_expert_output
,
topk_weights
=
topk_weights
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
output_dtype
=
output
.
dtype
)
# if fused_expert_output.numel() != 0 and apply_weights_and_reduce:
# fused_expert_output = self._apply_weights_and_reduce(
# num_tokens=topk_ids.size(0),
# fused_expert_output=fused_expert_output,
# topk_weights=topk_weights,
# apply_router_weight_on_input=apply_router_weight_on_input,
# output_dtype=output.dtype)
combined_x
,
_
,
event
=
self
.
buffer
.
combine
(
# HT combine only supports BF16
x
=
fused_expert_output
,
handle
=
self
.
handle
,
topk_weights
=
None
,
config
=
self
.
_get_combine_config
(),
previous_event
=
None
,
async_finish
=
False
,
allocate_on_comm_stream
=
False
)
async_finish
=
do_async
,
allocate_on_comm_stream
=
False
,
)
if
do_async
:
def
_receiver
():
if
event
.
event
is
not
None
:
event
.
current_stream_wait
()
# Respect inplace outputs.
output
.
copy_
(
combined_x
,
non_blocking
=
True
)
return
_receiver
else
:
# Respect inplace outputs.
output
.
copy_
(
combined_x
,
non_blocking
=
True
)
return
None
def
finalize_async
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
apply_weights_and_reduce
:
bool
=
True
,
)
->
Callable
:
receiver
=
self
.
_finalize
(
output
,
fused_expert_output
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
do_async
=
True
,
apply_weights_and_reduce
=
apply_weights_and_reduce
,
)
assert
receiver
is
not
None
return
receiver
def
finalize
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
apply_weights_and_reduce
:
bool
=
True
,
)
->
None
:
self
.
_finalize
(
output
,
fused_expert_output
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
do_async
=
False
,
apply_weights_and_reduce
=
apply_weights_and_reduce
,
)
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
View file @
3833018c
...
...
@@ -115,7 +115,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return
x
,
x_scales
def
prepare
(
def
prepare
_async
(
self
,
a1
:
torch
.
Tensor
,
a1_scale
:
Optional
[
torch
.
Tensor
],
...
...
@@ -126,9 +126,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map
:
Optional
[
torch
.
Tensor
],
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
)
->
tuple
[
Callable
,
mk
.
ReceiverType
]:
hidden_size
=
a1
.
size
(
1
)
assert
hidden_size
in
self
.
SUPPORTED_HIDDEN_SIZES
,
\
(
f
"Hidden Size
{
hidden_size
}
not in supported list of hidden sizes"
...
...
@@ -148,25 +146,74 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk
=
topk_ids
.
size
(
1
)
# TODO: this only works for topK=1, will need to update for topK>1
assert
topk
==
1
,
(
"apply_router_weight_on_input is only implemented for topk=1"
)
"apply_router_weight_on_input is only implemented for topk=1"
)
a1
=
a1
*
topk_weights
.
to
(
a1
.
dtype
)
# Dispatch
expert_x
,
expert_num_tokens
,
self
.
handle
,
event
,
hook
=
\
self
.
buffer
.
low_latency_dispatch
(
a1
,
expert_x
,
expert_num_tokens
,
self
.
handle
s
,
_
,
hook
=
self
.
buffer
.
low_latency_dispatch
(
a1
,
topk_ids
,
self
.
max_tokens_per_rank
,
num_experts
,
use_fp8
=
self
.
use_fp8_dispatch
or
self
.
use_int8_dispatch
,
use_int8
=
self
.
use_int8_dispatch
,
async_finish
=
False
,
return_recv_hook
=
False
)
return_recv_hook
=
True
,
)
return
(
hook
,
lambda
:
self
.
_receiver
(
expert_x
,
expert_num_tokens
,
a1_scale
,
a1
.
dtype
,
quant_config
,
),
)
def
_receiver
(
self
,
expert_x
:
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
expert_num_tokens
:
torch
.
Tensor
,
a1_scale
:
torch
.
Tensor
|
None
,
a1_dtype
:
torch
.
dtype
,
quant_config
:
FusedMoEQuantConfig
,
)
->
mk
.
PrepareResultType
:
expert_x
,
expert_x_scale
=
self
.
_do_quant
(
expert_x
,
a1_dtype
,
quant_config
)
expert_
x
,
expert_x_scale
=
self
.
_do_quant
(
expert_
x
,
a1_scale
,
a2_scale
,
a1
.
dtype
,
quant_config
.
quant_dtype
,
quant_config
.
per_act_token_quant
,
quant_config
.
block_shape
,
expert_num_tokens
)
expert_
tokens_meta
=
mk
.
ExpertTokensMetadata
(
expert_
num_tokens
=
expert_num_tokens
,
expert_num_tokens_cpu
=
None
)
return
(
expert_x
,
expert_x_scale
,
expert_num_tokens
,
None
,
None
)
return
expert_x
,
expert_x_scale
,
expert_tokens_meta
,
None
,
None
def
prepare
(
self
,
a1
:
torch
.
Tensor
,
a1_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
)
->
mk
.
PrepareResultType
:
hook
,
receiver
=
self
.
prepare_async
(
a1
,
a1_scale
,
a2_scale
,
topk_weights
,
topk_ids
,
num_experts
,
expert_map
,
apply_router_weight_on_input
,
quant_config
,
)
hook
()
return
receiver
()
def
_finalize
(
self
,
...
...
vllm/model_executor/layers/fused_moe/modular_kernel.py
View file @
3833018c
...
...
@@ -4,6 +4,8 @@ from abc import ABC, abstractmethod
from
enum
import
Enum
from
math
import
prod
from
typing
import
Optional
,
final
from
dataclasses
import
dataclass
from
collections.abc
import
Callable
import
torch
...
...
@@ -95,6 +97,50 @@ class FusedMoEActivationFormat(Enum):
BatchedExperts
=
"batched_experts"
,
@
dataclass
class
ExpertTokensMetadata
:
"""
Metadata regarding expert-token routing.
"""
expert_num_tokens
:
torch
.
Tensor
expert_num_tokens_cpu
:
torch
.
Tensor
|
None
@
staticmethod
def
make_from_list
(
expert_num_tokens_list
:
list
[
int
],
device
:
str
)
->
"ExpertTokensMetadata"
:
expert_num_tokens_cpu
=
torch
.
tensor
(
expert_num_tokens_list
,
device
=
"cpu"
,
dtype
=
torch
.
int32
)
return
ExpertTokensMetadata
(
expert_num_tokens
=
expert_num_tokens_cpu
.
to
(
device
,
non_blocking
=
True
),
expert_num_tokens_cpu
=
expert_num_tokens_cpu
,
)
#
# PrepareResultType is a tuple of:
# - quantized + dispatched a.
# - quantized + dispatched a1_scales.
# - Optional ExpertTokensMetadata containing gpu/cpu tensors
# as big as the number of local experts with the information about the
# number of tokens assigned to each local expert.
# - Optional dispatched expert topk IDs
# - Optional dispatched expert topk weight
#
# See `prepare` method below.
#
PrepareResultType
=
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
,
ExpertTokensMetadata
|
None
,
torch
.
Tensor
|
None
,
torch
.
Tensor
|
None
,
]
ReceiverType
=
Callable
[[],
PrepareResultType
]
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
class
FusedMoEPrepareAndFinalize
(
ABC
):
"""
...
...
@@ -880,8 +926,19 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
if
global_num_experts
==
-
1
:
global_num_experts
=
local_num_experts
(
a1q
,
a1q_scale
,
expert_num_tokens
,
_expert_topk_ids
,
_expert_topk_weights
)
=
self
.
prepare_finalize
.
prepare
(
# (a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
# _expert_topk_weights) = self.prepare_finalize.prepare(
# a1,
# a1_scale,
# a2_scale,
# topk_weights,
# topk_ids,
# global_num_experts,
# expert_map,
# apply_router_weight_on_input,
# self.fused_experts.quant_config,
# )
prepare_ret
=
self
.
prepare_finalize
.
prepare_async
(
a1
,
a1_scale
,
a2_scale
,
...
...
@@ -892,12 +949,35 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input
,
self
.
fused_experts
.
quant_config
,
)
hook
,
receiver
=
(
prepare_ret
if
isinstance
(
prepare_ret
,
tuple
)
else
(
None
,
prepare_ret
)
)
if
hook
is
not
None
:
hook
()
(
a1q
,
a1q_scale
,
expert_tokens_meta
,
_expert_topk_ids
,
_expert_topk_weights
,
)
=
receiver
()
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
topk_ids
=
topk_ids
if
_expert_topk_ids
is
None
else
_expert_topk_ids
topk_weights
=
(
topk_weights
if
_expert_topk_weights
is
None
else
_expert_topk_weights
)
if
a1q
.
numel
()
==
0
:
# This happens when none of the tokens from the all2all reach this
# EP rank. Also, note that this is only relevant for CUDAGraph
# incompatible all2all kernels like the DeepEP high-throughput
# kernels. CUDAGraph compatible all2all kernels like the pplx
# kernels and the DeepEP low-latency kernels are always batched
# and can never run into the tensor.numel() == 0 case.
fused_out
=
torch
.
empty_like
(
a1q
).
to
(
dtype
=
a1
.
dtype
)
else
:
fused_out
=
self
.
fused_experts
.
apply
(
None
,
a1
,
...
...
@@ -918,18 +998,15 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
workspace13
=
None
,
workspace2
=
None
,
use_nn_moe
=
use_nn_moe
,
expert_num_tokens
=
expert_num_tokens
,
expert_num_tokens
=
expert_
tokens_meta
.
expert_
num_tokens
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
,
expert_num_tokens_cpu
=
expert_tokens_meta
.
expert_num_tokens_cpu
)
shared_output
=
None
if
self
.
shared_experts
is
None
:
self
.
prepare_finalize
.
finalize
(
output
,
fused_out
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
apply_weights_and_reduce
=
False
)
else
:
hook
=
self
.
prepare_finalize
.
finalize_async
(
output
,
fused_out
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
apply_weights_and_reduce
=
Fals
e
)
topk_ids
,
apply_router_weight_on_input
,
apply_weights_and_reduce
=
Tru
e
)
if
self
.
shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
...
...
vllm/model_executor/layers/fused_moe/triton_group_gemm_moe.py
View file @
3833018c
...
...
@@ -85,6 +85,7 @@ class TritonOrGroupGemmExperts(mk.CustomizedFusedMoEPermuteExpertsUnpermute):
use_nn_moe
:
Optional
[
bool
]
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
expert_num_tokens_cpu
:
torch
.
Tensor
=
None
,
):
assert
self
.
fused_experts
is
not
None
...
...
@@ -107,4 +108,5 @@ class TritonOrGroupGemmExperts(mk.CustomizedFusedMoEPermuteExpertsUnpermute):
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
,
q_x
=
q_hidden_states
,
expert_num_tokens_cpu
=
expert_num_tokens_cpu
)
vllm/model_executor/layers/fused_moe/utils.py
View file @
3833018c
...
...
@@ -11,6 +11,7 @@ from triton.language.extra import libdevice
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
)
from
vllm.utils
import
round_up
try
:
from
lmslim.layers.gemm.int8_utils
import
(
per_token_group_quant_int8
,
per_token_quant_int8
)
...
...
@@ -276,8 +277,8 @@ def _int8_quantize(
# activations apply per-token quantization. Otherwise, assume
# activation tensor-wise fp8/int8 quantization, dynamic or static
if
block_shape
is
None
:
assert
per_act_token
,
\
"int8 quantization only supports block or channel-wise"
#
assert per_act_token, \
#
"int8 quantization only supports block or channel-wise"
if
expert_num_tokens
is
None
:
A
,
A_scale
=
per_token_quant_int8
(
A
)
else
:
...
...
@@ -361,3 +362,572 @@ def _validate_scale_shape(
assert
block_shape
is
not
None
expected
=
(
a
.
shape
[
0
],
cdiv
(
a
.
shape
[
1
],
block_shape
[
1
]))
assert
a_scale
.
shape
==
expected
,
f
"
{
a_scale
.
shape
}
==
{
expected
}
"
@
triton
.
jit
def
_count_expert_num_tokens
(
topk_ids_ptr
,
expert_num_tokens_ptr
,
num_experts
,
topk_numel
,
expert_map
,
HAS_EXPERT_MAP
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
curr_expert
=
tl
.
program_id
(
0
)
offsets
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
topk_ids_ptrs
=
topk_ids_ptr
+
offsets
acc
=
tl
.
zeros
((
BLOCK_SIZE
,),
dtype
=
tl
.
int32
)
for
x
in
range
(
tl
.
cdiv
(
topk_numel
,
BLOCK_SIZE
)):
mask
=
offsets
<
(
topk_numel
-
x
*
BLOCK_SIZE
)
expert_ids
=
tl
.
load
(
topk_ids_ptrs
,
mask
=
mask
,
other
=-
1
)
if
HAS_EXPERT_MAP
:
expert_map_ptrs
=
expert_map
+
expert_ids
expert_map_mask
=
expert_ids
>=
0
expert_ids
=
tl
.
load
(
expert_map_ptrs
,
mask
=
expert_map_mask
,
other
=-
1
)
has_curr_expert
=
tl
.
where
(
expert_ids
==
curr_expert
,
1
,
0
)
acc
=
acc
+
has_curr_expert
topk_ids_ptrs
+=
BLOCK_SIZE
if
curr_expert
<
num_experts
:
tl
.
store
(
expert_num_tokens_ptr
+
curr_expert
,
tl
.
sum
(
acc
))
def
count_expert_num_tokens
(
topk_ids
:
torch
.
Tensor
,
num_local_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
)
->
torch
.
Tensor
:
"""
Count the number to tokens assigned to each expert.
Parameters:
- topk_ids (torch.Tensor): Tensor mapping each token to its
list of experts.
- num_local_experts (int): Number of experts in this rank.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
Returns:
A tensor of size num_local_experts, where tensor[i] holds the number
of tokens assigned to the ith expert.
"""
assert
topk_ids
.
dtype
.
is_signed
,
"The kernel uses -1 to represent invalid topk_ids"
expert_num_tokens
=
torch
.
empty
(
(
num_local_experts
),
device
=
topk_ids
.
device
,
dtype
=
torch
.
int32
)
grid
=
num_local_experts
BLOCK_SIZE
=
min
(
topk_ids
.
numel
(),
1024
)
BLOCK_SIZE
=
triton
.
next_power_of_2
(
BLOCK_SIZE
)
_count_expert_num_tokens
[(
grid
,)](
topk_ids
,
expert_num_tokens
,
num_local_experts
,
topk_ids
.
numel
(),
expert_map
,
HAS_EXPERT_MAP
=
expert_map
is
not
None
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
return
expert_num_tokens
def
expert_num_tokens_round_up_and_sum
(
expert_num_tokens
:
torch
.
Tensor
,
alignment
:
int
)
->
int
:
# Round up each element in expert_num_tokens to the nearest multiple of
# alignment.
ent
=
(
expert_num_tokens
.
to
(
torch
.
int64
)
+
(
alignment
-
1
))
//
alignment
*
alignment
return
torch
.
sum
(
ent
).
item
()
def
compute_aligned_M
(
M
:
int
,
num_topk
:
int
,
local_num_experts
:
int
,
alignment
:
int
,
expert_num_tokens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
,
):
if
expert_num_tokens_cpu
is
not
None
:
return
expert_num_tokens_round_up_and_sum
(
expert_num_tokens_cpu
,
alignment
=
alignment
)
# expert_num_tokens information is not available on the cpu.
# compute the max required size.
M_sum
=
(
M
*
num_topk
)
+
local_num_experts
*
(
alignment
-
1
)
M_sum
=
round_up
(
M_sum
,
alignment
)
return
M_sum
@
triton
.
jit
def
apply_expert_map
(
expert_id
,
expert_map
):
if
expert_id
!=
-
1
:
expert_id
=
tl
.
load
(
expert_map
+
expert_id
).
to
(
expert_id
.
dtype
)
return
expert_id
@
triton
.
jit
def
round_up_256
(
x
:
int
)
->
int
:
y
=
256
return
((
x
+
y
-
1
)
//
y
)
*
y
@
triton
.
jit
def
round_up_128
(
x
:
int
)
->
int
:
y
=
128
return
((
x
+
y
-
1
)
//
y
)
*
y
@
triton
.
jit
def
_fwd_kernel_ep_scatter_1
(
num_recv_tokens_per_expert
,
expert_start_loc
,
m_indices
,
num_experts
:
tl
.
constexpr
,
BLOCK_E
:
tl
.
constexpr
,
BLOCK_EXPERT_NUM
:
tl
.
constexpr
,
):
cur_expert
=
tl
.
program_id
(
0
)
offset_cumsum
=
tl
.
arange
(
0
,
BLOCK_EXPERT_NUM
)
tokens_per_expert
=
tl
.
load
(
num_recv_tokens_per_expert
+
offset_cumsum
,
mask
=
offset_cumsum
<
num_experts
,
other
=
0
,
)
#tokens_per_expert = round_up_128(tokens_per_expert)
tokens_per_expert
=
round_up_256
(
tokens_per_expert
)
cumsum
=
tl
.
cumsum
(
tokens_per_expert
)
-
tokens_per_expert
#if cur_expert == 0:
tl
.
store
(
expert_start_loc
+
offset_cumsum
,
cumsum
,
mask
=
offset_cumsum
<
num_experts
)
tl
.
debug_barrier
()
#cur_expert_start = cumsum[cur_expert]
cur_expert_start
=
tl
.
load
(
expert_start_loc
+
cur_expert
)
cur_expert_token_num
=
tl
.
load
(
num_recv_tokens_per_expert
+
cur_expert
)
m_indices_start_ptr
=
m_indices
+
cur_expert_start
off_expert
=
tl
.
arange
(
0
,
BLOCK_E
)
for
start_m
in
tl
.
range
(
0
,
cur_expert_token_num
,
BLOCK_E
,
num_stages
=
4
):
tl
.
store
(
m_indices_start_ptr
+
start_m
+
off_expert
,
cur_expert
,
mask
=
start_m
+
off_expert
<
cur_expert_token_num
)
@
triton
.
jit
def
_fwd_kernel_ep_scatter_2
(
total_token_num
,
expert_start_loc
,
recv_x
,
recv_x_stride0
,
recv_x_stride1
,
recv_x_scale
,
recv_x_scale_stride0
,
recv_x_scale_stride1
,
recv_topk
,
recv_topk_stride0
,
recv_topk_stride1
,
output_tensor
,
output_tensor_stride0
,
output_tensor_stride1
,
output_tensor_scale
,
output_tensor_scale_stride0
,
output_tensor_scale_stride1
,
output_index
,
output_index_stride0
,
output_index_stride1
,
topk_num
:
tl
.
constexpr
,
expert_map
,
HAS_EXPERT_MAP
:
tl
.
constexpr
,
HIDDEN_SIZE
:
tl
.
constexpr
,
HIDDEN_SIZE_PAD
:
tl
.
constexpr
,
SCALE_HIDDEN_SIZE
:
tl
.
constexpr
,
SCALE_HIDDEN_SIZE_PAD
:
tl
.
constexpr
,
):
# start_token_id = tl.program_id(0)
# grid_num = tl.num_programs(0)
# offset_in = tl.arange(0, HIDDEN_SIZE_PAD)
# mask = offset_in < HIDDEN_SIZE
# offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
# mask_s = offset_in_s < SCALE_HIDDEN_SIZE
# for token_id in range(start_token_id, total_token_num, grid_num):
# to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
# to_copy_s = tl.load(
# recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s
# )
# for topk_index in tl.range(0, topk_num, 1, num_stages=4):
# expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
# if HAS_EXPERT_MAP:
# expert_id = apply_expert_map(expert_id, expert_map)
# if expert_id >= 0:
# dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1)
# tl.store(
# output_index + token_id * output_index_stride0 + topk_index,
# dest_token_index,
# )
# output_tensor_ptr = (
# output_tensor + dest_token_index * output_tensor_stride0
# )
# output_tensor_scale_ptr = (
# output_tensor_scale + dest_token_index * output_tensor_scale_stride0
# )
# tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask)
# tl.store(output_tensor_scale_ptr + offset_in_s, to_copy_s, mask=mask_s)
start_token_id
=
tl
.
program_id
(
0
)
grid_num
=
tl
.
num_programs
(
0
)
offset_in
=
tl
.
arange
(
0
,
HIDDEN_SIZE_PAD
)
mask
=
offset_in
<
HIDDEN_SIZE
index_in_s
=
tl
.
arange
(
0
,
SCALE_HIDDEN_SIZE_PAD
)
mask_s
=
index_in_s
<
SCALE_HIDDEN_SIZE
for
token_id_int32
in
range
(
start_token_id
,
total_token_num
,
grid_num
):
token_id
=
token_id_int32
.
to
(
tl
.
int64
)
to_copy
=
tl
.
load
(
recv_x
+
token_id
*
recv_x_stride0
+
offset_in
,
mask
=
mask
)
to_copy_s
=
tl
.
load
(
recv_x_scale
+
token_id
*
recv_x_scale_stride0
+
index_in_s
*
recv_x_scale_stride1
,
mask
=
mask_s
,
)
for
topk_idx_int32
in
tl
.
range
(
0
,
topk_num
,
1
,
num_stages
=
4
):
topk_index
=
topk_idx_int32
.
to
(
tl
.
int64
)
expert_id
=
tl
.
load
(
recv_topk
+
token_id
*
recv_topk_stride0
+
topk_index
)
if
HAS_EXPERT_MAP
:
expert_id
=
apply_expert_map
(
expert_id
,
expert_map
)
if
expert_id
>=
0
:
dest_token_index_int32
=
tl
.
atomic_add
(
expert_start_loc
+
expert_id
,
1
)
dest_token_index
=
dest_token_index_int32
.
to
(
tl
.
int64
)
tl
.
store
(
output_index
+
token_id
*
output_index_stride0
+
topk_index
,
dest_token_index_int32
,
)
output_tensor_ptr
=
(
output_tensor
+
dest_token_index
*
output_tensor_stride0
)
output_tensor_scale_ptr
=
(
output_tensor_scale
+
dest_token_index
*
output_tensor_scale_stride0
)
tl
.
store
(
output_tensor_ptr
+
offset_in
,
to_copy
,
mask
=
mask
)
tl
.
store
(
output_tensor_scale_ptr
+
index_in_s
*
output_tensor_scale_stride1
,
to_copy_s
,
mask
=
mask_s
,
)
@
torch
.
no_grad
()
def
ep_scatter
(
recv_x
:
torch
.
Tensor
,
recv_x_scale
:
torch
.
Tensor
,
recv_topk
:
torch
.
Tensor
,
num_recv_tokens_per_expert
:
torch
.
Tensor
,
expert_map
:
torch
.
Tensor
|
None
,
expert_start_loc
:
torch
.
Tensor
,
output_tensor
:
torch
.
Tensor
,
output_tensor_scale
:
torch
.
Tensor
,
m_indices
:
torch
.
Tensor
,
output_index
:
torch
.
Tensor
,
):
#BLOCK_E = 128 # token num of per expert is aligned to 128
#BLOCK_D = 128 # block size of quantization
BLOCK_E
=
256
# token num of per expert is aligned to 256
num_warps
=
8
num_experts
=
num_recv_tokens_per_expert
.
shape
[
0
]
hidden_size
=
recv_x
.
shape
[
1
]
scale_hidden_size
=
recv_x_scale
.
shape
[
-
1
]
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
grid
=
num_experts
assert
m_indices
.
shape
[
0
]
%
BLOCK_E
==
0
_fwd_kernel_ep_scatter_1
[(
grid
,)](
num_recv_tokens_per_expert
,
expert_start_loc
,
m_indices
,
num_experts
=
num_experts
,
num_warps
=
num_warps
,
BLOCK_E
=
BLOCK_E
,
BLOCK_EXPERT_NUM
=
triton
.
next_power_of_2
(
num_experts
),
)
grid
=
min
(
recv_topk
.
shape
[
0
],
1024
*
8
)
_fwd_kernel_ep_scatter_2
[(
grid
,)](
recv_topk
.
shape
[
0
],
expert_start_loc
,
recv_x
,
recv_x
.
stride
(
0
),
recv_x
.
stride
(
1
),
recv_x_scale
,
recv_x_scale
.
stride
(
0
),
recv_x_scale
.
stride
(
1
),
recv_topk
,
recv_topk
.
stride
(
0
),
recv_topk
.
stride
(
1
),
output_tensor
,
output_tensor
.
stride
(
0
),
output_tensor
.
stride
(
1
),
output_tensor_scale
,
output_tensor_scale
.
stride
(
0
),
output_tensor_scale
.
stride
(
1
),
output_index
,
output_index
.
stride
(
0
),
output_index
.
stride
(
1
),
topk_num
=
recv_topk
.
shape
[
1
],
expert_map
=
expert_map
,
HAS_EXPERT_MAP
=
expert_map
is
not
None
,
num_warps
=
num_warps
,
HIDDEN_SIZE
=
hidden_size
,
HIDDEN_SIZE_PAD
=
triton
.
next_power_of_2
(
hidden_size
),
SCALE_HIDDEN_SIZE
=
scale_hidden_size
,
#hidden_size // BLOCK_D,
SCALE_HIDDEN_SIZE_PAD
=
triton
.
next_power_of_2
(
scale_hidden_size
)
#triton.next_power_of_2(hidden_size // BLOCK_D),
)
return
@
triton
.
jit
def
_fwd_kernel_ep_gather
(
total_token_num
,
input_tensor
,
input_tensor_stride0
,
input_tensor_stride1
,
recv_topk_ids
,
recv_topk_ids_stride0
,
recv_topk_ids_stride1
,
recv_topk_weight
,
recv_topk_weight_stride0
,
recv_topk_weight_stride1
,
input_index
,
input_index_stride0
,
input_index_stride1
,
output_tensor
,
output_tensor_stride0
,
output_tensor_stride1
,
topk_num
:
tl
.
constexpr
,
expert_map
,
HAS_EXPERT_MAP
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
):
# cur_block = tl.program_id(0)
# start_cur_token = tl.program_id(1)
# grid_num = tl.num_programs(1)
# for cur_token in range(start_cur_token, total_token_num, grid_num):
# off_d = tl.arange(0, BLOCK_D)
# accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)
# for topk_index in range(0, topk_num):
# expert_id = tl.load(
# recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index
# )
# if HAS_EXPERT_MAP:
# expert_id = apply_expert_map(expert_id, expert_map)
# if expert_id >= 0:
# source_token_index = tl.load(
# input_index + cur_token * input_index_stride0 + topk_index
# )
# acc_weight = tl.load(
# recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index
# )
# tmp = tl.load(
# input_tensor
# + source_token_index * input_tensor_stride0
# + cur_block * BLOCK_D
# + off_d
# )
# accumulator += tmp.to(tl.float32) * acc_weight
# tl.store(
# output_tensor
# + cur_token * output_tensor_stride0
# + cur_block * BLOCK_D
# + off_d,
# accumulator.to(output_tensor.dtype.element_ty),
# )
cur_block_int32
=
tl
.
program_id
(
0
)
cur_block
=
cur_block_int32
.
to
(
tl
.
int64
)
start_cur_token_int32
=
tl
.
program_id
(
1
)
grid_num
=
tl
.
num_programs
(
1
)
for
cur_token_int32
in
range
(
start_cur_token_int32
,
total_token_num
,
grid_num
):
cur_token
=
cur_token_int32
.
to
(
tl
.
int64
)
off_d
=
tl
.
arange
(
0
,
BLOCK_D
)
accumulator
=
tl
.
zeros
([
BLOCK_D
],
dtype
=
tl
.
float32
)
for
topk_index_int32
in
range
(
0
,
topk_num
):
topk_index
=
topk_index_int32
.
to
(
tl
.
int64
)
expert_id
=
tl
.
load
(
recv_topk_ids
+
cur_token
*
recv_topk_ids_stride0
+
topk_index
)
if
HAS_EXPERT_MAP
:
expert_id
=
apply_expert_map
(
expert_id
,
expert_map
)
if
expert_id
>=
0
:
source_token_index_int32
=
tl
.
load
(
input_index
+
cur_token
*
input_index_stride0
+
topk_index
)
source_token_index
=
source_token_index_int32
.
to
(
tl
.
int64
)
acc_weight
=
tl
.
load
(
recv_topk_weight
+
cur_token
*
recv_topk_weight_stride0
+
topk_index
)
tmp
=
tl
.
load
(
input_tensor
+
source_token_index
*
input_tensor_stride0
+
cur_block
*
BLOCK_D
+
off_d
)
accumulator
+=
tmp
.
to
(
tl
.
float32
)
*
acc_weight
tl
.
store
(
output_tensor
+
cur_token
*
output_tensor_stride0
+
cur_block
*
BLOCK_D
+
off_d
,
accumulator
.
to
(
output_tensor
.
dtype
.
element_ty
),
)
@
torch
.
no_grad
()
def
ep_gather
(
input_tensor
:
torch
.
Tensor
,
recv_topk_ids
:
torch
.
Tensor
,
recv_topk_weight
:
torch
.
Tensor
,
input_index
:
torch
.
Tensor
,
expert_map
:
torch
.
Tensor
|
None
,
output_tensor
:
torch
.
Tensor
,
):
num_warps
=
2
num_tokens
=
output_tensor
.
shape
[
0
]
hidden_size
=
input_tensor
.
shape
[
1
]
BLOCK_D
=
min
(
hidden_size
,
1024
)
assert
hidden_size
%
BLOCK_D
==
0
grid
=
(
triton
.
cdiv
(
hidden_size
,
BLOCK_D
),
min
(
num_tokens
,
1024
))
_fwd_kernel_ep_gather
[
grid
](
num_tokens
,
input_tensor
,
input_tensor
.
stride
(
0
),
input_tensor
.
stride
(
1
),
recv_topk_ids
,
recv_topk_ids
.
stride
(
0
),
recv_topk_ids
.
stride
(
1
),
recv_topk_weight
,
recv_topk_weight
.
stride
(
0
),
recv_topk_weight
.
stride
(
1
),
input_index
,
input_index
.
stride
(
0
),
input_index
.
stride
(
1
),
output_tensor
,
output_tensor
.
stride
(
0
),
output_tensor
.
stride
(
1
),
topk_num
=
recv_topk_ids
.
shape
[
1
],
expert_map
=
expert_map
,
HAS_EXPERT_MAP
=
expert_map
is
not
None
,
num_warps
=
num_warps
,
BLOCK_D
=
BLOCK_D
,
)
return
def
deepgemm_moe_permute
(
aq
:
torch
.
Tensor
,
aq_scale
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
local_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
block_shape
:
list
[
int
],
expert_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_num_tokens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
,
aq_out
:
torch
.
Tensor
|
None
=
None
,
):
assert
aq
.
ndim
==
2
assert
topk_ids
.
dtype
.
is_signed
,
"The kernel uses -1 to represent invalid topk_ids"
H
=
aq
.
size
(
1
)
device
=
aq
.
device
block_m
=
block_shape
[
0
]
M_sum
=
compute_aligned_M
(
M
=
topk_ids
.
size
(
0
),
num_topk
=
topk_ids
.
size
(
1
),
local_num_experts
=
local_num_experts
,
alignment
=
block_m
,
expert_num_tokens_cpu
=
expert_num_tokens_cpu
,
)
expert_start_loc
=
torch
.
empty
(
(
local_num_experts
),
device
=
device
,
dtype
=
torch
.
int32
)
assert
aq_out
is
None
or
aq_out
.
shape
==
(
M_sum
,
H
)
if
aq_out
is
None
:
aq_out
=
torch
.
empty
((
M_sum
,
H
),
device
=
device
,
dtype
=
aq
.
dtype
)
aq_scale_out
=
torch
.
empty
(
(
M_sum
,
aq_scale
.
shape
[
-
1
]),
device
=
device
,
dtype
=
torch
.
float32
#(M_sum, H // block_k), device=device, dtype=torch.float32
)
# maybe_has_empty_blocks = expert_num_tokens_cpu is None
# expert_ids_init = torch.zeros# if maybe_has_empty_blocks else torch.empty
# expert_ids = expert_ids_init((M_sum), device=device, dtype=torch.int32)
expert_ids
=
torch
.
full
(
(
M_sum
,),
-
1
,
dtype
=
torch
.
int32
,
device
=
device
)
inv_perm
=
torch
.
empty
(
topk_ids
.
shape
,
device
=
device
,
dtype
=
torch
.
int32
)
if
expert_num_tokens
is
None
:
expert_num_tokens
=
count_expert_num_tokens
(
topk_ids
,
local_num_experts
,
expert_map
)
ep_scatter
(
recv_x
=
aq
,
recv_x_scale
=
aq_scale
,
recv_topk
=
topk_ids
,
num_recv_tokens_per_expert
=
expert_num_tokens
,
expert_start_loc
=
expert_start_loc
,
expert_map
=
expert_map
,
output_tensor
=
aq_out
,
output_tensor_scale
=
aq_scale_out
,
m_indices
=
expert_ids
,
output_index
=
inv_perm
,
)
return
aq_out
,
aq_scale_out
,
expert_ids
,
inv_perm
def
deepgemm_unpermute_and_reduce
(
a
:
torch
.
Tensor
,
# Grouped gemm output
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
inv_perm
:
torch
.
Tensor
,
expert_map
:
torch
.
Tensor
|
None
,
output
:
torch
.
Tensor
,
):
return
ep_gather
(
input_tensor
=
a
,
recv_topk_ids
=
topk_ids
,
recv_topk_weight
=
topk_weights
,
input_index
=
inv_perm
,
expert_map
=
expert_map
,
output_tensor
=
output
,
)
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
View file @
3833018c
...
...
@@ -19,12 +19,15 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoEConfig
,
FusedMoeWeightScaleSupported
,
FusedMoEPermuteExpertsUnpermute
,
FusedMoEPrepareAndFinalize
,)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.layers.fused_moe.utils
import
_resize_cache
from
vllm.model_executor.layers.fused_moe.utils
import
_resize_cache
,
compute_aligned_M
,
deepgemm_moe_permute
,
deepgemm_unpermute_and_reduce
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
get_w8a8_int8_marlin_weights
,
w8a8_nt_kpack2_marlin_weight
)
from
vllm.model_executor.layers.fused_moe.moe_permute_unpermute
import
(
_moe_permute
)
from
vllm.utils
import
round_up
try
:
from
lightop
import
m_grouped_w8a8_gemm_nt_masked
,
fuse_silu_mul_quant_ep
from
lightop
import
m_grouped_w8a8_gemm_nt_masked
,
m_grouped_w8a8_gemm_nt_contig_asm
,
fuse_silu_mul_quant_ep
,
fuse_silu_mul_quant
from
lmslim.layers.fused_moe.fuse_moe_int8_marlin
import
fused_experts_impl_int8_marlin
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer the quantitative model of moe.
\n
"
)
...
...
@@ -88,19 +91,21 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
(
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_high_throughput"
or
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
)
self
.
use_deepep_ll
=
self
.
use_deepep
and
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
#
self.use_deepep_ll = self.use_deepep and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency"
if
self
.
use_deepep
:
all2all_manager
=
get_ep_group
().
device_communicator
.
all2all_manager
assert
all2all_manager
is
not
None
self
.
num_dispatchers
=
all2all_manager
.
world_size
self
.
block_shape
=
[
256
,
256
]
self
.
use_deepgemm
=
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
or
envs
.
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
if
self
.
use_deepep
_ll
:
if
self
.
use_deepep
:
self
.
N
=
2
*
intermediate_size_per_partition
self
.
K
=
hidden_size
...
...
@@ -154,7 +159,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
w1_marlin_list
=
[]
for
ii
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
if
not
self
.
use_deep
ep_ll
:
if
not
self
.
use_deep
gemm
:
w1_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w13_weight
[
ii
])
else
:
w1_marlin_in
=
w8a8_nt_kpack2_marlin_weight
(
layer
.
w13_weight
[
ii
])
...
...
@@ -165,7 +170,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
del
w1_marlin_list
w2_marlin_list
=
[]
for
ii
in
range
(
layer
.
w2_weight
.
shape
[
0
]):
if
not
self
.
use_deep
ep_ll
:
if
not
self
.
use_deep
gemm
:
w2_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w2_weight
[
ii
])
else
:
w2_marlin_in
=
w8a8_nt_kpack2_marlin_weight
(
layer
.
w2_weight
[
ii
])
...
...
@@ -175,7 +180,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
layer
.
w13_weight
=
Parameter
(
w1_marlin
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w2_marlin
,
requires_grad
=
False
)
def
groupgemm_workspace_shapes
(
self
,
def
masked_
groupgemm_workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
...
...
@@ -198,7 +203,26 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
output
=
(
num_experts
,
max_num_tokens
*
num_dispatchers
,
K
)
return
(
workspace13
,
workspace2
,
output
,
a
.
dtype
)
def
w8a8_groupgemm_forward
(
self
,
def
contiguous_groupgemm_workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_num_tokens_cpu
:
torch
.
Tensor
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
assert
self
.
block_shape
is
not
None
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.
block_m
=
self
.
block_shape
[
0
]
M_sum
=
compute_aligned_M
(
M
,
topk
,
local_num_experts
,
block_m
,
expert_num_tokens_cpu
)
assert
M_sum
%
block_m
==
0
workspace1
=
(
M_sum
,
max
(
N
,
K
))
workspace2
=
(
M_sum
,
max
(
N
//
2
,
K
))
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
,
a
.
dtype
,
M_sum
)
def
w8a8_groupgemm_masked_forward
(
self
,
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
...
...
@@ -217,6 +241,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
routed_scaling_factor
:
Optional
[
float
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
q_x
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_num_tokens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
):
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
...
...
@@ -227,7 +252,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
N
,
K
=
self
.
N
,
self
.
K
(
workspace13_shape
,
workspace2_shape
,
fused_out_shape
,
workspace_dtype
)
=
self
.
groupgemm_workspace_shapes
(
workspace_dtype
)
=
self
.
masked_
groupgemm_workspace_shapes
(
x
,
q_x
,
max_num_tokens
,
N
,
K
,
top_k
,
global_num_experts
,
local_num_experts
)
workspace13
=
torch
.
empty
(
prod
(
workspace13_shape
),
...
...
@@ -266,6 +291,93 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
return
fused_out
def
w8a8_groupgemm_contiguous_forward
(
self
,
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
q_x
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_num_tokens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
):
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
local_num_experts
=
w1
.
size
(
0
)
a1q
=
q_x
N
,
K
=
self
.
N
,
self
.
K
(
workspace13_shape
,
workspace2_shape
,
fused_out_shape
,
workspace_dtype
,
M_sum
)
=
self
.
contiguous_groupgemm_workspace_shapes
(
x
,
q_x
,
topk_ids
.
size
(
0
),
N
,
K
,
topk_ids
.
size
(
1
),
global_num_experts
,
local_num_experts
,
expert_num_tokens_cpu
)
workspace13
=
torch
.
empty
(
prod
(
workspace13_shape
),
device
=
x
.
device
,
dtype
=
workspace_dtype
)
workspace2
=
torch
.
empty
(
prod
(
workspace2_shape
),
device
=
x
.
device
,
dtype
=
workspace_dtype
)
mm1_out
=
_resize_cache
(
workspace13
,
(
M_sum
,
N
))
mm2_out
=
_resize_cache
(
workspace2
,
(
M_sum
,
K
))
fused_out
=
_resize_cache
(
workspace13
,
fused_out_shape
)
a1q_perm
=
_resize_cache
(
workspace2
.
view
(
dtype
=
a1q
.
dtype
),
(
M_sum
,
K
))
a1q
,
a1q_scale
,
expert_ids
,
inv_perm
=
deepgemm_moe_permute
(
aq
=
a1q
,
aq_scale
=
a1_scale
,
topk_ids
=
topk_ids
,
local_num_experts
=
local_num_experts
,
expert_map
=
expert_map
,
block_shape
=
self
.
block_shape
,
expert_num_tokens
=
expert_num_tokens
,
expert_num_tokens_cpu
=
expert_num_tokens_cpu
,
aq_out
=
a1q_perm
,
)
# if expert_map is not None:
# # DeepGemm (Grouped Contiguous) kernel needs a valid B index
# # for all rows of A. To that effect, simply compute with
# # the 0th weight matrix.
# # Note that this relies on the fact that corresponding topk
# # weights would be 0 during weight multiplication.
# expert_ids = torch.where(expert_ids == -1, 0, expert_ids)
m_grouped_w8a8_gemm_nt_contig_asm
(
(
a1q
,
a1q_scale
),
(
w1
,
w1_scale
),
mm1_out
,
expert_ids
)
a2q
,
a2q_scale
=
fuse_silu_mul_quant
(
mm1_out
)
m_grouped_w8a8_gemm_nt_contig_asm
(
(
a2q
,
a2q_scale
),
(
w2
,
w2_scale
),
mm2_out
,
expert_ids
)
if
apply_router_weight_on_input
:
topk_weights
=
torch
.
ones_like
(
topk_weights
)
deepgemm_unpermute_and_reduce
(
a
=
mm2_out
,
topk_ids
=
topk_ids
,
topk_weights
=
topk_weights
,
inv_perm
=
inv_perm
,
expert_map
=
expert_map
,
output
=
fused_out
,
)
return
fused_out
def
fused_moe_forward
(
self
,
x
:
torch
.
Tensor
,
...
...
@@ -286,6 +398,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
routed_scaling_factor
:
Optional
[
float
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
q_x
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_num_tokens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
):
return
fused_experts_impl_int8_marlin
(
hidden_states
=
x
if
q_x
is
None
else
q_x
,
...
...
@@ -398,7 +511,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
return
TritonOrGroupGemmExperts
(
use_int8_w8a8
=
True
,
per_act_token_quant
=
True
,
fused_experts
=
self
.
w8a8_groupgemm_forward
fused_experts
=
self
.
w8a8_groupgemm_
masked_
forward
)
else
:
logger
.
debug
(
...
...
@@ -407,5 +520,6 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
False
)
return
TritonOrGroupGemmExperts
(
fused_experts
=
self
.
fused_moe_forward
use_int8_w8a8
=
envs
.
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
,
fused_experts
=
self
.
w8a8_groupgemm_contiguous_forward
if
envs
.
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
else
self
.
fused_moe_forward
)
\ No newline at end of file
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
View file @
3833018c
...
...
@@ -168,6 +168,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
(
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_high_throughput"
or
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
)
self
.
ep_size
=
get_ep_group
().
world_size
if
self
.
use_deepep
:
all2all_manager
=
get_ep_group
().
device_communicator
.
all2all_manager
assert
all2all_manager
is
not
None
...
...
@@ -352,7 +354,9 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
# (from deepgemm docs) : A value hint (which is a value on CPU)
# for the M expectation of each batch, correctly setting this value
# may lead to better performance.
expected_m
=
max_num_tokens
#expected_m = max_num_tokens
ori_bs
=
x
.
shape
[
0
]
expected_m
=
ori_bs
*
self
.
ep_size
m_grouped_w4a8_gemm_nt_masked
((
q_x
,
a1_scale
),
(
w1
,
w1_scale
),
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
3833018c
...
...
@@ -174,10 +174,11 @@ class DeepseekV2MoE(nn.Module):
dp_size
=
get_dp_group
().
world_size
self
.
use_mori_ep
=
parallel_config
.
enable_expert_parallel
and
dp_size
>
1
and
envs
.
VLLM_ALL2ALL_BACKEND
==
'mori'
self
.
enable_expert_parallel
=
parallel_config
.
enable_expert_parallel
self
.
use_deepep_ll
=
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
and
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
self
.
use_deepep
=
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
and
\
(
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_high_throughput"
or
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
)
if
not
self
.
use_deepep
_ll
:
if
not
self
.
use_deepep
:
moe_cls
=
FusedMoE
if
not
self
.
use_mori_ep
else
MoriMoE
self
.
experts
=
moe_cls
(
num_experts
=
config
.
n_routed_experts
,
...
...
@@ -250,7 +251,7 @@ class DeepseekV2MoE(nn.Module):
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
not
self
.
use_mori_ep
and
not
self
.
use_deepep
_ll
:
if
not
self
.
use_mori_ep
and
not
self
.
use_deepep
:
if
self
.
n_shared_experts
is
not
None
:
if
envs
.
USE_FUSED_RMS_QUANT
:
shared_output
,
new_resi
=
self
.
shared_experts
(
hidden_states
,
rms_weight
,
residual
,
update_hd
=
True
)
...
...
@@ -285,7 +286,7 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states
=
final_hidden_states
+
shared_output
\
*
(
1.
/
self
.
routed_scaling_factor
)
else
:
if
self
.
use_deepep
_ll
:
if
self
.
use_deepep
:
shared_output
,
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
...
...
@@ -717,8 +718,9 @@ class DeepseekV2DecoderLayer(nn.Module):
self
.
dp_size
=
get_dp_group
().
world_size
vllm_config
=
get_current_vllm_config
()
parallel_config
=
vllm_config
.
parallel_config
self
.
use_deepep_ll
=
self
.
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
and
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
self
.
use_deepep
=
self
.
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
and
\
(
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_high_throughput"
or
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
if
(
config
.
n_routed_experts
is
not
None
...
...
@@ -847,7 +849,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states
,
residual
)
if
isinstance
(
self
.
mlp
,
DeepseekV2MoE
)
and
self
.
use_deepep
_ll
and
self
.
tp_size
>
1
:
DeepseekV2MoE
)
and
self
.
use_deepep
and
self
.
tp_size
>
1
:
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
ori_bs
=
hidden_states
.
shape
[
0
]
...
...
@@ -860,7 +862,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states
=
self
.
mlp
(
hidden_states
)
if
isinstance
(
self
.
mlp
,
DeepseekV2MoE
)
and
self
.
use_deepep
_ll
and
self
.
tp_size
>
1
:
DeepseekV2MoE
)
and
self
.
use_deepep
and
self
.
tp_size
>
1
:
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
,
dim
=
0
).
contiguous
()
hidden_states
=
hidden_states
[:
ori_bs
,
:].
contiguous
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment