Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7750b91c
Unverified
Commit
7750b91c
authored
Jul 18, 2025
by
Hubert Lu
Committed by
GitHub
Jul 18, 2025
Browse files
[AMD] Add triton awq_dequantize kernel to support AWQ on ROCm (#7661)
parent
c8f31042
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
530 additions
and
3 deletions
+530
-3
python/sglang/srt/layers/quantization/awq.py
python/sglang/srt/layers/quantization/awq.py
+10
-2
python/sglang/srt/layers/quantization/awq_triton.py
python/sglang/srt/layers/quantization/awq_triton.py
+339
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+5
-1
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_awq_dequant.py
test/srt/test_awq_dequant.py
+175
-0
No files found.
python/sglang/srt/layers/quantization/awq.py
View file @
7750b91c
...
...
@@ -43,11 +43,20 @@ try:
except
ImportError
:
ops
=
None
from
sglang.srt.utils
import
is_cuda
from
sglang.srt.utils
import
is_cuda
,
is_hip
_is_cuda
=
is_cuda
()
_is_hip
=
is_hip
()
if
_is_cuda
:
from
sgl_kernel
import
awq_dequantize
,
fused_marlin_moe
elif
_is_hip
:
from
sglang.srt.layers.quantization.awq_triton
import
(
awq_dequantize_triton
as
awq_dequantize
,
)
warnings
.
warn
(
f
"HIP does not support fused_marlin_moe currently."
)
else
:
warnings
.
warn
(
f
"Only CUDA and HIP support AWQ currently."
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -398,7 +407,6 @@ class AWQLinearMethod(LinearMethodBase):
pack_factor
=
self
.
quant_config
.
pack_factor
out_shape
=
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
-
1
]
*
pack_factor
,)
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out
=
awq_dequantize
(
qweight
,
scales
,
qzeros
)
out
=
torch
.
matmul
(
reshaped_x
,
out
)
...
...
python/sglang/srt/layers/quantization/awq_triton.py
0 → 100644
View file @
7750b91c
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/awq_triton.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
import
triton
import
triton.language
as
tl
AWQ_TRITON_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
@
triton
.
jit
def
awq_dequantize_kernel
(
qweight_ptr
,
# quantized matrix
scales_ptr
,
# scales, per group
zeros_ptr
,
# zeros, per group
group_size
,
# Should always be one of the supported group sizes
result_ptr
,
# Output matrix
num_cols
,
# input num cols in qweight
num_rows
,
# input num rows in qweight
BLOCK_SIZE_X
:
tl
.
constexpr
,
BLOCK_SIZE_Y
:
tl
.
constexpr
,
):
# Setup the pids.
pid_x
=
tl
.
program_id
(
axis
=
0
)
pid_y
=
tl
.
program_id
(
axis
=
1
)
# Compute offsets and masks for qweight_ptr.
offsets_y
=
pid_y
*
BLOCK_SIZE_Y
+
tl
.
arange
(
0
,
BLOCK_SIZE_Y
)
offsets_x
=
pid_x
*
BLOCK_SIZE_X
+
tl
.
arange
(
0
,
BLOCK_SIZE_X
)
offsets
=
num_cols
*
offsets_y
[:,
None
]
+
offsets_x
[
None
,
:]
masks_y
=
offsets_y
<
num_rows
masks_x
=
offsets_x
<
num_cols
masks
=
masks_y
[:,
None
]
&
masks_x
[
None
,
:]
# Compute offsets and masks for result output ptr.
result_offsets_y
=
pid_y
*
BLOCK_SIZE_Y
+
tl
.
arange
(
0
,
BLOCK_SIZE_Y
)
result_offsets_x
=
pid_x
*
BLOCK_SIZE_X
*
8
+
tl
.
arange
(
0
,
BLOCK_SIZE_X
*
8
)
result_offsets
=
(
8
*
num_cols
*
result_offsets_y
[:,
None
]
+
result_offsets_x
[
None
,
:]
)
result_masks_y
=
result_offsets_y
<
num_rows
result_masks_x
=
result_offsets_x
<
num_cols
*
8
result_masks
=
result_masks_y
[:,
None
]
&
result_masks_x
[
None
,
:]
# Load the weights.
iweights
=
tl
.
load
(
qweight_ptr
+
offsets
,
masks
,
0.0
)
iweights
=
tl
.
interleave
(
iweights
,
iweights
)
iweights
=
tl
.
interleave
(
iweights
,
iweights
)
iweights
=
tl
.
interleave
(
iweights
,
iweights
)
# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
# that will map given indices to the correct order.
reverse_awq_order_tensor
=
(
(
tl
.
arange
(
0
,
2
)
*
4
)[
None
,
:]
+
tl
.
arange
(
0
,
4
)[:,
None
]
).
reshape
(
8
)
# Use this to compute a set of shifts that can be used to unpack and
# reorder the values in iweights and zeros.
shifts
=
reverse_awq_order_tensor
*
4
shifts
=
tl
.
broadcast_to
(
shifts
[
None
,
:],
(
BLOCK_SIZE_Y
*
BLOCK_SIZE_X
,
8
))
shifts
=
tl
.
reshape
(
shifts
,
(
BLOCK_SIZE_Y
,
BLOCK_SIZE_X
*
8
))
# Unpack and reorder: shift out the correct 4-bit value and mask.
iweights
=
(
iweights
>>
shifts
)
&
0xF
# Compute zero offsets and masks.
zero_offsets_y
=
pid_y
*
BLOCK_SIZE_Y
//
group_size
+
tl
.
arange
(
0
,
1
)
zero_offsets_x
=
pid_x
*
BLOCK_SIZE_X
+
tl
.
arange
(
0
,
BLOCK_SIZE_X
)
zero_offsets
=
num_cols
*
zero_offsets_y
[:,
None
]
+
zero_offsets_x
[
None
,
:]
zero_masks_y
=
zero_offsets_y
<
num_rows
//
group_size
zero_masks_x
=
zero_offsets_x
<
num_cols
zero_masks
=
zero_masks_y
[:,
None
]
&
zero_masks_x
[
None
,
:]
# Load the zeros.
zeros
=
tl
.
load
(
zeros_ptr
+
zero_offsets
,
zero_masks
,
0.0
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
zeros
=
tl
.
broadcast_to
(
zeros
,
(
BLOCK_SIZE_Y
,
BLOCK_SIZE_X
*
8
))
# Unpack and reorder: shift out the correct 4-bit value and mask.
zeros
=
(
zeros
>>
shifts
)
&
0xF
# Compute scale offsets and masks.
scale_offsets_y
=
pid_y
*
BLOCK_SIZE_Y
//
group_size
+
tl
.
arange
(
0
,
1
)
scale_offsets_x
=
pid_x
*
BLOCK_SIZE_X
*
8
+
tl
.
arange
(
0
,
BLOCK_SIZE_X
*
8
)
scale_offsets
=
num_cols
*
8
*
scale_offsets_y
[:,
None
]
+
scale_offsets_x
[
None
,
:]
scale_masks_y
=
scale_offsets_y
<
num_rows
//
group_size
scale_masks_x
=
scale_offsets_x
<
num_cols
*
8
scale_masks
=
scale_masks_y
[:,
None
]
&
scale_masks_x
[
None
,
:]
# Load the scales.
scales
=
tl
.
load
(
scales_ptr
+
scale_offsets
,
scale_masks
,
0.0
)
scales
=
tl
.
broadcast_to
(
scales
,
(
BLOCK_SIZE_Y
,
BLOCK_SIZE_X
*
8
))
# Dequantize.
iweights
=
(
iweights
-
zeros
)
*
scales
iweights
=
iweights
.
to
(
result_ptr
.
type
.
element_ty
)
# Finally, store.
tl
.
store
(
result_ptr
+
result_offsets
,
iweights
,
result_masks
)
@
triton
.
jit
def
awq_gemm_kernel
(
a_ptr
,
b_ptr
,
c_ptr
,
zeros_ptr
,
scales_ptr
,
M
,
N
,
K
,
group_size
,
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
axis
=
0
)
pid_z
=
tl
.
program_id
(
1
)
# NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead.
# num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE_N
)
pid_m
=
pid
//
num_pid_n
pid_n
=
pid
%
num_pid_n
accumulator_dtype
=
c_ptr
.
type
.
element_ty
# NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead.
# accumulator = tl.arange(0, BLOCK_SIZE_N)
# accumulator = tl.broadcast_to(accumulator[None, :],
# (BLOCK_SIZE_M, BLOCK_SIZE_N))
# accumulator = accumulator & 0x0
# accumulator = accumulator.to(accumulator_dtype)
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
accumulator_dtype
)
# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
# that will map given indices to the correct order.
reverse_awq_order_tensor
=
(
(
tl
.
arange
(
0
,
2
)
*
4
)[
None
,
:]
+
tl
.
arange
(
0
,
4
)[:,
None
]
).
reshape
(
8
)
# Create the necessary shifts to use to unpack.
shifts
=
reverse_awq_order_tensor
*
4
shifts
=
tl
.
broadcast_to
(
shifts
[
None
,
:],
(
BLOCK_SIZE_K
*
(
BLOCK_SIZE_N
//
8
),
8
))
shifts
=
tl
.
reshape
(
shifts
,
(
BLOCK_SIZE_K
,
BLOCK_SIZE_N
))
# Offsets and masks.
offsets_am
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
masks_am
=
offsets_am
<
M
offsets_bn
=
pid_n
*
(
BLOCK_SIZE_N
//
8
)
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
//
8
)
masks_bn
=
offsets_bn
<
N
//
8
offsets_zn
=
pid_n
*
(
BLOCK_SIZE_N
//
8
)
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
//
8
)
masks_zn
=
offsets_zn
<
N
//
8
offsets_sn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
masks_sn
=
offsets_sn
<
N
offsets_k
=
pid_z
*
BLOCK_SIZE_K
+
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
offsets_a
=
K
*
offsets_am
[:,
None
]
+
offsets_k
[
None
,
:]
offsets_b
=
(
N
//
8
)
*
offsets_k
[:,
None
]
+
offsets_bn
[
None
,
:]
a_ptrs
=
a_ptr
+
offsets_a
b_ptrs
=
b_ptr
+
offsets_b
# NOTE: Use this in TRITON_INTERPRET=1 mode instead of tl.cdiv
# block_offset = BLOCK_SIZE_K * SPLIT_K
# for k in range(0, (K + block_offset - 1) // (block_offset)):
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
*
SPLIT_K
)):
masks_k
=
offsets_k
<
K
masks_a
=
masks_am
[:,
None
]
&
masks_k
[
None
,
:]
a
=
tl
.
load
(
a_ptrs
,
mask
=
masks_a
,
other
=
0.0
)
masks_b
=
masks_k
[:,
None
]
&
masks_bn
[
None
,
:]
b
=
tl
.
load
(
b_ptrs
,
mask
=
masks_b
,
other
=
0.0
)
b
=
tl
.
interleave
(
b
,
b
)
b
=
tl
.
interleave
(
b
,
b
)
b
=
tl
.
interleave
(
b
,
b
)
# Dequantize b.
offsets_szk
=
(
BLOCK_SIZE_K
*
SPLIT_K
*
k
+
pid_z
*
BLOCK_SIZE_K
)
//
group_size
+
tl
.
arange
(
0
,
1
)
offsets_z
=
(
N
//
8
)
*
offsets_szk
[:,
None
]
+
offsets_zn
[
None
,
:]
masks_zk
=
offsets_szk
<
K
//
group_size
masks_z
=
masks_zk
[:,
None
]
&
masks_zn
[
None
,
:]
zeros_ptrs
=
zeros_ptr
+
offsets_z
zeros
=
tl
.
load
(
zeros_ptrs
,
mask
=
masks_z
,
other
=
0.0
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
zeros
=
tl
.
broadcast_to
(
zeros
,
(
BLOCK_SIZE_K
,
BLOCK_SIZE_N
))
offsets_s
=
N
*
offsets_szk
[:,
None
]
+
offsets_sn
[
None
,
:]
masks_sk
=
offsets_szk
<
K
//
group_size
masks_s
=
masks_sk
[:,
None
]
&
masks_sn
[
None
,
:]
scales_ptrs
=
scales_ptr
+
offsets_s
scales
=
tl
.
load
(
scales_ptrs
,
mask
=
masks_s
,
other
=
0.0
)
scales
=
tl
.
broadcast_to
(
scales
,
(
BLOCK_SIZE_K
,
BLOCK_SIZE_N
))
b
=
(
b
>>
shifts
)
&
0xF
zeros
=
(
zeros
>>
shifts
)
&
0xF
b
=
(
b
-
zeros
)
*
scales
b
=
b
.
to
(
c_ptr
.
type
.
element_ty
)
# Accumulate results.
accumulator
=
tl
.
dot
(
a
,
b
,
accumulator
,
out_dtype
=
accumulator_dtype
)
offsets_k
+=
BLOCK_SIZE_K
*
SPLIT_K
a_ptrs
+=
BLOCK_SIZE_K
*
SPLIT_K
b_ptrs
+=
BLOCK_SIZE_K
*
SPLIT_K
*
(
N
//
8
)
c
=
accumulator
.
to
(
c_ptr
.
type
.
element_ty
)
offs_cm
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
c_ptr
+
pid_z
*
N
*
M
+
N
*
offs_cm
[:,
None
]
+
offs_cn
[
None
,
:]
c_mask
=
(
offs_cm
[:,
None
]
<
M
)
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
# qweights - [K , M // 8], int32
# scales - [K // G, M ], float16
# zeros - [K // G, M // 8], int32
def
awq_dequantize_triton
(
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
zeros
:
torch
.
Tensor
,
block_size_x
:
int
=
32
,
block_size_y
:
int
=
32
,
)
->
torch
.
Tensor
:
K
=
qweight
.
shape
[
0
]
M
=
scales
.
shape
[
1
]
group_size
=
qweight
.
shape
[
0
]
//
scales
.
shape
[
0
]
assert
K
>
0
and
M
>
0
assert
scales
.
shape
[
0
]
==
K
//
group_size
and
scales
.
shape
[
1
]
==
M
assert
zeros
.
shape
[
0
]
==
K
//
group_size
and
zeros
.
shape
[
1
]
==
M
//
8
assert
group_size
<=
K
assert
group_size
in
AWQ_TRITON_SUPPORTED_GROUP_SIZES
or
group_size
==
K
# Result tensor:
# number of rows = same as input tensor
# number of cols = 8 x input tensor num cols
result
=
torch
.
empty
(
qweight
.
shape
[
0
],
qweight
.
shape
[
1
]
*
8
,
device
=
qweight
.
device
,
dtype
=
scales
.
dtype
,
)
Y
=
qweight
.
shape
[
0
]
# num rows
X
=
qweight
.
shape
[
1
]
# num cols
grid
=
lambda
META
:
(
triton
.
cdiv
(
X
,
META
[
"BLOCK_SIZE_X"
]),
triton
.
cdiv
(
Y
,
META
[
"BLOCK_SIZE_Y"
]),
)
awq_dequantize_kernel
[
grid
](
qweight
,
scales
,
zeros
,
group_size
,
result
,
X
,
Y
,
BLOCK_SIZE_X
=
block_size_x
,
BLOCK_SIZE_Y
=
block_size_y
,
)
return
result
# input - [M, K]
# qweight - [K, N // 8]
# qzeros - [K // G, N // 8]
# scales - [K // G, N]
# split_k_iters - parallelism along K-dimension, int, power of 2.
def
awq_gemm_triton
(
input
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
,
split_k_iters
:
int
,
block_size_m
:
int
=
32
,
block_size_n
:
int
=
32
,
block_size_k
:
int
=
32
,
)
->
torch
.
Tensor
:
M
,
K
=
input
.
shape
N
=
qweight
.
shape
[
1
]
*
8
group_size
=
qweight
.
shape
[
0
]
//
qzeros
.
shape
[
0
]
assert
N
>
0
and
K
>
0
and
M
>
0
assert
qweight
.
shape
[
0
]
==
K
and
qweight
.
shape
[
1
]
==
N
//
8
assert
qzeros
.
shape
[
0
]
==
K
//
group_size
and
qzeros
.
shape
[
1
]
==
N
//
8
assert
scales
.
shape
[
0
]
==
K
//
group_size
and
scales
.
shape
[
1
]
==
N
assert
split_k_iters
&
(
split_k_iters
-
1
)
==
0
and
split_k_iters
!=
0
assert
split_k_iters
<=
32
assert
group_size
<=
K
assert
group_size
in
AWQ_TRITON_SUPPORTED_GROUP_SIZES
or
group_size
==
K
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_SIZE_N"
]),
split_k_iters
,
)
result
=
torch
.
zeros
((
split_k_iters
,
M
,
N
),
dtype
=
scales
.
dtype
,
device
=
input
.
device
)
# A = input, B = qweight, C = result
# A = M x K, B = K x N, C = M x N
awq_gemm_kernel
[
grid
](
input
,
qweight
,
result
,
qzeros
,
scales
,
M
,
N
,
K
,
group_size
,
BLOCK_SIZE_M
=
block_size_m
,
BLOCK_SIZE_N
=
block_size_n
,
BLOCK_SIZE_K
=
block_size_k
,
SPLIT_K
=
split_k_iters
,
)
result
=
result
.
sum
(
0
)
return
result
python/sglang/srt/models/deepseek_v2.py
View file @
7750b91c
...
...
@@ -127,6 +127,10 @@ if _is_cuda:
)
elif
_is_cpu
and
_is_cpu_amx_available
:
pass
elif
_is_hip
:
from
sglang.srt.layers.quantization.awq_triton
import
(
awq_dequantize_triton
as
awq_dequantize
,
)
else
:
from
vllm._custom_ops
import
awq_dequantize
...
...
@@ -2176,7 +2180,7 @@ class DeepseekV2ForCausalLM(nn.Module):
)
if
hasattr
(
self_attn
.
kv_b_proj
,
"qweight"
):
# AWQ compatible
if
_is_cuda
:
if
_is_cuda
or
_is_hip
:
w
=
awq_dequantize
(
self_attn
.
kv_b_proj
.
qweight
,
self_attn
.
kv_b_proj
.
scales
,
...
...
test/srt/run_suite.py
View file @
7750b91c
...
...
@@ -147,6 +147,7 @@ suites = {
# TestFile("test_vision_chunked_prefill.py", 175), # Disabled temporarily and track in #7701
TestFile
(
"test_reasoning_parser.py"
,
5
),
TestFile
(
"test_rope_rocm.py"
,
3
),
TestFile
(
"test_awq_dequant.py"
,
2
),
],
"per-commit-npu"
:
[
TestFile
(
"test_ascend_attention_backend.py"
,
400
),
...
...
test/srt/test_awq_dequant.py
0 → 100644
View file @
7750b91c
# Adapted from https://github.com/vllm-project/vllm/blob/main/tests/kernels/quantization/test_awq_triton.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
unittest version of the AWQ Triton kernel tests.
Run with:
python -m unittest test_awq_dequant.py
"""
import
unittest
import
torch
from
sglang.srt.layers.quantization.awq_triton
import
(
AWQ_TRITON_SUPPORTED_GROUP_SIZES
,
awq_dequantize_triton
,
awq_gemm_triton
,
)
from
sglang.test.test_utils
import
CustomTestCase
device
=
"cuda"
def
reverse_awq_order
(
t
:
torch
.
Tensor
)
->
torch
.
Tensor
:
bits
=
4
AWQ_REVERSE_ORDER
=
[
0
,
4
,
1
,
5
,
2
,
6
,
3
,
7
]
idx
=
torch
.
arange
(
t
.
shape
[
-
1
],
dtype
=
torch
.
int32
,
device
=
t
.
device
)
idx
=
idx
.
view
(
-
1
,
32
//
bits
)[:,
AWQ_REVERSE_ORDER
].
view
(
-
1
)
return
(
t
[:,
idx
]
&
0xF
).
contiguous
()
def
awq_dequantize_torch
(
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
,
group_size
:
int
,
)
->
torch
.
Tensor
:
if
group_size
==
-
1
:
group_size
=
qweight
.
shape
[
0
]
bits
=
4
shifts
=
torch
.
arange
(
0
,
32
,
bits
,
device
=
qzeros
.
device
)
iweights
=
torch
.
bitwise_right_shift
(
qweight
[:,
:,
None
],
shifts
[
None
,
None
,
:]).
to
(
torch
.
int8
)
iweights
=
reverse_awq_order
(
iweights
.
view
(
iweights
.
shape
[
0
],
-
1
))
zeros
=
torch
.
bitwise_right_shift
(
qzeros
[:,
:,
None
],
shifts
[
None
,
None
,
:]).
to
(
torch
.
int8
)
zeros
=
reverse_awq_order
(
zeros
.
view
(
qzeros
.
shape
[
0
],
-
1
))
iweights
=
torch
.
bitwise_and
(
iweights
,
(
2
**
bits
)
-
1
)
zeros
=
torch
.
bitwise_and
(
zeros
,
(
2
**
bits
)
-
1
)
scales
=
scales
.
repeat_interleave
(
group_size
,
dim
=
0
)
zeros
=
zeros
.
repeat_interleave
(
group_size
,
dim
=
0
)
return
(
iweights
-
zeros
)
*
scales
class
TestAWQTriton
(
CustomTestCase
):
def
test_dequantize
(
self
):
rows_list
=
[
3584
,
18944
,
128
,
256
,
512
,
1024
]
cols_list
=
[
448
,
576
,
4736
,
16
,
32
,
64
,
128
]
for
qweight_rows
in
rows_list
:
for
qweight_cols
in
cols_list
:
for
group_size
in
AWQ_TRITON_SUPPORTED_GROUP_SIZES
:
with
self
.
subTest
(
rows
=
qweight_rows
,
cols
=
qweight_cols
,
g
=
group_size
):
self
.
_run_dequant_case
(
qweight_rows
=
qweight_rows
,
qweight_cols
=
qweight_cols
,
group_size
=
group_size
,
)
def
_run_dequant_case
(
self
,
qweight_rows
,
qweight_cols
,
group_size
):
if
group_size
==
-
1
:
group_size
=
qweight_rows
torch
.
manual_seed
(
0
)
qweight
=
torch
.
randint
(
0
,
torch
.
iinfo
(
torch
.
int32
).
max
,
(
qweight_rows
,
qweight_cols
),
dtype
=
torch
.
int32
,
device
=
device
,
)
scales
=
torch
.
rand
(
qweight_rows
//
group_size
,
qweight_cols
*
8
,
dtype
=
torch
.
float16
,
device
=
device
,
)
zeros
=
torch
.
randint
(
0
,
torch
.
iinfo
(
torch
.
int32
).
max
,
(
qweight_rows
//
group_size
,
qweight_cols
),
dtype
=
torch
.
int32
,
device
=
device
,
)
ref
=
awq_dequantize_torch
(
qweight
,
scales
,
zeros
,
group_size
)
tri
=
awq_dequantize_triton
(
qweight
,
scales
,
zeros
)
# sanity
self
.
assertFalse
(
torch
.
any
(
torch
.
isinf
(
tri
))
or
torch
.
any
(
torch
.
isnan
(
tri
)))
torch
.
testing
.
assert_close
(
ref
,
tri
)
# GEMM
def
test_gemm
(
self
):
N_list
=
[
1
,
2
,
4
,
8
,
14
,
17
,
23
,
32
]
K_list
=
[
128
]
M_list
=
[
16
,
24
,
32
]
splitK_list
=
[
1
,
8
]
for
N
in
N_list
:
for
K
in
K_list
:
for
M
in
M_list
:
for
group_size
in
AWQ_TRITON_SUPPORTED_GROUP_SIZES
:
for
splitK
in
splitK_list
:
with
self
.
subTest
(
N
=
N
,
K
=
K
,
M
=
M
,
g
=
group_size
,
sk
=
splitK
):
self
.
_run_gemm_case
(
N
=
N
,
K
=
K
,
M
=
M
,
group_size
=
group_size
,
splitK
=
splitK
,
)
def
_run_gemm_case
(
self
,
N
,
K
,
M
,
group_size
,
splitK
):
if
group_size
==
-
1
:
group_size
=
K
torch
.
manual_seed
(
0
)
x
=
torch
.
rand
((
N
,
K
),
dtype
=
torch
.
float32
,
device
=
device
)
qweight
=
torch
.
randint
(
0
,
torch
.
iinfo
(
torch
.
int32
).
max
,
(
K
,
M
//
8
),
dtype
=
torch
.
int32
,
device
=
device
,
)
qzeros
=
torch
.
randint
(
0
,
torch
.
iinfo
(
torch
.
int32
).
max
,
(
K
//
group_size
,
M
//
8
),
dtype
=
torch
.
int32
,
device
=
device
,
)
scales
=
torch
.
rand
((
K
//
group_size
,
M
),
dtype
=
torch
.
float32
,
device
=
device
)
tri_out
=
awq_gemm_triton
(
x
,
qweight
,
scales
,
qzeros
,
splitK
)
self
.
assertFalse
(
torch
.
any
(
torch
.
isinf
(
tri_out
))
or
torch
.
any
(
torch
.
isnan
(
tri_out
))
)
# dequantize & compare
w_deq
=
awq_dequantize_triton
(
qweight
,
scales
,
qzeros
)
ref_out
=
torch
.
matmul
(
x
,
w_deq
)
self
.
assertFalse
(
torch
.
any
(
torch
.
isinf
(
ref_out
))
or
torch
.
any
(
torch
.
isnan
(
ref_out
))
)
torch
.
testing
.
assert_close
(
tri_out
.
cpu
(),
ref_out
.
cpu
(),
atol
=
1e-1
,
rtol
=
1e-1
)
if
__name__
==
"__main__"
:
unittest
.
main
(
verbosity
=
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment