Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8223f750
Commit
8223f750
authored
Nov 16, 2025
by
luopl
Browse files
feat: implement int8 quantization
parent
34bf6014
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
2438 additions
and
387 deletions
+2438
-387
vllm/model_executor/layers/fused_moe/fused_moe_step3vw8a16.py
.../model_executor/layers/fused_moe/fused_moe_step3vw8a16.py
+2327
-0
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+3
-0
vllm/model_executor/layers/quantization/groupwise_quant.py
vllm/model_executor/layers/quantization/groupwise_quant.py
+105
-379
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+3
-8
No files found.
vllm/model_executor/layers/fused_moe/fused_moe_step3vw8a16.py
0 → 100644
View file @
8223f750
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Fused MoE kernel."""
import
functools
import
json
import
os
import
math
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
List
,
Optional
,
Tuple
import
torch
import
vllm.envs
as
envs
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
# yapf: disable
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEQuantConfig
,
get_config_quant_dtype
)
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
(
_valid_cutlass_block_scaled_grouped_gemm
,
run_cutlass_block_scaled_fused_experts
)
# yapf: enable
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
_valid_deep_gemm
,
deep_gemm_moe_fp8
)
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
moe_align_block_size
)
try
:
from
lmslim.layers.gemm.int8_utils
import
(
per_token_group_quant_int8
,
per_token_quant_int8
)
from
lmslim.layers.fused_moe.fuse_moe_int8
import
(
fused_experts_impl_int8
,
get_w8a8moe_json
)
from
lmslim.layers.fused_moe.fuse_moe_w4a8
import
fused_experts_impl_w4a8
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer the quantitative model of moe.
\n
"
)
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
)
from
vllm.model_executor.layers.fused_moe.utils
import
(
_resize_cache
,
moe_kernel_quantize_input
)
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils
import
direct_register_custom_op
# from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
logger
=
init_logger
(
__name__
)
if
envs
.
VLLM_USE_GLOBAL_CACHE13
:
moe_cache_singleton
=
None
@
torch
.
compile
def
moe_sum_reduce_torch_compile
(
x
,
out
,
routed_scaling_factor
):
torch
.
sum
(
x
,
dim
=
1
,
out
=
out
)
out
.
mul_
(
routed_scaling_factor
)
@
triton
.
jit
def
_moe_sum_reduce_kernel
(
input_ptr
,
input_stride_0
,
input_stride_1
,
input_stride_2
,
output_ptr
,
output_stride_0
,
output_stride_1
,
token_num
:
int
,
topk_num
:
int
,
hidden_dim
:
int
,
routed_scaling_factor
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DIM
:
tl
.
constexpr
,
NUM_STAGE
:
tl
.
constexpr
,
):
input_stride_0
=
tl
.
cast
(
input_stride_0
,
dtype
=
tl
.
int64
)
input_stride_1
=
tl
.
cast
(
input_stride_1
,
dtype
=
tl
.
int64
)
output_stride_0
=
tl
.
cast
(
output_stride_0
,
dtype
=
tl
.
int64
)
token_block_id
=
tl
.
program_id
(
0
)
dim_block_id
=
tl
.
program_id
(
1
)
token_start
=
token_block_id
*
BLOCK_M
token_end
=
min
((
token_block_id
+
1
)
*
BLOCK_M
,
token_num
)
dim_start
=
dim_block_id
*
BLOCK_DIM
dim_end
=
min
((
dim_block_id
+
1
)
*
BLOCK_DIM
,
hidden_dim
)
offs_dim
=
dim_start
+
tl
.
arange
(
0
,
BLOCK_DIM
)
for
token_index
in
range
(
token_start
,
token_end
):
accumulator
=
tl
.
zeros
((
BLOCK_DIM
,),
dtype
=
tl
.
float32
)
input_t_ptr
=
input_ptr
+
token_index
*
input_stride_0
+
offs_dim
for
i
in
tl
.
range
(
0
,
topk_num
,
num_stages
=
NUM_STAGE
):
tmp
=
tl
.
load
(
input_t_ptr
+
i
*
input_stride_1
,
mask
=
offs_dim
<
dim_end
,
other
=
0.0
)
accumulator
+=
tmp
accumulator
=
accumulator
*
routed_scaling_factor
store_t_ptr
=
output_ptr
+
token_index
*
output_stride_0
+
offs_dim
tl
.
store
(
store_t_ptr
,
accumulator
.
to
(
input_ptr
.
dtype
.
element_ty
),
mask
=
offs_dim
<
dim_end
,
)
def
moe_sum_reduce_triton
(
input
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
routed_scaling_factor
:
float
):
assert
input
.
is_contiguous
()
assert
output
.
is_contiguous
()
token_num
,
topk_num
,
hidden_dim
=
input
.
shape
assert
output
.
shape
[
0
]
==
token_num
and
output
.
shape
[
1
]
==
hidden_dim
if
token_num
<=
32
:
BLOCK_M
=
1
BLOCK_DIM
=
512
NUM_STAGE
=
2
num_warps
=
4
elif
token_num
<=
128
:
BLOCK_M
=
1
BLOCK_DIM
=
1024
NUM_STAGE
=
0
num_warps
=
2
elif
token_num
<=
4096
:
BLOCK_M
=
1
BLOCK_DIM
=
2048
NUM_STAGE
=
0
num_warps
=
2
else
:
BLOCK_M
=
1
BLOCK_DIM
=
2048
NUM_STAGE
=
2
num_warps
=
8
grid
=
(
triton
.
cdiv
(
token_num
,
BLOCK_M
),
triton
.
cdiv
(
hidden_dim
,
BLOCK_DIM
),
)
_moe_sum_reduce_kernel
[
grid
](
input
,
*
input
.
stride
(),
output
,
*
output
.
stride
(),
token_num
=
token_num
,
topk_num
=
topk_num
,
hidden_dim
=
hidden_dim
,
routed_scaling_factor
=
routed_scaling_factor
,
BLOCK_M
=
BLOCK_M
,
BLOCK_DIM
=
BLOCK_DIM
,
NUM_STAGE
=
NUM_STAGE
,
num_warps
=
num_warps
,
)
return
def
moe_reduce_dispatch
(
intermediate_cache3
:
torch
.
Tensor
,
out_hidden_states
:
torch
.
Tensor
,
begin_chunk_idx
:
int
,
end_chunk_idx
:
int
,
):
inter_cache_view
=
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
)
n
=
intermediate_cache3
.
shape
[
0
]
# 根据 n 大小选择不同的 reduce 实现
if
1
<=
n
<=
4
:
moe_sum_reduce_torch_compile
(
inter_cache_view
,
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
1.0
)
elif
4
<
n
<=
1024
:
moe_sum_reduce_triton
(
inter_cache_view
,
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
1.0
)
elif
1024
<
n
<=
32768
:
ops
.
moe_sum_opt1
(
inter_cache_view
,
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
])
else
:
ops
.
moe_sum
(
inter_cache_view
,
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
])
def
get_moe_cache
(
top_k_num
,
N
,
K
,
device
,
dtype
):
global
moe_cache_singleton
if
moe_cache_singleton
is
None
:
moe_cache_singleton
=
torch
.
empty
(
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
*
top_k_num
*
max
(
N
,
K
),
device
=
device
,
dtype
=
dtype
)
logger
.
info
(
f
"Initializing moe_cache_singleton shape:
{
moe_cache_singleton
.
shape
}
, memory:
{
moe_cache_singleton
.
element_size
()
*
moe_cache_singleton
.
numel
()
/
1024
**
2
:.
2
f
}
MB"
)
return
moe_cache_singleton
@
triton
.
jit
def
write_zeros_to_output
(
c_ptr
,
stride_cm
,
stride_cn
,
pid_n
,
N
,
offs_token
,
token_mask
,
BLOCK_SIZE_M
,
BLOCK_SIZE_N
,
compute_type
):
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
compute_type
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
c_ptr
+
stride_cm
*
offs_token
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
token_mask
[:,
None
]
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
@
triton
.
jit
def
fused_moe_kernel_awq
(
# Pointers to matrices
a_ptr
,
# [4, 7168]
b_ptr
,
# [256, 512, 3584]
c_ptr
,
# (8, 8, 512)
b_scale_ptr
,
# (256, 512, 56)
b_zp_ptr
,
# (256, 256, 56)
topk_weights_ptr
,
sorted_token_ids_ptr
,
# [0, 1, 2, 3, 4]
expert_ids_ptr
,
num_tokens_post_padded_ptr
,
# Matrix dimensions
N
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
EM
,
# pading后的总索引长度
num_valid_tokens
,
# 有效索引的上限
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am
,
stride_ak
,
stride_be
,
stride_bk
,
#1
stride_bn
,
stride_cm
,
stride_cn
,
stride_bse
,
stride_bsk
,
#1
stride_bsn
,
stride_bze
,
stride_bzk
,
stride_bzn
,
block_k_diviable
:
tl
.
constexpr
,
group_size
:
tl
.
constexpr
,
# 128
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
MUL_ROUTED_WEIGHT
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
has_zp
:
tl
.
constexpr
,
use_int4_w4a16
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
):
pid
=
tl
.
program_id
(
axis
=
0
)
num_pid_m
=
tl
.
cdiv
(
EM
,
BLOCK_SIZE_M
)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE_N
)
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
group_id
=
pid
//
num_pid_in_group
first_pid_m
=
group_id
*
GROUP_SIZE_M
group_size_m
=
min
(
num_pid_m
-
first_pid_m
,
GROUP_SIZE_M
)
pid_m
=
first_pid_m
+
((
pid
%
num_pid_in_group
)
%
group_size_m
)
pid_n
=
(
pid
%
num_pid_in_group
)
//
group_size_m
num_tokens_post_padded
=
tl
.
load
(
num_tokens_post_padded_ptr
)
if
pid_m
*
BLOCK_SIZE_M
>=
num_tokens_post_padded
:
return
offs_token_id
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_token
=
tl
.
load
(
sorted_token_ids_ptr
+
offs_token_id
)
# [block_m]
token_mask
=
offs_token
<
num_valid_tokens
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
# [block_n]
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
# 0, 1, 2, ...... , 127 # # [block_k]
offs_k2
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
//
2
)
# 0, 1, 2, ...... , 127 # # [block_k]
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
# [block_m, block_k]
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
)
if
use_int4_w4a16
:
# [0, 1, 2, ...... , 126, 127] --> [0, 0, 1, 1 ...... , 63, 63]
# [128, 129, 130, ...... , 254, 255] --> [64, 64, 65, 65 ...... , 127, 127]
# b_ptrs = b_ptr + off_experts * stride_be + \
# (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
\
offs_bn
[:,
None
]
*
stride_bn
+
(
offs_k2
[
None
,
:])
*
stride_bk
# tl.device_print("stride_bn",stride_bsn)>1
# tl.device_print("stride_bk",stride_bk)=1
b_shifter
=
(
offs_k
[:,
None
]
%
2
)
*
4
# 0, 4
elif
use_int8_w8a16
:
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
\
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
if
not
has_zp
and
use_int4_w4a16
:
b_zp_num
=
8
if
not
has_zp
and
use_int8_w8a16
:
b_zp_num
=
128
elif
has_zp
and
use_int4_w4a16
:
b_zp_shifter
=
(
offs_bn
[
None
,
:]
%
2
)
*
4
# 0, 4
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
if
not
block_k_diviable
:
k_mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
k_other
=
0.0
else
:
k_mask
=
None
k_other
=
None
a
=
tl
.
load
(
a_ptrs
,
mask
=
token_mask
[:,
None
]
&
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
),
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
)
if
use_int4_w4a16
:
b
=
tl
.
interleave
(
b
,
b
)
b
=
b
.
trans
()
b
=
(
b
>>
b_shifter
)
&
0xF
b_scale_ptrs
=
b_scale_ptr
+
off_experts
*
stride_bse
+
\
offs_bn
[
None
,
:]
*
stride_bsk
+
\
((
offs_k
[:,
None
]
+
BLOCK_SIZE_K
*
k
)
//
group_size
)
*
stride_bsn
qzeros_scles
=
tl
.
load
(
b_scale_ptrs
,
mask
=
k_mask
,
other
=
k_other
)
scales_int16
=
tl
.
cast
(
qzeros_scles
,
tl
.
uint16
)
b_scale
=
tl
.
cast
(
scales_int16
,
tl
.
float16
,
bitcast
=
True
)
# tl.device_print("b_scale dequant",b_scale)
mid
=
qzeros_scles
>>
16
# b_zp = tl.cast(mid,tl.float16,bitcast=False)
b_zp
=
tl
.
cast
(
mid
,
tl
.
float16
)
# b_zp = tl.cast(zeros_int16,tl.float16,bitcast=False)
# tl.device_print("bzp",b_zp)
# We accumulate along the K dimension.
b
=
((
b
-
b_zp
)
*
b_scale
).
to
(
tl
.
float16
)
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
# Advance the ptrs to the next K block.
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
if
use_int4_w4a16
:
b_ptrs
+=
(
BLOCK_SIZE_K
//
2
)
*
stride_bk
else
:
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
MUL_ROUTED_WEIGHT
:
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0
)
accumulator
=
accumulator
*
moe_weight
[:,
None
]
accumulator
=
accumulator
.
to
(
compute_type
)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
c_ptr
+
stride_cm
*
offs_token
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
token_mask
[:,
None
]
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
@
triton
.
jit
def
fused_moe_kernel_gptq_awq
(
# Pointers to matrices
a_ptr
,
b_ptr
,
c_ptr
,
b_scale_ptr
,
b_zp_ptr
,
topk_weights_ptr
,
sorted_token_ids_ptr
,
expert_ids_ptr
,
num_tokens_post_padded_ptr
,
# Matrix dimensions
N
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
EM
,
num_valid_tokens
,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am
,
stride_ak
,
stride_be
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
stride_bse
,
stride_bsk
,
stride_bsn
,
stride_bze
,
stride_bzk
,
stride_bzn
,
block_k_diviable
:
tl
.
constexpr
,
group_size
:
tl
.
constexpr
,
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
MUL_ROUTED_WEIGHT
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
has_zp
:
tl
.
constexpr
,
use_int4_w4a16
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can
be any shape representing batches and K is the feature dimension of
each token.
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
the number of experts, K is the input feature dimension, and N is
the output feature dimension.
- C: The output cache tensor with shape (M, topk, N), where M is the
total number of tokens post padding, topk is the number of times
each token is repeated, and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens,
repeated topk times and arranged by the expert index they are
assigned to.
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.
This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
multiplication across different blocks processed by the same expert.
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
pid
=
tl
.
program_id
(
axis
=
0
)
num_pid_m
=
tl
.
cdiv
(
EM
,
BLOCK_SIZE_M
)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE_N
)
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
group_id
=
pid
//
num_pid_in_group
first_pid_m
=
group_id
*
GROUP_SIZE_M
group_size_m
=
min
(
num_pid_m
-
first_pid_m
,
GROUP_SIZE_M
)
pid_m
=
first_pid_m
+
((
pid
%
num_pid_in_group
)
%
group_size_m
)
pid_n
=
(
pid
%
num_pid_in_group
)
//
group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
num_tokens_post_padded
=
tl
.
load
(
num_tokens_post_padded_ptr
)
if
pid_m
*
BLOCK_SIZE_M
>=
num_tokens_post_padded
:
return
offs_token_id
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
).
to
(
tl
.
int64
)
offs_token
=
tl
.
load
(
sorted_token_ids_ptr
+
offs_token_id
)
token_mask
=
offs_token
<
num_valid_tokens
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
).
to
(
tl
.
int64
))
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
)
if
use_int4_w4a16
:
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
\
(
offs_k
[:,
None
]
//
2
)
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
b_shifter
=
(
offs_k
[:,
None
]
%
2
)
*
4
elif
use_int8_w8a16
:
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
\
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
if
not
has_zp
and
use_int4_w4a16
:
b_zp_num
=
8
if
not
has_zp
and
use_int8_w8a16
:
b_zp_num
=
128
elif
has_zp
and
use_int4_w4a16
:
b_zp_shifter
=
(
offs_bn
[
None
,
:]
%
2
)
*
4
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
if
not
block_k_diviable
:
k_mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
k_other
=
0.0
else
:
k_mask
=
None
k_other
=
None
a
=
tl
.
load
(
a_ptrs
,
mask
=
token_mask
[:,
None
]
&
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
),
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
)
if
use_int4_w4a16
:
b
=
(
b
>>
b_shifter
)
&
0xF
b_scale_ptrs
=
b_scale_ptr
+
off_experts
*
stride_bse
+
\
offs_bn
[
None
,
:]
*
stride_bsn
+
\
((
offs_k
[:,
None
]
+
BLOCK_SIZE_K
*
k
)
//
group_size
)
*
stride_bsk
b_scale
=
tl
.
load
(
b_scale_ptrs
,
mask
=
k_mask
,
other
=
k_other
)
b_scale
=
b_scale
.
to
(
tl
.
float32
)
if
has_zp
and
use_int4_w4a16
:
offs_k_true
=
(
offs_k
[:,
None
]
+
BLOCK_SIZE_K
*
k
)
//
group_size
b_zp_ptrs
=
b_zp_ptr
+
off_experts
*
stride_bze
+
\
(
offs_bn
[
None
,
:]
//
2
)
*
stride_bzn
+
\
offs_k_true
*
stride_bzk
b_zp
=
tl
.
load
(
b_zp_ptrs
,
mask
=
k_mask
,
other
=
k_other
)
b_zp
=
((
b_zp
>>
b_zp_shifter
)
&
0xF
)
b_zp
=
b_zp
.
to
(
tl
.
float32
)
elif
has_zp
and
use_int8_w8a16
:
offs_k_true
=
(
offs_k
[:,
None
]
+
BLOCK_SIZE_K
*
k
)
//
group_size
b_zp_ptrs
=
b_zp_ptr
+
off_experts
*
stride_bze
+
\
offs_bn
[
None
,
:]
*
stride_bzn
+
\
offs_k_true
*
stride_bzk
b_zp
=
tl
.
load
(
b_zp_ptrs
,
mask
=
k_mask
,
other
=
k_other
)
b_zp
=
b_zp
.
to
(
tl
.
float32
)
# We accumulate along the K dimension.
if
has_zp
:
b
=
((
b
.
to
(
tl
.
float32
)
-
b_zp
)
*
b_scale
).
to
(
compute_type
)
else
:
b
=
((
b
.
to
(
tl
.
float32
)
-
b_zp_num
)
*
b_scale
).
to
(
compute_type
)
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
# Advance the ptrs to the next K block.
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
if
use_int4_w4a16
:
b_ptrs
+=
(
BLOCK_SIZE_K
//
2
)
*
stride_bk
else
:
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
MUL_ROUTED_WEIGHT
:
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0
)
accumulator
=
accumulator
*
moe_weight
[:,
None
]
accumulator
=
accumulator
.
to
(
compute_type
)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
c_ptr
+
stride_cm
*
offs_token
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
token_mask
[:,
None
]
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
@
triton
.
jit
def
fused_moe_kernel
(
# Pointers to matrices
a_ptr
,
b_ptr
,
c_ptr
,
a_scale_ptr
,
b_scale_ptr
,
topk_weights_ptr
,
sorted_token_ids_ptr
,
expert_ids_ptr
,
num_tokens_post_padded_ptr
,
# Matrix dimensions
N
,
K
,
EM
,
num_valid_tokens
,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am
,
stride_ak
,
stride_be
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
stride_asm
,
stride_ask
,
stride_bse
,
stride_bsk
,
stride_bsn
,
# Block size for block-wise quantization
group_n
:
tl
.
constexpr
,
group_k
:
tl
.
constexpr
,
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
MUL_ROUTED_WEIGHT
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
use_fp8_w8a8
:
tl
.
constexpr
,
use_int8_w8a8
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
,
per_channel_quant
:
tl
.
constexpr
,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can
be any shape representing batches and K is the feature dimension of
each token.
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
the number of experts, K is the input feature dimension, and N is
the output feature dimension.
- C: The output cache tensor with shape (M, topk, N), where M is the
total number of tokens post padding, topk is the number of times
each token is repeated, and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens,
repeated topk times and arranged by the expert index they are
assigned to.
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.
This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
multiplication across different blocks processed by the same expert.
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
pid
=
tl
.
program_id
(
axis
=
0
)
# num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
# num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# num_pid_in_group = GROUP_SIZE_M * num_pid_n
# group_id = pid // num_pid_in_group
# first_pid_m = group_id * GROUP_SIZE_M
# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
# pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
# pid_n = (pid % num_pid_in_group) // group_size_m
if
GROUP_SIZE_M
==
1
:
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE_N
)
pid_m
=
pid
//
num_pid_n
pid_n
=
pid
%
num_pid_n
else
:
num_pid_m
=
tl
.
cdiv
(
EM
,
BLOCK_SIZE_M
)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE_N
)
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
group_id
=
pid
//
num_pid_in_group
first_pid_m
=
group_id
*
GROUP_SIZE_M
group_size_m
=
min
(
num_pid_m
-
first_pid_m
,
GROUP_SIZE_M
)
pid_m
=
first_pid_m
+
((
pid
%
num_pid_in_group
)
%
group_size_m
)
pid_n
=
(
pid
%
num_pid_in_group
)
//
group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
num_tokens_post_padded
=
tl
.
load
(
num_tokens_post_padded_ptr
)
if
pid_m
*
BLOCK_SIZE_M
>=
num_tokens_post_padded
:
return
offs_token_id
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_token
=
tl
.
load
(
sorted_token_ids_ptr
+
offs_token_id
)
token_mask
=
offs_token
<
num_valid_tokens
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
)
if
off_experts
==
-
1
:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
# in the current expert parallel rank.
write_zeros_to_output
(
c_ptr
,
stride_cm
,
stride_cn
,
pid_n
,
N
,
offs_token
,
token_mask
,
BLOCK_SIZE_M
,
BLOCK_SIZE_N
,
compute_type
)
return
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
if
use_int8_w8a16
:
b_scale_ptrs
=
b_scale_ptr
+
off_experts
*
stride_bse
+
offs_bn
[
None
,
:]
*
stride_bsn
b_scale
=
tl
.
load
(
b_scale_ptrs
)
if
use_fp8_w8a8
or
use_int8_w8a8
:
# block-wise
if
group_k
>
0
and
group_n
>
0
:
a_scale_ptrs
=
a_scale_ptr
+
(
offs_token
//
top_k
)
*
stride_asm
offs_bsn
=
offs_bn
//
group_n
b_scale_ptrs
=
(
b_scale_ptr
+
off_experts
*
stride_bse
+
offs_bsn
*
stride_bsn
)
# channel-wise
elif
per_channel_quant
:
b_scale_ptrs
=
b_scale_ptr
+
off_experts
*
stride_bse
+
offs_bn
[
None
,
:]
*
stride_bsn
b_scale
=
tl
.
load
(
b_scale_ptrs
)
# Load per-token scale for activations
a_scale_ptrs
=
a_scale_ptr
+
(
offs_token
//
top_k
)
*
stride_asm
a_scale
=
tl
.
load
(
a_scale_ptrs
,
mask
=
token_mask
,
other
=
0.0
)[:,
None
]
# tensor-wise
else
:
a_scale
=
tl
.
load
(
a_scale_ptr
)
b_scale
=
tl
.
load
(
b_scale_ptr
+
off_experts
)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
a
=
tl
.
load
(
a_ptrs
,
mask
=
token_mask
[:,
None
]
&
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
),
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
# We accumulate along the K dimension.
if
use_int8_w8a16
:
accumulator
=
tl
.
dot
(
a
,
b
.
to
(
compute_type
),
acc
=
accumulator
)
elif
use_fp8_w8a8
or
use_int8_w8a8
:
if
group_k
>
0
and
group_n
>
0
:
k_start
=
k
*
BLOCK_SIZE_K
offs_ks
=
k_start
//
group_k
a_scale
=
tl
.
load
(
a_scale_ptrs
+
offs_ks
*
stride_ask
,
mask
=
token_mask
,
other
=
0.0
)
b_scale
=
tl
.
load
(
b_scale_ptrs
+
offs_ks
*
stride_bsk
)
accumulator
+=
tl
.
dot
(
a
,
b
)
*
a_scale
[:,
None
]
*
b_scale
[
None
,
:]
else
:
if
use_fp8_w8a8
:
# acc used to enable fp8_fast_accum
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
else
:
accumulator
+=
tl
.
dot
(
a
,
b
)
else
:
accumulator
+=
tl
.
dot
(
a
,
b
)
# Advance the ptrs to the next K block.
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
MUL_ROUTED_WEIGHT
:
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0
)
accumulator
=
accumulator
*
moe_weight
[:,
None
]
if
use_int8_w8a16
:
accumulator
=
(
accumulator
*
b_scale
).
to
(
compute_type
)
elif
use_fp8_w8a8
or
use_int8_w8a8
:
if
group_k
>
0
and
group_n
>
0
:
accumulator
=
accumulator
.
to
(
compute_type
)
else
:
accumulator
=
(
accumulator
*
a_scale
*
b_scale
).
to
(
compute_type
)
else
:
accumulator
=
accumulator
.
to
(
compute_type
)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
c_ptr
+
stride_cm
*
offs_token
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
token_mask
[:,
None
]
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
def
invoke_fused_moe_kernel
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
B_scale
:
Optional
[
torch
.
Tensor
],
B_zp
:
Optional
[
torch
.
Tensor
],
topk_weights
:
Optional
[
torch
.
Tensor
],
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
mul_routed_weight
:
bool
,
top_k
:
int
,
config
:
dict
[
str
,
Any
],
compute_type
:
tl
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
use_int4_w4a8
:
bool
,
per_channel_quant
:
bool
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
assert
topk_weights
is
not
None
or
not
mul_routed_weight
assert
topk_weights
is
None
or
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
if
use_fp8_w8a8
or
use_int8_w8a8
:
assert
B_scale
is
not
None
assert
(
block_shape
is
None
or
triton
.
cdiv
(
B
.
size
(
-
2
),
block_shape
[
0
])
==
B_scale
.
size
(
-
2
))
assert
(
block_shape
is
None
or
triton
.
cdiv
(
B
.
size
(
-
1
),
block_shape
[
1
])
==
B_scale
.
size
(
-
1
))
elif
use_int8_w8a16
or
use_int4_w4a16
:
assert
B_scale
is
not
None
assert
block_shape
is
None
or
block_shape
[
0
]
==
0
else
:
assert
A_scale
is
None
assert
B_scale
is
None
M
=
A
.
size
(
0
)
num_tokens
=
M
*
top_k
EM
=
sorted_token_ids
.
size
(
0
)
if
A
.
size
(
0
)
<
config
[
"BLOCK_SIZE_M"
]:
# optimize for small batch_size.
# We assume that top_ids of each token is unique, so
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
# and we can skip some invalid blocks.
EM
=
min
(
sorted_token_ids
.
size
(
0
),
A
.
size
(
0
)
*
top_k
*
config
[
'BLOCK_SIZE_M'
])
grid
=
lambda
META
:
(
triton
.
cdiv
(
EM
,
META
[
'BLOCK_SIZE_M'
])
*
triton
.
cdiv
(
B
.
size
(
1
)
if
not
use_nn_moe
else
B
.
size
(
2
),
META
[
'BLOCK_SIZE_N'
]),
)
if
(
use_int8_w8a16
or
use_int4_w4a16
)
and
\
block_shape
is
not
None
and
block_shape
[
1
]
>
0
:
assert
B_scale
is
not
None
and
B_scale
.
ndim
==
3
assert
B_zp
is
None
or
B_zp
.
ndim
==
3
if
os
.
environ
.
get
(
'moe_wna16_use_cuda'
)
==
'1'
:
use_moe_wna16_cuda
=
should_moe_wna16_use_cuda
(
num_valid_tokens
=
num_tokens
,
group_size
=
block_shape
[
1
],
num_experts
=
B
.
size
(
0
),
bit
=
4
if
use_int4_w4a16
else
8
)
config
=
config
.
copy
()
config
.
update
(
get_moe_wna16_block_config
(
config
=
config
,
use_moe_wna16_cuda
=
use_moe_wna16_cuda
,
num_valid_tokens
=
num_tokens
,
size_k
=
A
.
size
(
1
),
size_n
=
B
.
size
(
1
),
num_experts
=
B
.
size
(
1
),
group_size
=
block_shape
[
1
],
real_top_k
=
top_k
,
block_size_m
=
config
[
"BLOCK_SIZE_M"
]))
if
use_moe_wna16_cuda
:
bit
=
4
if
use_int4_w4a16
else
8
ops
.
moe_wna16_gemm
(
A
,
C
,
B
,
B_scale
,
B_zp
,
topk_weights
if
mul_routed_weight
else
None
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
top_k
,
config
[
"BLOCK_SIZE_M"
],
config
[
"BLOCK_SIZE_N"
],
config
[
"BLOCK_SIZE_K"
],
bit
)
return
if
os
.
environ
.
get
(
'AWQ_MOE_SZ'
)
==
'1'
:
fused_moe_kernel_awq
[
grid
](
A
,
B
,
C
,
B_scale
,
B_zp
,
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
B
.
size
(
1
),
A
.
size
(
1
),
EM
,
topk_ids
.
numel
(),
A
.
stride
(
0
),
A
.
stride
(
1
),
B
.
stride
(
0
),
B
.
stride
(
2
),
B
.
stride
(
1
),
C
.
stride
(
1
),
C
.
stride
(
2
),
B_scale
.
stride
(
0
),
B_scale
.
stride
(
2
),
B_scale
.
stride
(
1
),
B_zp
.
stride
(
0
)
if
B_zp
is
not
None
else
0
,
B_zp
.
stride
(
2
)
if
B_zp
is
not
None
else
0
,
B_zp
.
stride
(
1
)
if
B_zp
is
not
None
else
0
,
block_k_diviable
=
A
.
size
(
1
)
%
config
[
"BLOCK_SIZE_K"
]
==
0
,
group_size
=
block_shape
[
1
],
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
top_k
=
top_k
,
compute_type
=
compute_type
,
has_zp
=
B_zp
is
not
None
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int8_w8a16
=
use_int8_w8a16
,
**
config
,
)
else
:
fused_moe_kernel_gptq_awq
[
grid
](
A
,
B
,
C
,
B_scale
,
B_zp
,
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
B
.
size
(
1
),
A
.
size
(
1
),
EM
,
num_tokens
,
A
.
stride
(
0
),
A
.
stride
(
1
),
B
.
stride
(
0
),
B
.
stride
(
2
),
B
.
stride
(
1
),
C
.
stride
(
1
),
C
.
stride
(
2
),
B_scale
.
stride
(
0
),
B_scale
.
stride
(
2
),
B_scale
.
stride
(
1
),
B_zp
.
stride
(
0
)
if
B_zp
is
not
None
else
0
,
B_zp
.
stride
(
2
)
if
B_zp
is
not
None
else
0
,
B_zp
.
stride
(
1
)
if
B_zp
is
not
None
else
0
,
block_k_diviable
=
A
.
size
(
1
)
%
config
[
"BLOCK_SIZE_K"
]
==
0
,
group_size
=
block_shape
[
1
],
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
top_k
=
top_k
,
compute_type
=
compute_type
,
has_zp
=
B_zp
is
not
None
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int8_w8a16
=
use_int8_w8a16
,
**
config
,
)
else
:
# config = config.copy()
# BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
# if block_shape is not None:
# BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0],
# block_shape[1]))
fused_moe_kernel
[
grid
](
A
,
B
,
C
,
A_scale
,
B_scale
,
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
B
.
size
(
1
)
if
not
use_nn_moe
else
B
.
size
(
2
),
A
.
size
(
1
),
EM
,
num_tokens
,
A
.
stride
(
0
),
A
.
stride
(
1
),
B
.
stride
(
0
),
B
.
stride
(
2
)
if
not
use_nn_moe
else
B
.
stride
(
1
),
B
.
stride
(
1
)
if
not
use_nn_moe
else
B
.
stride
(
2
),
C
.
stride
(
1
),
C
.
stride
(
2
),
A_scale
.
stride
(
0
)
if
A_scale
is
not
None
and
A_scale
.
ndim
==
2
else
0
,
A_scale
.
stride
(
1
)
if
A_scale
is
not
None
and
A_scale
.
ndim
==
2
else
0
,
B_scale
.
stride
(
0
)
if
B_scale
is
not
None
and
B_scale
.
ndim
>=
2
else
0
,
B_scale
.
stride
(
2
)
if
B_scale
is
not
None
and
B_scale
.
ndim
==
3
else
0
,
B_scale
.
stride
(
1
)
if
B_scale
is
not
None
and
B_scale
.
ndim
>=
2
else
0
,
0
if
block_shape
is
None
else
block_shape
[
0
],
0
if
block_shape
is
None
else
block_shape
[
1
],
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
top_k
=
top_k
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
per_channel_quant
=
per_channel_quant
,
**
config
,
)
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
def
get_config_file_name
(
E
:
int
,
N
:
int
,
dtype
:
Optional
[
str
],
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
str
:
device_name
=
current_platform
.
get_device_name
().
replace
(
" "
,
"_"
)
dtype_selector
=
""
if
not
dtype
else
f
",dtype=
{
dtype
}
"
block_shape_selector
=
(
""
if
not
block_shape
or
not
all
(
block_shape
)
else
f
",block_shape=
{
block_shape
}
"
).
replace
(
" "
,
""
)
if
not
use_nn_moe
:
return
f
"E=
{
E
}
,N=
{
N
}
,device_name=
{
device_name
}{
dtype_selector
}{
block_shape_selector
}
.json"
# noqa: E501
else
:
return
f
"E=
{
E
}
,N=
{
N
}
,device_name=
{
device_name
}{
dtype_selector
}{
block_shape_selector
}
_nn.json"
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
@
functools
.
lru_cache
def
get_moe_configs
(
E
:
int
,
N
:
int
,
dtype
:
Optional
[
str
],
block_n
:
Optional
[
int
]
=
None
,
block_k
:
Optional
[
int
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
)
->
Optional
[
Dict
[
int
,
Any
]]:
"""
Return optimized configurations for the fused MoE kernel.
The return value will be a dictionary that maps an irregular grid of
batch sizes to configurations of the fused_moe kernel. To evaluate the
kernel on a given batch size bs, the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs
# directory
block_shape
=
[
block_n
,
block_k
]
if
block_n
and
block_k
else
None
json_file_name
=
get_config_file_name
(
E
,
N
,
dtype
,
block_shape
,
use_nn_moe
=
use_nn_moe
)
config_file_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
json_file_name
)
if
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
==
120
:
config_file_path_120
=
config_file_path
.
replace
(
".json"
,
"_120.json"
)
if
os
.
path
.
exists
(
config_file_path_120
):
with
open
(
config_file_path_120
)
as
f
:
logger
.
info
(
"Using configuration from %s for MoE layer."
,
config_file_path_120
)
# If a configuration has been found, return it
return
{
int
(
key
):
val
for
key
,
val
in
json
.
load
(
f
).
items
()}
if
os
.
path
.
exists
(
config_file_path
):
with
open
(
config_file_path
)
as
f
:
logger
.
info
(
"Using configuration from %s for MoE layer."
,
config_file_path
)
# If a configuration has been found, return it
return
{
int
(
key
):
val
for
key
,
val
in
json
.
load
(
f
).
items
()}
# If no optimized configuration is available, we will use the default
# configuration
logger
.
warning
(
(
"Using default MoE config. Performance might be sub-optimal! "
"Config file not found at %s"
),
config_file_path
)
return
None
def
get_moe_wna16_block_config
(
config
:
dict
[
str
,
int
],
use_moe_wna16_cuda
:
bool
,
num_valid_tokens
:
int
,
size_k
:
int
,
size_n
:
int
,
num_experts
:
int
,
group_size
:
int
,
real_top_k
:
int
,
block_size_m
:
int
):
if
"BLOCK_SIZE_N"
in
config
and
"BLOCK_SIZE_K"
in
config
:
# optimal block config is set
return
{}
if
not
use_moe_wna16_cuda
:
# triton moe wna16 kernel
if
num_valid_tokens
//
real_top_k
==
1
:
# if bs=1, use a smaller BLOCK_SIZE_N
return
{
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
}
else
:
return
{
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
}
else
:
# cuda moe wna16 kernel
# set default block_size 128, and increase them when num_blocks
# is too large.
block_size_n
=
128
block_size_k
=
128
if
block_size_k
<=
group_size
:
block_size_k
=
group_size
num_n_blocks
=
size_k
//
block_size_k
num_k_blocks
=
size_n
//
block_size_k
num_m_blocks
=
(
num_valid_tokens
+
block_size_m
-
1
)
/
block_size_m
+
\
num_experts
if
num_valid_tokens
//
real_top_k
<=
block_size_m
:
num_m_blocks
=
min
(
num_m_blocks
,
num_valid_tokens
)
num_blocks
=
num_m_blocks
*
num_n_blocks
*
num_k_blocks
if
size_k
%
256
==
0
and
num_blocks
>=
256
and
\
block_size_k
<
256
:
block_size_k
=
256
num_blocks
=
num_blocks
//
(
256
//
block_size_k
)
if
num_m_blocks
<=
16
and
size_k
%
(
block_size_k
*
2
)
==
0
and
\
size_k
%
(
block_size_k
*
2
)
==
0
and
block_size_k
<=
512
and
\
num_blocks
>=
512
:
block_size_k
=
block_size_k
*
2
num_blocks
=
num_blocks
//
2
if
num_blocks
>
1024
:
block_size_n
=
256
num_n_blocks
=
num_n_blocks
//
2
num_blocks
=
num_blocks
//
2
if
size_n
<=
1024
and
num_blocks
>=
1024
:
# The kernel performance got much better with BLOCK_SIZE_N=1024
# when num_blocks is large, event when N is small.
# Not sure why, maybe it force the CUDA SM process only one block
# at the same time.
block_size_n
=
1024
return
{
"BLOCK_SIZE_N"
:
block_size_n
,
"BLOCK_SIZE_K"
:
block_size_k
}
def
should_moe_wna16_use_cuda
(
num_valid_tokens
:
int
,
group_size
:
int
,
num_experts
:
int
,
bit
:
int
):
return
bit
==
4
and
group_size
in
[
32
,
64
,
128
]
and
\
num_valid_tokens
/
num_experts
<=
6
def
get_default_config
(
M
:
int
,
E
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
dtype
:
Optional
[
str
],
is_marlin
:
bool
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
)
->
dict
[
str
,
int
]:
if
dtype
==
"fp8_w8a8"
and
block_shape
is
not
None
:
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
# BLOCK_SIZE_K must be divisible by block_shape[1]
# num_stages=3 can cause triton.runtime.errors.OutOfResources
# on ROCm, set it to 2 instead.
config
=
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
block_shape
[
0
],
"BLOCK_SIZE_K"
:
block_shape
[
1
],
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
3
if
not
current_platform
.
is_rocm
()
else
2
,
}
# elif dtype in ["int4_w4a16", "int8_w8a16"] and block_shape is not None:
# # moe wna16 kernels
# # only set BLOCK_SIZE_M
# # BLOCK_SIZE_N and BLOCK_SIZE_K would be set later
# bit = 4 if dtype == "int4_w4a16" else 8
# use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk,
# block_shape[1], E, bit)
# if use_moe_wna16_cuda:
# config = {"BLOCK_SIZE_M": min(16, M)}
# elif M <= 20:
# config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1}
# elif M <= 40:
# config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1}
# else:
# config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1}
elif
is_marlin
:
for
block_size_m
in
[
8
,
16
,
32
,
48
,
64
]:
if
M
*
topk
/
E
/
block_size_m
<
0.9
:
break
return
{
"BLOCK_SIZE_M"
:
block_size_m
}
elif
M
<=
E
:
config
=
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
}
else
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
8
,
}
if
use_nn_moe
:
config
[
"num_ldmatrixes"
]
=
1
return
config
def
try_get_optimal_moe_config
(
w1_shape
:
tuple
[
int
,
...],
w2_shape
:
tuple
[
int
,
...],
top_k
:
int
,
dtype
:
Optional
[
str
],
M
:
int
,
is_marlin
:
bool
=
False
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
)
->
dict
[
str
,
int
]:
from
vllm.model_executor.layers.fused_moe
import
get_config
override_config
=
get_config
()
if
override_config
:
config
=
override_config
else
:
# First try to load optimal config from the file
if
not
use_nn_moe
:
E
,
_
,
N
=
w2_shape
else
:
E
,
N
,
_
=
w2_shape
# if dtype == "int4_w4a16":
# N = N * 2
block_n
=
block_shape
[
0
]
if
block_shape
else
0
block_k
=
block_shape
[
1
]
if
block_shape
else
0
configs
=
get_moe_configs
(
E
,
N
,
dtype
,
block_n
,
block_k
,
use_nn_moe
=
use_nn_moe
)
if
configs
:
# If an optimal configuration map has been found, look up the
# optimal config
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
else
:
# Else use the default config
config
=
get_default_config
(
M
,
E
,
N
,
w1_shape
[
2
]
if
not
use_nn_moe
else
w1_shape
[
1
],
top_k
,
dtype
,
is_marlin
,
block_shape
,
use_nn_moe
=
use_nn_moe
)
return
config
def
vllm_topk_softmax
(
topk_weights
:
torch
.
Tensor
,
topk_indices
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
renormalize
:
bool
)
->
tuple
[
torch
.
Tensor
,
...]:
ops
.
topk_softmax
(
topk_weights
,
topk_indices
,
token_expert_indices
,
gating_output
,
)
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
,
topk_indices
def
dispatch_topk_func
()
->
Callable
[...,
tuple
[
torch
.
Tensor
,
...]]:
# if is_rocm_aiter_moe_enabled():
# from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax
# return rocm_aiter_topk_softmax
return
vllm_topk_softmax
def
fused_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
indices_type
:
Optional
[
torch
.
dtype
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
assert
hidden_states
.
size
(
0
)
==
gating_output
.
size
(
0
),
(
"Number of tokens mismatch"
)
M
,
_
=
hidden_states
.
size
()
topk_weights
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
topk_ids
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
if
indices_type
is
None
else
indices_type
,
device
=
hidden_states
.
device
)
token_expert_indices
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
gating_output_float
=
gating_output
.
float
()
# TODO(woosuk): Optimize this.
topk_func
=
dispatch_topk_func
()
topk_weights
,
topk_ids
=
topk_func
(
topk_weights
,
topk_ids
,
token_expert_indices
,
gating_output_float
,
renormalize
)
return
topk_weights
,
topk_ids
,
token_expert_indices
def
is_power_of_two
(
n
):
return
n
>
0
and
math
.
log2
(
n
).
is_integer
()
# This is used by the Deepseek-V2 and Deepseek-V3 model
@
torch
.
compile
(
dynamic
=
True
,
backend
=
current_platform
.
simple_compile_backend
)
def
grouped_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
topk_group
:
int
=
0
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
hidden_states
.
size
(
0
)
==
gating_output
.
size
(
0
),
(
"Number of tokens mismatch"
)
if
scoring_func
==
"softmax"
:
scores
=
torch
.
softmax
(
gating_output
,
dim
=-
1
)
elif
scoring_func
==
"sigmoid"
:
scores
=
gating_output
.
sigmoid
()
else
:
raise
ValueError
(
f
"Unsupported scoring function:
{
scoring_func
}
"
)
num_token
=
scores
.
size
(
0
)
if
e_score_correction_bias
is
not
None
:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_scores
=
scores
scores
=
scores
+
e_score_correction_bias
.
unsqueeze
(
0
)
group_scores
=
(
scores
.
view
(
num_token
,
num_expert_group
,
-
1
).
topk
(
2
,
dim
=-
1
)[
0
].
sum
(
dim
=-
1
))
else
:
group_scores
=
scores
.
view
(
num_token
,
num_expert_group
,
-
1
).
max
(
dim
=-
1
).
values
# [n, n_group]
group_idx
=
torch
.
topk
(
group_scores
,
k
=
topk_group
,
dim
=-
1
,
sorted
=
False
)[
1
]
# [n, top_k_group]
group_mask
=
torch
.
zeros_like
(
group_scores
)
# [n, n_group]
group_mask
.
scatter_
(
1
,
group_idx
,
1
)
# [n, n_group]
score_mask
=
group_mask
.
unsqueeze
(
-
1
).
expand
(
num_token
,
num_expert_group
,
scores
.
size
(
-
1
)
//
num_expert_group
).
reshape
(
num_token
,
-
1
)
# [n, e]
tmp_scores
=
scores
.
masked_fill
(
~
score_mask
.
bool
(),
float
(
"-inf"
))
# [n, e]
if
e_score_correction_bias
is
not
None
:
topk_ids
=
torch
.
topk
(
tmp_scores
,
k
=
topk
,
dim
=-
1
,
sorted
=
False
)[
1
]
# Use original unbiased scores for the routing weights
topk_weights
=
original_scores
.
gather
(
1
,
topk_ids
)
else
:
topk_weights
,
topk_ids
=
torch
.
topk
(
tmp_scores
,
k
=
topk
,
dim
=-
1
,
sorted
=
False
)
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
def
get_config_dtype_str
(
dtype
:
torch
.
dtype
,
use_int4_w4a16
:
Optional
[
bool
]
=
False
,
use_int8_w8a16
:
Optional
[
bool
]
=
False
,
use_fp8_w8a8
:
Optional
[
bool
]
=
False
,
use_int8_w8a8
:
Optional
[
bool
]
=
False
,
use_int4_w4a8
:
Optional
[
bool
]
=
False
)
->
Optional
[
str
]:
if
use_fp8_w8a8
:
return
"fp8_w8a8"
elif
use_int8_w8a8
:
return
"int8_w8a8"
elif
use_int8_w8a16
:
return
"int8_w8a16"
elif
use_int4_w4a16
:
return
"int4_w4a16"
elif
use_int4_w4a8
:
return
"int4_w4a8"
elif
dtype
==
torch
.
float
:
# avoiding cases where kernel fails when float32 MoE
# use fp16/bfloat16 configs
return
"float32"
return
None
def
inplace_fused_experts_step3v
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
Optional
[
str
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
1.0
,
use_step3v_w8a16
:
Optional
[
bool
]
=
False
)
->
None
:
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
activation
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
use_int4_w4a8
,
per_channel_quant
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
,
shared_output
,
routed_scaling_factor
,
use_step3v_w8a16
)
def
inplace_fused_experts_step3v_fake
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
Optional
[
str
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
1.0
)
->
None
:
pass
direct_register_custom_op
(
op_name
=
"inplace_fused_experts_step3v"
,
op_func
=
inplace_fused_experts_step3v
,
mutates_args
=
[
"hidden_states"
],
fake_impl
=
inplace_fused_experts_step3v_fake
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
)
def
outplace_fused_experts_step3v
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
Optional
[
str
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
1.0
)
->
torch
.
Tensor
:
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
False
,
activation
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
use_int4_w4a8
,
per_channel_quant
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
,
shared_output
,
routed_scaling_factor
)
def
outplace_fused_experts_step3v_fake
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
Optional
[
str
]
=
None
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
1.0
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
direct_register_custom_op
(
op_name
=
"outplace_fused_experts_step3v"
,
op_func
=
outplace_fused_experts_step3v
,
mutates_args
=
[],
fake_impl
=
outplace_fused_experts_step3v_fake
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
)
def
torch_vllm_inplace_fused_experts
(
**
kwargs
)
->
torch
.
Tensor
:
torch
.
ops
.
vllm
.
inplace_fused_experts_step3v
(
**
kwargs
)
hidden_states
=
kwargs
[
'hidden_states'
]
return
hidden_states
def
torch_vllm_outplace_fused_experts
(
**
kwargs
)
->
torch
.
Tensor
:
return
torch
.
ops
.
vllm
.
outplace_fused_experts
(
**
kwargs
)
def
dispatch_fused_experts_func
(
inplace
:
bool
)
->
Callable
[...,
torch
.
Tensor
]:
if
inplace
:
return
torch_vllm_inplace_fused_experts
return
torch_vllm_outplace_fused_experts
# TODO (bnell): replace this with modular op. Can get rid of inplace/outplace
# torch ops.
def
fused_experts
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
allow_deep_gemm
:
bool
=
False
,
allow_cutlass_block_scaled_grouped_gemm
:
bool
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
1.0
,
use_step3v_w8a16
:
bool
=
False
,)
->
torch
.
Tensor
:
# For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
N
=
w1
.
size
(
1
)
if
(
allow_deep_gemm
and
use_fp8_w8a8
and
N
>
512
and
_valid_deep_gemm
(
hidden_states
,
w1
,
w2
)):
assert
apply_router_weight_on_input
is
False
return
deep_gemm_moe_fp8
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
inplace
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
elif
(
allow_cutlass_block_scaled_grouped_gemm
and
use_fp8_w8a8
and
_valid_cutlass_block_scaled_grouped_gemm
(
hidden_states
,
w1
,
w2
)):
assert
apply_router_weight_on_input
is
False
return
run_cutlass_block_scaled_fused_experts
(
a
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
)
else
:
return
dispatch_fused_experts_func
(
inplace
)(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a8
=
use_int4_w4a8
,
per_channel_quant
=
per_channel_quant
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w1_zp
=
w1_zp
,
w2_zp
=
w2_zp
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
,
use_step3v_w8a16
=
use_step3v_w8a16
)
def
fused_experts_impl
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
1.0
,
use_step3v_w8a16
:
Optional
[
bool
]
=
False
)
->
torch
.
Tensor
:
num_tokens
=
hidden_states
.
size
(
0
)
if
use_step3v_w8a16
:
#step3v w8a16
use_nn_moe
=
True
if
use_nn_moe
:
E
,
_
,
N
=
w1
.
size
()
else
:
E
,
N
,
_
=
w1
.
size
()
K
=
w2
.
size
(
1
)
if
global_num_experts
==
-
1
:
global_num_experts
=
E
top_k_num
=
topk_ids
.
size
(
1
)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
if
envs
.
VLLM_USE_GLOBAL_CACHE13
:
cache13
=
get_moe_cache
(
top_k_num
,
N
,
K
if
not
use_nn_moe
else
w2
.
shape
[
2
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
else
:
cache13
=
torch
.
empty
(
M
*
top_k_num
*
max
(
N
,
K
if
not
use_nn_moe
else
w2
.
shape
[
2
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
if
use_int8_w8a8
is
True
:
return
fused_experts_impl_int8
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
cache13
=
cache13
,
inplace
=
inplace
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
False
,
use_int8_w8a8
=
True
,
use_int8_w8a16
=
False
,
use_int4_w4a16
=
False
,
per_channel_quant
=
per_channel_quant
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w1_zp
=
w1_zp
,
w2_zp
=
w2_zp
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
use_nn_moe
=
False
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
)
elif
use_int4_w4a8
is
True
:
return
fused_experts_impl_w4a8
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
inplace
,
cache13
=
cache13
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
False
,
use_int8_w8a8
=
False
,
use_int8_w8a16
=
False
,
use_int4_w4a16
=
False
,
use_int4_w4a8
=
True
,
per_channel_quant
=
per_channel_quant
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w1_zp
=
w1_zp
,
w2_zp
=
w2_zp
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
use_nn_moe
=
False
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
)
#
if
use_int4_w4a16
:
assert
hidden_states
.
size
(
1
)
//
2
==
w1
.
size
(
2
),
(
"Hidden size mismatch"
)
elif
use_nn_moe
:
assert
hidden_states
.
size
(
1
)
==
w1
.
size
(
1
),
"Hidden size mismatch"
else
:
assert
hidden_states
.
size
(
1
)
==
w1
.
size
(
2
),
(
f
"Hidden size mismatch
{
hidden_states
.
size
(
1
)
}
!=
{
w1
.
size
(
2
)
}
"
)
assert
topk_weights
.
size
()
==
topk_ids
.
size
(),
"topk shape mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
stride
(
-
1
)
==
1
,
"Stride of last dimension must be 1"
assert
w2
.
stride
(
-
1
)
==
1
,
"Stride of last dimension must be 1"
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
config_dtype
=
get_config_dtype_str
(
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a8
=
use_int4_w4a8
,
dtype
=
hidden_states
.
dtype
)
qtype
=
get_config_quant_dtype
(
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a8
=
use_int4_w4a8
)
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
w1
.
size
(),
w2
.
size
(),
top_k_num
,
config_dtype
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
,
)
config
=
get_config_func
(
M
)
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
intermediate_cache1
=
cache13
[:
M
*
top_k_num
*
N
].
view
(
M
,
top_k_num
,
N
)
intermediate_cache3
=
cache13
[:
M
*
top_k_num
*
(
K
if
not
use_nn_moe
else
w2
.
shape
[
2
])].
view
(
M
,
top_k_num
,
K
if
not
use_nn_moe
else
w2
.
shape
[
2
])
# This needs separate memory since it's used concurrently with cache1
intermediate_cache2
=
torch
.
empty
((
M
*
top_k_num
,
N
//
2
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
if
hidden_states
.
dtype
==
torch
.
bfloat16
:
compute_type
=
tl
.
bfloat16
elif
hidden_states
.
dtype
==
torch
.
float16
:
compute_type
=
tl
.
float16
elif
hidden_states
.
dtype
==
torch
.
float32
:
compute_type
=
tl
.
float32
else
:
raise
ValueError
(
f
"Unsupported compute_type:
{
hidden_states
.
dtype
}
"
)
if
inplace
:
out_hidden_states
=
hidden_states
else
:
out_hidden_states
=
torch
.
empty_like
(
hidden_states
)
for
chunk
in
range
((
num_tokens
//
CHUNK_SIZE
)
+
1
):
begin_chunk_idx
,
end_chunk_idx
=
(
chunk
*
CHUNK_SIZE
,
min
((
chunk
+
1
)
*
CHUNK_SIZE
,
num_tokens
))
curr_hidden_states
=
hidden_states
[
begin_chunk_idx
:
end_chunk_idx
]
tokens_in_chunk
,
_
=
curr_hidden_states
.
size
()
if
tokens_in_chunk
==
0
:
break
if
tokens_in_chunk
<
CHUNK_SIZE
and
chunk
>
0
:
# Adjust the intermediate cache size and config for the last
# chunk. Note that in most cases we only have one chunk
# so the cache size and config are already set correctly and
# do not need to be adjusted.
intermediate_cache1
=
intermediate_cache1
[:
tokens_in_chunk
]
intermediate_cache2
=
intermediate_cache2
[:
tokens_in_chunk
*
topk_ids
.
size
(
1
)]
intermediate_cache3
=
intermediate_cache3
[:
tokens_in_chunk
]
if
not
use_int8_w8a8
:
config
=
get_config_func
(
tokens_in_chunk
)
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
qcurr_hidden_states
,
a1q_scale
=
moe_kernel_quantize_input
(
A
=
curr_hidden_states
,
A_scale
=
a1_scale
,
quant_dtype
=
qtype
,
per_act_token_quant
=
per_channel_quant
,
block_shape
=
block_shape
)
if
use_int4_w4a16
:
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
global_num_experts
,
expert_map
,
curr_hidden_states
.
shape
[
0
]))
else
:
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
global_num_experts
,
expert_map
))
if
use_step3v_w8a16
:
#step3v w8a16
w1
=
(
w1
.
reshape
([
w1_scale
.
shape
[
0
],
w1_scale
.
shape
[
1
],
-
1
,
w1_scale
.
shape
[
-
1
]])
*
w1_scale
.
unsqueeze
(
2
)).
reshape
(
w1
.
shape
)
#step3v w8反量化
invoke_fused_moe_kernel
(
qcurr_hidden_states
,
w1
,
intermediate_cache1
,
None
,
None
,
w1_zp
,
curr_topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
apply_router_weight_on_input
,
top_k_num
,
config
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a8
=
use_int4_w4a8
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
)
else
:
invoke_fused_moe_kernel
(
qcurr_hidden_states
,
w1
,
intermediate_cache1
,
a1q_scale
,
w1_scale
,
w1_zp
,
curr_topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
apply_router_weight_on_input
,
top_k_num
,
config
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a8
=
use_int4_w4a8
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
)
if
activation
==
"silu"
:
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
elif
activation
==
"gelu"
:
torch
.
ops
.
_C
.
gelu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
else
:
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
"
)
qintermediate_cache2
,
a2q_scale
=
moe_kernel_quantize_input
(
A
=
intermediate_cache2
,
A_scale
=
a2_scale
,
quant_dtype
=
qtype
,
per_act_token_quant
=
per_channel_quant
,
block_shape
=
block_shape
)
if
use_step3v_w8a16
:
#step3v w8a16
w2
=
(
w2
.
reshape
([
w2_scale
.
shape
[
0
],
w2_scale
.
shape
[
1
],
-
1
,
w2_scale
.
shape
[
-
1
]])
*
w2_scale
.
unsqueeze
(
2
)).
reshape
(
w2
.
shape
)
#step3v w8反量化
invoke_fused_moe_kernel
(
qintermediate_cache2
,
w2
,
intermediate_cache3
,
None
,
None
,
w2_zp
,
curr_topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
not
apply_router_weight_on_input
,
1
,
config
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a8
=
use_int4_w4a8
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
)
else
:
invoke_fused_moe_kernel
(
qintermediate_cache2
,
w2
,
intermediate_cache3
,
a2q_scale
,
w2_scale
,
w2_zp
,
curr_topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
not
apply_router_weight_on_input
,
1
,
config
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a8
=
use_int4_w4a8
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
)
if
envs
.
VLLM_USE_LIGHTOP
:
from
lightop
import
op
as
op
op
.
moe_sum
(
input
=
intermediate_cache3
.
view
(
*
intermediate_cache3
.
size
()),
output
=
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
bias
=
shared_output
[
begin_chunk_idx
:
end_chunk_idx
],
expert_mask
=
None
,
num_local_tokens
=
None
,
factor
=
routed_scaling_factor
)
# else:
# ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
# out_hidden_states[begin_chunk_idx:end_chunk_idx])
# if shared_output is not None:
# if hidden_states.dtype != torch.float16 or dpsk_fp16_quick:
# out_hidden_states[begin_chunk_idx:end_chunk_idx] = out_hidden_states[begin_chunk_idx:end_chunk_idx] * routed_scaling_factor + shared_output[begin_chunk_idx:end_chunk_idx]
# else:
# # Fix FP16 overflow
# # See DeepseekV2DecoderLayer for more details.
# out_hidden_states[begin_chunk_idx:end_chunk_idx] + shared_output[begin_chunk_idx:end_chunk_idx] * (1. / routed_scaling_factor)
# else:
# if hidden_states.dtype != torch.float16 or dpsk_fp16_quick:
# ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
# out_hidden_states[begin_chunk_idx:end_chunk_idx]) * routed_scaling_factor
else
:
if
envs
.
VLLM_USE_LIGHTOP_MOE_SUM
:
from
lightop
import
op
as
op
op
.
moe_sum
(
input
=
intermediate_cache3
.
view
(
*
intermediate_cache3
.
size
()),
output
=
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
bias
=
None
,
expert_mask
=
None
,
num_local_tokens
=
None
,
factor
=
1.0
)
elif
envs
.
VLLM_USE_OPT_MOE_SUM
:
moe_reduce_dispatch
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
size
()),
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
begin_chunk_idx
,
end_chunk_idx
)
else
:
ops
.
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
size
()),
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
])
return
out_hidden_states
def
fused_moe
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_step3v_w8a16
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
activation to compute the inner products for w1 and w2.
Defaults to False.
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
activation to compute the inner products for w1 and w2.
Defaults to False.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for
a1.
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
a2.
- block_shape: (Optional[List[int]]): Optional block size for block-wise
quantization.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
if
use_grouped_topk
:
assert
num_expert_group
is
not
None
and
topk_group
is
not
None
topk_weights
,
topk_ids
=
grouped_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
,
num_expert_group
,
topk_group
)
elif
custom_routing_function
is
None
:
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
else
:
topk_weights
,
topk_ids
=
custom_routing_function
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
return
fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
inplace
=
inplace
,
activation
=
activation
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a8
=
use_int4_w4a8
,
per_channel_quant
=
per_channel_quant
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w1_zp
=
w1_zp
,
w2_zp
=
w2_zp
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
,
use_step3v_w8a16
=
use_step3v_w8a16
)
class
TritonExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
def
__init__
(
self
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a8
:
bool
=
False
,
per_act_token_quant
:
bool
=
False
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
):
super
().
__init__
(
FusedMoEQuantConfig
.
make
(
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a8
=
use_int4_w4a8
,
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
))
self
.
use_fp8_w8a8
=
use_fp8_w8a8
self
.
use_int4_w4a16
=
use_int4_w4a16
self
.
use_int8_w8a8
=
use_int8_w8a8
self
.
use_int8_w8a16
=
use_int8_w8a16
self
.
use_int4_w4a8
=
use_int4_w4a8
@
property
def
activation_formats
(
self
)
->
tuple
[
mk
.
FusedMoEActivationFormat
,
mk
.
FusedMoEActivationFormat
]:
return
(
mk
.
FusedMoEActivationFormat
.
Standard
,
mk
.
FusedMoEActivationFormat
.
Standard
)
def
supports_chunking
(
self
)
->
bool
:
return
True
def
supports_expert_map
(
self
)
->
bool
:
return
True
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
workspace1
=
(
M
,
topk
,
max
(
N
*
2
,
K
))
workspace2
=
(
M
,
topk
,
N
)
output
=
(
M
,
topk
,
K
)
return
(
workspace1
,
workspace2
,
output
,
a
.
dtype
)
def
apply
(
self
,
output
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
global_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
w1_scale
:
Optional
[
torch
.
Tensor
],
w2_scale
:
Optional
[
torch
.
Tensor
],
w1_zp
:
Optional
[
torch
.
Tensor
],
w2_zp
:
Optional
[
torch
.
Tensor
],
a1q_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
workspace13
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
],
):
# Check constraints.
if
self
.
use_int4_w4a16
:
assert
hidden_states
.
size
(
-
1
)
//
2
==
w1
.
size
(
2
),
(
"Hidden size mismatch"
)
else
:
assert
hidden_states
.
size
(
-
1
)
==
w1
.
size
(
2
),
\
(
f
"Hidden size mismatch
{
hidden_states
.
size
(
-
1
)
}
"
f
"!=
{
w1
.
size
(
2
)
}
"
)
assert
hidden_states
.
is_contiguous
(
),
"Hidden_states must be contiguous"
assert
hidden_states
.
dim
()
==
2
assert
w1
.
stride
(
-
1
)
==
1
,
"Stride of last dimension must be 1"
assert
w2
.
stride
(
-
1
)
==
1
,
"Stride of last dimension must be 1"
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float8_e4m3fn
]
E
,
num_tokens
,
N
,
K
,
top_k_num
=
mk
.
_moe_problem_size
(
hidden_states
,
w1
,
w2
,
topk_ids
)
if
global_num_experts
==
-
1
:
global_num_experts
=
E
config_dtype
=
get_config_dtype_str
(
use_fp8_w8a8
=
self
.
use_fp8_w8a8
,
use_int8_w8a16
=
self
.
use_int8_w8a16
,
use_int4_w4a16
=
self
.
use_int4_w4a16
,
use_int4_w4a8
=
self
.
use_int4_w4a8
,
dtype
=
hidden_states
.
dtype
)
config
=
try_get_optimal_moe_config
(
w1
.
size
(),
w2
.
size
(),
top_k_num
,
config_dtype
,
num_tokens
,
block_shape
=
self
.
block_shape
,
)
if
hidden_states
.
dtype
==
torch
.
bfloat16
:
compute_type
=
tl
.
bfloat16
elif
hidden_states
.
dtype
==
torch
.
float16
:
compute_type
=
tl
.
float16
elif
hidden_states
.
dtype
==
torch
.
float32
:
compute_type
=
tl
.
float32
elif
hidden_states
.
dtype
==
torch
.
float8_e4m3fn
:
compute_type
=
tl
.
bfloat16
else
:
raise
ValueError
(
f
"Unsupported compute_type:
{
hidden_states
.
dtype
}
"
)
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
intermediate_cache1
=
_resize_cache
(
workspace13
,
(
num_tokens
,
top_k_num
,
N
))
intermediate_cache2
=
_resize_cache
(
workspace2
,
(
num_tokens
*
top_k_num
,
N
//
2
))
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
topk_ids
,
config
[
'BLOCK_SIZE_M'
],
global_num_experts
,
expert_map
))
invoke_fused_moe_kernel
(
hidden_states
,
w1
,
intermediate_cache1
,
a1q_scale
,
w1_scale
,
w1_zp
,
None
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
False
,
top_k_num
,
config
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
self
.
use_fp8_w8a8
,
use_int8_w8a8
=
self
.
use_int8_w8a8
,
use_int8_w8a16
=
self
.
use_int8_w8a16
,
use_int4_w4a16
=
self
.
use_int4_w4a16
,
use_int4_w4a8
=
self
.
use_int4_w4a8
,
per_channel_quant
=
self
.
per_act_token_quant
,
block_shape
=
self
.
block_shape
)
self
.
activation
(
activation
,
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
a2q_scale
:
Optional
[
torch
.
Tensor
]
=
None
qintermediate_cache2
,
a2q_scale
=
moe_kernel_quantize_input
(
intermediate_cache2
,
a2_scale
,
self
.
quant_dtype
,
self
.
per_act_token_quant
,
self
.
block_shape
)
invoke_fused_moe_kernel
(
qintermediate_cache2
,
w2
,
output
,
a2q_scale
,
w2_scale
,
w2_zp
,
None
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
False
,
1
,
config
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
self
.
use_fp8_w8a8
,
use_int8_w8a8
=
self
.
use_int8_w8a8
,
use_int8_w8a16
=
self
.
use_int8_w8a16
,
use_int4_w4a16
=
self
.
use_int4_w4a16
,
use_int4_w4a8
=
self
.
use_int4_w4a8
,
per_channel_quant
=
self
.
per_act_token_quant
,
block_shape
=
self
.
block_shape
)
def
modular_triton_fused_moe
(
use_fp8_w8a8
:
bool
,
use_int8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
use_int4_w4a8
:
bool
,
per_act_token_quant
:
bool
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
)
->
mk
.
FusedMoEModularKernel
:
return
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
TritonExperts
(
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int4_w4a8
=
use_int4_w4a8
,
per_act_token_quant
=
per_act_token_quant
,
block_shape
=
block_shape
,
),
)
vllm/model_executor/layers/quantization/__init__.py
View file @
8223f750
...
...
@@ -33,6 +33,7 @@ QuantizationMethods = Literal[
"ipex"
,
"quark"
,
"moe_wna16"
,
"groupwise-quant"
,
"torchao"
,
"auto-round"
,
"rtn"
,
...
...
@@ -120,6 +121,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from
.blockwise_int8
import
BlockInt8Config
from
.slimquant_w4a8
import
SlimQuantW4A8Int8Config
from
.slimquant_w4a8_marlin
import
SlimQuantW4A8Int8MarlinConfig
from
.groupwise_quant
import
GroupwiseQuantConfig
method_to_config
:
dict
[
str
,
type
[
QuantizationConfig
]]
=
{
"aqlm"
:
AQLMConfig
,
...
...
@@ -152,6 +154,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"auto-round"
:
AutoRoundConfig
,
"rtn"
:
RTNConfig
,
"blockwise_int8"
:
BlockInt8Config
,
"groupwise-quant"
:
GroupwiseQuantConfig
,
"slimquant_w4a8"
:
SlimQuantW4A8Int8Config
,
"slimquant_w4a8_marlin"
:
SlimQuantW4A8Int8MarlinConfig
,
}
...
...
vllm/model_executor/layers/quantization/groupwise_quant.py
View file @
8223f750
...
...
@@ -3,18 +3,16 @@ from typing import Any, Callable, Dict, List, Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm.distributed
import
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.layers.fused_moe.optimus_moe
import
(
# noqa: F401
optimus_moe_int8
)
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
#from vllm.model_executor.layers.fused_moe.optimus_moe import ( # noqa: F401
# optimus_moe_int8)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
direct_register_custom_op
class
GroupwiseQuantConfig
(
QuantizationConfig
):
"""Config class for Groupwise Quantization.
...
...
@@ -159,92 +157,7 @@ class GroupwiseQuantLinearMethod(LinearMethodBase):
has_num_experts
=
any
(
"num_experts"
in
name
for
name
in
layer_keys
)
if
not
has_num_experts
:
layer
.
register_parameter
(
"num_experts"
,
None
)
if
self
.
quant_config
.
weight_bits
==
4
:
assert
input_size_per_partition
%
self
.
quant_config
.
group_size
==
0
assert
output_size_per_partition
%
self
.
quant_config
.
pack_factor
==
0
if
num_experts
:
weight_shape
=
[
num_experts
,
input_size_per_partition
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
]
scale_shape
=
[
num_experts
,
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
]
else
:
weight_shape
=
[
input_size_per_partition
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
]
scale_shape
=
[
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
]
qweight
=
Parameter
(
torch
.
empty
(
*
weight_shape
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
scales
=
Parameter
(
torch
.
empty
(
*
scale_shape
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
zeros
=
Parameter
(
torch
.
empty
(
*
scale_shape
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
if
num_experts
:
set_weight_attrs
(
qweight
,
{
"input_dim"
:
1
,
"output_dim"
:
2
,
"packed_dim"
:
2
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
set_weight_attrs
(
scales
,
{
"input_dim"
:
1
,
"output_dim"
:
2
,
})
set_weight_attrs
(
zeros
,
{
"input_dim"
:
1
,
"output_dim"
:
2
,
})
else
:
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
set_weight_attrs
(
scales
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
})
set_weight_attrs
(
zeros
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
})
layer
.
register_parameter
(
"qweight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"scales"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
layer
.
register_parameter
(
"zeros"
,
zeros
)
set_weight_attrs
(
zeros
,
extra_weight_attrs
)
elif
self
.
quant_config
.
weight_bits
==
8
:
if
self
.
quant_config
.
weight_bits
==
8
:
assert
input_size_per_partition
%
self
.
quant_config
.
group_size
==
0
if
num_experts
:
weight_shape
=
[
...
...
@@ -300,68 +213,6 @@ class GroupwiseQuantLinearMethod(LinearMethodBase):
"output_dim"
:
1
,
})
layer
.
register_parameter
(
"qweight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"scales"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
elif
self
.
quant_config
.
weight_bits
==
6
:
assert
input_size_per_partition
%
self
.
quant_config
.
group_size
==
0
if
num_experts
:
weight_shape
=
[
num_experts
,
output_size_per_partition
,
input_size_per_partition
,
]
scale_shape
=
[
num_experts
,
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
]
else
:
weight_shape
=
[
output_size_per_partition
,
input_size_per_partition
]
scale_shape
=
[
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
]
qweight
=
Parameter
(
torch
.
zeros
(
*
weight_shape
,
device
=
"cpu"
,
# hack for fp6 weight is stored in float16, to avoid cuda oom
dtype
=
torch
.
float16
,
),
requires_grad
=
False
,
)
scales
=
Parameter
(
torch
.
empty
(
*
scale_shape
,
device
=
"cuda"
,
dtype
=
torch
.
float16
,
),
requires_grad
=
False
,
)
if
num_experts
:
set_weight_attrs
(
qweight
,
{
"input_dim"
:
2
,
"output_dim"
:
1
,
})
set_weight_attrs
(
scales
,
{
"input_dim"
:
1
,
"output_dim"
:
2
,
})
else
:
set_weight_attrs
(
qweight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
,
})
set_weight_attrs
(
scales
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
})
layer
.
register_parameter
(
"qweight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"scales"
,
scales
)
...
...
@@ -372,157 +223,78 @@ class GroupwiseQuantLinearMethod(LinearMethodBase):
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
if
not
hasattr
(
layer
,
"qweight"
):
return
if
self
.
quant_config
.
weight_bits
==
4
:
num_experts
=
layer
.
num_experts
qweight
=
layer
.
qweight
zeros
=
layer
.
zeros
scales
=
layer
.
scales
if
num_experts
:
qscales_list
=
[]
for
i
in
range
(
num_experts
):
qweight_processed
,
qscales
=
torch
.
ops
.
Optimus
.
GemmInt4GroupQuantWeight
(
qweight
[
i
],
zeros
[
i
],
scales
[
i
]
+
zeros
[
i
],
self
.
quant_config
.
group_size
)
qweight
[
i
].
copy_
(
qweight_processed
)
qscales_list
.
append
(
qscales
)
qscales
=
Parameter
(
torch
.
stack
(
qscales_list
),
requires_grad
=
False
)
layer
.
register_parameter
(
"qscales"
,
qscales
)
else
:
qweight_processed
,
qscales
=
torch
.
ops
.
Optimus
.
GemmInt4GroupQuantWeight
(
qweight
,
zeros
,
scales
+
zeros
,
self
.
quant_config
.
group_size
)
qweight
.
copy_
(
qweight_processed
)
qscales
=
Parameter
(
qscales
,
requires_grad
=
False
)
layer
.
register_parameter
(
"qscales"
,
qscales
)
layer
.
_parameters
.
pop
(
"zeros"
)
layer
.
_parameters
.
pop
(
"scales"
)
elif
self
.
quant_config
.
weight_bits
==
8
:
num_experts
=
layer
.
num_experts
qweight
=
layer
.
qweight
if
num_experts
:
for
i
in
range
(
num_experts
):
qweight
[
i
].
copy_
(
torch
.
ops
.
Optimus
.
FpAIntBPreprocessWeightGPU
(
qweight
[
i
].
t
().
contiguous
(),
torch
.
int8
))
else
:
qweight
.
copy_
(
torch
.
ops
.
Optimus
.
FpAIntBPreprocessWeightGPU
(
qweight
.
t
().
contiguous
(),
torch
.
int8
))
elif
self
.
quant_config
.
weight_bits
==
6
:
if
self
.
quant_config
.
weight_bits
==
8
:
num_experts
=
layer
.
num_experts
qweight
=
layer
.
qweight
layer
.
_parameters
.
pop
(
"qweight"
)
assert
qweight
.
shape
[
-
1
]
%
8
==
0
if
num_experts
:
qweight_processed
=
torch
.
empty
(
qweight
.
shape
[
0
],
qweight
.
shape
[
1
],
qweight
.
shape
[
2
]
*
6
//
8
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
else
:
qweight_processed
=
torch
.
empty
(
qweight
.
shape
[
0
],
qweight
.
shape
[
1
]
*
6
//
8
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
if
num_experts
:
for
i
in
range
(
num_experts
):
qweight_processed
[
i
]
=
torch
.
ops
.
Optimus
.
fp6_preprocess_weight
(
qweight
[
i
].
cpu
()).
cuda
()
qweight_processed
=
Parameter
(
qweight_processed
,
requires_grad
=
False
)
layer
.
register_parameter
(
"qweight"
,
qweight_processed
)
qweight
[
i
].
copy_
(
qweight
[
i
].
contiguous
())
else
:
qweight_processed
=
Parameter
(
torch
.
ops
.
Optimus
.
fp6_preprocess_weight
(
qweight
.
cpu
()).
cuda
(),
requires_grad
=
False
)
layer
.
register_parameter
(
"qweight"
,
qweight_processed
)
qweight
.
copy_
(
qweight
.
contiguous
())
else
:
raise
NotImplementedError
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_idx
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
if
self
.
quant_config
.
weight_bits
==
4
:
qweight
=
layer
.
qweight
qscales
=
layer
.
qscales
num_experts
=
layer
.
num_experts
if
num_experts
:
assert
expert_idx
is
not
None
,
"expert_idx is None"
qweight
=
qweight
[
expert_idx
]
qscales
=
qscales
[
expert_idx
]
out
=
torch
.
ops
.
vllm
.
optimus_gemm_int4_group
(
x
,
qweight
,
qscales
,
bias
,
None
,
# Placeholder for a fifth argument that is None
out
=
output
)
if
residual
is
not
None
:
out
+=
residual
return
out
elif
self
.
quant_config
.
weight_bits
==
8
:
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_idx
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
if
self
.
quant_config
.
weight_bits
==
8
:
qweight
=
layer
.
qweight
scales
=
layer
.
scales
num_experts
=
layer
.
num_experts
if
num_experts
:
assert
expert_idx
is
not
None
,
"expert_idx is None"
qweight
=
qweight
[
expert_idx
]
scales
=
scales
[
expert_idx
]
group_size
=
self
.
quant_config
.
group_size
input_size_per_partition
=
qweight
.
size
(
0
)
output_size_per_partition
=
qweight
.
size
(
1
)
qweight
=
qweight
.
to
(
torch
.
float32
)
scales
=
scales
.
to
(
torch
.
float32
)
scales_expanded
=
scales
.
repeat_interleave
(
group_size
,
dim
=
0
)
scales_expanded
=
scales_expanded
[:
input_size_per_partition
]
weight
=
qweight
*
scales_expanded
weight
=
weight
.
to
(
x
.
dtype
)
if
residual
is
not
None
:
assert
output
is
None
or
output
is
residual
out
=
torch
.
ops
.
vllm
.
optimus_fp_aintb_gemm
(
x
,
qweight
,
torch
.
int8
,
# Placeholder for dtype argument
scales
,
residual
,
out
=
residual
)
if
get_tensor_model_parallel_world_size
()
>
1
and
get_tensor_model_parallel_rank
()
!=
0
:
beta
=
0.0
else
:
beta
=
1.0
if
x
.
dim
()
==
2
:
torch
.
addmm
(
residual
,
x
,
weight
.
t
(),
beta
=
beta
,
out
=
residual
)
elif
x
.
dim
()
>=
3
:
hx
=
x
.
size
(
-
1
)
hr
=
residual
.
size
(
-
1
)
torch
.
addmm
(
residual
.
view
(
-
1
,
hr
),
x
.
view
(
-
1
,
hx
),
weight
.
t
(),
beta
=
beta
,
out
=
residual
.
view
(
-
1
,
hr
))
else
:
raise
AssertionError
(
f
"unrecognized tensor dimensions:
{
x
.
dim
()
}
"
)
if
bias
is
not
None
:
out
+=
bias
residual
+=
bias
return
residual
else
:
out
=
torch
.
ops
.
vllm
.
optimus_fp_aintb_gemm
(
x
,
qweight
,
torch
.
int8
,
# Placeholder for dtype argument
scales
,
bias
,
out
=
output
)
return
out
elif
self
.
quant_config
.
weight_bits
==
6
:
qweight
=
layer
.
qweight
scales
=
layer
.
scales
num_experts
=
layer
.
num_experts
if
num_experts
:
assert
expert_idx
is
not
None
,
"expert_idx is None"
qweight
=
qweight
[
expert_idx
]
scales
=
scales
[
expert_idx
]
if
x
.
dtype
!=
torch
.
bfloat16
:
if
output
is
None
:
output
=
torch
.
empty
(
x
.
shape
[
0
],
qweight
.
shape
[
0
],
device
=
x
.
device
,
dtype
=
torch
.
bfloat16
)
if
output
is
not
None
:
if
bias
is
not
None
:
# always separate bias add when output is provided
torch
.
matmul
(
x
,
weight
.
t
(),
out
=
output
)
output
.
add_
(
bias
)
return
output
return
torch
.
matmul
(
x
,
weight
.
t
(),
out
=
output
)
else
:
output
=
output
.
to
(
torch
.
bfloat16
)
out
=
torch
.
ops
.
vllm
.
optimus_fp6_linear
(
x
,
qweight
,
scales
,
4
,
# Placeholder for fp6_format_code
out
=
output
)
if
bias
is
not
None
:
out
+=
bias
if
residual
is
not
None
:
out
+=
residual
return
out
return
torch
.
nn
.
functional
.
linear
(
x
,
weight
.
t
(),
bias
)
else
:
raise
NotImplementedError
...
...
@@ -603,19 +375,20 @@ class GroupwiseInt8MoeMethod(FusedMoEMethodBase):
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# Process weights similar to GroupwiseQuantLinearMethod for 8-bit case
num_experts
=
layer
.
w13_weight
.
shape
[
0
]
#
num_experts = layer.w13_weight.shape[0]
for
expert
in
range
(
num_experts
):
# Preprocess w13 weight (gate and up combined)
layer
.
w13_weight
[
expert
].
copy_
(
torch
.
ops
.
Optimus
.
FpAIntBPreprocessWeightGPU
(
layer
.
w13_weight
[
expert
].
t
().
contiguous
(),
torch
.
int8
))
#
for expert in range(num_experts):
#
# Preprocess w13 weight (gate and up combined)
#
layer.w13_weight[expert].copy_(
#
torch.ops.Optimus.FpAIntBPreprocessWeightGPU(
#
layer.w13_weight[expert].t().contiguous(), torch.int8))
# Preprocess w2 weight (down)
layer
.
w2_weight
[
expert
].
copy_
(
torch
.
ops
.
Optimus
.
FpAIntBPreprocessWeightGPU
(
layer
.
w2_weight
[
expert
].
t
().
contiguous
(),
torch
.
int8
))
# # Preprocess w2 weight (down)
# layer.w2_weight[expert].copy_(
# torch.ops.Optimus.FpAIntBPreprocessWeightGPU(
# layer.w2_weight[expert].t().contiguous(), torch.int8))
pass
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -632,95 +405,48 @@ class GroupwiseInt8MoeMethod(FusedMoEMethodBase):
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
)
->
torch
.
Tensor
:
return
torch
.
ops
.
vllm
.
optimus_moe_int8
(
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `ExpertsInt8MoEMethod` yet."
)
from
vllm.model_executor.layers.fused_moe.fused_moe_step3vw8a16
import
fused_experts
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
global_num_experts
=
global_num_experts
,
norm_expert_weight
=
renormalize
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
#复用bf16 moe逻辑
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
)
# Wrapper and Fake Functions for Optimus::GemmInt4Group
def
optimus_gemm_int4_group
(
x
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
qscales
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
torch
.
ops
.
Optimus
.
GemmInt4Group
(
x
,
qweight
,
qscales
,
bias
,
None
,
out
=
out
)
def
optimus_gemm_int4_group_fake
(
x
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
qscales
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output_shape
=
list
(
x
.
shape
[:
-
1
])
+
[
qscales
.
shape
[
-
1
]]
if
out
is
not
None
:
return
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
return
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
# Wrapper and Fake Functions for Optimus::FpAIntBGemm
def
optimus_fp_aintb_gemm
(
x
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
dtype_arg
:
torch
.
dtype
,
scales
:
torch
.
Tensor
,
bias_or_residual
:
Optional
[
torch
.
Tensor
],
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
torch
.
ops
.
Optimus
.
FpAIntBGemm
(
x
,
qweight
,
dtype_arg
,
scales
,
bias_or_residual
,
"identity"
,
out
=
out
)
def
optimus_fp_aintb_gemm_fake
(
x
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
dtype_arg
:
torch
.
dtype
,
scales
:
torch
.
Tensor
,
bias_or_residual
:
Optional
[
torch
.
Tensor
],
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output_shape
=
list
(
x
.
shape
[:
-
1
])
+
[
qweight
.
shape
[
-
1
]]
if
out
is
not
None
:
return
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
return
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
# Wrapper and Fake Functions for Optimus::fp6_linear
def
optimus_fp6_linear
(
x
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
fp6_format_code
:
int
,
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
torch
.
ops
.
Optimus
.
fp6_linear
(
x
,
qweight
,
scales
,
fp6_format_code
,
out
=
out
)
def
optimus_fp6_linear_fake
(
x
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
fp6_format_code
:
int
,
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output_channels
=
scales
.
shape
[
-
1
]
output_shape
=
list
(
x
.
shape
[:
-
1
])
+
[
output_channels
]
output_dtype
=
x
.
dtype
if
x
.
dtype
!=
torch
.
bfloat16
:
output_dtype
=
torch
.
bfloat16
if
out
is
not
None
:
return
torch
.
empty
(
output_shape
,
dtype
=
output_dtype
,
device
=
x
.
device
)
return
torch
.
empty
(
output_shape
,
dtype
=
output_dtype
,
device
=
x
.
device
)
direct_register_custom_op
(
op_name
=
"optimus_gemm_int4_group"
,
op_func
=
optimus_gemm_int4_group
,
mutates_args
=
[
"out"
],
fake_impl
=
optimus_gemm_int4_group_fake
,
)
direct_register_custom_op
(
op_name
=
"optimus_fp_aintb_gemm"
,
op_func
=
optimus_fp_aintb_gemm
,
mutates_args
=
[
"out"
,
"bias_or_residual"
],
fake_impl
=
optimus_fp_aintb_gemm_fake
,
)
direct_register_custom_op
(
op_name
=
"optimus_fp6_linear"
,
op_func
=
optimus_fp6_linear
,
mutates_args
=
[
"out"
],
fake_impl
=
optimus_fp6_linear_fake
,
)
\ No newline at end of file
use_nn_moe
=
True
,
global_num_experts
=
global_num_experts
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
expert_map
=
expert_map
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
use_step3v_w8a16
=
True
)
vllm/platforms/rocm.py
View file @
8223f750
...
...
@@ -16,17 +16,14 @@ from vllm.utils import cuda_device_count_stateless
from
.interface
import
DeviceCapability
,
Platform
,
PlatformEnum
,
_Backend
from
vllm.utils
import
is_kme
,
SUPPORT_TC
from
vllm.utils
import
SUPPORT_TC
if
not
SUPPORT_TC
:
os
.
environ
[
'VLLM_USE_V1'
]
=
'0'
os
.
environ
[
'VLLM_USE_FLASH_ATTN_PA'
]
=
'0'
os
.
environ
[
'VLLM_USE_FLASH_MLA'
]
=
'0'
if
is_kme
:
os
.
environ
[
'VLLM_USE_FLASH_ATTN_PA'
]
=
'0'
if
TYPE_CHECKING
:
from
vllm.config
import
ModelConfig
,
VllmConfig
...
...
@@ -190,7 +187,7 @@ class RocmPlatform(Platform):
device_control_env_var
:
str
=
"CUDA_VISIBLE_DEVICES"
supported_quantization
:
list
[
str
]
=
[
"awq"
,
"gptq"
,
"fp8"
,
"compressed-tensors"
,
"fbgemm_fp8"
,
"gguf"
,
"awq"
,
"gptq"
,
"fp8"
,
"compressed-tensors"
,
"fbgemm_fp8"
,
"gguf"
,
"groupwise-quant"
,
"quark"
,
"ptpc_fp8"
,
"moe_wna16"
,
"blockwise_int8"
,
"slimquant_w4a8"
,
"awq_marlin"
,
"slimquant_w4a8_marlin"
]
...
...
@@ -304,8 +301,6 @@ class RocmPlatform(Platform):
logger
.
info
(
"flash_attn is not supported on NAVI GPUs."
)
else
:
logger
.
info
(
"%s is not supported in AMD GPUs."
,
selected_backend
)
if
is_kme
:
os
.
environ
[
'VLLM_USE_TRITON_FLASH_ATTN'
]
=
'1'
logger
.
info
(
"Using ROCmFlashAttention backend."
)
return
"vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend"
# noqa: E501
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment