Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
7fbe9bbc
Unverified
Commit
7fbe9bbc
authored
Aug 27, 2023
by
Casper
Committed by
GitHub
Aug 27, 2023
Browse files
Merge pull request
#1
from jamesdborin/new_model/gptj
Add GPTJ Support
parents
3a8072a1
50d1025f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
63 additions
and
8 deletions
+63
-8
awq/models/__init__.py
awq/models/__init__.py
+2
-1
awq/models/auto.py
awq/models/auto.py
+2
-1
awq/models/base.py
awq/models/base.py
+3
-3
awq/models/gptj.py
awq/models/gptj.py
+53
-0
awq/quantize/auto_scale.py
awq/quantize/auto_scale.py
+3
-3
No files found.
awq/models/__init__.py
View file @
7fbe9bbc
...
...
@@ -2,4 +2,5 @@ from .mpt import MptAWQForCausalLM
from
.llama
import
LlamaAWQForCausalLM
from
.opt
import
OptAWQForCausalLM
from
.falcon
import
FalconAWQForCausalLM
from
.bloom
import
BloomAWQForCausalLM
\ No newline at end of file
from
.bloom
import
BloomAWQForCausalLM
from
.gptj
import
GPTJAWQForCausalLM
\ No newline at end of file
awq/models/auto.py
View file @
7fbe9bbc
...
...
@@ -8,7 +8,8 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
"opt"
:
OptAWQForCausalLM
,
"RefinedWeb"
:
FalconAWQForCausalLM
,
"RefinedWebModel"
:
FalconAWQForCausalLM
,
"bloom"
:
BloomAWQForCausalLM
"bloom"
:
BloomAWQForCausalLM
,
"gptj"
:
GPTJAWQForCausalLM
}
def
check_and_get_model_type
(
model_dir
,
trust_remote_code
=
True
):
...
...
awq/models/base.py
View file @
7fbe9bbc
...
...
@@ -113,8 +113,8 @@ class BaseAWQForCausalLM(nn.Module):
super
().
__init__
()
self
.
module
=
module
def
forward
(
self
,
inp
,
**
kwargs
):
inps
.
append
(
inp
)
def
forward
(
self
,
hijacked_inputs
,
**
kwargs
):
inps
.
append
(
hijacked_inputs
)
layer_kwargs
.
update
(
kwargs
)
raise
ValueError
# early exit to break later inference
...
...
@@ -358,4 +358,4 @@ class BaseAWQForCausalLM(nn.Module):
# scale activation
scaled_act
=
ScaledActivation
(
scale_dict
[
'scale_layer'
],
scale_like
)
set_op_by_name
(
layer
,
scale_dict
[
'scale_name'
],
scaled_act
)
\ No newline at end of file
set_op_by_name
(
layer
,
scale_dict
[
'scale_name'
],
scaled_act
)
awq/models/gptj.py
0 → 100644
View file @
7fbe9bbc
from
.base
import
BaseAWQForCausalLM
from
transformers.models.gptj.modeling_gptj
import
GPTJForCausalLM
,
GPTJBlock
class
GPTJAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"GPTJBlock"
max_new_tokens_key
=
"n_positions"
@
staticmethod
def
get_model_layers
(
model
:
GPTJForCausalLM
):
return
model
.
transformer
.
h
@
staticmethod
def
get_act_for_scaling
(
module
:
GPTJBlock
):
return
dict
(
is_scalable
=
True
,
scale_name
=
"mlp.act"
,
scale_layer
=
module
.
mlp
.
act
,
scale_shape
=
module
.
mlp
.
fc_in
.
out_features
)
@
staticmethod
def
move_embed
(
model
:
GPTJForCausalLM
,
device
:
str
):
model
.
transformer
.
wte
=
model
.
transformer
.
wte
.
to
(
device
)
@
staticmethod
def
get_layers_for_scaling
(
module
:
GPTJBlock
,
input_feat
,
module_kwargs
):
layers
=
[]
# attention input + linear 1
layers
.
append
(
dict
(
prev_op
=
module
.
ln_1
,
layers
=
[
module
.
attn
.
q_proj
,
module
.
attn
.
k_proj
,
module
.
attn
.
v_proj
,
module
.
mlp
.
fc_in
],
inp
=
input_feat
[
'attn.q_proj'
],
module2inspect
=
module
,
kwargs
=
module_kwargs
))
# attention out
layers
.
append
(
dict
(
prev_op
=
module
.
attn
.
v_proj
,
layers
=
[
module
.
attn
.
out_proj
],
inp
=
input_feat
[
'attn.out_proj'
],
))
# linear 2
layers
.
append
(
dict
(
prev_op
=
module
.
mlp
.
act
,
layers
=
[
module
.
mlp
.
fc_out
],
inp
=
input_feat
[
'mlp.fc_out'
],
))
return
layers
\ No newline at end of file
awq/quantize/auto_scale.py
View file @
7fbe9bbc
...
...
@@ -5,7 +5,7 @@ import torch.nn as nn
from
transformers.models.bloom.modeling_bloom
import
BloomBlock
,
BloomGelu
from
transformers.models.opt.modeling_opt
import
OPTDecoderLayer
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaRMSNorm
from
transformers.activations
import
NewGELUActivation
from
.qmodule
import
ScaledActivation
from
awq.utils.module
import
get_op_by_name
,
get_op_name
,
set_op_by_name
...
...
@@ -79,7 +79,7 @@ def scale_fc_fc(fc1, fc2, scales):
@
torch
.
no_grad
()
def
scale_gelu_fc
(
gelu
,
fc
,
scales
):
assert
isinstance
(
gelu
,
nn
.
GELU
)
or
isinstance
(
gelu
,
BloomGelu
)
assert
any
(
isinstance
(
gelu
,
t
)
f
or
t
in
[
nn
.
GELU
,
BloomGelu
,
NewGELUActivation
]
)
assert
isinstance
(
fc
,
nn
.
Linear
)
fc
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
).
to
(
fc
.
weight
.
device
))
...
...
@@ -195,7 +195,7 @@ def apply_scale(module, scales_list, input_feat_dict=None):
scale_fc_fc
(
prev_op
,
layers
[
0
],
scales
)
elif
isinstance
(
prev_op
,
(
nn
.
LayerNorm
,
LlamaRMSNorm
)):
scale_ln_fcs
(
prev_op
,
layers
,
scales
)
elif
isinstance
(
prev_op
,
nn
.
GELU
)
or
isinstance
(
prev_op
,
BloomGelu
):
elif
any
(
isinstance
(
prev_op
,
t
)
f
or
t
in
[
nn
.
GELU
,
BloomGelu
,
NewGELUActivation
]
):
new_module
=
ScaledActivation
(
prev_op
,
scales
)
set_op_by_name
(
module
,
prev_op_name
,
new_module
)
scale_gelu_fc
(
prev_op
,
layers
[
0
],
scales
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment