Unverified Commit b1a63d1b authored by nvjullin's avatar nvjullin Committed by GitHub
Browse files

[BugFix] Make FlashInferMetadataBuilder non-blocking (#25040)


Signed-off-by: default avatarJulien Lin <jullin@nvidia.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
parent 48ecb443
...@@ -585,9 +585,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -585,9 +585,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
kv_data_type=self.kv_cache_dtype, kv_data_type=self.kv_cache_dtype,
) )
else: else:
attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(self.device) attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(
self.device, non_blocking=True)
attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to( attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to(
self.device) self.device, non_blocking=True)
if num_decodes > 0: if num_decodes > 0:
pure_decode = num_prefills == 0 pure_decode = num_prefills == 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment