Unverified Commit 6a7528e6 authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

[bugfix] Fix page size for create_flashmla_kv_indices_triton() for cutlass mla (#8685)

parent 2ae95d17
...@@ -102,7 +102,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend): ...@@ -102,7 +102,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
block_kv_indices, block_kv_indices,
self.req_to_token.stride(0), self.req_to_token.stride(0),
max_seqlen_pad, max_seqlen_pad,
PAGE_SIZE, PAGED_SIZE=PAGE_SIZE,
) )
workspace_size = cutlass_mla_get_workspace_size( workspace_size = cutlass_mla_get_workspace_size(
max_seqlen_pad * PAGE_SIZE, bs, num_kv_splits=1 max_seqlen_pad * PAGE_SIZE, bs, num_kv_splits=1
...@@ -165,7 +165,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend): ...@@ -165,7 +165,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
self.cuda_graph_kv_indices, self.cuda_graph_kv_indices,
self.req_to_token.stride(0), self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0), self.cuda_graph_kv_indices.stride(0),
PAGE_SIZE, PAGED_SIZE=PAGE_SIZE,
) )
self.forward_metadata = CutlassMLADecodeMetadata( self.forward_metadata = CutlassMLADecodeMetadata(
self.cuda_graph_mla_workspace, self.cuda_graph_mla_workspace,
...@@ -206,7 +206,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend): ...@@ -206,7 +206,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
self.cuda_graph_kv_indices, self.cuda_graph_kv_indices,
self.req_to_token.stride(0), self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0), self.cuda_graph_kv_indices.stride(0),
PAGE_SIZE, PAGED_SIZE=PAGE_SIZE,
) )
else: else:
super().init_forward_metadata_replay_cuda_graph( super().init_forward_metadata_replay_cuda_graph(
......
...@@ -147,8 +147,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -147,8 +147,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
block_kv_indices, block_kv_indices,
self.req_to_token.stride(0), self.req_to_token.stride(0),
max_blocks, max_blocks,
TRITON_PAD_NUM_PAGE_PER_BLOCK, NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
self.page_size, PAGED_SIZE=self.page_size,
) )
return block_kv_indices return block_kv_indices
...@@ -204,8 +204,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -204,8 +204,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
block_kv_indices, block_kv_indices,
self.req_to_token.stride(0), self.req_to_token.stride(0),
max_seqlen_pad, max_seqlen_pad,
TRITON_PAD_NUM_PAGE_PER_BLOCK, NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
self.page_size, PAGED_SIZE=self.page_size,
) )
metadata = TRTLLMMLADecodeMetadata(self.cuda_graph_workspace, block_kv_indices) metadata = TRTLLMMLADecodeMetadata(self.cuda_graph_workspace, block_kv_indices)
...@@ -248,8 +248,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -248,8 +248,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
metadata.block_kv_indices, metadata.block_kv_indices,
self.req_to_token.stride(0), self.req_to_token.stride(0),
metadata.block_kv_indices.shape[1], metadata.block_kv_indices.shape[1],
TRITON_PAD_NUM_PAGE_PER_BLOCK, NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
self.page_size, PAGED_SIZE=self.page_size,
) )
def get_cuda_graph_seq_len_fill_value(self) -> int: def get_cuda_graph_seq_len_fill_value(self) -> int:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment