Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
3b362c0d
"vscode:/vscode.git/clone" did not exist on "0d2a151ec81344e81fd345f3e53edd65ff856d5b"
Unverified
Commit
3b362c0d
authored
Nov 15, 2023
by
Younes Belkada
Committed by
GitHub
Nov 15, 2023
Browse files
[`core`] Replace `QuantLlamaMLP` with `QuantFusedMLP` (#188)
parent
09c73fb2
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
40 additions
and
13 deletions
+40
-13
awq/models/aquila.py
awq/models/aquila.py
+2
-2
awq/models/llama.py
awq/models/llama.py
+2
-2
awq/models/mistral.py
awq/models/mistral.py
+2
-2
awq/models/yi.py
awq/models/yi.py
+2
-2
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+10
-1
awq/modules/fused/mlp.py
awq/modules/fused/mlp.py
+22
-4
No files found.
awq/models/aquila.py
View file @
3b362c0d
...
...
@@ -8,7 +8,7 @@ from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer
as
OldAquilaDecoderLayer
,
LlamaForCausalLM
as
OldAquilaForCausalLM
)
from
awq.modules.fused.mlp
import
Quant
Llama
MLP
from
awq.modules.fused.mlp
import
Quant
Fused
MLP
from
awq.modules.fused.norm
import
FasterTransformerRMSNorm
class
AquilaAWQForCausalLM
(
BaseAWQForCausalLM
):
...
...
@@ -95,7 +95,7 @@ class AquilaFuser:
module
.
self_attn
.
k_proj
,
module
.
self_attn
.
v_proj
)
mlp
=
Quant
Llama
MLP
(
mlp
=
Quant
Fused
MLP
(
module
.
mlp
.
gate_proj
,
module
.
mlp
.
down_proj
,
module
.
mlp
.
up_proj
...
...
awq/models/llama.py
View file @
3b362c0d
...
...
@@ -8,7 +8,7 @@ from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer
as
OldLlamaDecoderLayer
,
LlamaForCausalLM
as
OldLlamaForCausalLM
)
from
awq.modules.fused.mlp
import
Quant
Llama
MLP
from
awq.modules.fused.mlp
import
Quant
Fused
MLP
from
awq.modules.fused.norm
import
FasterTransformerRMSNorm
class
LlamaAWQForCausalLM
(
BaseAWQForCausalLM
):
...
...
@@ -95,7 +95,7 @@ class LlamaFuser:
module
.
self_attn
.
k_proj
,
module
.
self_attn
.
v_proj
)
mlp
=
Quant
Llama
MLP
(
mlp
=
Quant
Fused
MLP
(
module
.
mlp
.
gate_proj
,
module
.
mlp
.
down_proj
,
module
.
mlp
.
up_proj
...
...
awq/models/mistral.py
View file @
3b362c0d
...
...
@@ -8,7 +8,7 @@ from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer
as
OldMistralDecoderLayer
,
MistralForCausalLM
as
OldMistralForCausalLM
)
from
awq.modules.fused.mlp
import
Quant
Llama
MLP
from
awq.modules.fused.mlp
import
Quant
Fused
MLP
from
awq.modules.fused.norm
import
FasterTransformerRMSNorm
class
MistralAWQForCausalLM
(
BaseAWQForCausalLM
):
...
...
@@ -95,7 +95,7 @@ class MistralFuser:
module
.
self_attn
.
k_proj
,
module
.
self_attn
.
v_proj
)
mlp
=
Quant
Llama
MLP
(
mlp
=
Quant
Fused
MLP
(
module
.
mlp
.
gate_proj
,
module
.
mlp
.
down_proj
,
module
.
mlp
.
up_proj
...
...
awq/models/yi.py
View file @
3b362c0d
...
...
@@ -4,7 +4,7 @@ from .base import BaseAWQForCausalLM
from
awq.utils.fused_utils
import
fuse_qkv
from
awq.modules.fused.block
import
LlamaLikeBlock
from
awq.modules.fused.model
import
LlamaLikeModel
from
awq.modules.fused.mlp
import
Quant
Llama
MLP
from
awq.modules.fused.mlp
import
Quant
Fused
MLP
from
awq.modules.fused.norm
import
FasterTransformerRMSNorm
class
YiAWQForCausalLM
(
BaseAWQForCausalLM
):
...
...
@@ -90,7 +90,7 @@ class YiFuser:
module
.
self_attn
.
k_proj
,
module
.
self_attn
.
v_proj
)
mlp
=
Quant
Llama
MLP
(
mlp
=
Quant
Fused
MLP
(
module
.
mlp
.
gate_proj
,
module
.
mlp
.
down_proj
,
module
.
mlp
.
up_proj
...
...
awq/modules/fused/attn.py
View file @
3b362c0d
...
...
@@ -131,7 +131,16 @@ class QuantAttentionFused(nn.Module):
elif
bsz
<
self
.
cache_batch_size
:
self
.
cache
.
decrease_batch_size
(
bsz
)
self
.
cache_batch_size
=
bsz
# Always reset to 0
self
.
start_pos
=
0
# In case we re-generate, we need to refresh the starting position
# to 0. We detect it by checking if `past_key_values` is set to None,
# which indicates that we are on the first step of `generate()`.
if
"past_key_value"
in
kwargs
and
kwargs
[
"past_key_value"
]
is
None
:
self
.
start_pos
=
0
xqkv
=
self
.
qkv_proj
(
hidden_states
)
xqkv
=
xqkv
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xqkv_view"
])
...
...
awq/modules/fused/mlp.py
View file @
3b362c0d
...
...
@@ -3,15 +3,17 @@ import awq_inference_engine
import
torch.nn.functional
as
F
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
class
QuantLlamaMLP
(
nn
.
Module
):
class
QuantFusedMLP
(
nn
.
Module
):
def
__init__
(
self
,
gate_proj
,
down_proj
,
up_proj
up_proj
,
activation
=
F
.
silu
,
):
super
().
__init__
()
self
.
register_buffer
(
'gate_proj_qweight'
,
gate_proj
.
qweight
)
self
.
register_buffer
(
'gate_proj_scales'
,
gate_proj
.
scales
)
self
.
register_buffer
(
'gate_proj_qzeros'
,
gate_proj
.
qzeros
)
...
...
@@ -32,6 +34,8 @@ class QuantLlamaMLP(nn.Module):
self
.
linear
=
awq_inference_engine
.
gemm_forward_cuda
self
.
group_size
=
8
self
.
activation
=
activation
def
forward
(
self
,
x
):
out_shape
=
x
.
shape
[:
-
1
]
+
(
self
.
intermediate_size
,)
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
...
...
@@ -49,8 +53,22 @@ class QuantLlamaMLP(nn.Module):
self
.
up_proj_qzeros
,
self
.
group_size
,
)
x
=
F
.
silu
(
gate_output
)
*
up_output
x
=
self
.
activation
(
gate_output
)
*
up_output
x
=
x
.
reshape
(
out_shape
)
x
=
self
.
down_proj
(
x
)
return
x
\ No newline at end of file
return
x
class
QuantLlamaMLP
(
QuantFusedMLP
):
r
"""
QuantLlamaMLP class kept for backward compatibilty, in the future, users
should always use `QuantFusedMLP` class instead.
"""
def
__init__
(
self
,
gate_proj
,
down_proj
,
up_proj
):
super
().
__init__
(
gate_proj
,
down_proj
,
up_proj
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment