Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
289f98c6
Commit
289f98c6
authored
Feb 08, 2026
by
王敏
Browse files
[feat]适配w8a8 deepep,接入lightop版deepgemm
parent
e807ec39
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
582 additions
and
167 deletions
+582
-167
vllm/distributed/device_communicators/all2all.py
vllm/distributed/device_communicators/all2all.py
+9
-4
vllm/model_executor/layers/fused_moe/all2all_utils.py
vllm/model_executor/layers/fused_moe/all2all_utils.py
+3
-0
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
.../model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+62
-23
vllm/model_executor/layers/fused_moe/config.py
vllm/model_executor/layers/fused_moe/config.py
+9
-6
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
+59
-22
vllm/model_executor/layers/fused_moe/deep_gemm_utils.py
vllm/model_executor/layers/fused_moe/deep_gemm_utils.py
+113
-68
vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
+2
-2
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
+12
-3
vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
...del_executor/layers/fused_moe/fused_moe_modular_method.py
+3
-0
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+8
-0
vllm/model_executor/layers/fused_moe/modular_kernel.py
vllm/model_executor/layers/fused_moe/modular_kernel.py
+61
-29
vllm/model_executor/layers/fused_moe/utils.py
vllm/model_executor/layers/fused_moe/utils.py
+143
-2
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
...ation/compressed_tensors/compressed_tensors_moe_marlin.py
+85
-8
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+13
-0
No files found.
vllm/distributed/device_communicators/all2all.py
View file @
289f98c6
...
...
@@ -259,7 +259,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
# This is the DeepEP default. Stick to it till we can establish
# reasonable defaults based on profiling.
self
.
num_sms
=
2
0
self
.
num_sms
=
3
0
def
get_handle
(
self
,
kwargs
):
raise
NotImplementedError
...
...
@@ -292,16 +292,21 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
def
_make_all2all_kwargs
(
self
)
->
dict
[
Any
,
Any
]:
# Defaults for internode and intranode are taken from DeepEP tests.
num_nvl_bytes
=
envs
.
VLLM_DEEPEP_BUFFER_SIZE_MB
*
1024
*
1024
#num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
num_nvl_bytes
=
int
(
2e9
/
2
)
#1024 * 1024 * 1024
num_rdma_bytes
=
None
num_qps_per_rank
=
None
if
self
.
internode
and
not
envs
.
VLLM_DEEPEP_HIGH_THROUGHPUT_FORCE_INTRA_NODE
:
num_rdma_bytes
=
envs
.
VLLM_DEEPEP_BUFFER_SIZE_MB
*
1024
*
1024
num_qps_per_rank
=
self
.
num_sms
//
2
# num_rdma_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
# num_qps_per_rank = self.num_sms // 2
num_rdma_bytes
=
int
(
1e9
/
2
)
#1024 * 1024 * 1024
num_qps_per_rank
=
30
#self.num_sms // 2
self
.
num_sms
=
30
else
:
num_rdma_bytes
=
0
num_qps_per_rank
=
1
self
.
num_sms
=
60
assert
num_rdma_bytes
is
not
None
assert
num_qps_per_rank
is
not
None
...
...
vllm/model_executor/layers/fused_moe/all2all_utils.py
View file @
289f98c6
...
...
@@ -162,6 +162,8 @@ def maybe_make_prepare_finalize(
and
quant_config
.
block_shape
==
DEEPEP_QUANT_BLOCK_SHAPE
)
use_int8_dispatch
=
quant_config
.
quant_dtype
==
torch
.
int8
prepare_finalize
=
DeepEPLLPrepareAndFinalize
(
handle
,
max_tokens_per_rank
=
moe
.
max_num_tokens
,
...
...
@@ -170,6 +172,7 @@ def maybe_make_prepare_finalize(
global_to_physical
=
global_to_physical
,
physical_to_global
=
physical_to_global
,
local_expert_global_ids
=
local_expert_global_ids
,
use_int8_dispatch
=
use_int8_dispatch
,
)
elif
moe
.
use_mori_kernels
:
assert
quant_config
is
not
None
...
...
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
View file @
289f98c6
...
...
@@ -32,6 +32,16 @@ from vllm.utils.deep_gemm import (
)
from
vllm.utils.math_utils
import
cdiv
,
round_up
from
vllm.utils.import_utils
import
has_deep_gemm
from
lightop
import
fuse_silu_mul_quant_ep
if
has_deep_gemm
():
from
deep_gemm
import
m_grouped_w8a8_gemm_nt_masked
else
:
from
lightop
import
m_grouped_w8a8_gemm_nt_masked
logger
=
init_logger
(
__name__
)
...
...
@@ -267,6 +277,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
quant_config
:
FusedMoEQuantConfig
,
max_num_tokens
:
int
,
num_dispatchers
:
int
,
N
:
int
=
-
1
,
K
:
int
=
-
1
,
):
"""
max_num_tokens: Maximum number of tokens from a DP Rank
...
...
@@ -279,8 +291,12 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
max_num_tokens
=
max_num_tokens
,
num_dispatchers
=
num_dispatchers
,
)
assert
self
.
block_shape
==
get_mk_alignment_for_contiguous_layout
()
assert
self
.
quant_config
.
use_fp8_w8a8
if
quant_config
.
use_fp8_w8a8
:
assert
self
.
block_shape
==
get_mk_alignment_for_contiguous_layout
()
self
.
N
=
N
self
.
K
=
K
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
...
...
@@ -398,6 +414,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2
:
torch
.
Tensor
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
apply_router_weight_on_input
:
bool
,
use_nn_moe
:
bool
|
None
=
False
,
):
assert
expert_tokens_meta
is
not
None
expert_num_tokens
=
expert_tokens_meta
.
expert_num_tokens
...
...
@@ -408,12 +425,15 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
a1q
=
hidden_states
_
,
N
,
K
=
w1
.
size
()
assert
w2
.
size
(
1
)
==
K
#
assert w2.size(1) == K
E
,
max_num_tokens
,
N
,
K
,
_
=
self
.
moe_problem_size
(
hidden_states
,
w1
,
w2
,
topk_ids
)
if
self
.
N
>
0
:
N
=
self
.
N
workspace1
=
_resize_cache
(
workspace13
,
(
E
,
max_num_tokens
,
N
))
expected_m
=
self
.
estimate_expected_m
(
...
...
@@ -422,25 +442,44 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk
=
topk_ids
.
size
(
-
1
),
)
fp8_m_grouped_gemm_nt_masked
(
(
a1q
,
a1q_scale
),
(
w1
,
self
.
w1_scale
),
workspace1
,
expert_num_tokens
,
expected_m
,
)
if
self
.
quant_config
.
use_fp8_w8a16
or
self
.
quant_config
.
use_fp8_w8a8
:
fp8_m_grouped_gemm_nt_masked
(
(
a1q
,
a1q_scale
),
(
w1
,
self
.
w1_scale
),
workspace1
,
expert_num_tokens
,
expected_m
,
)
quant_scale_fmt
=
DeepGemmQuantScaleFMT
.
from_oracle
()
a2q
,
a2q_scale
=
persistent_masked_m_silu_mul_quant
(
workspace1
,
expert_num_tokens
,
quant_scale_fmt
=
quant_scale_fmt
,
)
quant_scale_fmt
=
DeepGemmQuantScaleFMT
.
from_oracle
()
a2q
,
a2q_scale
=
persistent_masked_m_silu_mul_quant
(
workspace1
,
expert_num_tokens
,
quant_scale_fmt
=
quant_scale_fmt
,
)
fp8_m_grouped_gemm_nt_masked
(
(
a2q
,
a2q_scale
),
(
w2
,
self
.
w2_scale
),
output
,
expert_num_tokens
,
expected_m
,
)
fp8_m_grouped_gemm_nt_masked
(
(
a2q
,
a2q_scale
),
(
w2
,
self
.
w2_scale
),
output
,
expert_num_tokens
,
expected_m
,
)
elif
self
.
quant_config
.
use_int8_w8a8
:
m_grouped_w8a8_gemm_nt_masked
((
a1q
,
a1q_scale
),
(
w1
,
self
.
w1_scale
),
workspace1
,
expert_num_tokens
,
expected_m
,
)
assert
expert_num_tokens
is
not
None
a2q
,
a2q_scale
=
fuse_silu_mul_quant_ep
(
workspace1
,
expert_num_tokens
)
m_grouped_w8a8_gemm_nt_masked
((
a2q
,
a2q_scale
),
(
w2
,
self
.
w2_scale
),
output
,
expert_num_tokens
,
expected_m
)
else
:
raise
ValueError
(
f
"Unsupported dtype
{
self
.
quant_config
.
quant_dtype
}
"
)
vllm/model_executor/layers/fused_moe/config.py
View file @
289f98c6
...
...
@@ -87,7 +87,7 @@ def _quant_flags_to_group_shape(
"""
a_shape
:
GroupShape
|
None
w_shape
:
GroupShape
|
None
if
block_shape
is
not
None
:
if
block_shape
is
not
None
and
quant_dtype
!=
torch
.
int8
:
assert
not
per_act_token_quant
assert
not
per_out_ch_quant
# TODO(bnell): this is not quite right for activations since first
...
...
@@ -211,10 +211,10 @@ class FusedMoEQuantConfig:
_w1
:
FusedMoEQuantDesc
_w2
:
FusedMoEQuantDesc
def
__post_init__
(
self
):
assert
not
self
.
per_act_token_quant
or
self
.
block_shape
is
None
,
(
"illegal quantization"
)
#
def __post_init__(self):
#
assert not self.per_act_token_quant or self.block_shape is None, (
#
"illegal quantization"
#
)
#
# Convenience accessors for various properties.
...
...
@@ -246,6 +246,9 @@ class FusedMoEQuantConfig:
@
property
def
block_shape
(
self
)
->
list
[
int
]
|
None
:
if
self
.
use_int8_w8a8
:
return
[
256
,
256
]
if
(
self
.
_a1
.
shape
is
not
None
and
self
.
_a1
.
shape
!=
GroupShape
.
PER_TENSOR
...
...
@@ -569,7 +572,7 @@ def int8_w8a8_moe_quant_config(
a2_scale
=
a2_scale
,
per_act_token_quant
=
per_act_token_quant
,
per_out_ch_quant
=
False
,
block_shape
=
None
,
block_shape
=
[
256
,
256
]
,
)
...
...
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
View file @
289f98c6
...
...
@@ -37,6 +37,12 @@ from vllm.utils.deep_gemm import (
)
from
vllm.utils.import_utils
import
has_deep_gemm
from
lightop
import
fuse_silu_mul_quant
if
has_deep_gemm
():
from
deep_gemm
import
m_grouped_i8_gemm_nt_contiguous
else
:
from
lightop
import
m_grouped_w8a8_gemm_nt_contig_asm
as
m_grouped_i8_gemm_nt_contiguous
logger
=
init_logger
(
__name__
)
...
...
@@ -113,12 +119,22 @@ def _valid_deep_gemm(
class
DeepGemmExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
def
__init__
(
self
,
moe_config
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
):
def
__init__
(
self
,
moe_config
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
,
N
:
int
=
-
1
,
K
:
int
=
-
1
,):
super
().
__init__
(
moe_config
=
moe_config
,
quant_config
=
quant_config
)
assert
quant_config
.
block_shape
==
get_mk_alignment_for_contiguous_layout
()
assert
quant_config
.
quant_dtype
==
torch
.
float8_e4m3fn
assert
not
quant_config
.
per_act_token_quant
assert
not
quant_config
.
per_out_ch_quant
if
quant_config
.
use_fp8_w8a8
or
quant_config
.
use_fp8_w8a16
:
assert
quant_config
.
block_shape
==
get_mk_alignment_for_contiguous_layout
()
assert
quant_config
.
quant_dtype
==
torch
.
float8_e4m3fn
assert
not
quant_config
.
per_act_token_quant
assert
not
quant_config
.
per_out_ch_quant
self
.
N
=
N
self
.
K
=
K
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
...
...
@@ -241,6 +257,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2
:
torch
.
Tensor
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
apply_router_weight_on_input
:
bool
,
use_nn_moe
:
bool
|
None
=
False
,
):
assert
a1q_scale
is
not
None
assert
a2_scale
is
None
...
...
@@ -255,19 +272,24 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
if
global_num_experts
==
-
1
:
global_num_experts
=
local_num_experts
assert
w2
.
size
(
1
)
==
K
#assert w2.size(1) == K
if
self
.
N
>
0
:
N
=
self
.
N
K
=
self
.
K
use_fp8
=
self
.
quant_config
.
use_fp8_w8a16
or
self
.
quant_config
.
use_fp8_w8a8
M_sum
=
compute_aligned_M
(
M
=
topk_ids
.
size
(
0
),
num_topk
=
topk_ids
.
size
(
1
),
local_num_experts
=
local_num_experts
,
alignment
=
get_mk_alignment_for_contiguous_layout
()[
0
],
alignment
=
get_mk_alignment_for_contiguous_layout
()[
0
]
if
use_fp8
else
self
.
block_shape
[
0
]
,
expert_tokens_meta
=
expert_tokens_meta
,
)
a1q_perm
=
_resize_cache
(
workspace13
.
view
(
dtype
=
torch
.
float8_e4m3fn
),
(
M_sum
,
K
)
workspace13
.
view
(
dtype
=
torch
.
float8_e4m3fn
if
use_fp8
else
a1q
.
dtype
),
(
M_sum
,
K
)
)
a1q
,
a1q_scale
,
expert_ids
,
inv_perm
=
deepgemm_moe_permute
(
aq
=
a1q
,
aq_scale
=
a1q_scale
,
...
...
@@ -280,22 +302,37 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert
a1q
.
size
(
0
)
==
M_sum
mm1_out
=
_resize_cache
(
workspace2
,
(
M_sum
,
N
))
m_grouped_fp8_gemm_nt_contiguous
(
(
a1q
,
a1q_scale
),
(
w1
,
self
.
w1_scale
),
mm1_out
,
expert_ids
)
activation_out_dim
=
self
.
adjust_N_for_activation
(
N
,
activation
)
quant_out
=
_resize_cache
(
workspace13
.
view
(
dtype
=
torch
.
float8_e4m3fn
),
(
M_sum
,
activation_out_dim
)
)
a2q
,
a2q_scale
=
self
.
_act_mul_quant
(
input
=
mm1_out
.
view
(
-
1
,
N
),
output
=
quant_out
,
activation
=
activation
)
if
use_fp8
:
m_grouped_fp8_gemm_nt_contiguous
(
(
a1q
,
a1q_scale
),
(
w1
,
self
.
w1_scale
),
mm1_out
,
expert_ids
)
activation_out_dim
=
self
.
adjust_N_for_activation
(
N
,
activation
)
quant_out
=
_resize_cache
(
workspace13
.
view
(
dtype
=
torch
.
float8_e4m3fn
),
(
M_sum
,
activation_out_dim
)
)
a2q
,
a2q_scale
=
self
.
_act_mul_quant
(
input
=
mm1_out
.
view
(
-
1
,
N
),
output
=
quant_out
,
activation
=
activation
)
mm2_out
=
_resize_cache
(
workspace2
,
(
M_sum
,
K
))
m_grouped_fp8_gemm_nt_contiguous
(
(
a2q
,
a2q_scale
),
(
w2
,
self
.
w2_scale
),
mm2_out
,
expert_ids
)
elif
self
.
quant_config
.
use_int8_w8a8
:
m_grouped_i8_gemm_nt_contiguous
(
(
a1q
,
a1q_scale
),
(
w1
,
self
.
w1_scale
),
mm1_out
,
expert_ids
)
a2q
,
a2q_scale
=
fuse_silu_mul_quant
(
mm1_out
)
#a2q, a2q_scale = fuse_silu_mul_quant(input=mm1_out, expert_ids=expert_ids)
mm2_out
=
_resize_cache
(
workspace2
,
(
M_sum
,
K
))
m_grouped_i8_gemm_nt_contiguous
(
(
a2q
,
a2q_scale
),
(
w2
,
self
.
w2_scale
),
mm2_out
,
expert_ids
)
else
:
raise
ValueError
(
f
"Unsupported dtype
{
self
.
quant_config
.
quant_dtype
}
"
)
mm2_out
=
_resize_cache
(
workspace2
,
(
M_sum
,
K
))
m_grouped_fp8_gemm_nt_contiguous
(
(
a2q
,
a2q_scale
),
(
w2
,
self
.
w2_scale
),
mm2_out
,
expert_ids
)
if
apply_router_weight_on_input
:
topk_weights
=
torch
.
ones_like
(
topk_weights
)
...
...
vllm/model_executor/layers/fused_moe/deep_gemm_utils.py
View file @
289f98c6
...
...
@@ -13,6 +13,8 @@ from vllm.triton_utils import tl, triton
from
vllm.utils.deep_gemm
import
get_mk_alignment_for_contiguous_layout
from
vllm.utils.math_utils
import
round_up
from
lightop
import
op
def
expert_num_tokens_round_up_and_sum
(
expert_num_tokens
:
torch
.
Tensor
,
alignment
:
int
...
...
@@ -57,6 +59,12 @@ def round_up_128(x: int) -> int:
return
((
x
+
y
-
1
)
//
y
)
*
y
@
triton
.
jit
def
round_up_256
(
x
:
int
)
->
int
:
y
=
256
return
((
x
+
y
-
1
)
//
y
)
*
y
@
triton
.
jit
def
_fwd_kernel_ep_scatter_1
(
num_recv_tokens_per_expert
,
...
...
@@ -74,26 +82,27 @@ def _fwd_kernel_ep_scatter_1(
mask
=
offset_cumsum
<
num_experts
,
other
=
0
,
)
tokens_per_expert
=
round_up_128
(
tokens_per_expert
)
#tokens_per_expert = round_up_128(tokens_per_expert)
tokens_per_expert
=
round_up_256
(
tokens_per_expert
)
cumsum
=
tl
.
cumsum
(
tokens_per_expert
)
-
tokens_per_expert
#if cur_expert == 0:
tl
.
store
(
expert_start_loc
+
offset_cumsum
,
cumsum
,
mask
=
offset_cumsum
<
num_experts
)
tl
.
debug_barrier
()
#cur_expert_start = cumsum[cur_expert]
cur_expert_start
=
tl
.
load
(
expert_start_loc
+
cur_expert
)
cur_expert_token_num
=
tl
.
load
(
num_recv_tokens_per_expert
+
cur_expert
)
m_indices_start_ptr
=
m_indices
+
cur_expert_start
off_expert
=
tl
.
arange
(
0
,
BLOCK_E
)
# any rows in the per-expert aligned region that do not correspond to
# real tokens are left untouched here and should remain initialized to
# -1 so DeepGEMM can skip them
for
start_m
in
tl
.
range
(
0
,
cur_expert_token_num
,
BLOCK_E
,
num_stages
=
4
):
offs
=
start_m
+
off_expert
mask
=
offs
<
cur_expert_token_num
tl
.
store
(
m_indices_start_ptr
+
offs
,
m_indices_start_ptr
+
start_m
+
off_expert
,
cur_expert
,
mask
=
mask
,
mask
=
start_m
+
off_expert
<
cur_expert_token_num
)
...
...
@@ -133,26 +142,32 @@ def _fwd_kernel_ep_scatter_2(
offset_in
=
tl
.
arange
(
0
,
HIDDEN_SIZE_PAD
)
mask
=
offset_in
<
HIDDEN_SIZE
offset
_in_s
=
tl
.
arange
(
0
,
SCALE_HIDDEN_SIZE_PAD
)
mask_s
=
offset
_in_s
<
SCALE_HIDDEN_SIZE
index
_in_s
=
tl
.
arange
(
0
,
SCALE_HIDDEN_SIZE_PAD
)
mask_s
=
index
_in_s
<
SCALE_HIDDEN_SIZE
for
token_id
in
range
(
start_token_id
,
total_token_num
,
grid_num
):
for
token_id_int32
in
range
(
start_token_id
,
total_token_num
,
grid_num
):
token_id
=
token_id_int32
.
to
(
tl
.
int64
)
to_copy
=
tl
.
load
(
recv_x
+
token_id
*
recv_x_stride0
+
offset_in
,
mask
=
mask
)
to_copy_s
=
tl
.
load
(
recv_x_scale
+
token_id
*
recv_x_scale_stride0
+
offset_in_s
,
mask
=
mask_s
recv_x_scale
+
token_id
*
recv_x_scale_stride0
+
index_in_s
*
recv_x_scale_stride1
,
mask
=
mask_s
,
)
for
topk_index
in
tl
.
range
(
0
,
topk_num
,
1
,
num_stages
=
4
):
for
topk_idx_int32
in
tl
.
range
(
0
,
topk_num
,
1
,
num_stages
=
4
):
topk_index
=
topk_idx_int32
.
to
(
tl
.
int64
)
expert_id
=
tl
.
load
(
recv_topk
+
token_id
*
recv_topk_stride0
+
topk_index
)
if
HAS_EXPERT_MAP
:
expert_id
=
apply_expert_map
(
expert_id
,
expert_map
)
if
expert_id
>=
0
:
dest_token_index
=
tl
.
atomic_add
(
expert_start_loc
+
expert_id
,
1
)
dest_token_index_int32
=
tl
.
atomic_add
(
expert_start_loc
+
expert_id
,
1
)
dest_token_index
=
dest_token_index_int32
.
to
(
tl
.
int64
)
tl
.
store
(
output_index
+
token_id
*
output_index_stride0
+
topk_index
,
dest_token_index
,
dest_token_index
_int32
,
)
output_tensor_ptr
=
(
output_tensor
+
dest_token_index
*
output_tensor_stride0
...
...
@@ -161,7 +176,11 @@ def _fwd_kernel_ep_scatter_2(
output_tensor_scale
+
dest_token_index
*
output_tensor_scale_stride0
)
tl
.
store
(
output_tensor_ptr
+
offset_in
,
to_copy
,
mask
=
mask
)
tl
.
store
(
output_tensor_scale_ptr
+
offset_in_s
,
to_copy_s
,
mask
=
mask_s
)
tl
.
store
(
output_tensor_scale_ptr
+
index_in_s
*
output_tensor_scale_stride1
,
to_copy_s
,
mask
=
mask_s
,
)
@
torch
.
no_grad
()
...
...
@@ -177,58 +196,71 @@ def ep_scatter(
m_indices
:
torch
.
Tensor
,
output_index
:
torch
.
Tensor
,
):
BLOCK_E
=
128
# token num of per expert is aligned to 128
BLOCK_D
=
128
# block size of quantization
# BLOCK_E = 128 # token num of per expert is aligned to 128
# BLOCK_D = 128 # block size of quantization
BLOCK_E
=
256
# token num of per expert is aligned to 256
num_warps
=
8
num_experts
=
num_recv_tokens_per_expert
.
shape
[
0
]
hidden_size
=
recv_x
.
shape
[
1
]
scale_hidden_size
=
recv_x_scale
.
shape
[
-
1
]
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
grid
=
num_experts
assert
m_indices
.
shape
[
0
]
%
BLOCK_E
==
0
_fwd_kernel_ep_scatter_1
[(
grid
,)](
num_recv_tokens_per_expert
,
expert_start_loc
,
m_indices
,
num_experts
=
num_experts
,
num_warps
=
num_warps
,
BLOCK_E
=
BLOCK_E
,
BLOCK_EXPERT_NUM
=
triton
.
next_power_of_2
(
num_experts
),
)
if
hasattr
(
op
,
"ep_scatter"
):
op
.
ep_scatter
(
recv_x
,
recv_x_scale
,
recv_topk
,
expert_map
,
num_recv_tokens_per_expert
,
output_tensor
,
output_tensor_scale
,
m_indices
,
output_index
,
num_experts
,
BLOCK_E
)
else
:
_fwd_kernel_ep_scatter_1
[(
grid
,)](
num_recv_tokens_per_expert
,
expert_start_loc
,
m_indices
,
num_experts
=
num_experts
,
num_warps
=
num_warps
,
BLOCK_E
=
BLOCK_E
,
BLOCK_EXPERT_NUM
=
triton
.
next_power_of_2
(
num_experts
),
)
grid
=
min
(
recv_topk
.
shape
[
0
],
1024
*
8
)
_fwd_kernel_ep_scatter_2
[(
grid
,)](
recv_topk
.
shape
[
0
],
expert_start_loc
,
recv_x
,
recv_x
.
stride
(
0
),
recv_x
.
stride
(
1
),
recv_x_scale
,
recv_x_scale
.
stride
(
0
),
recv_x_scale
.
stride
(
1
),
recv_topk
,
recv_topk
.
stride
(
0
),
recv_topk
.
stride
(
1
),
output_tensor
,
output_tensor
.
stride
(
0
),
output_tensor
.
stride
(
1
),
output_tensor_scale
,
output_tensor_scale
.
stride
(
0
),
output_tensor_scale
.
stride
(
1
),
output_index
,
output_index
.
stride
(
0
),
output_index
.
stride
(
1
),
topk_num
=
recv_topk
.
shape
[
1
],
expert_map
=
expert_map
,
HAS_EXPERT_MAP
=
expert_map
is
not
None
,
num_warps
=
num_warps
,
HIDDEN_SIZE
=
hidden_size
,
HIDDEN_SIZE_PAD
=
triton
.
next_power_of_2
(
hidden_size
),
SCALE_HIDDEN_SIZE
=
hidden_size
//
BLOCK_D
,
SCALE_HIDDEN_SIZE_PAD
=
triton
.
next_power_of_2
(
hidden_size
//
BLOCK_D
),
)
grid
=
min
(
recv_topk
.
shape
[
0
],
1024
*
8
)
_fwd_kernel_ep_scatter_2
[(
grid
,)](
recv_topk
.
shape
[
0
],
expert_start_loc
,
recv_x
,
recv_x
.
stride
(
0
),
recv_x
.
stride
(
1
),
recv_x_scale
,
recv_x_scale
.
stride
(
0
),
recv_x_scale
.
stride
(
1
),
recv_topk
,
recv_topk
.
stride
(
0
),
recv_topk
.
stride
(
1
),
output_tensor
,
output_tensor
.
stride
(
0
),
output_tensor
.
stride
(
1
),
output_tensor_scale
,
output_tensor_scale
.
stride
(
0
),
output_tensor_scale
.
stride
(
1
),
output_index
,
output_index
.
stride
(
0
),
output_index
.
stride
(
1
),
topk_num
=
recv_topk
.
shape
[
1
],
expert_map
=
expert_map
,
HAS_EXPERT_MAP
=
expert_map
is
not
None
,
num_warps
=
num_warps
,
HIDDEN_SIZE
=
hidden_size
,
HIDDEN_SIZE_PAD
=
triton
.
next_power_of_2
(
hidden_size
),
# SCALE_HIDDEN_SIZE=hidden_size // BLOCK_D,
# SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size // BLOCK_D),
SCALE_HIDDEN_SIZE
=
scale_hidden_size
,
SCALE_HIDDEN_SIZE_PAD
=
triton
.
next_power_of_2
(
scale_hidden_size
),
)
return
...
...
@@ -255,25 +287,34 @@ def _fwd_kernel_ep_gather(
HAS_EXPERT_MAP
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
):
cur_block
=
tl
.
program_id
(
0
)
start_cur_token
=
tl
.
program_id
(
1
)
cur_block_int32
=
tl
.
program_id
(
0
)
cur_block
=
cur_block_int32
.
to
(
tl
.
int64
)
start_cur_token_int32
=
tl
.
program_id
(
1
)
grid_num
=
tl
.
num_programs
(
1
)
for
cur_token
in
range
(
start_cur_token
,
total_token_num
,
grid_num
):
for
cur_token_int32
in
range
(
start_cur_token_int32
,
total_token_num
,
grid_num
):
cur_token
=
cur_token_int32
.
to
(
tl
.
int64
)
off_d
=
tl
.
arange
(
0
,
BLOCK_D
)
accumulator
=
tl
.
zeros
([
BLOCK_D
],
dtype
=
tl
.
float32
)
for
topk_index
in
range
(
0
,
topk_num
):
for
topk_index_int32
in
range
(
0
,
topk_num
):
topk_index
=
topk_index_int32
.
to
(
tl
.
int64
)
expert_id
=
tl
.
load
(
recv_topk_ids
+
cur_token
*
recv_topk_ids_stride0
+
topk_index
)
if
HAS_EXPERT_MAP
:
expert_id
=
apply_expert_map
(
expert_id
,
expert_map
)
if
expert_id
>=
0
:
source_token_index
=
tl
.
load
(
source_token_index
_int32
=
tl
.
load
(
input_index
+
cur_token
*
input_index_stride0
+
topk_index
)
source_token_index
=
source_token_index_int32
.
to
(
tl
.
int64
)
acc_weight
=
tl
.
load
(
recv_topk_weight
+
cur_token
*
recv_topk_weight_stride0
+
topk_index
)
...
...
@@ -350,7 +391,8 @@ def deepgemm_moe_permute(
H
=
aq
.
size
(
1
)
device
=
aq
.
device
block_m
,
block_k
=
get_mk_alignment_for_contiguous_layout
()
#block_m, block_k = get_mk_alignment_for_contiguous_layout()
block_m
=
256
M_sum
=
compute_aligned_M
(
M
=
topk_ids
.
size
(
0
),
...
...
@@ -368,8 +410,11 @@ def deepgemm_moe_permute(
if
aq_out
is
None
:
aq_out
=
torch
.
empty
((
M_sum
,
H
),
device
=
device
,
dtype
=
aq
.
dtype
)
# aq_scale_out = torch.empty(
# (M_sum, H // block_k), device=device, dtype=torch.float32
# )
aq_scale_out
=
torch
.
empty
(
(
M_sum
,
H
//
block_k
),
device
=
device
,
dtype
=
torch
.
float32
(
M_sum
,
aq_scale
.
shape
[
-
1
]
),
device
=
device
,
dtype
=
torch
.
float32
)
# DeepGEMM uses negative values in m_indices (here expert_ids) to mark
...
...
vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
View file @
289f98c6
...
...
@@ -225,7 +225,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# DeepEP kernels only support dispatching block-quantized
# activation scales.
# Dispatch in bfloat16 and quantize afterwards
if
not
quant_config
.
is_block_quantized
:
if
not
quant_config
.
is_block_quantized
and
not
quant_config
.
is_per_act_token
:
# Quantize after dispatch.
expert_x_scale
=
None
if
expert_x
.
numel
()
!=
0
:
...
...
@@ -266,7 +266,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
)
a1
=
a1
*
topk_weights
.
to
(
a1
.
dtype
)
if
quant_config
.
is_block_quantized
:
if
quant_config
.
is_block_quantized
or
quant_config
.
is_per_act_token
:
# Quant and Dispatch
a1q
,
a1q_scale
=
moe_kernel_quantize_input
(
a1
,
...
...
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
View file @
289f98c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
from
typing
import
Callable
,
Optional
import
deep_ep
import
torch
...
...
@@ -91,12 +92,14 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
global_to_physical
:
torch
.
Tensor
|
None
=
None
,
physical_to_global
:
torch
.
Tensor
|
None
=
None
,
local_expert_global_ids
:
torch
.
Tensor
|
None
=
None
,
use_int8_dispatch
:
bool
=
False
):
super
().
__init__
()
self
.
buffer
=
buffer
self
.
max_tokens_per_rank
=
max_tokens_per_rank
self
.
use_fp8_dispatch
=
use_fp8_dispatch
self
.
use_int8_dispatch
=
use_int8_dispatch
# The dispatch function returns a handle that the combine function
# requires. We store the handle here so it is available to the
# combine function.
...
...
@@ -168,6 +171,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
x
:
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
a1_dtype
:
torch
.
dtype
,
quant_config
:
FusedMoEQuantConfig
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
if
self
.
use_fp8_dispatch
:
block_k
=
(
...
...
@@ -183,6 +187,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# Dequant to get back the tokens in the datatype we dispatched in.
x_fp8
,
x_scales
=
x
x
=
dequant_fp8
(
x_fp8
,
x_scales
).
to
(
dtype
=
a1_dtype
)
elif
self
.
use_int8_dispatch
:
x
,
x_scales
=
x
return
x
,
x_scales
assert
isinstance
(
x
,
(
torch
.
Tensor
,
tuple
))
q_dtype
=
quant_config
.
quant_dtype
...
...
@@ -214,7 +221,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts
,
max_tokens
,
hidden_dim
=
x
.
size
()
# TODO (varun): Optimization - Use a batched version of quant
x
=
x
.
view
((
-
1
,
hidden_dim
))
if
expert_num_tokens
is
None
:
x
=
x
.
view
((
-
1
,
hidden_dim
))
x
,
x_scales
=
moe_kernel_quantize_input
(
x
,
quant_config
.
a1_scale
,
...
...
@@ -294,7 +302,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
dispatch_topk_ids
,
self
.
max_tokens_per_rank
,
num_experts
,
use_fp8
=
self
.
use_fp8_dispatch
,
use_fp8
=
self
.
use_fp8_dispatch
or
self
.
use_int8_dispatch
,
use_int8
=
self
.
use_int8_dispatch
,
round_scale
=
self
.
use_ue8m0_dispatch
,
use_ue8m0
=
self
.
use_ue8m0_dispatch
,
**
(
dict
(
use_nvfp4
=
True
)
if
use_nvfp4
else
dict
()),
...
...
@@ -327,7 +336,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1_dtype
:
torch
.
dtype
,
quant_config
:
FusedMoEQuantConfig
,
)
->
mk
.
PrepareResultType
:
expert_x
,
expert_x_scale
=
self
.
_do_quant
(
expert_x
,
a1_dtype
,
quant_config
)
expert_x
,
expert_x_scale
=
self
.
_do_quant
(
expert_x
,
a1_dtype
,
quant_config
,
expert_num_tokens
)
expert_tokens_meta
=
mk
.
ExpertTokensMetadata
(
expert_num_tokens
=
expert_num_tokens
,
expert_num_tokens_cpu
=
None
...
...
vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
View file @
289f98c6
...
...
@@ -54,6 +54,8 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
old_quant_method
.
select_gemm_impl
(
prepare_finalize
,
moe_layer
),
shared_experts
,
moe_parallel_config
=
moe_layer
.
moe_parallel_config
,
N
=
old_quant_method
.
N
if
hasattr
(
old_quant_method
,
"N"
)
else
-
1
,
K
=
old_quant_method
.
K
if
hasattr
(
old_quant_method
,
"K"
)
else
-
1
,
),
)
...
...
@@ -95,6 +97,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
x
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_nn_moe
:
bool
|
None
=
False
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
self
.
fused_experts
(
hidden_states
=
x
,
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
289f98c6
...
...
@@ -281,6 +281,9 @@ def maybe_roundup_hidden_size(
return
hidden_size
first
=
True
# --8<-- [start:fused_moe]
@
CustomOp
.
register
(
"fused_moe"
)
class
FusedMoE
(
CustomOp
):
...
...
@@ -398,6 +401,11 @@ class FusedMoE(CustomOp):
# Expert mapping used in self.load_weights
self
.
expert_mapping
=
expert_mapping
global
first
if
first
:
print
(
f
"###################self.global_num_experts:
{
self
.
global_num_experts
}
, self.logical_num_experts:
{
self
.
logical_num_experts
}
self.moe_parallel_config:
{
self
.
moe_parallel_config
}
"
)
first
=
False
# Round up hidden size if needed.
hidden_size
=
maybe_roundup_hidden_size
(
hidden_size
,
...
...
vllm/model_executor/layers/fused_moe/modular_kernel.py
View file @
289f98c6
...
...
@@ -785,6 +785,21 @@ def _slice_scales(
return
None
_alt_stream
:
torch
.
cuda
.
Stream
|
None
=
None
def
alt_stream
()
->
torch
.
cuda
.
Stream
|
None
:
"""
Ensures aux_stream is initialized only once
"""
global
_alt_stream
# TODO: validate this works properly on ROCm platform.
if
_alt_stream
is
None
:
_alt_stream
=
torch
.
cuda
.
Stream
()
return
_alt_stream
@
final
class
FusedMoEModularKernel
(
torch
.
nn
.
Module
):
"""
...
...
@@ -805,6 +820,8 @@ class FusedMoEModularKernel(torch.nn.Module):
fused_experts
:
FusedMoEPermuteExpertsUnpermute
,
shared_experts
:
torch
.
nn
.
Module
|
None
=
None
,
moe_parallel_config
:
FusedMoEParallelConfig
|
None
=
None
,
N
:
int
=
-
1
,
K
:
int
=
-
1
,
):
super
().
__init__
()
self
.
prepare_finalize
=
prepare_finalize
...
...
@@ -831,6 +848,12 @@ class FusedMoEModularKernel(torch.nn.Module):
f
"
{
fused_experts
.
__class__
.
__name__
}
."
f
"
{
fused_experts
.
activation_format
()
}
"
)
self
.
N
=
N
self
.
K
=
K
if
self
.
shared_experts
is
not
None
:
self
.
alt_stream
=
alt_stream
()
self
.
alt_event
=
torch
.
cuda
.
Event
()
def
_post_init_setup
(
self
):
"""
...
...
@@ -1136,10 +1159,10 @@ class FusedMoEModularKernel(torch.nn.Module):
_
,
M_full
,
N
,
K
,
top_k
=
self
.
fused_experts
.
moe_problem_size
(
a1q
,
w1
,
w2
,
topk_ids
)
if
self
.
N
>
0
:
N
=
self
.
N
K
=
self
.
K
if
use_nn_moe
:
N
=
w1
.
size
(
2
)
num_chunks
,
CHUNK_SIZE
=
self
.
_chunk_info
(
M_full
)
def
input_chunk_range
(
chunk_idx
:
int
)
->
tuple
[
int
,
int
]:
...
...
@@ -1244,37 +1267,46 @@ class FusedMoEModularKernel(torch.nn.Module):
if
self
.
shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
else
:
finalize_ret
=
self
.
prepare_finalize
.
finalize_async
(
output
,
fused_out
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
self
.
fused_experts
.
finalize_weight_and_reduce_impl
(),
)
self
.
alt_event
.
record
()
if
self
.
shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
# TODO(lucas): refactor this in the alternative schedules followup
# currently unpack if we have hook + receiver pair or just
# receiver (see finalize_async docstring)
hook
,
receiver
=
(
finalize_ret
if
isinstance
(
finalize_ret
,
tuple
)
else
(
None
,
finalize_ret
)
)
current_stream
=
torch
.
cuda
.
current_stream
()
with
torch
.
cuda
.
stream
(
self
.
alt_stream
):
self
.
alt_stream
.
wait_event
(
self
.
alt_event
)
finalize_ret
=
self
.
prepare_finalize
.
finalize_async
(
output
,
fused_out
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
,
self
.
fused_experts
.
finalize_weight_and_reduce_impl
(),
)
if
hook
is
not
None
:
if
dbo_enabled
():
# If DBO is being used, register the hook with the ubatch
# context and call it in dbo_maybe_run_recv_hook instead of
# passing it to the receiver.
dbo_register_recv_hook
(
hook
)
dbo_yield
()
else
:
hook
()
# TODO(lucas): refactor this in the alternative schedules followup
# currently unpack if we have hook + receiver pair or just
# receiver (see finalize_async docstring)
hook
,
receiver
=
(
finalize_ret
if
isinstance
(
finalize_ret
,
tuple
)
else
(
None
,
finalize_ret
)
)
if
hook
is
not
None
:
if
dbo_enabled
():
# If DBO is being used, register the hook with the ubatch
# context and call it in dbo_maybe_run_recv_hook instead of
# passing it to the receiver.
dbo_register_recv_hook
(
hook
)
dbo_yield
()
else
:
hook
()
receiver
()
receiver
()
self
.
alt_event
.
record
()
current_stream
.
wait_event
(
self
.
alt_event
)
if
self
.
shared_experts
is
None
:
return
output
...
...
vllm/model_executor/layers/fused_moe/utils.py
View file @
289f98c6
...
...
@@ -2,9 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
functools
from
math
import
prod
from
typing
import
Optional
import
torch
import
torch.nn.functional
as
F
from
triton.language.extra
import
libdevice
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
...
...
@@ -154,11 +156,147 @@ def _fp8_quantize(
return
A
,
A_scale
@
triton
.
jit
def
_per_token_quant_int8_one_kernel_opt
(
x_ptr
,
xq_ptr
,
scale_ptr
,
stride_x
,
stride_xq
,
N
,
T_dim
,
has_tokens_per_expert
:
tl
.
constexpr
,
tokens_per_expert_ptr
,
BLOCK
:
tl
.
constexpr
):
row_id
=
tl
.
program_id
(
0
)
if
has_tokens_per_expert
:
e
=
row_id
//
T_dim
t
=
row_id
%
T_dim
num_valid_tokens_for_e
=
tl
.
load
(
tokens_per_expert_ptr
+
e
)
if
t
>=
num_valid_tokens_for_e
:
return
cols
=
tl
.
arange
(
0
,
BLOCK
)
mask
=
cols
<
N
x
=
tl
.
load
(
x_ptr
+
row_id
*
stride_x
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
x
)),
1e-10
)
scale_x
=
absmax
/
127
x_q
=
x
*
(
127
/
absmax
)
x_q
=
libdevice
.
nearbyint
(
x_q
).
to
(
tl
.
int8
)
tl
.
store
(
xq_ptr
+
row_id
*
stride_xq
+
cols
,
x_q
,
mask
=
mask
)
tl
.
store
(
scale_ptr
+
row_id
,
scale_x
)
@
triton
.
jit
def
_per_token_quant_int8_kernel_opt
(
x_ptr
,
xq_ptr
,
scale_ptr
,
stride_x
,
stride_xq
,
N
,
E_dim
,
T_dim
,
has_tokens_per_expert
:
tl
.
constexpr
,
tokens_per_expert_ptr
,
BLOCK
:
tl
.
constexpr
):
token_idx_start
=
tl
.
program_id
(
0
)
grid_size
=
tl
.
num_programs
(
0
)
num_total_tokens
=
E_dim
*
T_dim
for
token_idx
in
range
(
token_idx_start
,
num_total_tokens
,
grid_size
):
is_valid_token
=
True
if
has_tokens_per_expert
:
e
=
token_idx
//
T_dim
t
=
token_idx
%
T_dim
num_valid_tokens_for_e
=
tl
.
load
(
tokens_per_expert_ptr
+
e
)
if
t
>=
num_valid_tokens_for_e
:
is_valid_token
=
False
if
is_valid_token
:
cols
=
tl
.
arange
(
0
,
BLOCK
)
mask
=
cols
<
N
x
=
tl
.
load
(
x_ptr
+
token_idx
*
stride_x
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
x
)),
1e-10
)
scale_x
=
absmax
/
127
x_q
=
x
*
(
127
/
absmax
)
x_q
=
libdevice
.
nearbyint
(
x_q
).
to
(
tl
.
int8
)
tl
.
store
(
xq_ptr
+
token_idx
*
stride_xq
+
cols
,
x_q
,
mask
=
mask
)
tl
.
store
(
scale_ptr
+
token_idx
,
scale_x
)
def
per_token_quant_int8_triton_opt
(
x
:
torch
.
Tensor
,
tokens_per_expert
:
Optional
[
torch
.
Tensor
]
=
None
):
if
x
.
dim
()
!=
3
:
raise
ValueError
(
f
"Input must be 3D [E, T, H], but got
{
x
.
shape
}
"
)
E
,
T
,
H
=
x
.
shape
N
=
H
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
torch
.
int8
)
scales
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
1
,
),
device
=
x
.
device
,
dtype
=
torch
.
float32
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
if
T
>=
4096
:
num_warps
=
1
num_tokens
=
E
*
T
grid_opt
=
num_tokens
if
E
==
16
and
T
>=
1024
:
grid_opt
=
max
(
1
,
num_tokens
//
(
T
//
256
))
_per_token_quant_int8_kernel_opt
[(
grid_opt
,
)](
x
,
x_q
,
scales
,
stride_x
=
x
.
stride
(
-
2
),
stride_xq
=
x_q
.
stride
(
-
2
),
N
=
N
,
E_dim
=
E
,
T_dim
=
T
,
has_tokens_per_expert
=
tokens_per_expert
is
not
None
,
tokens_per_expert_ptr
=
tokens_per_expert
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
else
:
_per_token_quant_int8_one_kernel_opt
[(
grid_opt
,
)](
x
,
x_q
,
scales
,
stride_x
=
x
.
stride
(
-
2
),
stride_xq
=
x_q
.
stride
(
-
2
),
N
=
N
,
T_dim
=
T
,
has_tokens_per_expert
=
tokens_per_expert
is
not
None
,
tokens_per_expert_ptr
=
tokens_per_expert
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
return
x_q
,
scales
def
_int8_quantize
(
A
:
torch
.
Tensor
,
A_scale
:
torch
.
Tensor
|
None
,
per_act_token
:
bool
,
block_shape
:
list
[
int
]
|
None
=
None
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Perform int8 quantization on the inputs. If a block_shape
...
...
@@ -168,9 +306,12 @@ def _int8_quantize(
# If weights are per-channel (per_channel_quant=True), then
# activations apply per-token quantization. Otherwise, assume
# activation tensor-wise fp8/int8 quantization, dynamic or static
if
block_shape
is
None
:
if
block_shape
is
None
or
per_act_token
:
assert
per_act_token
,
"int8 quantization only supports block or channel-wise"
A
,
A_scale
=
per_token_quant_int8
(
A
)
if
expert_num_tokens
is
None
:
A
,
A_scale
=
per_token_quant_int8
(
A
)
else
:
A
,
A_scale
=
per_token_quant_int8_triton_opt
(
A
,
expert_num_tokens
)
else
:
assert
not
per_act_token
assert
len
(
block_shape
)
==
2
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
View file @
289f98c6
...
...
@@ -6,15 +6,26 @@ from enum import Enum
from
typing
import
Callable
,
Optional
import
torch
from
compressed_tensors.quantization
import
(
QuantizationStrategy
)
import
vllm.envs
as
envs
from
vllm.config
import
get_current_vllm_config
from
vllm.logger
import
init_logger
from
torch.nn.parameter
import
Parameter
from
vllm.distributed
import
get_ep_group
,
get_dp_group
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEActivationFormat
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
,
FusedMoEConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
get_w8a8_int8_marlin_weights
)
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEQuantConfig
)
get_w8a8_int8_marlin_weights
,
w8a8_nt_kpack2_marlin_weight
)
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEQuantConfig
,
int8_w8a8_moe_quant_config
)
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoEPermuteExpertsUnpermute
,
FusedMoEPrepareAndFinalize
,
FusedMoeWeightScaleSupported
,
)
try
:
from
lmslim.layers.fused_moe.fuse_moe_int8_marlin
import
fused_experts_impl_int8_marlin
except
Exception
:
...
...
@@ -74,14 +85,38 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
raise
ValueError
(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales."
)
vllm_config
=
get_current_vllm_config
()
parallel_config
=
vllm_config
.
parallel_config
self
.
dp_size
=
get_dp_group
().
world_size
self
.
use_deepep
=
self
.
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
and
\
(
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_high_throughput"
or
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
or
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_auto"
)
if
self
.
use_deepep
:
all2all_manager
=
get_ep_group
().
device_communicator
.
all2all_manager
assert
all2all_manager
is
not
None
self
.
num_dispatchers
=
all2all_manager
.
world_size
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
FusedMoEQuantConfig
]:
return
None
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEQuantConfig
|
None
:
return
int8_w8a8_moe_quant_config
(
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
per_act_token_quant
=
True
,
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
if
self
.
use_deepep
:
self
.
N
=
2
*
intermediate_size_per_partition
self
.
K
=
hidden_size
params_dtype
=
torch
.
int8
...
...
@@ -133,14 +168,20 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
w1_marlin_list
=
[]
for
ii
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
w1_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w13_weight
[
ii
])
if
not
self
.
use_deepep
:
w1_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w13_weight
[
ii
])
else
:
w1_marlin_in
=
w8a8_nt_kpack2_marlin_weight
(
layer
.
w13_weight
[
ii
])
w1_marlin_list
.
append
(
w1_marlin_in
)
w1_marlin
=
torch
.
stack
(
w1_marlin_list
,
dim
=
0
)
del
w1_marlin_list
w2_marlin_list
=
[]
for
ii
in
range
(
layer
.
w2_weight
.
shape
[
0
]):
w2_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w2_weight
[
ii
])
if
not
self
.
use_deepep
:
w2_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w2_weight
[
ii
])
else
:
w2_marlin_in
=
w8a8_nt_kpack2_marlin_weight
(
layer
.
w2_weight
[
ii
])
w2_marlin_list
.
append
(
w2_marlin_in
)
w2_marlin
=
torch
.
stack
(
w2_marlin_list
,
dim
=
0
)
...
...
@@ -176,4 +217,40 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
False
,
)
\ No newline at end of file
)
def
select_gemm_impl
(
self
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
layer
:
torch
.
nn
.
Module
,
)
->
FusedMoEPermuteExpertsUnpermute
:
from
vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe
import
(
BatchedDeepGemmExperts
,
)
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
DeepGemmExperts
,
)
if
(
prepare_finalize
.
activation_format
==
FusedMoEActivationFormat
.
BatchedExperts
):
max_num_tokens_per_rank
=
prepare_finalize
.
max_num_tokens_per_rank
()
assert
max_num_tokens_per_rank
is
not
None
logger
.
debug
(
"BatchedDeepGemmExperts(%s)"
,
self
.
__class__
.
__name__
)
return
BatchedDeepGemmExperts
(
moe_config
=
self
.
moe
,
max_num_tokens
=
max_num_tokens_per_rank
,
num_dispatchers
=
prepare_finalize
.
num_dispatchers
(),
quant_config
=
self
.
moe_quant_config
,
N
=
self
.
N
,
K
=
self
.
K
)
else
:
logger
.
debug
(
"DeepGemmExperts(%s)"
,
self
.
__class__
.
__name__
)
return
DeepGemmExperts
(
moe_config
=
self
.
moe
,
quant_config
=
self
.
moe_quant_config
,
N
=
self
.
N
,
K
=
self
.
K
)
\ No newline at end of file
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
289f98c6
...
...
@@ -30,6 +30,19 @@ def get_w8a8_int8_marlin_weights(
return
weight
def
w8a8_nt_kpack2_marlin_weight
(
w8a8_w
,
# [size_n, size_k// 2 ]
k_tile
=
16
,
n_tile
=
16
,
):
assert
w8a8_w
.
dtype
==
torch
.
int8
,
"w8a8_w 必须是 int8 类型"
size_n
,
size_k
=
w8a8_w
.
shape
assert
size_n
%
k_tile
==
0
and
size_k
%
n_tile
==
0
,
"k_tile / n_tile 必须能整除对应维度"
w8a8_w
=
w8a8_w
.
reshape
((
size_n
//
n_tile
,
n_tile
,
size_k
//
k_tile
,
k_tile
))
w8a8_w
=
w8a8_w
.
permute
((
0
,
2
,
1
,
3
)).
contiguous
()
w8a8_w
=
w8a8_w
.
reshape
((
size_n
//
k_tile
,
size_k
*
k_tile
))
return
w8a8_w
def
sparse_cutlass_supported
()
->
bool
:
if
not
current_platform
.
is_cuda
():
return
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment