Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
13130b89
Commit
13130b89
authored
Dec 18, 2025
by
王敏
Browse files
[feat]合入基于deepep的大EP
parent
06106338
Changes
27
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2367 additions
and
175 deletions
+2367
-175
vllm/config.py
vllm/config.py
+3
-1
vllm/distributed/device_communicators/all2all.py
vllm/distributed/device_communicators/all2all.py
+20
-7
vllm/distributed/device_communicators/base_device_communicator.py
...tributed/device_communicators/base_device_communicator.py
+3
-2
vllm/envs.py
vllm/envs.py
+5
-0
vllm/forward_context.py
vllm/forward_context.py
+2
-1
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+5
-0
vllm/model_executor/layers/fused_moe/config.py
vllm/model_executor/layers/fused_moe/config.py
+1
-1
vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
+207
-74
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
+137
-38
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
+1
-0
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+66
-14
vllm/model_executor/layers/fused_moe/modular_kernel.py
vllm/model_executor/layers/fused_moe/modular_kernel.py
+396
-0
vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
.../model_executor/layers/fused_moe/pplx_prepare_finalize.py
+1
-0
vllm/model_executor/layers/fused_moe/prepare_finalize.py
vllm/model_executor/layers/fused_moe/prepare_finalize.py
+1
-0
vllm/model_executor/layers/fused_moe/shared_fused_moe.py
vllm/model_executor/layers/fused_moe/shared_fused_moe.py
+58
-0
vllm/model_executor/layers/fused_moe/triton_group_gemm_moe.py
.../model_executor/layers/fused_moe/triton_group_gemm_moe.py
+112
-0
vllm/model_executor/layers/fused_moe/utils.py
vllm/model_executor/layers/fused_moe/utils.py
+761
-4
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+36
-1
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
...ation/compressed_tensors/compressed_tensors_moe_marlin.py
+339
-18
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
...del_executor/layers/quantization/slimquant_w4a8_marlin.py
+213
-14
No files found.
vllm/config.py
View file @
13130b89
...
...
@@ -4785,8 +4785,10 @@ class VllmConfig:
# add for spec decode
if
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
num_lookahead_slots
>
0
:
batch_size_capture_list
=
list
(
map
(
lambda
x
:
x
*
(
1
+
self
.
speculative_config
.
num_lookahead_slots
),
mtp_
batch_size_capture_list
=
list
(
map
(
lambda
x
:
x
*
(
1
+
self
.
speculative_config
.
num_lookahead_slots
),
batch_size_capture_list
))
batch_size_capture_list
=
sorted
(
set
(
batch_size_capture_list
+
mtp_batch_size_capture_list
))
batch_size_capture_list
=
[
i
for
i
in
batch_size_capture_list
if
i
==
1
or
i
%
(
1
+
self
.
speculative_config
.
num_lookahead_slots
)
==
0
]
self
.
compilation_config
.
init_with_cudagraph_sizes
(
batch_size_capture_list
)
...
...
vllm/distributed/device_communicators/all2all.py
View file @
13130b89
...
...
@@ -8,7 +8,7 @@ import torch.distributed as dist
from
vllm.forward_context
import
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.utils
import
has_deep_ep
,
has_pplx
import
vllm.envs
as
envs
from
.base_device_communicator
import
All2AllManagerBase
,
Cache
logger
=
init_logger
(
__name__
)
...
...
@@ -140,7 +140,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
# This is the DeepEP default. Stick to it till we can establish
# reasonable defaults based on profiling.
self
.
num_sms
=
2
0
self
.
num_sms
=
3
0
def
get_handle
(
self
,
kwargs
):
raise
NotImplementedError
...
...
@@ -166,16 +166,26 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
def
_make_all2all_kwargs
(
self
)
->
dict
[
Any
,
Any
]:
# Defaults for internode and intranode are taken from DeepEP tests.
num_nvl_bytes
=
1024
*
1024
*
1024
num_nvl_bytes
=
int
(
2e9
/
2
)
#
1024 * 1024 * 1024
num_rdma_bytes
=
None
num_qps_per_rank
=
None
if
self
.
internode
:
num_rdma_bytes
=
1024
*
1024
*
1024
num_qps_per_rank
=
self
.
num_sms
//
2
num_rdma_bytes
=
int
(
1e9
/
2
)
#1024 * 1024 * 1024
num_qps_per_rank
=
30
#self.num_sms // 2
self
.
num_sms
=
30
# import deep_ep
# num_nvl_bytes, num_rdma_bytes = 0, 0
# hidden_size = 7168
# hidden_bytes = hidden_size * 2
# for config in (deep_ep.Buffer.get_dispatch_config(self.cpu_group.size()), deep_ep.Buffer.get_combine_config(self.cpu_group.size())):
# num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, self.cpu_group.size()), num_nvl_bytes)
# num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, self.cpu_group.size()), num_rdma_bytes)
else
:
num_rdma_bytes
=
0
num_qps_per_rank
=
1
self
.
num_sms
=
60
assert
num_rdma_bytes
is
not
None
assert
num_qps_per_rank
is
not
None
...
...
@@ -183,7 +193,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
num_nvl_bytes
=
num_nvl_bytes
,
num_rdma_bytes
=
num_rdma_bytes
,
low_latency_mode
=
False
,
num_qps_per_rank
=
num_qps_per_rank
)
num_qps_per_rank
=
num_qps_per_rank
,
explicitly_destroy
=
False
)
def
get_handle
(
self
,
kwargs
):
...
...
@@ -244,7 +255,9 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
num_nvl_bytes
=
num_nvl_bytes
,
num_rdma_bytes
=
num_rdma_bytes
,
low_latency_mode
=
True
,
num_qps_per_rank
=
num_qps_per_rank
)
num_qps_per_rank
=
num_qps_per_rank
,
allow_mnnvl
=
envs
.
VLLM_ALLOW_MNNVL
,
)
def
get_handle
(
self
,
kwargs
):
"""
...
...
vllm/distributed/device_communicators/base_device_communicator.py
View file @
13130b89
...
...
@@ -237,10 +237,11 @@ class DeviceCommunicatorBase:
moe_modules
=
[
module
for
module
in
model
.
modules
()
if
module
.
__class__
.
__name__
==
"FusedMoE"
if
(
module
.
__class__
.
__name__
==
"FusedMoE"
or
module
.
__class__
.
__name__
==
"SharedFusedMoE"
)
]
for
module
in
moe_modules
:
module
.
quant_method
.
init_prepare_finalize
(
module
.
moe_config
,
module
.
quant_method
.
init_prepare_finalize
(
module
,
module
.
moe_config
,
module
.
quant_config
)
def
dispatch
(
...
...
vllm/envs.py
View file @
13130b89
...
...
@@ -196,6 +196,7 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
bool
=
False
VLLM_USE_FUSED_RMS_ROPE
:
bool
=
False
VLLM_USE_MARLIN_W16A16_MOE
:
bool
=
False
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
:
bool
=
True
def
get_default_cache_root
():
return
os
.
getenv
(
...
...
@@ -1275,6 +1276,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_MARLIN_W16A16_MOE"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_MARLIN_W16A16_MOE"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# vLLM will use deepgemm kernel for deepep ht mode
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM"
:
lambda
:
(
os
.
getenv
(
'VLLM_ENABLE_DEEPEP_HT_DEEPGEMM'
,
'1'
).
lower
()
in
(
"true"
,
"1"
)),
}
...
...
vllm/forward_context.py
View file @
13130b89
...
...
@@ -136,7 +136,8 @@ def set_forward_context(
forward_start_time
=
time
.
perf_counter
()
dp_metadata
:
Optional
[
DPMetadata
]
=
None
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
if
dp_size
>
1
and
(
use_navie_ep
=
envs
.
VLLM_ALL2ALL_BACKEND
==
'naive'
and
dp_size
>
1
and
vllm_config
.
parallel_config
.
enable_expert_parallel
if
use_navie_ep
and
dp_size
>
1
and
(
attn_metadata
is
not
None
or
num_tokens
is
not
None
):
dp_metadata
=
DPMetadata
.
make
(
vllm_config
.
parallel_config
,
attn_metadata
,
num_tokens
or
0
,
...
...
vllm/model_executor/layers/fused_moe/__init__.py
View file @
13130b89
...
...
@@ -10,6 +10,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
FusedMoEActivationFormat
,
FusedMoEPermuteExpertsUnpermute
,
FusedMoEPrepareAndFinalize
)
from
vllm.model_executor.layers.fused_moe.shared_fused_moe
import
SharedFusedMoE
from
vllm.triton_utils
import
HAS_TRITON
_config
:
Optional
[
dict
[
str
,
Any
]]
=
None
...
...
@@ -38,6 +39,7 @@ __all__ = [
"FusedMoEPrepareAndFinalize"
,
"override_config"
,
"get_config"
,
"SharedFusedMoE"
,
]
if
HAS_TRITON
:
...
...
@@ -59,6 +61,8 @@ if HAS_TRITON:
get_config_file_name
,
grouped_topk
)
from
vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe
import
(
TritonOrDeepGemmExperts
)
from
vllm.model_executor.layers.fused_moe.triton_group_gemm_moe
import
(
TritonOrGroupGemmExperts
)
__all__
+=
[
"fused_moe"
,
...
...
@@ -75,4 +79,5 @@ if HAS_TRITON:
"BatchedDeepGemmExperts"
,
"TritonOrDeepGemmExperts"
,
"BatchedTritonOrDeepGemmExperts"
,
"TritonOrGroupGemmExperts"
,
]
vllm/model_executor/layers/fused_moe/config.py
View file @
13130b89
...
...
@@ -54,7 +54,7 @@ def get_config_quant_dtype(
)
->
Optional
[
torch
.
dtype
]:
if
use_fp8_w8a8
:
return
torch
.
float8_e4m3fn
elif
use_int8_w8a8
:
elif
use_int8_w8a8
or
use_int4_w4a8
:
return
torch
.
int8
return
None
...
...
vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
View file @
13130b89
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
from
collections.abc
import
Callable
import
deep_ep
import
torch
import
torch.distributed
as
dist
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.utils
import
(
moe_kernel_quantize_input
)
from
vllm.distributed.parallel_state
import
get_ep_group
class
DeepEPHTPrepareAndFinalize
(
mk
.
FusedMoEPrepareAndFinalize
):
...
...
@@ -55,35 +59,49 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return
None
return
deep_ep
.
Buffer
.
get_combine_config
(
self
.
dp_size
)
def
_do_dispatch
(
self
,
tokens
:
torch
.
Tensor
,
token_scales
:
Optional
[
torch
.
Tensor
],
def
_do_dispatch
(
self
,
tokens
:
torch
.
Tensor
,
token_scales
:
torch
.
Tensor
|
None
,
rank_topk_ids
:
torch
.
Tensor
,
rank_topk_weights
:
torch
.
Tensor
,
num_experts
:
int
):
rank_topk_weights
:
torch
.
Tensor
,
num_experts
:
int
,
quant_config
:
FusedMoEQuantConfig
,
)
->
Callable
:
has_scales
=
token_scales
is
not
None
(
num_tokens_per_rank
,
num_tokens_per_rdma_rank
,
expert_num_tokens
,
is_token_in_rank
,
event
)
=
self
.
buffer
.
get_dispatch_layout
(
(
num_tokens_per_rank
,
num_tokens_per_rdma_rank
,
dispatch_expert_num_tokens
,
is_token_in_rank
,
event
,
)
=
self
.
buffer
.
get_dispatch_layout
(
topk_idx
=
rank_topk_ids
,
num_experts
=
num_experts
,
previous_event
=
None
,
async_finish
=
False
,
allocate_on_comm_stream
=
False
)
allocate_on_comm_stream
=
False
,
)
token_data
=
tokens
if
has_scales
:
token_data
=
(
tokens
,
token_scales
)
(
token_data
,
expert_topk_ids
,
expert_topk_weights
,
expert_num_tokens_per_expert_list
,
self
.
handle
,
event
token_data
,
expert_topk_ids
,
expert_topk_weights
,
expert_num_tokens_per_expert_list
,
self
.
handle
,
event
,
)
=
self
.
buffer
.
dispatch
(
x
=
token_data
,
handle
=
None
,
num_tokens_per_rank
=
num_tokens_per_rank
,
num_tokens_per_rdma_rank
=
num_tokens_per_rdma_rank
,
is_token_in_rank
=
is_token_in_rank
,
num_tokens_per_expert
=
expert_num_tokens
,
num_tokens_per_expert
=
dispatch_
expert_num_tokens
,
topk_idx
=
rank_topk_ids
,
topk_weights
=
rank_topk_weights
,
# expert_alignment rounds the number of tokens per expert
...
...
@@ -91,8 +109,36 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_alignment
=
1
,
config
=
self
.
_get_dispatch_config
(),
previous_event
=
None
,
async_finish
=
False
,
allocate_on_comm_stream
=
False
)
async_finish
=
True
,
allocate_on_comm_stream
=
False
,
)
return
lambda
:
self
.
_receiver
(
event
,
has_scales
,
token_data
,
expert_topk_ids
,
num_experts
,
expert_num_tokens_per_expert_list
,
expert_topk_weights
,
token_scales
,
quant_config
,
)
def
_receiver
(
self
,
event
:
deep_ep
.
EventOverlap
,
has_scales
:
bool
,
token_data
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
torch
.
Tensor
,
expert_topk_ids
:
torch
.
Tensor
|
None
,
num_experts
:
int
,
expert_num_tokens_per_expert_list
:
list
[
int
],
expert_topk_weights
:
torch
.
Tensor
|
None
,
a1_scale
:
torch
.
Tensor
|
None
,
quant_config
:
FusedMoEQuantConfig
,
)
->
mk
.
PrepareResultType
:
if
event
.
event
is
not
None
:
event
.
current_stream_wait
()
if
has_scales
:
expert_x
,
expert_x_scale
=
token_data
...
...
@@ -110,15 +156,45 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# DeepEP's topk_ids output refers to the local experts directly. Offset
# the topk_ids to move it back to the global experts space so it aligns
# with existing vLLM interfaces.
assert
expert_topk_ids
is
not
None
expert_topk_ids
=
torch
.
where
(
expert_topk_ids
==
-
1
,
num_experts
-
1
if
self
.
rank_expert_offset
==
0
else
0
,
expert_topk_ids
+
self
.
rank_expert_offset
)
expert_topk_ids
+
self
.
rank_expert_offset
,
)
return
(
expert_x
,
expert_x_scale
,
expert_num_tokens
,
expert_topk_ids
,
expert_topk_weights
)
# Makes a GPU-CPU copy.
# TODO (varun): Maybe it is better to re-compute the expert_num_tokens
# on GPU.
expert_tokens_meta
=
mk
.
ExpertTokensMetadata
.
make_from_list
(
expert_num_tokens_per_expert_list
,
device
=
expert_x
.
device
)
def
prepare
(
# Dispatch and Quant
# DeepEP kernels only support dispatching block-quantized
# activation scales.
# Dispatch in bfloat16 and quantize afterwards
if
not
quant_config
.
per_act_token_quant
:
# Quantize after dispatch.
expert_x_scale
=
None
if
expert_x
.
numel
()
!=
0
:
expert_x
,
expert_x_scale
=
moe_kernel_quantize_input
(
expert_x
,
a1_scale
,
quant_dtype
=
quant_config
.
quant_dtype
,
per_act_token_quant
=
False
,
block_shape
=
quant_config
.
block_shape
,
)
return
(
expert_x
,
expert_x_scale
,
expert_tokens_meta
,
expert_topk_ids
,
expert_topk_weights
,
)
def
prepare_async
(
self
,
a1
:
torch
.
Tensor
,
a1_scale
:
Optional
[
torch
.
Tensor
],
...
...
@@ -129,14 +205,13 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map
:
Optional
[
torch
.
Tensor
],
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
)
->
mk
.
ReceiverType
:
if
apply_router_weight_on_input
:
topk
=
topk_ids
.
size
(
1
)
# TODO: this only works for topK=1, will need to update for topK>1
assert
topk
==
1
,
(
"apply_router_weight_on_input is only implemented for topk=1"
)
"apply_router_weight_on_input is only implemented for topk=1"
)
a1
=
a1
*
topk_weights
.
to
(
a1
.
dtype
)
if
quant_config
.
per_act_token_quant
:
...
...
@@ -149,35 +224,43 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
)
if
a1q_scale
is
not
None
and
a1q_scale
.
numel
()
==
1
:
a1q_scale
=
a1q_scale
.
view
(
1
,
1
)
(
expert_x
,
expert_x_scale
,
expert_num_tokens
,
expert_topk_ids
,
expert_topk_weights
)
=
self
.
_do_dispatch
(
else
:
a1q
=
a1
a1q_scale
=
None
return
self
.
_do_dispatch
(
tokens
=
a1q
,
token_scales
=
a1q_scale
,
rank_topk_ids
=
topk_ids
,
rank_topk_weights
=
topk_weights
,
num_experts
=
num_experts
)
else
:
# DeepEP kernels only support dispatching per-token-quant
# quantization. dispatch in bfloat16.
(
expert_x
,
_
,
expert_num_tokens
,
expert_topk_ids
,
expert_topk_weights
)
=
self
.
_do_dispatch
(
tokens
=
a1
,
token_scales
=
None
,
rank_topk_ids
=
topk_ids
,
rank_topk_weights
=
topk_weights
,
num_experts
=
num_experts
)
# quantize now
expert_x_scale
=
None
if
expert_x
.
numel
()
!=
0
:
expert_x
,
expert_x_scale
=
moe_kernel_quantize_input
(
expert_x
,
a1_scale
,
quant_dtype
=
quant_config
.
quant_dtype
,
per_act_token_quant
=
False
,
block_shape
=
quant_config
.
block_shape
)
num_experts
=
num_experts
,
quant_config
=
quant_config
,
)
return
(
expert_x
,
expert_x_scale
,
expert_num_tokens
,
expert_topk_ids
,
expert_topk_weights
)
def
prepare
(
self
,
a1
:
torch
.
Tensor
,
a1_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
)
->
mk
.
PrepareResultType
:
receiver
=
self
.
prepare_async
(
a1
,
a1_scale
,
a2_scale
,
topk_weights
,
topk_ids
,
num_experts
,
expert_map
,
apply_router_weight_on_input
,
quant_config
,
)
return
receiver
()
def
_apply_weights_and_reduce
(
self
,
num_tokens
:
int
,
fused_expert_output
:
torch
.
Tensor
,
...
...
@@ -203,29 +286,79 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return
out
def
finalize
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
)
->
None
:
def
_finalize
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
do_async
:
bool
,
apply_weights_and_reduce
:
bool
=
True
,
)
->
Callable
|
None
:
assert
self
.
handle
is
not
None
# fused_expert_output can have 0 tokens - This happens when none of the
# tokens from the all2all reach this EP rank.
if
fused_expert_output
.
numel
()
!=
0
:
fused_expert_output
=
self
.
_apply_weights_and_reduce
(
num_tokens
=
topk_ids
.
size
(
0
),
fused_expert_output
=
fused_expert_output
,
topk_weights
=
topk_weights
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
output_dtype
=
output
.
dtype
)
combined_x
,
_
,
event
=
self
.
buffer
.
combine
(
# HT combine only supports BF16
x
=
fused_expert_output
,
handle
=
self
.
handle
,
topk_weights
=
None
,
config
=
self
.
_get_combine_config
(),
previous_event
=
None
,
async_finish
=
False
,
allocate_on_comm_stream
=
False
)
async_finish
=
do_async
,
allocate_on_comm_stream
=
False
,
)
if
do_async
:
def
_receiver
():
if
event
.
event
is
not
None
:
event
.
current_stream_wait
()
# Respect inplace outputs.
output
.
copy_
(
combined_x
,
non_blocking
=
True
)
return
_receiver
else
:
# Respect inplace outputs.
output
.
copy_
(
combined_x
,
non_blocking
=
True
)
return
None
def
finalize_async
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
apply_weights_and_reduce
:
bool
=
True
,
)
->
Callable
:
receiver
=
self
.
_finalize
(
output
,
fused_expert_output
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
do_async
=
True
,
apply_weights_and_reduce
=
apply_weights_and_reduce
,
)
assert
receiver
is
not
None
return
receiver
def
finalize
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
apply_weights_and_reduce
:
bool
=
True
,
)
->
None
:
self
.
_finalize
(
output
,
fused_expert_output
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
do_async
=
False
,
apply_weights_and_reduce
=
apply_weights_and_reduce
,
)
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
View file @
13130b89
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
,
Union
from
collections.abc
import
Callable
import
deep_ep
import
torch
...
...
@@ -44,12 +45,14 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
buffer
:
deep_ep
.
Buffer
,
max_tokens_per_rank
:
int
,
num_dispatchers
:
int
,
use_fp8_dispatch
:
bool
=
False
):
use_fp8_dispatch
:
bool
=
False
,
use_int8_dispatch
:
bool
=
False
):
super
().
__init__
()
self
.
buffer
=
buffer
self
.
max_tokens_per_rank
=
max_tokens_per_rank
self
.
use_fp8_dispatch
=
use_fp8_dispatch
self
.
use_int8_dispatch
=
use_int8_dispatch
# The dispatch function returns a handle that the combine function
# requires. We store the handle here so it is available to the
# combine function.
...
...
@@ -71,17 +74,19 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def
_do_quant
(
self
,
x
:
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
]
,
x
:
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
a1_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
a1_dtype
:
torch
.
dtype
,
quant_dtype
:
Optional
[
torch
.
dtype
],
per_act_token_quant
:
bool
,
block_shape
:
Optional
[
list
[
int
]],
quant_config
:
FusedMoEQuantConfig
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
block_k
=
block_shape
[
1
]
if
block_shape
is
not
None
else
None
if
self
.
use_fp8_dispatch
:
block_k
=
(
quant_config
.
block_shape
[
1
]
if
quant_config
.
block_shape
is
not
None
else
None
)
if
block_k
==
DEEPEP_QUANT_BLOCK_SIZE
:
# DeepEP kernels did the quantization for us.
x
,
x_scales
=
x
...
...
@@ -96,19 +101,25 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts
,
max_tokens
,
hidden_dim
=
x
.
size
()
# TODO (varun): Optimization - Use a batched version of quant
if
expert_num_tokens
is
None
:
x
=
x
.
view
((
-
1
,
hidden_dim
))
x
,
x_scales
=
moe_kernel_quantize_input
(
x
,
a1_scale
,
quant_dtype
,
per_act_token_quant
,
block_shape
)
x
,
x_scales
=
moe_kernel_quantize_input
(
x
,
a1_scale
,
quant_config
.
quant_dtype
,
quant_config
.
per_act_token_quant
,
quant_config
.
block_shape
,
expert_num_tokens
)
x
=
x
.
view
((
num_experts
,
-
1
,
hidden_dim
))
if
quant_dtype
is
not
None
:
if
quant_config
.
quant_dtype
is
not
None
:
assert
x_scales
is
not
None
x_scales
=
normalize_batched_scales_shape
(
x_scales
,
num_experts
)
return
x
,
x_scales
def
prepare
(
def
prepare
_async
(
self
,
a1
:
torch
.
Tensor
,
a1_scale
:
Optional
[
torch
.
Tensor
],
...
...
@@ -119,9 +130,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map
:
Optional
[
torch
.
Tensor
],
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
)
->
tuple
[
Callable
,
mk
.
ReceiverType
]:
hidden_size
=
a1
.
size
(
1
)
assert
hidden_size
in
self
.
SUPPORTED_HIDDEN_SIZES
,
\
(
f
"Hidden Size
{
hidden_size
}
not in supported list of hidden sizes"
...
...
@@ -141,29 +150,86 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk
=
topk_ids
.
size
(
1
)
# TODO: this only works for topK=1, will need to update for topK>1
assert
topk
==
1
,
(
"apply_router_weight_on_input is only implemented for topk=1"
)
"apply_router_weight_on_input is only implemented for topk=1"
)
a1
=
a1
*
topk_weights
.
to
(
a1
.
dtype
)
# Dispatch
expert_x
,
expert_num_tokens
,
self
.
handle
,
event
,
hook
=
\
self
.
buffer
.
low_latency_dispatch
(
a1
,
expert_x
,
expert_num_tokens
,
self
.
handle
,
_
,
hook
=
self
.
buffer
.
low_latency_dispatch
(
a1
,
topk_ids
,
self
.
max_tokens_per_rank
,
num_experts
,
use_fp8
=
self
.
use_fp8_dispatch
,
use_fp8
=
self
.
use_fp8_dispatch
or
self
.
use_int8_dispatch
,
use_int8
=
self
.
use_int8_dispatch
,
async_finish
=
False
,
return_recv_hook
=
False
)
expert_x
,
expert_x_scale
=
self
.
_do_quant
(
expert_x
,
a1_scale
,
a2_scale
,
a1
.
dtype
,
quant_config
.
quant_dtype
,
quant_config
.
per_act_token_quant
,
quant_config
.
block_shape
)
return_recv_hook
=
True
,
)
return
(
hook
,
lambda
:
self
.
_receiver
(
expert_x
,
expert_num_tokens
,
a1_scale
,
a1
.
dtype
,
quant_config
,
),
)
def
_receiver
(
self
,
expert_x
:
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
expert_num_tokens
:
torch
.
Tensor
,
a1_scale
:
torch
.
Tensor
|
None
,
a1_dtype
:
torch
.
dtype
,
quant_config
:
FusedMoEQuantConfig
,
)
->
mk
.
PrepareResultType
:
expert_x
,
expert_x_scale
=
self
.
_do_quant
(
expert_x
,
a1_scale
,
a1_dtype
,
quant_config
,
expert_num_tokens
)
return
(
expert_x
,
expert_x_scale
,
expert_num_tokens
,
None
,
None
)
expert_tokens_meta
=
mk
.
ExpertTokensMetadata
(
expert_num_tokens
=
expert_num_tokens
,
expert_num_tokens_cpu
=
None
)
def
finalize
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
)
->
None
:
return
expert_x
,
expert_x_scale
,
expert_tokens_meta
,
None
,
None
def
prepare
(
self
,
a1
:
torch
.
Tensor
,
a1_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
)
->
mk
.
PrepareResultType
:
hook
,
receiver
=
self
.
prepare_async
(
a1
,
a1_scale
,
a2_scale
,
topk_weights
,
topk_ids
,
num_experts
,
expert_map
,
apply_router_weight_on_input
,
quant_config
,
)
hook
()
return
receiver
()
def
_finalize
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
apply_weights_and_reduce
:
bool
,
do_async
:
bool
,
)
->
Callable
:
do_recv_hook
=
do_async
assert
self
.
handle
is
not
None
combine_topk_weights
=
topk_weights
...
...
@@ -172,12 +238,45 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
combine_topk_weights
=
torch
.
ones_like
(
topk_weights
)
# TODO (varun) : Enable zero copy mode
_
,
event
,
hook
=
self
.
buffer
.
low_latency_combine
(
_
,
_
,
recv_
hook
=
self
.
buffer
.
low_latency_combine
(
fused_expert_output
,
topk_ids
,
combine_topk_weights
,
self
.
handle
,
async_finish
=
False
,
zero_copy
=
False
,
return_recv_hook
=
False
,
out
=
output
)
return_recv_hook
=
do_recv_hook
,
out
=
output
,
)
return
recv_hook
def
finalize_async
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
apply_weights_and_reduce
:
bool
=
True
)
->
None
:
return
self
.
_finalize
(
output
,
fused_expert_output
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
apply_weights_and_reduce
,
do_async
=
True
,
)
def
finalize
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
apply_weights_and_reduce
:
bool
=
True
)
->
None
:
self
.
_finalize
(
output
,
fused_expert_output
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
apply_weights_and_reduce
,
do_async
=
False
,
)
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
View file @
13130b89
...
...
@@ -596,6 +596,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
apply_weights_and_reduce
:
bool
=
True
)
->
None
:
num_tokens
=
topk_ids
.
size
(
0
)
num_local_experts
=
fused_expert_output
.
size
(
0
)
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
13130b89
...
...
@@ -29,7 +29,8 @@ from vllm.model_executor.layers.fused_moe.config import (
# yapf: enable
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
FusedMoEActivationFormat
,
FusedMoEModularKernel
,
FusedMoEPermuteExpertsUnpermute
,
FusedMoEPrepareAndFinalize
)
DeepGemmDisabledFusedMoEModularKernel
,
FusedMoEPermuteExpertsUnpermute
,
FusedMoEPrepareAndFinalize
)
# from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
# is_rocm_aiter_moe_enabled)
from
vllm.model_executor.layers.quantization.base_config
import
(
...
...
@@ -40,7 +41,7 @@ from vllm.model_executor.utils import set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms.interface
import
CpuArchEnum
from
vllm.utils
import
direct_register_custom_op
,
has_deep_ep
,
has_pplx
from
vllm.utils
import
direct_register_custom_op
,
has_deep_ep
,
has_pplx
,
has_deep_gemm
from
vllm
import
_custom_ops
as
ops
...
...
@@ -91,7 +92,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
raise
NotImplementedError
def
init_prepare_finalize
(
self
,
moe
:
FusedMoEConfig
,
def
init_prepare_finalize
(
self
,
layer
,
moe
:
FusedMoEConfig
,
quant_config
:
Optional
[
QuantizationConfig
]):
all2all_manager
=
get_ep_group
().
device_communicator
.
all2all_manager
assert
all2all_manager
is
not
None
...
...
@@ -170,6 +171,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
and
moe
.
quant_config
.
block_shape
==
DEEPEP_QUANT_BLOCK_SHAPE
)
use_int8_dispatch
=
False
#moe.quant_config.quant_dtype == torch.int8
# Note (varun): Whether to use FP8 dispatch or not needs some
# profiling. Turning it off for now.
prepare_finalize
=
DeepEPLLPrepareAndFinalize
(
...
...
@@ -177,6 +180,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
max_tokens_per_rank
=
moe
.
max_num_tokens
,
num_dispatchers
=
all2all_manager
.
world_size
,
use_fp8_dispatch
=
use_fp8_dispatch
,
use_int8_dispatch
=
use_int8_dispatch
,
)
self
.
topk_indices_dtype
=
None
...
...
@@ -184,10 +188,18 @@ class FusedMoEMethodBase(QuantizeMethodBase):
logger
.
debug
(
"%s"
,
prepare_finalize
.
__class__
.
__name__
)
self
.
topk_indices_dtype
=
prepare_finalize
.
topk_indices_dtype
()
experts
=
self
.
select_gemm_impl
(
prepare_finalize
,
moe
)
if
has_deep_gemm
():
self
.
fused_experts
=
FusedMoEModularKernel
(
prepare_finalize
,
experts
,
)
else
:
self
.
fused_experts
=
DeepGemmDisabledFusedMoEModularKernel
(
prepare_finalize
,
experts
,
shared_experts
=
layer
.
shared_experts
if
hasattr
(
layer
,
"shared_experts"
)
else
None
,
)
def
select_gemm_impl
(
self
,
...
...
@@ -898,6 +910,10 @@ class FusedMoE(torch.nn.Module):
def
use_deepep_ll_kernels
(
self
):
return
self
.
moe_parallel_config
.
use_deepep_ll_kernels
@
property
def
shared_experts
(
self
)
->
Optional
[
torch
.
nn
.
Module
]:
return
None
def
_load_per_tensor_weight_scale
(
self
,
shard_id
:
str
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
...
...
@@ -1445,9 +1461,13 @@ class FusedMoE(torch.nn.Module):
assert
i_q
is
None
and
i_s
is
None
,
"moe.quant fused not support TPU now"
return
self
.
forward_impl
(
hidden_states
,
router_logits
)
else
:
if
self
.
shared_experts
is
None
:
return
torch
.
ops
.
vllm
.
moe_forward
(
hidden_states
,
router_logits
,
self
.
layer_name
,
shared_output
,
i_q
,
i_s
)
else
:
return
torch
.
ops
.
vllm
.
moe_forward_shared
(
hidden_states
,
router_logits
,
self
.
layer_name
,
shared_output
)
def
forward_impl_chunked
(
self
,
full_hidden_states
:
torch
.
Tensor
,
full_router_logits
:
torch
.
Tensor
):
...
...
@@ -1531,13 +1551,14 @@ class FusedMoE(torch.nn.Module):
i_q
:
Optional
[
torch
.
Tensor
]
=
None
,
i_s
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
):
assert
self
.
quant_method
is
not
None
if
(
self
.
moe_parallel_config
.
use_pplx_kernels
or
self
.
moe_parallel_config
.
use_deepep_ll_kernels
):
if
(
self
.
moe_parallel_config
.
use_pplx_kernels
):
#
or self.moe_parallel_config.use_deepep_ll_kernels):
return
self
.
forward_impl_chunked
(
hidden_states
,
router_logits
)
do_naive_dispatch_combine
:
bool
=
(
self
.
dp_size
>
1
and
not
self
.
moe_parallel_config
.
use_deepep_ht_kernels
)
and
not
self
.
moe_parallel_config
.
use_deepep_ht_kernels
and
not
self
.
moe_parallel_config
.
use_deepep_ll_kernels
)
if
do_naive_dispatch_combine
:
hidden_states
,
router_logits
=
get_ep_group
().
dispatch
(
hidden_states
,
router_logits
)
...
...
@@ -1694,3 +1715,34 @@ direct_register_custom_op(
dispatch_key
=
current_platform
.
dispatch_key
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
)
def
moe_forward_shared
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
layer_name
:
str
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
assert
self
.
shared_experts
is
not
None
return
self
.
forward_impl
(
hidden_states
,
router_logits
,
shared_output
)
def
moe_forward_shared_fake
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
layer_name
:
str
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
shared_out
=
torch
.
empty_like
(
hidden_states
)
fused_out
=
torch
.
empty_like
(
hidden_states
)
return
shared_out
,
fused_out
direct_register_custom_op
(
op_name
=
"moe_forward_shared"
,
op_func
=
moe_forward_shared
,
mutates_args
=
[
"hidden_states"
],
fake_impl
=
moe_forward_shared_fake
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,),
)
\ No newline at end of file
vllm/model_executor/layers/fused_moe/modular_kernel.py
View file @
13130b89
...
...
@@ -4,6 +4,8 @@ from abc import ABC, abstractmethod
from
enum
import
Enum
from
math
import
prod
from
typing
import
Optional
,
final
from
dataclasses
import
dataclass
from
collections.abc
import
Callable
import
torch
...
...
@@ -95,6 +97,54 @@ class FusedMoEActivationFormat(Enum):
BatchedExperts
=
"batched_experts"
,
@
dataclass
class
ExpertTokensMetadata
:
"""
Metadata regarding expert-token routing.
"""
expert_num_tokens
:
torch
.
Tensor
expert_num_tokens_cpu
:
torch
.
Tensor
|
None
@
staticmethod
def
make_from_list
(
expert_num_tokens_list
:
list
[
int
],
device
:
str
)
->
"ExpertTokensMetadata"
:
expert_num_tokens_cpu
=
torch
.
tensor
(
expert_num_tokens_list
,
device
=
"cpu"
,
dtype
=
torch
.
int32
,
pin_memory
=
True
)
expert_num_tokens
=
expert_num_tokens_cpu
.
to
(
device
=
device
,
non_blocking
=
True
)
return
ExpertTokensMetadata
(
expert_num_tokens
=
expert_num_tokens
,
expert_num_tokens_cpu
=
expert_num_tokens_cpu
,
)
#
# PrepareResultType is a tuple of:
# - quantized + dispatched a.
# - quantized + dispatched a1_scales.
# - Optional ExpertTokensMetadata containing gpu/cpu tensors
# as big as the number of local experts with the information about the
# number of tokens assigned to each local expert.
# - Optional dispatched expert topk IDs
# - Optional dispatched expert topk weight
#
# See `prepare` method below.
#
PrepareResultType
=
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
,
ExpertTokensMetadata
|
None
,
torch
.
Tensor
|
None
,
torch
.
Tensor
|
None
,
]
ReceiverType
=
Callable
[[],
PrepareResultType
]
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
class
FusedMoEPrepareAndFinalize
(
ABC
):
"""
...
...
@@ -149,6 +199,7 @@ class FusedMoEPrepareAndFinalize(ABC):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
apply_weights_and_reduce
:
bool
=
True
)
->
None
:
"""
Perform any combine plus apply weights and perform a reduction on the
...
...
@@ -357,6 +408,168 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
raise
NotImplementedError
class
CustomizedFusedMoEPermuteExpertsUnpermute
(
ABC
):
"""
An abstract base class for the [Permute-Experts-Unpermute] step described
above.
"""
def
__init__
(
self
,
quant_config
:
Optional
[
FusedMoEQuantConfig
],
):
if
quant_config
is
not
None
:
self
.
quant_config
=
quant_config
else
:
self
.
quant_config
=
FusedMoEQuantConfig
()
@
property
@
abstractmethod
def
activation_formats
(
self
)
->
tuple
[
FusedMoEActivationFormat
,
FusedMoEActivationFormat
]:
"""
A property which is a tuple of the input and output activation formats
for the 'apply' method.
"""
raise
NotImplementedError
@
property
def
quant_dtype
(
self
)
->
Optional
[
torch
.
dtype
]:
return
self
.
quant_config
.
quant_dtype
@
property
def
block_shape
(
self
)
->
Optional
[
list
[
int
]]:
return
self
.
quant_config
.
block_shape
@
property
def
per_act_token_quant
(
self
)
->
bool
:
return
self
.
quant_config
.
per_act_token_quant
@
property
def
per_out_ch_quant
(
self
)
->
bool
:
return
self
.
quant_config
.
per_out_ch_quant
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
@
abstractmethod
def
supports_chunking
(
self
)
->
bool
:
"""
A flag indicating whether or not this class supports activation
chunking.
"""
raise
NotImplementedError
@
abstractmethod
def
supports_expert_map
(
self
)
->
bool
:
"""
A flag indicating whether or not this class supports expert maps
"""
raise
NotImplementedError
@
abstractmethod
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
"""
Compute the shapes for the temporary and final outputs of the two gemms
and activation in the fused expert function. Since the gemms are
independent, the workspace for the first gemm can be shared with the
workspace for the last gemm.
Returns a tuple of:
- workspace13 shape tuple: must be large enough to hold the
result of either expert gemm.
- workspace2 shape tuple: must be large enough to hold the
result of the activation function.
- output shape tuple: must be exact size of the final gemm output.
- Workspace type: The dtype to use for the workspace tensors.
- Note: in order for activation chunking to work, the first dimension
of each tuple must be the number of tokens.
"""
raise
NotImplementedError
def
activation
(
self
,
activation
:
str
,
output
:
torch
.
Tensor
,
input
:
torch
.
Tensor
)
->
None
:
assert
output
.
size
(
-
1
)
*
2
==
input
.
size
(
-
1
)
if
activation
==
"silu"
:
torch
.
ops
.
_C
.
silu_and_mul
(
output
,
input
)
elif
activation
==
"gelu"
:
torch
.
ops
.
_C
.
gelu_and_mul
(
output
,
input
)
else
:
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
"
)
def
enable_chunking
(
self
):
return
envs
.
VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING
and
\
self
.
supports_chunking
()
@
abstractmethod
def
apply
(
self
,
output
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
global_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
w1_scale
:
Optional
[
torch
.
Tensor
],
w2_scale
:
Optional
[
torch
.
Tensor
],
w1_zp
:
Optional
[
torch
.
Tensor
],
w2_zp
:
Optional
[
torch
.
Tensor
],
a1q_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
workspace13
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
):
"""
This function computes the intermediate result of a Mixture of Experts
(MoE) layer using two sets of weights, w1 and w2.
Parameters:
- output: (torch.Tensor): The unweighted, unreduced output tensor.
- hidden_states: (torch.Tensor): The (quantized) input tensor to the MoE
layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_ids (torch.Tensor): A map of row to expert id.
- activation (str): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
w1.
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
w2.
- a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be
used for a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
- workspace13 (torch.Tensor): A scratch tensor used for gemm outputs
must be large enough to hold output of either MoE gemm.
- workspace2 (torch.Tensor): A scratch tensor used for the activation
function.
- expert_num_tokens: An optional tensor containing the number of tokens
assigned to each expert when using batched experts format input.
"""
raise
NotImplementedError
def
_chunk_scales
(
scales
:
Optional
[
torch
.
Tensor
],
start
:
int
,
end
:
int
)
->
Optional
[
torch
.
Tensor
]:
if
scales
is
not
None
:
...
...
@@ -596,3 +809,186 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_ids
,
apply_router_weight_on_input
)
return
output
@
final
class
DeepGemmDisabledFusedMoEModularKernel
(
torch
.
nn
.
Module
):
"""
This class combines a FusedMoEPrepareAndFinalize instance and
a FusedMoEPermuteExpertsUnpermute to provide an interface that
is compatible with the `fused_experts` function in fused_moe.py.
It takes care of managing any required scratch space.
Note: Instances of this class should only be used for a single model
layer due to any layer specific state that may be used by the component
objects.
"""
def
__init__
(
self
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
fused_experts
:
CustomizedFusedMoEPermuteExpertsUnpermute
,
shared_experts
:
Optional
[
torch
.
nn
.
Module
]
=
None
,
):
super
().
__init__
()
self
.
prepare_finalize
=
prepare_finalize
self
.
fused_experts
=
fused_experts
self
.
shared_experts
=
shared_experts
# assert prepare_finalize.activation_format == \
# fused_experts.activation_formats[0], (
# f"{prepare_finalize.__class__.__name__}."
# f"{prepare_finalize.activation_format} == "
# f"{fused_experts.__class__.__name__}."
# f"{fused_experts.activation_formats[0]}")
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
**
_
)
->
torch
.
Tensor
:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
of weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states: (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_weights (torch.Tensor): The topk weights applied at the end of
the layer.
- topk_ids (torch.Tensor): A map of row to expert id.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
w1.
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
w2.
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is
1.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
a1
=
hidden_states
if
inplace
and
self
.
shared_experts
is
None
:
output
=
hidden_states
else
:
output
=
torch
.
zeros_like
(
hidden_states
)
local_num_experts
=
w1
.
size
(
0
)
if
global_num_experts
==
-
1
:
global_num_experts
=
local_num_experts
prepare_ret
=
self
.
prepare_finalize
.
prepare_async
(
a1
,
a1_scale
,
a2_scale
,
topk_weights
,
topk_ids
,
global_num_experts
,
expert_map
,
apply_router_weight_on_input
,
self
.
fused_experts
.
quant_config
,
)
hook
,
receiver
=
(
prepare_ret
if
isinstance
(
prepare_ret
,
tuple
)
else
(
None
,
prepare_ret
)
)
if
hook
is
not
None
:
hook
()
(
a1q
,
a1q_scale
,
expert_tokens_meta
,
_expert_topk_ids
,
_expert_topk_weights
,
)
=
receiver
()
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
topk_ids
=
topk_ids
if
_expert_topk_ids
is
None
else
_expert_topk_ids
topk_weights
=
(
topk_weights
if
_expert_topk_weights
is
None
else
_expert_topk_weights
)
if
a1q
.
numel
()
==
0
:
# This happens when none of the tokens from the all2all reach this
# EP rank. Also, note that this is only relevant for CUDAGraph
# incompatible all2all kernels like the DeepEP high-throughput
# kernels. CUDAGraph compatible all2all kernels like the pplx
# kernels and the DeepEP low-latency kernels are always batched
# and can never run into the tensor.numel() == 0 case.
fused_out
=
torch
.
empty_like
(
a1q
).
to
(
dtype
=
a1
.
dtype
)
else
:
fused_out
=
self
.
fused_experts
.
apply
(
None
,
a1
,
a1q
,
w1
,
w2
,
topk_ids
,
topk_weights
=
topk_weights
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w1_zp
=
w1_zp
,
w2_zp
=
w2_zp
,
a1q_scale
=
a1q_scale
,
a2_scale
=
a2_scale
,
workspace13
=
None
,
workspace2
=
None
,
use_nn_moe
=
use_nn_moe
,
expert_num_tokens
=
expert_tokens_meta
.
expert_num_tokens
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
,
expert_num_tokens_cpu
=
expert_tokens_meta
.
expert_num_tokens_cpu
)
shared_output
=
None
hook
=
self
.
prepare_finalize
.
finalize_async
(
output
,
fused_out
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
apply_weights_and_reduce
=
True
)
if
self
.
shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
if
hook
is
not
None
:
hook
()
if
self
.
shared_experts
is
not
None
:
return
(
shared_output
,
output
)
return
output
vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
View file @
13130b89
...
...
@@ -207,6 +207,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
apply_weights_and_reduce
:
bool
=
True
)
->
None
:
# This argument is optional
# There's not much point setting this unless it is != topk_ids.size(0)
...
...
vllm/model_executor/layers/fused_moe/prepare_finalize.py
View file @
13130b89
...
...
@@ -61,6 +61,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
apply_weights_and_reduce
:
bool
=
True
)
->
None
:
_moe_unpermute_and_reduce
(
output
,
fused_expert_output
,
None
,
topk_weights
,
apply_router_weight_on_input
)
vllm/model_executor/layers/fused_moe/shared_fused_moe.py
0 → 100644
View file @
13130b89
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
import
torch
from
vllm.distributed
import
tensor_model_parallel_all_reduce
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoE
# TODO(bnell): Add shared + fused combo function? e.g. +
class
SharedFusedMoE
(
FusedMoE
):
"""
A FusedMoE operation that also computes the results of shared experts.
If an all2all communicator is being used the shared expert computation
can be interleaved with the fused all2all dispatch communication step.
"""
def
__init__
(
self
,
shared_experts
:
torch
.
nn
.
Module
,
use_overlapped
:
bool
=
True
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
_shared_experts
=
shared_experts
self
.
use_overlapped
=
use_overlapped
@
property
def
shared_experts
(
self
)
->
Optional
[
torch
.
nn
.
Module
]:
return
self
.
_shared_experts
if
self
.
use_overlapped
else
None
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
if
not
self
.
use_overlapped
:
shared_out
=
self
.
_shared_experts
(
hidden_states
)
# Reduce outputs if necessary, since the MLP should
# have been created with reduce_results=False.
if
(
self
.
reduce_results
and
self
.
tp_size
>
1
and
self
.
must_reduce_shared_expert_outputs
()):
shared_out
=
tensor_model_parallel_all_reduce
(
shared_out
)
fused_out
=
super
().
forward
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
)
else
:
# Matrix multiply.
fused_out
=
super
().
forward
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
)
return
fused_out
vllm/model_executor/layers/fused_moe/triton_group_gemm_moe.py
0 → 100644
View file @
13130b89
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
DeepGemmExperts
,
_valid_deep_gemm
,
_valid_deep_gemm_shape
)
class
TritonOrGroupGemmExperts
(
mk
.
CustomizedFusedMoEPermuteExpertsUnpermute
):
def
__init__
(
self
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a8
:
bool
=
False
,
per_act_token_quant
:
bool
=
False
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
fused_experts
=
None
):
super
().
__init__
(
FusedMoEQuantConfig
.
make
(
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int4_w4a8
=
use_int4_w4a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
))
self
.
fused_experts
=
fused_experts
@
property
def
activation_formats
(
self
)
->
tuple
[
mk
.
FusedMoEActivationFormat
,
mk
.
FusedMoEActivationFormat
]:
return
(
mk
.
FusedMoEActivationFormat
.
Standard
,
mk
.
FusedMoEActivationFormat
.
Standard
)
def
supports_chunking
(
self
)
->
bool
:
return
True
def
supports_expert_map
(
self
)
->
bool
:
return
True
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
raise
NotImplementedError
def
apply
(
self
,
output
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
q_hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
global_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
w1_scale
:
Optional
[
torch
.
Tensor
],
w2_scale
:
Optional
[
torch
.
Tensor
],
w1_zp
:
Optional
[
torch
.
Tensor
],
w2_zp
:
Optional
[
torch
.
Tensor
],
a1q_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
workspace13
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
topk_weights
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
expert_num_tokens_cpu
:
torch
.
Tensor
=
None
,
):
assert
self
.
fused_experts
is
not
None
return
self
.
fused_experts
(
x
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_ids
=
topk_ids
,
topk_weights
=
topk_weights
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
False
,
activation
=
activation
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1q_scale
,
a2_scale
=
a2_scale
,
expert_num_tokens
=
expert_num_tokens
,
use_nn_moe
=
use_nn_moe
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
,
q_x
=
q_hidden_states
,
expert_num_tokens_cpu
=
expert_num_tokens_cpu
)
vllm/model_executor/layers/fused_moe/utils.py
View file @
13130b89
...
...
@@ -4,10 +4,21 @@ from math import prod
from
typing
import
Optional
import
torch
from
torch
import
nn
import
triton
import
triton.language
as
tl
from
triton.language.extra
import
libdevice
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
)
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.utils
import
round_up
try
:
from
lmslim.layers.gemm.int8_utils
import
(
per_token_group_quant_int8
,
per_token_quant_int8
)
...
...
@@ -48,12 +59,221 @@ def _fp8_quantize(
return
A
,
A_scale
# @triton.jit
# def _per_token_quant_int8_kernel_opt(
# x_ptr,
# xq_ptr,
# scale_ptr,
# stride_x,
# stride_xq,
# N,
# T_dim,
# has_tokens_per_expert: tl.constexpr,
# tokens_per_expert_ptr,
# BLOCK: tl.constexpr
# ):
# row_id = tl.program_id(0)
# if has_tokens_per_expert:
# e = row_id // T_dim
# t = row_id % T_dim
# num_valid_tokens_for_e = tl.load(tokens_per_expert_ptr + e)
# if t >= num_valid_tokens_for_e:
# return
# cols = tl.arange(0, BLOCK)
# mask = cols < N
# x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
# other=0.0).to(tl.float32)
# absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
# scale_x = absmax / 127
# x_q = x * (127 / absmax)
# x_q = libdevice.nearbyint(x_q).to(tl.int8)
# tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
# tl.store(scale_ptr + row_id, scale_x)
# def per_token_quant_int8_triton_opt(x: torch.Tensor,
# tokens_per_expert: Optional[torch.Tensor] = None):
# """
# Python wrapper for the Triton kernel.
# """
# if x.dim() != 3:
# raise ValueError(f"Input must be 3D [E, T, H], but got {x.shape}")
# E, T, H = x.shape
# M = E * T
# N = H
# x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
# scales = torch.empty(x.shape[:-1] + (1, ),
# device=x.device,
# dtype=torch.float32)
# BLOCK = triton.next_power_of_2(N)
# num_warps = min(max(BLOCK // 256, 1), 8)
# grid_opt = M
# _per_token_quant_int8_kernel_opt[(grid_opt, )](
# x,
# x_q,
# scales,
# stride_x=x.stride(-2),
# stride_xq=x_q.stride(-2),
# N=N,
# T_dim=T,
# has_tokens_per_expert= tokens_per_expert is not None,
# tokens_per_expert_ptr=tokens_per_expert,
# BLOCK=BLOCK,
# num_warps=num_warps,
# num_stages=1,
# )
# return x_q, scales
@
triton
.
jit
def
_per_token_quant_int8_one_kernel_opt
(
x_ptr
,
xq_ptr
,
scale_ptr
,
stride_x
,
stride_xq
,
N
,
T_dim
,
has_tokens_per_expert
:
tl
.
constexpr
,
tokens_per_expert_ptr
,
BLOCK
:
tl
.
constexpr
):
row_id
=
tl
.
program_id
(
0
)
if
has_tokens_per_expert
:
e
=
row_id
//
T_dim
t
=
row_id
%
T_dim
num_valid_tokens_for_e
=
tl
.
load
(
tokens_per_expert_ptr
+
e
)
if
t
>=
num_valid_tokens_for_e
:
return
cols
=
tl
.
arange
(
0
,
BLOCK
)
mask
=
cols
<
N
x
=
tl
.
load
(
x_ptr
+
row_id
*
stride_x
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
x
)),
1e-10
)
scale_x
=
absmax
/
127
x_q
=
x
*
(
127
/
absmax
)
x_q
=
libdevice
.
nearbyint
(
x_q
).
to
(
tl
.
int8
)
tl
.
store
(
xq_ptr
+
row_id
*
stride_xq
+
cols
,
x_q
,
mask
=
mask
)
tl
.
store
(
scale_ptr
+
row_id
,
scale_x
)
@
triton
.
jit
def
_per_token_quant_int8_kernel_opt
(
x_ptr
,
xq_ptr
,
scale_ptr
,
stride_x
,
stride_xq
,
N
,
E_dim
,
T_dim
,
has_tokens_per_expert
:
tl
.
constexpr
,
tokens_per_expert_ptr
,
BLOCK
:
tl
.
constexpr
):
token_idx_start
=
tl
.
program_id
(
0
)
grid_size
=
tl
.
num_programs
(
0
)
num_total_tokens
=
E_dim
*
T_dim
for
token_idx
in
range
(
token_idx_start
,
num_total_tokens
,
grid_size
):
is_valid_token
=
True
if
has_tokens_per_expert
:
e
=
token_idx
//
T_dim
t
=
token_idx
%
T_dim
num_valid_tokens_for_e
=
tl
.
load
(
tokens_per_expert_ptr
+
e
)
if
t
>=
num_valid_tokens_for_e
:
is_valid_token
=
False
if
is_valid_token
:
cols
=
tl
.
arange
(
0
,
BLOCK
)
mask
=
cols
<
N
x
=
tl
.
load
(
x_ptr
+
token_idx
*
stride_x
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
x
)),
1e-10
)
scale_x
=
absmax
/
127
x_q
=
x
*
(
127
/
absmax
)
x_q
=
libdevice
.
nearbyint
(
x_q
).
to
(
tl
.
int8
)
tl
.
store
(
xq_ptr
+
token_idx
*
stride_xq
+
cols
,
x_q
,
mask
=
mask
)
tl
.
store
(
scale_ptr
+
token_idx
,
scale_x
)
def
per_token_quant_int8_triton_opt
(
x
:
torch
.
Tensor
,
tokens_per_expert
:
Optional
[
torch
.
Tensor
]
=
None
):
if
x
.
dim
()
!=
3
:
raise
ValueError
(
f
"Input must be 3D [E, T, H], but got
{
x
.
shape
}
"
)
E
,
T
,
H
=
x
.
shape
N
=
H
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
torch
.
int8
)
scales
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
1
,
),
device
=
x
.
device
,
dtype
=
torch
.
float32
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
if
T
>=
4096
:
num_warps
=
1
num_tokens
=
E
*
T
grid_opt
=
num_tokens
if
E
==
16
and
T
>=
1024
:
grid_opt
=
max
(
1
,
num_tokens
//
(
T
//
256
))
_per_token_quant_int8_kernel_opt
[(
grid_opt
,
)](
x
,
x_q
,
scales
,
stride_x
=
x
.
stride
(
-
2
),
stride_xq
=
x_q
.
stride
(
-
2
),
N
=
N
,
E_dim
=
E
,
T_dim
=
T
,
has_tokens_per_expert
=
tokens_per_expert
is
not
None
,
tokens_per_expert_ptr
=
tokens_per_expert
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
else
:
_per_token_quant_int8_one_kernel_opt
[(
grid_opt
,
)](
x
,
x_q
,
scales
,
stride_x
=
x
.
stride
(
-
2
),
stride_xq
=
x_q
.
stride
(
-
2
),
N
=
N
,
T_dim
=
T
,
has_tokens_per_expert
=
tokens_per_expert
is
not
None
,
tokens_per_expert_ptr
=
tokens_per_expert
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
return
x_q
,
scales
def
_int8_quantize
(
A
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
per_act_token
:
bool
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Perform int8 quantization on the inputs. If a block_shape
...
...
@@ -64,9 +284,12 @@ def _int8_quantize(
# activations apply per-token quantization. Otherwise, assume
# activation tensor-wise fp8/int8 quantization, dynamic or static
if
block_shape
is
None
:
assert
per_act_token
,
\
"int8 quantization only supports block or channel-wise"
# assert per_act_token, \
# "int8 quantization only supports block or channel-wise"
if
expert_num_tokens
is
None
:
A
,
A_scale
=
per_token_quant_int8
(
A
)
else
:
A
,
A_scale
=
per_token_quant_int8_triton_opt
(
A
,
expert_num_tokens
)
else
:
assert
not
per_act_token
assert
len
(
block_shape
)
==
2
...
...
@@ -83,11 +306,12 @@ def moe_kernel_quantize_input(
quant_dtype
:
Optional
[
torch
.
dtype
],
per_act_token_quant
:
bool
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
if
quant_dtype
==
torch
.
float8_e4m3fn
:
return
_fp8_quantize
(
A
,
A_scale
,
per_act_token_quant
,
block_shape
)
elif
quant_dtype
==
torch
.
int8
:
return
_int8_quantize
(
A
,
A_scale
,
per_act_token_quant
,
block_shape
)
return
_int8_quantize
(
A
,
A_scale
,
per_act_token_quant
,
block_shape
,
expert_num_tokens
)
else
:
return
A
,
A_scale
...
...
@@ -145,3 +369,536 @@ def _validate_scale_shape(
assert
block_shape
is
not
None
expected
=
(
a
.
shape
[
0
],
cdiv
(
a
.
shape
[
1
],
block_shape
[
1
]))
assert
a_scale
.
shape
==
expected
,
f
"
{
a_scale
.
shape
}
==
{
expected
}
"
@
triton
.
jit
def
_count_expert_num_tokens
(
topk_ids_ptr
,
expert_num_tokens_ptr
,
num_experts
,
topk_numel
,
expert_map
,
HAS_EXPERT_MAP
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
curr_expert
=
tl
.
program_id
(
0
)
offsets
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
topk_ids_ptrs
=
topk_ids_ptr
+
offsets
acc
=
tl
.
zeros
((
BLOCK_SIZE
,),
dtype
=
tl
.
int32
)
for
x
in
range
(
tl
.
cdiv
(
topk_numel
,
BLOCK_SIZE
)):
mask
=
offsets
<
(
topk_numel
-
x
*
BLOCK_SIZE
)
expert_ids
=
tl
.
load
(
topk_ids_ptrs
,
mask
=
mask
,
other
=-
1
)
if
HAS_EXPERT_MAP
:
expert_map_ptrs
=
expert_map
+
expert_ids
expert_map_mask
=
expert_ids
>=
0
expert_ids
=
tl
.
load
(
expert_map_ptrs
,
mask
=
expert_map_mask
,
other
=-
1
)
has_curr_expert
=
tl
.
where
(
expert_ids
==
curr_expert
,
1
,
0
)
acc
=
acc
+
has_curr_expert
topk_ids_ptrs
+=
BLOCK_SIZE
if
curr_expert
<
num_experts
:
tl
.
store
(
expert_num_tokens_ptr
+
curr_expert
,
tl
.
sum
(
acc
))
def
count_expert_num_tokens
(
topk_ids
:
torch
.
Tensor
,
num_local_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
)
->
torch
.
Tensor
:
"""
Count the number to tokens assigned to each expert.
Parameters:
- topk_ids (torch.Tensor): Tensor mapping each token to its
list of experts.
- num_local_experts (int): Number of experts in this rank.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
Returns:
A tensor of size num_local_experts, where tensor[i] holds the number
of tokens assigned to the ith expert.
"""
assert
topk_ids
.
dtype
.
is_signed
,
"The kernel uses -1 to represent invalid topk_ids"
expert_num_tokens
=
torch
.
empty
(
(
num_local_experts
),
device
=
topk_ids
.
device
,
dtype
=
torch
.
int32
)
grid
=
num_local_experts
BLOCK_SIZE
=
min
(
topk_ids
.
numel
(),
1024
)
BLOCK_SIZE
=
triton
.
next_power_of_2
(
BLOCK_SIZE
)
_count_expert_num_tokens
[(
grid
,)](
topk_ids
,
expert_num_tokens
,
num_local_experts
,
topk_ids
.
numel
(),
expert_map
,
HAS_EXPERT_MAP
=
expert_map
is
not
None
,
BLOCK_SIZE
=
BLOCK_SIZE
,
)
return
expert_num_tokens
def
expert_num_tokens_round_up_and_sum
(
expert_num_tokens
:
torch
.
Tensor
,
alignment
:
int
)
->
int
:
# Round up each element in expert_num_tokens to the nearest multiple of
# alignment.
ent
=
(
expert_num_tokens
.
to
(
torch
.
int64
)
+
(
alignment
-
1
))
//
alignment
*
alignment
return
torch
.
sum
(
ent
).
item
()
def
compute_aligned_M
(
M
:
int
,
num_topk
:
int
,
local_num_experts
:
int
,
alignment
:
int
,
expert_num_tokens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
,
):
if
expert_num_tokens_cpu
is
not
None
:
return
expert_num_tokens_round_up_and_sum
(
expert_num_tokens_cpu
,
alignment
=
alignment
)
# expert_num_tokens information is not available on the cpu.
# compute the max required size.
M_sum
=
(
M
*
num_topk
)
+
local_num_experts
*
(
alignment
-
1
)
M_sum
=
round_up
(
M_sum
,
alignment
)
return
M_sum
@
triton
.
jit
def
apply_expert_map
(
expert_id
,
expert_map
):
if
expert_id
!=
-
1
:
expert_id
=
tl
.
load
(
expert_map
+
expert_id
).
to
(
expert_id
.
dtype
)
return
expert_id
@
triton
.
jit
def
round_up_256
(
x
:
int
)
->
int
:
y
=
256
return
((
x
+
y
-
1
)
//
y
)
*
y
@
triton
.
jit
def
round_up_128
(
x
:
int
)
->
int
:
y
=
128
return
((
x
+
y
-
1
)
//
y
)
*
y
@
triton
.
jit
def
_fwd_kernel_ep_scatter_1
(
num_recv_tokens_per_expert
,
expert_start_loc
,
m_indices
,
num_experts
:
tl
.
constexpr
,
BLOCK_E
:
tl
.
constexpr
,
BLOCK_EXPERT_NUM
:
tl
.
constexpr
,
):
cur_expert
=
tl
.
program_id
(
0
)
offset_cumsum
=
tl
.
arange
(
0
,
BLOCK_EXPERT_NUM
)
tokens_per_expert
=
tl
.
load
(
num_recv_tokens_per_expert
+
offset_cumsum
,
mask
=
offset_cumsum
<
num_experts
,
other
=
0
,
)
#tokens_per_expert = round_up_128(tokens_per_expert)
tokens_per_expert
=
round_up_256
(
tokens_per_expert
)
cumsum
=
tl
.
cumsum
(
tokens_per_expert
)
-
tokens_per_expert
#if cur_expert == 0:
tl
.
store
(
expert_start_loc
+
offset_cumsum
,
cumsum
,
mask
=
offset_cumsum
<
num_experts
)
tl
.
debug_barrier
()
#cur_expert_start = cumsum[cur_expert]
cur_expert_start
=
tl
.
load
(
expert_start_loc
+
cur_expert
)
cur_expert_token_num
=
tl
.
load
(
num_recv_tokens_per_expert
+
cur_expert
)
m_indices_start_ptr
=
m_indices
+
cur_expert_start
off_expert
=
tl
.
arange
(
0
,
BLOCK_E
)
for
start_m
in
tl
.
range
(
0
,
cur_expert_token_num
,
BLOCK_E
,
num_stages
=
4
):
tl
.
store
(
m_indices_start_ptr
+
start_m
+
off_expert
,
cur_expert
,
mask
=
start_m
+
off_expert
<
cur_expert_token_num
)
@
triton
.
jit
def
_fwd_kernel_ep_scatter_2
(
total_token_num
,
expert_start_loc
,
recv_x
,
recv_x_stride0
,
recv_x_stride1
,
recv_x_scale
,
recv_x_scale_stride0
,
recv_x_scale_stride1
,
recv_topk
,
recv_topk_stride0
,
recv_topk_stride1
,
output_tensor
,
output_tensor_stride0
,
output_tensor_stride1
,
output_tensor_scale
,
output_tensor_scale_stride0
,
output_tensor_scale_stride1
,
output_index
,
output_index_stride0
,
output_index_stride1
,
topk_num
:
tl
.
constexpr
,
expert_map
,
HAS_EXPERT_MAP
:
tl
.
constexpr
,
HIDDEN_SIZE
:
tl
.
constexpr
,
HIDDEN_SIZE_PAD
:
tl
.
constexpr
,
SCALE_HIDDEN_SIZE
:
tl
.
constexpr
,
SCALE_HIDDEN_SIZE_PAD
:
tl
.
constexpr
,
):
start_token_id
=
tl
.
program_id
(
0
)
grid_num
=
tl
.
num_programs
(
0
)
offset_in
=
tl
.
arange
(
0
,
HIDDEN_SIZE_PAD
)
mask
=
offset_in
<
HIDDEN_SIZE
index_in_s
=
tl
.
arange
(
0
,
SCALE_HIDDEN_SIZE_PAD
)
mask_s
=
index_in_s
<
SCALE_HIDDEN_SIZE
for
token_id_int32
in
range
(
start_token_id
,
total_token_num
,
grid_num
):
token_id
=
token_id_int32
.
to
(
tl
.
int64
)
to_copy
=
tl
.
load
(
recv_x
+
token_id
*
recv_x_stride0
+
offset_in
,
mask
=
mask
)
to_copy_s
=
tl
.
load
(
recv_x_scale
+
token_id
*
recv_x_scale_stride0
+
index_in_s
*
recv_x_scale_stride1
,
mask
=
mask_s
,
)
for
topk_idx_int32
in
tl
.
range
(
0
,
topk_num
,
1
,
num_stages
=
4
):
topk_index
=
topk_idx_int32
.
to
(
tl
.
int64
)
expert_id
=
tl
.
load
(
recv_topk
+
token_id
*
recv_topk_stride0
+
topk_index
)
if
HAS_EXPERT_MAP
:
expert_id
=
apply_expert_map
(
expert_id
,
expert_map
)
if
expert_id
>=
0
:
dest_token_index_int32
=
tl
.
atomic_add
(
expert_start_loc
+
expert_id
,
1
)
dest_token_index
=
dest_token_index_int32
.
to
(
tl
.
int64
)
tl
.
store
(
output_index
+
token_id
*
output_index_stride0
+
topk_index
,
dest_token_index_int32
,
)
output_tensor_ptr
=
(
output_tensor
+
dest_token_index
*
output_tensor_stride0
)
output_tensor_scale_ptr
=
(
output_tensor_scale
+
dest_token_index
*
output_tensor_scale_stride0
)
tl
.
store
(
output_tensor_ptr
+
offset_in
,
to_copy
,
mask
=
mask
)
tl
.
store
(
output_tensor_scale_ptr
+
index_in_s
*
output_tensor_scale_stride1
,
to_copy_s
,
mask
=
mask_s
,
)
@
torch
.
no_grad
()
def
ep_scatter
(
recv_x
:
torch
.
Tensor
,
recv_x_scale
:
torch
.
Tensor
,
recv_topk
:
torch
.
Tensor
,
num_recv_tokens_per_expert
:
torch
.
Tensor
,
expert_map
:
torch
.
Tensor
|
None
,
expert_start_loc
:
torch
.
Tensor
,
output_tensor
:
torch
.
Tensor
,
output_tensor_scale
:
torch
.
Tensor
,
m_indices
:
torch
.
Tensor
,
output_index
:
torch
.
Tensor
,
):
#BLOCK_E = 128 # token num of per expert is aligned to 128
#BLOCK_D = 128 # block size of quantization
BLOCK_E
=
256
# token num of per expert is aligned to 256
num_warps
=
8
num_experts
=
num_recv_tokens_per_expert
.
shape
[
0
]
hidden_size
=
recv_x
.
shape
[
1
]
scale_hidden_size
=
recv_x_scale
.
shape
[
-
1
]
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
grid
=
num_experts
assert
m_indices
.
shape
[
0
]
%
BLOCK_E
==
0
_fwd_kernel_ep_scatter_1
[(
grid
,)](
num_recv_tokens_per_expert
,
expert_start_loc
,
m_indices
,
num_experts
=
num_experts
,
num_warps
=
num_warps
,
BLOCK_E
=
BLOCK_E
,
BLOCK_EXPERT_NUM
=
triton
.
next_power_of_2
(
num_experts
),
)
grid
=
min
(
recv_topk
.
shape
[
0
],
1024
*
8
)
_fwd_kernel_ep_scatter_2
[(
grid
,)](
recv_topk
.
shape
[
0
],
expert_start_loc
,
recv_x
,
recv_x
.
stride
(
0
),
recv_x
.
stride
(
1
),
recv_x_scale
,
recv_x_scale
.
stride
(
0
),
recv_x_scale
.
stride
(
1
),
recv_topk
,
recv_topk
.
stride
(
0
),
recv_topk
.
stride
(
1
),
output_tensor
,
output_tensor
.
stride
(
0
),
output_tensor
.
stride
(
1
),
output_tensor_scale
,
output_tensor_scale
.
stride
(
0
),
output_tensor_scale
.
stride
(
1
),
output_index
,
output_index
.
stride
(
0
),
output_index
.
stride
(
1
),
topk_num
=
recv_topk
.
shape
[
1
],
expert_map
=
expert_map
,
HAS_EXPERT_MAP
=
expert_map
is
not
None
,
num_warps
=
num_warps
,
HIDDEN_SIZE
=
hidden_size
,
HIDDEN_SIZE_PAD
=
triton
.
next_power_of_2
(
hidden_size
),
SCALE_HIDDEN_SIZE
=
scale_hidden_size
,
#hidden_size // BLOCK_D,
SCALE_HIDDEN_SIZE_PAD
=
triton
.
next_power_of_2
(
scale_hidden_size
)
#triton.next_power_of_2(hidden_size // BLOCK_D),
)
return
@
triton
.
jit
def
_fwd_kernel_ep_gather
(
total_token_num
,
input_tensor
,
input_tensor_stride0
,
input_tensor_stride1
,
recv_topk_ids
,
recv_topk_ids_stride0
,
recv_topk_ids_stride1
,
recv_topk_weight
,
recv_topk_weight_stride0
,
recv_topk_weight_stride1
,
input_index
,
input_index_stride0
,
input_index_stride1
,
output_tensor
,
output_tensor_stride0
,
output_tensor_stride1
,
topk_num
:
tl
.
constexpr
,
expert_map
,
HAS_EXPERT_MAP
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
):
cur_block_int32
=
tl
.
program_id
(
0
)
cur_block
=
cur_block_int32
.
to
(
tl
.
int64
)
start_cur_token_int32
=
tl
.
program_id
(
1
)
grid_num
=
tl
.
num_programs
(
1
)
for
cur_token_int32
in
range
(
start_cur_token_int32
,
total_token_num
,
grid_num
):
cur_token
=
cur_token_int32
.
to
(
tl
.
int64
)
off_d
=
tl
.
arange
(
0
,
BLOCK_D
)
accumulator
=
tl
.
zeros
([
BLOCK_D
],
dtype
=
tl
.
float32
)
for
topk_index_int32
in
range
(
0
,
topk_num
):
topk_index
=
topk_index_int32
.
to
(
tl
.
int64
)
expert_id
=
tl
.
load
(
recv_topk_ids
+
cur_token
*
recv_topk_ids_stride0
+
topk_index
)
if
HAS_EXPERT_MAP
:
expert_id
=
apply_expert_map
(
expert_id
,
expert_map
)
if
expert_id
>=
0
:
source_token_index_int32
=
tl
.
load
(
input_index
+
cur_token
*
input_index_stride0
+
topk_index
)
source_token_index
=
source_token_index_int32
.
to
(
tl
.
int64
)
acc_weight
=
tl
.
load
(
recv_topk_weight
+
cur_token
*
recv_topk_weight_stride0
+
topk_index
)
tmp
=
tl
.
load
(
input_tensor
+
source_token_index
*
input_tensor_stride0
+
cur_block
*
BLOCK_D
+
off_d
)
accumulator
+=
tmp
.
to
(
tl
.
float32
)
*
acc_weight
tl
.
store
(
output_tensor
+
cur_token
*
output_tensor_stride0
+
cur_block
*
BLOCK_D
+
off_d
,
accumulator
.
to
(
output_tensor
.
dtype
.
element_ty
),
)
@
torch
.
no_grad
()
def
ep_gather
(
input_tensor
:
torch
.
Tensor
,
recv_topk_ids
:
torch
.
Tensor
,
recv_topk_weight
:
torch
.
Tensor
,
input_index
:
torch
.
Tensor
,
expert_map
:
torch
.
Tensor
|
None
,
output_tensor
:
torch
.
Tensor
,
):
num_warps
=
2
num_tokens
=
output_tensor
.
shape
[
0
]
hidden_size
=
input_tensor
.
shape
[
1
]
BLOCK_D
=
min
(
hidden_size
,
1024
)
assert
hidden_size
%
BLOCK_D
==
0
grid
=
(
triton
.
cdiv
(
hidden_size
,
BLOCK_D
),
min
(
num_tokens
,
1024
))
_fwd_kernel_ep_gather
[
grid
](
num_tokens
,
input_tensor
,
input_tensor
.
stride
(
0
),
input_tensor
.
stride
(
1
),
recv_topk_ids
,
recv_topk_ids
.
stride
(
0
),
recv_topk_ids
.
stride
(
1
),
recv_topk_weight
,
recv_topk_weight
.
stride
(
0
),
recv_topk_weight
.
stride
(
1
),
input_index
,
input_index
.
stride
(
0
),
input_index
.
stride
(
1
),
output_tensor
,
output_tensor
.
stride
(
0
),
output_tensor
.
stride
(
1
),
topk_num
=
recv_topk_ids
.
shape
[
1
],
expert_map
=
expert_map
,
HAS_EXPERT_MAP
=
expert_map
is
not
None
,
num_warps
=
num_warps
,
BLOCK_D
=
BLOCK_D
,
)
return
def
deepgemm_moe_permute
(
aq
:
torch
.
Tensor
,
aq_scale
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
local_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
block_shape
:
list
[
int
],
expert_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_num_tokens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
,
aq_out
:
torch
.
Tensor
|
None
=
None
,
M_sum
:
int
|
None
=
None
,
):
assert
aq
.
ndim
==
2
assert
topk_ids
.
dtype
.
is_signed
,
"The kernel uses -1 to represent invalid topk_ids"
H
=
aq
.
size
(
1
)
device
=
aq
.
device
block_m
=
block_shape
[
0
]
if
M_sum
is
None
:
M_sum
=
compute_aligned_M
(
M
=
topk_ids
.
size
(
0
),
num_topk
=
topk_ids
.
size
(
1
),
local_num_experts
=
local_num_experts
,
alignment
=
block_m
,
expert_num_tokens_cpu
=
expert_num_tokens_cpu
,
)
expert_start_loc
=
torch
.
empty
(
(
local_num_experts
),
device
=
device
,
dtype
=
torch
.
int32
)
assert
aq_out
is
None
or
aq_out
.
shape
==
(
M_sum
,
H
)
if
aq_out
is
None
:
aq_out
=
torch
.
empty
((
M_sum
,
H
),
device
=
device
,
dtype
=
aq
.
dtype
)
aq_scale_out
=
torch
.
empty
(
(
M_sum
,
aq_scale
.
shape
[
-
1
]),
device
=
device
,
dtype
=
torch
.
float32
)
expert_ids
=
torch
.
full
(
(
M_sum
,),
-
1
,
dtype
=
torch
.
int32
,
device
=
device
)
inv_perm
=
torch
.
empty
(
topk_ids
.
shape
,
device
=
device
,
dtype
=
torch
.
int32
)
if
expert_num_tokens
is
None
:
expert_num_tokens
=
count_expert_num_tokens
(
topk_ids
,
local_num_experts
,
expert_map
)
ep_scatter
(
recv_x
=
aq
,
recv_x_scale
=
aq_scale
,
recv_topk
=
topk_ids
,
num_recv_tokens_per_expert
=
expert_num_tokens
,
expert_start_loc
=
expert_start_loc
,
expert_map
=
expert_map
,
output_tensor
=
aq_out
,
output_tensor_scale
=
aq_scale_out
,
m_indices
=
expert_ids
,
output_index
=
inv_perm
,
)
return
aq_out
,
aq_scale_out
,
expert_ids
,
inv_perm
def
deepgemm_unpermute_and_reduce
(
a
:
torch
.
Tensor
,
# Grouped gemm output
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
inv_perm
:
torch
.
Tensor
,
expert_map
:
torch
.
Tensor
|
None
,
output
:
torch
.
Tensor
,
):
return
ep_gather
(
input_tensor
=
a
,
recv_topk_ids
=
topk_ids
,
recv_topk_weight
=
topk_weights
,
input_index
=
inv_perm
,
expert_map
=
expert_map
,
output_tensor
=
output
,
)
class
EPSharedExperts
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
reduce_results
:
bool
=
True
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
expect_tp_size
=
1
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
reduce_results
=
reduce_results
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
expect_tp_size
=
1
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
vllm/model_executor/layers/linear.py
View file @
13130b89
...
...
@@ -485,9 +485,13 @@ class ColumnParallelLinear(LinearBase):
prefix
:
str
=
""
,
*
,
return_bias
:
bool
=
True
,
expect_tp_size
:
Optional
[
int
]
=
None
,
):
# Divide the weight matrix along the last dimension.
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
if
expect_tp_size
is
not
None
:
self
.
expect_tp_size
=
expect_tp_size
self
.
tp_size
=
self
.
expect_tp_size
self
.
input_size_per_partition
=
input_size
self
.
output_size_per_partition
=
divide
(
output_size
,
self
.
tp_size
)
self
.
output_partition_sizes
=
[
self
.
output_size_per_partition
]
...
...
@@ -749,10 +753,18 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
prefix
:
str
=
""
,
*
,
return_bias
:
bool
=
True
,
expect_tp_size
:
Optional
[
int
]
=
None
,
):
self
.
eps
=
eps
self
.
output_sizes
=
output_sizes
tp_size
=
get_tensor_model_parallel_world_size
()
if
expect_tp_size
is
not
None
:
tp_size
=
expect_tp_size
self
.
expect_tp_size
=
expect_tp_size
self
.
expect_tp_size
=
expect_tp_size
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
super
().
__init__
(
input_size
=
input_size
,
output_size
=
sum
(
output_sizes
),
...
...
@@ -762,7 +774,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
params_dtype
=
params_dtype
,
quant_config
=
quant_config
,
prefix
=
prefix
,
return_bias
=
return_bias
)
return_bias
=
return_bias
,
expect_tp_size
=
expect_tp_size
)
def
weight_loader
(
self
,
param
:
Parameter
,
...
...
@@ -859,6 +872,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert
loaded_shard_id
<
len
(
self
.
output_sizes
)
tp_rank
=
get_tensor_model_parallel_rank
()
tp_size
=
get_tensor_model_parallel_world_size
()
if
self
.
expect_tp_size
is
not
None
and
self
.
expect_tp_size
==
1
:
tp_rank
=
0
tp_size
=
1
if
output_dim
is
not
None
:
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
//
tp_size
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
//
tp_size
...
...
@@ -975,6 +993,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
tp_size
=
get_tensor_model_parallel_world_size
()
if
self
.
expect_tp_size
is
not
None
and
self
.
expect_tp_size
==
1
:
tp_size
=
1
if
hasattr
(
param
,
"expect_tp_size"
):
param
.
expect_tp_size
=
self
.
expect_tp_size
if
isinstance
(
param
,
BlockQuantScaleParameter
):
from
vllm.model_executor.layers.quantization.fp8
import
(
Fp8LinearMethod
,
Fp8MoEMethod
)
...
...
@@ -1405,10 +1428,16 @@ class RowParallelLinear(LinearBase):
prefix
:
str
=
""
,
*
,
return_bias
:
bool
=
True
,
expect_tp_size
:
Optional
[
int
]
=
None
,
):
# Divide the weight matrix along the first dimension.
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
if
expect_tp_size
is
not
None
:
self
.
tp_rank
=
0
self
.
tp_size
=
1
self
.
expect_tp_size
=
expect_tp_size
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
self
.
output_size_per_partition
=
output_size
self
.
output_partition_sizes
=
[
output_size
]
...
...
@@ -1454,6 +1483,10 @@ class RowParallelLinear(LinearBase):
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
tp_rank
=
get_tensor_model_parallel_rank
()
tp_size
=
get_tensor_model_parallel_world_size
()
if
self
.
expect_tp_size
is
not
None
:
tp_rank
=
0
tp_size
=
1
input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
use_bitsandbytes_4bit
=
getattr
(
param
,
"use_bitsandbytes_4bit"
,
False
)
is_sharded_weight
=
getattr
(
param
,
"is_sharded_weight"
,
False
)
...
...
@@ -1506,6 +1539,8 @@ class RowParallelLinear(LinearBase):
assert
loaded_weight
.
numel
()
==
1
loaded_weight
=
loaded_weight
.
reshape
(
1
)
if
self
.
expect_tp_size
is
not
None
and
hasattr
(
param
,
"expect_tp_size"
):
param
.
expect_tp_size
=
self
.
expect_tp_size
param
.
load_row_parallel_weight
(
loaded_weight
=
loaded_weight
)
def
forward
(
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
View file @
13130b89
...
...
@@ -4,17 +4,30 @@
import
enum
from
enum
import
Enum
from
typing
import
Callable
,
Optional
,
List
from
math
import
prod
import
torch
from
compressed_tensors.quantization
import
(
QuantizationStrategy
)
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.config
import
get_current_vllm_config
from
vllm.distributed
import
get_tensor_model_parallel_world_size
,
get_ep_group
,
get_dp_group
from
torch.nn.parameter
import
Parameter
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEActivationFormat
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
FusedMoEConfig
,
FusedMoeWeightScaleSupported
,
FusedMoEPermuteExpertsUnpermute
,
FusedMoEPrepareAndFinalize
,)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.layers.fused_moe.utils
import
_resize_cache
,
compute_aligned_M
,
deepgemm_moe_permute
,
deepgemm_unpermute_and_reduce
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
get_w8a8_int8_marlin_weights
)
get_w8a8_int8_marlin_weights
,
w8a8_nt_kpack2_marlin_weight
)
from
vllm.model_executor.layers.fused_moe.moe_permute_unpermute
import
(
_moe_permute
)
from
vllm.utils
import
round_up
try
:
from
lightop
import
m_grouped_w8a8_gemm_nt_masked
,
m_grouped_w8a8_gemm_nt_contig_asm
,
fuse_silu_mul_quant_ep
,
fuse_silu_mul_quant
from
lmslim.layers.fused_moe.fuse_moe_int8_marlin
import
fused_experts_impl_int8_marlin
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer the quantitative model of moe.
\n
"
)
...
...
@@ -69,11 +82,32 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales."
)
self
.
fused_experts
=
self
.
fused_moe_forward
vllm_config
=
get_current_vllm_config
()
parallel_config
=
vllm_config
.
parallel_config
self
.
dp_size
=
get_dp_group
().
world_size
self
.
ep_size
=
get_ep_group
().
world_size
self
.
use_deepep
=
self
.
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
and
\
(
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_high_throughput"
or
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
)
self
.
use_deepgemm
=
False
if
self
.
use_deepep
:
all2all_manager
=
get_ep_group
().
device_communicator
.
all2all_manager
assert
all2all_manager
is
not
None
self
.
num_dispatchers
=
all2all_manager
.
world_size
self
.
block_shape
=
[
256
,
256
]
self
.
use_deepgemm
=
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
or
envs
.
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
if
self
.
use_deepep
:
self
.
N
=
2
*
intermediate_size_per_partition
self
.
K
=
hidden_size
params_dtype
=
torch
.
int8
# WEIGHTS
...
...
@@ -124,19 +158,270 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
w1_marlin_list
=
[]
for
ii
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
if
not
self
.
use_deepgemm
:
w1_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w13_weight
[
ii
])
else
:
w1_marlin_in
=
w8a8_nt_kpack2_marlin_weight
(
layer
.
w13_weight
[
ii
])
w1_marlin_list
.
append
(
w1_marlin_in
)
w1_marlin
=
torch
.
stack
(
w1_marlin_list
,
dim
=
0
)
del
w1_marlin_list
w2_marlin_list
=
[]
for
ii
in
range
(
layer
.
w2_weight
.
shape
[
0
]):
if
not
self
.
use_deepgemm
:
w2_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w2_weight
[
ii
])
else
:
w2_marlin_in
=
w8a8_nt_kpack2_marlin_weight
(
layer
.
w2_weight
[
ii
])
w2_marlin_list
.
append
(
w2_marlin_in
)
w2_marlin
=
torch
.
stack
(
w2_marlin_list
,
dim
=
0
)
w2_marlin
=
torch
.
stack
(
w2_marlin_list
,
dim
=
0
)
layer
.
w13_weight
=
Parameter
(
w1_marlin
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w2_marlin
,
requires_grad
=
False
)
def
masked_groupgemm_workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,):
assert
a
.
dim
()
==
2
# FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed.
num_dispatchers
=
self
.
num_dispatchers
num_experts
=
local_num_experts
max_num_tokens
=
a
.
size
(
0
)
if
self
.
max_num_tokens_per_rank
is
None
else
self
.
max_num_tokens_per_rank
workspace13
=
(
num_experts
,
max_num_tokens
*
num_dispatchers
,
max
(
K
,
N
))
workspace2
=
(
num_experts
,
max_num_tokens
*
num_dispatchers
,
(
N
//
2
))
output
=
(
num_experts
,
max_num_tokens
*
num_dispatchers
,
K
)
return
(
workspace13
,
workspace2
,
output
,
a
.
dtype
)
def
contiguous_groupgemm_workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_num_tokens_cpu
:
torch
.
Tensor
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
assert
self
.
block_shape
is
not
None
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.
block_m
=
self
.
block_shape
[
0
]
M_sum
=
compute_aligned_M
(
M
,
topk
,
local_num_experts
,
block_m
,
expert_num_tokens_cpu
)
assert
M_sum
%
block_m
==
0
workspace1
=
(
M_sum
,
max
(
N
,
K
))
workspace2
=
(
M_sum
,
max
(
N
//
2
,
K
))
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
,
a
.
dtype
,
M_sum
)
def
w8a8_groupgemm_masked_forward
(
self
,
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
q_x
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_num_tokens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
):
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
local_num_experts
=
w1
.
size
(
0
)
E
,
max_num_tokens
,
_
,
_
,
top_k
=
mk
.
_moe_problem_size
(
q_x
,
w1
,
w2
,
topk_ids
)
N
,
K
=
self
.
N
,
self
.
K
(
workspace13_shape
,
workspace2_shape
,
fused_out_shape
,
workspace_dtype
)
=
self
.
masked_groupgemm_workspace_shapes
(
x
,
q_x
,
max_num_tokens
,
N
,
K
,
top_k
,
global_num_experts
,
local_num_experts
)
workspace13
=
torch
.
empty
(
prod
(
workspace13_shape
),
device
=
x
.
device
,
dtype
=
workspace_dtype
)
workspace1
=
_resize_cache
(
workspace13
,
(
E
,
max_num_tokens
,
N
))
fused_out
=
_resize_cache
(
workspace13
,
fused_out_shape
)
# (from deepgemm docs) : A value hint (which is a value on CPU)
# for the M expectation of each batch, correctly setting this value
# may lead to better performance.
#expected_m = max_num_tokens
ori_bs
=
x
.
shape
[
0
]
expected_m
=
ori_bs
*
self
.
ep_size
# expected_m = (
# x.shape[0] * self.dp_size * topk_ids.shape[1]
# + global_num_experts
# ) // global_num_experts
m_grouped_w8a8_gemm_nt_masked
((
q_x
,
a1_scale
),
(
w1
,
w1_scale
),
workspace1
,
expert_num_tokens
,
expected_m
,
)
assert
expert_num_tokens
is
not
None
a2q
,
a2q_scale
=
fuse_silu_mul_quant_ep
(
workspace1
,
expert_num_tokens
)
m_grouped_w8a8_gemm_nt_masked
((
a2q
,
a2q_scale
),
(
w2
,
w2_scale
),
fused_out
,
expert_num_tokens
,
expected_m
)
return
fused_out
def
w8a8_groupgemm_contiguous_forward
(
self
,
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
q_x
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_num_tokens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
):
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
local_num_experts
=
w1
.
size
(
0
)
a1q
=
q_x
N
,
K
=
self
.
N
,
self
.
K
(
workspace13_shape
,
workspace2_shape
,
fused_out_shape
,
workspace_dtype
,
M_sum
)
=
self
.
contiguous_groupgemm_workspace_shapes
(
x
,
q_x
,
topk_ids
.
size
(
0
),
N
,
K
,
topk_ids
.
size
(
1
),
global_num_experts
,
local_num_experts
,
expert_num_tokens_cpu
)
workspace13
=
torch
.
empty
(
prod
(
workspace13_shape
),
device
=
x
.
device
,
dtype
=
workspace_dtype
)
workspace2
=
torch
.
empty
(
prod
(
workspace2_shape
),
device
=
x
.
device
,
dtype
=
workspace_dtype
)
mm1_out
=
_resize_cache
(
workspace13
,
(
M_sum
,
N
))
mm2_out
=
_resize_cache
(
workspace2
,
(
M_sum
,
K
))
act_out
=
_resize_cache
(
workspace2
,
(
M_sum
,
N
//
2
))
quant_out
=
_resize_cache
(
workspace13
.
view
(
dtype
=
a1q
.
dtype
),
(
M_sum
,
N
//
2
)
)
fused_out
=
_resize_cache
(
workspace13
,
fused_out_shape
)
a1q_perm
=
_resize_cache
(
workspace2
.
view
(
dtype
=
a1q
.
dtype
),
(
M_sum
,
K
))
a1q
,
a1q_scale
,
expert_ids
,
inv_perm
=
deepgemm_moe_permute
(
aq
=
a1q
,
aq_scale
=
a1_scale
,
topk_ids
=
topk_ids
,
local_num_experts
=
local_num_experts
,
expert_map
=
expert_map
,
block_shape
=
self
.
block_shape
,
expert_num_tokens
=
expert_num_tokens
,
expert_num_tokens_cpu
=
expert_num_tokens_cpu
,
aq_out
=
a1q_perm
,
M_sum
=
M_sum
)
m_grouped_w8a8_gemm_nt_contig_asm
(
(
a1q
,
a1q_scale
),
(
w1
,
w1_scale
),
mm1_out
,
expert_ids
)
a2q
,
a2q_scale
=
fuse_silu_mul_quant
(
mm1_out
)
#a2q, a2q_scale = fuse_silu_mul_quant(input=mm1_out, expert_ids=expert_ids)
m_grouped_w8a8_gemm_nt_contig_asm
(
(
a2q
,
a2q_scale
),
(
w2
,
w2_scale
),
mm2_out
,
expert_ids
)
if
apply_router_weight_on_input
:
topk_weights
=
torch
.
ones_like
(
topk_weights
)
deepgemm_unpermute_and_reduce
(
a
=
mm2_out
,
topk_ids
=
topk_ids
,
topk_weights
=
topk_weights
,
inv_perm
=
inv_perm
,
expert_map
=
expert_map
,
output
=
fused_out
,
)
return
fused_out
def
fused_moe_forward
(
self
,
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
i_q
:
Optional
[
torch
.
Tensor
]
=
None
,
i_s
:
Optional
[
torch
.
Tensor
]
=
None
,
q_x
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_num_tokens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
):
return
fused_experts_impl_int8_marlin
(
hidden_states
=
x
if
q_x
is
None
else
q_x
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_int8_w8a8
=
True
,
per_channel_quant
=
True
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
use_nn_moe
=
False
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
,
i_q
=
i_q
,
i_s
=
i_s
)
def
apply
(
self
,
...
...
@@ -164,14 +449,14 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
i_q
:
Optional
[
torch
.
Tensor
]
=
None
,
i_s
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
i_s
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
)
->
torch
.
Tensor
:
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for "
"`CompressedTensorsW8A8Int8MoEMethod` yet."
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
...
...
@@ -184,27 +469,63 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
use_fused_gate
=
use_fused_gate
,
e_score_correction_bias
=
e_score_correction_bias
)
e_score_correction_bias
=
e_score_correction_bias
,
indices_type
=
torch
.
int64
if
self
.
use_deepep
else
None
,)
return
fused_experts
_impl_int8_marlin
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
return
self
.
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_int8_w8a8
=
True
,
per_channel_quant
=
True
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
w1_scale
=
(
layer
.
w13_weight_scale
)
,
w2_scale
=
(
layer
.
w2_weight_scale
)
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
False
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
,
i_q
=
i_q
,
i_s
=
i_s
)
i_s
=
i_s
)
def
select_gemm_impl
(
self
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
moe
:
FusedMoEConfig
,
)
->
FusedMoEPermuteExpertsUnpermute
:
from
vllm.model_executor.layers.fused_moe
import
(
TritonOrGroupGemmExperts
)
if
(
prepare_finalize
.
activation_format
==
FusedMoEActivationFormat
.
BatchedExperts
):
max_num_tokens_per_rank
=
(
prepare_finalize
.
max_num_tokens_per_rank
())
assert
max_num_tokens_per_rank
is
not
None
self
.
max_num_tokens_per_rank
=
max_num_tokens_per_rank
logger
.
debug
(
"TritonOrGroupGemmExperts(%s): "
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s"
,
self
.
__class__
.
__name__
,
max_num_tokens_per_rank
,
None
,
True
)
return
TritonOrGroupGemmExperts
(
use_int8_w8a8
=
True
,
per_act_token_quant
=
True
,
fused_experts
=
self
.
w8a8_groupgemm_masked_forward
)
else
:
logger
.
debug
(
"TritonOrGroupGemmExperts(%s): block_size=%s, per_act_token=%s"
,
self
.
__class__
.
__name__
,
None
,
False
)
return
TritonOrGroupGemmExperts
(
use_int8_w8a8
=
envs
.
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
,
fused_experts
=
self
.
w8a8_groupgemm_contiguous_forward
if
envs
.
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
else
self
.
fused_moe_forward
)
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
View file @
13130b89
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
os
from
math
import
prod
import
torch
from
torch.nn.parameter
import
Parameter
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
torch.nn.parameter
import
Parameter
from
vllm.distributed
import
get_tensor_model_parallel_world_size
,
get_ep_group
,
get_dp_group
from
vllm.logger
import
init_logger
from
vllm.config
import
get_current_vllm_config
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
...
...
@@ -13,16 +19,25 @@ from vllm.model_executor.layers.quantization.base_config import (QuantizationCon
from
vllm.model_executor.layers.quantization.utils.w4a8_utils
import
w4a8_weight_repack_impl
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.layers.fused_moe.utils
import
_resize_cache
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
ModelWeightParameter
)
from
vllm.model_executor.layers.quantization.slimquant_w4a8
import
SlimQuantW4A8Int8LinearMethod
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEActivationFormat
,
FusedMoEConfig
,
FusedMoEMethodBase
,
FusedMoEPermuteExpertsUnpermute
,
FusedMoEPrepareAndFinalize
,
FusedMoeWeightScaleSupported
)
try
:
from
lmslim.layers.fused_moe.fuse_moe_w4a8_marlin
import
fused_experts_impl_w4a8_marlin
from
lightop
import
m_grouped_w4a8_gemm_nt_masked
,
fuse_silu_mul_quant
,
fuse_silu_mul_quant_ep
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer the quantitative model of moe.
\n
"
)
logger
=
init_logger
(
__name__
)
class
MarlinMoeWorkspace
:
"""
Singleton manager for device-specific workspace buffers used by w4a8 Marlin-MoE.
...
...
@@ -145,6 +160,21 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
def
__init__
(
self
,
quant_config
):
self
.
quant_config
=
quant_config
self
.
fused_experts
=
self
.
w4a8_fused_moe_marlin_forward
vllm_config
=
get_current_vllm_config
()
parallel_config
=
vllm_config
.
parallel_config
dp_size
=
get_dp_group
().
world_size
self
.
use_deepep
=
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
and
\
(
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_high_throughput"
or
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
)
self
.
ep_size
=
get_ep_group
().
world_size
if
self
.
use_deepep
:
all2all_manager
=
get_ep_group
().
device_communicator
.
all2all_manager
assert
all2all_manager
is
not
None
self
.
num_dispatchers
=
all2all_manager
.
world_size
def
create_weights
(
self
,
...
...
@@ -157,6 +187,10 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
):
tp_size
=
get_tensor_model_parallel_world_size
()
if
self
.
use_deepep
:
self
.
N
=
2
*
intermediate_size
self
.
K
=
hidden_size
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
...
...
@@ -209,6 +243,139 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
layer
.
w13_weight
=
Parameter
(
w4a8_weight_repack_impl
(
layer
.
w13_weight
),
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w4a8_weight_repack_impl
(
layer
.
w2_weight
),
requires_grad
=
False
)
def
w4a8_fused_moe_marlin_forward
(
self
,
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
q_x
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
):
workspace
,
global_reduce_buffer
=
MarlinMoeWorkspace
(
x
.
device
).
get_buffers
()
return
fused_experts_impl_w4a8_marlin
(
x
if
q_x
is
None
else
q_x
,
w1
,
w2
,
topk_ids
=
topk_ids
,
topk_weights
=
topk_weights
,
workspace
=
workspace
,
global_reduce_buffer
=
global_reduce_buffer
,
inplace
=
True
,
use_int4_w4a8
=
True
,
per_channel_quant
=
True
,
activation
=
activation
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
global_num_experts
=
global_num_experts
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
use_nn_moe
=
use_nn_moe
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
,
)
def
groupgemm_workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,):
assert
a
.
dim
()
==
2
# FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed.
num_dispatchers
=
self
.
num_dispatchers
num_experts
=
local_num_experts
max_num_tokens
=
a
.
size
(
0
)
if
self
.
max_num_tokens_per_rank
is
None
else
self
.
max_num_tokens_per_rank
workspace13
=
(
num_experts
,
max_num_tokens
*
num_dispatchers
,
max
(
K
,
N
))
workspace2
=
(
num_experts
,
max_num_tokens
*
num_dispatchers
,
(
N
//
2
))
output
=
(
num_experts
,
max_num_tokens
*
num_dispatchers
,
K
)
return
(
workspace13
,
workspace2
,
output
,
a
.
dtype
)
def
w4a8_groupgemm_marlin_forward
(
self
,
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
q_x
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
):
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
local_num_experts
=
w1
.
size
(
0
)
E
,
max_num_tokens
,
_
,
_
,
top_k
=
mk
.
_moe_problem_size
(
q_x
,
w1
,
w2
,
topk_ids
)
N
,
K
=
self
.
N
,
self
.
K
(
workspace13_shape
,
workspace2_shape
,
fused_out_shape
,
workspace_dtype
)
=
self
.
groupgemm_workspace_shapes
(
x
,
q_x
,
max_num_tokens
,
N
,
K
,
top_k
,
global_num_experts
,
local_num_experts
)
workspace13
=
torch
.
empty
(
prod
(
workspace13_shape
),
device
=
x
.
device
,
dtype
=
workspace_dtype
)
workspace1
=
_resize_cache
(
workspace13
,
(
E
,
max_num_tokens
,
N
))
fused_out
=
_resize_cache
(
workspace13
,
fused_out_shape
)
# (from deepgemm docs) : A value hint (which is a value on CPU)
# for the M expectation of each batch, correctly setting this value
# may lead to better performance.
#expected_m = max_num_tokens
ori_bs
=
x
.
shape
[
0
]
expected_m
=
ori_bs
*
self
.
ep_size
m_grouped_w4a8_gemm_nt_masked
((
q_x
,
a1_scale
),
(
w1
,
w1_scale
),
workspace1
,
expert_num_tokens
,
expected_m
,
)
assert
expert_num_tokens
is
not
None
a2q
,
a2q_scale
=
fuse_silu_mul_quant_ep
(
workspace1
,
expert_num_tokens
)
m_grouped_w4a8_gemm_nt_masked
((
a2q
,
a2q_scale
),
(
w2
,
w2_scale
),
fused_out
,
expert_num_tokens
,
expected_m
)
return
fused_out
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -233,9 +400,9 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
)
->
torch
.
Tensor
:
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet."
)
#
if enable_eplb:
#
raise NotImplementedError(
#
"EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet.")
# Expert selection
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
...
...
@@ -248,30 +415,62 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
indices_type
=
torch
.
int64
if
self
.
use_deepep
else
None
,
routed_scaling_factor
=
routed_scaling_factor
,
use_fused_gate
=
use_fused_gate
)
workspace
,
global_reduce_buffer
=
MarlinMoeWorkspace
(
x
.
device
).
get_buffers
()
return
fused_experts_impl_w4a8_marlin
(
return
self
.
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
workspace
=
workspace
,
global_reduce_buffer
=
global_reduce_buffer
,
inplace
=
True
,
use_int4_w4a8
=
True
,
per_channel_quant
=
True
,
activation
=
activation
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
(
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_nn_moe
=
use_nn_moe
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
,
)
def
select_gemm_impl
(
self
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
moe
:
FusedMoEConfig
,
)
->
FusedMoEPermuteExpertsUnpermute
:
from
vllm.model_executor.layers.fused_moe
import
(
TritonOrGroupGemmExperts
)
if
(
prepare_finalize
.
activation_format
==
FusedMoEActivationFormat
.
BatchedExperts
):
max_num_tokens_per_rank
=
(
prepare_finalize
.
max_num_tokens_per_rank
())
self
.
max_num_tokens_per_rank
=
max_num_tokens_per_rank
assert
max_num_tokens_per_rank
is
not
None
logger
.
info
(
"TritonOrGroupGemmExperts(%s): "
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s"
,
self
.
__class__
.
__name__
,
max_num_tokens_per_rank
,
None
,
True
)
return
TritonOrGroupGemmExperts
(
use_int4_w4a8
=
True
,
per_act_token_quant
=
True
,
fused_experts
=
self
.
w4a8_groupgemm_marlin_forward
)
else
:
logger
.
info
(
"TritonOrGroupGemmExperts(%s): block_size=%s, per_act_token=%s"
,
self
.
__class__
.
__name__
,
None
,
False
)
return
TritonOrGroupGemmExperts
(
# use_int4_w4a8=True,
# per_act_token_quant=True,
fused_experts
=
self
.
w4a8_fused_moe_marlin_forward
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment