Unverified Commit 70c0c1f9 authored by eigen's avatar eigen Committed by GitHub
Browse files

fix: trtllm-gen attention take zero-init workspace (#10330)

parent 760b788a
...@@ -58,7 +58,6 @@ class TRTLLMMLAPrefillMetadata: ...@@ -58,7 +58,6 @@ class TRTLLMMLAPrefillMetadata:
class TRTLLMMLADecodeMetadata: class TRTLLMMLADecodeMetadata:
"""Metadata for TRTLLM MLA decode operations.""" """Metadata for TRTLLM MLA decode operations."""
workspace: Optional[torch.Tensor] = None
block_kv_indices: Optional[torch.Tensor] = None block_kv_indices: Optional[torch.Tensor] = None
max_seq_len: Optional[int] = None max_seq_len: Optional[int] = None
...@@ -187,9 +186,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -187,9 +186,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
self.decode_cuda_graph_kv_indices = torch.full( self.decode_cuda_graph_kv_indices = torch.full(
(max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device (max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
) )
self.decode_cuda_graph_workspace = torch.empty(
self.workspace_size, dtype=torch.int8, device=self.device
)
super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf) super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf)
...@@ -240,7 +236,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -240,7 +236,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
max_seq_len_val = int(seq_lens.max().item()) max_seq_len_val = int(seq_lens.max().item())
metadata = TRTLLMMLADecodeMetadata( metadata = TRTLLMMLADecodeMetadata(
self.decode_cuda_graph_workspace,
block_kv_indices, block_kv_indices,
max_seq_len_val, max_seq_len_val,
) )
...@@ -339,7 +334,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -339,7 +334,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
max_seq_len_val = int(max_seq) max_seq_len_val = int(max_seq)
self.forward_decode_metadata = TRTLLMMLADecodeMetadata( self.forward_decode_metadata = TRTLLMMLADecodeMetadata(
self.workspace_buffer, block_kv_indices, max_seq_len_val block_kv_indices, max_seq_len_val
) )
forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata
else: else:
...@@ -513,7 +508,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -513,7 +508,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
query=query, query=query,
kv_cache=kv_cache, kv_cache=kv_cache,
workspace_buffer=metadata.workspace, workspace_buffer=self.workspace_buffer,
qk_nope_head_dim=self.qk_nope_head_dim, qk_nope_head_dim=self.qk_nope_head_dim,
kv_lora_rank=self.kv_lora_rank, kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim, qk_rope_head_dim=self.qk_rope_head_dim,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment