Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
9c3dfa07
Unverified
Commit
9c3dfa07
authored
Dec 11, 2023
by
Younes Belkada
Committed by
GitHub
Dec 11, 2023
Browse files
FEAT: Add possibility of skipping modules when quantizing (#248)
parent
78b59d73
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
24 additions
and
7 deletions
+24
-7
awq/models/_config.py
awq/models/_config.py
+5
-2
awq/models/base.py
awq/models/base.py
+2
-2
awq/quantize/quantizer.py
awq/quantize/quantizer.py
+13
-1
awq/quantize/scale.py
awq/quantize/scale.py
+4
-2
No files found.
awq/models/_config.py
View file @
9c3dfa07
import
os
import
json
import
logging
from
typing
import
Dict
from
typing
import
Dict
,
Optional
,
List
from
dataclasses
import
dataclass
,
field
,
fields
from
transformers.utils.hub
import
PushToHubMixin
,
cached_file
...
...
@@ -13,6 +13,7 @@ class AwqConfig(PushToHubMixin):
w_bit
:
int
=
field
(
default
=
4
)
version
:
str
=
field
(
default
=
"GEMM"
)
config_file_name
=
"quant_config.json"
modules_to_not_convert
:
Optional
[
List
]
=
None
def
save_pretrained
(
self
,
save_dir
:
str
,
**
kwargs
):
logging
.
warning
(
...
...
@@ -76,7 +77,8 @@ class AwqConfig(PushToHubMixin):
"zero_point"
:
self
.
zero_point
,
"q_group_size"
:
self
.
q_group_size
,
"w_bit"
:
self
.
w_bit
,
"version"
:
self
.
version
"version"
:
self
.
version
,
"modules_to_not_convert"
:
self
.
modules_to_not_convert
,
}
def
to_transformers_dict
(
self
):
...
...
@@ -86,4 +88,5 @@ class AwqConfig(PushToHubMixin):
"group_size"
:
self
.
q_group_size
,
"bits"
:
self
.
w_bit
,
"version"
:
self
.
version
.
lower
(),
"modules_to_not_convert"
:
self
.
modules_to_not_convert
,
}
awq/models/base.py
View file @
9c3dfa07
...
...
@@ -49,12 +49,12 @@ class BaseAWQForCausalLM(nn.Module):
@
torch
.
no_grad
()
def
quantize
(
self
,
tokenizer
=
None
,
quant_config
=
{},
calib_data
:
Union
[
str
,
List
[
str
]]
=
"pileval"
,
split
=
"train"
,
text_column
=
"text"
,
duo_scaling
=
True
):
split
=
"train"
,
text_column
=
"text"
,
duo_scaling
=
True
,
modules_to_not_convert
=
None
):
self
.
quant_config
:
AwqConfig
=
AwqConfig
.
from_dict
(
quant_config
)
quantizer
=
AwqQuantizer
(
self
,
self
.
model
,
tokenizer
,
self
.
quant_config
.
w_bit
,
self
.
quant_config
.
q_group_size
,
self
.
quant_config
.
version
,
calib_data
,
split
,
text_column
,
duo_scaling
self
.
quant_config
.
version
,
calib_data
,
split
,
text_column
,
duo_scaling
,
modules_to_not_convert
=
modules_to_not_convert
)
quantizer
.
quantize
()
self
.
is_quantized
=
True
...
...
awq/quantize/quantizer.py
View file @
9c3dfa07
...
...
@@ -14,7 +14,7 @@ from awq.utils.module import append_str_prefix, get_op_name, get_named_linears,
class
AwqQuantizer
:
def
__init__
(
self
,
awq_model
,
model
,
tokenizer
,
w_bit
,
group_size
,
version
,
calib_data
,
split
,
text_column
,
duo_scaling
)
->
None
:
calib_data
,
split
,
text_column
,
duo_scaling
,
modules_to_not_convert
=
None
)
->
None
:
self
.
awq_model
=
awq_model
self
.
model
=
model
self
.
tokenizer
=
tokenizer
...
...
@@ -25,6 +25,7 @@ class AwqQuantizer:
self
.
split
=
split
self
.
text_column
=
text_column
self
.
duo_scaling
=
duo_scaling
self
.
modules_to_not_convert
=
modules_to_not_convert
if
modules_to_not_convert
is
not
None
else
[]
self
.
modules
,
self
.
module_kwargs
,
self
.
inps
=
self
.
init_quant
()
def
pseudo_quantize_tensor
(
self
,
w
:
torch
.
Tensor
,
get_scale_zp
=
False
):
...
...
@@ -68,6 +69,13 @@ class AwqQuantizer:
return
w
def
_exclude_layers_to_not_quantize
(
self
,
linear_layers
):
filtered_layers
=
{}
for
name
,
linear_layer
in
linear_layers
.
items
():
if
not
any
(
key
in
name
for
key
in
self
.
modules_to_not_convert
):
filtered_layers
[
name
]
=
linear_layer
return
filtered_layers
def
quantize
(
self
):
for
i
in
tqdm
(
range
(
len
(
self
.
modules
)),
desc
=
"AWQ"
):
# Move module and inputs to correct device
...
...
@@ -80,6 +88,10 @@ class AwqQuantizer:
# [STEP 1]: Get layer, extract linear modules, extract input features
named_linears
=
get_named_linears
(
self
.
modules
[
i
])
# Filter out the linear layers we don't want to exclude
named_linears
=
self
.
_exclude_layers_to_not_quantize
(
named_linears
)
input_feat
=
self
.
_get_input_feat
(
self
.
modules
[
i
],
named_linears
)
clear_memory
()
...
...
awq/quantize/scale.py
View file @
9c3dfa07
...
...
@@ -53,6 +53,8 @@ def apply_scale(module, scales_list, input_feat_dict=None):
# apply the scaling to input feat if given; prepare it for clipping
if
input_feat_dict
is
not
None
:
for
layer_name
in
layer_names
:
# Skip the modules that are not quantized
if
layer_name
in
input_feat_dict
:
inp
=
input_feat_dict
[
layer_name
]
inp
.
div_
(
scales
.
view
(
1
,
-
1
).
to
(
inp
.
device
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment