Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8943d3db
Commit
8943d3db
authored
Dec 17, 2025
by
yangql
Browse files
解决deep的auto冲突
parents
0d3ae2fc
ab1acdce
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1121 additions
and
222 deletions
+1121
-222
vllm/distributed/device_communicators/all2all.py
vllm/distributed/device_communicators/all2all.py
+3
-0
vllm/envs.py
vllm/envs.py
+4
-0
vllm/model_executor/layers/fused_moe/deepep_auto_prepare_finalize.py
...executor/layers/fused_moe/deepep_auto_prepare_finalize.py
+53
-41
vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
+209
-76
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
+68
-21
vllm/model_executor/layers/fused_moe/modular_kernel.py
vllm/model_executor/layers/fused_moe/modular_kernel.py
+132
-48
vllm/model_executor/layers/fused_moe/triton_group_gemm_moe.py
.../model_executor/layers/fused_moe/triton_group_gemm_moe.py
+2
-0
vllm/model_executor/layers/fused_moe/utils.py
vllm/model_executor/layers/fused_moe/utils.py
+502
-2
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
...ation/compressed_tensors/compressed_tensors_moe_marlin.py
+130
-16
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
...del_executor/layers/quantization/slimquant_w4a8_marlin.py
+5
-1
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+13
-17
No files found.
vllm/distributed/device_communicators/all2all.py
View file @
8943d3db
...
...
@@ -173,6 +173,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
if
self
.
internode
:
num_rdma_bytes
=
int
(
1e9
/
2
)
#1024 * 1024 * 1024
num_qps_per_rank
=
30
#self.num_sms // 2
self
.
num_sms
=
30
# import deep_ep
# num_nvl_bytes, num_rdma_bytes = 0, 0
...
...
@@ -184,6 +185,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
else
:
num_rdma_bytes
=
0
num_qps_per_rank
=
1
self
.
num_sms
=
60
assert
num_rdma_bytes
is
not
None
assert
num_qps_per_rank
is
not
None
...
...
@@ -192,6 +194,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
num_rdma_bytes
=
num_rdma_bytes
,
low_latency_mode
=
False
,
num_qps_per_rank
=
num_qps_per_rank
,
allow_mnnvl
=
envs
.
VLLM_ALLOW_MNNVL
,
explicitly_destroy
=
False
)
def
get_handle
(
self
,
kwargs
):
...
...
vllm/envs.py
View file @
8943d3db
...
...
@@ -180,6 +180,7 @@ if TYPE_CHECKING:
VLLM_USE_PD_SPLIT
:
bool
=
False
VLLM_USE_PP_BALANCE
:
bool
=
False
VLLM_USE_ZERO_MTP
:
bool
=
False
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
:
bool
=
True
def
get_default_cache_root
():
return
os
.
getenv
(
...
...
@@ -1181,6 +1182,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_ZERO_MTP"
:
lambda
:
(
os
.
getenv
(
'VLLM_USE_ZERO_MTP'
,
'1'
).
lower
()
in
(
"true"
,
"1"
)),
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM"
:
lambda
:
(
os
.
getenv
(
'VLLM_ENABLE_DEEPEP_HT_DEEPGEMM'
,
'1'
).
lower
()
in
(
"true"
,
"1"
)),
}
# --8<-- [end:env-vars-definition]
...
...
vllm/model_executor/layers/fused_moe/deepep_auto_prepare_finalize.py
View file @
8943d3db
...
...
@@ -21,11 +21,13 @@ class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
super
().
__init__
()
self
.
ht_prepare_finalize
=
ht_prepare_finalize
self
.
ll_prepare_finalize
=
ll_prepare_finalize
self
.
_current_phase
=
"decode"
# default to
prefill (HT
)
self
.
_current_phase
=
"decode"
# default to
decode (LL
)
def
_get_current_prepare_finalize
(
self
)
->
mk
.
FusedMoEPrepareAndFinalize
:
"""Get the appropriate prepare_finalize based on current phase."""
# Try to infer phase from forward_context if available
# Try to infer phase from forward_context if available:
# - 有 decode tokens -> 使用 LL (decode)
# - 否则默认 HT (prefill)
try
:
forward_context
=
get_forward_context
()
attn_metadata
=
forward_context
.
attn_metadata
...
...
@@ -36,44 +38,60 @@ class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
else
:
attn_metadata
=
None
if
attn_metadata
is
not
None
and
hasattr
(
attn_metadata
,
'num_prefill_tokens'
)
and
hasattr
(
attn_metadata
,
'num_decode_tokens'
):
# Only use prefill mode when BOTH conditions are met:
# 1. There are prefill tokens and no decode tokens
# 2. skip_cuda_graphs is True
is_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
>
0
and
attn_metadata
.
num_decode_tokens
==
0
skip_cuda_graphs
=
forward_context
.
skip_cuda_graphs
# Only use prefill (HT) when both conditions are satisfied
self
.
_current_phase
=
"prefill"
if
(
is_prefill_tokens
and
skip_cuda_graphs
)
else
"decode"
if
attn_metadata
is
not
None
and
hasattr
(
attn_metadata
,
"num_decode_tokens"
):
# 只根据 decode tokens 判定:有 decode -> decode,否则 prefill
self
.
_current_phase
=
(
"decode"
if
attn_metadata
.
num_decode_tokens
>
0
else
"prefill"
)
except
Exception
:
# If forward_context is not available, use stored phase
pass
# Prefill uses HT, decode uses LL
# print("
self._current_phase
",self._current_phase)
# if self._current_phase == "prefill":
if
self
.
_current_phase
==
"prefill"
:
print
(
"************prefill***********"
)
# return self.ht_prepare_finalize
# else:
return
self
.
ll_prepare_finalize
#
return self.ll_prepare_finalize
return
self
.
ht_prepare_finalize
@
property
def
activation_format
(
self
)
->
mk
.
FusedMoEActivationFormat
:
# Use the current prepare_finalize's activation format
# Note: HT uses Standard, LL uses BatchedExperts
# Dynamically return based on current phase
prepare_finalize
=
self
.
_get_current_prepare_finalize
()
return
prepare_finalize
.
activation_format
pf
=
self
.
_get_current_prepare_finalize
()
try
:
return
pf
.
activation_format
except
NotImplementedError
:
# Fallback to standard format if underlying impl does not provide it.
return
mk
.
FusedMoEActivationFormat
.
Standard
def
topk_indices_dtype
(
self
)
->
Optional
[
torch
.
dtype
]:
# Both HT and LL return int64
return
torch
.
int64
pf
=
self
.
_get_current_prepare_finalize
()
return
pf
.
topk_indices_dtype
()
def
max_num_tokens_per_rank
(
self
)
->
Optional
[
int
]:
# LL has a limit, HT returns None
return
self
.
ll_prepare_finalize
.
max_num_tokens_per_rank
()
pf
=
self
.
_get_current_prepare_finalize
()
return
pf
.
max_num_tokens_per_rank
()
def
num_dispatchers
(
self
)
->
int
:
# Both should return the same value
return
self
.
ht_prepare_finalize
.
num_dispatchers
()
pf
=
self
.
_get_current_prepare_finalize
()
return
pf
.
num_dispatchers
()
def
prepare_async
(
self
,
a1
:
torch
.
Tensor
,
a1_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
):
pf
=
self
.
_get_current_prepare_finalize
()
return
pf
.
prepare_async
(
a1
,
a1_scale
,
a2_scale
,
topk_weights
,
topk_ids
,
num_experts
,
expert_map
,
apply_router_weight_on_input
,
quant_config
)
def
prepare
(
self
,
...
...
@@ -88,9 +106,8 @@ class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
quant_config
:
FusedMoEQuantConfig
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
"""Route prepare call to the appropriate implementation."""
prepare_finalize
=
self
.
_get_current_prepare_finalize
()
return
prepare_finalize
.
prepare
(
pf
=
self
.
_get_current_prepare_finalize
()
return
pf
.
prepare
(
a1
,
a1_scale
,
a2_scale
,
topk_weights
,
topk_ids
,
num_experts
,
expert_map
,
apply_router_weight_on_input
,
quant_config
)
...
...
@@ -103,9 +120,8 @@ class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input
:
bool
,
apply_weights_and_reduce
:
bool
=
True
)
->
None
:
"""Route finalize call to the appropriate implementation."""
prepare_finalize
=
self
.
_get_current_prepare_finalize
()
return
prepare_finalize
.
finalize
(
pf
=
self
.
_get_current_prepare_finalize
()
return
pf
.
finalize
(
output
,
fused_expert_output
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
apply_weights_and_reduce
)
...
...
@@ -118,15 +134,11 @@ class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input
:
bool
,
apply_weights_and_reduce
:
bool
=
True
):
"""Route finalize_async call to the appropriate implementation if available."""
prepare_finalize
=
self
.
_get_current_prepare_finalize
()
if
hasattr
(
prepare_finalize
,
'finalize_async'
):
return
prepare_finalize
.
finalize_async
(
pf
=
self
.
_get_current_prepare_finalize
()
if
hasattr
(
pf
,
"finalize_async"
):
return
pf
.
finalize_async
(
output
,
fused_expert_output
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
apply_weights_and_reduce
)
else
:
# Fallback to synchronous finalize
return
prepare_finalize
.
finalize
(
return
pf
.
finalize
(
output
,
fused_expert_output
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
apply_weights_and_reduce
)
vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
View file @
8943d3db
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
from
collections.abc
import
Callable
import
deep_ep
import
torch
...
...
@@ -58,39 +59,49 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return
None
return
deep_ep
.
Buffer
.
get_combine_config
(
self
.
dp_size
)
def
sync
(
self
):
# torch.cuda.synchronize()
dist
.
barrier
()
def
_do_dispatch
(
self
,
tokens
:
torch
.
Tensor
,
token_scales
:
Optional
[
torch
.
Tensor
],
def
_do_dispatch
(
self
,
tokens
:
torch
.
Tensor
,
token_scales
:
torch
.
Tensor
|
None
,
rank_topk_ids
:
torch
.
Tensor
,
rank_topk_weights
:
torch
.
Tensor
,
num_experts
:
int
):
rank_topk_weights
:
torch
.
Tensor
,
num_experts
:
int
,
quant_config
:
FusedMoEQuantConfig
,
)
->
Callable
:
has_scales
=
token_scales
is
not
None
(
num_tokens_per_rank
,
num_tokens_per_rdma_rank
,
expert_num_tokens
,
is_token_in_rank
,
event
)
=
self
.
buffer
.
get_dispatch_layout
(
(
num_tokens_per_rank
,
num_tokens_per_rdma_rank
,
dispatch_expert_num_tokens
,
is_token_in_rank
,
event
,
)
=
self
.
buffer
.
get_dispatch_layout
(
topk_idx
=
rank_topk_ids
,
num_experts
=
num_experts
,
previous_event
=
None
,
async_finish
=
False
,
allocate_on_comm_stream
=
False
)
allocate_on_comm_stream
=
False
,
)
token_data
=
tokens
if
has_scales
:
token_data
=
(
tokens
,
token_scales
)
(
token_data
,
expert_topk_ids
,
expert_topk_weights
,
expert_num_tokens_per_expert_list
,
self
.
handle
,
event
token_data
,
expert_topk_ids
,
expert_topk_weights
,
expert_num_tokens_per_expert_list
,
self
.
handle
,
event
,
)
=
self
.
buffer
.
dispatch
(
x
=
token_data
,
handle
=
None
,
num_tokens_per_rank
=
num_tokens_per_rank
,
num_tokens_per_rdma_rank
=
num_tokens_per_rdma_rank
,
is_token_in_rank
=
is_token_in_rank
,
num_tokens_per_expert
=
expert_num_tokens
,
num_tokens_per_expert
=
dispatch_
expert_num_tokens
,
topk_idx
=
rank_topk_ids
,
topk_weights
=
rank_topk_weights
,
# expert_alignment rounds the number of tokens per expert
...
...
@@ -98,8 +109,36 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_alignment
=
1
,
config
=
self
.
_get_dispatch_config
(),
previous_event
=
None
,
async_finish
=
False
,
allocate_on_comm_stream
=
False
)
async_finish
=
True
,
allocate_on_comm_stream
=
False
,
)
return
lambda
:
self
.
_receiver
(
event
,
has_scales
,
token_data
,
expert_topk_ids
,
num_experts
,
expert_num_tokens_per_expert_list
,
expert_topk_weights
,
token_scales
,
quant_config
,
)
def
_receiver
(
self
,
event
:
deep_ep
.
EventOverlap
,
has_scales
:
bool
,
token_data
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
torch
.
Tensor
,
expert_topk_ids
:
torch
.
Tensor
|
None
,
num_experts
:
int
,
expert_num_tokens_per_expert_list
:
list
[
int
],
expert_topk_weights
:
torch
.
Tensor
|
None
,
a1_scale
:
torch
.
Tensor
|
None
,
quant_config
:
FusedMoEQuantConfig
,
)
->
mk
.
PrepareResultType
:
if
event
.
event
is
not
None
:
event
.
current_stream_wait
()
if
has_scales
:
expert_x
,
expert_x_scale
=
token_data
...
...
@@ -117,15 +156,45 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# DeepEP's topk_ids output refers to the local experts directly. Offset
# the topk_ids to move it back to the global experts space so it aligns
# with existing vLLM interfaces.
assert
expert_topk_ids
is
not
None
expert_topk_ids
=
torch
.
where
(
expert_topk_ids
==
-
1
,
num_experts
-
1
if
self
.
rank_expert_offset
==
0
else
0
,
expert_topk_ids
+
self
.
rank_expert_offset
)
expert_topk_ids
+
self
.
rank_expert_offset
,
)
return
(
expert_x
,
expert_x_scale
,
expert_num_tokens
,
expert_topk_ids
,
expert_topk_weights
)
# Makes a GPU-CPU copy.
# TODO (varun): Maybe it is better to re-compute the expert_num_tokens
# on GPU.
expert_tokens_meta
=
mk
.
ExpertTokensMetadata
.
make_from_list
(
expert_num_tokens_per_expert_list
,
device
=
expert_x
.
device
)
def
prepare
(
# Dispatch and Quant
# DeepEP kernels only support dispatching block-quantized
# activation scales.
# Dispatch in bfloat16 and quantize afterwards
if
not
quant_config
.
per_act_token_quant
:
# Quantize after dispatch.
expert_x_scale
=
None
if
expert_x
.
numel
()
!=
0
:
expert_x
,
expert_x_scale
=
moe_kernel_quantize_input
(
expert_x
,
a1_scale
,
quant_dtype
=
quant_config
.
quant_dtype
,
per_act_token_quant
=
False
,
block_shape
=
quant_config
.
block_shape
,
)
return
(
expert_x
,
expert_x_scale
,
expert_tokens_meta
,
expert_topk_ids
,
expert_topk_weights
,
)
def
prepare_async
(
self
,
a1
:
torch
.
Tensor
,
a1_scale
:
Optional
[
torch
.
Tensor
],
...
...
@@ -136,14 +205,13 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map
:
Optional
[
torch
.
Tensor
],
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
)
->
mk
.
ReceiverType
:
if
apply_router_weight_on_input
:
topk
=
topk_ids
.
size
(
1
)
# TODO: this only works for topK=1, will need to update for topK>1
assert
topk
==
1
,
(
"apply_router_weight_on_input is only implemented for topk=1"
)
"apply_router_weight_on_input is only implemented for topk=1"
)
a1
=
a1
*
topk_weights
.
to
(
a1
.
dtype
)
if
quant_config
.
per_act_token_quant
:
...
...
@@ -156,35 +224,43 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
)
if
a1q_scale
is
not
None
and
a1q_scale
.
numel
()
==
1
:
a1q_scale
=
a1q_scale
.
view
(
1
,
1
)
(
expert_x
,
expert_x_scale
,
expert_num_tokens
,
expert_topk_ids
,
expert_topk_weights
)
=
self
.
_do_dispatch
(
else
:
a1q
=
a1
a1q_scale
=
None
return
self
.
_do_dispatch
(
tokens
=
a1q
,
token_scales
=
a1q_scale
,
rank_topk_ids
=
topk_ids
,
rank_topk_weights
=
topk_weights
,
num_experts
=
num_experts
)
else
:
# DeepEP kernels only support dispatching per-token-quant
# quantization. dispatch in bfloat16.
(
expert_x
,
_
,
expert_num_tokens
,
expert_topk_ids
,
expert_topk_weights
)
=
self
.
_do_dispatch
(
tokens
=
a1
,
token_scales
=
None
,
rank_topk_ids
=
topk_ids
,
rank_topk_weights
=
topk_weights
,
num_experts
=
num_experts
)
# quantize now
expert_x_scale
=
None
if
expert_x
.
numel
()
!=
0
:
expert_x
,
expert_x_scale
=
moe_kernel_quantize_input
(
expert_x
,
a1_scale
,
quant_dtype
=
quant_config
.
quant_dtype
,
per_act_token_quant
=
False
,
block_shape
=
quant_config
.
block_shape
)
num_experts
=
num_experts
,
quant_config
=
quant_config
,
)
return
(
expert_x
,
expert_x_scale
,
expert_num_tokens
,
expert_topk_ids
,
expert_topk_weights
)
def
prepare
(
self
,
a1
:
torch
.
Tensor
,
a1_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
)
->
mk
.
PrepareResultType
:
receiver
=
self
.
prepare_async
(
a1
,
a1_scale
,
a2_scale
,
topk_weights
,
topk_ids
,
num_experts
,
expert_map
,
apply_router_weight_on_input
,
quant_config
,
)
return
receiver
()
def
_apply_weights_and_reduce
(
self
,
num_tokens
:
int
,
fused_expert_output
:
torch
.
Tensor
,
...
...
@@ -210,31 +286,88 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return
out
def
finalize
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
def
_finalize
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
apply_weights_and_reduce
:
bool
=
True
)
->
None
:
do_async
:
bool
,
apply_weights_and_reduce
:
bool
=
True
,
)
->
Callable
|
None
:
assert
self
.
handle
is
not
None
# fused_expert_output can have 0 tokens - This happens when none of the
# tokens from the all2all reach this EP rank.
if
fused_expert_output
.
numel
()
!=
0
and
apply_weights_and_reduce
:
fused_expert_output
=
self
.
_apply_weights_and_reduce
(
num_tokens
=
topk_ids
.
size
(
0
),
fused_expert_output
=
fused_expert_output
,
topk_weights
=
topk_weights
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
output_dtype
=
output
.
dtype
)
# if fused_expert_output.numel() != 0 and apply_weights_and_reduce:
# fused_expert_output = self._apply_weights_and_reduce(
# num_tokens=topk_ids.size(0),
# fused_expert_output=fused_expert_output,
# topk_weights=topk_weights,
# apply_router_weight_on_input=apply_router_weight_on_input,
# output_dtype=output.dtype)
combined_x
,
_
,
event
=
self
.
buffer
.
combine
(
# HT combine only supports BF16
x
=
fused_expert_output
,
handle
=
self
.
handle
,
topk_weights
=
None
,
config
=
self
.
_get_combine_config
(),
previous_event
=
None
,
async_finish
=
False
,
allocate_on_comm_stream
=
False
)
async_finish
=
do_async
,
allocate_on_comm_stream
=
False
,
)
if
do_async
:
def
_receiver
():
if
event
.
event
is
not
None
:
event
.
current_stream_wait
()
# Respect inplace outputs.
output
.
copy_
(
combined_x
,
non_blocking
=
True
)
return
_receiver
else
:
# Respect inplace outputs.
output
.
copy_
(
combined_x
,
non_blocking
=
True
)
return
None
def
finalize_async
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
apply_weights_and_reduce
:
bool
=
True
,
)
->
Callable
:
receiver
=
self
.
_finalize
(
output
,
fused_expert_output
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
do_async
=
True
,
apply_weights_and_reduce
=
apply_weights_and_reduce
,
)
assert
receiver
is
not
None
return
receiver
def
finalize
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
apply_weights_and_reduce
:
bool
=
True
,
)
->
None
:
self
.
_finalize
(
output
,
fused_expert_output
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
do_async
=
False
,
apply_weights_and_reduce
=
apply_weights_and_reduce
,
)
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
View file @
8943d3db
...
...
@@ -115,7 +115,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return
x
,
x_scales
def
prepare
(
def
prepare
_async
(
self
,
a1
:
torch
.
Tensor
,
a1_scale
:
Optional
[
torch
.
Tensor
],
...
...
@@ -126,9 +126,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map
:
Optional
[
torch
.
Tensor
],
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
)
->
tuple
[
Callable
,
mk
.
ReceiverType
]:
hidden_size
=
a1
.
size
(
1
)
assert
hidden_size
in
self
.
SUPPORTED_HIDDEN_SIZES
,
\
(
f
"Hidden Size
{
hidden_size
}
not in supported list of hidden sizes"
...
...
@@ -148,25 +146,74 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk
=
topk_ids
.
size
(
1
)
# TODO: this only works for topK=1, will need to update for topK>1
assert
topk
==
1
,
(
"apply_router_weight_on_input is only implemented for topk=1"
)
"apply_router_weight_on_input is only implemented for topk=1"
)
a1
=
a1
*
topk_weights
.
to
(
a1
.
dtype
)
# Dispatch
expert_x
,
expert_num_tokens
,
self
.
handle
,
event
,
hook
=
\
self
.
buffer
.
low_latency_dispatch
(
a1
,
expert_x
,
expert_num_tokens
,
self
.
handle
s
,
_
,
hook
=
self
.
buffer
.
low_latency_dispatch
(
a1
,
topk_ids
,
self
.
max_tokens_per_rank
,
num_experts
,
use_fp8
=
self
.
use_fp8_dispatch
or
self
.
use_int8_dispatch
,
use_int8
=
self
.
use_int8_dispatch
,
async_finish
=
False
,
return_recv_hook
=
False
)
return_recv_hook
=
True
,
)
return
(
hook
,
lambda
:
self
.
_receiver
(
expert_x
,
expert_num_tokens
,
a1_scale
,
a1
.
dtype
,
quant_config
,
),
)
def
_receiver
(
self
,
expert_x
:
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
expert_num_tokens
:
torch
.
Tensor
,
a1_scale
:
torch
.
Tensor
|
None
,
a1_dtype
:
torch
.
dtype
,
quant_config
:
FusedMoEQuantConfig
,
)
->
mk
.
PrepareResultType
:
expert_x
,
expert_x_scale
=
self
.
_do_quant
(
expert_x
,
a1_dtype
,
quant_config
)
expert_
x
,
expert_x_scale
=
self
.
_do_quant
(
expert_
x
,
a1_scale
,
a2_scale
,
a1
.
dtype
,
quant_config
.
quant_dtype
,
quant_config
.
per_act_token_quant
,
quant_config
.
block_shape
,
expert_num_tokens
)
expert_
tokens_meta
=
mk
.
ExpertTokensMetadata
(
expert_
num_tokens
=
expert_num_tokens
,
expert_num_tokens_cpu
=
None
)
return
(
expert_x
,
expert_x_scale
,
expert_num_tokens
,
None
,
None
)
return
expert_x
,
expert_x_scale
,
expert_tokens_meta
,
None
,
None
def
prepare
(
self
,
a1
:
torch
.
Tensor
,
a1_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
)
->
mk
.
PrepareResultType
:
hook
,
receiver
=
self
.
prepare_async
(
a1
,
a1_scale
,
a2_scale
,
topk_weights
,
topk_ids
,
num_experts
,
expert_map
,
apply_router_weight_on_input
,
quant_config
,
)
hook
()
return
receiver
()
def
_finalize
(
self
,
...
...
vllm/model_executor/layers/fused_moe/modular_kernel.py
View file @
8943d3db
...
...
@@ -4,13 +4,15 @@ from abc import ABC, abstractmethod
from
enum
import
Enum
from
math
import
prod
from
typing
import
Optional
,
final
from
dataclasses
import
dataclass
from
collections.abc
import
Callable
import
torch
import
vllm.envs
as
envs
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.utils
import
_resize_cache
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
,
async_tensor_h2d
#
# This file defines a set of base classes used to make MoE kernels more modular.
...
...
@@ -95,6 +97,57 @@ class FusedMoEActivationFormat(Enum):
BatchedExperts
=
"batched_experts"
,
@
dataclass
class
ExpertTokensMetadata
:
"""
Metadata regarding expert-token routing.
"""
expert_num_tokens
:
torch
.
Tensor
expert_num_tokens_cpu
:
torch
.
Tensor
|
None
@
staticmethod
def
make_from_list
(
expert_num_tokens_list
:
list
[
int
],
device
:
str
)
->
"ExpertTokensMetadata"
:
# expert_num_tokens_cpu = torch.tensor(
# expert_num_tokens_list, device="cpu", dtype=torch.int32
# )
expert_num_tokens_cpu
=
torch
.
tensor
(
expert_num_tokens_list
,
device
=
"cpu"
,
dtype
=
torch
.
int32
,
pin_memory
=
True
)
expert_num_tokens
=
expert_num_tokens_cpu
.
to
(
device
=
device
,
non_blocking
=
True
)
return
ExpertTokensMetadata
(
expert_num_tokens
=
expert_num_tokens
,
expert_num_tokens_cpu
=
expert_num_tokens_cpu
,
)
#
# PrepareResultType is a tuple of:
# - quantized + dispatched a.
# - quantized + dispatched a1_scales.
# - Optional ExpertTokensMetadata containing gpu/cpu tensors
# as big as the number of local experts with the information about the
# number of tokens assigned to each local expert.
# - Optional dispatched expert topk IDs
# - Optional dispatched expert topk weight
#
# See `prepare` method below.
#
PrepareResultType
=
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
,
ExpertTokensMetadata
|
None
,
torch
.
Tensor
|
None
,
torch
.
Tensor
|
None
,
]
ReceiverType
=
Callable
[[],
PrepareResultType
]
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
class
FusedMoEPrepareAndFinalize
(
ABC
):
"""
...
...
@@ -880,8 +933,19 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
if
global_num_experts
==
-
1
:
global_num_experts
=
local_num_experts
(
a1q
,
a1q_scale
,
expert_num_tokens
,
_expert_topk_ids
,
_expert_topk_weights
)
=
self
.
prepare_finalize
.
prepare
(
# (a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
# _expert_topk_weights) = self.prepare_finalize.prepare(
# a1,
# a1_scale,
# a2_scale,
# topk_weights,
# topk_ids,
# global_num_experts,
# expert_map,
# apply_router_weight_on_input,
# self.fused_experts.quant_config,
# )
prepare_ret
=
self
.
prepare_finalize
.
prepare_async
(
a1
,
a1_scale
,
a2_scale
,
...
...
@@ -892,12 +956,35 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input
,
self
.
fused_experts
.
quant_config
,
)
hook
,
receiver
=
(
prepare_ret
if
isinstance
(
prepare_ret
,
tuple
)
else
(
None
,
prepare_ret
)
)
if
hook
is
not
None
:
hook
()
(
a1q
,
a1q_scale
,
expert_tokens_meta
,
_expert_topk_ids
,
_expert_topk_weights
,
)
=
receiver
()
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
topk_ids
=
topk_ids
if
_expert_topk_ids
is
None
else
_expert_topk_ids
topk_weights
=
(
topk_weights
if
_expert_topk_weights
is
None
else
_expert_topk_weights
)
if
a1q
.
numel
()
==
0
:
# This happens when none of the tokens from the all2all reach this
# EP rank. Also, note that this is only relevant for CUDAGraph
# incompatible all2all kernels like the DeepEP high-throughput
# kernels. CUDAGraph compatible all2all kernels like the pplx
# kernels and the DeepEP low-latency kernels are always batched
# and can never run into the tensor.numel() == 0 case.
fused_out
=
torch
.
empty_like
(
a1q
).
to
(
dtype
=
a1
.
dtype
)
else
:
fused_out
=
self
.
fused_experts
.
apply
(
None
,
a1
,
...
...
@@ -918,18 +1005,15 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
workspace13
=
None
,
workspace2
=
None
,
use_nn_moe
=
use_nn_moe
,
expert_num_tokens
=
expert_num_tokens
,
expert_num_tokens
=
expert_
tokens_meta
.
expert_
num_tokens
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
,
expert_num_tokens_cpu
=
expert_tokens_meta
.
expert_num_tokens_cpu
)
shared_output
=
None
if
self
.
shared_experts
is
None
:
self
.
prepare_finalize
.
finalize
(
output
,
fused_out
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
apply_weights_and_reduce
=
False
)
else
:
hook
=
self
.
prepare_finalize
.
finalize_async
(
output
,
fused_out
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
apply_weights_and_reduce
=
Fals
e
)
topk_ids
,
apply_router_weight_on_input
,
apply_weights_and_reduce
=
Tru
e
)
if
self
.
shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
...
...
vllm/model_executor/layers/fused_moe/triton_group_gemm_moe.py
View file @
8943d3db
...
...
@@ -85,6 +85,7 @@ class TritonOrGroupGemmExperts(mk.CustomizedFusedMoEPermuteExpertsUnpermute):
use_nn_moe
:
Optional
[
bool
]
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
expert_num_tokens_cpu
:
torch
.
Tensor
=
None
,
):
assert
self
.
fused_experts
is
not
None
...
...
@@ -107,4 +108,5 @@ class TritonOrGroupGemmExperts(mk.CustomizedFusedMoEPermuteExpertsUnpermute):
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
,
q_x
=
q_hidden_states
,
expert_num_tokens_cpu
=
expert_num_tokens_cpu
)
vllm/model_executor/layers/fused_moe/utils.py
View file @
8943d3db
...
...
@@ -11,6 +11,7 @@ from triton.language.extra import libdevice
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
)
from
vllm.utils
import
round_up
try
:
from
lmslim.layers.gemm.int8_utils
import
(
per_token_group_quant_int8
,
per_token_quant_int8
)
...
...
@@ -276,8 +277,8 @@ def _int8_quantize(
# activations apply per-token quantization. Otherwise, assume
# activation tensor-wise fp8/int8 quantization, dynamic or static
if
block_shape
is
None
:
assert
per_act_token
,
\
"int8 quantization only supports block or channel-wise"
#
assert per_act_token, \
#
"int8 quantization only supports block or channel-wise"
if
expert_num_tokens
is
None
:
A
,
A_scale
=
per_token_quant_int8
(
A
)
else
:
...
...
@@ -361,3 +362,502 @@ def _validate_scale_shape(
assert
block_shape
is
not
None
expected
=
(
a
.
shape
[
0
],
cdiv
(
a
.
shape
[
1
],
block_shape
[
1
]))
assert
a_scale
.
shape
==
expected
,
f
"
{
a_scale
.
shape
}
==
{
expected
}
"
@
triton
.
jit
def
_count_expert_num_tokens
(
topk_ids_ptr
,
expert_num_tokens_ptr
,
num_experts
,
topk_numel
,
expert_map
,
HAS_EXPERT_MAP
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
curr_expert
=
tl
.
program_id
(
0
)
offsets
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
topk_ids_ptrs
=
topk_ids_ptr
+
offsets
acc
=
tl
.
zeros
((
BLOCK_SIZE
,),
dtype
=
tl
.
int32
)
for
x
in
range
(
tl
.
cdiv
(
topk_numel
,
BLOCK_SIZE
)):
mask
=
offsets
<
(
topk_numel
-
x
*
BLOCK_SIZE
)
expert_ids
=
tl
.
load
(
topk_ids_ptrs
,
mask
=
mask
,
other
=-
1
)
if
HAS_EXPERT_MAP
:
expert_map_ptrs
=
expert_map
+
expert_ids
expert_map_mask
=
expert_ids
>=
0
expert_ids
=
tl
.
load
(
expert_map_ptrs
,
mask
=
expert_map_mask
,
other
=-
1
)
has_curr_expert
=
tl
.
where
(
expert_ids
==
curr_expert
,
1
,
0
)
acc
=
acc
+
has_curr_expert
topk_ids_ptrs
+=
BLOCK_SIZE
if
curr_expert
<
num_experts
:
tl
.
store
(
expert_num_tokens_ptr
+
curr_expert
,
tl
.
sum
(
acc
))
def
count_expert_num_tokens
(
topk_ids
:
torch
.
Tensor
,
num_local_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
)
->
torch
.
Tensor
:
"""
Count the number to tokens assigned to each expert.
Parameters:
- topk_ids (torch.Tensor): Tensor mapping each token to its
list of experts.
- num_local_experts (int): Number of experts in this rank.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
Returns:
A tensor of size num_local_experts, where tensor[i] holds the number
of tokens assigned to the ith expert.
"""
assert
topk_ids
.
dtype
.
is_signed
,
"The kernel uses -1 to represent invalid topk_ids"
expert_num_tokens
=
torch
.
empty
(
(
num_local_experts
),
device
=
topk_ids
.
device
,
dtype
=
torch
.
int32
)
grid
=
num_local_experts
BLOCK_SIZE
=
min
(
topk_ids
.
numel
(),
1024
)
BLOCK_SIZE
=
triton
.
next_power_of_2
(
BLOCK_SIZE
)
_count_expert_num_tokens
[(
grid
,)](
topk_ids
,
expert_num_tokens
,
num_local_experts
,
topk_ids
.
numel
(),
expert_map
,
HAS_EXPERT_MAP
=
expert_map
is
not
None
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
return
expert_num_tokens
def
expert_num_tokens_round_up_and_sum
(
expert_num_tokens
:
torch
.
Tensor
,
alignment
:
int
)
->
int
:
# Round up each element in expert_num_tokens to the nearest multiple of
# alignment.
ent
=
(
expert_num_tokens
.
to
(
torch
.
int64
)
+
(
alignment
-
1
))
//
alignment
*
alignment
return
torch
.
sum
(
ent
).
item
()
def
compute_aligned_M
(
M
:
int
,
num_topk
:
int
,
local_num_experts
:
int
,
alignment
:
int
,
expert_num_tokens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
,
):
if
expert_num_tokens_cpu
is
not
None
:
return
expert_num_tokens_round_up_and_sum
(
expert_num_tokens_cpu
,
alignment
=
alignment
)
# expert_num_tokens information is not available on the cpu.
# compute the max required size.
M_sum
=
(
M
*
num_topk
)
+
local_num_experts
*
(
alignment
-
1
)
M_sum
=
round_up
(
M_sum
,
alignment
)
return
M_sum
@
triton
.
jit
def
apply_expert_map
(
expert_id
,
expert_map
):
if
expert_id
!=
-
1
:
expert_id
=
tl
.
load
(
expert_map
+
expert_id
).
to
(
expert_id
.
dtype
)
return
expert_id
@
triton
.
jit
def
round_up_256
(
x
:
int
)
->
int
:
y
=
256
return
((
x
+
y
-
1
)
//
y
)
*
y
@
triton
.
jit
def
round_up_128
(
x
:
int
)
->
int
:
y
=
128
return
((
x
+
y
-
1
)
//
y
)
*
y
@
triton
.
jit
def
_fwd_kernel_ep_scatter_1
(
num_recv_tokens_per_expert
,
expert_start_loc
,
m_indices
,
num_experts
:
tl
.
constexpr
,
BLOCK_E
:
tl
.
constexpr
,
BLOCK_EXPERT_NUM
:
tl
.
constexpr
,
):
cur_expert
=
tl
.
program_id
(
0
)
offset_cumsum
=
tl
.
arange
(
0
,
BLOCK_EXPERT_NUM
)
tokens_per_expert
=
tl
.
load
(
num_recv_tokens_per_expert
+
offset_cumsum
,
mask
=
offset_cumsum
<
num_experts
,
other
=
0
,
)
#tokens_per_expert = round_up_128(tokens_per_expert)
tokens_per_expert
=
round_up_256
(
tokens_per_expert
)
cumsum
=
tl
.
cumsum
(
tokens_per_expert
)
-
tokens_per_expert
#if cur_expert == 0:
tl
.
store
(
expert_start_loc
+
offset_cumsum
,
cumsum
,
mask
=
offset_cumsum
<
num_experts
)
tl
.
debug_barrier
()
#cur_expert_start = cumsum[cur_expert]
cur_expert_start
=
tl
.
load
(
expert_start_loc
+
cur_expert
)
cur_expert_token_num
=
tl
.
load
(
num_recv_tokens_per_expert
+
cur_expert
)
m_indices_start_ptr
=
m_indices
+
cur_expert_start
off_expert
=
tl
.
arange
(
0
,
BLOCK_E
)
for
start_m
in
tl
.
range
(
0
,
cur_expert_token_num
,
BLOCK_E
,
num_stages
=
4
):
tl
.
store
(
m_indices_start_ptr
+
start_m
+
off_expert
,
cur_expert
,
mask
=
start_m
+
off_expert
<
cur_expert_token_num
)
@
triton
.
jit
def
_fwd_kernel_ep_scatter_2
(
total_token_num
,
expert_start_loc
,
recv_x
,
recv_x_stride0
,
recv_x_stride1
,
recv_x_scale
,
recv_x_scale_stride0
,
recv_x_scale_stride1
,
recv_topk
,
recv_topk_stride0
,
recv_topk_stride1
,
output_tensor
,
output_tensor_stride0
,
output_tensor_stride1
,
output_tensor_scale
,
output_tensor_scale_stride0
,
output_tensor_scale_stride1
,
output_index
,
output_index_stride0
,
output_index_stride1
,
topk_num
:
tl
.
constexpr
,
expert_map
,
HAS_EXPERT_MAP
:
tl
.
constexpr
,
HIDDEN_SIZE
:
tl
.
constexpr
,
HIDDEN_SIZE_PAD
:
tl
.
constexpr
,
SCALE_HIDDEN_SIZE
:
tl
.
constexpr
,
SCALE_HIDDEN_SIZE_PAD
:
tl
.
constexpr
,
):
start_token_id
=
tl
.
program_id
(
0
)
grid_num
=
tl
.
num_programs
(
0
)
offset_in
=
tl
.
arange
(
0
,
HIDDEN_SIZE_PAD
)
mask
=
offset_in
<
HIDDEN_SIZE
index_in_s
=
tl
.
arange
(
0
,
SCALE_HIDDEN_SIZE_PAD
)
mask_s
=
index_in_s
<
SCALE_HIDDEN_SIZE
for
token_id_int32
in
range
(
start_token_id
,
total_token_num
,
grid_num
):
token_id
=
token_id_int32
.
to
(
tl
.
int64
)
to_copy
=
tl
.
load
(
recv_x
+
token_id
*
recv_x_stride0
+
offset_in
,
mask
=
mask
)
to_copy_s
=
tl
.
load
(
recv_x_scale
+
token_id
*
recv_x_scale_stride0
+
index_in_s
*
recv_x_scale_stride1
,
mask
=
mask_s
,
)
for
topk_idx_int32
in
tl
.
range
(
0
,
topk_num
,
1
,
num_stages
=
4
):
topk_index
=
topk_idx_int32
.
to
(
tl
.
int64
)
expert_id
=
tl
.
load
(
recv_topk
+
token_id
*
recv_topk_stride0
+
topk_index
)
if
HAS_EXPERT_MAP
:
expert_id
=
apply_expert_map
(
expert_id
,
expert_map
)
if
expert_id
>=
0
:
dest_token_index_int32
=
tl
.
atomic_add
(
expert_start_loc
+
expert_id
,
1
)
dest_token_index
=
dest_token_index_int32
.
to
(
tl
.
int64
)
tl
.
store
(
output_index
+
token_id
*
output_index_stride0
+
topk_index
,
dest_token_index_int32
,
)
output_tensor_ptr
=
(
output_tensor
+
dest_token_index
*
output_tensor_stride0
)
output_tensor_scale_ptr
=
(
output_tensor_scale
+
dest_token_index
*
output_tensor_scale_stride0
)
tl
.
store
(
output_tensor_ptr
+
offset_in
,
to_copy
,
mask
=
mask
)
tl
.
store
(
output_tensor_scale_ptr
+
index_in_s
*
output_tensor_scale_stride1
,
to_copy_s
,
mask
=
mask_s
,
)
@
torch
.
no_grad
()
def
ep_scatter
(
recv_x
:
torch
.
Tensor
,
recv_x_scale
:
torch
.
Tensor
,
recv_topk
:
torch
.
Tensor
,
num_recv_tokens_per_expert
:
torch
.
Tensor
,
expert_map
:
torch
.
Tensor
|
None
,
expert_start_loc
:
torch
.
Tensor
,
output_tensor
:
torch
.
Tensor
,
output_tensor_scale
:
torch
.
Tensor
,
m_indices
:
torch
.
Tensor
,
output_index
:
torch
.
Tensor
,
):
#BLOCK_E = 128 # token num of per expert is aligned to 128
#BLOCK_D = 128 # block size of quantization
BLOCK_E
=
256
# token num of per expert is aligned to 256
num_warps
=
8
num_experts
=
num_recv_tokens_per_expert
.
shape
[
0
]
hidden_size
=
recv_x
.
shape
[
1
]
scale_hidden_size
=
recv_x_scale
.
shape
[
-
1
]
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
grid
=
num_experts
assert
m_indices
.
shape
[
0
]
%
BLOCK_E
==
0
_fwd_kernel_ep_scatter_1
[(
grid
,)](
num_recv_tokens_per_expert
,
expert_start_loc
,
m_indices
,
num_experts
=
num_experts
,
num_warps
=
num_warps
,
BLOCK_E
=
BLOCK_E
,
BLOCK_EXPERT_NUM
=
triton
.
next_power_of_2
(
num_experts
),
)
grid
=
min
(
recv_topk
.
shape
[
0
],
1024
*
8
)
_fwd_kernel_ep_scatter_2
[(
grid
,)](
recv_topk
.
shape
[
0
],
expert_start_loc
,
recv_x
,
recv_x
.
stride
(
0
),
recv_x
.
stride
(
1
),
recv_x_scale
,
recv_x_scale
.
stride
(
0
),
recv_x_scale
.
stride
(
1
),
recv_topk
,
recv_topk
.
stride
(
0
),
recv_topk
.
stride
(
1
),
output_tensor
,
output_tensor
.
stride
(
0
),
output_tensor
.
stride
(
1
),
output_tensor_scale
,
output_tensor_scale
.
stride
(
0
),
output_tensor_scale
.
stride
(
1
),
output_index
,
output_index
.
stride
(
0
),
output_index
.
stride
(
1
),
topk_num
=
recv_topk
.
shape
[
1
],
expert_map
=
expert_map
,
HAS_EXPERT_MAP
=
expert_map
is
not
None
,
num_warps
=
num_warps
,
HIDDEN_SIZE
=
hidden_size
,
HIDDEN_SIZE_PAD
=
triton
.
next_power_of_2
(
hidden_size
),
SCALE_HIDDEN_SIZE
=
scale_hidden_size
,
#hidden_size // BLOCK_D,
SCALE_HIDDEN_SIZE_PAD
=
triton
.
next_power_of_2
(
scale_hidden_size
)
#triton.next_power_of_2(hidden_size // BLOCK_D),
)
return
@
triton
.
jit
def
_fwd_kernel_ep_gather
(
total_token_num
,
input_tensor
,
input_tensor_stride0
,
input_tensor_stride1
,
recv_topk_ids
,
recv_topk_ids_stride0
,
recv_topk_ids_stride1
,
recv_topk_weight
,
recv_topk_weight_stride0
,
recv_topk_weight_stride1
,
input_index
,
input_index_stride0
,
input_index_stride1
,
output_tensor
,
output_tensor_stride0
,
output_tensor_stride1
,
topk_num
:
tl
.
constexpr
,
expert_map
,
HAS_EXPERT_MAP
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
):
cur_block_int32
=
tl
.
program_id
(
0
)
cur_block
=
cur_block_int32
.
to
(
tl
.
int64
)
start_cur_token_int32
=
tl
.
program_id
(
1
)
grid_num
=
tl
.
num_programs
(
1
)
for
cur_token_int32
in
range
(
start_cur_token_int32
,
total_token_num
,
grid_num
):
cur_token
=
cur_token_int32
.
to
(
tl
.
int64
)
off_d
=
tl
.
arange
(
0
,
BLOCK_D
)
accumulator
=
tl
.
zeros
([
BLOCK_D
],
dtype
=
tl
.
float32
)
for
topk_index_int32
in
range
(
0
,
topk_num
):
topk_index
=
topk_index_int32
.
to
(
tl
.
int64
)
expert_id
=
tl
.
load
(
recv_topk_ids
+
cur_token
*
recv_topk_ids_stride0
+
topk_index
)
if
HAS_EXPERT_MAP
:
expert_id
=
apply_expert_map
(
expert_id
,
expert_map
)
if
expert_id
>=
0
:
source_token_index_int32
=
tl
.
load
(
input_index
+
cur_token
*
input_index_stride0
+
topk_index
)
source_token_index
=
source_token_index_int32
.
to
(
tl
.
int64
)
acc_weight
=
tl
.
load
(
recv_topk_weight
+
cur_token
*
recv_topk_weight_stride0
+
topk_index
)
tmp
=
tl
.
load
(
input_tensor
+
source_token_index
*
input_tensor_stride0
+
cur_block
*
BLOCK_D
+
off_d
)
accumulator
+=
tmp
.
to
(
tl
.
float32
)
*
acc_weight
tl
.
store
(
output_tensor
+
cur_token
*
output_tensor_stride0
+
cur_block
*
BLOCK_D
+
off_d
,
accumulator
.
to
(
output_tensor
.
dtype
.
element_ty
),
)
@
torch
.
no_grad
()
def
ep_gather
(
input_tensor
:
torch
.
Tensor
,
recv_topk_ids
:
torch
.
Tensor
,
recv_topk_weight
:
torch
.
Tensor
,
input_index
:
torch
.
Tensor
,
expert_map
:
torch
.
Tensor
|
None
,
output_tensor
:
torch
.
Tensor
,
):
num_warps
=
2
num_tokens
=
output_tensor
.
shape
[
0
]
hidden_size
=
input_tensor
.
shape
[
1
]
BLOCK_D
=
min
(
hidden_size
,
1024
)
assert
hidden_size
%
BLOCK_D
==
0
grid
=
(
triton
.
cdiv
(
hidden_size
,
BLOCK_D
),
min
(
num_tokens
,
1024
))
_fwd_kernel_ep_gather
[
grid
](
num_tokens
,
input_tensor
,
input_tensor
.
stride
(
0
),
input_tensor
.
stride
(
1
),
recv_topk_ids
,
recv_topk_ids
.
stride
(
0
),
recv_topk_ids
.
stride
(
1
),
recv_topk_weight
,
recv_topk_weight
.
stride
(
0
),
recv_topk_weight
.
stride
(
1
),
input_index
,
input_index
.
stride
(
0
),
input_index
.
stride
(
1
),
output_tensor
,
output_tensor
.
stride
(
0
),
output_tensor
.
stride
(
1
),
topk_num
=
recv_topk_ids
.
shape
[
1
],
expert_map
=
expert_map
,
HAS_EXPERT_MAP
=
expert_map
is
not
None
,
num_warps
=
num_warps
,
BLOCK_D
=
BLOCK_D
,
)
return
def
deepgemm_moe_permute
(
aq
:
torch
.
Tensor
,
aq_scale
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
local_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
block_shape
:
list
[
int
],
expert_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_num_tokens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
,
aq_out
:
torch
.
Tensor
|
None
=
None
,
M_sum
:
int
|
None
=
None
,
):
assert
aq
.
ndim
==
2
assert
topk_ids
.
dtype
.
is_signed
,
"The kernel uses -1 to represent invalid topk_ids"
H
=
aq
.
size
(
1
)
device
=
aq
.
device
block_m
=
block_shape
[
0
]
if
M_sum
is
None
:
M_sum
=
compute_aligned_M
(
M
=
topk_ids
.
size
(
0
),
num_topk
=
topk_ids
.
size
(
1
),
local_num_experts
=
local_num_experts
,
alignment
=
block_m
,
expert_num_tokens_cpu
=
expert_num_tokens_cpu
,
)
expert_start_loc
=
torch
.
empty
(
(
local_num_experts
),
device
=
device
,
dtype
=
torch
.
int32
)
assert
aq_out
is
None
or
aq_out
.
shape
==
(
M_sum
,
H
)
if
aq_out
is
None
:
aq_out
=
torch
.
empty
((
M_sum
,
H
),
device
=
device
,
dtype
=
aq
.
dtype
)
aq_scale_out
=
torch
.
empty
(
(
M_sum
,
aq_scale
.
shape
[
-
1
]),
device
=
device
,
dtype
=
torch
.
float32
#(M_sum, H // block_k), device=device, dtype=torch.float32
)
# maybe_has_empty_blocks = expert_num_tokens_cpu is None
# expert_ids_init = torch.zeros# if maybe_has_empty_blocks else torch.empty
# expert_ids = expert_ids_init((M_sum), device=device, dtype=torch.int32)
expert_ids
=
torch
.
full
(
(
M_sum
,),
-
1
,
dtype
=
torch
.
int32
,
device
=
device
)
inv_perm
=
torch
.
empty
(
topk_ids
.
shape
,
device
=
device
,
dtype
=
torch
.
int32
)
if
expert_num_tokens
is
None
:
expert_num_tokens
=
count_expert_num_tokens
(
topk_ids
,
local_num_experts
,
expert_map
)
ep_scatter
(
recv_x
=
aq
,
recv_x_scale
=
aq_scale
,
recv_topk
=
topk_ids
,
num_recv_tokens_per_expert
=
expert_num_tokens
,
expert_start_loc
=
expert_start_loc
,
expert_map
=
expert_map
,
output_tensor
=
aq_out
,
output_tensor_scale
=
aq_scale_out
,
m_indices
=
expert_ids
,
output_index
=
inv_perm
,
)
return
aq_out
,
aq_scale_out
,
expert_ids
,
inv_perm
def
deepgemm_unpermute_and_reduce
(
a
:
torch
.
Tensor
,
# Grouped gemm output
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
inv_perm
:
torch
.
Tensor
,
expert_map
:
torch
.
Tensor
|
None
,
output
:
torch
.
Tensor
,
):
return
ep_gather
(
input_tensor
=
a
,
recv_topk_ids
=
topk_ids
,
recv_topk_weight
=
topk_weights
,
input_index
=
inv_perm
,
expert_map
=
expert_map
,
output_tensor
=
output
,
)
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
View file @
8943d3db
...
...
@@ -19,12 +19,15 @@ from vllm.model_executor.layers.fused_moe import (
FusedMoEConfig
,
FusedMoeWeightScaleSupported
,
FusedMoEPermuteExpertsUnpermute
,
FusedMoEPrepareAndFinalize
,)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.layers.fused_moe.utils
import
_resize_cache
from
vllm.model_executor.layers.fused_moe.utils
import
_resize_cache
,
compute_aligned_M
,
deepgemm_moe_permute
,
deepgemm_unpermute_and_reduce
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
get_w8a8_int8_marlin_weights
,
w8a8_nt_kpack2_marlin_weight
)
from
vllm.model_executor.layers.fused_moe.moe_permute_unpermute
import
(
_moe_permute
)
from
vllm.utils
import
round_up
try
:
from
lightop
import
m_grouped_w8a8_gemm_nt_masked
,
fuse_silu_mul_quant_ep
from
lightop
import
m_grouped_w8a8_gemm_nt_masked
,
m_grouped_w8a8_gemm_nt_contig_asm
,
fuse_silu_mul_quant_ep
,
fuse_silu_mul_quant
from
lmslim.layers.fused_moe.fuse_moe_int8_marlin
import
fused_experts_impl_int8_marlin
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer the quantitative model of moe.
\n
"
)
...
...
@@ -84,26 +87,27 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
parallel_config
=
vllm_config
.
parallel_config
self
.
dp_size
=
get_dp_group
().
world_size
self
.
ep_size
=
get_ep_group
().
world_size
backend
=
envs
.
VLLM_ALL2ALL_BACKEND
self
.
use_deepep
=
self
.
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
and
\
(
backend
==
"deepep_high_throughput"
or
\
backend
==
"deepep_low_latency"
or
\
backend
==
"deepep_auto"
)
(
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_high_throughput"
or
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
or
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_auto"
)
self
.
use_deepep_ll
=
self
.
use_deepep
and
(
backend
==
"deepep_low_latency"
or
\
(
backend
==
"deepep_auto"
))
#self.use_deepep_ll = self.use_deepep and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency"
if
self
.
use_deepep
:
all2all_manager
=
get_ep_group
().
device_communicator
.
all2all_manager
assert
all2all_manager
is
not
None
self
.
num_dispatchers
=
all2all_manager
.
world_size
self
.
block_shape
=
[
256
,
256
]
self
.
use_deepgemm
=
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
or
envs
.
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
or
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_auto"
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
if
self
.
use_deepep
_ll
:
if
self
.
use_deepep
:
self
.
N
=
2
*
intermediate_size_per_partition
self
.
K
=
hidden_size
...
...
@@ -157,7 +161,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
w1_marlin_list
=
[]
for
ii
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
if
not
self
.
use_deep
ep_ll
:
if
not
self
.
use_deep
gemm
:
w1_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w13_weight
[
ii
])
else
:
w1_marlin_in
=
w8a8_nt_kpack2_marlin_weight
(
layer
.
w13_weight
[
ii
])
...
...
@@ -168,7 +172,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
del
w1_marlin_list
w2_marlin_list
=
[]
for
ii
in
range
(
layer
.
w2_weight
.
shape
[
0
]):
if
not
self
.
use_deep
ep_ll
:
if
not
self
.
use_deep
gemm
:
w2_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w2_weight
[
ii
])
else
:
w2_marlin_in
=
w8a8_nt_kpack2_marlin_weight
(
layer
.
w2_weight
[
ii
])
...
...
@@ -178,7 +182,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
layer
.
w13_weight
=
Parameter
(
w1_marlin
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w2_marlin
,
requires_grad
=
False
)
def
groupgemm_workspace_shapes
(
self
,
def
masked_
groupgemm_workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
...
...
@@ -201,7 +205,26 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
output
=
(
num_experts
,
max_num_tokens
*
num_dispatchers
,
K
)
return
(
workspace13
,
workspace2
,
output
,
a
.
dtype
)
def
w8a8_groupgemm_forward
(
self
,
def
contiguous_groupgemm_workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_num_tokens_cpu
:
torch
.
Tensor
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
assert
self
.
block_shape
is
not
None
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.
block_m
=
self
.
block_shape
[
0
]
M_sum
=
compute_aligned_M
(
M
,
topk
,
local_num_experts
,
block_m
,
expert_num_tokens_cpu
)
assert
M_sum
%
block_m
==
0
workspace1
=
(
M_sum
,
max
(
N
,
K
))
workspace2
=
(
M_sum
,
max
(
N
//
2
,
K
))
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
,
a
.
dtype
,
M_sum
)
def
w8a8_groupgemm_masked_forward
(
self
,
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
...
...
@@ -220,6 +243,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
routed_scaling_factor
:
Optional
[
float
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
q_x
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_num_tokens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
):
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
...
...
@@ -230,7 +254,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
N
,
K
=
self
.
N
,
self
.
K
(
workspace13_shape
,
workspace2_shape
,
fused_out_shape
,
workspace_dtype
)
=
self
.
groupgemm_workspace_shapes
(
workspace_dtype
)
=
self
.
masked_
groupgemm_workspace_shapes
(
x
,
q_x
,
max_num_tokens
,
N
,
K
,
top_k
,
global_num_experts
,
local_num_experts
)
workspace13
=
torch
.
empty
(
prod
(
workspace13_shape
),
...
...
@@ -269,6 +293,94 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
return
fused_out
def
w8a8_groupgemm_contiguous_forward
(
self
,
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
q_x
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_num_tokens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
):
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
local_num_experts
=
w1
.
size
(
0
)
a1q
=
q_x
N
,
K
=
self
.
N
,
self
.
K
(
workspace13_shape
,
workspace2_shape
,
fused_out_shape
,
workspace_dtype
,
M_sum
)
=
self
.
contiguous_groupgemm_workspace_shapes
(
x
,
q_x
,
topk_ids
.
size
(
0
),
N
,
K
,
topk_ids
.
size
(
1
),
global_num_experts
,
local_num_experts
,
expert_num_tokens_cpu
)
workspace13
=
torch
.
empty
(
prod
(
workspace13_shape
),
device
=
x
.
device
,
dtype
=
workspace_dtype
)
workspace2
=
torch
.
empty
(
prod
(
workspace2_shape
),
device
=
x
.
device
,
dtype
=
workspace_dtype
)
mm1_out
=
_resize_cache
(
workspace13
,
(
M_sum
,
N
))
mm2_out
=
_resize_cache
(
workspace2
,
(
M_sum
,
K
))
fused_out
=
_resize_cache
(
workspace13
,
fused_out_shape
)
a1q_perm
=
_resize_cache
(
workspace2
.
view
(
dtype
=
a1q
.
dtype
),
(
M_sum
,
K
))
a1q
,
a1q_scale
,
expert_ids
,
inv_perm
=
deepgemm_moe_permute
(
aq
=
a1q
,
aq_scale
=
a1_scale
,
topk_ids
=
topk_ids
,
local_num_experts
=
local_num_experts
,
expert_map
=
expert_map
,
block_shape
=
self
.
block_shape
,
expert_num_tokens
=
expert_num_tokens
,
expert_num_tokens_cpu
=
expert_num_tokens_cpu
,
aq_out
=
a1q_perm
,
M_sum
=
M_sum
)
# if expert_map is not None:
# # DeepGemm (Grouped Contiguous) kernel needs a valid B index
# # for all rows of A. To that effect, simply compute with
# # the 0th weight matrix.
# # Note that this relies on the fact that corresponding topk
# # weights would be 0 during weight multiplication.
# expert_ids = torch.where(expert_ids == -1, 0, expert_ids)
m_grouped_w8a8_gemm_nt_contig_asm
(
(
a1q
,
a1q_scale
),
(
w1
,
w1_scale
),
mm1_out
,
expert_ids
)
a2q
,
a2q_scale
=
fuse_silu_mul_quant
(
mm1_out
)
m_grouped_w8a8_gemm_nt_contig_asm
(
(
a2q
,
a2q_scale
),
(
w2
,
w2_scale
),
mm2_out
,
expert_ids
)
if
apply_router_weight_on_input
:
topk_weights
=
torch
.
ones_like
(
topk_weights
)
deepgemm_unpermute_and_reduce
(
a
=
mm2_out
,
topk_ids
=
topk_ids
,
topk_weights
=
topk_weights
,
inv_perm
=
inv_perm
,
expert_map
=
expert_map
,
output
=
fused_out
,
)
return
fused_out
def
fused_moe_forward
(
self
,
x
:
torch
.
Tensor
,
...
...
@@ -289,6 +401,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
routed_scaling_factor
:
Optional
[
float
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
q_x
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_num_tokens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
):
return
fused_experts_impl_int8_marlin
(
hidden_states
=
x
if
q_x
is
None
else
q_x
,
...
...
@@ -401,7 +514,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
return
TritonOrGroupGemmExperts
(
use_int8_w8a8
=
True
,
per_act_token_quant
=
True
,
fused_experts
=
self
.
w8a8_groupgemm_forward
fused_experts
=
self
.
w8a8_groupgemm_
masked_
forward
)
else
:
logger
.
debug
(
...
...
@@ -410,5 +523,6 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
False
)
return
TritonOrGroupGemmExperts
(
fused_experts
=
self
.
fused_moe_forward
use_int8_w8a8
=
envs
.
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
,
fused_experts
=
self
.
w8a8_groupgemm_contiguous_forward
if
envs
.
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
else
self
.
fused_moe_forward
)
\ No newline at end of file
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
View file @
8943d3db
...
...
@@ -168,6 +168,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
(
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_high_throughput"
or
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
)
self
.
ep_size
=
get_ep_group
().
world_size
if
self
.
use_deepep
:
all2all_manager
=
get_ep_group
().
device_communicator
.
all2all_manager
assert
all2all_manager
is
not
None
...
...
@@ -352,7 +354,9 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
# (from deepgemm docs) : A value hint (which is a value on CPU)
# for the M expectation of each batch, correctly setting this value
# may lead to better performance.
expected_m
=
max_num_tokens
#expected_m = max_num_tokens
ori_bs
=
x
.
shape
[
0
]
expected_m
=
ori_bs
*
self
.
ep_size
m_grouped_w4a8_gemm_nt_masked
((
q_x
,
a1_scale
),
(
w1
,
w1_scale
),
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
8943d3db
...
...
@@ -174,14 +174,12 @@ class DeepseekV2MoE(nn.Module):
dp_size
=
get_dp_group
().
world_size
self
.
use_mori_ep
=
parallel_config
.
enable_expert_parallel
and
dp_size
>
1
and
envs
.
VLLM_ALL2ALL_BACKEND
==
'mori'
self
.
enable_expert_parallel
=
parallel_config
.
enable_expert_parallel
backend
=
envs
.
VLLM_ALL2ALL_BACKEND
self
.
use_deepep_ll
=
(
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
and
(
backend
==
"deepep_low_latency"
or
backend
==
"deepep_auto"
)
)
self
.
use_deepep
=
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
and
\
(
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_high_throughput"
or
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
or
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_auto"
)
if
not
self
.
use_deepep
_ll
:
if
not
self
.
use_deepep
:
moe_cls
=
FusedMoE
if
not
self
.
use_mori_ep
else
MoriMoE
self
.
experts
=
moe_cls
(
num_experts
=
config
.
n_routed_experts
,
...
...
@@ -254,7 +252,7 @@ class DeepseekV2MoE(nn.Module):
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
not
self
.
use_mori_ep
and
not
self
.
use_deepep
_ll
:
if
not
self
.
use_mori_ep
and
not
self
.
use_deepep
:
if
self
.
n_shared_experts
is
not
None
:
if
envs
.
USE_FUSED_RMS_QUANT
:
shared_output
,
new_resi
=
self
.
shared_experts
(
hidden_states
,
rms_weight
,
residual
,
update_hd
=
True
)
...
...
@@ -289,7 +287,7 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states
=
final_hidden_states
+
shared_output
\
*
(
1.
/
self
.
routed_scaling_factor
)
else
:
if
self
.
use_deepep
_ll
:
if
self
.
use_deepep
:
shared_output
,
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
...
...
@@ -721,12 +719,10 @@ class DeepseekV2DecoderLayer(nn.Module):
self
.
dp_size
=
get_dp_group
().
world_size
vllm_config
=
get_current_vllm_config
()
parallel_config
=
vllm_config
.
parallel_config
backend
=
envs
.
VLLM_ALL2ALL_BACKEND
self
.
use_deepep_ll
=
(
self
.
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
and
(
backend
==
"deepep_low_latency"
or
backend
==
"deepep_auto"
)
)
self
.
use_deepep
=
self
.
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
and
\
(
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_high_throughput"
or
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
or
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_auto"
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
if
(
config
.
n_routed_experts
is
not
None
...
...
@@ -855,7 +851,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states
,
residual
)
if
isinstance
(
self
.
mlp
,
DeepseekV2MoE
)
and
self
.
use_deepep
_ll
and
self
.
tp_size
>
1
:
DeepseekV2MoE
)
and
self
.
use_deepep
and
self
.
tp_size
>
1
:
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
ori_bs
=
hidden_states
.
shape
[
0
]
...
...
@@ -868,7 +864,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states
=
self
.
mlp
(
hidden_states
)
if
isinstance
(
self
.
mlp
,
DeepseekV2MoE
)
and
self
.
use_deepep
_ll
and
self
.
tp_size
>
1
:
DeepseekV2MoE
)
and
self
.
use_deepep
and
self
.
tp_size
>
1
:
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
,
dim
=
0
).
contiguous
()
hidden_states
=
hidden_states
[:
ori_bs
,
:].
contiguous
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment