Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
396d92d5
Unverified
Commit
396d92d5
authored
Jul 21, 2024
by
Alexander Matveev
Committed by
GitHub
Jul 21, 2024
Browse files
[Kernel][Core] Add AWQ support to the Marlin kernel (#6612)
parent
25e778aa
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
149 additions
and
1 deletion
+149
-1
vllm/model_executor/layers/quantization/utils/quant_utils.py
vllm/model_executor/layers/quantization/utils/quant_utils.py
+149
-1
No files found.
vllm/model_executor/layers/quantization/utils/quant_utils.py
View file @
396d92d5
...
@@ -106,6 +106,67 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
...
@@ -106,6 +106,67 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
)
)
def
quantize_weights_with_zp
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
):
orig_device
=
w
.
device
size_k
,
size_n
=
w
.
shape
assert
w
.
is_floating_point
(),
"w must be float"
assert
num_bits
in
SUPPORTED_NUM_BITS
,
f
"Unsupported num_bits =
{
num_bits
}
"
assert
group_size
in
SUPPORTED_GROUP_SIZES
+
[
size_k
],
f
"Unsupported groupsize =
{
group_size
}
"
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
max_q_val
=
2
**
num_bits
-
1
min_q_val
=
0
# Reshape to [groupsize, -1]
if
group_size
<
size_k
:
w
=
w
.
reshape
((
-
1
,
group_size
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
group_size
,
-
1
))
# Compute scale for each group
max
=
torch
.
max
(
w
,
0
,
keepdim
=
True
)[
0
]
min
=
torch
.
min
(
w
,
0
,
keepdim
=
True
)[
0
]
s
=
(
max
-
min
).
clamp
(
min
=
1e-5
)
/
max_q_val
# Compute zero-point for each group
zp
=
(
-
torch
.
round
(
min
/
s
)).
clamp
(
min_q_val
,
max_q_val
).
int
()
# Quantize
q_w
=
torch
.
round
(
w
/
s
).
int
()
+
zp
q_w
=
torch
.
clamp
(
q_w
,
min_q_val
,
max_q_val
)
# Compute ref (dequantized)
w_ref
=
(
q_w
-
zp
).
half
()
*
s
# Restore original shapes
if
group_size
<
size_k
:
def
reshape_w
(
w
):
w
=
w
.
reshape
((
group_size
,
-
1
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
size_k
,
size_n
)).
contiguous
()
return
w
q_w
=
reshape_w
(
q_w
)
w_ref
=
reshape_w
(
w_ref
)
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
zp
=
zp
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
(
w_ref
.
to
(
device
=
orig_device
),
q_w
.
to
(
device
=
orig_device
),
s
.
to
(
device
=
orig_device
),
zp
.
to
(
device
=
orig_device
),
)
def
sort_weights
(
q_w
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
):
def
sort_weights
(
q_w
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
):
orig_device
=
q_w
.
device
orig_device
=
q_w
.
device
...
@@ -122,7 +183,7 @@ def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
...
@@ -122,7 +183,7 @@ def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
)
)
def
gptq_pack
(
def
pack_rows
(
q_w
:
torch
.
Tensor
,
q_w
:
torch
.
Tensor
,
num_bits
:
int
,
num_bits
:
int
,
size_k
:
int
,
size_k
:
int
,
...
@@ -144,3 +205,90 @@ def gptq_pack(
...
@@ -144,3 +205,90 @@ def gptq_pack(
q_res
=
torch
.
from_numpy
(
q_res
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
q_res
=
torch
.
from_numpy
(
q_res
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
return
q_res
return
q_res
def
pack_cols
(
q_w
:
torch
.
Tensor
,
num_bits
:
int
,
size_k
:
int
,
size_n
:
int
,
):
assert
q_w
.
shape
==
(
size_k
,
size_n
)
pack_factor
=
get_pack_factor
(
num_bits
)
assert
size_n
%
pack_factor
==
0
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_res
=
numpy
.
zeros
((
size_k
,
size_n
//
pack_factor
),
dtype
=
numpy
.
uint32
)
for
i
in
range
(
pack_factor
):
q_res
|=
q_w
[:,
i
::
pack_factor
]
<<
num_bits
*
i
q_res
=
torch
.
from_numpy
(
q_res
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
q_res
=
q_res
.
contiguous
()
return
q_res
def
unpack_cols
(
packed_q_w
:
torch
.
Tensor
,
num_bits
:
int
,
size_k
:
int
,
size_n
:
int
,
):
pack_factor
=
get_pack_factor
(
num_bits
)
assert
size_n
%
pack_factor
==
0
assert
packed_q_w
.
shape
==
(
size_k
,
size_n
//
pack_factor
),
"packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}"
.
format
(
packed_q_w
.
shape
,
size_k
,
size_n
,
pack_factor
)
orig_device
=
packed_q_w
.
device
packed_q_w_cpu
=
packed_q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_res
=
numpy
.
zeros
((
size_k
,
size_n
),
dtype
=
numpy
.
uint32
)
mask
=
(
1
<<
num_bits
)
-
1
for
i
in
range
(
pack_factor
):
vals
=
packed_q_w_cpu
&
mask
packed_q_w_cpu
>>=
num_bits
q_res
[:,
i
::
pack_factor
]
=
vals
q_res
=
torch
.
from_numpy
(
q_res
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
q_res
=
q_res
.
contiguous
()
return
q_res
def
gptq_pack
(
q_w
:
torch
.
Tensor
,
num_bits
:
int
,
size_k
:
int
,
size_n
:
int
,
):
return
pack_rows
(
q_w
,
num_bits
,
size_k
,
size_n
)
def
awq_pack
(
q_w
:
torch
.
Tensor
,
num_bits
:
int
,
size_k
:
int
,
size_n
:
int
,
):
assert
q_w
.
shape
==
(
size_k
,
size_n
)
# Interleave column dim (for the dequantize code) and pack it to int32
if
num_bits
==
4
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
numpy
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
q_w
=
q_w
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
q_w
=
q_w
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
pack_cols
(
q_w
,
num_bits
,
size_k
,
size_n
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment