Unverified Commit 0aac2048 authored by haosdent's avatar haosdent Committed by GitHub
Browse files

[Bugfix] Restore CUDA graph persistent buffers for FP8 FlashMLA decode (#35175)


Signed-off-by: default avatarhaosdent <haosdent@gmail.com>
Co-authored-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent cb226321
...@@ -172,6 +172,21 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): ...@@ -172,6 +172,21 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
num_q_tokens_per_head_k, num_q_tokens_per_head_k,
1, # MQA for the decode path 1, # MQA for the decode path
) )
# Copy FP8 metadata into persistent CUDA graph buffers
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
assert self.cg_buf_tile_scheduler_metadata is not None
assert self.cg_buf_num_splits is not None
n = tile_scheduler_metadata.size(0)
assert n <= self.cg_buf_tile_scheduler_metadata.size(0)
self.cg_buf_tile_scheduler_metadata[:n].copy_(tile_scheduler_metadata)
tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata[:n]
n = num_splits.size(0)
assert n <= self.cg_buf_num_splits.size(0)
self.cg_buf_num_splits[:n].copy_(num_splits)
num_splits = self.cg_buf_num_splits[:n]
scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata
scheduler_metadata.num_splits = num_splits scheduler_metadata.num_splits = num_splits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment