Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
e46703d8
Commit
e46703d8
authored
Oct 06, 2023
by
Casper Hansen
Browse files
Use new WindowedCache
parent
66b2e233
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
66 additions
and
82 deletions
+66
-82
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+66
-82
No files found.
awq/modules/fused/attn.py
View file @
e46703d8
...
...
@@ -2,8 +2,8 @@ import os
import
math
import
torch
import
torch.nn
as
nn
import
awq_inference_engine
from
torch.nn
import
functional
as
F
from
awq.modules.fused.cache
import
WindowedCache
try
:
import
ft_inference_engine
...
...
@@ -25,11 +25,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
shape
=
[
d
if
i
==
1
or
i
==
ndim
-
1
else
1
for
i
,
d
in
enumerate
(
x
.
shape
)]
return
freqs_cis
.
view
(
*
shape
)
def
apply_rotary_emb
(
xq
:
torch
.
Tensor
,
xk
:
torch
.
Tensor
,
freqs_cis
:
torch
.
Tensor
,
):
def
apply_rotary_emb
(
xq
:
torch
.
Tensor
,
xk
:
torch
.
Tensor
,
freqs_cis
:
torch
.
Tensor
):
xq_
=
torch
.
view_as_complex
(
xq
.
float
().
reshape
(
*
xq
.
shape
[:
-
1
],
2
,
-
1
).
transpose
(
-
2
,
-
1
).
contiguous
()
)
...
...
@@ -65,6 +61,49 @@ def build_alibi_bias(
slopes
=
slopes
.
squeeze
(
0
).
squeeze
(
-
1
).
squeeze
(
-
1
)
return
slopes
.
to
(
dtype
=
dtype
),
alibi_bias
.
to
(
dtype
=
dtype
)
def
get_attention_shapes
(
attention_shapes
,
max_seq_len
,
cache_batch_size
,
n_heads
,
n_kv_heads
,
head_dim
):
if
attention_shapes
is
not
None
:
attention_shapes
=
attention_shapes
elif
n_kv_heads
==
0
:
attention_shapes
=
{
# following fastertransformer definition
"cache_v"
:
(
cache_batch_size
,
n_heads
,
max_seq_len
,
head_dim
,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k"
:
(
cache_batch_size
,
n_heads
,
head_dim
//
8
,
max_seq_len
,
8
,),
"xqkv_view"
:
(
-
1
,
n_heads
,
head_dim
),
"xq_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
0
],
"xk_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
1
],
"xv_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
2
],
"xq_view"
:
(
n_heads
,
head_dim
),
"xk_view"
:
(
n_heads
,
head_dim
),
"xv_view"
:
(
n_heads
,
head_dim
),
"xk_reshape"
:
(
n_heads
,
head_dim
//
8
,
8
),
"single_xq_view"
:
(
n_heads
,
head_dim
),
"single_xk_view"
:
(
n_heads
,
head_dim
),
"single_xv_view"
:
(
n_heads
,
head_dim
)
}
else
:
attention_shapes
=
{
# following fastertransformer definition
"cache_v"
:
(
cache_batch_size
,
n_kv_heads
,
max_seq_len
,
head_dim
,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k"
:
(
cache_batch_size
,
n_kv_heads
,
head_dim
//
8
,
max_seq_len
,
8
,),
"xqkv_view"
:
(
n_heads
+
n_kv_heads
*
2
,
head_dim
),
"xq_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
0
:
n_heads
],
"xk_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
n_heads
:
(
n_heads
+
n_kv_heads
)],
"xv_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
-
n_kv_heads
:],
"xq_view"
:
(
n_heads
,
head_dim
),
"xk_view"
:
(
n_kv_heads
,
head_dim
),
"xv_view"
:
(
n_kv_heads
,
head_dim
),
"xk_reshape"
:
(
n_kv_heads
,
head_dim
//
8
,
8
),
"single_xq_view"
:
(
n_heads
,
head_dim
),
"single_xk_view"
:
(
n_kv_heads
,
head_dim
),
"single_xv_view"
:
(
n_kv_heads
,
head_dim
)
}
return
attention_shapes
class
QuantAttentionFused
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
n_heads
,
n_kv_heads
,
qkv_layer
,
o_proj
,
dev
,
max_seq_len
,
...
...
@@ -81,9 +120,15 @@ class QuantAttentionFused(nn.Module):
self
.
use_alibi
=
use_alibi
self
.
cache_batch_size
=
int
(
os
.
getenv
(
"AWQ_BATCH_SIZE"
,
"1"
))
self
.
max_seq_len
=
max_seq_len
self
.
attention_shapes
=
self
.
_get_attention_shapes
(
attention_shapes
,
max_seq_len
)
self
.
cache_v
=
(
torch
.
zeros
(
self
.
attention_shapes
[
"cache_v"
]).
to
(
dev
).
half
()
)
self
.
cache_k
=
(
torch
.
zeros
(
self
.
attention_shapes
[
"cache_k"
]).
to
(
dev
).
half
()
)
# attention shapes for self attention
self
.
attention_shapes
=
get_attention_shapes
(
attention_shapes
,
max_seq_len
,
self
.
cache_batch_size
,
n_heads
,
n_kv_heads
,
self
.
head_dim
)
# cache store that rolls cache
self
.
cache
=
WindowedCache
(
self
.
attention_shapes
[
"cache_v"
],
self
.
attention_shapes
[
"cache_k"
],
dev
)
if
use_alibi
:
alibi_slopes
,
alibi_bias
=
build_alibi_bias
(
self
.
n_heads
,
max_seq_len
)
...
...
@@ -100,55 +145,7 @@ class QuantAttentionFused(nn.Module):
self
.
alibi_slopes
=
None
self
.
is_neox
=
True
def
_get_attention_shapes
(
self
,
attention_shapes
,
max_seq_len
):
if
attention_shapes
is
not
None
:
attention_shapes
=
attention_shapes
elif
self
.
n_kv_heads
==
0
:
attention_shapes
=
{
# following fastertransformer definition
"cache_v"
:
(
self
.
cache_batch_size
,
self
.
n_heads
,
max_seq_len
,
self
.
head_dim
,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k"
:
(
self
.
cache_batch_size
,
self
.
n_heads
,
self
.
head_dim
//
8
,
max_seq_len
,
8
,),
"xqkv_view"
:
(
-
1
,
self
.
n_heads
,
self
.
head_dim
),
"xq_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
0
],
"xk_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
1
],
"xv_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
2
],
"xq_view"
:
(
self
.
n_heads
,
self
.
head_dim
),
"xk_view"
:
(
self
.
n_heads
,
self
.
head_dim
),
"xv_view"
:
(
self
.
n_heads
,
self
.
head_dim
),
"xk_reshape"
:
(
self
.
n_heads
,
self
.
head_dim
//
8
,
8
),
"single_xq_view"
:
(
self
.
n_heads
,
self
.
head_dim
),
"single_xk_view"
:
(
self
.
n_heads
,
self
.
head_dim
),
"single_xv_view"
:
(
self
.
n_heads
,
self
.
head_dim
)
}
else
:
attention_shapes
=
{
# following fastertransformer definition
"cache_v"
:
(
self
.
cache_batch_size
,
self
.
n_kv_heads
,
max_seq_len
,
self
.
head_dim
,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k"
:
(
self
.
cache_batch_size
,
self
.
n_kv_heads
,
self
.
head_dim
//
8
,
max_seq_len
,
8
,),
"xqkv_view"
:
(
self
.
n_heads
+
self
.
n_kv_heads
*
2
,
self
.
head_dim
),
"xq_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
0
:
self
.
n_heads
],
"xk_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
self
.
n_heads
:
(
self
.
n_heads
+
self
.
n_kv_heads
)],
"xv_slice"
:
lambda
xqkv
:
xqkv
[:,
:,
-
self
.
n_kv_heads
:],
"xq_view"
:
(
self
.
n_heads
,
self
.
head_dim
),
"xk_view"
:
(
self
.
n_kv_heads
,
self
.
head_dim
),
"xv_view"
:
(
self
.
n_kv_heads
,
self
.
head_dim
),
"xk_reshape"
:
(
self
.
n_kv_heads
,
self
.
head_dim
//
8
,
8
),
"single_xq_view"
:
(
self
.
n_heads
,
self
.
head_dim
),
"single_xk_view"
:
(
self
.
n_kv_heads
,
self
.
head_dim
),
"single_xv_view"
:
(
self
.
n_kv_heads
,
self
.
head_dim
)
}
return
attention_shapes
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
past_key_value
=
None
,
attention_mask
=
None
,
position_ids
=
None
,
output_attentions
=
False
,
use_cache
=
False
,
*
args
,
**
kwargs
):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
=
None
,
*
args
,
**
kwargs
):
bsz
,
seqlen
,
_
=
hidden_states
.
shape
if
bsz
!=
self
.
cache_batch_size
:
raise
RuntimeError
(
...
...
@@ -157,14 +154,8 @@ class QuantAttentionFused(nn.Module):
)
if
self
.
start_pos
>
self
.
max_seq_len
or
self
.
start_pos
+
seqlen
>
self
.
max_seq_len
:
# Roll cache to the left
roll_len
=
self
.
start_pos
self
.
cache_v
=
torch
.
roll
(
self
.
cache_v
,
shifts
=-
roll_len
,
dims
=
2
)
self
.
cache_k
=
torch
.
roll
(
self
.
cache_k
,
shifts
=-
roll_len
,
dims
=
3
)
# Zero out the new part
self
.
cache_v
[:,
:,
-
roll_len
:,
:]
=
0
self
.
cache_k
[:,
:,
:,
-
roll_len
:,
:]
=
0
self
.
start_pos
=
0
excess_length
=
self
.
start_pos
+
seqlen
-
self
.
max_seq_len
self
.
start_pos
=
self
.
cache
.
roll_kv
(
excess_length
,
self
.
start_pos
)
xqkv
=
self
.
qkv_proj
(
hidden_states
)
xqkv
=
xqkv
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xqkv_view"
])
...
...
@@ -181,8 +172,7 @@ class QuantAttentionFused(nn.Module):
if
not
self
.
use_alibi
:
xq
,
xk
=
apply_rotary_emb
(
xq
,
xk
,
freqs_cis
=
self
.
freqs_cis
[
self
.
start_pos
:
self
.
start_pos
+
seqlen
])
self
.
cache_k
=
self
.
cache_k
.
to
(
xq
)
self
.
cache_v
=
self
.
cache_v
.
to
(
xq
)
self
.
cache
.
to
(
xq
)
values_store
=
xv
.
transpose
(
2
,
1
)
keys_store
=
(
...
...
@@ -191,13 +181,10 @@ class QuantAttentionFused(nn.Module):
.
contiguous
()
)
self
.
cache_v
[:
bsz
,
:,
self
.
start_pos
:
self
.
start_pos
+
seqlen
,
:]
=
values_store
self
.
cache_k
[:
bsz
,
:,
:,
self
.
start_pos
:
self
.
start_pos
+
seqlen
,
:]
=
keys_store
self
.
cache
.
update_kv
(
values_store
,
keys_store
,
bsz
,
self
.
start_pos
,
seqlen
)
if
seqlen
==
1
:
xv
=
self
.
cache_v
[:
bsz
,
:,
:
self
.
start_pos
+
seqlen
,
:].
transpose
(
1
,
2
).
contiguous
()
xk
=
self
.
cache_k
[:
bsz
,
:,
:,
:
self
.
start_pos
+
seqlen
,
:].
transpose
(
2
,
3
).
contiguous
()
xk
=
xk
.
reshape
(
xk
.
shape
[:
-
2
]
+
(
self
.
head_dim
,)).
transpose
(
1
,
2
).
contiguous
()
xv
,
xk
=
self
.
cache
.
get_kv
(
bsz
,
self
.
start_pos
,
seqlen
,
self
.
head_dim
)
keys
=
xk
values
=
xv
...
...
@@ -229,8 +216,8 @@ class QuantAttentionFused(nn.Module):
xq
,
# query
xk
,
# key
xv
,
# value
self
.
cache
_
k
,
# key cache
self
.
cache
_
v
,
# value cache
self
.
cache
.
k
,
# key cache
self
.
cache
.
v
,
# value cache
None
,
# length per sample
self
.
alibi_slopes
,
# alibi slopes
self
.
start_pos
,
# timestep
...
...
@@ -241,11 +228,8 @@ class QuantAttentionFused(nn.Module):
attention_weight
=
attention_weight
.
reshape
(
bsz
,
1
,
-
1
)
attn_output
=
self
.
o_proj
(
attention_weight
)
if
use_cache
:
self
.
start_pos
+=
seqlen
else
:
self
.
start_pos
=
0
self
.
start_pos
+=
seqlen
# past_key_value is replaced with cache_v, cache_k, returning None
return
attn_output
,
attention_weight
,
None
# past_key_value is replaced with cache_v, cache_k, returning empty data
past_key_value
=
[
torch
.
Tensor
([
[
[[
0
]],
[[
0
]],
[[
0
]]
]
])]
return
attn_output
,
attention_weight
,
past_key_value
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment