"examples/tensorflow/rgcn/model.py" did not exist on "31a7d50964ae1c9d5693661567e7a3e034383bd7"
Unverified Commit 70c0c1f9 authored by eigen's avatar eigen Committed by GitHub
Browse files

fix: trtllm-gen attention take zero-init workspace (#10330)

parent 760b788a
......@@ -58,7 +58,6 @@ class TRTLLMMLAPrefillMetadata:
class TRTLLMMLADecodeMetadata:
"""Metadata for TRTLLM MLA decode operations."""
workspace: Optional[torch.Tensor] = None
block_kv_indices: Optional[torch.Tensor] = None
max_seq_len: Optional[int] = None
......@@ -187,9 +186,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
self.decode_cuda_graph_kv_indices = torch.full(
(max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
)
self.decode_cuda_graph_workspace = torch.empty(
self.workspace_size, dtype=torch.int8, device=self.device
)
super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf)
......@@ -240,7 +236,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
max_seq_len_val = int(seq_lens.max().item())
metadata = TRTLLMMLADecodeMetadata(
self.decode_cuda_graph_workspace,
block_kv_indices,
max_seq_len_val,
)
......@@ -339,7 +334,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
max_seq_len_val = int(max_seq)
self.forward_decode_metadata = TRTLLMMLADecodeMetadata(
self.workspace_buffer, block_kv_indices, max_seq_len_val
block_kv_indices, max_seq_len_val
)
forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata
else:
......@@ -513,7 +508,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
query=query,
kv_cache=kv_cache,
workspace_buffer=metadata.workspace,
workspace_buffer=self.workspace_buffer,
qk_nope_head_dim=self.qk_nope_head_dim,
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment