Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bb94d2e5
Commit
bb94d2e5
authored
Apr 07, 2025
by
yangql
Browse files
增加fused-moe int4/int8的支持,以及deepseek精度问题的修复
parent
087254b9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
536 additions
and
163 deletions
+536
-163
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+506
-140
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+30
-23
No files found.
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
bb94d2e5
...
...
@@ -14,17 +14,248 @@ from vllm import _custom_ops as ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
)
from
vllm.model_executor.layers.quantization.utils.int8_utils
import
(
per_token_group_quant_int8
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
logger
=
init_logger
(
__name__
)
device_name
=
current_platform
.
get_device_name
().
replace
(
" "
,
"_"
)
if
device_name
==
'K100_AI'
and
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
==
120
:
stage1_best_config
=
[
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#0
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#1
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#2
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#3
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#4
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#5
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#6
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
2
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
8
},
#7
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#8
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#9
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#10
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#11
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#12
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#13
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#14
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#15
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#32
]
stage2_best_config
=
[
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#0
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#1
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#2
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#3
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#4
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#5
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#6
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#7
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#8
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#9
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#10
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#11
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#12
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#13
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#14
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#15
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#16
]
else
:
stage1_best_config
=
[
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#0
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#1
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#2
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#3
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#4
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#5
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#6
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#7
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#8
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#9
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#10
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
8
},
#11
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#12
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
4
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#13
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#14
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#15
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#32
]
stage2_best_config
=
[
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#0
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#1
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#2
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#3
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#4
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#5
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#6
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#7
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#8
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#9
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#10
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#11
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
8
},
#12
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#13
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#14
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#15
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#16
]
@
triton
.
jit
def
write_zeros_to_output
(
c_ptr
,
stride_cm
,
stride_cn
,
pid_n
,
N
,
offs_token
,
token_mask
,
BLOCK_SIZE_M
,
BLOCK_SIZE_N
,
compute_type
):
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
compute_type
)
def
fused_moe_kernel_awq
(
# Pointers to matrices
a_ptr
,
# [4, 7168]
b_ptr
,
# [256, 512, 3584]
c_ptr
,
# (8, 8, 512)
b_scale_ptr
,
# (256, 512, 56)
b_zp_ptr
,
# (256, 256, 56)
topk_weights_ptr
,
sorted_token_ids_ptr
,
# [0, 1, 2, 3, 4]
expert_ids_ptr
,
num_tokens_post_padded_ptr
,
# Matrix dimensions
N
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
EM
,
# pading后的总索引长度
num_valid_tokens
,
# 有效索引的上限
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am
,
stride_ak
,
stride_be
,
stride_bk
,
#1
stride_bn
,
stride_cm
,
stride_cn
,
stride_bse
,
stride_bsk
,
#1
stride_bsn
,
stride_bze
,
stride_bzk
,
stride_bzn
,
block_k_diviable
:
tl
.
constexpr
,
group_size
:
tl
.
constexpr
,
# 128
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
MUL_ROUTED_WEIGHT
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
has_zp
:
tl
.
constexpr
,
use_int4_w4a16
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
):
pid
=
tl
.
program_id
(
axis
=
0
)
num_pid_m
=
tl
.
cdiv
(
EM
,
BLOCK_SIZE_M
)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE_N
)
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
group_id
=
pid
//
num_pid_in_group
first_pid_m
=
group_id
*
GROUP_SIZE_M
group_size_m
=
min
(
num_pid_m
-
first_pid_m
,
GROUP_SIZE_M
)
pid_m
=
first_pid_m
+
((
pid
%
num_pid_in_group
)
%
group_size_m
)
pid_n
=
(
pid
%
num_pid_in_group
)
//
group_size_m
num_tokens_post_padded
=
tl
.
load
(
num_tokens_post_padded_ptr
)
if
pid_m
*
BLOCK_SIZE_M
>=
num_tokens_post_padded
:
return
offs_token_id
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_token
=
tl
.
load
(
sorted_token_ids_ptr
+
offs_token_id
)
# [block_m]
token_mask
=
offs_token
<
num_valid_tokens
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
# [block_n]
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
# 0, 1, 2, ...... , 127 # # [block_k]
offs_k2
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
//
2
)
# 0, 1, 2, ...... , 127 # # [block_k]
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
# [block_m, block_k]
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
)
if
use_int4_w4a16
:
# [0, 1, 2, ...... , 126, 127] --> [0, 0, 1, 1 ...... , 63, 63]
# [128, 129, 130, ...... , 254, 255] --> [64, 64, 65, 65 ...... , 127, 127]
# b_ptrs = b_ptr + off_experts * stride_be + \
# (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
\
offs_bn
[:,
None
]
*
stride_bn
+
(
offs_k2
[
None
,
:])
*
stride_bk
# tl.device_print("stride_bn",stride_bsn)>1
# tl.device_print("stride_bk",stride_bk)=1
b_shifter
=
(
offs_k
[:,
None
]
%
2
)
*
4
# 0, 4
elif
use_int8_w8a16
:
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
\
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
if
not
has_zp
and
use_int4_w4a16
:
b_zp_num
=
8
if
not
has_zp
and
use_int8_w8a16
:
b_zp_num
=
128
elif
has_zp
and
use_int4_w4a16
:
b_zp_shifter
=
(
offs_bn
[
None
,
:]
%
2
)
*
4
# 0, 4
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
if
not
block_k_diviable
:
k_mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
k_other
=
0.0
else
:
k_mask
=
None
k_other
=
None
a
=
tl
.
load
(
a_ptrs
,
mask
=
token_mask
[:,
None
]
&
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
),
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
)
if
use_int4_w4a16
:
b
=
tl
.
interleave
(
b
,
b
)
b
=
b
.
trans
()
b
=
(
b
>>
b_shifter
)
&
0xF
b_scale_ptrs
=
b_scale_ptr
+
off_experts
*
stride_bse
+
\
offs_bn
[
None
,
:]
*
stride_bsk
+
\
((
offs_k
[:,
None
]
+
BLOCK_SIZE_K
*
k
)
//
group_size
)
*
stride_bsn
qzeros_scles
=
tl
.
load
(
b_scale_ptrs
,
mask
=
k_mask
,
other
=
k_other
)
scales_int16
=
tl
.
cast
(
qzeros_scles
,
tl
.
uint16
)
b_scale
=
tl
.
cast
(
scales_int16
,
tl
.
float16
,
bitcast
=
True
)
# tl.device_print("b_scale dequant",b_scale)
mid
=
qzeros_scles
>>
16
# b_zp = tl.cast(mid,tl.float16,bitcast=False)
b_zp
=
tl
.
cast
(
mid
,
tl
.
float16
)
# b_zp = tl.cast(zeros_int16,tl.float16,bitcast=False)
# tl.device_print("bzp",b_zp)
# We accumulate along the K dimension.
b
=
((
b
-
b_zp
)
*
b_scale
).
to
(
tl
.
float16
)
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
# Advance the ptrs to the next K block.
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
if
use_int4_w4a16
:
b_ptrs
+=
(
BLOCK_SIZE_K
//
2
)
*
stride_bk
else
:
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
MUL_ROUTED_WEIGHT
:
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0
)
accumulator
=
accumulator
*
moe_weight
[:,
None
]
accumulator
=
accumulator
.
to
(
compute_type
)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
c_ptr
+
stride_cm
*
offs_token
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
...
...
@@ -287,6 +518,7 @@ def fused_moe_kernel(
top_k
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
use_fp8_w8a8
:
tl
.
constexpr
,
use_int8_w8a8
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
...
...
@@ -340,7 +572,6 @@ def fused_moe_kernel(
pid_m
=
first_pid_m
+
((
pid
%
num_pid_in_group
)
%
group_size_m
)
pid_n
=
(
pid
%
num_pid_in_group
)
//
group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
...
...
@@ -355,22 +586,13 @@ def fused_moe_kernel(
offs_token
=
tl
.
load
(
sorted_token_ids_ptr
+
offs_token_id
)
token_mask
=
offs_token
<
num_valid_tokens
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
).
to
(
tl
.
int64
)
if
off_experts
==
-
1
:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
# in the current expert parallel rank.
write_zeros_to_output
(
c_ptr
,
stride_cm
,
stride_cn
,
pid_n
,
N
,
offs_token
,
token_mask
,
BLOCK_SIZE_M
,
BLOCK_SIZE_N
,
compute_type
)
return
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
).
to
(
tl
.
int64
))
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
).
to
(
tl
.
int64
)
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
if
use_int8_w8a16
:
...
...
@@ -378,7 +600,7 @@ def fused_moe_kernel(
None
,
:]
*
stride_bsn
b_scale
=
tl
.
load
(
b_scale_ptrs
)
if
use_fp8_w8a8
:
if
use_fp8_w8a8
or
use_int8_w8a8
:
if
group_k
>
0
and
group_n
>
0
:
a_scale_ptrs
=
a_scale_ptr
+
(
offs_token
//
top_k
)
*
stride_asm
offs_bsn
=
offs_bn
//
group_n
...
...
@@ -407,7 +629,7 @@ def fused_moe_kernel(
# We accumulate along the K dimension.
if
use_int8_w8a16
:
accumulator
=
tl
.
dot
(
a
,
b
.
to
(
compute_type
),
acc
=
accumulator
)
elif
use_fp8_w8a8
:
elif
use_fp8_w8a8
or
use_int8_w8a8
:
if
group_k
>
0
and
group_n
>
0
:
k_start
=
k
*
BLOCK_SIZE_K
offs_ks
=
k_start
//
group_k
...
...
@@ -433,7 +655,7 @@ def fused_moe_kernel(
accumulator
=
accumulator
*
moe_weight
[:,
None
]
if
use_int8_w8a16
:
accumulator
=
(
accumulator
*
b_scale
).
to
(
compute_type
)
elif
use_fp8_w8a8
:
elif
use_fp8_w8a8
or
use_int8_w8a8
:
if
group_k
>
0
and
group_n
>
0
:
accumulator
=
accumulator
.
to
(
compute_type
)
else
:
...
...
@@ -591,7 +813,8 @@ def moe_align_block_size(
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
num_experts
:
int
,
expert_map
:
torch
.
Tensor
=
None
expert_map
:
torch
.
Tensor
=
None
,
num_token
:
Optional
[
int
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Aligns the token distribution across experts to be compatible with block
...
...
@@ -634,15 +857,20 @@ def moe_align_block_size(
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
"""
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids
=
torch
.
empty
((
max_num_tokens_padded
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
if
num_token
:
if
num_token
<
block_size
:
max_num_tokens_padded
=
min
(
topk_ids
.
numel
()
*
block_size
,
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
))
else
:
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids
=
torch
.
full
((
max_num_tokens_padded
,),
fill_value
=
topk_ids
.
numel
(),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
else
:
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids
=
torch
.
empty
((
max_num_tokens_padded
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
block_size
)
# Expert ids must be zeroed out to prevent index out of bounds error while
# mapping global expert ids to local expert ids in expert parallelism.
expert_ids
=
torch
.
zeros
((
max_num_m_blocks
,
),
expert_ids
=
torch
.
empty
((
max_num_m_blocks
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
...
...
@@ -693,6 +921,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
config
:
Dict
[
str
,
Any
],
compute_type
:
tl
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
...
...
@@ -711,6 +940,19 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
assert
triton
.
cdiv
(
B
.
shape
[
-
2
],
block_n
)
==
B_scale
.
shape
[
-
2
]
assert
triton
.
cdiv
(
B
.
shape
[
-
1
],
block_k
)
==
B_scale
.
shape
[
-
1
]
elif
use_int8_w8a8
:
assert
B_scale
is
not
None
if
block_shape
is
None
:
A
,
A_scale
=
ops
.
scaled_int8_quant
(
A
,
A_scale
)
else
:
assert
len
(
block_shape
)
==
2
block_n
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
A
,
A_scale
=
per_token_group_quant_int8
(
A
,
block_k
)
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
assert
triton
.
cdiv
(
B
.
shape
[
-
2
],
block_n
)
==
B_scale
.
shape
[
-
2
]
assert
triton
.
cdiv
(
B
.
shape
[
-
1
],
block_k
)
==
B_scale
.
shape
[
-
1
]
elif
use_int8_w8a16
or
use_int4_w4a16
:
assert
B_scale
is
not
None
assert
block_shape
is
None
or
block_shape
[
0
]
==
0
...
...
@@ -733,77 +975,117 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
block_shape
is
not
None
and
block_shape
[
1
]
>
0
:
assert
B_scale
is
not
None
and
B_scale
.
ndim
==
3
assert
B_zp
is
None
or
B_zp
.
ndim
==
3
if
os
.
environ
.
get
(
'moe_wna16_use_cuda'
)
==
'1'
:
use_moe_wna16_cuda
=
should_moe_wna16_use_cuda
(
num_valid_tokens
=
topk_ids
.
numel
(),
group_size
=
block_shape
[
1
],
num_experts
=
B
.
shape
[
0
],
bit
=
4
if
use_int4_w4a16
else
8
)
config
=
config
.
copy
()
config
.
update
(
get_moe_wna16_block_config
(
config
=
config
,
use_moe_wna16_cuda
=
use_moe_wna16_cuda
,
num_valid_tokens
=
topk_ids
.
numel
(),
size_k
=
A
.
shape
[
1
],
size_n
=
B
.
shape
[
1
],
num_experts
=
B
.
shape
[
1
],
group_size
=
block_shape
[
1
],
real_top_k
=
topk_ids
.
shape
[
1
],
block_size_m
=
config
[
"BLOCK_SIZE_M"
]))
if
use_moe_wna16_cuda
:
bit
=
4
if
use_int4_w4a16
else
8
ops
.
moe_wna16_gemm
(
A
,
C
,
B
,
B_scale
,
B_zp
,
topk_weights
if
mul_routed_weight
else
None
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
top_k
,
config
[
"BLOCK_SIZE_M"
],
config
[
"BLOCK_SIZE_N"
],
config
[
"BLOCK_SIZE_K"
],
bit
)
return
if
os
.
environ
.
get
(
'AWQ_MOE_SZ'
)
==
'1'
:
fused_moe_kernel_awq
[
grid
](
A
,
B
,
C
,
B_scale
,
B_zp
,
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
B
.
shape
[
1
],
A
.
shape
[
1
],
EM
,
topk_ids
.
numel
(),
A
.
stride
(
0
),
A
.
stride
(
1
),
B
.
stride
(
0
),
B
.
stride
(
2
),
B
.
stride
(
1
),
C
.
stride
(
1
),
C
.
stride
(
2
),
B_scale
.
stride
(
0
),
B_scale
.
stride
(
2
),
B_scale
.
stride
(
1
),
B_zp
.
stride
(
0
)
if
B_zp
is
not
None
else
0
,
B_zp
.
stride
(
2
)
if
B_zp
is
not
None
else
0
,
B_zp
.
stride
(
1
)
if
B_zp
is
not
None
else
0
,
block_k_diviable
=
A
.
shape
[
1
]
%
config
[
"BLOCK_SIZE_K"
]
==
0
,
group_size
=
block_shape
[
1
],
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
top_k
=
top_k
,
compute_type
=
compute_type
,
has_zp
=
B_zp
is
not
None
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int8_w8a16
=
use_int8_w8a16
,
**
config
,
)
else
:
fused_moe_kernel_gptq_awq
[
grid
](
A
,
B
,
C
,
B_scale
,
B_zp
,
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
B
.
shape
[
1
],
A
.
shape
[
1
],
EM
,
topk_ids
.
numel
(),
A
.
stride
(
0
),
A
.
stride
(
1
),
B
.
stride
(
0
),
B
.
stride
(
2
),
B
.
stride
(
1
),
C
.
stride
(
1
),
C
.
stride
(
2
),
B_scale
.
stride
(
0
),
B_scale
.
stride
(
2
),
B_scale
.
stride
(
1
),
B_zp
.
stride
(
0
)
if
B_zp
is
not
None
else
0
,
B_zp
.
stride
(
2
)
if
B_zp
is
not
None
else
0
,
B_zp
.
stride
(
1
)
if
B_zp
is
not
None
else
0
,
block_k_diviable
=
A
.
shape
[
1
]
%
config
[
"BLOCK_SIZE_K"
]
==
0
,
group_size
=
block_shape
[
1
],
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
top_k
=
top_k
,
compute_type
=
compute_type
,
has_zp
=
B_zp
is
not
None
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int8_w8a16
=
use_int8_w8a16
,
**
config
,
)
use_moe_wna16_cuda
=
should_moe_wna16_use_cuda
(
num_valid_tokens
=
topk_ids
.
numel
(),
group_size
=
block_shape
[
1
],
num_experts
=
B
.
shape
[
0
],
bit
=
4
if
use_int4_w4a16
else
8
)
config
=
config
.
copy
()
config
.
update
(
get_moe_wna16_block_config
(
config
=
config
,
use_moe_wna16_cuda
=
use_moe_wna16_cuda
,
num_valid_tokens
=
topk_ids
.
numel
(),
size_k
=
A
.
shape
[
1
],
size_n
=
B
.
shape
[
1
],
num_experts
=
B
.
shape
[
1
],
group_size
=
block_shape
[
1
],
real_top_k
=
topk_ids
.
shape
[
1
],
block_size_m
=
config
[
"BLOCK_SIZE_M"
]))
if
use_moe_wna16_cuda
:
bit
=
4
if
use_int4_w4a16
else
8
ops
.
moe_wna16_gemm
(
A
,
C
,
B
,
B_scale
,
B_zp
,
topk_weights
if
mul_routed_weight
else
None
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
top_k
,
config
[
"BLOCK_SIZE_M"
],
config
[
"BLOCK_SIZE_N"
],
config
[
"BLOCK_SIZE_K"
],
bit
)
return
fused_moe_kernel_gptq_awq
[
grid
](
A
,
B
,
C
,
B_scale
,
B_zp
,
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
B
.
shape
[
1
],
A
.
shape
[
1
],
EM
,
topk_ids
.
numel
(),
A
.
stride
(
0
),
A
.
stride
(
1
),
B
.
stride
(
0
),
B
.
stride
(
2
),
B
.
stride
(
1
),
C
.
stride
(
1
),
C
.
stride
(
2
),
B_scale
.
stride
(
0
),
B_scale
.
stride
(
2
),
B_scale
.
stride
(
1
),
B_zp
.
stride
(
0
)
if
B_zp
is
not
None
else
0
,
B_zp
.
stride
(
2
)
if
B_zp
is
not
None
else
0
,
B_zp
.
stride
(
1
)
if
B_zp
is
not
None
else
0
,
block_k_diviable
=
A
.
shape
[
1
]
%
config
[
"BLOCK_SIZE_K"
]
==
0
,
group_size
=
block_shape
[
1
],
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
top_k
=
top_k
,
compute_type
=
compute_type
,
has_zp
=
B_zp
is
not
None
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int8_w8a16
=
use_int8_w8a16
,
**
config
,
)
else
:
config
=
config
.
copy
()
BLOCK_SIZE_K
=
config
.
pop
(
"BLOCK_SIZE_K"
)
if
block_shape
is
not
None
:
BLOCK_SIZE_K
=
min
(
BLOCK_SIZE_K
,
min
(
block_shape
[
0
],
block_shape
[
1
]))
#
config = config.copy()
#
BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
#
if block_shape is not None:
#
BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0],
#
block_shape[1]))
fused_moe_kernel
[
grid
](
A
,
B
,
...
...
@@ -841,8 +1123,9 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
top_k
=
top_k
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
BLOCK_SIZE_K
=
BLOCK_SIZE_K
,
#
BLOCK_SIZE_K=BLOCK_SIZE_K,
**
config
,
)
...
...
@@ -869,7 +1152,7 @@ def get_moe_configs(
dtype
:
Optional
[
str
],
block_n
:
Optional
[
int
]
=
None
,
block_k
:
Optional
[
int
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
Optional
[
Dict
[
int
,
Any
]]:
"""
Return optimized configurations for the fused MoE kernel.
...
...
@@ -887,6 +1170,15 @@ def get_moe_configs(
config_file_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
json_file_name
)
if
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
==
120
:
config_file_path_120
=
config_file_path
.
replace
(
".json"
,
"_120.json"
)
if
os
.
path
.
exists
(
config_file_path_120
):
with
open
(
config_file_path_120
)
as
f
:
logger
.
info
(
"Using configuration from %s for MoE layer."
,
config_file_path_120
)
# If a configuration has been found, return it
return
{
int
(
key
):
val
for
key
,
val
in
json
.
load
(
f
).
items
()}
if
os
.
path
.
exists
(
config_file_path
):
with
open
(
config_file_path
)
as
f
:
logger
.
info
(
"Using configuration from %s for MoE layer."
,
...
...
@@ -975,7 +1267,7 @@ def get_default_config(
dtype
:
Optional
[
str
],
is_marlin
:
bool
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
Dict
[
str
,
int
]:
if
dtype
==
"fp8_w8a8"
and
block_shape
is
not
None
:
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
...
...
@@ -988,21 +1280,21 @@ def get_default_config(
"num_warps"
:
4
,
"num_stages"
:
3
,
}
elif
dtype
in
[
"int4_w4a16"
,
"int8_w8a16"
]
and
block_shape
is
not
None
:
# moe wna16 kernels
# only set BLOCK_SIZE_M
# BLOCK_SIZE_N and BLOCK_SIZE_K would be set later
bit
=
4
if
dtype
==
"int4_w4a16"
else
8
use_moe_wna16_cuda
=
should_moe_wna16_use_cuda
(
M
*
topk
,
block_shape
[
1
],
E
,
bit
)
if
use_moe_wna16_cuda
:
config
=
{
"BLOCK_SIZE_M"
:
min
(
16
,
M
)}
elif
M
<=
20
:
config
=
{
"BLOCK_SIZE_M"
:
16
,
"GROUP_SIZE_M"
:
1
}
elif
M
<=
40
:
config
=
{
"BLOCK_SIZE_M"
:
32
,
"GROUP_SIZE_M"
:
1
}
else
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
"GROUP_SIZE_M"
:
1
}
#
elif dtype in ["int4_w4a16", "int8_w8a16"] and block_shape is not None:
#
# moe wna16 kernels
#
# only set BLOCK_SIZE_M
#
# BLOCK_SIZE_N and BLOCK_SIZE_K would be set later
#
bit = 4 if dtype == "int4_w4a16" else 8
#
use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk,
#
block_shape[1], E, bit)
#
if use_moe_wna16_cuda:
#
config = {"BLOCK_SIZE_M": min(16, M)}
#
elif M <= 20:
#
config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1}
#
elif M <= 40:
#
config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1}
#
else:
#
config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1}
else
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
...
...
@@ -1043,8 +1335,8 @@ def try_get_optimal_moe_config(
E
,
_
,
N
=
w2_shape
else
:
E
,
N
,
_
=
w2_shape
if
dtype
==
"int4_w4a16"
:
N
=
N
*
2
#
if dtype == "int4_w4a16":
#
N = N * 2
block_n
=
block_shape
[
0
]
if
block_shape
else
0
block_k
=
block_shape
[
1
]
if
block_shape
else
0
configs
=
get_moe_configs
(
E
,
N
,
dtype
,
block_n
,
block_k
,
use_nn_moe
=
use_nn_moe
)
...
...
@@ -1159,9 +1451,12 @@ def grouped_topk(hidden_states: torch.Tensor,
def
get_config_dtype_str
(
dtype
:
torch
.
dtype
,
use_int4_w4a16
:
Optional
[
bool
]
=
False
,
use_int8_w8a16
:
Optional
[
bool
]
=
False
,
use_fp8_w8a8
:
Optional
[
bool
]
=
False
):
use_fp8_w8a8
:
Optional
[
bool
]
=
False
,
use_int8_w8a8
:
Optional
[
bool
]
=
False
):
if
use_fp8_w8a8
:
return
"fp8_w8a8"
elif
use_int8_w8a8
:
return
"int8_w8a8"
elif
use_int8_w8a16
:
return
"int8_w8a16"
elif
use_int4_w4a16
:
...
...
@@ -1180,6 +1475,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
topk_ids
:
torch
.
Tensor
,
activation
:
Optional
[
str
]
=
None
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
...
...
@@ -1193,7 +1489,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
activation
,
use_fp8_w8a8
,
use_int8_w8a16
,
activation
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
...
...
@@ -1207,6 +1503,7 @@ def inplace_fused_experts_fake(
topk_ids
:
torch
.
Tensor
,
activation
:
Optional
[
str
]
=
None
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
...
...
@@ -1238,6 +1535,7 @@ def outplace_fused_experts(
topk_ids
:
torch
.
Tensor
,
activation
:
Optional
[
str
]
=
None
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
...
...
@@ -1251,7 +1549,7 @@ def outplace_fused_experts(
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
torch
.
Tensor
:
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
False
,
activation
,
use_fp8_w8a8
,
use_int8_w8a16
,
False
,
activation
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
...
...
@@ -1265,6 +1563,7 @@ def outplace_fused_experts_fake(
topk_ids
:
torch
.
Tensor
,
activation
:
Optional
[
str
]
=
None
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
...
...
@@ -1296,6 +1595,7 @@ def fused_experts(hidden_states: torch.Tensor,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
...
...
@@ -1312,14 +1612,14 @@ def fused_experts(hidden_states: torch.Tensor,
if
inplace
:
torch
.
ops
.
vllm
.
inplace_fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
activation
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
global_num_experts
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
return
hidden_states
else
:
return
torch
.
ops
.
vllm
.
outplace_fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
activation
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
global_num_experts
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
)
...
...
@@ -1332,6 +1632,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
...
...
@@ -1373,22 +1674,24 @@ def fused_experts_impl(hidden_states: torch.Tensor,
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
config_dtype
=
get_config_dtype_str
(
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
dtype
=
hidden_states
.
dtype
)
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
w1
.
shape
,
w2
.
shape
,
top_k_num
,
config_dtype
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
,
)
if
not
use_int8_w8a8
:
config_dtype
=
get_config_dtype_str
(
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
dtype
=
hidden_states
.
dtype
)
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
w1
.
shape
,
w2
.
shape
,
top_k_num
,
config_dtype
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
,
)
config
=
get_config_func
(
M
)
config
=
get_config_func
(
M
)
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
...
...
@@ -1442,10 +1745,43 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
if
use_int8_w8a8
:
m
=
curr_hidden_states
.
shape
[
0
]
if
m
<=
16
:
config
=
stage1_best_config
[
m
-
1
]
elif
m
<=
32
:
config
=
stage1_best_config
[
15
]
elif
m
<=
64
:
config
=
stage1_best_config
[
16
]
elif
m
<
256
:
config
=
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
}
else
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
,
"num_warps"
:
4
}
# sorted_token_ids, expert_ids, num_tokens_post_padded = (
# moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'],
# global_num_experts, expert_map))
if
use_int4_w4a16
:
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
global_num_experts
,
expert_map
,
curr_hidden_states
.
shape
[
0
]))
else
:
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
global_num_experts
,
expert_map
))
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
global_num_experts
,
expert_map
))
invoke_fused_moe_kernel
(
curr_hidden_states
,
w1
,
...
...
@@ -1463,6 +1799,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
block_shape
=
block_shape
,
...
...
@@ -1476,7 +1813,33 @@ def fused_experts_impl(hidden_states: torch.Tensor,
intermediate_cache1
.
view
(
-
1
,
N
))
else
:
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
"
)
if
use_int8_w8a8
:
m
=
curr_hidden_states
.
shape
[
0
]
if
m
<=
16
:
config
=
stage2_best_config
[
m
-
1
]
elif
m
<=
32
:
config
=
stage2_best_config
[
15
]
elif
m
<=
64
:
config
=
stage2_best_config
[
16
]
elif
m
<
256
:
config
=
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
}
else
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
,
"num_warps"
:
4
}
invoke_fused_moe_kernel
(
intermediate_cache2
,
w2
,
intermediate_cache3
,
...
...
@@ -1493,6 +1856,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
block_shape
=
block_shape
,
...
...
@@ -1517,6 +1881,7 @@ def fused_moe(
topk_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
...
...
@@ -1598,6 +1963,7 @@ def fused_moe(
inplace
=
inplace
,
activation
=
activation
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
global_num_experts
=
global_num_experts
,
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
bb94d2e5
...
...
@@ -164,21 +164,28 @@ class DeepseekV2MoE(nn.Module):
shared_output
=
self
.
shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
else
:
# This is a special case to avoid FP16 overflow
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
# if hidden_states.dtype != torch.float16:
# final_hidden_states = self.experts(
# hidden_states=hidden_states,
# router_logits=router_logits) * self.routed_scaling_factor
# else:
# # This is a special case to avoid FP16 overflow
# final_hidden_states = self.experts(hidden_states=hidden_states,
# router_logits=router_logits)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
if
shared_output
is
not
None
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
=
final_hidden_states
+
shared_output
else
:
# This is a special case to avoid FP16 overflow
final_hidden_states
=
final_hidden_states
+
shared_output
\
*
(
1.
/
self
.
routed_scaling_factor
)
final_hidden_states
=
final_hidden_states
+
shared_output
# if shared_output is not None:
# if hidden_states.dtype != torch.float16:
# final_hidden_states = final_hidden_states + shared_output
# else:
# # This is a special case to avoid FP16 overflow
# final_hidden_states = final_hidden_states + shared_output \
# * (1. / self.routed_scaling_factor)
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
...
...
@@ -571,18 +578,18 @@ class DeepseekV2DecoderLayer(nn.Module):
)
# Fully Connected
if
isinstance
(
self
.
mlp
,
DeepseekV2MoE
)
and
\
hidden_states
.
dtype
==
torch
.
float16
:
# This is a special case to avoid FP16 overflow
hidden_states
*=
1.
/
self
.
routed_scaling_factor
#
if isinstance(self.mlp, DeepseekV2MoE) and \
#
hidden_states.dtype == torch.float16:
#
# This is a special case to avoid FP16 overflow
#
hidden_states *= 1. / self.routed_scaling_factor
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
if
isinstance
(
self
.
mlp
,
DeepseekV2MLP
)
and
\
hidden_states
.
dtype
==
torch
.
float16
:
# This is a special case to avoid FP16 overflow
hidden_states
*=
1.
/
self
.
routed_scaling_factor
residual
*=
1.
/
self
.
routed_scaling_factor
#
if isinstance(self.mlp, DeepseekV2MLP) and \
#
hidden_states.dtype == torch.float16:
#
# This is a special case to avoid FP16 overflow
#
hidden_states *= 1. / self.routed_scaling_factor
#
residual *= 1. / self.routed_scaling_factor
return
hidden_states
,
residual
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment