Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
c6c7b065
Unverified
Commit
c6c7b065
authored
Jan 24, 2024
by
Casper
Committed by
GitHub
Jan 24, 2024
Browse files
Torch only inference + any-device quantization (#319)
parent
8117845b
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
375 additions
and
197 deletions
+375
-197
awq/models/base.py
awq/models/base.py
+5
-9
awq/modules/fused/mlp.py
awq/modules/fused/mlp.py
+7
-2
awq/modules/fused/norm.py
awq/modules/fused/norm.py
+6
-1
awq/modules/linear/exllama.py
awq/modules/linear/exllama.py
+6
-3
awq/modules/linear/exllamav2.py
awq/modules/linear/exllamav2.py
+6
-2
awq/modules/linear/gemm.py
awq/modules/linear/gemm.py
+184
-0
awq/modules/linear/gemv.py
awq/modules/linear/gemv.py
+5
-154
awq/quantize/quantizer.py
awq/quantize/quantizer.py
+16
-8
awq/quantize/scale.py
awq/quantize/scale.py
+7
-5
awq/utils/fused_utils.py
awq/utils/fused_utils.py
+4
-3
awq/utils/packing_utils.py
awq/utils/packing_utils.py
+17
-0
awq/utils/utils.py
awq/utils/utils.py
+9
-1
examples/basic_generate.py
examples/basic_generate.py
+2
-7
examples/tinyllama_generate.py
examples/tinyllama_generate.py
+36
-0
setup.py
setup.py
+7
-2
tests/test_dequantization.py
tests/test_dequantization.py
+58
-0
No files found.
awq/models/base.py
View file @
c6c7b065
import
os
import
os
import
gc
import
gc
import
json
import
json
import
time
import
torch
import
torch
import
transformers
import
torch.nn
as
nn
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
typing
import
List
,
Union
from
typing
import
List
,
Union
from
safetensors.torch
import
save_file
from
safetensors.torch
import
save_file
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
import
transformers
from
transformers.modeling_utils
import
shard_checkpoint
from
transformers.modeling_utils
import
shard_checkpoint
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
awq.modules.linear.gemm
import
WQLinear_GEMM
from
awq.modules.exllama
import
WQLinear_Exllama
,
exllama_post_init
from
awq.modules.linear.gemv
import
WQLinear_GEMV
from
awq.modules.exllamav2
import
WQLinear_ExllamaV2
,
exllamav2_post_init
from
awq.modules.linear.exllama
import
WQLinear_Exllama
,
exllama_post_init
from
awq.modules.linear.exllamav2
import
WQLinear_ExllamaV2
,
exllamav2_post_init
from
awq.utils.module
import
(
from
awq.utils.module
import
(
get_named_linears
,
get_named_linears
,
set_op_by_name
,
set_op_by_name
,
...
@@ -35,9 +34,6 @@ from accelerate.big_modeling import (
...
@@ -35,9 +34,6 @@ from accelerate.big_modeling import (
from
awq.models._config
import
AwqConfig
from
awq.models._config
import
AwqConfig
from
awq.modules.act
import
ScaledActivation
from
awq.modules.act
import
ScaledActivation
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
awq.modules.exllama
import
WQLinear_Exllama
from
awq.modules.exllamav2
import
WQLinear_ExllamaV2
from
awq.quantize.quantizer
import
AwqQuantizer
from
awq.quantize.quantizer
import
AwqQuantizer
from
awq.utils.module
import
get_named_linears
,
set_op_by_name
from
awq.utils.module
import
get_named_linears
,
set_op_by_name
...
...
awq/modules/fused/mlp.py
View file @
c6c7b065
import
torch.nn
as
nn
import
torch.nn
as
nn
import
awq_ext
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
awq.modules.linear.gemm
import
WQLinear_GEMM
from
awq.modules.linear.gemv
import
WQLinear_GEMV
try
:
import
awq_ext
# with CUDA kernels
AWQ_INSTALLED
=
True
except
:
AWQ_INSTALLED
=
False
class
QuantFusedMLP
(
nn
.
Module
):
class
QuantFusedMLP
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
...
awq/modules/fused/norm.py
View file @
c6c7b065
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
import
awq_ext
try
:
import
awq_ext
# with CUDA kernels
AWQ_INSTALLED
=
True
except
:
AWQ_INSTALLED
=
False
class
FasterTransformerRMSNorm
(
nn
.
Module
):
class
FasterTransformerRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
weight
,
eps
=
1e-6
):
def
__init__
(
self
,
weight
,
eps
=
1e-6
):
...
...
awq/modules/exllama.py
→
awq/modules/
linear/
exllama.py
View file @
c6c7b065
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
awq.utils.exllama_utils
import
unpack_reorder_pack
from
awq.utils.packing_utils
import
unpack_reorder_pack
import
exl_ext
# with CUDA kernels (AutoAWQ_kernels)
try
:
import
exl_ext
# with CUDA kernels (AutoAWQ_kernels)
AWQ_INSTALLED
=
True
except
:
AWQ_INSTALLED
=
False
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor
=
torch
.
empty
((
1
,
1
),
device
=
"meta"
)
none_tensor
=
torch
.
empty
((
1
,
1
),
device
=
"meta"
)
...
...
awq/modules/exllamav2.py
→
awq/modules/
linear/
exllamav2.py
View file @
c6c7b065
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Dict
from
typing
import
Dict
from
awq.utils.
exllama
_utils
import
unpack_reorder_pack
from
awq.utils.
packing
_utils
import
unpack_reorder_pack
import
exlv2_ext
# with CUDA kernels (AutoAWQ_kernels)
try
:
import
exlv2_ext
# with CUDA kernels (AutoAWQ_kernels)
AWQ_INSTALLED
=
True
except
:
AWQ_INSTALLED
=
False
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor
=
torch
.
empty
((
1
,
1
),
device
=
"meta"
)
none_tensor
=
torch
.
empty
((
1
,
1
),
device
=
"meta"
)
...
...
awq/modules/linear/gemm.py
0 → 100644
View file @
c6c7b065
import
torch
import
torch.nn
as
nn
from
awq.utils.utils
import
get_best_device
from
awq.utils.packing_utils
import
dequantize_gemm
try
:
import
awq_ext
# with CUDA kernels
AWQ_INSTALLED
=
True
except
:
AWQ_INSTALLED
=
False
class
WQLinear_GEMM
(
nn
.
Module
):
def
__init__
(
self
,
w_bit
,
group_size
,
in_features
,
out_features
,
bias
,
dev
):
super
().
__init__
()
if
w_bit
not
in
[
4
]:
raise
NotImplementedError
(
"Only 4-bit are supported for now."
)
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
w_bit
=
w_bit
self
.
group_size
=
group_size
if
group_size
!=
-
1
else
in_features
# quick sanity check (make sure aligment)
assert
self
.
in_features
%
self
.
group_size
==
0
assert
out_features
%
(
32
//
self
.
w_bit
)
==
0
self
.
register_buffer
(
"qweight"
,
torch
.
zeros
(
(
in_features
,
out_features
//
(
32
//
self
.
w_bit
)),
dtype
=
torch
.
int32
,
device
=
dev
,
),
)
self
.
register_buffer
(
"qzeros"
,
torch
.
zeros
(
(
in_features
//
self
.
group_size
,
out_features
//
(
32
//
self
.
w_bit
)),
dtype
=
torch
.
int32
,
device
=
dev
,
),
)
self
.
register_buffer
(
"scales"
,
torch
.
zeros
(
(
in_features
//
self
.
group_size
,
out_features
),
dtype
=
torch
.
float16
,
device
=
dev
,
),
)
if
bias
:
self
.
register_buffer
(
"bias"
,
torch
.
zeros
(
(
out_features
),
dtype
=
torch
.
float16
,
device
=
dev
,
),
)
else
:
self
.
bias
=
None
@
classmethod
def
from_linear
(
cls
,
linear
,
w_bit
,
group_size
,
init_only
=
False
,
scales
=
None
,
zeros
=
None
):
awq_linear
=
cls
(
w_bit
,
group_size
,
linear
.
in_features
,
linear
.
out_features
,
linear
.
bias
is
not
None
,
linear
.
weight
.
device
,
)
if
init_only
:
# just prepare for loading sd
return
awq_linear
# need scales and zeros info for real quantization
assert
scales
is
not
None
and
zeros
is
not
None
scale_zeros
=
zeros
*
scales
awq_linear
.
scales
=
scales
.
clone
().
half
()
if
linear
.
bias
is
not
None
:
awq_linear
.
bias
=
linear
.
bias
.
clone
().
half
()
pack_num
=
32
//
awq_linear
.
w_bit
intweight
=
[]
for
idx
in
range
(
awq_linear
.
in_features
):
intweight
.
append
(
torch
.
round
(
(
linear
.
weight
.
data
[:,
idx
]
+
scale_zeros
[
idx
//
group_size
])
/
awq_linear
.
scales
[
idx
//
group_size
]
).
to
(
torch
.
int
)[:,
None
]
)
intweight
=
torch
.
cat
(
intweight
,
dim
=
1
)
intweight
=
intweight
.
t
().
contiguous
()
intweight
=
intweight
.
to
(
dtype
=
torch
.
int32
)
best_device
=
get_best_device
()
# Avoid: The operator 'aten::__lshift__.Scalar' is not currently implemented for the MPS device
if
"mps"
in
best_device
:
intweight
=
intweight
.
to
(
"cpu"
)
qweight
=
torch
.
zeros
(
(
intweight
.
shape
[
0
],
intweight
.
shape
[
1
]
//
32
*
awq_linear
.
w_bit
),
dtype
=
torch
.
int32
,
device
=
intweight
.
device
,
)
for
col
in
range
(
intweight
.
shape
[
1
]
//
pack_num
):
if
awq_linear
.
w_bit
==
4
:
order_map
=
[
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
]
else
:
raise
NotImplementedError
(
"Only 4-bit are supported for now."
)
for
i
in
range
(
pack_num
):
qweight_col
=
intweight
[:,
col
*
pack_num
+
order_map
[
i
]]
qweight
[:,
col
]
|=
qweight_col
<<
(
i
*
awq_linear
.
w_bit
)
awq_linear
.
qweight
=
qweight
zeros
=
zeros
.
to
(
dtype
=
torch
.
int32
,
device
=
best_device
)
if
"mps"
in
best_device
:
zeros
=
zeros
.
to
(
"cpu"
)
qzeros
=
torch
.
zeros
(
(
zeros
.
shape
[
0
],
zeros
.
shape
[
1
]
//
32
*
awq_linear
.
w_bit
),
dtype
=
torch
.
int32
,
device
=
zeros
.
device
,
)
for
col
in
range
(
zeros
.
shape
[
1
]
//
pack_num
):
if
awq_linear
.
w_bit
==
4
:
order_map
=
[
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
]
else
:
raise
NotImplementedError
(
"Only 4-bit are supported for now."
)
for
i
in
range
(
pack_num
):
qzero_col
=
zeros
[:,
col
*
pack_num
+
order_map
[
i
]]
qzeros
[:,
col
]
|=
qzero_col
<<
(
i
*
awq_linear
.
w_bit
)
awq_linear
.
qzeros
=
qzeros
return
awq_linear
@
torch
.
no_grad
()
def
forward
(
self
,
x
):
out_shape
=
x
.
shape
[:
-
1
]
+
(
self
.
out_features
,)
input_dtype
=
x
.
dtype
if
input_dtype
!=
torch
.
float16
:
x
=
x
.
half
()
if
AWQ_INSTALLED
:
out
=
awq_ext
.
gemm_forward_cuda
(
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
]),
self
.
qweight
,
self
.
scales
,
self
.
qzeros
,
8
)
else
:
out
=
dequantize_gemm
(
self
.
qweight
,
self
.
qzeros
,
self
.
scales
,
self
.
w_bit
,
self
.
group_size
)
out
=
torch
.
matmul
(
x
,
out
)
if
input_dtype
!=
torch
.
float16
:
out
=
out
.
to
(
dtype
=
input_dtype
)
out
=
out
+
self
.
bias
if
self
.
bias
is
not
None
else
out
return
out
.
reshape
(
out_shape
)
def
extra_repr
(
self
)
->
str
:
return
(
"in_features={}, out_features={}, bias={}, w_bit={}, group_size={}"
.
format
(
self
.
in_features
,
self
.
out_features
,
self
.
bias
is
not
None
,
self
.
w_bit
,
self
.
group_size
,
)
)
awq/modules/linear.py
→
awq/modules/linear
/gemv
.py
View file @
c6c7b065
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
awq_ext
# with CUDA kernels
try
:
import
awq_ext
# with CUDA kernels
AWQ_INSTALLED
=
True
except
:
AWQ_INSTALLED
=
False
def
make_divisible
(
c
,
divisor
):
def
make_divisible
(
c
,
divisor
):
...
@@ -23,159 +27,6 @@ def calculate_zeros_width(in_features, group_size=128, pack_num=8):
...
@@ -23,159 +27,6 @@ def calculate_zeros_width(in_features, group_size=128, pack_num=8):
return
base_width
return
base_width
class
WQLinear_GEMM
(
nn
.
Module
):
def
__init__
(
self
,
w_bit
,
group_size
,
in_features
,
out_features
,
bias
,
dev
):
super
().
__init__
()
if
w_bit
not
in
[
4
]:
raise
NotImplementedError
(
"Only 4-bit are supported for now."
)
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
w_bit
=
w_bit
self
.
group_size
=
group_size
if
group_size
!=
-
1
else
in_features
# quick sanity check (make sure aligment)
assert
self
.
in_features
%
self
.
group_size
==
0
assert
out_features
%
(
32
//
self
.
w_bit
)
==
0
self
.
register_buffer
(
"qweight"
,
torch
.
zeros
(
(
in_features
,
out_features
//
(
32
//
self
.
w_bit
)),
dtype
=
torch
.
int32
,
device
=
dev
,
),
)
self
.
register_buffer
(
"qzeros"
,
torch
.
zeros
(
(
in_features
//
self
.
group_size
,
out_features
//
(
32
//
self
.
w_bit
)),
dtype
=
torch
.
int32
,
device
=
dev
,
),
)
self
.
register_buffer
(
"scales"
,
torch
.
zeros
(
(
in_features
//
self
.
group_size
,
out_features
),
dtype
=
torch
.
float16
,
device
=
dev
,
),
)
if
bias
:
self
.
register_buffer
(
"bias"
,
torch
.
zeros
(
(
out_features
),
dtype
=
torch
.
float16
,
device
=
dev
,
),
)
else
:
self
.
bias
=
None
@
classmethod
def
from_linear
(
cls
,
linear
,
w_bit
,
group_size
,
init_only
=
False
,
scales
=
None
,
zeros
=
None
):
awq_linear
=
cls
(
w_bit
,
group_size
,
linear
.
in_features
,
linear
.
out_features
,
linear
.
bias
is
not
None
,
linear
.
weight
.
device
,
)
if
init_only
:
# just prepare for loading sd
return
awq_linear
# need scales and zeros info for real quantization
assert
scales
is
not
None
and
zeros
is
not
None
scale_zeros
=
zeros
*
scales
awq_linear
.
scales
=
scales
.
clone
().
half
()
if
linear
.
bias
is
not
None
:
awq_linear
.
bias
=
linear
.
bias
.
clone
().
half
()
pack_num
=
32
//
awq_linear
.
w_bit
intweight
=
[]
for
idx
in
range
(
awq_linear
.
in_features
):
intweight
.
append
(
torch
.
round
(
(
linear
.
weight
.
data
[:,
idx
]
+
scale_zeros
[
idx
//
group_size
])
/
awq_linear
.
scales
[
idx
//
group_size
]
).
to
(
torch
.
int
)[:,
None
]
)
intweight
=
torch
.
cat
(
intweight
,
dim
=
1
)
intweight
=
intweight
.
t
().
contiguous
()
intweight
=
intweight
.
to
(
dtype
=
torch
.
int32
)
qweight
=
torch
.
zeros
(
(
intweight
.
shape
[
0
],
intweight
.
shape
[
1
]
//
32
*
awq_linear
.
w_bit
),
dtype
=
torch
.
int32
,
device
=
intweight
.
device
,
)
for
col
in
range
(
intweight
.
shape
[
1
]
//
pack_num
):
if
awq_linear
.
w_bit
==
4
:
order_map
=
[
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
]
else
:
raise
NotImplementedError
(
"Only 4-bit are supported for now."
)
for
i
in
range
(
pack_num
):
qweight_col
=
intweight
[:,
col
*
pack_num
+
order_map
[
i
]]
qweight
[:,
col
]
|=
qweight_col
<<
(
i
*
awq_linear
.
w_bit
)
awq_linear
.
qweight
=
qweight
zeros
=
zeros
.
to
(
dtype
=
torch
.
int32
)
qzeros
=
torch
.
zeros
(
(
zeros
.
shape
[
0
],
zeros
.
shape
[
1
]
//
32
*
awq_linear
.
w_bit
),
dtype
=
torch
.
int32
,
device
=
zeros
.
device
,
)
for
col
in
range
(
zeros
.
shape
[
1
]
//
pack_num
):
if
awq_linear
.
w_bit
==
4
:
order_map
=
[
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
]
else
:
raise
NotImplementedError
(
"Only 4-bit are supported for now."
)
for
i
in
range
(
pack_num
):
qzero_col
=
zeros
[:,
col
*
pack_num
+
order_map
[
i
]]
qzeros
[:,
col
]
|=
qzero_col
<<
(
i
*
awq_linear
.
w_bit
)
awq_linear
.
qzeros
=
qzeros
return
awq_linear
@
torch
.
no_grad
()
def
forward
(
self
,
x
):
out_shape
=
x
.
shape
[:
-
1
]
+
(
self
.
out_features
,)
input_dtype
=
x
.
dtype
if
input_dtype
!=
torch
.
float16
:
x
=
x
.
half
()
out
=
awq_ext
.
gemm_forward_cuda
(
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
]),
self
.
qweight
,
self
.
scales
,
self
.
qzeros
,
8
)
if
input_dtype
!=
torch
.
float16
:
out
=
out
.
to
(
dtype
=
input_dtype
)
out
=
out
+
self
.
bias
if
self
.
bias
is
not
None
else
out
return
out
.
reshape
(
out_shape
)
def
extra_repr
(
self
)
->
str
:
return
(
"in_features={}, out_features={}, bias={}, w_bit={}, group_size={}"
.
format
(
self
.
in_features
,
self
.
out_features
,
self
.
bias
is
not
None
,
self
.
w_bit
,
self
.
group_size
,
)
)
class
WQLinear_GEMV
(
nn
.
Module
):
class
WQLinear_GEMV
(
nn
.
Module
):
def
__init__
(
self
,
w_bit
,
group_size
,
in_features
,
out_features
,
bias
,
dev
):
def
__init__
(
self
,
w_bit
,
group_size
,
in_features
,
out_features
,
bias
,
dev
):
super
().
__init__
()
super
().
__init__
()
...
...
awq/quantize/quantizer.py
View file @
c6c7b065
...
@@ -6,10 +6,12 @@ import torch.nn as nn
...
@@ -6,10 +6,12 @@ import torch.nn as nn
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
typing
import
Dict
,
List
from
typing
import
Dict
,
List
from
collections
import
defaultdict
from
collections
import
defaultdict
from
awq.utils.utils
import
clear_memory
from
awq.utils.calib_data
import
get_calib_dataset
from
awq.utils.calib_data
import
get_calib_dataset
from
awq.quantize.scale
import
apply_scale
,
apply_clip
from
awq.quantize.scale
import
apply_scale
,
apply_clip
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
awq.utils.utils
import
clear_memory
,
get_best_device
from
awq.modules.linear.gemm
import
WQLinear_GEMM
from
awq.modules.linear.gemv
import
WQLinear_GEMV
from
awq.utils.module
import
(
from
awq.utils.module
import
(
append_str_prefix
,
append_str_prefix
,
get_op_name
,
get_op_name
,
...
@@ -83,7 +85,12 @@ class AwqQuantizer:
...
@@ -83,7 +85,12 @@ class AwqQuantizer:
# Move module and inputs to correct device
# Move module and inputs to correct device
common_device
=
next
(
self
.
modules
[
i
].
parameters
()).
device
common_device
=
next
(
self
.
modules
[
i
].
parameters
()).
device
if
common_device
is
None
or
str
(
common_device
)
==
"cpu"
:
if
common_device
is
None
or
str
(
common_device
)
==
"cpu"
:
self
.
modules
[
i
]
=
self
.
modules
[
i
].
cuda
(
"cuda:"
+
str
(
i
%
torch
.
cuda
.
device_count
()))
if
torch
.
cuda
.
is_available
():
best_device
=
"cuda:"
+
str
(
i
%
torch
.
cuda
.
device_count
())
else
:
best_device
=
get_best_device
()
self
.
modules
[
i
]
=
self
.
modules
[
i
].
to
(
best_device
)
common_device
=
next
(
self
.
modules
[
i
].
parameters
()).
device
common_device
=
next
(
self
.
modules
[
i
].
parameters
()).
device
if
self
.
module_kwargs
.
get
(
"position_ids"
)
is
not
None
:
if
self
.
module_kwargs
.
get
(
"position_ids"
)
is
not
None
:
...
@@ -132,7 +139,7 @@ class AwqQuantizer:
...
@@ -132,7 +139,7 @@ class AwqQuantizer:
def
_apply_quant
(
self
,
module
,
named_linears
:
Dict
[
str
,
nn
.
Linear
]):
def
_apply_quant
(
self
,
module
,
named_linears
:
Dict
[
str
,
nn
.
Linear
]):
for
name
,
linear_layer
in
named_linears
.
items
():
for
name
,
linear_layer
in
named_linears
.
items
():
# NOTE: small regression in perplexity if linear layer uses .cpu().float()
# NOTE: small regression in perplexity if linear layer uses .cpu().float()
linear_layer
=
linear_layer
.
cuda
(
).
half
()
linear_layer
=
linear_layer
.
to
(
get_best_device
()
).
half
()
linear_layer
.
weight
.
data
,
scales
,
zeros
=
self
.
pseudo_quantize_tensor
(
linear_layer
.
weight
.
data
,
scales
,
zeros
=
self
.
pseudo_quantize_tensor
(
linear_layer
.
weight
.
data
,
linear_layer
.
weight
.
data
,
...
@@ -274,7 +281,7 @@ class AwqQuantizer:
...
@@ -274,7 +281,7 @@ class AwqQuantizer:
if
any
([
_
in
name
for
_
in
avoid_clipping
]):
if
any
([
_
in
name
for
_
in
avoid_clipping
]):
continue
continue
named_linears
[
name
].
cuda
()
named_linears
[
name
].
to
(
get_best_device
()
)
max_val
=
self
.
_compute_best_clip
(
named_linears
[
name
].
weight
,
input_feat
[
name
])
max_val
=
self
.
_compute_best_clip
(
named_linears
[
name
].
weight
,
input_feat
[
name
])
clip_list
.
append
((
name
,
max_val
))
clip_list
.
append
((
name
,
max_val
))
...
@@ -343,8 +350,9 @@ class AwqQuantizer:
...
@@ -343,8 +350,9 @@ class AwqQuantizer:
inps
=
[]
inps
=
[]
layer_kwargs
=
{}
layer_kwargs
=
{}
modules
[
0
]
=
modules
[
0
].
cuda
()
best_device
=
get_best_device
()
self
.
awq_model
.
move_embed
(
self
.
model
,
"cuda"
)
modules
[
0
]
=
modules
[
0
].
to
(
best_device
)
self
.
awq_model
.
move_embed
(
self
.
model
,
best_device
)
# get input and kwargs to layer 0
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# with_kwargs is only supported in PyTorch 2.0
...
@@ -390,7 +398,7 @@ class AwqQuantizer:
...
@@ -390,7 +398,7 @@ class AwqQuantizer:
clear_memory
()
clear_memory
()
if
layer_kwargs
.
get
(
"attention_mask"
)
is
not
None
:
if
layer_kwargs
.
get
(
"attention_mask"
)
is
not
None
:
layer_kwargs
[
"attention_mask"
]
=
layer_kwargs
[
"attention_mask"
].
to
(
"cuda"
)
layer_kwargs
[
"attention_mask"
]
=
layer_kwargs
[
"attention_mask"
].
to
(
best_device
)
return
modules
,
layer_kwargs
,
inps
return
modules
,
layer_kwargs
,
inps
...
...
awq/quantize/scale.py
View file @
c6c7b065
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Tuple
,
List
from
typing
import
Tuple
,
List
from
awq.utils.utils
import
get_best_device
from
awq.modules.act
import
ScaledActivation
from
awq.modules.act
import
ScaledActivation
from
awq.utils.module
import
get_op_by_name
,
set_op_by_name
from
awq.utils.module
import
get_op_by_name
,
set_op_by_name
from
transformers.models.bloom.modeling_bloom
import
BloomGelu
from
transformers.models.bloom.modeling_bloom
import
BloomGelu
...
@@ -14,7 +15,7 @@ allowed_act_fns = [nn.GELU, BloomGelu, NewGELUActivation, PytorchGELUTanh, GELUA
...
@@ -14,7 +15,7 @@ allowed_act_fns = [nn.GELU, BloomGelu, NewGELUActivation, PytorchGELUTanh, GELUA
def
apply_clip
(
module
,
clip_list
:
Tuple
[
str
,
torch
.
Tensor
]):
def
apply_clip
(
module
,
clip_list
:
Tuple
[
str
,
torch
.
Tensor
]):
for
name
,
max_val
in
clip_list
:
for
name
,
max_val
in
clip_list
:
layer
:
nn
.
Linear
=
get_op_by_name
(
module
,
name
)
layer
:
nn
.
Linear
=
get_op_by_name
(
module
,
name
)
layer
.
cuda
()
layer
.
to
(
get_best_device
()
)
max_val
=
max_val
.
to
(
layer
.
weight
.
device
)
max_val
=
max_val
.
to
(
layer
.
weight
.
device
)
org_shape
=
layer
.
weight
.
shape
org_shape
=
layer
.
weight
.
shape
layer
.
weight
.
data
=
layer
.
weight
.
data
.
reshape
(
*
max_val
.
shape
[:
2
],
-
1
)
layer
.
weight
.
data
=
layer
.
weight
.
data
.
reshape
(
*
max_val
.
shape
[:
2
],
-
1
)
...
@@ -27,11 +28,12 @@ def apply_scale(module, scales_list, input_feat_dict=None):
...
@@ -27,11 +28,12 @@ def apply_scale(module, scales_list, input_feat_dict=None):
for
prev_op_name
,
layer_names
,
scales
in
scales_list
:
for
prev_op_name
,
layer_names
,
scales
in
scales_list
:
prev_op
=
get_op_by_name
(
module
,
prev_op_name
)
prev_op
=
get_op_by_name
(
module
,
prev_op_name
)
layers
=
[
get_op_by_name
(
module
,
name
)
for
name
in
layer_names
]
layers
=
[
get_op_by_name
(
module
,
name
)
for
name
in
layer_names
]
prev_op
.
cuda
()
best_device
=
get_best_device
()
prev_op
.
to
(
best_device
)
for
layer
in
layers
:
for
layer
in
layers
:
layer
.
cuda
(
)
layer
.
to
(
best_device
)
scales
.
cuda
(
)
scales
.
to
(
best_device
)
if
isinstance
(
prev_op
,
nn
.
Linear
)
and
type
(
layers
)
==
list
and
isinstance
(
layers
[
0
],
nn
.
Linear
):
if
isinstance
(
prev_op
,
nn
.
Linear
)
and
type
(
layers
)
==
list
and
isinstance
(
layers
[
0
],
nn
.
Linear
):
scale_fc_fcs
(
prev_op
,
layers
,
scales
)
scale_fc_fcs
(
prev_op
,
layers
,
scales
)
...
...
awq/utils/fused_utils.py
View file @
c6c7b065
import
torch
import
torch
from
awq.modules.exllama
import
WQLinear_Exllama
from
awq.modules.linear.gemm
import
WQLinear_GEMM
from
awq.modules.exllamav2
import
WQLinear_ExllamaV2
from
awq.modules.linear.gemv
import
WQLinear_GEMV
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
awq.modules.linear.exllama
import
WQLinear_Exllama
from
awq.modules.linear.exllamav2
import
WQLinear_ExllamaV2
def
prepare_correct_devices
(
next_layer
,
hidden_states
,
mask
):
def
prepare_correct_devices
(
next_layer
,
hidden_states
,
mask
):
hidden_states
=
hidden_states
.
to
(
next_layer
.
device
)
hidden_states
=
hidden_states
.
to
(
next_layer
.
device
)
...
...
awq/utils/
exllama
_utils.py
→
awq/utils/
packing
_utils.py
View file @
c6c7b065
...
@@ -78,3 +78,20 @@ def unpack_reorder_pack(qweight, qzeros, bits):
...
@@ -78,3 +78,20 @@ def unpack_reorder_pack(qweight, qzeros, bits):
qweight
,
qzeros
=
pack_exllama
(
iweight
,
izeros
,
bits
)
qweight
,
qzeros
=
pack_exllama
(
iweight
,
izeros
,
bits
)
return
qweight
,
qzeros
return
qweight
,
qzeros
def
dequantize_gemm
(
qweight
,
qzeros
,
scales
,
bits
,
group_size
):
# Unpack the qweight and qzeros tensors
iweight
,
izeros
=
unpack_awq
(
qweight
,
qzeros
,
bits
)
# Reverse the order of the iweight and izeros tensors
iweight
,
izeros
=
reverse_awq_order
(
iweight
,
izeros
,
bits
)
# overflow checks
iweight
=
torch
.
bitwise_and
(
iweight
,
(
2
**
bits
)
-
1
)
izeros
=
torch
.
bitwise_and
(
izeros
,
(
2
**
bits
)
-
1
)
# fp16 weights
scales
=
scales
.
repeat_interleave
(
group_size
,
dim
=
0
)
izeros
=
izeros
.
repeat_interleave
(
group_size
,
dim
=
0
)
iweight
=
(
iweight
-
izeros
)
*
scales
return
iweight
\ No newline at end of file
awq/utils/utils.py
View file @
c6c7b065
...
@@ -64,4 +64,12 @@ def clear_memory(weight=None):
...
@@ -64,4 +64,12 @@ def clear_memory(weight=None):
def
compute_memory_used_pct
(
device
):
def
compute_memory_used_pct
(
device
):
memory_used
=
torch
.
cuda
.
max_memory_allocated
(
device
)
/
(
1024
**
3
)
memory_used
=
torch
.
cuda
.
max_memory_allocated
(
device
)
/
(
1024
**
3
)
memory_pct
=
memory_used
/
(
torch
.
cuda
.
get_device_properties
(
device
).
total_memory
/
(
1024
**
3
))
*
100
memory_pct
=
memory_used
/
(
torch
.
cuda
.
get_device_properties
(
device
).
total_memory
/
(
1024
**
3
))
*
100
return
memory_pct
return
memory_pct
\ No newline at end of file
def
get_best_device
():
if
torch
.
backends
.
mps
.
is_available
():
return
'mps'
elif
torch
.
cuda
.
is_available
():
return
'cuda:0'
else
:
return
'cpu'
\ No newline at end of file
examples/basic_generate.py
View file @
c6c7b065
from
awq
import
AutoAWQForCausalLM
from
awq
import
AutoAWQForCausalLM
from
transformers
import
AutoTokenizer
,
TextStreamer
from
transformers
import
AutoTokenizer
,
TextStreamer
quant_path
=
"TheBloke/
zephyr-7B-beta
-AWQ"
quant_path
=
"TheBloke/
Mistral-7B-Instruct-v0.2
-AWQ"
# Load model
# Load model
model
=
AutoAWQForCausalLM
.
from_quantized
(
quant_path
,
fuse_layers
=
True
)
model
=
AutoAWQForCausalLM
.
from_quantized
(
quant_path
,
fuse_layers
=
True
)
...
@@ -9,12 +9,7 @@ tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
...
@@ -9,12 +9,7 @@ tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
streamer
=
TextStreamer
(
tokenizer
,
skip_prompt
=
True
,
skip_special_tokens
=
True
)
streamer
=
TextStreamer
(
tokenizer
,
skip_prompt
=
True
,
skip_special_tokens
=
True
)
# Convert prompt to tokens
# Convert prompt to tokens
prompt_template
=
"""
\
prompt_template
=
"[INST] {prompt} [/INST]"
<|system|>
</s>
<|user|>
{prompt}</s>
<|assistant|>"""
prompt
=
"You're standing on the surface of the Earth. "
\
prompt
=
"You're standing on the surface of the Earth. "
\
"You walk one mile south, one mile west and one mile north. "
\
"You walk one mile south, one mile west and one mile north. "
\
...
...
examples/tinyllama_generate.py
0 → 100644
View file @
c6c7b065
from
awq
import
AutoAWQForCausalLM
from
transformers
import
AutoTokenizer
,
TextStreamer
quant_path
=
"TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ"
# Load model
model
=
AutoAWQForCausalLM
.
from_quantized
(
quant_path
,
fuse_layers
=
False
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
quant_path
,
trust_remote_code
=
True
)
streamer
=
TextStreamer
(
tokenizer
,
skip_prompt
=
True
,
skip_special_tokens
=
True
)
# Convert prompt to tokens
prompt_template
=
"""
\
<|im_start|>system
{system}<|im_end|>
<|im_start|>user
{prompt}<|im_end|>
<|im_start|>assistant
"""
system
=
"You are a helpful assistant that answers precisely."
prompt
=
"You're standing on the surface of the Earth. "
\
"You walk one mile south, one mile west and one mile north. "
\
"You end up exactly where you started. Where are you?"
tokens
=
tokenizer
(
prompt_template
.
format
(
system
=
system
,
prompt
=
prompt
),
return_tensors
=
'pt'
).
input_ids
.
to
(
"mps"
)
# Generate output
generation_output
=
model
.
generate
(
tokens
,
streamer
=
streamer
,
max_new_tokens
=
64
)
\ No newline at end of file
setup.py
View file @
c6c7b065
import
os
import
os
import
sys
import
sys
import
torch
import
torch
import
platform
from
pathlib
import
Path
from
pathlib
import
Path
from
setuptools
import
setup
,
find_packages
from
setuptools
import
setup
,
find_packages
...
@@ -8,8 +9,9 @@ os.environ["CC"] = "g++"
...
@@ -8,8 +9,9 @@ os.environ["CC"] = "g++"
os
.
environ
[
"CXX"
]
=
"g++"
os
.
environ
[
"CXX"
]
=
"g++"
AUTOAWQ_VERSION
=
"0.1.8"
AUTOAWQ_VERSION
=
"0.1.8"
PYPI_BUILD
=
os
.
getenv
(
"PYPI_BUILD"
,
"0"
)
==
"1"
PYPI_BUILD
=
os
.
getenv
(
"PYPI_BUILD"
,
"0"
)
==
"1"
HAS_CUDA
=
torch
.
cuda
.
is_available
()
if
not
PYPI_BUILD
:
if
not
PYPI_BUILD
and
HAS_CUDA
:
try
:
try
:
CUDA_VERSION
=
""
.
join
(
os
.
environ
.
get
(
"CUDA_VERSION"
,
torch
.
version
.
cuda
).
split
(
"."
))[:
3
]
CUDA_VERSION
=
""
.
join
(
os
.
environ
.
get
(
"CUDA_VERSION"
,
torch
.
version
.
cuda
).
split
(
"."
))[:
3
]
AUTOAWQ_VERSION
+=
f
"+cu
{
CUDA_VERSION
}
"
AUTOAWQ_VERSION
+=
f
"+cu
{
CUDA_VERSION
}
"
...
@@ -42,7 +44,6 @@ common_setup_kwargs = {
...
@@ -42,7 +44,6 @@ common_setup_kwargs = {
}
}
requirements
=
[
requirements
=
[
"autoawq-kernels"
,
"torch>=2.0.1"
,
"torch>=2.0.1"
,
"transformers>=4.35.0"
,
"transformers>=4.35.0"
,
"tokenizers>=0.12.1"
,
"tokenizers>=0.12.1"
,
...
@@ -50,6 +51,10 @@ requirements = [
...
@@ -50,6 +51,10 @@ requirements = [
"datasets"
,
"datasets"
,
]
]
# CUDA kernels
if
platform
.
system
().
lower
()
!=
"darwin"
and
HAS_CUDA
:
requirements
.
append
(
"autoawq-kernels"
)
setup
(
setup
(
packages
=
find_packages
(),
packages
=
find_packages
(),
install_requires
=
requirements
,
install_requires
=
requirements
,
...
...
tests/test_dequantization.py
0 → 100644
View file @
c6c7b065
import
torch
torch
.
manual_seed
(
0
)
torch
.
cuda
.
manual_seed
(
0
)
torch
.
cuda
.
manual_seed_all
(
0
)
import
awq_ext
from
awq.utils.packing_utils
import
dequantize_gemm
in_features
=
4096
out_features
=
1792
w_bit
=
4
group_size
=
128
MAX_INT32
=
0x7fffffff
MIN_INT32
=
-
MAX_INT32
-
1
qweight
=
torch
.
randint
(
MIN_INT32
,
MAX_INT32
,
(
in_features
,
out_features
//
(
32
//
w_bit
)),
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
qzeros
=
torch
.
randint
(
MIN_INT32
,
MAX_INT32
,
(
in_features
//
group_size
,
out_features
//
(
32
//
w_bit
)),
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
scales
=
torch
.
randn
(
(
in_features
//
group_size
,
out_features
),
dtype
=
torch
.
float16
,
device
=
"cuda"
,
)
with
torch
.
no_grad
():
cuda_out
=
awq_ext
.
dequantize_weights_cuda
(
qweight
,
scales
,
qzeros
,
0
,
0
,
0
,
False
)
torch_out
=
dequantize_gemm
(
qweight
,
qzeros
,
scales
,
w_bit
,
group_size
)
assert
(
torch
.
allclose
(
cuda_out
,
torch_out
,
rtol
=
0.0001
))
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment