Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
68c727a1
Unverified
Commit
68c727a1
authored
Feb 25, 2024
by
Casper
Committed by
GitHub
Feb 25, 2024
Browse files
New optimized kernels (#365)
parent
6b7992aa
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
260 additions
and
20 deletions
+260
-20
awq/models/base.py
awq/models/base.py
+13
-5
awq/modules/linear/__init__.py
awq/modules/linear/__init__.py
+4
-3
awq/modules/linear/gemv_fast.py
awq/modules/linear/gemv_fast.py
+209
-0
awq/quantize/quantizer.py
awq/quantize/quantizer.py
+10
-4
awq/utils/fused_utils.py
awq/utils/fused_utils.py
+21
-5
examples/benchmark.py
examples/benchmark.py
+3
-2
setup.py
setup.py
+0
-1
No files found.
awq/models/base.py
View file @
68c727a1
...
@@ -12,11 +12,17 @@ from typing_extensions import Doc, Annotated
...
@@ -12,11 +12,17 @@ from typing_extensions import Doc, Annotated
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
from
transformers.modeling_utils
import
shard_checkpoint
from
transformers.modeling_utils
import
shard_checkpoint
from
awq.modules.linear.gemm
import
WQLinear_GEMM
from
awq.modules.linear
import
(
from
awq.modules.linear.gemv
import
WQLinear_GEMV
WQLinear_GEMM
,
from
awq.modules.linear.marlin
import
WQLinear_Marlin
,
marlin_post_init
WQLinear_GEMV
,
from
awq.modules.linear.exllama
import
WQLinear_Exllama
,
exllama_post_init
WQLinear_Marlin
,
from
awq.modules.linear.exllamav2
import
WQLinear_ExllamaV2
,
exllamav2_post_init
WQLinear_Exllama
,
WQLinear_ExllamaV2
,
WQLinear_GEMVFast
,
marlin_post_init
,
exllama_post_init
,
exllamav2_post_init
,
)
from
awq.utils.module
import
(
from
awq.utils.module
import
(
get_named_linears
,
get_named_linears
,
set_op_by_name
,
set_op_by_name
,
...
@@ -541,6 +547,8 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -541,6 +547,8 @@ class BaseAWQForCausalLM(nn.Module):
q_linear_module
=
WQLinear_GEMM
q_linear_module
=
WQLinear_GEMM
elif
version
==
"gemv"
:
elif
version
==
"gemv"
:
q_linear_module
=
WQLinear_GEMV
q_linear_module
=
WQLinear_GEMV
elif
version
==
"gemv_fast"
:
q_linear_module
=
WQLinear_GEMVFast
q_linear
=
q_linear_module
.
from_linear
(
q_linear
=
q_linear_module
.
from_linear
(
module
,
quant_config
.
w_bit
,
quant_config
.
q_group_size
,
True
module
,
quant_config
.
w_bit
,
quant_config
.
q_group_size
,
True
...
...
awq/modules/linear/__init__.py
View file @
68c727a1
from
.exllama
import
WQLinear_Exllama
from
.exllama
import
WQLinear_Exllama
,
exllama_post_init
from
.exllamav2
import
WQLinear_ExllamaV2
from
.exllamav2
import
WQLinear_ExllamaV2
,
exllamav2_post_init
from
.gemm
import
WQLinear_GEMM
from
.gemm
import
WQLinear_GEMM
from
.gemv
import
WQLinear_GEMV
from
.gemv
import
WQLinear_GEMV
from
.marlin
import
WQLinear_Marlin
from
.marlin
import
WQLinear_Marlin
,
marlin_post_init
from
.gemv_fast
import
WQLinear_GEMVFast
awq/modules/linear/gemv_fast.py
0 → 100644
View file @
68c727a1
import
torch
try
:
import
awq_v2_ext
# with CUDA kernels (AutoAWQ_kernels)
AWQ_INSTALLED
=
True
except
:
AWQ_INSTALLED
=
False
def
make_divisible
(
c
,
divisor
):
return
(
c
+
divisor
-
1
)
//
divisor
def
calculate_zeros_width
(
in_features
,
group_size
=
128
,
pack_num
=
8
):
if
group_size
>=
128
:
size_multiplier
=
1
elif
group_size
==
64
:
size_multiplier
=
2
elif
group_size
==
32
:
size_multiplier
=
4
else
:
raise
NotImplementedError
base_width
=
make_divisible
(
in_features
//
group_size
,
pack_num
)
base_width
=
make_divisible
(
base_width
,
size_multiplier
)
*
size_multiplier
return
base_width
def
pack_intweight
(
unpacked_qweight
,
interleave
,
kstride
):
# unpacked_qweight: [N, K]
N
=
unpacked_qweight
.
shape
[
0
]
K
=
unpacked_qweight
.
shape
[
1
]
Packed_Kernel
=
unpacked_qweight
.
cpu
().
numpy
().
reshape
(
N
,
K
//
32
,
32
)
# np.arange(32).reshape(4, 4, 2).transpose(1, 0, 2) => [0, 1, 8, 9, 16, 17, 24, 25, ...]
Packed_Kernel
=
Packed_Kernel
.
reshape
(
N
,
K
//
32
,
4
,
4
,
2
).
transpose
(
0
,
1
,
3
,
2
,
4
)
Packed_Kernel
=
Packed_Kernel
.
reshape
(
N
,
K
//
32
,
32
)
# reorder each 8 weights for fast dequantization
# [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7]
Packed_Kernel
=
Packed_Kernel
.
reshape
(
N
,
K
//
32
,
4
,
8
)
Packed_Kernel
=
Packed_Kernel
.
reshape
(
N
,
K
//
32
,
4
,
4
,
2
).
transpose
(
0
,
1
,
2
,
4
,
3
)
Packed_Kernel
=
Packed_Kernel
.
reshape
(
N
,
K
)
# interleaving every four rows
Packed_Kernel
=
Packed_Kernel
.
reshape
(
N
//
interleave
,
interleave
,
K
//
kstride
,
kstride
)
# N // 4, K // 64, 4, 64
Packed_Kernel
=
Packed_Kernel
.
transpose
(
0
,
2
,
1
,
3
)
Packed_Kernel
=
Packed_Kernel
.
reshape
(
N
//
interleave
,
K
//
kstride
,
kstride
,
interleave
)
# Packing -> (N // 4, K // 64, 64)
Packed_Kernel
=
(
Packed_Kernel
[...,
0
]
|
(
Packed_Kernel
[...,
1
]
<<
4
)
|
(
Packed_Kernel
[...,
2
]
<<
8
)
|
(
Packed_Kernel
[...,
3
]
<<
12
)
)
# reshape to (N // 4, K), FP16 format
Packed_Kernel
=
Packed_Kernel
.
reshape
(
N
//
interleave
,
K
)
qweight
=
(
torch
.
tensor
(
Packed_Kernel
.
astype
(
"int16"
))
.
to
(
unpacked_qweight
.
device
)
.
contiguous
()
)
return
qweight
class
WQLinear_GEMVFast
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
w_bit
,
group_size
,
in_features
,
out_features
,
bias
,
dev
):
super
().
__init__
()
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
w_bit
=
w_bit
self
.
group_size
=
group_size
if
group_size
!=
-
1
else
in_features
self
.
split_k_iters
=
8
self
.
interleave
=
4
# quick sanity check (make sure aligment)
assert
self
.
in_features
%
self
.
group_size
==
0
assert
out_features
%
(
32
//
self
.
w_bit
)
==
0
pack_num
=
32
//
self
.
w_bit
int16_pack_num
=
16
//
self
.
w_bit
assert
out_features
%
(
self
.
interleave
)
==
0
self
.
register_buffer
(
"qweight"
,
torch
.
zeros
(
(
out_features
//
self
.
interleave
,
in_features
//
int16_pack_num
*
self
.
interleave
,
),
dtype
=
torch
.
int16
,
device
=
dev
,
),
)
self
.
register_buffer
(
"scales"
,
torch
.
zeros
(
(
calculate_zeros_width
(
in_features
,
self
.
group_size
)
*
pack_num
,
out_features
,
),
dtype
=
torch
.
float16
,
device
=
dev
,
),
)
self
.
register_buffer
(
"qzeros"
,
torch
.
zeros
(
(
calculate_zeros_width
(
in_features
,
self
.
group_size
)
*
pack_num
,
out_features
,
),
dtype
=
torch
.
float16
,
device
=
dev
,
),
)
if
bias
:
self
.
register_buffer
(
"bias"
,
torch
.
zeros
((
out_features
),
dtype
=
torch
.
float16
,
device
=
dev
)
)
else
:
self
.
bias
=
None
@
classmethod
def
from_linear
(
cls
,
linear
,
w_bit
,
group_size
,
init_only
=
False
,
scales
=
None
,
zeros
=
None
):
awq_linear
=
cls
(
w_bit
,
group_size
,
linear
.
in_features
,
linear
.
out_features
,
linear
.
bias
is
not
None
,
linear
.
weight
.
device
,
)
if
init_only
:
return
awq_linear
# need scales and zeros info for real quantization
assert
scales
is
not
None
and
zeros
is
not
None
scale_zeros
=
zeros
*
scales
pack_num
=
32
//
awq_linear
.
w_bit
qscales
=
torch
.
zeros
(
(
scales
.
shape
[
0
],
calculate_zeros_width
(
linear
.
in_features
,
group_size
)
*
pack_num
,
),
dtype
=
torch
.
float16
,
device
=
scales
.
device
,
)
qscales
[:,
:
scales
.
shape
[
1
]]
=
scales
# awq_linear.scales = scales.clone().half()
awq_linear
.
scales
=
qscales
.
transpose
(
1
,
0
).
contiguous
()
if
linear
.
bias
is
not
None
:
awq_linear
.
bias
=
linear
.
bias
.
clone
().
half
()
intweight
=
[]
for
idx
in
range
(
awq_linear
.
in_features
):
intweight
.
append
(
torch
.
round
(
(
linear
.
weight
.
data
[:,
idx
]
+
scale_zeros
[:,
idx
//
group_size
])
/
qscales
[:,
idx
//
group_size
]
).
to
(
torch
.
int
)[:,
None
]
)
intweight
=
torch
.
cat
(
intweight
,
dim
=
1
)
intweight
=
intweight
.
to
(
dtype
=
torch
.
int32
)
awq_linear
.
qweight
=
pack_intweight
(
intweight
.
contiguous
(),
interleave
=
4
,
kstride
=
64
)
zeros
=
zeros
.
to
(
dtype
=
torch
.
int32
)
qzeros
=
torch
.
zeros_like
(
qscales
)
qzeros
[:,
:
scales
.
shape
[
1
]]
=
-
(
qscales
[:,
:
scales
.
shape
[
1
]]
*
(
zeros
.
to
(
torch
.
float32
))
).
to
(
torch
.
float16
)
awq_linear
.
qzeros
=
qzeros
.
transpose
(
1
,
0
).
contiguous
()
return
awq_linear
@
torch
.
no_grad
()
def
forward
(
self
,
x
):
inputs
=
x
if
inputs
.
numel
()
/
inputs
.
shape
[
-
1
]
<
8
:
out
=
awq_v2_ext
.
gemv_forward_cuda_decode
(
inputs
,
self
.
qweight
,
self
.
scales
,
self
.
qzeros
,
inputs
.
numel
()
//
inputs
.
shape
[
-
1
],
self
.
out_features
,
self
.
in_features
,
self
.
group_size
,
)
else
:
out
=
awq_v2_ext
.
gemm_forward_cuda_prefill
(
inputs
,
self
.
qweight
,
self
.
scales
,
self
.
qzeros
)
out
=
out
+
self
.
bias
if
self
.
bias
is
not
None
else
out
return
out
awq/quantize/quantizer.py
View file @
68c727a1
...
@@ -9,9 +9,12 @@ from collections import defaultdict
...
@@ -9,9 +9,12 @@ from collections import defaultdict
from
awq.utils.calib_data
import
get_calib_dataset
from
awq.utils.calib_data
import
get_calib_dataset
from
awq.quantize.scale
import
apply_scale
,
apply_clip
from
awq.quantize.scale
import
apply_scale
,
apply_clip
from
awq.utils.utils
import
clear_memory
,
get_best_device
from
awq.utils.utils
import
clear_memory
,
get_best_device
from
awq.modules.linear.gemm
import
WQLinear_GEMM
from
awq.modules.linear
import
(
from
awq.modules.linear.gemv
import
WQLinear_GEMV
WQLinear_GEMM
,
from
awq.modules.linear.marlin
import
WQLinear_Marlin
WQLinear_GEMV
,
WQLinear_Marlin
,
WQLinear_GEMVFast
,
)
from
awq.utils.module
import
(
from
awq.utils.module
import
(
append_str_prefix
,
append_str_prefix
,
get_op_name
,
get_op_name
,
...
@@ -200,6 +203,9 @@ class AwqQuantizer:
...
@@ -200,6 +203,9 @@ class AwqQuantizer:
elif
self
.
version
==
"marlin"
:
elif
self
.
version
==
"marlin"
:
q_linear_module
=
WQLinear_Marlin
q_linear_module
=
WQLinear_Marlin
elif
self
.
version
==
"gemv_fast"
:
q_linear_module
=
WQLinear_GEMVFast
else
:
else
:
raise
ValueError
(
f
"Unknown version
{
self
.
version
}
"
)
raise
ValueError
(
f
"Unknown version
{
self
.
version
}
"
)
...
@@ -466,6 +472,7 @@ class AwqQuantizer:
...
@@ -466,6 +472,7 @@ class AwqQuantizer:
self
.
model
(
samples
.
to
(
next
(
self
.
model
.
parameters
()).
device
))
self
.
model
(
samples
.
to
(
next
(
self
.
model
.
parameters
()).
device
))
except
ValueError
:
# work with early exit
except
ValueError
:
# work with early exit
pass
pass
modules
[
0
]
=
modules
[
0
].
module
# restore
# Update the layer kwargs with `prepare_inputs_for_generation` method
# Update the layer kwargs with `prepare_inputs_for_generation` method
# that takes care of everything to avoid unexpected errors.
# that takes care of everything to avoid unexpected errors.
...
@@ -474,7 +481,6 @@ class AwqQuantizer:
...
@@ -474,7 +481,6 @@ class AwqQuantizer:
layer_kwargs
.
pop
(
"input_ids"
)
layer_kwargs
.
pop
(
"input_ids"
)
del
samples
del
samples
modules
[
0
]
=
modules
[
0
].
module
# restore
inps
=
inps
[
0
]
inps
=
inps
[
0
]
modules
[
0
]
=
modules
[
0
].
cpu
()
modules
[
0
]
=
modules
[
0
].
cpu
()
...
...
awq/utils/fused_utils.py
View file @
68c727a1
import
torch
import
torch
from
awq.modules.linear.gemm
import
WQLinear_GEMM
from
awq.modules.linear
import
(
from
awq.modules.linear.gemv
import
WQLinear_GEMV
WQLinear_GEMM
,
from
awq.modules.linear.marlin
import
WQLinear_Marlin
WQLinear_GEMV
,
from
awq.modules.linear.exllama
import
WQLinear_Exllama
WQLinear_Marlin
,
from
awq.modules.linear.exllamav2
import
WQLinear_ExllamaV2
WQLinear_Exllama
,
WQLinear_ExllamaV2
,
WQLinear_GEMVFast
,
)
def
prepare_correct_devices
(
next_layer
,
hidden_states
,
mask
):
def
prepare_correct_devices
(
next_layer
,
hidden_states
,
mask
):
...
@@ -73,6 +76,8 @@ def fuse_qkv(module, q_proj, k_proj, v_proj):
...
@@ -73,6 +76,8 @@ def fuse_qkv(module, q_proj, k_proj, v_proj):
q_linear
=
WQLinear_ExllamaV2
q_linear
=
WQLinear_ExllamaV2
elif
isinstance
(
q_proj
,
WQLinear_Marlin
):
elif
isinstance
(
q_proj
,
WQLinear_Marlin
):
q_linear
=
WQLinear_Marlin
q_linear
=
WQLinear_Marlin
elif
isinstance
(
q_proj
,
WQLinear_GEMVFast
):
q_linear
=
WQLinear_GEMVFast
qkv_layer
=
q_linear
(
qkv_layer
=
q_linear
(
q_proj
.
w_bit
,
q_proj
.
w_bit
,
...
@@ -132,6 +137,17 @@ def fuse_qkv(module, q_proj, k_proj, v_proj):
...
@@ -132,6 +137,17 @@ def fuse_qkv(module, q_proj, k_proj, v_proj):
[
q_proj
.
scales
,
k_proj
.
scales
,
v_proj
.
scales
],
dim
=
1
[
q_proj
.
scales
,
k_proj
.
scales
,
v_proj
.
scales
],
dim
=
1
)
)
# workspace is created in post_init
# workspace is created in post_init
elif
isinstance
(
q_proj
,
WQLinear_GEMVFast
):
qkv_layer
.
qweight
=
torch
.
cat
(
[
q_proj
.
qweight
,
k_proj
.
qweight
,
v_proj
.
qweight
],
dim
=
0
)
qkv_layer
.
qzeros
=
torch
.
cat
(
[
q_proj
.
qzeros
,
k_proj
.
qzeros
,
v_proj
.
qzeros
],
dim
=
1
).
contiguous
()
qkv_layer
.
scales
=
torch
.
cat
(
[
q_proj
.
scales
,
k_proj
.
scales
,
v_proj
.
scales
],
dim
=
1
).
contiguous
()
qkv_layer
.
split_k_iters
=
q_proj
.
split_k_iters
qkv_layer
.
bias
=
bias
qkv_layer
.
bias
=
bias
...
...
examples/benchmark.py
View file @
68c727a1
...
@@ -117,11 +117,12 @@ def run_round(generator, model_path, quant_file, n_generate, input_ids, batch_si
...
@@ -117,11 +117,12 @@ def run_round(generator, model_path, quant_file, n_generate, input_ids, batch_si
raise
RuntimeError
(
ex
)
raise
RuntimeError
(
ex
)
total_memory_used
=
0
total_memory_used
=
0
memory_pct
=
100
if
successful_generate
:
if
successful_generate
:
# number of tokens in context / time for processing context * batch size
# number of tokens in context / time for processing context * batch size
prefill_tokens_per_second
=
input_ids
.
shape
[
1
]
/
context_time
*
batch_size
prefill_tokens_per_second
=
round
(
input_ids
.
shape
[
1
]
/
context_time
*
batch_size
,
2
)
# 1 second / median time per token in seconds * batch size
# 1 second / median time per token in seconds * batch size
decode_tokens_per_second
=
1
/
np
.
median
(
generate_time
)
*
batch_size
decode_tokens_per_second
=
round
(
1
/
np
.
median
(
generate_time
)
*
batch_size
,
2
)
print
(
f
" ** Speed (Prefill):
{
prefill_tokens_per_second
:.
2
f
}
tokens/second"
)
print
(
f
" ** Speed (Prefill):
{
prefill_tokens_per_second
:.
2
f
}
tokens/second"
)
print
(
f
" ** Speed (Decode):
{
decode_tokens_per_second
:.
2
f
}
tokens/second"
)
print
(
f
" ** Speed (Decode):
{
decode_tokens_per_second
:.
2
f
}
tokens/second"
)
...
...
setup.py
View file @
68c727a1
...
@@ -89,7 +89,6 @@ requirements = [
...
@@ -89,7 +89,6 @@ requirements = [
"transformers>=4.35.0"
,
"transformers>=4.35.0"
,
"tokenizers>=0.12.1"
,
"tokenizers>=0.12.1"
,
"typing_extensions>=4.8.0"
,
"typing_extensions>=4.8.0"
,
"triton"
,
"accelerate"
,
"accelerate"
,
"datasets"
,
"datasets"
,
"zstandard"
,
"zstandard"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment