Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
9b2946b6
Commit
9b2946b6
authored
Sep 08, 2023
by
Casper Hansen
Browse files
Add deprecation warning
parent
fe314160
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
7 deletions
+22
-7
awq/models/base.py
awq/models/base.py
+22
-7
No files found.
awq/models/base.py
View file @
9b2946b6
...
...
@@ -2,12 +2,13 @@ import os
import
gc
import
json
import
torch
import
logging
import
functools
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
collections
import
defaultdict
from
awq.modules.qlinear
import
WQLinear_GEMM
from
awq.modules.qlinear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
awq.modules.act
import
ScaledActivation
from
huggingface_hub
import
snapshot_download
from
awq.utils.calib_data
import
get_calib_dataset
...
...
@@ -254,7 +255,7 @@ class BaseAWQForCausalLM(nn.Module):
@
classmethod
def
from_quantized
(
self
,
model_path
,
model_type
,
model_filename
,
max_new_tokens
=
None
,
device
=
'balanced'
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
,
safetensors
=
False
,
is_quantized
=
True
,
fuse_layers
=
False
):
safetensors
=
False
,
is_quantized
=
True
,
fuse_layers
=
False
,
version
=
'GEMM'
):
# [STEP 1] Download model if path is not a directory
if
not
os
.
path
.
isdir
(
model_path
):
ignore_patterns
=
[
"*msgpack*"
,
"*h5*"
]
...
...
@@ -276,7 +277,7 @@ class BaseAWQForCausalLM(nn.Module):
quant_config
=
json
.
loads
(
file
.
read
())
else
:
# Default config that works for most models
quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
128
,
"w_bit"
:
4
}
quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
128
,
"w_bit"
:
4
,
"version"
:
"GEMM"
}
# Load model config and set max generation length
if
max_new_tokens
is
None
and
hasattr
(
self
,
'max_new_tokens_key'
):
...
...
@@ -294,7 +295,7 @@ class BaseAWQForCausalLM(nn.Module):
# Only need to replace layers if a model is AWQ quantized
if
is_quantized
:
# Prepare WQLinear layers, replace nn.Linear
self
.
_load_quantized_modules
(
self
,
model
,
quant_config
)
self
.
_load_quantized_modules
(
self
,
model
,
quant_config
,
version
)
model
.
tie_weights
()
...
...
@@ -334,9 +335,14 @@ class BaseAWQForCausalLM(nn.Module):
return
self
(
model
,
model_type
,
is_quantized
=
is_quantized
,
quant_config
=
quant_config
)
def
_load_quantized_modules
(
self
,
model
,
quant_config
):
def
_load_quantized_modules
(
self
,
model
,
quant_config
,
version
):
# Real quantization of weights
assert
quant_config
[
"zero_point"
],
"We only support zero_point quantization now."
if
version
==
'GEMM'
:
logging
.
warning
(
'Deprecated model weight format. Re-quantize '
'your weights again with version="GEMV" for a speedup. '
'In the next AutoAWQ version, GEMM will be deprecated.'
)
# Get blocks of model
layers
=
self
.
get_model_layers
(
model
)
...
...
@@ -352,8 +358,17 @@ class BaseAWQForCausalLM(nn.Module):
# Replace nn.Linear with WQLinear
for
name
,
module
in
named_linears
.
items
():
q_linear
=
WQLinear_GEMM
.
from_linear
(
module
,
quant_config
[
'w_bit'
],
quant_config
[
'q_group_size'
],
True
)
if
version
==
'GEMM'
:
q_linear_module
=
WQLinear_GEMM
elif
version
==
'GEMV'
:
q_linear_module
=
WQLinear_GEMV
q_linear
=
q_linear_module
.
from_linear
(
module
,
quant_config
[
'w_bit'
],
quant_config
[
'q_group_size'
],
True
)
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
set_op_by_name
(
layer
,
name
,
q_linear
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment