Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
d68441d0
Commit
d68441d0
authored
Aug 19, 2023
by
Casper Hansen
Browse files
Load quantized model with saved quant_config
parent
84d23089
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
36 additions
and
18 deletions
+36
-18
awq/entry.py
awq/entry.py
+3
-3
awq/models/auto.py
awq/models/auto.py
+2
-5
awq/models/base.py
awq/models/base.py
+31
-10
No files found.
awq/entry.py
View file @
d68441d0
...
...
@@ -46,12 +46,12 @@ def run_quant(model_path, search_path, dump_path, quant_config):
# Save quantized model
model
.
save_quantized
(
dump_path
)
def
run_perplexity
(
quant_path
,
quant_file
,
quant_config
,
device
):
def
run_perplexity
(
quant_path
,
quant_file
,
device
):
"""
Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness
"""
# Load model
model
=
AutoAWQForCausalLM
.
from_quantized
(
quant_path
,
quant_file
,
quant_config
)
model
=
AutoAWQForCausalLM
.
from_quantized
(
quant_path
,
quant_file
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
quant_path
,
trust_remote_code
=
True
)
# Load adapter
...
...
@@ -92,6 +92,6 @@ if __name__ == '__main__':
elif
args
.
entry_type
==
'quant'
:
run_quant
(
args
.
model_path
,
args
.
search_path
,
args
.
quant_path
,
quant_config
)
elif
args
.
entry_type
==
'perplexity'
:
run_perplexity
(
args
.
quant_path
,
args
.
quant_file
,
args
.
w_bit
,
quant_config
,
args
.
device
)
run_perplexity
(
args
.
quant_path
,
args
.
quant_file
,
args
.
device
)
else
:
raise
Exception
(
'--entry_type must be one of (search|quant|perplexity)'
)
\ No newline at end of file
awq/models/auto.py
View file @
d68441d0
...
...
@@ -16,8 +16,6 @@ def check_and_get_model_type(model_dir, trust_remote_code=True):
return
model_type
class
AutoAWQForCausalLM
:
default_quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
128
,
"w_bit"
:
4
}
def
__init__
(
self
):
raise
EnvironmentError
(
'You must instantiate AutoAWQForCausalLM with
\n
'
'AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained'
)
...
...
@@ -31,11 +29,10 @@ class AutoAWQForCausalLM:
)
@
classmethod
def
from_quantized
(
self
,
quant_path
,
quant_filename
,
quant_config
=
{},
def
from_quantized
(
self
,
quant_path
,
quant_filename
,
device
=
'balanced'
,
trust_remote_code
=
True
)
->
BaseAWQForCausalLM
:
model_type
=
check_and_get_model_type
(
quant_path
,
trust_remote_code
)
quant_config
=
quant_config
if
quant_config
else
self
.
default_quant_config
return
AWQ_CAUSAL_LM_MODEL_MAP
[
model_type
].
from_quantized
(
quant_path
,
model_type
,
quant_filename
,
quant_config
,
device
,
trust_remote_code
=
trust_remote_code
quant_path
,
model_type
,
quant_filename
,
device
,
trust_remote_code
=
trust_remote_code
)
\ No newline at end of file
awq/models/base.py
View file @
d68441d0
...
...
@@ -3,7 +3,6 @@ import gc
import
json
import
torch
import
functools
import
accelerate
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
collections
import
defaultdict
...
...
@@ -44,11 +43,11 @@ class BaseAWQForCausalLM(nn.Module):
auto_scale
=
auto_scale
,
mse_range
=
mse_range
,
calib_data
=
calib_data
)
if
run_quant
:
self
.
_awq_quant
(
quant_config
)
self
.
_awq_quant
()
def
_awq_quant
(
self
,
quant_config
):
assert
quant_config
[
"zero_point"
],
"We only support zero_point quantization now."
def
_awq_quant
(
self
):
assert
self
.
quant_config
[
"zero_point"
],
"We only support zero_point quantization now."
layers
=
self
.
get_model_layers
(
self
.
model
)
# Run AWQ quantization
...
...
@@ -59,11 +58,25 @@ class BaseAWQForCausalLM(nn.Module):
for
name
,
module
in
named_linears
.
items
():
module
.
cuda
()
module
.
weight
.
data
,
scales
,
zeros
=
pseudo_quantize_tensor
(
module
.
weight
.
data
,
w_bit
=
quant_config
[
'w_bit'
],
get_scale_zp
=
True
,
**
quant_config
)
module
.
weight
.
data
,
scales
,
zeros
=
pseudo_quantize_tensor
(
module
.
weight
.
data
,
get_scale_zp
=
True
,
**
self
.
quant_config
)
scales
=
scales
.
t
().
contiguous
()
zeros
=
zeros
.
t
().
contiguous
()
q_linear
=
WQLinear
.
from_linear
(
module
,
quant_config
[
'w_bit'
],
quant_config
[
'q_group_size'
],
False
,
scales
,
zeros
)
module
,
self
.
quant_config
[
'w_bit'
],
self
.
quant_config
[
'q_group_size'
],
False
,
scales
,
zeros
)
module
.
cpu
()
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
set_op_by_name
(
layer
,
name
,
q_linear
)
...
...
@@ -228,7 +241,7 @@ class BaseAWQForCausalLM(nn.Module):
)
@
classmethod
def
from_quantized
(
self
,
model_path
,
model_type
,
model_filename
,
w_bit
=
4
,
quant_config
=
{},
def
from_quantized
(
self
,
model_path
,
model_type
,
model_filename
,
device
=
'balanced'
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
,
safetensors
=
False
,
is_quantized
=
True
):
# Download model if path is not a directory
...
...
@@ -245,6 +258,14 @@ class BaseAWQForCausalLM(nn.Module):
model_filename
=
model_path
+
f
'/
{
model_filename
}
'
# Load config
quant_config_path
=
f
'
{
model_path
}
/quant_config.json'
if
os
.
path
.
exists
(
quant_config_path
):
with
open
(
quant_config_path
,
'r'
)
as
file
:
quant_config
=
json
.
loads
(
file
.
read
())
else
:
# Default config that works for most models
quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
128
,
"w_bit"
:
4
}
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
trust_remote_code
)
# Load empty weights
...
...
@@ -254,7 +275,7 @@ class BaseAWQForCausalLM(nn.Module):
# Only need to replace layers if a model is AWQ quantized
if
is_quantized
:
# Prepare WQLinear layers, replace nn.Linear
self
.
_load_quantized_modules
(
self
,
model
,
w_bit
,
quant_config
)
self
.
_load_quantized_modules
(
self
,
model
,
quant_config
)
model
.
tie_weights
()
...
...
@@ -281,7 +302,7 @@ class BaseAWQForCausalLM(nn.Module):
return
self
(
model
,
model_type
,
is_quantized
=
is_quantized
,
quant_config
=
quant_config
)
def
_load_quantized_modules
(
self
,
model
,
w_bit
,
quant_config
):
def
_load_quantized_modules
(
self
,
model
,
quant_config
):
# Real quantization of weights
assert
quant_config
[
"zero_point"
],
"We only support zero_point quantization now."
...
...
@@ -300,7 +321,7 @@ class BaseAWQForCausalLM(nn.Module):
# Replace nn.Linear with WQLinear
for
name
,
module
in
named_linears
.
items
():
q_linear
=
WQLinear
.
from_linear
(
module
,
w_bit
,
quant_config
[
'q_group_size'
],
True
)
module
,
quant_config
[
'
w_bit
'
]
,
quant_config
[
'q_group_size'
],
True
)
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
set_op_by_name
(
layer
,
name
,
q_linear
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment