Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
d68441d0
Commit
d68441d0
authored
Aug 19, 2023
by
Casper Hansen
Browse files
Load quantized model with saved quant_config
parent
84d23089
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
36 additions
and
18 deletions
+36
-18
awq/entry.py
awq/entry.py
+3
-3
awq/models/auto.py
awq/models/auto.py
+2
-5
awq/models/base.py
awq/models/base.py
+31
-10
No files found.
awq/entry.py
View file @
d68441d0
...
@@ -46,12 +46,12 @@ def run_quant(model_path, search_path, dump_path, quant_config):
...
@@ -46,12 +46,12 @@ def run_quant(model_path, search_path, dump_path, quant_config):
# Save quantized model
# Save quantized model
model
.
save_quantized
(
dump_path
)
model
.
save_quantized
(
dump_path
)
def
run_perplexity
(
quant_path
,
quant_file
,
quant_config
,
device
):
def
run_perplexity
(
quant_path
,
quant_file
,
device
):
"""
"""
Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness
Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness
"""
"""
# Load model
# Load model
model
=
AutoAWQForCausalLM
.
from_quantized
(
quant_path
,
quant_file
,
quant_config
)
model
=
AutoAWQForCausalLM
.
from_quantized
(
quant_path
,
quant_file
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
quant_path
,
trust_remote_code
=
True
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
quant_path
,
trust_remote_code
=
True
)
# Load adapter
# Load adapter
...
@@ -92,6 +92,6 @@ if __name__ == '__main__':
...
@@ -92,6 +92,6 @@ if __name__ == '__main__':
elif
args
.
entry_type
==
'quant'
:
elif
args
.
entry_type
==
'quant'
:
run_quant
(
args
.
model_path
,
args
.
search_path
,
args
.
quant_path
,
quant_config
)
run_quant
(
args
.
model_path
,
args
.
search_path
,
args
.
quant_path
,
quant_config
)
elif
args
.
entry_type
==
'perplexity'
:
elif
args
.
entry_type
==
'perplexity'
:
run_perplexity
(
args
.
quant_path
,
args
.
quant_file
,
args
.
w_bit
,
quant_config
,
args
.
device
)
run_perplexity
(
args
.
quant_path
,
args
.
quant_file
,
args
.
device
)
else
:
else
:
raise
Exception
(
'--entry_type must be one of (search|quant|perplexity)'
)
raise
Exception
(
'--entry_type must be one of (search|quant|perplexity)'
)
\ No newline at end of file
awq/models/auto.py
View file @
d68441d0
...
@@ -16,8 +16,6 @@ def check_and_get_model_type(model_dir, trust_remote_code=True):
...
@@ -16,8 +16,6 @@ def check_and_get_model_type(model_dir, trust_remote_code=True):
return
model_type
return
model_type
class
AutoAWQForCausalLM
:
class
AutoAWQForCausalLM
:
default_quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
128
,
"w_bit"
:
4
}
def
__init__
(
self
):
def
__init__
(
self
):
raise
EnvironmentError
(
'You must instantiate AutoAWQForCausalLM with
\n
'
raise
EnvironmentError
(
'You must instantiate AutoAWQForCausalLM with
\n
'
'AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained'
)
'AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained'
)
...
@@ -31,11 +29,10 @@ class AutoAWQForCausalLM:
...
@@ -31,11 +29,10 @@ class AutoAWQForCausalLM:
)
)
@
classmethod
@
classmethod
def
from_quantized
(
self
,
quant_path
,
quant_filename
,
quant_config
=
{},
def
from_quantized
(
self
,
quant_path
,
quant_filename
,
device
=
'balanced'
,
trust_remote_code
=
True
)
->
BaseAWQForCausalLM
:
device
=
'balanced'
,
trust_remote_code
=
True
)
->
BaseAWQForCausalLM
:
model_type
=
check_and_get_model_type
(
quant_path
,
trust_remote_code
)
model_type
=
check_and_get_model_type
(
quant_path
,
trust_remote_code
)
quant_config
=
quant_config
if
quant_config
else
self
.
default_quant_config
return
AWQ_CAUSAL_LM_MODEL_MAP
[
model_type
].
from_quantized
(
return
AWQ_CAUSAL_LM_MODEL_MAP
[
model_type
].
from_quantized
(
quant_path
,
model_type
,
quant_filename
,
quant_config
,
device
,
trust_remote_code
=
trust_remote_code
quant_path
,
model_type
,
quant_filename
,
device
,
trust_remote_code
=
trust_remote_code
)
)
\ No newline at end of file
awq/models/base.py
View file @
d68441d0
...
@@ -3,7 +3,6 @@ import gc
...
@@ -3,7 +3,6 @@ import gc
import
json
import
json
import
torch
import
torch
import
functools
import
functools
import
accelerate
import
torch.nn
as
nn
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
collections
import
defaultdict
from
collections
import
defaultdict
...
@@ -44,11 +43,11 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -44,11 +43,11 @@ class BaseAWQForCausalLM(nn.Module):
auto_scale
=
auto_scale
,
mse_range
=
mse_range
,
calib_data
=
calib_data
)
auto_scale
=
auto_scale
,
mse_range
=
mse_range
,
calib_data
=
calib_data
)
if
run_quant
:
if
run_quant
:
self
.
_awq_quant
(
quant_config
)
self
.
_awq_quant
()
def
_awq_quant
(
self
,
quant_config
):
def
_awq_quant
(
self
):
assert
quant_config
[
"zero_point"
],
"We only support zero_point quantization now."
assert
self
.
quant_config
[
"zero_point"
],
"We only support zero_point quantization now."
layers
=
self
.
get_model_layers
(
self
.
model
)
layers
=
self
.
get_model_layers
(
self
.
model
)
# Run AWQ quantization
# Run AWQ quantization
...
@@ -59,11 +58,25 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -59,11 +58,25 @@ class BaseAWQForCausalLM(nn.Module):
for
name
,
module
in
named_linears
.
items
():
for
name
,
module
in
named_linears
.
items
():
module
.
cuda
()
module
.
cuda
()
module
.
weight
.
data
,
scales
,
zeros
=
pseudo_quantize_tensor
(
module
.
weight
.
data
,
w_bit
=
quant_config
[
'w_bit'
],
get_scale_zp
=
True
,
**
quant_config
)
module
.
weight
.
data
,
scales
,
zeros
=
pseudo_quantize_tensor
(
module
.
weight
.
data
,
get_scale_zp
=
True
,
**
self
.
quant_config
)
scales
=
scales
.
t
().
contiguous
()
scales
=
scales
.
t
().
contiguous
()
zeros
=
zeros
.
t
().
contiguous
()
zeros
=
zeros
.
t
().
contiguous
()
q_linear
=
WQLinear
.
from_linear
(
q_linear
=
WQLinear
.
from_linear
(
module
,
quant_config
[
'w_bit'
],
quant_config
[
'q_group_size'
],
False
,
scales
,
zeros
)
module
,
self
.
quant_config
[
'w_bit'
],
self
.
quant_config
[
'q_group_size'
],
False
,
scales
,
zeros
)
module
.
cpu
()
module
.
cpu
()
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
set_op_by_name
(
layer
,
name
,
q_linear
)
set_op_by_name
(
layer
,
name
,
q_linear
)
...
@@ -228,7 +241,7 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -228,7 +241,7 @@ class BaseAWQForCausalLM(nn.Module):
)
)
@
classmethod
@
classmethod
def
from_quantized
(
self
,
model_path
,
model_type
,
model_filename
,
w_bit
=
4
,
quant_config
=
{},
def
from_quantized
(
self
,
model_path
,
model_type
,
model_filename
,
device
=
'balanced'
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
,
device
=
'balanced'
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
,
safetensors
=
False
,
is_quantized
=
True
):
safetensors
=
False
,
is_quantized
=
True
):
# Download model if path is not a directory
# Download model if path is not a directory
...
@@ -245,6 +258,14 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -245,6 +258,14 @@ class BaseAWQForCausalLM(nn.Module):
model_filename
=
model_path
+
f
'/
{
model_filename
}
'
model_filename
=
model_path
+
f
'/
{
model_filename
}
'
# Load config
# Load config
quant_config_path
=
f
'
{
model_path
}
/quant_config.json'
if
os
.
path
.
exists
(
quant_config_path
):
with
open
(
quant_config_path
,
'r'
)
as
file
:
quant_config
=
json
.
loads
(
file
.
read
())
else
:
# Default config that works for most models
quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
128
,
"w_bit"
:
4
}
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
trust_remote_code
)
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
trust_remote_code
)
# Load empty weights
# Load empty weights
...
@@ -254,7 +275,7 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -254,7 +275,7 @@ class BaseAWQForCausalLM(nn.Module):
# Only need to replace layers if a model is AWQ quantized
# Only need to replace layers if a model is AWQ quantized
if
is_quantized
:
if
is_quantized
:
# Prepare WQLinear layers, replace nn.Linear
# Prepare WQLinear layers, replace nn.Linear
self
.
_load_quantized_modules
(
self
,
model
,
w_bit
,
quant_config
)
self
.
_load_quantized_modules
(
self
,
model
,
quant_config
)
model
.
tie_weights
()
model
.
tie_weights
()
...
@@ -281,7 +302,7 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -281,7 +302,7 @@ class BaseAWQForCausalLM(nn.Module):
return
self
(
model
,
model_type
,
is_quantized
=
is_quantized
,
quant_config
=
quant_config
)
return
self
(
model
,
model_type
,
is_quantized
=
is_quantized
,
quant_config
=
quant_config
)
def
_load_quantized_modules
(
self
,
model
,
w_bit
,
quant_config
):
def
_load_quantized_modules
(
self
,
model
,
quant_config
):
# Real quantization of weights
# Real quantization of weights
assert
quant_config
[
"zero_point"
],
"We only support zero_point quantization now."
assert
quant_config
[
"zero_point"
],
"We only support zero_point quantization now."
...
@@ -300,7 +321,7 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -300,7 +321,7 @@ class BaseAWQForCausalLM(nn.Module):
# Replace nn.Linear with WQLinear
# Replace nn.Linear with WQLinear
for
name
,
module
in
named_linears
.
items
():
for
name
,
module
in
named_linears
.
items
():
q_linear
=
WQLinear
.
from_linear
(
q_linear
=
WQLinear
.
from_linear
(
module
,
w_bit
,
quant_config
[
'q_group_size'
],
True
)
module
,
quant_config
[
'
w_bit
'
]
,
quant_config
[
'q_group_size'
],
True
)
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
set_op_by_name
(
layer
,
name
,
q_linear
)
set_op_by_name
(
layer
,
name
,
q_linear
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment