Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
7fbe9bbc
Unverified
Commit
7fbe9bbc
authored
Aug 27, 2023
by
Casper
Committed by
GitHub
Aug 27, 2023
Browse files
Merge pull request
#1
from jamesdborin/new_model/gptj
Add GPTJ Support
parents
3a8072a1
50d1025f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
63 additions
and
8 deletions
+63
-8
awq/models/__init__.py
awq/models/__init__.py
+2
-1
awq/models/auto.py
awq/models/auto.py
+2
-1
awq/models/base.py
awq/models/base.py
+3
-3
awq/models/gptj.py
awq/models/gptj.py
+53
-0
awq/quantize/auto_scale.py
awq/quantize/auto_scale.py
+3
-3
No files found.
awq/models/__init__.py
View file @
7fbe9bbc
...
@@ -2,4 +2,5 @@ from .mpt import MptAWQForCausalLM
...
@@ -2,4 +2,5 @@ from .mpt import MptAWQForCausalLM
from
.llama
import
LlamaAWQForCausalLM
from
.llama
import
LlamaAWQForCausalLM
from
.opt
import
OptAWQForCausalLM
from
.opt
import
OptAWQForCausalLM
from
.falcon
import
FalconAWQForCausalLM
from
.falcon
import
FalconAWQForCausalLM
from
.bloom
import
BloomAWQForCausalLM
from
.bloom
import
BloomAWQForCausalLM
\ No newline at end of file
from
.gptj
import
GPTJAWQForCausalLM
\ No newline at end of file
awq/models/auto.py
View file @
7fbe9bbc
...
@@ -8,7 +8,8 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
...
@@ -8,7 +8,8 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
"opt"
:
OptAWQForCausalLM
,
"opt"
:
OptAWQForCausalLM
,
"RefinedWeb"
:
FalconAWQForCausalLM
,
"RefinedWeb"
:
FalconAWQForCausalLM
,
"RefinedWebModel"
:
FalconAWQForCausalLM
,
"RefinedWebModel"
:
FalconAWQForCausalLM
,
"bloom"
:
BloomAWQForCausalLM
"bloom"
:
BloomAWQForCausalLM
,
"gptj"
:
GPTJAWQForCausalLM
}
}
def
check_and_get_model_type
(
model_dir
,
trust_remote_code
=
True
):
def
check_and_get_model_type
(
model_dir
,
trust_remote_code
=
True
):
...
...
awq/models/base.py
View file @
7fbe9bbc
...
@@ -113,8 +113,8 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -113,8 +113,8 @@ class BaseAWQForCausalLM(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
module
=
module
self
.
module
=
module
def
forward
(
self
,
inp
,
**
kwargs
):
def
forward
(
self
,
hijacked_inputs
,
**
kwargs
):
inps
.
append
(
inp
)
inps
.
append
(
hijacked_inputs
)
layer_kwargs
.
update
(
kwargs
)
layer_kwargs
.
update
(
kwargs
)
raise
ValueError
# early exit to break later inference
raise
ValueError
# early exit to break later inference
...
@@ -358,4 +358,4 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -358,4 +358,4 @@ class BaseAWQForCausalLM(nn.Module):
# scale activation
# scale activation
scaled_act
=
ScaledActivation
(
scale_dict
[
'scale_layer'
],
scale_like
)
scaled_act
=
ScaledActivation
(
scale_dict
[
'scale_layer'
],
scale_like
)
set_op_by_name
(
layer
,
scale_dict
[
'scale_name'
],
scaled_act
)
set_op_by_name
(
layer
,
scale_dict
[
'scale_name'
],
scaled_act
)
\ No newline at end of file
awq/models/gptj.py
0 → 100644
View file @
7fbe9bbc
from
.base
import
BaseAWQForCausalLM
from
transformers.models.gptj.modeling_gptj
import
GPTJForCausalLM
,
GPTJBlock
class
GPTJAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"GPTJBlock"
max_new_tokens_key
=
"n_positions"
@
staticmethod
def
get_model_layers
(
model
:
GPTJForCausalLM
):
return
model
.
transformer
.
h
@
staticmethod
def
get_act_for_scaling
(
module
:
GPTJBlock
):
return
dict
(
is_scalable
=
True
,
scale_name
=
"mlp.act"
,
scale_layer
=
module
.
mlp
.
act
,
scale_shape
=
module
.
mlp
.
fc_in
.
out_features
)
@
staticmethod
def
move_embed
(
model
:
GPTJForCausalLM
,
device
:
str
):
model
.
transformer
.
wte
=
model
.
transformer
.
wte
.
to
(
device
)
@
staticmethod
def
get_layers_for_scaling
(
module
:
GPTJBlock
,
input_feat
,
module_kwargs
):
layers
=
[]
# attention input + linear 1
layers
.
append
(
dict
(
prev_op
=
module
.
ln_1
,
layers
=
[
module
.
attn
.
q_proj
,
module
.
attn
.
k_proj
,
module
.
attn
.
v_proj
,
module
.
mlp
.
fc_in
],
inp
=
input_feat
[
'attn.q_proj'
],
module2inspect
=
module
,
kwargs
=
module_kwargs
))
# attention out
layers
.
append
(
dict
(
prev_op
=
module
.
attn
.
v_proj
,
layers
=
[
module
.
attn
.
out_proj
],
inp
=
input_feat
[
'attn.out_proj'
],
))
# linear 2
layers
.
append
(
dict
(
prev_op
=
module
.
mlp
.
act
,
layers
=
[
module
.
mlp
.
fc_out
],
inp
=
input_feat
[
'mlp.fc_out'
],
))
return
layers
\ No newline at end of file
awq/quantize/auto_scale.py
View file @
7fbe9bbc
...
@@ -5,7 +5,7 @@ import torch.nn as nn
...
@@ -5,7 +5,7 @@ import torch.nn as nn
from
transformers.models.bloom.modeling_bloom
import
BloomBlock
,
BloomGelu
from
transformers.models.bloom.modeling_bloom
import
BloomBlock
,
BloomGelu
from
transformers.models.opt.modeling_opt
import
OPTDecoderLayer
from
transformers.models.opt.modeling_opt
import
OPTDecoderLayer
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaRMSNorm
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaRMSNorm
from
transformers.activations
import
NewGELUActivation
from
.qmodule
import
ScaledActivation
from
.qmodule
import
ScaledActivation
from
awq.utils.module
import
get_op_by_name
,
get_op_name
,
set_op_by_name
from
awq.utils.module
import
get_op_by_name
,
get_op_name
,
set_op_by_name
...
@@ -79,7 +79,7 @@ def scale_fc_fc(fc1, fc2, scales):
...
@@ -79,7 +79,7 @@ def scale_fc_fc(fc1, fc2, scales):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
scale_gelu_fc
(
gelu
,
fc
,
scales
):
def
scale_gelu_fc
(
gelu
,
fc
,
scales
):
assert
isinstance
(
gelu
,
nn
.
GELU
)
or
isinstance
(
gelu
,
BloomGelu
)
assert
any
(
isinstance
(
gelu
,
t
)
f
or
t
in
[
nn
.
GELU
,
BloomGelu
,
NewGELUActivation
]
)
assert
isinstance
(
fc
,
nn
.
Linear
)
assert
isinstance
(
fc
,
nn
.
Linear
)
fc
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
).
to
(
fc
.
weight
.
device
))
fc
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
).
to
(
fc
.
weight
.
device
))
...
@@ -195,7 +195,7 @@ def apply_scale(module, scales_list, input_feat_dict=None):
...
@@ -195,7 +195,7 @@ def apply_scale(module, scales_list, input_feat_dict=None):
scale_fc_fc
(
prev_op
,
layers
[
0
],
scales
)
scale_fc_fc
(
prev_op
,
layers
[
0
],
scales
)
elif
isinstance
(
prev_op
,
(
nn
.
LayerNorm
,
LlamaRMSNorm
)):
elif
isinstance
(
prev_op
,
(
nn
.
LayerNorm
,
LlamaRMSNorm
)):
scale_ln_fcs
(
prev_op
,
layers
,
scales
)
scale_ln_fcs
(
prev_op
,
layers
,
scales
)
elif
isinstance
(
prev_op
,
nn
.
GELU
)
or
isinstance
(
prev_op
,
BloomGelu
):
elif
any
(
isinstance
(
prev_op
,
t
)
f
or
t
in
[
nn
.
GELU
,
BloomGelu
,
NewGELUActivation
]
):
new_module
=
ScaledActivation
(
prev_op
,
scales
)
new_module
=
ScaledActivation
(
prev_op
,
scales
)
set_op_by_name
(
module
,
prev_op_name
,
new_module
)
set_op_by_name
(
module
,
prev_op_name
,
new_module
)
scale_gelu_fc
(
prev_op
,
layers
[
0
],
scales
)
scale_gelu_fc
(
prev_op
,
layers
[
0
],
scales
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment