Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
ad45716f
Commit
ad45716f
authored
Oct 17, 2023
by
twaka
Browse files
gpt_neox
parent
1b54b9f9
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
65 additions
and
4 deletions
+65
-4
awq/models/__init__.py
awq/models/__init__.py
+2
-1
awq/models/auto.py
awq/models/auto.py
+2
-1
awq/models/gpt_neox.py
awq/models/gpt_neox.py
+59
-0
awq/quantize/scale.py
awq/quantize/scale.py
+2
-2
No files found.
awq/models/__init__.py
View file @
ad45716f
...
...
@@ -5,4 +5,5 @@ from .falcon import FalconAWQForCausalLM
from
.bloom
import
BloomAWQForCausalLM
from
.gptj
import
GPTJAWQForCausalLM
from
.gpt_bigcode
import
GptBigCodeAWQForCausalLM
from
.mistral
import
MistralAWQForCausalLM
\ No newline at end of file
from
.mistral
import
MistralAWQForCausalLM
from
.gpt_neox
import
GPTNeoXAWQForCausalLM
awq/models/auto.py
View file @
ad45716f
...
...
@@ -13,7 +13,8 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
"bloom"
:
BloomAWQForCausalLM
,
"gptj"
:
GPTJAWQForCausalLM
,
"gpt_bigcode"
:
GptBigCodeAWQForCausalLM
,
"mistral"
:
MistralAWQForCausalLM
"mistral"
:
MistralAWQForCausalLM
,
"gpt_neox"
:
GPTNeoXAWQForCausalLM
,
}
def
check_and_get_model_type
(
model_dir
,
trust_remote_code
=
True
):
...
...
awq/models/gpt_neox.py
0 → 100644
View file @
ad45716f
from
.base
import
BaseAWQForCausalLM
from
typing
import
Dict
from
transformers.models.gpt_neox.modeling_gpt_neox
import
GPTNeoXLayer
,
GPTNeoXForCausalLM
class
GPTNeoXAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"GPTNeoXDecoderLayer"
max_new_tokens_key
=
"max_position_embeddings"
@
staticmethod
def
get_model_layers
(
model
:
GPTNeoXForCausalLM
):
return
model
.
gpt_neox
.
layers
@
staticmethod
def
get_act_for_scaling
(
module
:
GPTNeoXLayer
):
return
dict
(
is_scalable
=
True
,
scale_name
=
"mlp.act"
,
scale_layer
=
module
.
mlp
.
act
,
scale_shape
=
module
.
mlp
.
dense_h_to_4h
.
out_features
,
)
@
staticmethod
def
move_embed
(
model
:
GPTNeoXForCausalLM
,
device
:
str
):
model
.
gpt_neox
.
embed_in
=
model
.
gpt_neox
.
embed_in
.
to
(
device
)
@
staticmethod
def
get_layers_for_scaling
(
module
:
GPTNeoXLayer
,
input_feat
,
module_kwargs
):
layers
=
[]
# attention input
layers
.
append
(
dict
(
prev_op
=
module
.
input_layernorm
,
layers
=
[
module
.
attention
.
query_key_value
],
inp
=
input_feat
[
'attention.query_key_value'
],
))
# # attention out
# layers.append(dict(
# prev_op=module.attention.query_key_value,
# layers=[module.attention.dense],
# inp=input_feat['attention.dense'],
# ))
# NOTE: assumes "use_parallel_residual": false
# linear 1
layers
.
append
(
dict
(
prev_op
=
module
.
post_attention_layernorm
,
layers
=
[
module
.
mlp
.
dense_h_to_4h
],
inp
=
input_feat
[
'mlp.dense_h_to_4h'
],
))
# linear 2
layers
.
append
(
dict
(
prev_op
=
module
.
mlp
.
act
,
layers
=
[
module
.
mlp
.
dense_4h_to_h
],
inp
=
input_feat
[
'mlp.dense_4h_to_h'
],
))
return
layers
awq/quantize/scale.py
View file @
ad45716f
...
...
@@ -5,10 +5,10 @@ from awq.modules.act import ScaledActivation
from
awq.utils.module
import
get_op_by_name
,
set_op_by_name
from
transformers.models.bloom.modeling_bloom
import
BloomGelu
from
transformers.models.llama.modeling_llama
import
LlamaRMSNorm
from
transformers.activations
import
NewGELUActivation
,
PytorchGELUTanh
from
transformers.activations
import
NewGELUActivation
,
PytorchGELUTanh
,
GELUActivation
allowed_norms
=
[
nn
.
LayerNorm
,
LlamaRMSNorm
]
allowed_act_fns
=
[
nn
.
GELU
,
BloomGelu
,
NewGELUActivation
,
PytorchGELUTanh
]
allowed_act_fns
=
[
nn
.
GELU
,
BloomGelu
,
NewGELUActivation
,
PytorchGELUTanh
,
GELUActivation
]
@
torch
.
no_grad
()
def
apply_clip
(
module
,
clip_list
:
Tuple
[
str
,
torch
.
Tensor
]):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment