Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
762fd1c3
Commit
762fd1c3
authored
Feb 24, 2023
by
Woosuk Kwon
Browse files
Refactor and annotate types for attention
parent
7f22f90e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
32 deletions
+40
-32
cacheflow/models/attention.py
cacheflow/models/attention.py
+40
-32
No files found.
cacheflow/models/attention.py
View file @
762fd1c3
from
typing
import
Optional
from
typing
import
List
,
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -30,24 +30,34 @@ class OPTCacheFlowAttention(nn.Module):
...
@@ -30,24 +30,34 @@ class OPTCacheFlowAttention(nn.Module):
def
multi_query_kv_attention
(
def
multi_query_kv_attention
(
self
,
self
,
output
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
# [num_prompt_tokens, num_heads, head_size]
prompt_lens
:
List
[
int
],
)
->
None
:
)
->
None
:
# FIXME(woosuk): Replace this with a custom op call.
# FIXME(woosuk): Replace the following with a custom op.
attention_mask
=
torch
.
triu
(
start_idx
=
0
torch
.
ones
(
query
.
shape
[
0
],
key
.
shape
[
0
]),
diagonal
=
1
)
*
-
1e5
for
prompt_len
in
prompt_lens
:
attention_mask
=
attention_mask
.
to
(
dtype
=
query
.
dtype
,
device
=
query
.
device
)
out
=
output
[
start_idx
:
start_idx
+
prompt_len
]
out
=
self
.
_masked_attention
(
query
,
key
,
value
,
attention_mask
)
q
=
query
[
start_idx
:
start_idx
+
prompt_len
]
output
.
copy_
(
out
,
non_blocking
=
True
)
k
=
key
[
start_idx
:
start_idx
+
prompt_len
]
v
=
value
[
start_idx
:
start_idx
+
prompt_len
]
attention_mask
=
torch
.
triu
(
torch
.
ones
(
q
.
shape
[
0
],
k
.
shape
[
0
]),
diagonal
=
1
)
*
-
1e5
attention_mask
=
attention_mask
.
to
(
dtype
=
q
.
dtype
,
device
=
q
.
device
)
attention_out
=
self
.
_masked_attention
(
q
,
k
,
v
,
attention_mask
)
out
.
copy_
(
attention_out
,
non_blocking
=
True
)
start_idx
+=
prompt_len
def
single_query_cached_kv_attention
(
def
single_query_cached_kv_attention
(
self
,
self
,
output
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
# [num_generation_tokens, num_heads, head_size]
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
# [num_generation_tokens, num_heads, head_size]
key_cache
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
# [num_blocks, num_heads, head_size/x, block_size, x]
value_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
# [num_blocks, num_heads, block_size, head_size]
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
None
:
)
->
None
:
num_heads
=
value_cache
.
shape
[
1
]
num_heads
=
value_cache
.
shape
[
1
]
...
@@ -82,15 +92,18 @@ class OPTCacheFlowAttention(nn.Module):
...
@@ -82,15 +92,18 @@ class OPTCacheFlowAttention(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
# [num_tokens, num_heads * head_size]
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
# [num_tokens, num_heads * head_size]
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
# [num_tokens, num_heads * head_size]
key_cache
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
# [num_blocks, num_heads, head_size/x, block_size, x]
value_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
# [num_blocks, num_heads, block_size, head_size]
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# [num_tokens, num_heads * head_size]
# Prune out invalid tokens.
# Pre-allocate the output tensor.
output
=
torch
.
empty_like
(
query
)
# Prune out paddings if any.
query
=
query
[:
input_metadata
.
num_valid_tokens
]
query
=
query
[:
input_metadata
.
num_valid_tokens
]
key
=
key
[:
input_metadata
.
num_valid_tokens
]
key
=
key
[:
input_metadata
.
num_valid_tokens
]
value
=
value
[:
input_metadata
.
num_valid_tokens
]
value
=
value
[:
input_metadata
.
num_valid_tokens
]
...
@@ -101,18 +114,11 @@ class OPTCacheFlowAttention(nn.Module):
...
@@ -101,18 +114,11 @@ class OPTCacheFlowAttention(nn.Module):
query
=
query
.
view
(
-
1
,
num_heads
,
head_size
)
query
=
query
.
view
(
-
1
,
num_heads
,
head_size
)
key
=
key
.
view
(
-
1
,
num_heads
,
head_size
)
key
=
key
.
view
(
-
1
,
num_heads
,
head_size
)
value
=
value
.
view
(
-
1
,
num_heads
,
head_size
)
value
=
value
.
view
(
-
1
,
num_heads
,
head_size
)
output
=
output
.
view
(
-
1
,
num_heads
,
head_size
)
# Compute the attention op for prompts.
# Compute the attention op for prompts.
output
=
torch
.
empty_like
(
query
)
self
.
multi_query_kv_attention
(
start_idx
=
0
output
,
query
,
key
,
value
,
input_metadata
.
prompt_lens
)
for
i
in
range
(
input_metadata
.
num_prompts
):
prompt_len
=
input_metadata
.
prompt_lens
[
i
]
out
=
output
[
start_idx
:
start_idx
+
prompt_len
]
q
=
query
[
start_idx
:
start_idx
+
prompt_len
]
k
=
key
[
start_idx
:
start_idx
+
prompt_len
]
v
=
value
[
start_idx
:
start_idx
+
prompt_len
]
self
.
multi_query_kv_attention
(
out
,
q
,
k
,
v
)
start_idx
+=
prompt_len
# Wait until the cache op is done.
# Wait until the cache op is done.
if
cache_event
is
not
None
:
if
cache_event
is
not
None
:
...
@@ -124,6 +130,7 @@ class OPTCacheFlowAttention(nn.Module):
...
@@ -124,6 +130,7 @@ class OPTCacheFlowAttention(nn.Module):
if
input_metadata
.
num_generation_tokens
>
0
:
if
input_metadata
.
num_generation_tokens
>
0
:
# Compute the attention op for generation tokens.
# Compute the attention op for generation tokens.
start_idx
=
sum
(
input_metadata
.
prompt_lens
)
self
.
single_query_cached_kv_attention
(
self
.
single_query_cached_kv_attention
(
output
[
start_idx
:],
output
[
start_idx
:],
query
[
start_idx
:],
query
[
start_idx
:],
...
@@ -132,4 +139,5 @@ class OPTCacheFlowAttention(nn.Module):
...
@@ -132,4 +139,5 @@ class OPTCacheFlowAttention(nn.Module):
input_metadata
)
input_metadata
)
# Reshape the output tensor.
# Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings.
return
output
.
view
(
-
1
,
num_heads
*
head_size
)
return
output
.
view
(
-
1
,
num_heads
*
head_size
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment