Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
5db43a7f
Commit
5db43a7f
authored
Sep 08, 2023
by
Casper Hansen
Browse files
Implement GEMM/GEMV in quantize function and fused modules
parent
9b2946b6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
41 additions
and
22 deletions
+41
-22
awq/models/base.py
awq/models/base.py
+27
-14
awq/models/llama.py
awq/models/llama.py
+13
-7
awq/models/mpt.py
awq/models/mpt.py
+1
-1
awq/modules/linear.py
awq/modules/linear.py
+0
-0
No files found.
awq/models/base.py
View file @
5db43a7f
...
@@ -8,11 +8,11 @@ import torch.nn as nn
...
@@ -8,11 +8,11 @@ import torch.nn as nn
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
collections
import
defaultdict
from
collections
import
defaultdict
from
awq.modules.qlinear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
awq.modules.act
import
ScaledActivation
from
awq.modules.act
import
ScaledActivation
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
from
awq.utils.calib_data
import
get_calib_dataset
from
awq.utils.calib_data
import
get_calib_dataset
from
awq.quantize.quantizer
import
pseudo_quantize_tensor
from
awq.quantize.quantizer
import
pseudo_quantize_tensor
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
awq.quantize.auto_clip
import
auto_clip_block
,
apply_clip
from
awq.quantize.auto_clip
import
auto_clip_block
,
apply_clip
from
awq.quantize.auto_scale
import
auto_scale_block
,
apply_scale
from
awq.quantize.auto_scale
import
auto_scale_block
,
apply_scale
from
transformers
import
AutoModelForCausalLM
,
AutoConfig
,
PreTrainedModel
from
transformers
import
AutoModelForCausalLM
,
AutoConfig
,
PreTrainedModel
...
@@ -43,6 +43,11 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -43,6 +43,11 @@ class BaseAWQForCausalLM(nn.Module):
auto_scale
=
True
,
mse_range
=
True
,
run_search
=
True
,
run_quant
=
True
,
auto_scale
=
True
,
mse_range
=
True
,
run_search
=
True
,
run_quant
=
True
,
calib_data
=
"pileval"
):
calib_data
=
"pileval"
):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
quant_config
[
"version"
]
=
"GEMM"
if
'version'
not
in
quant_config
.
keys
()
else
quant_config
[
"version"
]
if
quant_config
[
"version"
]
==
"GEMM"
:
logging
.
warning
(
'Deprecated model weight format. Re-quantize '
'your weights again with version="GEMV" for a speedup. '
'In the next AutoAWQ version, GEMM will be deprecated.'
)
if
run_search
:
if
run_search
:
self
.
search_result
=
self
.
_awq_search
(
tokenizer
,
quant_config
,
n_samples
=
n_samples
,
seqlen
=
seqlen
,
self
.
search_result
=
self
.
_awq_search
(
tokenizer
,
quant_config
,
n_samples
=
n_samples
,
seqlen
=
seqlen
,
...
@@ -53,7 +58,7 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -53,7 +58,7 @@ class BaseAWQForCausalLM(nn.Module):
self
.
is_quantized
=
True
self
.
is_quantized
=
True
@
staticmethod
@
staticmethod
def
fuse_layers
(
model
):
def
fuse_layers
(
model
,
quant_config
):
pass
pass
def
_awq_quant
(
self
):
def
_awq_quant
(
self
):
...
@@ -78,12 +83,17 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -78,12 +83,17 @@ class BaseAWQForCausalLM(nn.Module):
scales
=
scales
.
t
().
contiguous
()
scales
=
scales
.
t
().
contiguous
()
zeros
=
zeros
.
t
().
contiguous
()
zeros
=
zeros
.
t
().
contiguous
()
q_linear
=
WQLinear_GEMM
.
from_linear
(
if
self
.
quant_config
[
"version"
]
==
'GEMM'
:
module
,
q_linear_module
=
WQLinear_GEMM
self
.
quant_config
[
'w_bit'
],
elif
self
.
quant_config
[
"version"
]
==
'GEMV'
:
self
.
quant_config
[
'q_group_size'
],
q_linear_module
=
WQLinear_GEMV
False
,
scales
,
q_linear
=
q_linear_module
.
from_linear
(
module
,
self
.
quant_config
[
'w_bit'
],
self
.
quant_config
[
'q_group_size'
],
False
,
scales
,
zeros
zeros
)
)
...
@@ -275,9 +285,12 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -275,9 +285,12 @@ class BaseAWQForCausalLM(nn.Module):
if
os
.
path
.
exists
(
quant_config_path
):
if
os
.
path
.
exists
(
quant_config_path
):
with
open
(
quant_config_path
,
'r'
)
as
file
:
with
open
(
quant_config_path
,
'r'
)
as
file
:
quant_config
=
json
.
loads
(
file
.
read
())
quant_config
=
json
.
loads
(
file
.
read
())
if
"version"
not
in
quant_config
.
keys
():
quant_config
[
"version"
]
=
version
else
:
else
:
# Default config that works for most models
# Default config that works for most models
quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
128
,
"w_bit"
:
4
,
"version"
:
"GEMM"
}
quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
128
,
"w_bit"
:
4
,
"version"
:
version
}
# Load model config and set max generation length
# Load model config and set max generation length
if
max_new_tokens
is
None
and
hasattr
(
self
,
'max_new_tokens_key'
):
if
max_new_tokens
is
None
and
hasattr
(
self
,
'max_new_tokens_key'
):
...
@@ -295,7 +308,7 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -295,7 +308,7 @@ class BaseAWQForCausalLM(nn.Module):
# Only need to replace layers if a model is AWQ quantized
# Only need to replace layers if a model is AWQ quantized
if
is_quantized
:
if
is_quantized
:
# Prepare WQLinear layers, replace nn.Linear
# Prepare WQLinear layers, replace nn.Linear
self
.
_load_quantized_modules
(
self
,
model
,
quant_config
,
version
)
self
.
_load_quantized_modules
(
self
,
model
,
quant_config
,
quant_config
[
"
version
"
]
)
model
.
tie_weights
()
model
.
tie_weights
()
...
@@ -315,7 +328,7 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -315,7 +328,7 @@ class BaseAWQForCausalLM(nn.Module):
)
)
if
fuse_layers
:
if
fuse_layers
:
self
.
fuse_layers
(
model
)
self
.
fuse_layers
(
model
,
quant_config
)
else
:
else
:
# If not quantized, must load with AutoModelForCausalLM
# If not quantized, must load with AutoModelForCausalLM
...
@@ -364,9 +377,9 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -364,9 +377,9 @@ class BaseAWQForCausalLM(nn.Module):
q_linear_module
=
WQLinear_GEMV
q_linear_module
=
WQLinear_GEMV
q_linear
=
q_linear_module
.
from_linear
(
q_linear
=
q_linear_module
.
from_linear
(
module
,
module
,
quant_config
[
'w_bit'
],
quant_config
[
'w_bit'
],
quant_config
[
'q_group_size'
],
quant_config
[
'q_group_size'
],
True
True
)
)
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
...
...
awq/models/llama.py
View file @
5db43a7f
...
@@ -6,8 +6,8 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
...
@@ -6,8 +6,8 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
max_new_tokens_key
=
"max_position_embeddings"
max_new_tokens_key
=
"max_position_embeddings"
@
staticmethod
@
staticmethod
def
fuse_layers
(
model
:
LlamaForCausalLM
):
def
fuse_layers
(
model
:
LlamaForCausalLM
,
quant_config
:
dict
):
fuser
=
LlamaFuser
(
model
)
fuser
=
LlamaFuser
(
model
,
quant_config
)
fuser
.
fuse_attention
()
fuser
.
fuse_attention
()
fuser
.
fuse_rmsnorm
()
fuser
.
fuse_rmsnorm
()
fuser
.
fuse_mlp
()
fuser
.
fuse_mlp
()
...
@@ -66,17 +66,18 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
...
@@ -66,17 +66,18 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
return
layers
return
layers
import
torch
import
torch
from
typing
import
List
,
Tuple
from
typing
import
List
,
Tuple
,
Union
from
awq.modules.qlinear
import
WQLinear_GEMM
from
awq.utils.utils
import
set_module_name
from
awq.utils.utils
import
set_module_name
from
awq.modules.fused.mlp
import
QuantLlamaMLP
from
awq.modules.fused.mlp
import
QuantLlamaMLP
from
awq.modules.fused.norm
import
FTLlamaRMSNorm
from
awq.modules.fused.norm
import
FTLlamaRMSNorm
from
awq.modules.fused.attn
import
QuantLlamaAttention
from
awq.modules.fused.attn
import
QuantLlamaAttention
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
LlamaRMSNorm
,
LlamaMLP
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
LlamaRMSNorm
,
LlamaMLP
class
LlamaFuser
:
class
LlamaFuser
:
def
__init__
(
self
,
model
):
def
__init__
(
self
,
model
,
quant_config
):
self
.
model
=
model
self
.
model
=
model
self
.
quant_config
=
quant_config
self
.
attention_modules
:
List
[
Tuple
[
str
,
LlamaAttention
]]
=
[
self
.
attention_modules
:
List
[
Tuple
[
str
,
LlamaAttention
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
...
@@ -95,7 +96,7 @@ class LlamaFuser:
...
@@ -95,7 +96,7 @@ class LlamaFuser:
def
fuse_attention
(
self
):
def
fuse_attention
(
self
):
for
name
,
module
in
self
.
attention_modules
:
for
name
,
module
in
self
.
attention_modules
:
qkv_layer
:
WQLinear_GEMM
=
self
.
_fuse_qkv
(
module
)
qkv_layer
:
Union
[
WQLinear_GEMM
,
WQLinear_GEMV
]
=
self
.
_fuse_qkv
(
module
)
attn
=
QuantLlamaAttention
(
attn
=
QuantLlamaAttention
(
module
.
hidden_size
,
module
.
hidden_size
,
module
.
num_heads
,
module
.
num_heads
,
...
@@ -113,7 +114,12 @@ class LlamaFuser:
...
@@ -113,7 +114,12 @@ class LlamaFuser:
bias
=
torch
.
cat
([
q_proj
.
bias
,
k_proj
.
bias
,
v_proj
.
bias
],
dim
=
0
)
if
q_proj
.
bias
is
not
None
else
None
bias
=
torch
.
cat
([
q_proj
.
bias
,
k_proj
.
bias
,
v_proj
.
bias
],
dim
=
0
)
if
q_proj
.
bias
is
not
None
else
None
# create module
# create module
qkv_layer
=
WQLinear_GEMM
(
if
self
.
quant_config
[
"version"
]
==
'GEMM'
:
qkv_module
=
WQLinear_GEMM
elif
self
.
quant_config
[
"version"
]
==
'GEMV'
:
qkv_module
=
WQLinear_GEMV
qkv_layer
=
qkv_module
(
q_proj
.
w_bit
,
q_proj
.
w_bit
,
q_proj
.
group_size
,
q_proj
.
group_size
,
q_proj
.
in_features
,
q_proj
.
in_features
,
...
...
awq/models/mpt.py
View file @
5db43a7f
...
@@ -6,7 +6,7 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
...
@@ -6,7 +6,7 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
max_new_tokens_key
=
"max_seq_len"
max_new_tokens_key
=
"max_seq_len"
@
staticmethod
@
staticmethod
def
fuse_layers
(
model
:
MptForCausalLM
):
def
fuse_layers
(
model
:
MptForCausalLM
,
quant_config
:
dict
):
fuser
=
MptFuser
(
model
)
fuser
=
MptFuser
(
model
)
fuser
.
fuse_mlp
()
fuser
.
fuse_mlp
()
...
...
awq/modules/
q
linear.py
→
awq/modules/linear.py
View file @
5db43a7f
File moved
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment