Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e39ebf5c
Unverified
Commit
e39ebf5c
authored
Sep 04, 2024
by
Elfie Guo
Committed by
GitHub
Sep 05, 2024
Browse files
[Core/Bugfix] Add query dtype as per FlashInfer API requirements. (#8173)
parent
ba262c4e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
2 deletions
+10
-2
tests/kernels/test_flashinfer.py
tests/kernels/test_flashinfer.py
+2
-1
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+8
-1
No files found.
tests/kernels/test_flashinfer.py
View file @
e39ebf5c
...
...
@@ -445,7 +445,8 @@ def test_flashinfer_decode_with_paged_fp8_kv(
head_size
,
block_size
,
"NONE"
,
data_type
=
dtype
)
data_type
=
dtype
,
q_data_type
=
dtype
)
output
=
wrapper
.
forward
(
query
,
kv_cache_fp8
,
logits_soft_cap
=
soft_cap
,
...
...
vllm/attention/backends/flashinfer.py
View file @
e39ebf5c
...
...
@@ -224,6 +224,7 @@ class FlashInferState(AttentionState):
query_start_loc
=
query_start_loc_host
,
device
=
self
.
runner
.
device
,
data_type
=
kv_cache_dtype
,
q_data_type
=
self
.
runner
.
model_config
.
dtype
,
use_cuda_graph
=
True
,
decode_wrapper
=
self
.
_graph_decode_wrapper
,
prefill_wrapper
=
None
)
...
...
@@ -292,6 +293,8 @@ class FlashInferMetadata(AttentionMetadata):
page_size
:
Optional
[
int
]
=
None
# The data type of the paged kv cache
data_type
:
torch
.
dtype
=
None
# The data type of the query
q_data_type
:
torch
.
dtype
=
None
device
:
torch
.
device
=
torch
.
device
(
"cuda"
)
is_profile_run
:
bool
=
False
...
...
@@ -353,7 +356,10 @@ class FlashInferMetadata(AttentionMetadata):
self
.
page_size
,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode
=
"NONE"
,
data_type
=
self
.
data_type
)
# kv-cache data type.
data_type
=
self
.
data_type
,
# query data type.
q_data_type
=
self
.
q_data_type
)
def
asdict_zerocopy
(
self
,
skip_fields
:
Optional
[
Set
[
str
]]
=
None
...
...
@@ -617,6 +623,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
query_start_loc
=
query_start_loc
,
device
=
device
,
data_type
=
kv_cache_dtype
,
q_data_type
=
self
.
runner
.
model_config
.
dtype
,
use_cuda_graph
=
use_captured_graph
,
is_profile_run
=
self
.
is_profile_run
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment