Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e53dfd3e
Unverified
Commit
e53dfd3e
authored
Aug 07, 2024
by
Lily Liu
Committed by
GitHub
Aug 07, 2024
Browse files
[Kernel] Fix Flashinfer Correctness (#7284)
parent
6d944202
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
3 deletions
+7
-3
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+7
-3
No files found.
vllm/attention/backends/flashinfer.py
View file @
e53dfd3e
...
@@ -127,6 +127,7 @@ class FlashInferMetadata(AttentionMetadata):
...
@@ -127,6 +127,7 @@ class FlashInferMetadata(AttentionMetadata):
raise
ValueError
(
raise
ValueError
(
f
"Only
{
supported_head_sizes
}
are supported for head_dim,"
,
f
"Only
{
supported_head_sizes
}
are supported for head_dim,"
,
f
"received
{
self
.
head_dim
}
."
)
f
"received
{
self
.
head_dim
}
."
)
self
.
is_profile_run
=
is_block_tables_empty
(
self
.
block_tables
)
def
begin_forward
(
self
):
def
begin_forward
(
self
):
if
self
.
num_prefill_tokens
>
0
:
if
self
.
num_prefill_tokens
>
0
:
...
@@ -140,11 +141,14 @@ class FlashInferMetadata(AttentionMetadata):
...
@@ -140,11 +141,14 @@ class FlashInferMetadata(AttentionMetadata):
assert
self
.
paged_kv_last_page_len
is
not
None
assert
self
.
paged_kv_last_page_len
is
not
None
batch_size
=
self
.
query_start_loc
.
shape
[
0
]
-
1
batch_size
=
self
.
query_start_loc
.
shape
[
0
]
-
1
assert
batch_size
>=
0
assert
batch_size
>=
0
# The pr
e
fil
l stage
does not read kv cache.
# The pr
o
fil
e run
does not read kv cache.
# Both paged_kv_indices and paged_kv_last_page_len are empty.
# Both paged_kv_indices and paged_kv_last_page_len are empty.
# paged_kv_indptr is a zero tensor with size batch_size + 1.
# paged_kv_indptr is a zero tensor with size batch_size + 1.
if
self
.
is_profile_run
:
self
.
paged_kv_indptr
=
torch
.
zeros
(
batch_size
+
1
,
self
.
paged_kv_indptr
=
torch
.
zeros
(
batch_size
+
1
,
device
=
self
.
device
)
device
=
self
.
device
)
else
:
self
.
paged_kv_indptr
=
self
.
paged_kv_indptr
.
to
(
self
.
device
)
self
.
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
.
to
(
self
.
paged_kv_last_page_len
=
self
.
paged_kv_last_page_len
.
to
(
self
.
device
)
self
.
device
)
self
.
paged_kv_indices
=
self
.
paged_kv_indices
.
to
(
self
.
device
)
self
.
paged_kv_indices
=
self
.
paged_kv_indices
.
to
(
self
.
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment