Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
fe314160
Commit
fe314160
authored
Sep 08, 2023
by
Casper Hansen
Browse files
Refactor modules, create separate GEMM and GEMV
parent
ef6b60e2
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
226 additions
and
14 deletions
+226
-14
awq/models/base.py
awq/models/base.py
+4
-3
awq/models/llama.py
awq/models/llama.py
+6
-6
awq/models/mpt.py
awq/models/mpt.py
+1
-1
awq/modules/__init__.py
awq/modules/__init__.py
+0
-3
awq/modules/act.py
awq/modules/act.py
+10
-0
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+0
-0
awq/modules/fused/mlp.py
awq/modules/fused/mlp.py
+0
-0
awq/modules/fused/norm.py
awq/modules/fused/norm.py
+0
-0
awq/modules/qlinear.py
awq/modules/qlinear.py
+204
-0
awq/quantize/auto_scale.py
awq/quantize/auto_scale.py
+1
-1
No files found.
awq/models/base.py
View file @
fe314160
...
...
@@ -7,10 +7,11 @@ import torch.nn as nn
from
tqdm
import
tqdm
from
collections
import
defaultdict
from
awq.modules.qlinear
import
WQLinear_GEMM
from
awq.modules.act
import
ScaledActivation
from
huggingface_hub
import
snapshot_download
from
awq.utils.calib_data
import
get_calib_dataset
from
awq.quantize.quantizer
import
pseudo_quantize_tensor
from
awq.quantize.qmodule
import
WQLinear
,
ScaledActivation
from
awq.quantize.auto_clip
import
auto_clip_block
,
apply_clip
from
awq.quantize.auto_scale
import
auto_scale_block
,
apply_scale
from
transformers
import
AutoModelForCausalLM
,
AutoConfig
,
PreTrainedModel
...
...
@@ -76,7 +77,7 @@ class BaseAWQForCausalLM(nn.Module):
scales
=
scales
.
t
().
contiguous
()
zeros
=
zeros
.
t
().
contiguous
()
q_linear
=
WQLinear
.
from_linear
(
q_linear
=
WQLinear
_GEMM
.
from_linear
(
module
,
self
.
quant_config
[
'w_bit'
],
self
.
quant_config
[
'q_group_size'
],
...
...
@@ -351,7 +352,7 @@ class BaseAWQForCausalLM(nn.Module):
# Replace nn.Linear with WQLinear
for
name
,
module
in
named_linears
.
items
():
q_linear
=
WQLinear
.
from_linear
(
q_linear
=
WQLinear
_GEMM
.
from_linear
(
module
,
quant_config
[
'w_bit'
],
quant_config
[
'q_group_size'
],
True
)
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
set_op_by_name
(
layer
,
name
,
q_linear
)
...
...
awq/models/llama.py
View file @
fe314160
...
...
@@ -67,11 +67,11 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
import
torch
from
typing
import
List
,
Tuple
from
awq.
quantize.qmodule
import
WQLinear
from
awq.
modules.qlinear
import
WQLinear
_GEMM
from
awq.utils.utils
import
set_module_name
from
awq.modules.fused
_
mlp
import
QuantLlamaMLP
from
awq.modules.fused
_
norm
import
FTLlamaRMSNorm
from
awq.modules.fused
_
attn
import
QuantLlamaAttention
from
awq.modules.fused
.
mlp
import
QuantLlamaMLP
from
awq.modules.fused
.
norm
import
FTLlamaRMSNorm
from
awq.modules.fused
.
attn
import
QuantLlamaAttention
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
LlamaRMSNorm
,
LlamaMLP
class
LlamaFuser
:
...
...
@@ -95,7 +95,7 @@ class LlamaFuser:
def
fuse_attention
(
self
):
for
name
,
module
in
self
.
attention_modules
:
qkv_layer
:
WQLinear
=
self
.
_fuse_qkv
(
module
)
qkv_layer
:
WQLinear
_GEMM
=
self
.
_fuse_qkv
(
module
)
attn
=
QuantLlamaAttention
(
module
.
hidden_size
,
module
.
num_heads
,
...
...
@@ -113,7 +113,7 @@ class LlamaFuser:
bias
=
torch
.
cat
([
q_proj
.
bias
,
k_proj
.
bias
,
v_proj
.
bias
],
dim
=
0
)
if
q_proj
.
bias
is
not
None
else
None
# create module
qkv_layer
=
WQLinear
(
qkv_layer
=
WQLinear
_GEMM
(
q_proj
.
w_bit
,
q_proj
.
group_size
,
q_proj
.
in_features
,
...
...
awq/models/mpt.py
View file @
fe314160
...
...
@@ -67,7 +67,7 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
from
typing
import
List
,
Tuple
from
awq.utils.utils
import
set_module_name
from
awq.modules.fused
_
mlp
import
QuantMPTMLP
from
awq.modules.fused
.
mlp
import
QuantMPTMLP
class
MptFuser
:
def
__init__
(
self
,
model
):
...
...
awq/modules/__init__.py
View file @
fe314160
from
.fused_norm
import
*
from
.fused_attn
import
*
from
.fused_mlp
import
*
awq/modules/act.py
0 → 100644
View file @
fe314160
import
torch.nn
as
nn
class
ScaledActivation
(
nn
.
Module
):
def
__init__
(
self
,
module
,
scales
):
super
().
__init__
()
self
.
act
=
module
self
.
scales
=
nn
.
Parameter
(
scales
.
data
)
def
forward
(
self
,
x
):
return
self
.
act
(
x
)
/
self
.
scales
.
view
(
1
,
1
,
-
1
).
to
(
x
.
device
)
awq/modules/fused
_
attn.py
→
awq/modules/fused
/
attn.py
View file @
fe314160
File moved
awq/modules/fused
_
mlp.py
→
awq/modules/fused
/
mlp.py
View file @
fe314160
File moved
awq/modules/fused
_
norm.py
→
awq/modules/fused
/
norm.py
View file @
fe314160
File moved
awq/
quantize/qmodule
.py
→
awq/
modules/qlinear
.py
View file @
fe314160
...
...
@@ -4,17 +4,24 @@ import torch.nn as nn
import
awq_inference_engine
# with CUDA kernels
class
ScaledActivation
(
nn
.
Module
):
def
__init__
(
self
,
module
,
scales
):
super
().
__init__
()
self
.
act
=
module
self
.
scales
=
nn
.
Parameter
(
scales
.
data
)
def
forward
(
self
,
x
):
return
self
.
act
(
x
)
/
self
.
scales
.
view
(
1
,
1
,
-
1
).
to
(
x
.
device
)
def
make_divisible
(
c
,
divisor
):
return
(
c
+
divisor
-
1
)
//
divisor
def
calculate_zeros_width
(
in_features
,
group_size
=
128
,
pack_num
=
8
):
if
group_size
>=
128
:
size_multiplier
=
1
elif
group_size
==
64
:
size_multiplier
=
2
elif
group_size
==
32
:
size_multiplier
=
4
else
:
raise
NotImplementedError
base_width
=
make_divisible
(
in_features
//
group_size
,
pack_num
)
base_width
=
make_divisible
(
base_width
,
size_multiplier
)
*
size_multiplier
return
base_width
class
WQLinear
(
nn
.
Module
):
class
WQLinear
_GEMM
(
nn
.
Module
):
def
__init__
(
self
,
w_bit
,
group_size
,
in_features
,
out_features
,
bias
,
dev
):
super
().
__init__
()
...
...
@@ -25,6 +32,7 @@ class WQLinear(nn.Module):
self
.
out_features
=
out_features
self
.
w_bit
=
w_bit
self
.
group_size
=
group_size
if
group_size
!=
-
1
else
in_features
# quick sanity check (make sure aligment)
assert
self
.
in_features
%
self
.
group_size
==
0
assert
out_features
%
(
32
//
self
.
w_bit
)
==
0
...
...
@@ -74,7 +82,7 @@ class WQLinear(nn.Module):
zeros
=
zeros
.
to
(
dtype
=
torch
.
int32
)
qzeros
=
torch
.
zeros
((
zeros
.
shape
[
0
],
zeros
.
shape
[
1
]
//
32
*
awq_linear
.
w_bit
),
dtype
=
torch
.
int32
,
device
=
zeros
.
device
)
for
col
in
range
(
zeros
.
shape
[
1
]
//
pack_num
):
for
col
in
range
(
zeros
.
shape
[
1
]
//
pack_num
):
if
awq_linear
.
w_bit
==
4
:
order_map
=
[
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
]
else
:
...
...
@@ -97,3 +105,100 @@ class WQLinear(nn.Module):
return
'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'
.
format
(
self
.
in_features
,
self
.
out_features
,
self
.
bias
is
not
None
,
self
.
w_bit
,
self
.
group_size
)
class
WQLinear_GEMV
(
nn
.
Module
):
def
__init__
(
self
,
w_bit
,
group_size
,
in_features
,
out_features
,
bias
,
dev
):
super
().
__init__
()
if
w_bit
not
in
[
4
]:
raise
NotImplementedError
(
"Only 4-bit are supported for now."
)
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
w_bit
=
w_bit
self
.
group_size
=
group_size
if
group_size
!=
-
1
else
in_features
self
.
split_k_iters
=
8
# quick sanity check (make sure aligment)
assert
self
.
in_features
%
self
.
group_size
==
0
assert
out_features
%
(
32
//
self
.
w_bit
)
==
0
pack_num
=
(
32
//
self
.
w_bit
)
self
.
register_buffer
(
'qweight'
,
torch
.
zeros
((
out_features
,
in_features
//
pack_num
),
dtype
=
torch
.
int32
,
device
=
dev
))
self
.
register_buffer
(
'qzeros'
,
torch
.
zeros
((
out_features
,
calculate_zeros_width
(
in_features
,
self
.
group_size
)),
dtype
=
torch
.
int32
,
device
=
dev
))
self
.
register_buffer
(
'scales'
,
torch
.
zeros
((
out_features
,
calculate_zeros_width
(
in_features
,
self
.
group_size
)
*
pack_num
),
dtype
=
torch
.
float16
,
device
=
dev
))
if
bias
:
self
.
register_buffer
(
'bias'
,
torch
.
zeros
((
out_features
),
dtype
=
torch
.
float16
,
device
=
dev
))
else
:
self
.
bias
=
None
@
classmethod
def
from_linear
(
cls
,
linear
,
w_bit
,
group_size
,
init_only
=
False
,
scales
=
None
,
zeros
=
None
):
awq_linear
=
cls
(
w_bit
,
group_size
,
linear
.
in_features
,
linear
.
out_features
,
linear
.
bias
is
not
None
,
linear
.
weight
.
device
)
if
init_only
:
# just prepare for loading sd
return
awq_linear
# need scales and zeros info for real quantization
assert
scales
is
not
None
and
zeros
is
not
None
scale_zeros
=
zeros
*
scales
pack_num
=
32
//
awq_linear
.
w_bit
qscales
=
torch
.
zeros
(
(
scales
.
shape
[
0
],
calculate_zeros_width
(
linear
.
in_features
,
group_size
)
*
pack_num
),
dtype
=
torch
.
float16
,
device
=
scales
.
device
)
qscales
[:,
:
scales
.
shape
[
1
]]
=
scales
awq_linear
.
scales
=
qscales
if
linear
.
bias
is
not
None
:
awq_linear
.
bias
=
linear
.
bias
.
clone
().
half
()
intweight
=
[]
for
idx
in
range
(
awq_linear
.
in_features
):
intweight
.
append
(
torch
.
round
((
linear
.
weight
.
data
[:,
idx
]
+
scale_zeros
[:,
idx
//
group_size
])
/
awq_linear
.
scales
[:,
idx
//
group_size
]).
to
(
torch
.
int
)[:,
None
])
intweight
=
torch
.
cat
(
intweight
,
dim
=
1
)
intweight
=
intweight
.
to
(
dtype
=
torch
.
int32
)
qweight
=
torch
.
zeros
((
intweight
.
shape
[
0
],
intweight
.
shape
[
1
]
//
32
*
awq_linear
.
w_bit
),
dtype
=
torch
.
int32
,
device
=
intweight
.
device
)
for
col
in
range
(
intweight
.
shape
[
1
]
//
pack_num
):
if
awq_linear
.
w_bit
==
4
:
order_map
=
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
]
else
:
raise
NotImplementedError
(
"Only 4-bit are supported for now."
)
for
i
in
range
(
pack_num
):
qweight_col
=
intweight
[:,
col
*
pack_num
+
order_map
[
i
]]
qweight
[:,
col
]
|=
qweight_col
<<
(
i
*
awq_linear
.
w_bit
)
awq_linear
.
qweight
=
qweight
zeros
=
zeros
.
to
(
dtype
=
torch
.
int32
)
qzeros
=
torch
.
zeros
(
(
zeros
.
shape
[
0
],
calculate_zeros_width
(
linear
.
in_features
,
group_size
)),
dtype
=
torch
.
int32
,
device
=
zeros
.
device
,
)
for
col
in
range
((
zeros
.
shape
[
1
]
+
pack_num
-
1
)
//
pack_num
):
if
awq_linear
.
w_bit
==
4
:
order_map
=
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
]
else
:
raise
NotImplementedError
(
"Only 4-bit are supported for now."
)
for
i
in
range
(
pack_num
):
if
col
*
pack_num
+
order_map
[
i
]
>=
zeros
.
shape
[
1
]:
continue
qzero_col
=
zeros
[:,
col
*
pack_num
+
order_map
[
i
]]
qzeros
[:,
col
]
|=
qzero_col
<<
(
i
*
awq_linear
.
w_bit
)
awq_linear
.
qzeros
=
qzeros
return
awq_linear
@
torch
.
no_grad
()
def
forward
(
self
,
x
):
out_shape
=
x
.
shape
[:
-
1
]
+
(
self
.
out_features
,
)
out
=
awq_inference_engine
.
gemv_forward_cuda
(
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
]),
self
.
qweight
,
self
.
scales
,
self
.
qzeros
,
self
.
group_size
)
out
=
out
+
self
.
bias
if
self
.
bias
is
not
None
else
out
return
out
.
reshape
(
out_shape
)
def
extra_repr
(
self
)
->
str
:
return
'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'
.
format
(
self
.
in_features
,
self
.
out_features
,
self
.
bias
is
not
None
,
self
.
w_bit
,
self
.
group_size
)
\ No newline at end of file
awq/quantize/auto_scale.py
View file @
fe314160
...
...
@@ -7,7 +7,7 @@ from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu
from
transformers.models.opt.modeling_opt
import
OPTDecoderLayer
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaRMSNorm
from
transformers.activations
import
NewGELUActivation
from
.
q
module
import
ScaledActivation
from
awq
.module
s.act
import
ScaledActivation
from
awq.utils.module
import
get_op_by_name
,
get_op_name
,
set_op_by_name
__all__
=
[
"auto_scale_block"
,
"apply_scale"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment