Unverified Commit 5312a728 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Bug] Fix `'CutlassMLAImpl' object has no attribute '_workspace_buffer'` (#31173)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent de717476
......@@ -355,6 +355,7 @@ class MLACommonPrefillMetadata:
max_query_len: int
chunked_context: ChunkedContextMetadata | None = None
query_seq_lens: torch.Tensor | None = None
workspace_buffer: torch.Tensor | None = None
q_data_type: torch.dtype | None = None
......@@ -986,6 +987,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
prefill_metadata.query_seq_lens = (
prefill_query_start_loc[1:] - prefill_query_start_loc[:-1]
)
prefill_metadata.workspace_buffer = self._workspace_buffer
decode_metadata = None
if num_decodes > 0:
......@@ -1567,6 +1569,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
from flashinfer.prefill import trtllm_ragged_attention_deepseek
assert prefill.query_seq_lens is not None
assert prefill.workspace_buffer is not None
if fp8_attention:
logger.debug_once("Running TRT-LLM ragged prefill in FP8")
......@@ -1579,7 +1582,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
query=q,
key=k,
value=v,
workspace_buffer=self._workspace_buffer,
workspace_buffer=prefill.workspace_buffer,
seq_lens=prefill.query_seq_lens,
max_q_len=prefill.max_query_len,
max_kv_len=prefill.max_query_len,
......@@ -1615,6 +1618,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
assert prefill.chunked_context is not None
assert prefill.chunked_context.seq_lens[chunk_idx] is not None
assert prefill.workspace_buffer is not None
out = torch.zeros(
q.shape[0],
......@@ -1623,7 +1627,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
device=q.device,
dtype=q.dtype,
)
self._workspace_buffer.fill_(0)
prefill.workspace_buffer.fill_(0)
if fp8_attention:
logger.debug_once("Running TRT-LLM ragged prefill context chunk in FP8")
......@@ -1636,7 +1640,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
query=q,
key=k,
value=v,
workspace_buffer=self._workspace_buffer,
workspace_buffer=prefill.workspace_buffer,
seq_lens=prefill.chunked_context.seq_lens[chunk_idx],
max_q_len=prefill.max_query_len,
max_kv_len=prefill.chunked_context.max_seq_lens[chunk_idx],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment