Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
87e0bcd4
Commit
87e0bcd4
authored
Feb 23, 2023
by
Woosuk Kwon
Browse files
Fix attention
parent
1ce13335
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
7 deletions
+12
-7
cacheflow/models/attention.py
cacheflow/models/attention.py
+12
-7
No files found.
cacheflow/models/attention.py
View file @
87e0bcd4
...
@@ -44,19 +44,18 @@ class OPTCacheFlowAttention(nn.Module):
...
@@ -44,19 +44,18 @@ class OPTCacheFlowAttention(nn.Module):
# FIXME(woosuk): Replace the following with a custom op.
# FIXME(woosuk): Replace the following with a custom op.
for
i
in
range
(
input_metadata
.
num_generation_tokens
):
for
i
in
range
(
input_metadata
.
num_generation_tokens
):
q
=
query
[
i
]
q
=
query
[
i
]
.
unsqueeze
(
0
)
block_table
=
block_tables
[
i
]
block_table
=
block_tables
[
i
]
context_len
=
int
(
input_metadata
.
context_lens
[
i
])
context_len
=
int
(
input_metadata
.
context_lens
[
i
])
keys
=
[]
keys
=
[]
for
j
in
range
(
context_len
):
for
j
in
range
(
context_len
):
block_number
=
block_table
[
j
//
block_size
]
block_number
=
block_table
[
j
//
block_size
]
block_offset
=
j
%
block_size
block_offset
=
j
%
block_size
k
=
key_cache
[
block_number
,
:,
:,
block_offset
,
:]
k
=
key_cache
[
block_number
,
:,
:,
block_offset
,
:]
k
=
k
.
view
(
num_heads
,
head_size
)
k
=
k
.
reshape
(
num_heads
,
head_size
)
keys
.
append
(
k
)
keys
.
append
(
k
)
keys
=
torch
.
stack
(
keys
,
dim
=-
1
)
keys
=
torch
.
stack
(
keys
,
dim
=
0
)
logits
=
q
@
keys
attention_weights
=
torch
.
softmax
(
logits
,
dim
=-
1
)
values
=
[]
values
=
[]
for
j
in
range
(
context_len
):
for
j
in
range
(
context_len
):
...
@@ -64,8 +63,14 @@ class OPTCacheFlowAttention(nn.Module):
...
@@ -64,8 +63,14 @@ class OPTCacheFlowAttention(nn.Module):
block_offset
=
j
%
block_size
block_offset
=
j
%
block_size
v
=
value_cache
[
block_number
,
:,
block_offset
,
:]
v
=
value_cache
[
block_number
,
:,
block_offset
,
:]
values
.
append
(
v
)
values
.
append
(
v
)
values
=
torch
.
stack
(
values
,
dim
=-
1
)
values
=
torch
.
stack
(
values
,
dim
=
0
)
out
=
attention_weights
@
values
q
=
q
.
unsqueeze
(
0
)
keys
=
keys
.
unsqueeze
(
0
)
values
=
values
.
unsqueeze
(
0
)
out
=
xops
.
memory_efficient_attention
(
q
,
keys
,
values
,
scale
=
self
.
scale
)
out
=
out
.
view
(
num_heads
,
head_size
)
output
[
i
].
copy_
(
out
,
non_blocking
=
True
)
output
[
i
].
copy_
(
out
,
non_blocking
=
True
)
def
forward
(
def
forward
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment