Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
9b2946b6
Commit
9b2946b6
authored
Sep 08, 2023
by
Casper Hansen
Browse files
Add deprecation warning
parent
fe314160
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
7 deletions
+22
-7
awq/models/base.py
awq/models/base.py
+22
-7
No files found.
awq/models/base.py
View file @
9b2946b6
...
@@ -2,12 +2,13 @@ import os
...
@@ -2,12 +2,13 @@ import os
import
gc
import
gc
import
json
import
json
import
torch
import
torch
import
logging
import
functools
import
functools
import
torch.nn
as
nn
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
collections
import
defaultdict
from
collections
import
defaultdict
from
awq.modules.qlinear
import
WQLinear_GEMM
from
awq.modules.qlinear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
awq.modules.act
import
ScaledActivation
from
awq.modules.act
import
ScaledActivation
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
from
awq.utils.calib_data
import
get_calib_dataset
from
awq.utils.calib_data
import
get_calib_dataset
...
@@ -254,7 +255,7 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -254,7 +255,7 @@ class BaseAWQForCausalLM(nn.Module):
@
classmethod
@
classmethod
def
from_quantized
(
self
,
model_path
,
model_type
,
model_filename
,
max_new_tokens
=
None
,
def
from_quantized
(
self
,
model_path
,
model_type
,
model_filename
,
max_new_tokens
=
None
,
device
=
'balanced'
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
,
device
=
'balanced'
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
,
safetensors
=
False
,
is_quantized
=
True
,
fuse_layers
=
False
):
safetensors
=
False
,
is_quantized
=
True
,
fuse_layers
=
False
,
version
=
'GEMM'
):
# [STEP 1] Download model if path is not a directory
# [STEP 1] Download model if path is not a directory
if
not
os
.
path
.
isdir
(
model_path
):
if
not
os
.
path
.
isdir
(
model_path
):
ignore_patterns
=
[
"*msgpack*"
,
"*h5*"
]
ignore_patterns
=
[
"*msgpack*"
,
"*h5*"
]
...
@@ -276,7 +277,7 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -276,7 +277,7 @@ class BaseAWQForCausalLM(nn.Module):
quant_config
=
json
.
loads
(
file
.
read
())
quant_config
=
json
.
loads
(
file
.
read
())
else
:
else
:
# Default config that works for most models
# Default config that works for most models
quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
128
,
"w_bit"
:
4
}
quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
128
,
"w_bit"
:
4
,
"version"
:
"GEMM"
}
# Load model config and set max generation length
# Load model config and set max generation length
if
max_new_tokens
is
None
and
hasattr
(
self
,
'max_new_tokens_key'
):
if
max_new_tokens
is
None
and
hasattr
(
self
,
'max_new_tokens_key'
):
...
@@ -294,7 +295,7 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -294,7 +295,7 @@ class BaseAWQForCausalLM(nn.Module):
# Only need to replace layers if a model is AWQ quantized
# Only need to replace layers if a model is AWQ quantized
if
is_quantized
:
if
is_quantized
:
# Prepare WQLinear layers, replace nn.Linear
# Prepare WQLinear layers, replace nn.Linear
self
.
_load_quantized_modules
(
self
,
model
,
quant_config
)
self
.
_load_quantized_modules
(
self
,
model
,
quant_config
,
version
)
model
.
tie_weights
()
model
.
tie_weights
()
...
@@ -334,9 +335,14 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -334,9 +335,14 @@ class BaseAWQForCausalLM(nn.Module):
return
self
(
model
,
model_type
,
is_quantized
=
is_quantized
,
quant_config
=
quant_config
)
return
self
(
model
,
model_type
,
is_quantized
=
is_quantized
,
quant_config
=
quant_config
)
def
_load_quantized_modules
(
self
,
model
,
quant_config
):
def
_load_quantized_modules
(
self
,
model
,
quant_config
,
version
):
# Real quantization of weights
# Real quantization of weights
assert
quant_config
[
"zero_point"
],
"We only support zero_point quantization now."
assert
quant_config
[
"zero_point"
],
"We only support zero_point quantization now."
if
version
==
'GEMM'
:
logging
.
warning
(
'Deprecated model weight format. Re-quantize '
'your weights again with version="GEMV" for a speedup. '
'In the next AutoAWQ version, GEMM will be deprecated.'
)
# Get blocks of model
# Get blocks of model
layers
=
self
.
get_model_layers
(
model
)
layers
=
self
.
get_model_layers
(
model
)
...
@@ -352,8 +358,17 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -352,8 +358,17 @@ class BaseAWQForCausalLM(nn.Module):
# Replace nn.Linear with WQLinear
# Replace nn.Linear with WQLinear
for
name
,
module
in
named_linears
.
items
():
for
name
,
module
in
named_linears
.
items
():
q_linear
=
WQLinear_GEMM
.
from_linear
(
if
version
==
'GEMM'
:
module
,
quant_config
[
'w_bit'
],
quant_config
[
'q_group_size'
],
True
)
q_linear_module
=
WQLinear_GEMM
elif
version
==
'GEMV'
:
q_linear_module
=
WQLinear_GEMV
q_linear
=
q_linear_module
.
from_linear
(
module
,
quant_config
[
'w_bit'
],
quant_config
[
'q_group_size'
],
True
)
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
set_op_by_name
(
layer
,
name
,
q_linear
)
set_op_by_name
(
layer
,
name
,
q_linear
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment