Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9db52eab
Unverified
Commit
9db52eab
authored
Sep 06, 2024
by
rasmith
Committed by
GitHub
Sep 06, 2024
Browse files
[Kernel] [Triton] Memory optimization for awq_gemm and awq_dequantize, 2x throughput (#8248)
parent
1447c97e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
11 deletions
+23
-11
vllm/model_executor/layers/quantization/awq_triton.py
vllm/model_executor/layers/quantization/awq_triton.py
+23
-11
No files found.
vllm/model_executor/layers/quantization/awq_triton.py
View file @
9db52eab
...
@@ -22,7 +22,7 @@ def awq_dequantize_kernel(
...
@@ -22,7 +22,7 @@ def awq_dequantize_kernel(
# Compute offsets and masks for qweight_ptr.
# Compute offsets and masks for qweight_ptr.
offsets_y
=
pid_y
*
BLOCK_SIZE_Y
+
tl
.
arange
(
0
,
BLOCK_SIZE_Y
)
offsets_y
=
pid_y
*
BLOCK_SIZE_Y
+
tl
.
arange
(
0
,
BLOCK_SIZE_Y
)
offsets_x
=
pid_x
*
BLOCK_SIZE_X
+
tl
.
arange
(
0
,
BLOCK_SIZE_X
*
8
)
//
8
offsets_x
=
pid_x
*
BLOCK_SIZE_X
+
tl
.
arange
(
0
,
BLOCK_SIZE_X
)
offsets
=
num_cols
*
offsets_y
[:,
None
]
+
offsets_x
[
None
,
:]
offsets
=
num_cols
*
offsets_y
[:,
None
]
+
offsets_x
[
None
,
:]
masks_y
=
offsets_y
<
num_rows
masks_y
=
offsets_y
<
num_rows
...
@@ -43,6 +43,9 @@ def awq_dequantize_kernel(
...
@@ -43,6 +43,9 @@ def awq_dequantize_kernel(
# Load the weights.
# Load the weights.
iweights
=
tl
.
load
(
qweight_ptr
+
offsets
,
masks
)
iweights
=
tl
.
load
(
qweight_ptr
+
offsets
,
masks
)
iweights
=
tl
.
interleave
(
iweights
,
iweights
)
iweights
=
tl
.
interleave
(
iweights
,
iweights
)
iweights
=
tl
.
interleave
(
iweights
,
iweights
)
# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
# that will map given indices to the correct order.
# that will map given indices to the correct order.
...
@@ -59,9 +62,8 @@ def awq_dequantize_kernel(
...
@@ -59,9 +62,8 @@ def awq_dequantize_kernel(
iweights
=
(
iweights
>>
shifts
)
&
0xF
iweights
=
(
iweights
>>
shifts
)
&
0xF
# Compute zero offsets and masks.
# Compute zero offsets and masks.
zero_offsets_y
=
(
pid_y
*
BLOCK_SIZE_Y
//
group_size
+
zero_offsets_y
=
pid_y
*
BLOCK_SIZE_Y
//
group_size
+
tl
.
arange
(
0
,
1
)
tl
.
arange
(
0
,
BLOCK_SIZE_Y
)
//
group_size
)
zero_offsets_x
=
pid_x
*
BLOCK_SIZE_X
+
tl
.
arange
(
0
,
BLOCK_SIZE_X
)
zero_offsets_x
=
pid_x
*
BLOCK_SIZE_X
+
tl
.
arange
(
0
,
BLOCK_SIZE_X
*
8
)
//
8
zero_offsets
=
num_cols
*
zero_offsets_y
[:,
None
]
+
zero_offsets_x
[
None
,
:]
zero_offsets
=
num_cols
*
zero_offsets_y
[:,
None
]
+
zero_offsets_x
[
None
,
:]
zero_masks_y
=
zero_offsets_y
<
num_rows
//
group_size
zero_masks_y
=
zero_offsets_y
<
num_rows
//
group_size
...
@@ -70,13 +72,16 @@ def awq_dequantize_kernel(
...
@@ -70,13 +72,16 @@ def awq_dequantize_kernel(
# Load the zeros.
# Load the zeros.
zeros
=
tl
.
load
(
zeros_ptr
+
zero_offsets
,
zero_masks
)
zeros
=
tl
.
load
(
zeros_ptr
+
zero_offsets
,
zero_masks
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
zeros
=
tl
.
broadcast_to
(
zeros
,
(
BLOCK_SIZE_Y
,
BLOCK_SIZE_X
*
8
))
# Unpack and reorder: shift out the correct 4-bit value and mask.
# Unpack and reorder: shift out the correct 4-bit value and mask.
zeros
=
(
zeros
>>
shifts
)
&
0xF
zeros
=
(
zeros
>>
shifts
)
&
0xF
# Compute scale offsets and masks.
# Compute scale offsets and masks.
scale_offsets_y
=
(
pid_y
*
BLOCK_SIZE_Y
//
group_size
+
scale_offsets_y
=
pid_y
*
BLOCK_SIZE_Y
//
group_size
+
tl
.
arange
(
0
,
1
)
tl
.
arange
(
0
,
BLOCK_SIZE_Y
)
//
group_size
)
scale_offsets_x
=
(
pid_x
*
BLOCK_SIZE_X
*
8
+
scale_offsets_x
=
(
pid_x
*
BLOCK_SIZE_X
*
8
+
tl
.
arange
(
0
,
BLOCK_SIZE_X
*
8
))
tl
.
arange
(
0
,
BLOCK_SIZE_X
*
8
))
scale_offsets
=
(
num_cols
*
8
*
scale_offsets_y
[:,
None
]
+
scale_offsets
=
(
num_cols
*
8
*
scale_offsets_y
[:,
None
]
+
...
@@ -87,6 +92,7 @@ def awq_dequantize_kernel(
...
@@ -87,6 +92,7 @@ def awq_dequantize_kernel(
# Load the scales.
# Load the scales.
scales
=
tl
.
load
(
scales_ptr
+
scale_offsets
,
scale_masks
)
scales
=
tl
.
load
(
scales_ptr
+
scale_offsets
,
scale_masks
)
scales
=
tl
.
broadcast_to
(
scales
,
(
BLOCK_SIZE_Y
,
BLOCK_SIZE_X
*
8
))
# Dequantize.
# Dequantize.
iweights
=
(
iweights
-
zeros
)
*
scales
iweights
=
(
iweights
-
zeros
)
*
scales
...
@@ -137,12 +143,10 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
...
@@ -137,12 +143,10 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
offsets_am
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offsets_am
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
masks_am
=
offsets_am
<
M
masks_am
=
offsets_am
<
M
offsets_bn
=
(
pid_n
*
(
BLOCK_SIZE_N
//
8
)
+
offsets_bn
=
pid_n
*
(
BLOCK_SIZE_N
//
8
)
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
//
8
)
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
//
8
)
masks_bn
=
offsets_bn
<
N
//
8
masks_bn
=
offsets_bn
<
N
//
8
offsets_zn
=
(
pid_n
*
(
BLOCK_SIZE_N
//
8
)
+
offsets_zn
=
pid_n
*
(
BLOCK_SIZE_N
//
8
)
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
//
8
)
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
//
8
)
masks_zn
=
offsets_zn
<
N
//
8
masks_zn
=
offsets_zn
<
N
//
8
offsets_sn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
offsets_sn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
...
@@ -165,22 +169,30 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
...
@@ -165,22 +169,30 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
masks_b
=
masks_k
[:,
None
]
&
masks_bn
[
None
,
:]
masks_b
=
masks_k
[:,
None
]
&
masks_bn
[
None
,
:]
b
=
tl
.
load
(
b_ptrs
,
mask
=
masks_b
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
masks_b
)
b
=
tl
.
interleave
(
b
,
b
)
b
=
tl
.
interleave
(
b
,
b
)
b
=
tl
.
interleave
(
b
,
b
)
# Dequantize b.
# Dequantize b.
offsets_szk
=
(
offsets_szk
=
(
(
BLOCK_SIZE_K
*
SPLIT_K
*
k
+
pid_z
*
BLOCK_SIZE_K
)
//
group_size
+
(
BLOCK_SIZE_K
*
SPLIT_K
*
k
+
pid_z
*
BLOCK_SIZE_K
)
//
group_size
+
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
//
group_size
)
tl
.
arange
(
0
,
1
)
)
offsets_z
=
(
N
//
8
)
*
offsets_szk
[:,
None
]
+
offsets_zn
[
None
,
:]
offsets_z
=
(
N
//
8
)
*
offsets_szk
[:,
None
]
+
offsets_zn
[
None
,
:]
masks_zk
=
offsets_szk
<
K
//
group_size
masks_zk
=
offsets_szk
<
K
//
group_size
masks_z
=
masks_zk
[:,
None
]
&
masks_zn
[
None
,
:]
masks_z
=
masks_zk
[:,
None
]
&
masks_zn
[
None
,
:]
zeros_ptrs
=
zeros_ptr
+
offsets_z
zeros_ptrs
=
zeros_ptr
+
offsets_z
zeros
=
tl
.
load
(
zeros_ptrs
,
mask
=
masks_z
)
zeros
=
tl
.
load
(
zeros_ptrs
,
mask
=
masks_z
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
zeros
=
tl
.
broadcast_to
(
zeros
,
(
BLOCK_SIZE_K
,
BLOCK_SIZE_N
))
offsets_s
=
N
*
offsets_szk
[:,
None
]
+
offsets_sn
[
None
,
:]
offsets_s
=
N
*
offsets_szk
[:,
None
]
+
offsets_sn
[
None
,
:]
masks_sk
=
offsets_szk
<
K
//
group_size
masks_sk
=
offsets_szk
<
K
//
group_size
masks_s
=
masks_sk
[:,
None
]
&
masks_sn
[
None
,
:]
masks_s
=
masks_sk
[:,
None
]
&
masks_sn
[
None
,
:]
scales_ptrs
=
scales_ptr
+
offsets_s
scales_ptrs
=
scales_ptr
+
offsets_s
scales
=
tl
.
load
(
scales_ptrs
,
mask
=
masks_s
)
scales
=
tl
.
load
(
scales_ptrs
,
mask
=
masks_s
)
scales
=
tl
.
broadcast_to
(
scales
,
(
BLOCK_SIZE_K
,
BLOCK_SIZE_N
))
b
=
(
b
>>
shifts
)
&
0xF
b
=
(
b
>>
shifts
)
&
0xF
zeros
=
(
zeros
>>
shifts
)
&
0xF
zeros
=
(
zeros
>>
shifts
)
&
0xF
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment