Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
8eb26eb2
Unverified
Commit
8eb26eb2
authored
Sep 26, 2023
by
Casper
Committed by
GitHub
Sep 26, 2023
Browse files
Merge pull request #69 from VikParuchuri/main
Use typing classes over base types
parents
386fede8
4abfefc9
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
18 additions
and
13 deletions
+18
-13
awq/models/base.py
awq/models/base.py
+2
-2
awq/models/falcon.py
awq/models/falcon.py
+2
-1
awq/models/llama.py
awq/models/llama.py
+2
-1
awq/models/mpt.py
awq/models/mpt.py
+2
-1
awq/modules/fused/model.py
awq/modules/fused/model.py
+3
-2
awq/quantize/quantizer.py
awq/quantize/quantizer.py
+5
-4
awq/quantize/scale.py
awq/quantize/scale.py
+2
-2
No files found.
awq/models/base.py
View file @
8eb26eb2
...
...
@@ -4,7 +4,7 @@ import json
import
torch
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
typing
import
List
,
Union
from
typing
import
List
,
Union
,
Dict
from
safetensors.torch
import
save_file
from
awq.modules.act
import
ScaledActivation
from
huggingface_hub
import
snapshot_download
...
...
@@ -23,7 +23,7 @@ class BaseAWQForCausalLM(nn.Module):
self
.
model_type
:
str
=
model_type
self
.
is_quantized
:
bool
=
is_quantized
self
.
search_result
=
None
self
.
quant_config
:
d
ict
=
quant_config
self
.
quant_config
:
D
ict
=
quant_config
def
to
(
self
,
device
:
str
):
return
self
.
model
.
to
(
device
)
...
...
awq/models/falcon.py
View file @
8eb26eb2
from
.base
import
BaseAWQForCausalLM
from
typing
import
Dict
from
transformers.models.falcon.modeling_falcon
import
FalconDecoderLayer
as
OldFalconDecoderLayer
,
FalconForCausalLM
,
FalconAttention
class
FalconAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"FalconDecoderLayer"
@
staticmethod
def
fuse_layers
(
model
:
FalconForCausalLM
,
quant_config
:
d
ict
):
def
fuse_layers
(
model
:
FalconForCausalLM
,
quant_config
:
D
ict
):
fuser
=
FalconFuser
(
model
)
# TODO: Implement correctly fused modules for Falcon 40B and Falcon 180B
...
...
awq/models/llama.py
View file @
8eb26eb2
from
.base
import
BaseAWQForCausalLM
from
typing
import
Dict
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaForCausalLM
class
LlamaAWQForCausalLM
(
BaseAWQForCausalLM
):
...
...
@@ -6,7 +7,7 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
max_new_tokens_key
=
"max_position_embeddings"
@
staticmethod
def
fuse_layers
(
model
:
LlamaForCausalLM
,
quant_config
:
d
ict
):
def
fuse_layers
(
model
:
LlamaForCausalLM
,
quant_config
:
D
ict
):
fuser
=
LlamaFuser
(
model
,
quant_config
)
fuser
.
fuse_attention
()
fuser
.
fuse_rmsnorm
()
...
...
awq/models/mpt.py
View file @
8eb26eb2
from
.base
import
BaseAWQForCausalLM
from
typing
import
Dict
from
transformers.models.mpt.modeling_mpt
import
MptBlock
as
OldMptBlock
,
MptForCausalLM
class
MptAWQForCausalLM
(
BaseAWQForCausalLM
):
...
...
@@ -6,7 +7,7 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
max_new_tokens_key
=
"max_seq_len"
@
staticmethod
def
fuse_layers
(
model
:
MptForCausalLM
,
quant_config
:
d
ict
):
def
fuse_layers
(
model
:
MptForCausalLM
,
quant_config
:
D
ict
):
fuser
=
MptFuser
(
model
)
fuser
.
fuse_transformer
()
...
...
awq/modules/fused/model.py
View file @
8eb26eb2
import
torch
import
torch.nn
as
nn
from
typing
import
List
from
awq.modules.fused.block
import
MPTBlock
,
FalconDecoderLayer
from
transformers.modeling_outputs
import
BaseModelOutputWithPast
...
...
@@ -8,7 +9,7 @@ class MPTModel(nn.Module):
super
().
__init__
()
self
.
vocab_size
=
vocab_size
self
.
wte
=
wte
self
.
blocks
:
l
ist
[
MPTBlock
]
=
nn
.
ModuleList
(
blocks
)
self
.
blocks
:
L
ist
[
MPTBlock
]
=
nn
.
ModuleList
(
blocks
)
self
.
norm_f
=
norm_f
self
.
attn_uses_sequence_id
=
False
self
.
prefix_lm
=
False
...
...
@@ -36,7 +37,7 @@ class FalconModel(nn.Module):
super
().
__init__
()
self
.
vocab_size
=
vocab_size
self
.
word_embeddings
=
word_embeddings
self
.
blocks
:
l
ist
[
FalconDecoderLayer
]
=
nn
.
ModuleList
(
blocks
)
self
.
blocks
:
L
ist
[
FalconDecoderLayer
]
=
nn
.
ModuleList
(
blocks
)
self
.
ln_f
=
ln_f
self
.
attn_uses_sequence_id
=
False
self
.
prefix_lm
=
False
...
...
awq/quantize/quantizer.py
View file @
8eb26eb2
...
...
@@ -3,6 +3,7 @@ import logging
import
functools
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
typing
import
Dict
,
List
from
collections
import
defaultdict
from
awq.utils.utils
import
clear_memory
from
awq.utils.calib_data
import
get_calib_dataset
...
...
@@ -62,7 +63,7 @@ class AwqQuantizer:
clear_memory
()
# [STEP 2]: Compute and apply scale list
module_config
:
l
ist
[
d
ict
]
=
self
.
awq_model
.
get_layers_for_scaling
(
module_config
:
L
ist
[
D
ict
]
=
self
.
awq_model
.
get_layers_for_scaling
(
self
.
modules
[
i
],
input_feat
,
self
.
module_kwargs
)
scales_list
=
[
self
.
_search_best_scale
(
self
.
modules
[
i
],
**
layer
)
for
layer
in
module_config
]
...
...
@@ -78,7 +79,7 @@ class AwqQuantizer:
self
.
_apply_quant
(
self
.
modules
[
i
],
named_linears
)
clear_memory
()
def
_apply_quant
(
self
,
module
,
named_linears
:
d
ict
[
str
,
nn
.
Linear
]):
def
_apply_quant
(
self
,
module
,
named_linears
:
D
ict
[
str
,
nn
.
Linear
]):
for
name
,
linear_layer
in
named_linears
.
items
():
# NOTE: small regression in perplexity if linear layer uses .cpu().float()
linear_layer
=
linear_layer
.
cuda
().
half
()
...
...
@@ -111,7 +112,7 @@ class AwqQuantizer:
clear_memory
()
@
torch
.
no_grad
()
def
_search_best_scale
(
self
,
module
,
prev_op
,
layers
:
l
ist
[
nn
.
Linear
],
inp
:
torch
.
Tensor
,
module2inspect
=
None
,
kwargs
=
{}):
def
_search_best_scale
(
self
,
module
,
prev_op
,
layers
:
L
ist
[
nn
.
Linear
],
inp
:
torch
.
Tensor
,
module2inspect
=
None
,
kwargs
=
{}):
if
module2inspect
is
None
:
assert
len
(
layers
)
==
1
module2inspect
=
layers
[
0
]
...
...
@@ -148,7 +149,7 @@ class AwqQuantizer:
return
(
get_op_name
(
module
,
prev_op
),
tuple
([
get_op_name
(
module
,
m
)
for
m
in
layers
]),
best_scales
)
def
_compute_best_scale
(
self
,
x
,
w_max
,
x_max
,
module2inspect
,
linears2scale
:
l
ist
[
nn
.
Linear
],
def
_compute_best_scale
(
self
,
x
,
w_max
,
x_max
,
module2inspect
,
linears2scale
:
L
ist
[
nn
.
Linear
],
fp16_output
,
kwargs
=
{}):
"""
Compute loss and select best scales
...
...
awq/quantize/scale.py
View file @
8eb26eb2
import
torch
import
torch.nn
as
nn
from
typing
import
Tuple
from
typing
import
Tuple
,
List
from
awq.modules.act
import
ScaledActivation
from
awq.utils.module
import
get_op_by_name
,
set_op_by_name
from
transformers.models.bloom.modeling_bloom
import
BloomGelu
...
...
@@ -62,7 +62,7 @@ def apply_scale(module, scales_list, input_feat_dict=None):
scales
.
cpu
()
@
torch
.
no_grad
()
def
scale_ln_fcs
(
ln
:
nn
.
Linear
,
fcs
:
l
ist
[
nn
.
Linear
],
scales
:
torch
.
Tensor
):
def
scale_ln_fcs
(
ln
:
nn
.
Linear
,
fcs
:
L
ist
[
nn
.
Linear
],
scales
:
torch
.
Tensor
):
if
not
isinstance
(
fcs
,
list
):
fcs
=
[
fcs
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment