Unverified Commit e39ebf5c authored by Elfie Guo's avatar Elfie Guo Committed by GitHub
Browse files

[Core/Bugfix] Add query dtype as per FlashInfer API requirements. (#8173)

parent ba262c4e
...@@ -445,7 +445,8 @@ def test_flashinfer_decode_with_paged_fp8_kv( ...@@ -445,7 +445,8 @@ def test_flashinfer_decode_with_paged_fp8_kv(
head_size, head_size,
block_size, block_size,
"NONE", "NONE",
data_type=dtype) data_type=dtype,
q_data_type=dtype)
output = wrapper.forward(query, output = wrapper.forward(query,
kv_cache_fp8, kv_cache_fp8,
logits_soft_cap=soft_cap, logits_soft_cap=soft_cap,
......
...@@ -224,6 +224,7 @@ class FlashInferState(AttentionState): ...@@ -224,6 +224,7 @@ class FlashInferState(AttentionState):
query_start_loc=query_start_loc_host, query_start_loc=query_start_loc_host,
device=self.runner.device, device=self.runner.device,
data_type=kv_cache_dtype, data_type=kv_cache_dtype,
q_data_type=self.runner.model_config.dtype,
use_cuda_graph=True, use_cuda_graph=True,
decode_wrapper=self._graph_decode_wrapper, decode_wrapper=self._graph_decode_wrapper,
prefill_wrapper=None) prefill_wrapper=None)
...@@ -292,6 +293,8 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -292,6 +293,8 @@ class FlashInferMetadata(AttentionMetadata):
page_size: Optional[int] = None page_size: Optional[int] = None
# The data type of the paged kv cache # The data type of the paged kv cache
data_type: torch.dtype = None data_type: torch.dtype = None
# The data type of the query
q_data_type: torch.dtype = None
device: torch.device = torch.device("cuda") device: torch.device = torch.device("cuda")
is_profile_run: bool = False is_profile_run: bool = False
...@@ -353,7 +356,10 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -353,7 +356,10 @@ class FlashInferMetadata(AttentionMetadata):
self.page_size, self.page_size,
# Disable flashinfer's pos encoding and use vllm's rope. # Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode="NONE", pos_encoding_mode="NONE",
data_type=self.data_type) # kv-cache data type.
data_type=self.data_type,
# query data type.
q_data_type=self.q_data_type)
def asdict_zerocopy(self, def asdict_zerocopy(self,
skip_fields: Optional[Set[str]] = None skip_fields: Optional[Set[str]] = None
...@@ -617,6 +623,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -617,6 +623,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
device=device, device=device,
data_type=kv_cache_dtype, data_type=kv_cache_dtype,
q_data_type=self.runner.model_config.dtype,
use_cuda_graph=use_captured_graph, use_cuda_graph=use_captured_graph,
is_profile_run=self.is_profile_run) is_profile_run=self.is_profile_run)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment