Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
04e5acc0
Unverified
Commit
04e5acc0
authored
Mar 06, 2023
by
Woosuk Kwon
Committed by
GitHub
Mar 06, 2023
Browse files
Fix a bug in 1D input shape (#5)
parent
3e9f991d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
6 deletions
+11
-6
cacheflow/models/attention.py
cacheflow/models/attention.py
+8
-3
cacheflow/models/input_metadata.py
cacheflow/models/input_metadata.py
+1
-1
server.py
server.py
+2
-2
No files found.
cacheflow/models/attention.py
View file @
04e5acc0
...
@@ -47,9 +47,8 @@ class OPTCacheFlowAttention(nn.Module):
...
@@ -47,9 +47,8 @@ class OPTCacheFlowAttention(nn.Module):
max_s
=
max_prompt_len
,
max_s
=
max_prompt_len
,
causal
=
True
,
causal
=
True
,
)[
0
]
)[
0
]
num_tokens
=
prefix_sum
[
-
1
]
# FIXME(woosuk): Unnecessary copy. Optimize this.
# FIXME(woosuk): Unnecessary copy. Optimize this.
output
[:
num_tokens
]
.
copy_
(
out
,
non_blocking
=
True
)
output
.
copy_
(
out
,
non_blocking
=
True
)
def
single_query_cached_kv_attention
(
def
single_query_cached_kv_attention
(
self
,
self
,
...
@@ -108,8 +107,14 @@ class OPTCacheFlowAttention(nn.Module):
...
@@ -108,8 +107,14 @@ class OPTCacheFlowAttention(nn.Module):
# Compute the attention op for prompts.
# Compute the attention op for prompts.
if
input_metadata
.
num_prompts
>
0
:
if
input_metadata
.
num_prompts
>
0
:
num_prompt_tokens
=
sum
(
input_metadata
.
prompt_lens
)
self
.
multi_query_kv_attention
(
self
.
multi_query_kv_attention
(
output
,
query
,
key
,
value
,
input_metadata
.
prompt_lens
)
output
[:
num_prompt_tokens
],
query
[:
num_prompt_tokens
],
key
[:
num_prompt_tokens
],
value
[:
num_prompt_tokens
],
input_metadata
.
prompt_lens
,
)
# Wait until the cache op is done.
# Wait until the cache op is done.
if
cache_event
is
not
None
:
if
cache_event
is
not
None
:
...
...
cacheflow/models/input_metadata.py
View file @
04e5acc0
...
@@ -24,7 +24,7 @@ class InputMetadata:
...
@@ -24,7 +24,7 @@ class InputMetadata:
self
.
num_prompts
=
len
(
prompt_lens
)
self
.
num_prompts
=
len
(
prompt_lens
)
self
.
num_generation_tokens
=
context_lens
.
shape
[
0
]
self
.
num_generation_tokens
=
context_lens
.
shape
[
0
]
self
.
num_valid_tokens
=
len
(
slot_mapping
)
self
.
num_valid_tokens
=
slot_mapping
.
shape
[
0
]
if
block_tables
.
numel
()
>
0
:
if
block_tables
.
numel
()
>
0
:
self
.
max_num_blocks_per_seq
=
block_tables
.
shape
[
1
]
self
.
max_num_blocks_per_seq
=
block_tables
.
shape
[
1
]
else
:
else
:
...
...
server.py
View file @
04e5acc0
...
@@ -57,11 +57,11 @@ def main():
...
@@ -57,11 +57,11 @@ def main():
'UC Berkeley is'
,
'UC Berkeley is'
,
'The future of cloud computing is'
,
'The future of cloud computing is'
,
]
]
for
prompt
in
test_inputs
:
frontend
.
query
(
prompt
)
# FIXME
# FIXME
while
True
:
while
True
:
if
test_inputs
:
frontend
.
query
(
test_inputs
.
pop
())
scheduler
.
step
()
scheduler
.
step
()
if
not
scheduler
.
pending
and
not
scheduler
.
running
:
if
not
scheduler
.
pending
and
not
scheduler
.
running
:
break
break
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment