Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
6b1c96c7
Commit
6b1c96c7
authored
Aug 25, 2023
by
EC2 Default User
Browse files
fixed catcher input name
parent
fac1af55
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
5 deletions
+5
-5
awq/models/base.py
awq/models/base.py
+2
-2
awq/quantize/auto_scale.py
awq/quantize/auto_scale.py
+3
-3
No files found.
awq/models/base.py
View file @
6b1c96c7
...
@@ -113,8 +113,8 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -113,8 +113,8 @@ class BaseAWQForCausalLM(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
module
=
module
self
.
module
=
module
def
forward
(
self
,
inp
,
**
kwargs
):
def
forward
(
self
,
hidden_states
,
**
kwargs
):
inps
.
append
(
inp
)
inps
.
append
(
hidden_states
)
layer_kwargs
.
update
(
kwargs
)
layer_kwargs
.
update
(
kwargs
)
raise
ValueError
# early exit to break later inference
raise
ValueError
# early exit to break later inference
...
...
awq/quantize/auto_scale.py
View file @
6b1c96c7
...
@@ -5,7 +5,7 @@ import torch.nn as nn
...
@@ -5,7 +5,7 @@ import torch.nn as nn
from
transformers.models.bloom.modeling_bloom
import
BloomBlock
,
BloomGelu
from
transformers.models.bloom.modeling_bloom
import
BloomBlock
,
BloomGelu
from
transformers.models.opt.modeling_opt
import
OPTDecoderLayer
from
transformers.models.opt.modeling_opt
import
OPTDecoderLayer
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaRMSNorm
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaRMSNorm
from
transformers.activations
import
NewGELUActivation
from
.qmodule
import
ScaledActivation
from
.qmodule
import
ScaledActivation
from
awq.utils.module
import
get_op_by_name
,
get_op_name
,
set_op_by_name
from
awq.utils.module
import
get_op_by_name
,
get_op_name
,
set_op_by_name
...
@@ -79,7 +79,7 @@ def scale_fc_fc(fc1, fc2, scales):
...
@@ -79,7 +79,7 @@ def scale_fc_fc(fc1, fc2, scales):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
scale_gelu_fc
(
gelu
,
fc
,
scales
):
def
scale_gelu_fc
(
gelu
,
fc
,
scales
):
assert
isinstance
(
gelu
,
nn
.
GELU
)
or
isinstance
(
gelu
,
BloomGelu
)
assert
any
(
isinstance
(
gelu
,
t
)
f
or
t
in
[
nn
.
GELU
,
BloomGelu
,
NewGELUActivation
]
)
assert
isinstance
(
fc
,
nn
.
Linear
)
assert
isinstance
(
fc
,
nn
.
Linear
)
fc
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
).
to
(
fc
.
weight
.
device
))
fc
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
).
to
(
fc
.
weight
.
device
))
...
@@ -195,7 +195,7 @@ def apply_scale(module, scales_list, input_feat_dict=None):
...
@@ -195,7 +195,7 @@ def apply_scale(module, scales_list, input_feat_dict=None):
scale_fc_fc
(
prev_op
,
layers
[
0
],
scales
)
scale_fc_fc
(
prev_op
,
layers
[
0
],
scales
)
elif
isinstance
(
prev_op
,
(
nn
.
LayerNorm
,
LlamaRMSNorm
)):
elif
isinstance
(
prev_op
,
(
nn
.
LayerNorm
,
LlamaRMSNorm
)):
scale_ln_fcs
(
prev_op
,
layers
,
scales
)
scale_ln_fcs
(
prev_op
,
layers
,
scales
)
elif
isinstance
(
prev_op
,
nn
.
GELU
)
or
isinstance
(
prev_op
,
BloomGelu
):
elif
any
(
isinstance
(
prev_op
,
t
)
f
or
t
in
[
nn
.
GELU
,
BloomGelu
,
NewGELUActivation
]
):
new_module
=
ScaledActivation
(
prev_op
,
scales
)
new_module
=
ScaledActivation
(
prev_op
,
scales
)
set_op_by_name
(
module
,
prev_op_name
,
new_module
)
set_op_by_name
(
module
,
prev_op_name
,
new_module
)
scale_gelu_fc
(
prev_op
,
layers
[
0
],
scales
)
scale_gelu_fc
(
prev_op
,
layers
[
0
],
scales
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment