Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
932844f1
Commit
932844f1
authored
Feb 23, 2023
by
Woosuk Kwon
Browse files
Fix attention
parent
ba84b872
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
6 deletions
+21
-6
cacheflow/models/attention.py
cacheflow/models/attention.py
+10
-6
cacheflow/models/input_metadata.py
cacheflow/models/input_metadata.py
+11
-0
No files found.
cacheflow/models/attention.py
View file @
932844f1
...
@@ -53,20 +53,19 @@ class OPTCacheFlowAttention(nn.Module):
...
@@ -53,20 +53,19 @@ class OPTCacheFlowAttention(nn.Module):
context_len
=
int
(
input_metadata
.
context_lens
[
i
])
context_len
=
int
(
input_metadata
.
context_lens
[
i
])
keys
=
[]
keys
=
[]
values
=
[]
for
j
in
range
(
context_len
):
for
j
in
range
(
context_len
):
block_number
=
block_table
[
j
//
block_size
]
block_number
=
int
(
block_table
[
j
//
block_size
]
)
block_offset
=
j
%
block_size
block_offset
=
j
%
block_size
k
=
key_cache
[
block_number
,
:,
:,
block_offset
,
:]
k
=
key_cache
[
block_number
,
:,
:,
block_offset
,
:]
k
=
k
.
reshape
(
num_heads
,
head_size
)
k
=
k
.
reshape
(
num_heads
,
head_size
)
keys
.
append
(
k
)
keys
.
append
(
k
)
keys
=
torch
.
stack
(
keys
,
dim
=
0
)
values
=
[]
for
j
in
range
(
context_len
):
block_number
=
block_table
[
j
//
block_size
]
block_offset
=
j
%
block_size
v
=
value_cache
[
block_number
,
:,
block_offset
,
:]
v
=
value_cache
[
block_number
,
:,
block_offset
,
:]
values
.
append
(
v
)
values
.
append
(
v
)
keys
=
torch
.
stack
(
keys
,
dim
=
0
)
values
=
torch
.
stack
(
values
,
dim
=
0
)
values
=
torch
.
stack
(
values
,
dim
=
0
)
q
=
q
.
unsqueeze
(
0
)
q
=
q
.
unsqueeze
(
0
)
...
@@ -87,6 +86,11 @@ class OPTCacheFlowAttention(nn.Module):
...
@@ -87,6 +86,11 @@ class OPTCacheFlowAttention(nn.Module):
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Prune out invalid tokens.
query
=
query
[:
input_metadata
.
num_valid_tokens
]
key
=
key
[:
input_metadata
.
num_valid_tokens
]
value
=
value
[:
input_metadata
.
num_valid_tokens
]
# Reshape the input tensors.
# Reshape the input tensors.
num_heads
=
value_cache
.
shape
[
1
]
num_heads
=
value_cache
.
shape
[
1
]
head_size
=
value_cache
.
shape
[
3
]
head_size
=
value_cache
.
shape
[
3
]
...
...
cacheflow/models/input_metadata.py
View file @
932844f1
...
@@ -11,6 +11,7 @@ class InputMetadata:
...
@@ -11,6 +11,7 @@ class InputMetadata:
prompt_lens
:
List
[
int
],
prompt_lens
:
List
[
int
],
slot_mapping
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
# FIXME: Rename
max_context_len
:
int
,
max_context_len
:
int
,
block_tables
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
...
@@ -23,9 +24,19 @@ class InputMetadata:
...
@@ -23,9 +24,19 @@ class InputMetadata:
self
.
num_prompts
=
len
(
prompt_lens
)
self
.
num_prompts
=
len
(
prompt_lens
)
self
.
num_generation_tokens
=
context_lens
.
shape
[
0
]
self
.
num_generation_tokens
=
context_lens
.
shape
[
0
]
self
.
num_valid_tokens
=
len
(
slot_mapping
)
if
block_tables
.
numel
()
>
0
:
if
block_tables
.
numel
()
>
0
:
self
.
max_num_blocks_per_seq
=
block_tables
.
shape
[
1
]
self
.
max_num_blocks_per_seq
=
block_tables
.
shape
[
1
]
else
:
else
:
self
.
max_num_blocks_per_seq
=
0
self
.
max_num_blocks_per_seq
=
0
assert
self
.
num_generation_tokens
==
block_tables
.
shape
[
0
]
assert
self
.
num_generation_tokens
==
block_tables
.
shape
[
0
]
assert
self
.
num_prompts
+
self
.
num_generation_tokens
==
len
(
seq_ids
)
assert
self
.
num_prompts
+
self
.
num_generation_tokens
==
len
(
seq_ids
)
def
__repr__
(
self
)
->
str
:
return
(
f
'InputMetadata('
f
'seq_ids=
{
self
.
seq_ids
}
, '
f
'num_prompts=
{
self
.
num_prompts
}
, '
f
'num_generation_tokens=
{
self
.
num_generation_tokens
}
, '
f
'num_valid_tokens=
{
self
.
num_valid_tokens
}
, '
f
'max_num_blocks_per_seq=
{
self
.
max_num_blocks_per_seq
}
, '
f
'max_context_len=
{
self
.
max_context_len
}
)'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment