Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
fbfa9d82
Commit
fbfa9d82
authored
Sep 08, 2023
by
Casper Hansen
Browse files
Attention works
parent
75710806
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
342 additions
and
91 deletions
+342
-91
awq/models/llama.py
awq/models/llama.py
+50
-8
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+292
-83
No files found.
awq/models/llama.py
View file @
fbfa9d82
...
@@ -70,7 +70,7 @@ from typing import List, Tuple, Union
...
@@ -70,7 +70,7 @@ from typing import List, Tuple, Union
from
awq.utils.utils
import
set_module_name
from
awq.utils.utils
import
set_module_name
from
awq.modules.fused.mlp
import
QuantLlamaMLP
from
awq.modules.fused.mlp
import
QuantLlamaMLP
from
awq.modules.fused.norm
import
FTLlamaRMSNorm
from
awq.modules.fused.norm
import
FTLlamaRMSNorm
from
awq.modules.fused.attn
import
QuantLlamaAttention
from
awq.modules.fused.attn
import
QuantLlamaAttention
,
QuantLlamaAttentionFused
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
LlamaRMSNorm
,
LlamaMLP
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
LlamaRMSNorm
,
LlamaMLP
...
@@ -96,18 +96,59 @@ class LlamaFuser:
...
@@ -96,18 +96,59 @@ class LlamaFuser:
def
fuse_attention
(
self
):
def
fuse_attention
(
self
):
for
name
,
module
in
self
.
attention_modules
:
for
name
,
module
in
self
.
attention_modules
:
qkv_layer
:
Union
[
WQLinear_GEMM
,
WQLinear_GEMV
]
=
self
.
_fuse_qkv
(
module
)
qkv_layer
:
Union
[
WQLinear_GEMM
,
WQLinear_GEMV
]
=
self
.
_fuse_qkv2
(
module
)
attn
=
QuantLlamaAttention
(
# attn = QuantLlamaAttention(
# module.hidden_size,
# module.num_heads,
# module.num_key_value_heads,
# qkv_layer,
# module.o_proj,
# next(iter(qkv_layer.state_dict().values())).device,
# self.model.config.max_new_tokens
# )
attn
=
QuantLlamaAttentionFused
(
module
.
hidden_size
,
module
.
hidden_size
,
module
.
num_heads
,
module
.
num_heads
,
module
.
num_key_value_heads
,
qkv_layer
,
qkv_layer
,
module
.
o_proj
,
module
.
o_proj
,
next
(
iter
(
qkv_layer
.
state_dict
().
values
())).
device
,
next
(
iter
(
qkv_layer
.
state_dict
().
values
())).
device
,
self
.
model
.
config
.
max_new_tokens
self
.
model
.
config
.
max_new_tokens
)
)
set_module_name
(
self
.
model
,
name
,
attn
)
set_module_name
(
self
.
model
,
name
,
attn
)
def
_fuse_qkv2
(
self
,
module
:
LlamaAttention
):
q_proj
=
module
.
q_proj
k_proj
=
module
.
k_proj
v_proj
=
module
.
v_proj
qweights
=
torch
.
cat
([
q_proj
.
qweight
,
k_proj
.
qweight
,
v_proj
.
qweight
],
dim
=
0
)
qzeros
=
torch
.
cat
([
q_proj
.
qzeros
,
k_proj
.
qzeros
,
v_proj
.
qzeros
],
dim
=
0
)
scales
=
torch
.
cat
([
q_proj
.
scales
,
k_proj
.
scales
,
v_proj
.
scales
],
dim
=
0
)
# g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
g_idx
=
None
bias
=
(
torch
.
cat
([
q_proj
.
bias
,
k_proj
.
bias
,
v_proj
.
bias
],
dim
=
0
)
if
q_proj
.
bias
is
not
None
else
None
)
qkv_layer
=
WQLinear_GEMV
(
q_proj
.
w_bit
,
q_proj
.
group_size
,
q_proj
.
in_features
,
q_proj
.
out_features
+
k_proj
.
out_features
+
v_proj
.
out_features
,
q_proj
.
bias
is
not
None
,
q_proj
.
qweight
.
device
,
)
qkv_layer
.
qweight
=
qweights
qkv_layer
.
qzeros
=
qzeros
qkv_layer
.
scales
=
scales
qkv_layer
.
bias
=
bias
qkv_layer
.
split_k_iters
=
q_proj
.
split_k_iters
return
qkv_layer
def
_fuse_qkv
(
self
,
module
:
LlamaAttention
):
def
_fuse_qkv
(
self
,
module
:
LlamaAttention
):
# get qkv and bias
# get qkv and bias
q_proj
,
k_proj
,
v_proj
=
module
.
q_proj
,
module
.
k_proj
,
module
.
v_proj
q_proj
,
k_proj
,
v_proj
=
module
.
q_proj
,
module
.
k_proj
,
module
.
v_proj
...
@@ -129,10 +170,11 @@ class LlamaFuser:
...
@@ -129,10 +170,11 @@ class LlamaFuser:
)
)
# replace buffers with real weights
# replace buffers with real weights
qkv_layer
.
qweight
=
torch
.
cat
([
q_proj
.
qweight
,
k_proj
.
qweight
,
v_proj
.
qweight
],
dim
=
1
)
qkv_layer
.
qweight
s
=
torch
.
cat
([
q_proj
.
qweight
,
k_proj
.
qweight
,
v_proj
.
qweight
],
dim
=
0
)
qkv_layer
.
qzeros
=
torch
.
cat
([
q_proj
.
qzeros
,
k_proj
.
qzeros
,
v_proj
.
qzeros
],
dim
=
1
)
qkv_layer
.
qzeros
=
torch
.
cat
([
q_proj
.
qzeros
,
k_proj
.
qzeros
,
v_proj
.
qzeros
],
dim
=
0
)
qkv_layer
.
scales
=
torch
.
cat
([
q_proj
.
scales
,
k_proj
.
scales
,
v_proj
.
scales
],
dim
=
1
)
qkv_layer
.
scales
=
torch
.
cat
([
q_proj
.
scales
,
k_proj
.
scales
,
v_proj
.
scales
],
dim
=
0
)
qkv_layer
.
bias
=
bias
qkv_layer
.
bias
=
bias
qkv_layer
.
split_k_iters
=
q_proj
.
split_k_iters
return
qkv_layer
return
qkv_layer
...
...
awq/modules/fused/attn.py
View file @
fbfa9d82
import
math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
awq_inference_engine
import
awq_inference_engine
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
transformers.models.llama.modeling_llama
import
apply_rotary_pos_emb
,
LlamaRotaryEmbedding
try
:
from
flash_attn
import
flash_attn_func
FLASH_INSTALLED
=
True
except
:
FLASH_INSTALLED
=
False
class
QuantLlamaRotary
(
nn
.
Module
):
def
__init__
(
self
,
dim
=
4096
,
max_position_embeddings
=
2048
,
base
=
10000
,
device
=
None
,
is_neox
=
True
,
num_heads
=
None
,
num_kv_heads
=
None
):
super
().
__init__
()
self
.
dim
=
dim
self
.
is_neox
=
is_neox
self
.
num_heads
=
num_heads
self
.
num_kv_heads
=
num_kv_heads
# create cache
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
dim
,
2
,
device
=
device
)
/
dim
))
t
=
torch
.
arange
(
max_position_embeddings
,
device
=
device
).
float
()
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
.
float
())
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
).
to
(
torch
.
get_default_dtype
())
# Embedding size: [max_position, rotary_dim]
self
.
register_buffer
(
"cos_sin_cache"
,
cache
.
half
(),
persistent
=
False
)
def
forward
(
self
,
qkv_states
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
batch_size
:
int
,
q_len
:
int
):
# get qkv
query
,
key
,
value
=
qkv_states
.
chunk
(
chunks
=
3
,
dim
=-
1
)
del
qkv_states
# [num_tokens, num_heads * head_size]
query_batch_size
,
query_len
,
_
=
query
.
shape
query
=
query
.
view
(
query_len
*
query_batch_size
,
self
.
num_heads
*
self
.
dim
)
# [num_tokens, num_kv_heads * head_size]
key_batch_size
,
key_len
,
_
=
key
.
shape
key
=
key
.
view
(
key_len
*
key_batch_size
,
self
.
num_kv_heads
*
self
.
dim
)
# [num_tokens]
positions
=
position_ids
.
view
(
-
1
).
to
(
query
.
device
)
# Apply rotary embedding to the query and key before passing them
# to the attention op.
query
=
query
.
contiguous
()
key
=
key
.
contiguous
()
awq_inference_engine
.
rotary_embedding
(
positions
,
query
,
key
,
self
.
dim
,
self
.
cos_sin_cache
,
self
.
is_neox
)
query
=
query
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
dim
).
transpose
(
1
,
2
)
key
=
key
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
dim
).
transpose
(
1
,
2
)
value
=
value
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
dim
).
transpose
(
1
,
2
)
return
query
,
key
,
value
class
QuantLlamaRotaryEmbedding
(
nn
.
Module
):
class
QuantLlamaRotaryEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
max_position_embeddings
=
2048
,
base
=
10000
,
device
=
None
):
def
__init__
(
self
,
dim
,
max_position_embeddings
=
2048
,
base
=
10000
,
device
=
None
):
...
@@ -11,28 +80,35 @@ class QuantLlamaRotaryEmbedding(nn.Module):
...
@@ -11,28 +80,35 @@ class QuantLlamaRotaryEmbedding(nn.Module):
self
.
dim
=
dim
self
.
dim
=
dim
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
self
.
base
=
base
self
.
base
=
base
inv_freq
=
1.0
/
(
self
.
base
**
(
torch
.
arange
(
0
,
self
.
dim
,
2
).
float
().
to
(
device
)
/
self
.
dim
))
inv_freq
=
1.0
/
(
self
.
base
**
(
torch
.
arange
(
0
,
self
.
dim
,
2
).
float
().
to
(
device
)
/
self
.
dim
)
)
self
.
register_buffer
(
"inv_freq"
,
inv_freq
)
self
.
register_buffer
(
"inv_freq"
,
inv_freq
)
# Build here to make `torch.jit.trace` work.
# Build here to make `torch.jit.trace` work.
self
.
_set_cos_sin_cache
(
self
.
_set_cos_sin_cache
(
seq_len
=
max_position_embeddings
,
device
=
self
.
inv_freq
.
device
,
dtype
=
torch
.
get_default_dtype
()
seq_len
=
max_position_embeddings
,
device
=
self
.
inv_freq
.
device
,
dtype
=
torch
.
get_default_dtype
(),
)
)
def
_set_cos_sin_cache
(
self
,
seq_len
,
device
,
dtype
):
def
_set_cos_sin_cache
(
self
,
seq_len
,
device
,
dtype
):
self
.
max_seq_len_cached
=
seq_len
self
.
max_seq_len_cached
=
seq_len
t
=
torch
.
arange
(
self
.
max_seq_len_cached
,
device
=
device
,
dtype
=
self
.
inv_freq
.
dtype
)
t
=
torch
.
arange
(
self
.
max_seq_len_cached
,
device
=
device
,
dtype
=
self
.
inv_freq
.
dtype
)
freqs
=
torch
.
einsum
(
"i,j->ij"
,
t
,
self
.
inv_freq
)
freqs
=
torch
.
einsum
(
"i,j->ij"
,
t
,
self
.
inv_freq
)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
cos
=
freqs
.
cos
()
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
# [max_position, rot_dim]
# self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
# self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
self
.
register_buffer
(
"cos_sin_cache"
,
cache
.
half
(),
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_cache"
,
cache
.
half
(),
persistent
=
False
)
def
forward
(
def
forward
(
self
,
self
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
@@ -41,20 +117,72 @@ class QuantLlamaRotaryEmbedding(nn.Module):
...
@@ -41,20 +117,72 @@ class QuantLlamaRotaryEmbedding(nn.Module):
):
):
# Apply rotary embedding to the query and key before passing them
# Apply rotary embedding to the query and key before passing them
# to the attention op.
# to the attention op.
# print(positions.shape, query.shape, key.shape, self.cos_sin_cache.shape)
query
=
query
.
contiguous
()
query
=
query
.
contiguous
()
key
=
key
.
contiguous
()
key
=
key
.
contiguous
()
awq_inference_engine
.
rotary_embedding
(
awq_inference_engine
.
rotary_embedding
(
positions
,
positions
,
query
,
query
,
key
,
key
,
self
.
dim
,
self
.
dim
,
self
.
cos_sin_cache
,
self
.
cos_sin_cache
,
True
# is_neox
True
)
)
return
query
,
key
return
query
,
key
class
TorchAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
use_flash
=
False
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
use_flash
=
use_flash
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
use_cache
:
bool
,
past_key_value
:
torch
.
Tensor
,
batch_size
:
int
,
q_len
:
int
):
is_causal
=
past_key_value
is
None
kv_seq_len
=
q_len
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
[
0
].
shape
[
-
2
]
value
=
value
.
to
(
key
.
device
)
if
past_key_value
is
not
None
:
# reuse k, v, self_attention
key
=
torch
.
cat
([
past_key_value
[
0
],
key
],
dim
=
2
)
value
=
torch
.
cat
([
past_key_value
[
1
],
value
],
dim
=
2
)
if
use_cache
:
# Since qkv_proj is fused, query etc will hold a reference to the original qkv_states tensor
# which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
key
=
key
.
contiguous
()
value
=
value
.
contiguous
()
query
=
query
.
contiguous
()
past_key_value
=
(
key
,
value
)
if
use_cache
else
None
if
self
.
use_flash
and
FLASH_INSTALLED
:
query
=
query
.
transpose
(
1
,
2
)
key
=
key
.
transpose
(
1
,
2
)
value
=
value
.
transpose
(
1
,
2
)
attn_output
=
flash_attn_func
(
query
,
key
,
value
,
causal
=
is_causal
)
else
:
attn_output
=
F
.
scaled_dot_product_attention
(
query
,
key
,
value
,
is_causal
=
is_causal
)
del
query
,
key
,
value
attn_output
=
attn_output
.
transpose
(
1
,
2
).
reshape
(
batch_size
,
q_len
,
self
.
hidden_size
)
return
attn_output
,
past_key_value
class
QuantLlamaAttention
(
nn
.
Module
):
class
QuantLlamaAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
"""Multi-headed attention from 'Attention Is All You Need' paper"""
...
@@ -66,101 +194,182 @@ class QuantLlamaAttention(nn.Module):
...
@@ -66,101 +194,182 @@ class QuantLlamaAttention(nn.Module):
qkv_proj
,
qkv_proj
,
o_proj
,
o_proj
,
dev
,
dev
,
max_new_tokens
,
max_new_tokens
use_hf_rotary
=
False
):
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
num_heads
=
num_heads
self
.
num_kv_heads
=
num_kv_heads
self
.
head_dim
=
hidden_size
//
num_heads
self
.
use_hf_rotary
=
use_hf_rotary
if
(
self
.
head_dim
*
num_heads
)
!=
self
.
hidden_size
:
raise
ValueError
(
f
"hidden_size must be divisible by num_heads (got `hidden_size`:
{
self
.
hidden_size
}
"
f
" and `num_heads`:
{
num_heads
}
)."
)
self
.
qkv_proj
=
qkv_proj
self
.
qkv_proj
=
qkv_proj
self
.
o_proj
=
o_proj
self
.
o_proj
=
o_proj
self
.
attn
=
TorchAttention
(
hidden_size
)
self
.
rotary_emb
=
QuantLlamaRotary
(
dim
=
hidden_size
//
num_heads
,
max_position_embeddings
=
max_new_tokens
,
device
=
dev
,
is_neox
=
True
,
num_heads
=
num_heads
,
num_kv_heads
=
num_heads
)
if
use_hf_rotary
:
self
.
rotary_emb
=
LlamaRotaryEmbedding
(
self
.
head_dim
,
max_new_tokens
,
device
=
dev
)
else
:
self
.
rotary_emb
=
QuantLlamaRotaryEmbedding
(
self
.
head_dim
,
max_position_embeddings
=
max_new_tokens
,
device
=
dev
)
def
forward
(
self
,
hidden_states
,
past_key_value
=
None
,
attention_mask
=
None
,
position_ids
=
None
,
output_attentions
=
False
,
use_cache
=
False
):
def
forward
(
self
,
hidden_states
,
past_key_value
=
None
,
attention_mask
=
None
,
position_ids
=
None
,
output_attentions
=
False
,
use_cache
=
False
):
"""Input shape: Batch x Time x Channel"""
"""Input shape: Batch x Time x Channel"""
bsz
,
q_len
,
_
=
hidden_states
.
size
()
batch_size
,
q_len
,
_
=
hidden_states
.
size
()
qkv_states
=
self
.
qkv_proj
(
hidden_states
)
qkv_states
=
self
.
qkv_proj
(
hidden_states
)
query
,
key
,
value
=
self
.
rotary_emb
(
qkv_states
,
position_ids
,
batch_size
,
q_len
)
attn_output
,
past_key_value
=
self
.
attn
(
query
,
key
,
value
,
use_cache
,
past_key_value
,
batch_size
,
q_len
)
attn_output
=
self
.
o_proj
(
attn_output
)
if
self
.
use_hf_rotary
:
return
attn_output
,
None
,
past_key_value
# get qkv
qkv_states
=
qkv_states
.
view
(
bsz
,
q_len
,
3
,
self
.
num_heads
,
self
.
head_dim
)
query
,
key
,
value
=
torch
.
split
(
qkv_states
,
1
,
dim
=
2
)
del
qkv_states
# reshape for hf rotary
query
=
query
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key
=
key
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value
=
value
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
kv_seq_len
=
key
.
shape
[
-
2
]
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
[
0
].
shape
[
-
2
]
cos
,
sin
=
self
.
rotary_emb
(
value
,
seq_len
=
kv_seq_len
)
query
,
key
=
apply_rotary_pos_emb
(
query
,
key
,
cos
,
sin
,
position_ids
)
else
:
def
precompute_freqs_cis
(
dim
:
int
,
end
:
int
,
theta
:
float
=
10000.0
):
# get qkv
freqs
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
)[:
(
dim
//
2
)].
float
()
/
dim
))
query
,
key
,
value
=
qkv_states
.
chunk
(
chunks
=
3
,
dim
=-
1
)
t
=
torch
.
arange
(
end
,
device
=
freqs
.
device
)
# type: ignore
del
qkv_states
freqs
=
torch
.
outer
(
t
,
freqs
).
float
()
# type: ignore
freqs_cis
=
torch
.
polar
(
torch
.
ones_like
(
freqs
),
freqs
)
# complex64
return
freqs_cis
# [num_tokens, num_heads * head_size]
def
reshape_for_broadcast
(
freqs_cis
:
torch
.
Tensor
,
x
:
torch
.
Tensor
):
query_batch_size
,
query_len
,
_
=
query
.
shape
ndim
=
x
.
ndim
query
=
query
.
view
(
query_len
*
query_batch_size
,
self
.
num_heads
*
self
.
head_dim
)
assert
0
<=
1
<
ndim
assert
freqs_cis
.
shape
==
(
x
.
shape
[
1
],
x
.
shape
[
-
1
])
shape
=
[
d
if
i
==
1
or
i
==
ndim
-
1
else
1
for
i
,
d
in
enumerate
(
x
.
shape
)]
return
freqs_cis
.
view
(
*
shape
)
# [num_tokens, num_kv_heads * head_size]
def
apply_rotary_emb
(
key_batch_size
,
key_len
,
_
=
key
.
shape
xq
:
torch
.
Tensor
,
key
=
key
.
view
(
key_len
*
key_batch_size
,
self
.
num_kv_heads
*
self
.
head_dim
)
xk
:
torch
.
Tensor
,
freqs_cis
:
torch
.
Tensor
,
):
xq_
=
torch
.
view_as_complex
(
xq
.
float
().
reshape
(
*
xq
.
shape
[:
-
1
],
2
,
-
1
).
transpose
(
-
2
,
-
1
).
contiguous
()
)
xk_
=
torch
.
view_as_complex
(
xk
.
float
().
reshape
(
*
xk
.
shape
[:
-
1
],
2
,
-
1
).
transpose
(
-
2
,
-
1
).
contiguous
()
)
freqs_cis
=
reshape_for_broadcast
(
freqs_cis
,
xq_
)
xq_out
=
torch
.
view_as_real
(
xq_
*
freqs_cis
).
transpose
(
-
2
,
-
1
).
flatten
(
3
)
xk_out
=
torch
.
view_as_real
(
xk_
*
freqs_cis
).
transpose
(
-
2
,
-
1
).
flatten
(
3
)
return
xq_out
.
type_as
(
xq
),
xk_out
.
type_as
(
xk
)
# [num_tokens]
positions
=
position_ids
.
view
(
-
1
).
to
(
query
.
device
)
query
,
key
=
self
.
rotary_emb
(
query
,
key
,
positions
)
class
QuantLlamaAttentionFused
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
num_heads
,
qkv_layer
,
o_proj
,
dev
,
max_position_embeddings
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
n_local_heads
=
num_heads
self
.
head_dim
=
self
.
hidden_size
//
num_heads
self
.
qkv_proj
=
qkv_layer
self
.
o_proj
=
o_proj
self
.
start_pos
=
0
query
=
query
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
self
.
freqs_cis
=
precompute_freqs_cis
(
key
=
key
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
self
.
head_dim
,
value
=
value
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
max_position_embeddings
*
2
,
)
is_causal
=
past_key_value
is
None
kv_seq_len
=
q_len
# following fastertransformer definition
if
past_key_value
is
not
None
:
self
.
cache_v
=
(
kv_seq_len
+=
past_key_value
[
0
].
shape
[
-
2
]
torch
.
zeros
(
(
1
,
self
.
n_local_heads
,
max_position_embeddings
,
self
.
head_dim
,
)
)
.
to
(
dev
)
.
half
()
)
# added to half
# 8: pack 8 fp16 in FT, if fp32 then use 4
self
.
cache_k
=
(
torch
.
zeros
(
(
1
,
self
.
n_local_heads
,
self
.
head_dim
//
8
,
max_position_embeddings
,
8
,
)
)
.
to
(
dev
)
.
half
()
)
# added to half
# dummy
self
.
rotary_emb
=
QuantLlamaRotaryEmbedding
(
hidden_size
//
num_heads
,
max_position_embeddings
=
max_position_embeddings
,
base
=
10000
,
device
=
dev
)
value
=
value
.
to
(
key
.
device
)
def
forward
(
self
,
hidden_states
,
past_key_value
=
None
,
attention_mask
=
None
,
position_ids
=
None
,
output_attentions
=
False
,
use_cache
=
False
):
bsz
,
seqlen
,
_
=
hidden_states
.
shape
xqkv
=
self
.
qkv_proj
(
hidden_states
)
xqkv
=
xqkv
.
view
(
bsz
,
seqlen
,
-
1
,
self
.
n_local_heads
,
self
.
head_dim
)
xq
=
xqkv
[:,
:,
0
]
xk
=
xqkv
[:,
:,
1
]
xv
=
xqkv
[:,
:,
2
]
if
past_key_value
is
not
None
:
if
seqlen
>
1
:
# reuse k, v, self_attention
xq
=
xq
.
view
(
bsz
,
seqlen
,
self
.
n_local_heads
,
self
.
head_dim
)
key
=
torch
.
cat
([
past_key_value
[
0
],
key
],
dim
=
2
)
xk
=
xk
.
view
(
bsz
,
seqlen
,
self
.
n_local_heads
,
self
.
head_
dim
)
value
=
torch
.
cat
([
past_key_value
[
1
],
value
],
dim
=
2
)
xv
=
xv
.
view
(
bsz
,
seqlen
,
self
.
n_local_heads
,
self
.
head_
dim
)
if
use_cache
:
xq
,
xk
=
self
.
rotary_emb
(
xq
,
xk
,
position_ids
)
# Since qkv_proj is fused, query etc will hold a reference to the original qkv_states tensor
# which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
key
=
key
.
contiguous
()
value
=
value
.
contiguous
()
query
=
query
.
contiguous
()
past_key_value
=
(
key
,
value
)
if
use_cache
else
None
self
.
cache_k
=
self
.
cache_k
.
to
(
xq
)
self
.
cache_v
=
self
.
cache_v
.
to
(
xq
)
# with torch.backends.cuda.sdp_kernel(enable_math=False):
values_store
=
xv
.
transpose
(
2
,
1
)
attn_output
=
F
.
scaled_dot_product_attention
(
query
,
key
,
value
,
is_causal
=
is_causal
)
keys_store
=
(
del
query
,
key
,
value
xk
.
reshape
(
bsz
,
seqlen
,
self
.
n_local_heads
,
self
.
head_dim
//
8
,
8
)
.
permute
(
0
,
2
,
3
,
1
,
4
)
.
contiguous
()
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
reshape
(
bsz
,
q_len
,
self
.
hidden_size
)
self
.
cache_v
[:
bsz
,
:,
self
.
start_pos
:
self
.
start_pos
+
seqlen
,
:]
=
values_store
attn_output
=
self
.
o_proj
(
attn_output
)
self
.
cache_k
[:
bsz
,
:,
:,
self
.
start_pos
:
self
.
start_pos
+
seqlen
,
:]
=
keys_store
keys
=
xk
values
=
xv
past_key_value
=
(
xk
,
xv
)
if
use_cache
else
None
xq
=
xq
.
transpose
(
1
,
2
)
keys
=
keys
.
transpose
(
1
,
2
)
values
=
values
.
transpose
(
1
,
2
)
scores
=
torch
.
matmul
(
xq
,
keys
.
transpose
(
2
,
3
))
/
math
.
sqrt
(
self
.
head_dim
)
if
attention_mask
is
not
None
:
scores
=
scores
+
attention_mask
# (bs, n_local_heads, slen, cache_len + slen)
scores
=
F
.
softmax
(
scores
.
float
(),
dim
=-
1
).
type_as
(
xq
)
output
=
torch
.
matmul
(
scores
,
values
)
# (bs, n_local_heads, slen, head_dim)
output
=
output
.
transpose
(
1
,
2
).
contiguous
().
view
(
bsz
,
seqlen
,
-
1
)
else
:
xq
=
xq
[:,
0
,
:,
:]
xk
=
xk
[:,
0
,
:,
:]
xv
=
xv
[:,
0
,
:,
:]
past_key_value
=
(
xk
,
xv
)
if
use_cache
else
None
output
=
awq_inference_engine
.
single_query_attention
(
xq
,
xk
,
xv
,
self
.
cache_k
,
self
.
cache_v
,
None
,
None
,
self
.
start_pos
,
self
.
head_dim
,
10000
,
True
,
)
output
=
output
.
reshape
(
bsz
,
1
,
-
1
)
attn_output
=
self
.
o_proj
(
output
)
if
use_cache
:
self
.
start_pos
+=
seqlen
else
:
self
.
start_pos
=
0
return
attn_output
,
None
,
past_key_value
return
attn_output
,
None
,
past_key_value
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment