Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
27b78c73
Unverified
Commit
27b78c73
authored
Jan 29, 2025
by
Jinzhen Lin
Committed by
GitHub
Jan 29, 2025
Browse files
[Kernel] add triton fused moe kernel for gptq/awq (#12185)
parent
b02fd288
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
874 additions
and
55 deletions
+874
-55
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+91
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+354
-53
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+5
-2
vllm/model_executor/layers/quantization/moe_wna16.py
vllm/model_executor/layers/quantization/moe_wna16.py
+424
-0
No files found.
tests/kernels/test_moe.py
View file @
27b78c73
...
...
@@ -18,6 +18,8 @@ from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe
as
iterative_moe
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
marlin_quantize
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
quantize_weights
)
from
vllm.model_executor.models.mixtral
import
MixtralMoE
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
...
...
@@ -55,6 +57,95 @@ def test_fused_moe(
rtol
=
0
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
32
,
222
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
1024
,
2048
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
1024
])
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"has_zp"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"weight_bits"
,
[
4
,
8
])
def
test_fused_moe_wn16
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
group_size
:
int
,
has_zp
:
bool
,
weight_bits
:
int
):
print
(
m
,
n
,
k
,
e
,
topk
,
dtype
,
group_size
,
has_zp
,
weight_bits
)
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
if
weight_bits
==
4
:
pack_factor
=
2
quant_type
=
scalar_types
.
uint4
if
has_zp
else
scalar_types
.
uint4b8
elif
weight_bits
==
8
:
pack_factor
=
1
quant_type
=
scalar_types
.
uint8
if
has_zp
else
scalar_types
.
uint8b128
w1_ref
=
w1
.
clone
()
w2_ref
=
w2
.
clone
()
w1_qweight
=
torch
.
empty
((
e
,
2
*
n
,
k
//
pack_factor
),
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
w2_qweight
=
torch
.
empty
((
e
,
k
,
n
//
pack_factor
),
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
w1_scales
=
torch
.
empty
((
e
,
2
*
n
,
k
//
group_size
),
device
=
"cuda"
,
dtype
=
dtype
)
w2_scales
=
torch
.
empty
((
e
,
k
,
n
//
group_size
),
device
=
"cuda"
,
dtype
=
dtype
)
w1_qzeros
=
torch
.
empty
((
e
,
2
*
n
//
pack_factor
,
k
//
group_size
),
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
w2_qzeros
=
torch
.
empty
((
e
,
k
//
pack_factor
,
n
//
group_size
),
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
for
i
in
range
(
e
*
2
):
expert_id
=
i
%
e
if
i
//
e
==
0
:
w
,
w_ref
,
w_qweight
,
w_scales
,
w_qzeros
=
\
w1
,
w1_ref
,
w1_qweight
,
w1_scales
,
w1_qzeros
else
:
w
,
w_ref
,
w_qweight
,
w_scales
,
w_qzeros
=
\
w2
,
w2_ref
,
w2_qweight
,
w2_scales
,
w2_qzeros
weight
,
qweight
,
scales
,
qzeros
=
quantize_weights
(
w
[
expert_id
].
T
,
quant_type
,
group_size
,
has_zp
,
False
)
weight
=
weight
.
T
qweight
=
qweight
.
T
.
contiguous
().
to
(
torch
.
uint8
)
scales
=
scales
.
T
if
has_zp
:
qzeros
=
qzeros
.
T
.
contiguous
().
to
(
torch
.
uint8
)
if
weight_bits
==
4
:
qweight
=
qweight
[:,
1
::
2
]
*
16
+
qweight
[:,
::
2
]
if
has_zp
:
qzeros
=
qzeros
[
1
::
2
,
:]
*
16
+
qzeros
[::
2
,
:]
w_ref
[
expert_id
]
=
weight
w_qweight
[
expert_id
]
=
qweight
w_scales
[
expert_id
]
=
scales
if
has_zp
:
w_qzeros
[
expert_id
]
=
qzeros
triton_output
=
fused_moe
(
a
,
w1_qweight
,
w2_qweight
,
score
,
topk
,
renormalize
=
False
,
use_int4_w4a16
=
weight_bits
==
4
,
use_int8_w8a16
=
weight_bits
==
8
,
w1_scale
=
w1_scales
,
w2_scale
=
w2_scales
,
w1_zp
=
w1_qzeros
if
has_zp
else
None
,
w2_zp
=
w2_qzeros
if
has_zp
else
None
,
block_shape
=
[
0
,
group_size
])
torch_output
=
torch_moe
(
a
,
w1_ref
,
w2_ref
,
score
,
topk
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
torch
.
inference_mode
()
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
27b78c73
...
...
@@ -19,6 +19,206 @@ from vllm.utils import direct_register_custom_op
logger
=
init_logger
(
__name__
)
@
triton
.
jit
def
fused_moe_kernel_gptq_awq
(
# Pointers to matrices
a_ptr
,
b_ptr
,
c_ptr
,
b_scale_ptr
,
b_zp_ptr
,
topk_weights_ptr
,
sorted_token_ids_ptr
,
expert_ids_ptr
,
num_tokens_post_padded_ptr
,
# Matrix dimensions
N
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
EM
,
num_valid_tokens
,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am
,
stride_ak
,
stride_be
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
stride_bse
,
stride_bsk
,
stride_bsn
,
stride_bze
,
stride_bzk
,
stride_bzn
,
block_k_diviable
:
tl
.
constexpr
,
group_size
:
tl
.
constexpr
,
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
MUL_ROUTED_WEIGHT
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
has_zp
:
tl
.
constexpr
,
use_int4_w4a16
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can
be any shape representing batches and K is the feature dimension of
each token.
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
the number of experts, K is the input feature dimension, and N is
the output feature dimension.
- C: The output cache tensor with shape (M, topk, N), where M is the
total number of tokens post padding, topk is the number of times
each token is repeated, and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens,
repeated topk times and arranged by the expert index they are
assigned to.
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.
This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
multiplication across different blocks processed by the same expert.
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
pid
=
tl
.
program_id
(
axis
=
0
)
num_pid_m
=
tl
.
cdiv
(
EM
,
BLOCK_SIZE_M
)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE_N
)
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
group_id
=
pid
//
num_pid_in_group
first_pid_m
=
group_id
*
GROUP_SIZE_M
group_size_m
=
min
(
num_pid_m
-
first_pid_m
,
GROUP_SIZE_M
)
pid_m
=
first_pid_m
+
((
pid
%
num_pid_in_group
)
%
group_size_m
)
pid_n
=
(
pid
%
num_pid_in_group
)
//
group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
num_tokens_post_padded
=
tl
.
load
(
num_tokens_post_padded_ptr
)
if
pid_m
*
BLOCK_SIZE_M
>=
num_tokens_post_padded
:
return
offs_token_id
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
).
to
(
tl
.
int64
)
offs_token
=
tl
.
load
(
sorted_token_ids_ptr
+
offs_token_id
)
token_mask
=
offs_token
<
num_valid_tokens
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
).
to
(
tl
.
int64
))
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
).
to
(
tl
.
int64
)
if
use_int4_w4a16
:
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
\
(
offs_k
[:,
None
]
//
2
)
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
b_shifter
=
(
offs_k
[:,
None
]
%
2
)
*
4
elif
use_int8_w8a16
:
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
\
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
if
not
has_zp
and
use_int4_w4a16
:
b_zp_num
=
8
if
not
has_zp
and
use_int8_w8a16
:
b_zp_num
=
128
elif
has_zp
and
use_int4_w4a16
:
b_zp_shifter
=
(
offs_bn
[
None
,
:]
%
2
)
*
4
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
if
not
block_k_diviable
:
k_mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
k_other
=
0.0
else
:
k_mask
=
None
k_other
=
None
a
=
tl
.
load
(
a_ptrs
,
mask
=
token_mask
[:,
None
]
&
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
),
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
)
if
use_int4_w4a16
:
b
=
(
b
>>
b_shifter
)
&
0xF
b_scale_ptrs
=
b_scale_ptr
+
off_experts
*
stride_bse
+
\
offs_bn
[
None
,
:]
*
stride_bsn
+
\
((
offs_k
[:,
None
]
+
BLOCK_SIZE_K
*
k
)
//
group_size
)
*
stride_bsk
b_scale
=
tl
.
load
(
b_scale_ptrs
,
mask
=
k_mask
,
other
=
k_other
)
b_scale
=
b_scale
.
to
(
tl
.
float32
)
if
has_zp
and
use_int4_w4a16
:
offs_k_true
=
(
offs_k
[:,
None
]
+
BLOCK_SIZE_K
*
k
)
//
group_size
b_zp_ptrs
=
b_zp_ptr
+
off_experts
*
stride_bze
+
\
(
offs_bn
[
None
,
:]
//
2
)
*
stride_bzn
+
\
offs_k_true
*
stride_bzk
b_zp
=
tl
.
load
(
b_zp_ptrs
,
mask
=
k_mask
,
other
=
k_other
)
b_zp
=
((
b_zp
>>
b_zp_shifter
)
&
0xF
)
b_zp
=
b_zp
.
to
(
tl
.
float32
)
elif
has_zp
and
use_int8_w8a16
:
offs_k_true
=
(
offs_k
[:,
None
]
+
BLOCK_SIZE_K
*
k
)
//
group_size
b_zp_ptrs
=
b_zp_ptr
+
off_experts
*
stride_bze
+
\
offs_bn
[
None
,
:]
*
stride_bzn
+
\
offs_k_true
*
stride_bzk
b_zp
=
tl
.
load
(
b_zp_ptrs
,
mask
=
k_mask
,
other
=
k_other
)
b_zp
=
b_zp
.
to
(
tl
.
float32
)
# We accumulate along the K dimension.
if
has_zp
:
b
=
((
b
.
to
(
tl
.
float32
)
-
b_zp
)
*
b_scale
).
to
(
compute_type
)
else
:
b
=
((
b
.
to
(
tl
.
float32
)
-
b_zp_num
)
*
b_scale
).
to
(
compute_type
)
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
# Advance the ptrs to the next K block.
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
if
use_int4_w4a16
:
b_ptrs
+=
(
BLOCK_SIZE_K
//
2
)
*
stride_bk
else
:
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
MUL_ROUTED_WEIGHT
:
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0
)
accumulator
=
accumulator
*
moe_weight
[:,
None
]
accumulator
=
accumulator
.
to
(
compute_type
)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
c_ptr
+
stride_cm
*
offs_token
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
token_mask
[:,
None
]
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
@
triton
.
jit
def
fused_moe_kernel
(
# Pointers to matrices
...
...
@@ -266,6 +466,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
C
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
B_scale
:
Optional
[
torch
.
Tensor
],
B_zp
:
Optional
[
torch
.
Tensor
],
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
...
...
@@ -277,6 +478,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
compute_type
:
tl
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
use_int4_w4a16
:
bool
,
block_shape
:
Optional
[
List
[
int
]]
=
None
)
->
None
:
assert
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
...
...
@@ -292,50 +494,108 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
assert
triton
.
cdiv
(
B
.
shape
[
-
2
],
block_n
)
==
B_scale
.
shape
[
-
2
]
assert
triton
.
cdiv
(
B
.
shape
[
-
1
],
block_k
)
==
B_scale
.
shape
[
-
1
]
elif
use_int8_w8a16
:
elif
use_int8_w8a16
or
use_int4_w4a16
:
assert
B_scale
is
not
None
assert
block_shape
is
None
or
block_shape
[
0
]
==
0
else
:
assert
A_scale
is
None
assert
B_scale
is
None
grid
=
lambda
META
:
(
triton
.
cdiv
(
sorted_token_ids
.
shape
[
0
],
META
[
'BLOCK_SIZE_M'
])
*
triton
.
cdiv
(
B
.
shape
[
1
],
META
[
'BLOCK_SIZE_N'
]),
)
EM
=
sorted_token_ids
.
shape
[
0
]
if
A
.
shape
[
0
]
<
config
[
"BLOCK_SIZE_M"
]:
# optimize for small batch_size.
# We assume that top_ids of each token is unique, so
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
# and we can skip some invalid blocks.
EM
=
min
(
sorted_token_ids
.
shape
[
0
],
A
.
shape
[
0
]
*
top_k
*
config
[
'BLOCK_SIZE_M'
])
grid
=
lambda
META
:
(
triton
.
cdiv
(
EM
,
META
[
'BLOCK_SIZE_M'
])
*
triton
.
cdiv
(
B
.
shape
[
1
],
META
[
'BLOCK_SIZE_N'
]),
)
if
(
use_int8_w8a16
or
use_int4_w4a16
)
and
\
block_shape
is
not
None
and
block_shape
[
1
]
>
0
:
assert
B_scale
is
not
None
and
B_scale
.
ndim
==
3
assert
B_zp
is
None
or
B_zp
.
ndim
==
3
fused_moe_kernel_gptq_awq
[
grid
](
A
,
B
,
C
,
B_scale
,
B_zp
,
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
B
.
shape
[
1
],
A
.
shape
[
1
],
EM
,
topk_ids
.
numel
(),
A
.
stride
(
0
),
A
.
stride
(
1
),
B
.
stride
(
0
),
B
.
stride
(
2
),
B
.
stride
(
1
),
C
.
stride
(
1
),
C
.
stride
(
2
),
B_scale
.
stride
(
0
),
B_scale
.
stride
(
2
),
B_scale
.
stride
(
1
),
B_zp
.
stride
(
0
)
if
B_zp
is
not
None
else
0
,
B_zp
.
stride
(
2
)
if
B_zp
is
not
None
else
0
,
B_zp
.
stride
(
1
)
if
B_zp
is
not
None
else
0
,
block_k_diviable
=
A
.
shape
[
1
]
%
config
[
"BLOCK_SIZE_K"
]
==
0
,
group_size
=
block_shape
[
1
],
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
top_k
=
top_k
,
compute_type
=
compute_type
,
has_zp
=
B_zp
is
not
None
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int8_w8a16
=
use_int8_w8a16
,
**
config
,
)
fused_moe_kernel
[
grid
](
A
,
B
,
C
,
A_scale
,
B_scale
,
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
B
.
shape
[
1
],
B
.
shape
[
2
],
sorted_token_ids
.
shape
[
0
],
topk_ids
.
numel
(),
A
.
stride
(
0
),
A
.
stride
(
1
),
B
.
stride
(
0
),
B
.
stride
(
2
),
B
.
stride
(
1
),
C
.
stride
(
1
),
C
.
stride
(
2
),
A_scale
.
stride
(
0
)
if
A_scale
is
not
None
and
A_scale
.
ndim
==
2
else
0
,
A_scale
.
stride
(
1
)
if
A_scale
is
not
None
and
A_scale
.
ndim
==
2
else
0
,
B_scale
.
stride
(
0
)
if
B_scale
is
not
None
and
B_scale
.
ndim
>=
2
else
0
,
B_scale
.
stride
(
2
)
if
B_scale
is
not
None
and
B_scale
.
ndim
==
3
else
0
,
B_scale
.
stride
(
1
)
if
B_scale
is
not
None
and
B_scale
.
ndim
>=
2
else
0
,
0
if
block_shape
is
None
else
block_shape
[
0
],
0
if
block_shape
is
None
else
block_shape
[
1
],
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
top_k
=
top_k
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
**
config
,
)
else
:
fused_moe_kernel
[
grid
](
A
,
B
,
C
,
A_scale
,
B_scale
,
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
B
.
shape
[
1
],
A
.
shape
[
1
],
EM
,
topk_ids
.
numel
(),
A
.
stride
(
0
),
A
.
stride
(
1
),
B
.
stride
(
0
),
B
.
stride
(
2
),
B
.
stride
(
1
),
C
.
stride
(
1
),
C
.
stride
(
2
),
A_scale
.
stride
(
0
)
if
A_scale
is
not
None
and
A_scale
.
ndim
==
2
else
0
,
A_scale
.
stride
(
1
)
if
A_scale
is
not
None
and
A_scale
.
ndim
==
2
else
0
,
B_scale
.
stride
(
0
)
if
B_scale
is
not
None
and
B_scale
.
ndim
>=
2
else
0
,
B_scale
.
stride
(
2
)
if
B_scale
is
not
None
and
B_scale
.
ndim
==
3
else
0
,
B_scale
.
stride
(
1
)
if
B_scale
is
not
None
and
B_scale
.
ndim
>=
2
else
0
,
0
if
block_shape
is
None
else
block_shape
[
0
],
0
if
block_shape
is
None
else
block_shape
[
1
],
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
top_k
=
top_k
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
**
config
,
)
def
get_config_file_name
(
E
:
int
,
N
:
int
,
dtype
:
Optional
[
str
])
->
str
:
...
...
@@ -432,7 +692,7 @@ def try_get_optimal_moe_config(
# NOTE: For block-wise quant,
# BLOCK_K must be divisible by block_shape[1]
# BLOCK_N and BLOCK_M has no requirements
if
block_shape
is
not
None
:
if
block_shape
is
not
None
and
block_shape
[
0
]
!=
0
:
config
[
"BLOCK_SIZE_N"
]
=
block_shape
[
0
]
config
[
"BLOCK_SIZE_K"
]
=
block_shape
[
1
]
return
config
...
...
@@ -531,12 +791,15 @@ def grouped_topk(hidden_states: torch.Tensor,
def
get_config_dtype_str
(
dtype
:
torch
.
dtype
,
use_int4_w4a16
:
Optional
[
bool
]
=
False
,
use_int8_w8a16
:
Optional
[
bool
]
=
False
,
use_fp8_w8a8
:
Optional
[
bool
]
=
False
):
if
use_fp8_w8a8
:
return
"fp8_w8a8"
elif
use_int8_w8a16
:
return
"int8_w8a16"
elif
use_int4_w4a16
:
return
"int4_w8a16"
elif
dtype
==
torch
.
float
:
# avoiding cases where kernel fails when float32 MoE
# use fp16/bfloat16 configs
...
...
@@ -551,14 +814,17 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
topk_ids
:
torch
.
Tensor
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
)
->
None
:
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
use_fp8_w8a8
,
use_int8_w8a16
,
w1_scale
,
w
2
_scale
,
a1_scale
,
a2_scale
,
block_shape
)
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
w
1
_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
)
def
inplace_fused_experts_fake
(
...
...
@@ -569,8 +835,11 @@ def inplace_fused_experts_fake(
topk_ids
:
torch
.
Tensor
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
)
->
None
:
...
...
@@ -593,14 +862,18 @@ def outplace_fused_experts(
topk_ids
:
torch
.
Tensor
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
)
->
torch
.
Tensor
:
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
False
,
use_fp8_w8a8
,
use_int8_w8a16
,
w1_scale
,
w2_scale
,
a1_scale
,
a2_scale
,
block_shape
)
False
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
)
def
outplace_fused_experts_fake
(
...
...
@@ -611,8 +884,11 @@ def outplace_fused_experts_fake(
topk_ids
:
torch
.
Tensor
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
)
->
torch
.
Tensor
:
...
...
@@ -635,8 +911,11 @@ def fused_experts(hidden_states: torch.Tensor,
inplace
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
):
...
...
@@ -644,16 +923,15 @@ def fused_experts(hidden_states: torch.Tensor,
torch
.
ops
.
vllm
.
inplace_fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
,
use_int8_w8a16
,
w1_scale
,
w2_scale
,
a1_scale
,
use_int4_w4a16
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
)
return
hidden_states
else
:
return
torch
.
ops
.
vllm
.
outplace_fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
,
use_int8_w8a16
,
w1_scale
,
w2_scale
,
a1_scale
,
a2_scale
,
block_shape
)
return
torch
.
ops
.
vllm
.
outplace_fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
)
def
fused_experts_impl
(
hidden_states
:
torch
.
Tensor
,
...
...
@@ -664,13 +942,21 @@ def fused_experts_impl(hidden_states: torch.Tensor,
inplace
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
):
# Check constraints.
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Hidden size mismatch"
if
use_int4_w4a16
:
assert
hidden_states
.
shape
[
1
]
//
2
==
w1
.
shape
[
2
],
"Hidden size mismatch"
else
:
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Hidden size mismatch"
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
...
...
@@ -687,6 +973,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
config_dtype
=
get_config_dtype_str
(
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
dtype
=
hidden_states
.
dtype
)
get_config_func
=
functools
.
partial
(
...
...
@@ -755,6 +1042,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
intermediate_cache1
,
a1_scale
,
w1_scale
,
w1_zp
,
curr_topk_weights
,
curr_topk_ids
,
sorted_token_ids
,
...
...
@@ -766,6 +1054,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
block_shape
=
block_shape
)
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate_cache2
,
...
...
@@ -776,6 +1065,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
intermediate_cache3
,
a2_scale
,
w2_scale
,
w2_zp
,
curr_topk_weights
,
curr_topk_ids
,
sorted_token_ids
,
...
...
@@ -787,6 +1077,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
block_shape
=
block_shape
)
ops
.
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
...
...
@@ -808,8 +1099,11 @@ def fused_moe(
custom_routing_function
:
Optional
[
Callable
]
=
None
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_zp
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
...
...
@@ -834,8 +1128,12 @@ def fused_moe(
note: Deepseekv2 model uses grouped_topk
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
activation to compute the inner products for w1 and w2.
Defaults to False.
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
activation to compute the inner products for w1 and w2.
Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
...
...
@@ -873,8 +1171,11 @@ def fused_moe(
inplace
=
inplace
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w1_zp
=
w1_zp
,
w2_zp
=
w2_zp
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
)
vllm/model_executor/layers/quantization/__init__.py
View file @
27b78c73
...
...
@@ -26,7 +26,8 @@ QUANTIZATION_METHODS: List[str] = [
"experts_int8"
,
"neuron_quant"
,
"ipex"
,
"quark"
"quark"
,
"moe_wna16"
]
# The customized quantization methods which will be added to this dict.
...
...
@@ -94,6 +95,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from
.ipex_quant
import
IPEXConfig
from
.marlin
import
MarlinConfig
from
.modelopt
import
ModelOptFp8Config
from
.moe_wna16
import
MoeWNA16Config
from
.neuron_quant
import
NeuronQuantConfig
from
.qqq
import
QQQConfig
from
.tpu_int8
import
Int8TpuConfig
...
...
@@ -121,7 +123,8 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
"experts_int8"
:
ExpertsInt8Config
,
"neuron_quant"
:
NeuronQuantConfig
,
"ipex"
:
IPEXConfig
,
"quark"
:
QuarkConfig
"quark"
:
QuarkConfig
,
"moe_wna16"
:
MoeWNA16Config
,
}
# Update the `method_to_config` with customized quantization methods.
method_to_config
.
update
(
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG
)
...
...
vllm/model_executor/layers/quantization/moe_wna16.py
0 → 100644
View file @
27b78c73
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
from
vllm.distributed
import
get_tensor_model_parallel_rank
,
get_tp_group
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.layers.linear
import
UnquantizedLinearMethod
from
vllm.model_executor.layers.quantization.awq
import
(
AWQConfig
,
AWQLinearMethod
)
from
vllm.model_executor.layers.quantization.awq_marlin
import
(
AWQMarlinConfig
,
AWQMarlinLinearMethod
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.gptq
import
(
GPTQConfig
,
GPTQLinearMethod
)
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinConfig
,
GPTQMarlinLinearMethod
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
class
MoeWNA16Config
(
QuantizationConfig
):
"""Config class for MOE WNA16 (W8A16/W4A16) quantization."""
def
__init__
(
self
,
linear_quant_method
:
str
,
weight_bits
:
int
,
group_size
:
int
,
has_zp
:
bool
,
lm_head_quantized
:
bool
,
modules_to_not_convert
:
Optional
[
List
[
str
]],
full_config
:
Dict
[
str
,
Any
])
->
None
:
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
has_zp
=
has_zp
self
.
bit8_pack_factor
=
8
//
self
.
weight_bits
self
.
lm_head_quantized
=
lm_head_quantized
self
.
linear_quant_method
=
linear_quant_method
self
.
full_config
=
full_config
self
.
use_marlin
=
False
if
self
.
linear_quant_method
==
"gptq"
:
self
.
use_marlin
=
GPTQMarlinConfig
.
is_gptq_marlin_compatible
(
full_config
)
elif
self
.
linear_quant_method
==
"awq"
:
capability_tuple
=
current_platform
.
get_device_capability
()
device_capability
=
(
-
1
if
capability_tuple
is
None
else
capability_tuple
.
to_int
())
awq_min_capability
=
AWQConfig
.
get_min_capability
()
if
device_capability
<
awq_min_capability
:
raise
ValueError
(
"The quantization method moe_wna16 + awq is not supported "
"for the current GPU. "
f
"Minimum capability:
{
awq_min_capability
}
. "
f
"Current capability:
{
device_capability
}
."
)
self
.
use_marlin
=
AWQMarlinConfig
.
is_awq_marlin_compatible
(
full_config
)
else
:
raise
ValueError
(
"moe_wna16 only support gptq and awq."
)
if
modules_to_not_convert
is
None
:
self
.
modules_to_not_convert
=
[]
else
:
self
.
modules_to_not_convert
=
modules_to_not_convert
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"moe_wna16"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
bfloat16
,
torch
.
half
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
70
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"quantize_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"MoeWNA16Config"
:
linear_quant_method
=
cls
.
get_from_keys
(
config
,
[
"quant_method"
])
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
lm_head_quantized
=
cls
.
get_from_keys_or
(
config
,
[
"lm_head"
],
default
=
False
)
if
linear_quant_method
==
"gptq"
:
has_zp
=
not
cls
.
get_from_keys
(
config
,
[
"sym"
])
modules_to_not_convert
=
[]
elif
linear_quant_method
==
"awq"
:
has_zp
=
cls
.
get_from_keys
(
config
,
[
"zero_point"
])
modules_to_not_convert
=
cls
.
get_from_keys
(
config
,
[
"modules_to_not_convert"
])
else
:
raise
ValueError
(
"moe_wna16 only support gptq and awq."
)
return
cls
(
linear_quant_method
,
weight_bits
,
group_size
,
has_zp
,
lm_head_quantized
,
modules_to_not_convert
,
config
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
can_convert
=
cls
.
is_moe_wna16_compatible
(
hf_quant_cfg
)
if
can_convert
and
user_quant
==
"moe_wna16"
:
return
cls
.
get_name
()
return
None
@
classmethod
def
is_moe_wna16_compatible
(
cls
,
quant_config
:
Dict
[
str
,
Any
]):
# Extract data from quant config.
quant_method
=
quant_config
.
get
(
"quant_method"
,
""
).
lower
()
num_bits
=
quant_config
.
get
(
"bits"
)
desc_act
=
quant_config
.
get
(
"desc_act"
)
capability_tuple
=
current_platform
.
get_device_capability
()
device_capability
=
(
-
1
if
capability_tuple
is
None
else
capability_tuple
.
to_int
())
awq_min_capability
=
AWQConfig
.
get_min_capability
()
gptq_compatible
=
quant_method
==
"gptq"
and
\
not
desc_act
and
num_bits
in
[
4
,
8
]
awq_compatible
=
quant_method
==
"awq"
and
num_bits
==
4
and
\
device_capability
>=
awq_min_capability
return
gptq_compatible
or
awq_compatible
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
if
is_layer_skipped_quant
(
prefix
,
self
.
modules_to_not_convert
):
return
UnquantizedLinearMethod
()
elif
isinstance
(
layer
,
FusedMoE
):
return
MoeWNA16Method
(
self
)
else
:
if
self
.
linear_quant_method
==
"gptq"
:
if
self
.
use_marlin
:
return
GPTQMarlinLinearMethod
(
GPTQMarlinConfig
.
from_config
(
self
.
full_config
))
else
:
return
GPTQLinearMethod
(
GPTQConfig
.
from_config
(
self
.
full_config
))
elif
self
.
linear_quant_method
==
"awq"
:
if
self
.
use_marlin
:
return
AWQMarlinLinearMethod
(
AWQMarlinConfig
.
from_config
(
self
.
full_config
))
else
:
return
AWQLinearMethod
(
AWQConfig
.
from_config
(
self
.
full_config
))
else
:
raise
ValueError
(
"moe_wna16 only support gptq and awq."
)
def
is_layer_skipped_quant
(
prefix
:
str
,
modules_to_not_convert
:
List
[
str
]):
return
any
(
module_name
in
prefix
for
module_name
in
modules_to_not_convert
)
class
MoeWNA16Method
(
FusedMoEMethodBase
):
"""Linear method for MOE WNA16 (W8A16/W4A16) quantization.
Args:
quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
"""
def
__init__
(
self
,
quant_config
:
MoeWNA16Config
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
layer
.
quant_config
=
self
.
quant_config
bit8_pack_factor
=
self
.
quant_config
.
bit8_pack_factor
group_size
=
self
.
quant_config
.
group_size
group_size_div_factor
=
1
# make intermediate_size and hidden_size diviable by group_size
# we reduce the group size to ensure that
# and we would repeat the loaded_weight later
while
intermediate_size_per_partition
%
group_size
or
\
hidden_size
%
group_size
:
group_size
=
group_size
//
2
group_size_div_factor
*=
2
assert
group_size
>=
32
layer
.
group_size
=
group_size
layer
.
group_size_div_factor
=
group_size_div_factor
strategy
=
FusedMoeWeightScaleSupported
.
GROUP
.
value
extra_weight_attrs
.
update
({
"quant_method"
:
strategy
,
"is_transposed"
:
False
})
assert
'weight_loader'
in
extra_weight_attrs
weight_loader
=
extra_weight_attrs
[
'weight_loader'
]
wrapped_weight_loader
=
MoeWNA16Method
.
get_weight_loader
(
layer
,
weight_loader
)
extra_weight_attrs
[
'weight_loader'
]
=
wrapped_weight_loader
# Fused gate_up_proj (column parallel)
w13_qweight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
//
bit8_pack_factor
,
dtype
=
torch
.
uint8
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_qweight"
,
w13_qweight
)
set_weight_attrs
(
w13_qweight
,
extra_weight_attrs
)
# down_proj (row parallel)
w2_qweight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
//
bit8_pack_factor
,
dtype
=
torch
.
uint8
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_qweight"
,
w2_qweight
)
set_weight_attrs
(
w2_qweight
,
extra_weight_attrs
)
w13_scales
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
//
group_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_scales"
,
w13_scales
)
set_weight_attrs
(
w13_scales
,
extra_weight_attrs
)
w2_scales
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
//
group_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_scales"
,
w2_scales
)
set_weight_attrs
(
w2_scales
,
extra_weight_attrs
)
if
self
.
quant_config
.
has_zp
:
w13_qzeros
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
2
*
intermediate_size_per_partition
//
bit8_pack_factor
,
hidden_size
//
group_size
,
dtype
=
torch
.
uint8
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_qzeros"
,
w13_qzeros
)
set_weight_attrs
(
w13_qzeros
,
extra_weight_attrs
)
w2_qzeros
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
hidden_size
//
bit8_pack_factor
,
intermediate_size_per_partition
//
group_size
,
dtype
=
torch
.
uint8
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_qzeros"
,
w2_qzeros
)
set_weight_attrs
(
w2_qzeros
,
extra_weight_attrs
)
if
self
.
quant_config
.
linear_quant_method
==
"gptq"
:
# some param are unused, but we need to init them in order to
# load weights
invalid_param_keys
=
[
"w13_g_idx"
,
"w2_g_idx"
]
if
not
self
.
quant_config
.
has_zp
:
invalid_param_keys
+=
[
"w13_qzeros"
,
"w2_qzeros"
]
for
key
in
invalid_param_keys
:
param
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
0
,
),
dtype
=
torch
.
int32
),
requires_grad
=
False
)
layer
.
register_parameter
(
key
,
param
)
set_weight_attrs
(
param
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
weight_bits
=
self
.
quant_config
.
weight_bits
has_zp
=
self
.
quant_config
.
has_zp
return
fused_experts
(
x
,
layer
.
w13_qweight
,
layer
.
w2_qweight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
use_int4_w4a16
=
weight_bits
==
4
,
use_int8_w8a16
=
weight_bits
==
8
,
w1_scale
=
layer
.
w13_scales
,
w2_scale
=
layer
.
w2_scales
,
w1_zp
=
layer
.
w13_qzeros
if
has_zp
else
None
,
w2_zp
=
layer
.
w2_qzeros
if
has_zp
else
None
,
block_shape
=
[
0
,
layer
.
group_size
])
@
staticmethod
def
get_weight_loader
(
layer
,
weight_loader
):
def
convert_awq_tensor
(
tensor
,
tensor_type
):
# convert awq qweight/qzeros to a standard format (assume int4)
# qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8)
# qzeros: (k // group_size, n // pack_factor_bit32) ->
# (n // pack_factor_bit8, k // group_size)
# pack_factor_bit32 = 32 // weight_bits
# pack_factor_bit8 = 8 // weight_bits
# 0. suppose origin shape (a, b), dtype int32
# 1. convert to uint8, shape (a, b) -> (a, 4 * b)
size0
=
tensor
.
size
(
0
)
tensor
=
tensor
.
view
(
torch
.
uint8
)
# 2. unpack to uint4 (only when weight_bits == 4)
# shape (a, 4 * b) -> (a, 4 * b, 2)
shifter
=
torch
.
tensor
([
0
,
4
],
dtype
=
torch
.
uint8
,
device
=
tensor
.
device
)
tensor
=
(
tensor
[:,
:,
None
]
>>
shifter
)
&
0xF
# 3. change order, see
# https://github.com/casper-hansen/AutoAWQ/blob/v0.2.8/awq/utils/quant_utils.py
# shape -> (a, 4 * b * pack_factor_bit8)
reverse_awq_pack_order
=
[
0
,
4
,
1
,
5
,
2
,
6
,
3
,
7
]
tensor
=
tensor
.
view
(
-
1
,
8
)[:,
reverse_awq_pack_order
]
tensor
=
tensor
.
view
(
size0
,
-
1
)
# 4. transpose, shape -> (4 * b * pack_factor_bit8, a)
tensor
=
tensor
.
T
.
contiguous
()
# 5. repack (only when weight_bits == 4)
# qweight shape -> (4 * b * pack_factor_bit8, a // pack_factor_bit8)
# qzeros shape -> (4 * b, a)
if
tensor_type
==
"qweight"
:
tensor
=
tensor
[:,
1
::
2
]
*
16
+
tensor
[:,
::
2
]
elif
tensor_type
==
"qzeros"
:
tensor
=
tensor
[
1
::
2
,
:]
*
16
+
tensor
[::
2
,
:]
return
tensor
def
convert_gptq_int4_qzeros
(
tensor
):
tensor
=
tensor
.
view
(
torch
.
uint8
)
shifter
=
torch
.
tensor
([
0
,
4
],
dtype
=
torch
.
uint8
,
device
=
tensor
.
device
)
tensor
=
(
tensor
[:,
:,
None
]
>>
shifter
)
&
0xF
tensor
=
tensor
+
1
tensor
=
tensor
[:,
:,
0
]
+
tensor
[:,
:,
1
]
*
16
return
tensor
def
moe_wna16_weight_loader
(
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
str
,
expert_id
:
int
):
if
"g_idx"
in
weight_name
:
return
if
not
layer
.
quant_config
.
has_zp
and
"qzeros"
in
weight_name
:
return
device
=
get_tp_group
().
device
tp_rank
=
get_tensor_model_parallel_rank
()
loaded_weight
=
loaded_weight
.
to
(
device
)
shard_size
=
layer
.
intermediate_size_per_partition
# convert gptq and awq weight to a standard format
if
layer
.
quant_config
.
linear_quant_method
==
"awq"
:
assert
layer
.
quant_config
.
weight_bits
==
4
if
"weight"
in
weight_name
:
loaded_weight
=
convert_awq_tensor
(
loaded_weight
,
"qweight"
)
elif
"zeros"
in
weight_name
:
loaded_weight
=
convert_awq_tensor
(
loaded_weight
,
"qzeros"
)
else
:
loaded_weight
=
loaded_weight
.
T
elif
layer
.
quant_config
.
linear_quant_method
==
"gptq"
:
assert
layer
.
quant_config
.
weight_bits
in
[
4
,
8
]
if
"weight"
in
weight_name
:
loaded_weight
=
loaded_weight
.
T
.
contiguous
().
view
(
torch
.
uint8
)
elif
"zeros"
in
weight_name
:
# add 1 to gptq qzeros to align with awq
loaded_weight
=
loaded_weight
.
view
(
torch
.
uint8
)
if
layer
.
quant_config
.
weight_bits
==
4
:
loaded_weight
=
convert_gptq_int4_qzeros
(
loaded_weight
).
T
else
:
loaded_weight
=
loaded_weight
.
T
+
1
else
:
loaded_weight
=
loaded_weight
.
T
# repeat the qzeros/scales to fit new group size
if
layer
.
group_size_div_factor
>
1
and
\
"qzeros"
in
weight_name
or
"scales"
in
weight_name
:
loaded_weight
=
loaded_weight
.
repeat_interleave
(
layer
.
group_size_div_factor
,
1
)
if
"w13_qzeros"
in
weight_name
:
tensor
=
loaded_weight
.
view
(
layer
.
tp_size
,
-
1
,
loaded_weight
.
size
(
1
))[
tp_rank
]
if
shard_id
==
"w1"
:
param
.
data
[
expert_id
,
:
shard_size
//
2
]
=
tensor
else
:
param
.
data
[
expert_id
,
shard_size
//
2
:]
=
tensor
elif
"w2_qzeros"
in
weight_name
:
param
.
data
[
expert_id
]
=
loaded_weight
.
view
(
loaded_weight
.
size
(
0
),
layer
.
tp_size
,
-
1
)[:,
tp_rank
]
else
:
weight_loader
(
param
,
loaded_weight
,
weight_name
,
shard_id
,
expert_id
)
return
moe_wna16_weight_loader
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment