Unverified Commit 5312a728 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Bug] Fix `'CutlassMLAImpl' object has no attribute '_workspace_buffer'` (#31173)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent de717476
...@@ -355,6 +355,7 @@ class MLACommonPrefillMetadata: ...@@ -355,6 +355,7 @@ class MLACommonPrefillMetadata:
max_query_len: int max_query_len: int
chunked_context: ChunkedContextMetadata | None = None chunked_context: ChunkedContextMetadata | None = None
query_seq_lens: torch.Tensor | None = None query_seq_lens: torch.Tensor | None = None
workspace_buffer: torch.Tensor | None = None
q_data_type: torch.dtype | None = None q_data_type: torch.dtype | None = None
...@@ -986,6 +987,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -986,6 +987,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
prefill_metadata.query_seq_lens = ( prefill_metadata.query_seq_lens = (
prefill_query_start_loc[1:] - prefill_query_start_loc[:-1] prefill_query_start_loc[1:] - prefill_query_start_loc[:-1]
) )
prefill_metadata.workspace_buffer = self._workspace_buffer
decode_metadata = None decode_metadata = None
if num_decodes > 0: if num_decodes > 0:
...@@ -1567,6 +1569,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1567,6 +1569,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
from flashinfer.prefill import trtllm_ragged_attention_deepseek from flashinfer.prefill import trtllm_ragged_attention_deepseek
assert prefill.query_seq_lens is not None assert prefill.query_seq_lens is not None
assert prefill.workspace_buffer is not None
if fp8_attention: if fp8_attention:
logger.debug_once("Running TRT-LLM ragged prefill in FP8") logger.debug_once("Running TRT-LLM ragged prefill in FP8")
...@@ -1579,7 +1582,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1579,7 +1582,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
query=q, query=q,
key=k, key=k,
value=v, value=v,
workspace_buffer=self._workspace_buffer, workspace_buffer=prefill.workspace_buffer,
seq_lens=prefill.query_seq_lens, seq_lens=prefill.query_seq_lens,
max_q_len=prefill.max_query_len, max_q_len=prefill.max_query_len,
max_kv_len=prefill.max_query_len, max_kv_len=prefill.max_query_len,
...@@ -1615,6 +1618,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1615,6 +1618,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
assert prefill.chunked_context is not None assert prefill.chunked_context is not None
assert prefill.chunked_context.seq_lens[chunk_idx] is not None assert prefill.chunked_context.seq_lens[chunk_idx] is not None
assert prefill.workspace_buffer is not None
out = torch.zeros( out = torch.zeros(
q.shape[0], q.shape[0],
...@@ -1623,7 +1627,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1623,7 +1627,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
device=q.device, device=q.device,
dtype=q.dtype, dtype=q.dtype,
) )
self._workspace_buffer.fill_(0) prefill.workspace_buffer.fill_(0)
if fp8_attention: if fp8_attention:
logger.debug_once("Running TRT-LLM ragged prefill context chunk in FP8") logger.debug_once("Running TRT-LLM ragged prefill context chunk in FP8")
...@@ -1636,7 +1640,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1636,7 +1640,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
query=q, query=q,
key=k, key=k,
value=v, value=v,
workspace_buffer=self._workspace_buffer, workspace_buffer=prefill.workspace_buffer,
seq_lens=prefill.chunked_context.seq_lens[chunk_idx], seq_lens=prefill.chunked_context.seq_lens[chunk_idx],
max_q_len=prefill.max_query_len, max_q_len=prefill.max_query_len,
max_kv_len=prefill.chunked_context.max_seq_lens[chunk_idx], max_kv_len=prefill.chunked_context.max_seq_lens[chunk_idx],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment