Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fe6d3b05
Commit
fe6d3b05
authored
Apr 02, 2025
by
zhuwenwen
Browse files
remove fused_moe of quantization
parent
68826ce6
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
87 additions
and
474 deletions
+87
-474
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+87
-474
No files found.
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
fe6d3b05
...
@@ -14,255 +14,17 @@ from vllm import _custom_ops as ops
...
@@ -14,255 +14,17 @@ from vllm import _custom_ops as ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
)
per_token_group_quant_fp8
)
from
vllm.model_executor.layers.quantization.utils.int8_utils
import
(
per_token_group_quant_int8
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
device_name
=
current_platform
.
get_device_name
().
replace
(
" "
,
"_"
)
if
device_name
==
'K100_AI'
and
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
==
120
:
stage1_best_config
=
[
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#0
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#1
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#2
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#3
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#4
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#5
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#6
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
2
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
8
},
#7
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#8
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#9
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#10
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#11
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#12
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#13
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#14
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#15
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#32
]
stage2_best_config
=
[
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#0
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#1
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#2
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#3
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#4
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#5
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#6
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#7
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#8
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#9
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#10
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#11
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#12
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#13
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#14
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#15
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#16
]
else
:
stage1_best_config
=
[
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#0
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#1
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#2
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#3
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#4
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#5
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#6
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#7
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#8
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#9
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#10
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
8
},
#11
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#12
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
4
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#13
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#14
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#15
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#32
]
stage2_best_config
=
[
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#0
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#1
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#2
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#3
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#4
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#5
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#6
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#7
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#8
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#9
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#10
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#11
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
8
},
#12
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#13
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#14
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#15
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#16
]
@
triton
.
jit
@
triton
.
jit
def
fused_moe_kernel_awq
(
# Pointers to matrices
a_ptr
,
# [4, 7168]
b_ptr
,
# [256, 512, 3584]
c_ptr
,
# (8, 8, 512)
b_scale_ptr
,
# (256, 512, 56)
b_zp_ptr
,
# (256, 256, 56)
topk_weights_ptr
,
sorted_token_ids_ptr
,
# [0, 1, 2, 3, 4]
expert_ids_ptr
,
num_tokens_post_padded_ptr
,
# Matrix dimensions
N
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
EM
,
# pading后的总索引长度
num_valid_tokens
,
# 有效索引的上限
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am
,
stride_ak
,
stride_be
,
stride_bk
,
#1
stride_bn
,
stride_cm
,
stride_cn
,
stride_bse
,
stride_bsk
,
#1
stride_bsn
,
stride_bze
,
stride_bzk
,
stride_bzn
,
block_k_diviable
:
tl
.
constexpr
,
group_size
:
tl
.
constexpr
,
# 128
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
MUL_ROUTED_WEIGHT
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
has_zp
:
tl
.
constexpr
,
use_int4_w4a16
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
):
pid
=
tl
.
program_id
(
axis
=
0
)
num_pid_m
=
tl
.
cdiv
(
EM
,
BLOCK_SIZE_M
)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE_N
)
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
group_id
=
pid
//
num_pid_in_group
first_pid_m
=
group_id
*
GROUP_SIZE_M
group_size_m
=
min
(
num_pid_m
-
first_pid_m
,
GROUP_SIZE_M
)
pid_m
=
first_pid_m
+
((
pid
%
num_pid_in_group
)
%
group_size_m
)
pid_n
=
(
pid
%
num_pid_in_group
)
//
group_size_m
num_tokens_post_padded
=
tl
.
load
(
num_tokens_post_padded_ptr
)
if
pid_m
*
BLOCK_SIZE_M
>=
num_tokens_post_padded
:
return
offs_token_id
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_token
=
tl
.
load
(
sorted_token_ids_ptr
+
offs_token_id
)
# [block_m]
token_mask
=
offs_token
<
num_valid_tokens
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
# [block_n]
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
# 0, 1, 2, ...... , 127 # # [block_k]
offs_k2
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
//
2
)
# 0, 1, 2, ...... , 127 # # [block_k]
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
# [block_m, block_k]
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
)
if
use_int4_w4a16
:
# [0, 1, 2, ...... , 126, 127] --> [0, 0, 1, 1 ...... , 63, 63]
# [128, 129, 130, ...... , 254, 255] --> [64, 64, 65, 65 ...... , 127, 127]
# b_ptrs = b_ptr + off_experts * stride_be + \
# (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
\
offs_bn
[:,
None
]
*
stride_bn
+
(
offs_k2
[
None
,
:])
*
stride_bk
# tl.device_print("stride_bn",stride_bsn)>1
# tl.device_print("stride_bk",stride_bk)=1
b_shifter
=
(
offs_k
[:,
None
]
%
2
)
*
4
# 0, 4
elif
use_int8_w8a16
:
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
\
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
if
not
has_zp
and
use_int4_w4a16
:
b_zp_num
=
8
if
not
has_zp
and
use_int8_w8a16
:
b_zp_num
=
128
elif
has_zp
and
use_int4_w4a16
:
b_zp_shifter
=
(
offs_bn
[
None
,
:]
%
2
)
*
4
# 0, 4
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
if
not
block_k_diviable
:
k_mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
k_other
=
0.0
else
:
k_mask
=
None
k_other
=
None
a
=
tl
.
load
(
a_ptrs
,
mask
=
token_mask
[:,
None
]
&
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
),
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
)
if
use_int4_w4a16
:
b
=
tl
.
interleave
(
b
,
b
)
b
=
b
.
trans
()
b
=
(
b
>>
b_shifter
)
&
0xF
b_scale_ptrs
=
b_scale_ptr
+
off_experts
*
stride_bse
+
\
offs_bn
[
None
,
:]
*
stride_bsk
+
\
((
offs_k
[:,
None
]
+
BLOCK_SIZE_K
*
k
)
//
group_size
)
*
stride_bsn
qzeros_scles
=
tl
.
load
(
b_scale_ptrs
,
mask
=
k_mask
,
other
=
k_other
)
scales_int16
=
tl
.
cast
(
qzeros_scles
,
tl
.
uint16
)
b_scale
=
tl
.
cast
(
scales_int16
,
tl
.
float16
,
bitcast
=
True
)
# tl.device_print("b_scale dequant",b_scale)
mid
=
qzeros_scles
>>
16
# b_zp = tl.cast(mid,tl.float16,bitcast=False)
b_zp
=
tl
.
cast
(
mid
,
tl
.
float16
)
# b_zp = tl.cast(zeros_int16,tl.float16,bitcast=False)
# tl.device_print("bzp",b_zp)
# We accumulate along the K dimension.
b
=
((
b
-
b_zp
)
*
b_scale
).
to
(
tl
.
float16
)
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
# Advance the ptrs to the next K block.
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
if
use_int4_w4a16
:
b_ptrs
+=
(
BLOCK_SIZE_K
//
2
)
*
stride_bk
else
:
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
MUL_ROUTED_WEIGHT
:
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0
)
accumulator
=
accumulator
*
moe_weight
[:,
None
]
accumulator
=
accumulator
.
to
(
compute_type
)
# -----------------------------------------------------------
# Write back the block of the output
def
write_zeros_to_output
(
c_ptr
,
stride_cm
,
stride_cn
,
pid_n
,
N
,
offs_token
,
def
write_zeros_to_output
(
c_ptr
,
stride_cm
,
stride_cn
,
pid_n
,
N
,
offs_token
,
token_mask
,
BLOCK_SIZE_M
,
BLOCK_SIZE_N
,
token_mask
,
BLOCK_SIZE_M
,
BLOCK_SIZE_N
,
compute_type
):
compute_type
):
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
compute_type
)
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
compute_type
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
c_ptr
+
stride_cm
*
offs_token
[:,
None
]
+
stride_cn
*
offs_cn
[
c_ptrs
=
c_ptr
+
stride_cm
*
offs_token
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
None
,
:]
...
@@ -525,7 +287,6 @@ def fused_moe_kernel(
...
@@ -525,7 +287,6 @@ def fused_moe_kernel(
top_k
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
use_fp8_w8a8
:
tl
.
constexpr
,
use_fp8_w8a8
:
tl
.
constexpr
,
use_int8_w8a8
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
):
use_int8_w8a16
:
tl
.
constexpr
):
"""
"""
Implements the fused computation for a Mixture of Experts (MOE) using
Implements the fused computation for a Mixture of Experts (MOE) using
...
@@ -579,6 +340,7 @@ def fused_moe_kernel(
...
@@ -579,6 +340,7 @@ def fused_moe_kernel(
pid_m
=
first_pid_m
+
((
pid
%
num_pid_in_group
)
%
group_size_m
)
pid_m
=
first_pid_m
+
((
pid
%
num_pid_in_group
)
%
group_size_m
)
pid_n
=
(
pid
%
num_pid_in_group
)
//
group_size_m
pid_n
=
(
pid
%
num_pid_in_group
)
//
group_size_m
# ----------------------------------------------------------
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# We will advance this pointer as we move in the K direction
...
@@ -616,7 +378,7 @@ def fused_moe_kernel(
...
@@ -616,7 +378,7 @@ def fused_moe_kernel(
None
,
:]
*
stride_bsn
None
,
:]
*
stride_bsn
b_scale
=
tl
.
load
(
b_scale_ptrs
)
b_scale
=
tl
.
load
(
b_scale_ptrs
)
if
use_fp8_w8a8
or
use_int8_w8a8
:
if
use_fp8_w8a8
:
if
group_k
>
0
and
group_n
>
0
:
if
group_k
>
0
and
group_n
>
0
:
a_scale_ptrs
=
a_scale_ptr
+
(
offs_token
//
top_k
)
*
stride_asm
a_scale_ptrs
=
a_scale_ptr
+
(
offs_token
//
top_k
)
*
stride_asm
offs_bsn
=
offs_bn
//
group_n
offs_bsn
=
offs_bn
//
group_n
...
@@ -645,7 +407,7 @@ def fused_moe_kernel(
...
@@ -645,7 +407,7 @@ def fused_moe_kernel(
# We accumulate along the K dimension.
# We accumulate along the K dimension.
if
use_int8_w8a16
:
if
use_int8_w8a16
:
accumulator
=
tl
.
dot
(
a
,
b
.
to
(
compute_type
),
acc
=
accumulator
)
accumulator
=
tl
.
dot
(
a
,
b
.
to
(
compute_type
),
acc
=
accumulator
)
elif
use_fp8_w8a8
or
use_int8_w8a8
:
elif
use_fp8_w8a8
:
if
group_k
>
0
and
group_n
>
0
:
if
group_k
>
0
and
group_n
>
0
:
k_start
=
k
*
BLOCK_SIZE_K
k_start
=
k
*
BLOCK_SIZE_K
offs_ks
=
k_start
//
group_k
offs_ks
=
k_start
//
group_k
...
@@ -671,7 +433,7 @@ def fused_moe_kernel(
...
@@ -671,7 +433,7 @@ def fused_moe_kernel(
accumulator
=
accumulator
*
moe_weight
[:,
None
]
accumulator
=
accumulator
*
moe_weight
[:,
None
]
if
use_int8_w8a16
:
if
use_int8_w8a16
:
accumulator
=
(
accumulator
*
b_scale
).
to
(
compute_type
)
accumulator
=
(
accumulator
*
b_scale
).
to
(
compute_type
)
elif
use_fp8_w8a8
or
use_int8_w8a8
:
elif
use_fp8_w8a8
:
if
group_k
>
0
and
group_n
>
0
:
if
group_k
>
0
and
group_n
>
0
:
accumulator
=
accumulator
.
to
(
compute_type
)
accumulator
=
accumulator
.
to
(
compute_type
)
else
:
else
:
...
@@ -829,8 +591,7 @@ def moe_align_block_size(
...
@@ -829,8 +591,7 @@ def moe_align_block_size(
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
block_size
:
int
,
num_experts
:
int
,
num_experts
:
int
,
expert_map
:
torch
.
Tensor
=
None
,
expert_map
:
torch
.
Tensor
=
None
num_token
:
Optional
[
int
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
Aligns the token distribution across experts to be compatible with block
Aligns the token distribution across experts to be compatible with block
...
@@ -873,13 +634,6 @@ def moe_align_block_size(
...
@@ -873,13 +634,6 @@ def moe_align_block_size(
- The padding ensures that the total number of tokens is now divisible
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
by block_size for proper block matrix operations.
"""
"""
if
num_token
:
if
num_token
<
block_size
:
max_num_tokens_padded
=
min
(
topk_ids
.
numel
()
*
block_size
,
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
))
else
:
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids
=
torch
.
full
((
max_num_tokens_padded
,),
fill_value
=
topk_ids
.
numel
(),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
else
:
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids
=
torch
.
empty
((
max_num_tokens_padded
,
),
sorted_ids
=
torch
.
empty
((
max_num_tokens_padded
,
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
...
@@ -939,7 +693,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
...
@@ -939,7 +693,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
config
:
Dict
[
str
,
Any
],
config
:
Dict
[
str
,
Any
],
compute_type
:
tl
.
dtype
,
compute_type
:
tl
.
dtype
,
use_fp8_w8a8
:
bool
,
use_fp8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
use_int4_w4a16
:
bool
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
...
@@ -958,19 +711,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
...
@@ -958,19 +711,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
assert
triton
.
cdiv
(
B
.
shape
[
-
2
],
block_n
)
==
B_scale
.
shape
[
-
2
]
assert
triton
.
cdiv
(
B
.
shape
[
-
2
],
block_n
)
==
B_scale
.
shape
[
-
2
]
assert
triton
.
cdiv
(
B
.
shape
[
-
1
],
block_k
)
==
B_scale
.
shape
[
-
1
]
assert
triton
.
cdiv
(
B
.
shape
[
-
1
],
block_k
)
==
B_scale
.
shape
[
-
1
]
elif
use_int8_w8a8
:
assert
B_scale
is
not
None
if
block_shape
is
None
:
A
,
A_scale
=
ops
.
scaled_int8_quant
(
A
,
A_scale
)
else
:
assert
len
(
block_shape
)
==
2
block_n
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
A
,
A_scale
=
per_token_group_quant_int8
(
A
,
block_k
)
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
assert
triton
.
cdiv
(
B
.
shape
[
-
2
],
block_n
)
==
B_scale
.
shape
[
-
2
]
assert
triton
.
cdiv
(
B
.
shape
[
-
1
],
block_k
)
==
B_scale
.
shape
[
-
1
]
elif
use_int8_w8a16
or
use_int4_w4a16
:
elif
use_int8_w8a16
or
use_int4_w4a16
:
assert
B_scale
is
not
None
assert
B_scale
is
not
None
assert
block_shape
is
None
or
block_shape
[
0
]
==
0
assert
block_shape
is
None
or
block_shape
[
0
]
==
0
...
@@ -1021,45 +761,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
...
@@ -1021,45 +761,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
config
[
"BLOCK_SIZE_K"
],
bit
)
config
[
"BLOCK_SIZE_K"
],
bit
)
return
return
if
os
.
environ
.
get
(
'AWQ_MOE_SZ'
)
==
'1'
:
fused_moe_kernel_awq
[
grid
](
A
,
B
,
C
,
B_scale
,
B_zp
,
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
B
.
shape
[
1
],
A
.
shape
[
1
],
EM
,
topk_ids
.
numel
(),
A
.
stride
(
0
),
A
.
stride
(
1
),
B
.
stride
(
0
),
B
.
stride
(
2
),
B
.
stride
(
1
),
C
.
stride
(
1
),
C
.
stride
(
2
),
B_scale
.
stride
(
0
),
B_scale
.
stride
(
2
),
B_scale
.
stride
(
1
),
B_zp
.
stride
(
0
)
if
B_zp
is
not
None
else
0
,
B_zp
.
stride
(
2
)
if
B_zp
is
not
None
else
0
,
B_zp
.
stride
(
1
)
if
B_zp
is
not
None
else
0
,
block_k_diviable
=
A
.
shape
[
1
]
%
config
[
"BLOCK_SIZE_K"
]
==
0
,
group_size
=
block_shape
[
1
],
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
top_k
=
top_k
,
compute_type
=
compute_type
,
has_zp
=
B_zp
is
not
None
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int8_w8a16
=
use_int8_w8a16
,
**
config
,
)
else
:
fused_moe_kernel_gptq_awq
[
grid
](
fused_moe_kernel_gptq_awq
[
grid
](
A
,
A
,
B
,
B
,
...
@@ -1140,7 +841,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
...
@@ -1140,7 +841,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
top_k
=
top_k
,
top_k
=
top_k
,
compute_type
=
compute_type
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
BLOCK_SIZE_K
=
BLOCK_SIZE_K
,
BLOCK_SIZE_K
=
BLOCK_SIZE_K
,
**
config
,
**
config
,
...
@@ -1161,7 +861,6 @@ def get_config_file_name(E: int,
...
@@ -1161,7 +861,6 @@ def get_config_file_name(E: int,
else
:
else
:
return
f
"E=
{
E
}
,N=
{
N
}
,device_name=
{
device_name
}{
dtype_selector
}{
block_shape_selector
}
_nn.json"
return
f
"E=
{
E
}
,N=
{
N
}
,device_name=
{
device_name
}{
dtype_selector
}{
block_shape_selector
}
_nn.json"
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
@
functools
.
lru_cache
@
functools
.
lru_cache
def
get_moe_configs
(
def
get_moe_configs
(
...
@@ -1170,7 +869,7 @@ def get_moe_configs(
...
@@ -1170,7 +869,7 @@ def get_moe_configs(
dtype
:
Optional
[
str
],
dtype
:
Optional
[
str
],
block_n
:
Optional
[
int
]
=
None
,
block_n
:
Optional
[
int
]
=
None
,
block_k
:
Optional
[
int
]
=
None
,
block_k
:
Optional
[
int
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
use_nn_moe
:
Optional
[
bool
]
=
False
,
)
->
Optional
[
Dict
[
int
,
Any
]]:
)
->
Optional
[
Dict
[
int
,
Any
]]:
"""
"""
Return optimized configurations for the fused MoE kernel.
Return optimized configurations for the fused MoE kernel.
...
@@ -1188,15 +887,6 @@ def get_moe_configs(
...
@@ -1188,15 +887,6 @@ def get_moe_configs(
config_file_path
=
os
.
path
.
join
(
config_file_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
json_file_name
)
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
json_file_name
)
if
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
==
120
:
config_file_path_120
=
config_file_path
.
replace
(
".json"
,
"_120.json"
)
if
os
.
path
.
exists
(
config_file_path_120
):
with
open
(
config_file_path_120
)
as
f
:
logger
.
info
(
"Using configuration from %s for MoE layer."
,
config_file_path_120
)
# If a configuration has been found, return it
return
{
int
(
key
):
val
for
key
,
val
in
json
.
load
(
f
).
items
()}
if
os
.
path
.
exists
(
config_file_path
):
if
os
.
path
.
exists
(
config_file_path
):
with
open
(
config_file_path
)
as
f
:
with
open
(
config_file_path
)
as
f
:
logger
.
info
(
"Using configuration from %s for MoE layer."
,
logger
.
info
(
"Using configuration from %s for MoE layer."
,
...
@@ -1285,7 +975,7 @@ def get_default_config(
...
@@ -1285,7 +975,7 @@ def get_default_config(
dtype
:
Optional
[
str
],
dtype
:
Optional
[
str
],
is_marlin
:
bool
,
is_marlin
:
bool
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
use_nn_moe
:
Optional
[
bool
]
=
False
,
)
->
Dict
[
str
,
int
]:
)
->
Dict
[
str
,
int
]:
if
dtype
==
"fp8_w8a8"
and
block_shape
is
not
None
:
if
dtype
==
"fp8_w8a8"
and
block_shape
is
not
None
:
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
...
@@ -1341,7 +1031,7 @@ def try_get_optimal_moe_config(
...
@@ -1341,7 +1031,7 @@ def try_get_optimal_moe_config(
M
:
int
,
M
:
int
,
is_marlin
:
bool
=
False
,
is_marlin
:
bool
=
False
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
use_nn_moe
:
Optional
[
bool
]
=
False
,
):
):
from
vllm.model_executor.layers.fused_moe
import
get_config
from
vllm.model_executor.layers.fused_moe
import
get_config
override_config
=
get_config
()
override_config
=
get_config
()
...
@@ -1469,12 +1159,9 @@ def grouped_topk(hidden_states: torch.Tensor,
...
@@ -1469,12 +1159,9 @@ def grouped_topk(hidden_states: torch.Tensor,
def
get_config_dtype_str
(
dtype
:
torch
.
dtype
,
def
get_config_dtype_str
(
dtype
:
torch
.
dtype
,
use_int4_w4a16
:
Optional
[
bool
]
=
False
,
use_int4_w4a16
:
Optional
[
bool
]
=
False
,
use_int8_w8a16
:
Optional
[
bool
]
=
False
,
use_int8_w8a16
:
Optional
[
bool
]
=
False
,
use_fp8_w8a8
:
Optional
[
bool
]
=
False
,
use_fp8_w8a8
:
Optional
[
bool
]
=
False
):
use_int8_w8a8
:
Optional
[
bool
]
=
False
):
if
use_fp8_w8a8
:
if
use_fp8_w8a8
:
return
"fp8_w8a8"
return
"fp8_w8a8"
elif
use_int8_w8a8
:
return
"int8_w8a8"
elif
use_int8_w8a16
:
elif
use_int8_w8a16
:
return
"int8_w8a16"
return
"int8_w8a16"
elif
use_int4_w4a16
:
elif
use_int4_w4a16
:
...
@@ -1493,7 +1180,6 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
...
@@ -1493,7 +1180,6 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
Optional
[
str
]
=
None
,
activation
:
Optional
[
str
]
=
None
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
...
@@ -1505,12 +1191,12 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
...
@@ -1505,12 +1191,12 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
)
->
None
:
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
activation
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
activation
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
global_num_experts
,
expert_map
,
use_int4_w4a16
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
,
)
block_shape
,
use_nn_moe
)
def
inplace_fused_experts_fake
(
def
inplace_fused_experts_fake
(
...
@@ -1521,7 +1207,6 @@ def inplace_fused_experts_fake(
...
@@ -1521,7 +1207,6 @@ def inplace_fused_experts_fake(
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
Optional
[
str
]
=
None
,
activation
:
Optional
[
str
]
=
None
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
...
@@ -1533,7 +1218,7 @@ def inplace_fused_experts_fake(
...
@@ -1533,7 +1218,7 @@ def inplace_fused_experts_fake(
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
)
->
None
:
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
pass
pass
...
@@ -1553,7 +1238,6 @@ def outplace_fused_experts(
...
@@ -1553,7 +1238,6 @@ def outplace_fused_experts(
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
Optional
[
str
]
=
None
,
activation
:
Optional
[
str
]
=
None
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
...
@@ -1565,13 +1249,12 @@ def outplace_fused_experts(
...
@@ -1565,13 +1249,12 @@ def outplace_fused_experts(
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
torch
.
Tensor
:
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
False
,
activation
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
False
,
activation
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
global_num_experts
,
expert_map
,
use_int4_w4a16
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
a2_scale
,
block_shape
,
use_nn_moe
)
use_nn_moe
,)
def
outplace_fused_experts_fake
(
def
outplace_fused_experts_fake
(
...
@@ -1582,7 +1265,6 @@ def outplace_fused_experts_fake(
...
@@ -1582,7 +1265,6 @@ def outplace_fused_experts_fake(
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
Optional
[
str
]
=
None
,
activation
:
Optional
[
str
]
=
None
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
...
@@ -1594,7 +1276,7 @@ def outplace_fused_experts_fake(
...
@@ -1594,7 +1276,7 @@ def outplace_fused_experts_fake(
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
return
torch
.
empty_like
(
hidden_states
)
...
@@ -1614,7 +1296,6 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -1614,7 +1296,6 @@ def fused_experts(hidden_states: torch.Tensor,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
...
@@ -1626,23 +1307,21 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -1626,23 +1307,21 @@ def fused_experts(hidden_states: torch.Tensor,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
torch
.
Tensor
:
if
inplace
:
if
inplace
:
torch
.
ops
.
vllm
.
inplace_fused_experts
(
torch
.
ops
.
vllm
.
inplace_fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
activation
,
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
activation
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
global_num_experts
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
block_shape
,
use_nn_moe
)
use_nn_moe
,)
return
hidden_states
return
hidden_states
else
:
else
:
return
torch
.
ops
.
vllm
.
outplace_fused_experts
(
return
torch
.
ops
.
vllm
.
outplace_fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
activation
,
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
activation
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
global_num_experts
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
block_shape
,
use_nn_moe
)
use_nn_moe
,)
def
fused_experts_impl
(
hidden_states
:
torch
.
Tensor
,
def
fused_experts_impl
(
hidden_states
:
torch
.
Tensor
,
...
@@ -1653,7 +1332,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1653,7 +1332,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
...
@@ -1665,7 +1343,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1665,7 +1343,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
):
use_nn_moe
:
Optional
[
bool
]
=
False
):
# Check constraints.
# Check constraints.
if
use_int4_w4a16
:
if
use_int4_w4a16
:
assert
hidden_states
.
shape
[
1
]
//
2
==
w1
.
shape
[
assert
hidden_states
.
shape
[
1
]
//
2
==
w1
.
shape
[
...
@@ -1684,12 +1362,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1684,12 +1362,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
]
]
num_tokens
,
_
=
hidden_states
.
shape
num_tokens
,
_
=
hidden_states
.
shape
if
use_nn_moe
:
if
use_nn_moe
:
E
,
_
,
N
=
w1
.
shape
E
,
_
,
N
=
w1
.
shape
else
:
else
:
E
,
N
,
_
=
w1
.
shape
E
,
N
,
_
=
w1
.
shape
if
global_num_experts
==
-
1
:
if
global_num_experts
==
-
1
:
global_num_experts
=
E
global_num_experts
=
E
top_k_num
=
topk_ids
.
shape
[
1
]
top_k_num
=
topk_ids
.
shape
[
1
]
...
@@ -1697,9 +1373,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1697,9 +1373,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
# https://github.com/vllm-project/vllm/issues/5938
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
if
not
use_int8_w8a8
:
config_dtype
=
get_config_dtype_str
(
use_fp8_w8a8
=
use_fp8_w8a8
,
config_dtype
=
get_config_dtype_str
(
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a16
=
use_int4_w4a16
,
dtype
=
hidden_states
.
dtype
)
dtype
=
hidden_states
.
dtype
)
...
@@ -1708,7 +1382,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1708,7 +1382,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
try_get_optimal_moe_config
,
try_get_optimal_moe_config
,
w1
.
shape
,
w1
.
shape
,
w2
.
shape
,
w2
.
shape
,
topk_
ids
.
shape
[
1
]
,
top
_
k_
num
,
config_dtype
,
config_dtype
,
block_shape
=
block_shape
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
,
use_nn_moe
=
use_nn_moe
,
...
@@ -1718,13 +1392,13 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1718,13 +1392,13 @@ def fused_experts_impl(hidden_states: torch.Tensor,
# We can reuse the memory between these because by the time we need
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
# cache3, we're done with cache1
cache13
=
torch
.
empty
(
M
*
top_k_num
*
max
(
N
,
w2
.
shape
[
1
]),
cache13
=
torch
.
empty
(
M
*
top_k_num
*
max
(
N
,
w2
.
shape
[
1
]
if
not
use_nn_moe
else
w2
.
shape
[
2
]
),
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
dtype
=
hidden_states
.
dtype
)
intermediate_cache1
=
cache13
[:
M
*
top_k_num
*
N
].
view
(
intermediate_cache1
=
cache13
[:
M
*
top_k_num
*
N
].
view
(
(
M
,
topk_ids
.
shape
[
1
],
N
))
(
M
,
topk_ids
.
shape
[
1
],
N
))
intermediate_cache3
=
cache13
[:
M
*
top_k_num
*
(
w2
.
shape
[
1
]
if
not
use_nn_moe
else
w2
.
shape
[
2
])].
view
(
intermediate_cache3
=
cache13
[:
M
*
top_k_num
*
(
w2
.
shape
[
1
]
if
not
use_nn_moe
else
w2
.
shape
[
2
])].
view
(
(
M
,
topk_ids
.
shape
[
1
],
w2
.
shape
[
1
]))
(
M
,
topk_ids
.
shape
[
1
],
w2
.
shape
[
1
]
if
not
use_nn_moe
else
w2
.
shape
[
2
]
))
# This needs separate memory since it's used concurrently with cache1
# This needs separate memory since it's used concurrently with cache1
intermediate_cache2
=
torch
.
empty
((
M
*
top_k_num
,
N
//
2
),
intermediate_cache2
=
torch
.
empty
((
M
*
top_k_num
,
N
//
2
),
...
@@ -1769,39 +1443,9 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1769,39 +1443,9 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
if
use_int8_w8a8
:
m
=
curr_hidden_states
.
shape
[
0
]
if
m
<=
16
:
config
=
stage1_best_config
[
m
-
1
]
elif
m
<=
32
:
config
=
stage1_best_config
[
15
]
elif
m
<=
64
:
config
=
stage1_best_config
[
16
]
elif
m
<
256
:
config
=
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
}
else
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
,
"num_warps"
:
4
}
if
use_int4_w4a16
:
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
global_num_experts
,
expert_map
,
curr_hidden_states
.
shape
[
0
]))
else
:
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
global_num_experts
,
expert_map
))
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
global_num_experts
,
expert_map
))
invoke_fused_moe_kernel
(
curr_hidden_states
,
invoke_fused_moe_kernel
(
curr_hidden_states
,
w1
,
w1
,
...
@@ -1819,7 +1463,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1819,7 +1463,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config
,
config
,
compute_type
=
compute_type
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a16
=
use_int4_w4a16
,
block_shape
=
block_shape
,
block_shape
=
block_shape
,
...
@@ -1834,33 +1477,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1834,33 +1477,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
else
:
else
:
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
"
)
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
"
)
if
use_int8_w8a8
:
m
=
curr_hidden_states
.
shape
[
0
]
if
m
<=
16
:
config
=
stage2_best_config
[
m
-
1
]
elif
m
<=
32
:
config
=
stage2_best_config
[
15
]
elif
m
<=
64
:
config
=
stage2_best_config
[
16
]
elif
m
<
256
:
config
=
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
}
else
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
,
"num_warps"
:
4
}
invoke_fused_moe_kernel
(
intermediate_cache2
,
invoke_fused_moe_kernel
(
intermediate_cache2
,
w2
,
w2
,
intermediate_cache3
,
intermediate_cache3
,
...
@@ -1877,7 +1493,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1877,7 +1493,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config
,
config
,
compute_type
=
compute_type
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a16
=
use_int4_w4a16
,
block_shape
=
block_shape
,
block_shape
=
block_shape
,
...
@@ -1902,7 +1517,6 @@ def fused_moe(
...
@@ -1902,7 +1517,6 @@ def fused_moe(
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
global_num_experts
:
int
=
-
1
,
...
@@ -1984,7 +1598,6 @@ def fused_moe(
...
@@ -1984,7 +1598,6 @@ def fused_moe(
inplace
=
inplace
,
inplace
=
inplace
,
activation
=
activation
,
activation
=
activation
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a16
=
use_int4_w4a16
,
global_num_experts
=
global_num_experts
,
global_num_experts
=
global_num_experts
,
...
@@ -1996,4 +1609,4 @@ def fused_moe(
...
@@ -1996,4 +1609,4 @@ def fused_moe(
a1_scale
=
a1_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
,)
use_nn_moe
=
use_nn_moe
)
\ No newline at end of file
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment