Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
7f22f90e
Commit
7f22f90e
authored
Feb 24, 2023
by
Woosuk Kwon
Browse files
Remove xformers
parent
afdbe5d3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
16 deletions
+20
-16
cacheflow/models/attention.py
cacheflow/models/attention.py
+20
-16
No files found.
cacheflow/models/attention.py
View file @
7f22f90e
...
...
@@ -2,7 +2,6 @@ from typing import Optional
import
torch
import
torch.nn
as
nn
import
xformers.ops
as
xops
from
cacheflow
import
ops
from
cacheflow.models
import
InputMetadata
...
...
@@ -14,8 +13,20 @@ class OPTCacheFlowAttention(nn.Module):
super
().
__init__
()
self
.
scale
=
scale
# Shape-agnostic attention mask.
self
.
attention_mask
=
xops
.
LowerTriangularMask
()
def
_masked_attention
(
self
,
query
:
torch
.
Tensor
,
# [num_queries, num_heads, head_size]
key
:
torch
.
Tensor
,
# [num_keys, num_heads, head_size]
value
:
torch
.
Tensor
,
# [num_keys, num_heads, head_size]
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
# [num_queries, num_keys]
)
->
torch
.
Tensor
:
# [num_queries, num_heads, head_size]
query
=
query
*
self
.
scale
attn
=
torch
.
einsum
(
'qhd,khd->hqk'
,
query
,
key
)
if
attn_mask
is
not
None
:
attn
=
attn
+
attn_mask
attn
=
torch
.
softmax
(
attn
,
dim
=-
1
)
out
=
torch
.
einsum
(
'hqk,khd->qhd'
,
attn
,
value
)
return
out
def
multi_query_kv_attention
(
self
,
...
...
@@ -24,13 +35,11 @@ class OPTCacheFlowAttention(nn.Module):
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
)
->
None
:
query
=
query
.
unsqueeze
(
0
)
key
=
key
.
unsqueeze
(
0
)
value
=
value
.
unsqueeze
(
0
)
out
=
xops
.
memory_efficient_attention
(
query
,
key
,
value
,
attn_bias
=
self
.
attention_mask
,
scale
=
self
.
scale
)
out
=
out
.
squeeze
(
0
)
# FIXME(woosuk): Directly write the attention output.
# FIXME(woosuk): Replace this with a custom op call.
attention_mask
=
torch
.
triu
(
torch
.
ones
(
query
.
shape
[
0
],
key
.
shape
[
0
]),
diagonal
=
1
)
*
-
1e5
attention_mask
=
attention_mask
.
to
(
dtype
=
query
.
dtype
,
device
=
query
.
device
)
out
=
self
.
_masked_attention
(
query
,
key
,
value
,
attention_mask
)
output
.
copy_
(
out
,
non_blocking
=
True
)
def
single_query_cached_kv_attention
(
...
...
@@ -64,15 +73,10 @@ class OPTCacheFlowAttention(nn.Module):
v
=
value_cache
[
block_number
,
:,
block_offset
,
:]
values
.
append
(
v
)
keys
=
torch
.
stack
(
keys
,
dim
=
0
)
values
=
torch
.
stack
(
values
,
dim
=
0
)
q
=
q
.
unsqueeze
(
0
)
keys
=
keys
.
unsqueeze
(
0
)
values
=
values
.
unsqueeze
(
0
)
out
=
xops
.
memory_efficient_attention
(
q
,
keys
,
values
,
scale
=
self
.
scale
)
out
=
self
.
_masked_attention
(
q
,
keys
,
values
)
out
=
out
.
view
(
num_heads
,
head_size
)
output
[
i
].
copy_
(
out
,
non_blocking
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment