Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
890b6aa7
Commit
890b6aa7
authored
Sep 08, 2023
by
Casper Hansen
Browse files
GEMM + GEMV compatibility
parent
5297eccc
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
13 deletions
+39
-13
awq/models/llama.py
awq/models/llama.py
+17
-6
awq/modules/fused/mlp.py
awq/modules/fused/mlp.py
+22
-7
No files found.
awq/models/llama.py
View file @
890b6aa7
...
@@ -111,20 +111,31 @@ class LlamaFuser:
...
@@ -111,20 +111,31 @@ class LlamaFuser:
q_proj
,
k_proj
,
v_proj
=
module
.
q_proj
,
module
.
k_proj
,
module
.
v_proj
q_proj
,
k_proj
,
v_proj
=
module
.
q_proj
,
module
.
k_proj
,
module
.
v_proj
bias
=
torch
.
cat
([
q_proj
.
bias
,
k_proj
.
bias
,
v_proj
.
bias
],
dim
=
0
)
if
q_proj
.
bias
is
not
None
else
None
bias
=
torch
.
cat
([
q_proj
.
bias
,
k_proj
.
bias
,
v_proj
.
bias
],
dim
=
0
)
if
q_proj
.
bias
is
not
None
else
None
qkv_layer
=
WQLinear_GEMV
(
if
isinstance
(
q_proj
,
WQLinear_GEMV
):
q_linear
=
WQLinear_GEMV
else
:
q_linear
=
WQLinear_GEMM
qkv_layer
=
q_linear
(
q_proj
.
w_bit
,
q_proj
.
w_bit
,
q_proj
.
group_size
,
q_proj
.
group_size
,
q_proj
.
in_features
,
q_proj
.
in_features
,
q_proj
.
out_features
+
k_proj
.
out_features
+
v_proj
.
out_features
,
q_proj
.
out_features
+
k_proj
.
out_features
+
v_proj
.
out_features
,
q_proj
.
bias
is
not
None
,
q_proj
.
bias
is
not
None
,
q_proj
.
qweight
.
device
,
next
(
iter
(
module
.
state_dict
().
values
()))
.
device
)
)
if
isinstance
(
qkv_layer
,
WQLinear_GEMV
):
qkv_layer
.
qweight
=
torch
.
cat
([
q_proj
.
qweight
,
k_proj
.
qweight
,
v_proj
.
qweight
],
dim
=
0
)
qkv_layer
.
qweight
=
torch
.
cat
([
q_proj
.
qweight
,
k_proj
.
qweight
,
v_proj
.
qweight
],
dim
=
0
)
qkv_layer
.
qzeros
=
torch
.
cat
([
q_proj
.
qzeros
,
k_proj
.
qzeros
,
v_proj
.
qzeros
],
dim
=
0
)
qkv_layer
.
qzeros
=
torch
.
cat
([
q_proj
.
qzeros
,
k_proj
.
qzeros
,
v_proj
.
qzeros
],
dim
=
0
)
qkv_layer
.
scales
=
torch
.
cat
([
q_proj
.
scales
,
k_proj
.
scales
,
v_proj
.
scales
],
dim
=
0
)
qkv_layer
.
scales
=
torch
.
cat
([
q_proj
.
scales
,
k_proj
.
scales
,
v_proj
.
scales
],
dim
=
0
)
qkv_layer
.
split_k_iters
=
q_proj
.
split_k_iters
else
:
qkv_layer
.
qweight
=
torch
.
cat
([
q_proj
.
qweight
,
k_proj
.
qweight
,
v_proj
.
qweight
],
dim
=
1
)
qkv_layer
.
qzeros
=
torch
.
cat
([
q_proj
.
qzeros
,
k_proj
.
qzeros
,
v_proj
.
qzeros
],
dim
=
1
)
qkv_layer
.
scales
=
torch
.
cat
([
q_proj
.
scales
,
k_proj
.
scales
,
v_proj
.
scales
],
dim
=
1
)
qkv_layer
.
bias
=
bias
qkv_layer
.
bias
=
bias
qkv_layer
.
split_k_iters
=
q_proj
.
split_k_iters
return
qkv_layer
return
qkv_layer
...
...
awq/modules/fused/mlp.py
View file @
890b6aa7
...
@@ -2,6 +2,7 @@ import torch
...
@@ -2,6 +2,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
awq_inference_engine
import
awq_inference_engine
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
class
QuantMPTMLP
(
nn
.
Module
):
class
QuantMPTMLP
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -19,14 +20,21 @@ class QuantMPTMLP(nn.Module):
...
@@ -19,14 +20,21 @@ class QuantMPTMLP(nn.Module):
self
.
act
=
act
self
.
act
=
act
self
.
down_proj
=
down_proj
self
.
down_proj
=
down_proj
if
isinstance
(
down_proj
,
WQLinear_GEMV
):
self
.
linear
=
awq_inference_engine
.
gemv_forward_cuda
self
.
group_size
=
down_proj
.
group_size
else
:
self
.
linear
=
awq_inference_engine
.
gemm_forward_cuda
self
.
group_size
=
8
def
forward
(
self
,
x
:
torch
.
Tensor
):
def
forward
(
self
,
x
:
torch
.
Tensor
):
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
x
=
awq_inference_engine
.
gemv_forward_cuda
(
x
=
self
.
linear
(
x
,
x
,
self
.
up_proj_qweight
,
self
.
up_proj_qweight
,
self
.
up_proj_scales
,
self
.
up_proj_scales
,
self
.
up_proj_qzeros
,
self
.
up_proj_qzeros
,
self
.
down_proj
.
group_size
self
.
group_size
)
)
return
self
.
down_proj
(
self
.
act
(
x
))
return
self
.
down_proj
(
self
.
act
(
x
))
...
@@ -37,7 +45,7 @@ class QuantLlamaMLP(nn.Module):
...
@@ -37,7 +45,7 @@ class QuantLlamaMLP(nn.Module):
self
,
self
,
gate_proj
,
gate_proj
,
down_proj
,
down_proj
,
up_proj
,
up_proj
):
):
super
().
__init__
()
super
().
__init__
()
self
.
register_buffer
(
'gate_proj_qweight'
,
gate_proj
.
qweight
)
self
.
register_buffer
(
'gate_proj_qweight'
,
gate_proj
.
qweight
)
...
@@ -53,22 +61,29 @@ class QuantLlamaMLP(nn.Module):
...
@@ -53,22 +61,29 @@ class QuantLlamaMLP(nn.Module):
self
.
w_bit
=
gate_proj
.
w_bit
self
.
w_bit
=
gate_proj
.
w_bit
self
.
down_proj
=
down_proj
self
.
down_proj
=
down_proj
if
isinstance
(
down_proj
,
WQLinear_GEMV
):
self
.
linear
=
awq_inference_engine
.
gemv_forward_cuda
self
.
group_size
=
down_proj
.
group_size
else
:
self
.
linear
=
awq_inference_engine
.
gemm_forward_cuda
self
.
group_size
=
8
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
out_shape
=
x
.
shape
[:
-
1
]
+
(
self
.
intermediate_size
,)
out_shape
=
x
.
shape
[:
-
1
]
+
(
self
.
intermediate_size
,)
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
gate_output
=
awq_inference_engine
.
gemv_forward_cuda
(
gate_output
=
self
.
linear
(
x
,
x
,
self
.
gate_proj_qweight
,
self
.
gate_proj_qweight
,
self
.
gate_proj_scales
,
self
.
gate_proj_scales
,
self
.
gate_proj_qzeros
,
self
.
gate_proj_qzeros
,
self
.
down_proj
.
group_size
,
self
.
group_size
,
)
)
up_output
=
awq_inference_engine
.
gemv_forward_cuda
(
up_output
=
self
.
linear
(
x
,
x
,
self
.
up_proj_qweight
,
self
.
up_proj_qweight
,
self
.
up_proj_scales
,
self
.
up_proj_scales
,
self
.
up_proj_qzeros
,
self
.
up_proj_qzeros
,
self
.
down_proj
.
group_size
,
self
.
group_size
,
)
)
x
=
F
.
silu
(
gate_output
)
*
up_output
x
=
F
.
silu
(
gate_output
)
*
up_output
x
=
x
.
reshape
(
out_shape
)
x
=
x
.
reshape
(
out_shape
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment