Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
3b362c0d
"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "93cc6da7771baf4c7beae0b6373efbe9dc16485d"
Unverified
Commit
3b362c0d
authored
Nov 15, 2023
by
Younes Belkada
Committed by
GitHub
Nov 15, 2023
Browse files
[`core`] Replace `QuantLlamaMLP` with `QuantFusedMLP` (#188)
parent
09c73fb2
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
40 additions
and
13 deletions
+40
-13
awq/models/aquila.py
awq/models/aquila.py
+2
-2
awq/models/llama.py
awq/models/llama.py
+2
-2
awq/models/mistral.py
awq/models/mistral.py
+2
-2
awq/models/yi.py
awq/models/yi.py
+2
-2
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+10
-1
awq/modules/fused/mlp.py
awq/modules/fused/mlp.py
+22
-4
No files found.
awq/models/aquila.py
View file @
3b362c0d
...
@@ -8,7 +8,7 @@ from transformers.models.llama.modeling_llama import (
...
@@ -8,7 +8,7 @@ from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer
as
OldAquilaDecoderLayer
,
LlamaDecoderLayer
as
OldAquilaDecoderLayer
,
LlamaForCausalLM
as
OldAquilaForCausalLM
LlamaForCausalLM
as
OldAquilaForCausalLM
)
)
from
awq.modules.fused.mlp
import
Quant
Llama
MLP
from
awq.modules.fused.mlp
import
Quant
Fused
MLP
from
awq.modules.fused.norm
import
FasterTransformerRMSNorm
from
awq.modules.fused.norm
import
FasterTransformerRMSNorm
class
AquilaAWQForCausalLM
(
BaseAWQForCausalLM
):
class
AquilaAWQForCausalLM
(
BaseAWQForCausalLM
):
...
@@ -95,7 +95,7 @@ class AquilaFuser:
...
@@ -95,7 +95,7 @@ class AquilaFuser:
module
.
self_attn
.
k_proj
,
module
.
self_attn
.
k_proj
,
module
.
self_attn
.
v_proj
module
.
self_attn
.
v_proj
)
)
mlp
=
Quant
Llama
MLP
(
mlp
=
Quant
Fused
MLP
(
module
.
mlp
.
gate_proj
,
module
.
mlp
.
gate_proj
,
module
.
mlp
.
down_proj
,
module
.
mlp
.
down_proj
,
module
.
mlp
.
up_proj
module
.
mlp
.
up_proj
...
...
awq/models/llama.py
View file @
3b362c0d
...
@@ -8,7 +8,7 @@ from transformers.models.llama.modeling_llama import (
...
@@ -8,7 +8,7 @@ from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer
as
OldLlamaDecoderLayer
,
LlamaDecoderLayer
as
OldLlamaDecoderLayer
,
LlamaForCausalLM
as
OldLlamaForCausalLM
LlamaForCausalLM
as
OldLlamaForCausalLM
)
)
from
awq.modules.fused.mlp
import
Quant
Llama
MLP
from
awq.modules.fused.mlp
import
Quant
Fused
MLP
from
awq.modules.fused.norm
import
FasterTransformerRMSNorm
from
awq.modules.fused.norm
import
FasterTransformerRMSNorm
class
LlamaAWQForCausalLM
(
BaseAWQForCausalLM
):
class
LlamaAWQForCausalLM
(
BaseAWQForCausalLM
):
...
@@ -95,7 +95,7 @@ class LlamaFuser:
...
@@ -95,7 +95,7 @@ class LlamaFuser:
module
.
self_attn
.
k_proj
,
module
.
self_attn
.
k_proj
,
module
.
self_attn
.
v_proj
module
.
self_attn
.
v_proj
)
)
mlp
=
Quant
Llama
MLP
(
mlp
=
Quant
Fused
MLP
(
module
.
mlp
.
gate_proj
,
module
.
mlp
.
gate_proj
,
module
.
mlp
.
down_proj
,
module
.
mlp
.
down_proj
,
module
.
mlp
.
up_proj
module
.
mlp
.
up_proj
...
...
awq/models/mistral.py
View file @
3b362c0d
...
@@ -8,7 +8,7 @@ from transformers.models.mistral.modeling_mistral import (
...
@@ -8,7 +8,7 @@ from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer
as
OldMistralDecoderLayer
,
MistralDecoderLayer
as
OldMistralDecoderLayer
,
MistralForCausalLM
as
OldMistralForCausalLM
MistralForCausalLM
as
OldMistralForCausalLM
)
)
from
awq.modules.fused.mlp
import
Quant
Llama
MLP
from
awq.modules.fused.mlp
import
Quant
Fused
MLP
from
awq.modules.fused.norm
import
FasterTransformerRMSNorm
from
awq.modules.fused.norm
import
FasterTransformerRMSNorm
class
MistralAWQForCausalLM
(
BaseAWQForCausalLM
):
class
MistralAWQForCausalLM
(
BaseAWQForCausalLM
):
...
@@ -95,7 +95,7 @@ class MistralFuser:
...
@@ -95,7 +95,7 @@ class MistralFuser:
module
.
self_attn
.
k_proj
,
module
.
self_attn
.
k_proj
,
module
.
self_attn
.
v_proj
module
.
self_attn
.
v_proj
)
)
mlp
=
Quant
Llama
MLP
(
mlp
=
Quant
Fused
MLP
(
module
.
mlp
.
gate_proj
,
module
.
mlp
.
gate_proj
,
module
.
mlp
.
down_proj
,
module
.
mlp
.
down_proj
,
module
.
mlp
.
up_proj
module
.
mlp
.
up_proj
...
...
awq/models/yi.py
View file @
3b362c0d
...
@@ -4,7 +4,7 @@ from .base import BaseAWQForCausalLM
...
@@ -4,7 +4,7 @@ from .base import BaseAWQForCausalLM
from
awq.utils.fused_utils
import
fuse_qkv
from
awq.utils.fused_utils
import
fuse_qkv
from
awq.modules.fused.block
import
LlamaLikeBlock
from
awq.modules.fused.block
import
LlamaLikeBlock
from
awq.modules.fused.model
import
LlamaLikeModel
from
awq.modules.fused.model
import
LlamaLikeModel
from
awq.modules.fused.mlp
import
Quant
Llama
MLP
from
awq.modules.fused.mlp
import
Quant
Fused
MLP
from
awq.modules.fused.norm
import
FasterTransformerRMSNorm
from
awq.modules.fused.norm
import
FasterTransformerRMSNorm
class
YiAWQForCausalLM
(
BaseAWQForCausalLM
):
class
YiAWQForCausalLM
(
BaseAWQForCausalLM
):
...
@@ -90,7 +90,7 @@ class YiFuser:
...
@@ -90,7 +90,7 @@ class YiFuser:
module
.
self_attn
.
k_proj
,
module
.
self_attn
.
k_proj
,
module
.
self_attn
.
v_proj
module
.
self_attn
.
v_proj
)
)
mlp
=
Quant
Llama
MLP
(
mlp
=
Quant
Fused
MLP
(
module
.
mlp
.
gate_proj
,
module
.
mlp
.
gate_proj
,
module
.
mlp
.
down_proj
,
module
.
mlp
.
down_proj
,
module
.
mlp
.
up_proj
module
.
mlp
.
up_proj
...
...
awq/modules/fused/attn.py
View file @
3b362c0d
...
@@ -131,7 +131,16 @@ class QuantAttentionFused(nn.Module):
...
@@ -131,7 +131,16 @@ class QuantAttentionFused(nn.Module):
elif
bsz
<
self
.
cache_batch_size
:
elif
bsz
<
self
.
cache_batch_size
:
self
.
cache
.
decrease_batch_size
(
bsz
)
self
.
cache
.
decrease_batch_size
(
bsz
)
self
.
cache_batch_size
=
bsz
self
.
cache_batch_size
=
bsz
# Always reset to 0
self
.
start_pos
=
0
# In case we re-generate, we need to refresh the starting position
# to 0. We detect it by checking if `past_key_values` is set to None,
# which indicates that we are on the first step of `generate()`.
if
"past_key_value"
in
kwargs
and
kwargs
[
"past_key_value"
]
is
None
:
self
.
start_pos
=
0
xqkv
=
self
.
qkv_proj
(
hidden_states
)
xqkv
=
self
.
qkv_proj
(
hidden_states
)
xqkv
=
xqkv
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xqkv_view"
])
xqkv
=
xqkv
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xqkv_view"
])
...
...
awq/modules/fused/mlp.py
View file @
3b362c0d
...
@@ -3,15 +3,17 @@ import awq_inference_engine
...
@@ -3,15 +3,17 @@ import awq_inference_engine
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
class
QuantLlamaMLP
(
nn
.
Module
):
class
QuantFusedMLP
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
gate_proj
,
gate_proj
,
down_proj
,
down_proj
,
up_proj
up_proj
,
activation
=
F
.
silu
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
register_buffer
(
'gate_proj_qweight'
,
gate_proj
.
qweight
)
self
.
register_buffer
(
'gate_proj_qweight'
,
gate_proj
.
qweight
)
self
.
register_buffer
(
'gate_proj_scales'
,
gate_proj
.
scales
)
self
.
register_buffer
(
'gate_proj_scales'
,
gate_proj
.
scales
)
self
.
register_buffer
(
'gate_proj_qzeros'
,
gate_proj
.
qzeros
)
self
.
register_buffer
(
'gate_proj_qzeros'
,
gate_proj
.
qzeros
)
...
@@ -32,6 +34,8 @@ class QuantLlamaMLP(nn.Module):
...
@@ -32,6 +34,8 @@ class QuantLlamaMLP(nn.Module):
self
.
linear
=
awq_inference_engine
.
gemm_forward_cuda
self
.
linear
=
awq_inference_engine
.
gemm_forward_cuda
self
.
group_size
=
8
self
.
group_size
=
8
self
.
activation
=
activation
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
out_shape
=
x
.
shape
[:
-
1
]
+
(
self
.
intermediate_size
,)
out_shape
=
x
.
shape
[:
-
1
]
+
(
self
.
intermediate_size
,)
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
...
@@ -49,8 +53,22 @@ class QuantLlamaMLP(nn.Module):
...
@@ -49,8 +53,22 @@ class QuantLlamaMLP(nn.Module):
self
.
up_proj_qzeros
,
self
.
up_proj_qzeros
,
self
.
group_size
,
self
.
group_size
,
)
)
x
=
F
.
silu
(
gate_output
)
*
up_output
x
=
self
.
activation
(
gate_output
)
*
up_output
x
=
x
.
reshape
(
out_shape
)
x
=
x
.
reshape
(
out_shape
)
x
=
self
.
down_proj
(
x
)
x
=
self
.
down_proj
(
x
)
return
x
return
x
\ No newline at end of file
class
QuantLlamaMLP
(
QuantFusedMLP
):
r
"""
QuantLlamaMLP class kept for backward compatibilty, in the future, users
should always use `QuantFusedMLP` class instead.
"""
def
__init__
(
self
,
gate_proj
,
down_proj
,
up_proj
):
super
().
__init__
(
gate_proj
,
down_proj
,
up_proj
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment