Unverified Commit bebd0576 authored by Elfie Guo's avatar Elfie Guo Committed by GitHub
Browse files

Integrate trtllm ragged attention for prefill self-attention (#9801)

parent f9836660
...@@ -96,6 +96,7 @@ class FlashInferMhaChunkKVRunner: ...@@ -96,6 +96,7 @@ class FlashInferMhaChunkKVRunner:
def update_wrapper( def update_wrapper(
self, self,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
disable_flashinfer_ragged: bool = False,
): ):
assert forward_batch.num_prefix_chunks is not None assert forward_batch.num_prefix_chunks is not None
num_prefix_chunks = forward_batch.num_prefix_chunks num_prefix_chunks = forward_batch.num_prefix_chunks
...@@ -128,16 +129,17 @@ class FlashInferMhaChunkKVRunner: ...@@ -128,16 +129,17 @@ class FlashInferMhaChunkKVRunner:
causal=False, causal=False,
) )
# ragged prefill # ragged prefill
self.ragged_wrapper.begin_forward( if not disable_flashinfer_ragged:
qo_indptr=qo_indptr, self.ragged_wrapper.begin_forward(
kv_indptr=qo_indptr, qo_indptr=qo_indptr,
num_qo_heads=self.num_local_heads, kv_indptr=qo_indptr,
num_kv_heads=self.num_local_heads, num_qo_heads=self.num_local_heads,
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim, num_kv_heads=self.num_local_heads,
head_dim_vo=self.v_head_dim, head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
q_data_type=self.q_data_type, head_dim_vo=self.v_head_dim,
causal=True, q_data_type=self.q_data_type,
) causal=True,
)
def forward( def forward(
self, self,
...@@ -491,9 +493,11 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -491,9 +493,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
return 1 return 1
def init_mha_chunk_metadata(self, forward_batch: ForwardBatch): def init_mha_chunk_metadata(
self, forward_batch: ForwardBatch, disable_flashinfer_ragged: bool = False
):
"""Init the metadata for a forward pass.""" """Init the metadata for a forward pass."""
self.mha_chunk_kv_cache.update_wrapper(forward_batch) self.mha_chunk_kv_cache.update_wrapper(forward_batch, disable_flashinfer_ragged)
def forward_extend( def forward_extend(
self, self,
......
...@@ -45,6 +45,15 @@ TRTLLM_BLOCK_CONSTRAINT = 128 ...@@ -45,6 +45,15 @@ TRTLLM_BLOCK_CONSTRAINT = 128
global_zero_init_workspace_buffer = None global_zero_init_workspace_buffer = None
@dataclass
class TRTLLMMLAPrefillMetadata:
"""Metadata for TRTLLM MLA prefill operations."""
max_seq_len: int
cum_seq_lens: torch.Tensor
seq_lens: torch.Tensor
@dataclass @dataclass
class TRTLLMMLADecodeMetadata: class TRTLLMMLADecodeMetadata:
"""Metadata for TRTLLM MLA decode operations.""" """Metadata for TRTLLM MLA decode operations."""
...@@ -101,7 +110,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -101,7 +110,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
# CUDA graph state # CUDA graph state
self.decode_cuda_graph_metadata = {} self.decode_cuda_graph_metadata = {}
self.decode_cuda_graph_kv_indices = None self.decode_cuda_graph_kv_indices = None
self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None
self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
def _calc_padded_blocks(self, max_seq_len: int) -> int: def _calc_padded_blocks(self, max_seq_len: int) -> int:
""" """
...@@ -235,7 +245,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -235,7 +245,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
max_seq_len_val, max_seq_len_val,
) )
self.decode_cuda_graph_metadata[bs] = metadata self.decode_cuda_graph_metadata[bs] = metadata
self.forward_metadata = metadata self.forward_decode_metadata = metadata
def init_forward_metadata_replay_cuda_graph( def init_forward_metadata_replay_cuda_graph(
self, self,
...@@ -291,31 +301,52 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -291,31 +301,52 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Initialize the metadata for a forward pass.""" """Initialize the metadata for a forward pass."""
# Delegate to parent for non-decode modes. # Delegate to parent for non-decode modes.
if not forward_batch.forward_mode.is_decode_or_idle(): if (
return super().init_forward_metadata(forward_batch) forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
seq_lens = forward_batch.seq_lens - forward_batch.extend_prefix_lens
cum_seq_lens_q = torch.cat(
(
torch.tensor([0], device=forward_batch.seq_lens.device),
torch.cumsum(seq_lens, dim=0),
)
).int()
max_seq_len = max(forward_batch.extend_seq_lens_cpu)
self.forward_prefill_metadata = TRTLLMMLAPrefillMetadata(
max_seq_len,
cum_seq_lens_q,
seq_lens,
)
elif forward_batch.forward_mode.is_decode_or_idle():
bs = forward_batch.batch_size
bs = forward_batch.batch_size # Get maximum sequence length.
if getattr(forward_batch, "seq_lens_cpu", None) is not None:
max_seq = forward_batch.seq_lens_cpu.max().item()
else:
max_seq = forward_batch.seq_lens.max().item()
# Get maximum sequence length. max_seqlen_pad = self._calc_padded_blocks(max_seq)
if getattr(forward_batch, "seq_lens_cpu", None) is not None: block_kv_indices = self._create_block_kv_indices(
max_seq = forward_batch.seq_lens_cpu.max().item() bs,
max_seqlen_pad,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens.device,
)
max_seq_len_val = int(max_seq)
self.forward_decode_metadata = TRTLLMMLADecodeMetadata(
self.workspace_buffer, block_kv_indices, max_seq_len_val
)
forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata
else: else:
max_seq = forward_batch.seq_lens.max().item() return super().init_forward_metadata(forward_batch)
max_seqlen_pad = self._calc_padded_blocks(max_seq)
block_kv_indices = self._create_block_kv_indices(
bs,
max_seqlen_pad,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens.device,
)
max_seq_len_val = int(max_seq) def init_mha_chunk_metadata(self, forward_batch: ForwardBatch):
self.forward_metadata = TRTLLMMLADecodeMetadata( super().init_mha_chunk_metadata(forward_batch, disable_flashinfer_ragged=True)
self.workspace_buffer, block_kv_indices, max_seq_len_val
)
forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
def quantize_and_rope_for_fp8( def quantize_and_rope_for_fp8(
self, self,
...@@ -459,7 +490,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -459,7 +490,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
# Get metadata # Get metadata
metadata = ( metadata = (
getattr(forward_batch, "decode_trtllm_mla_metadata", None) getattr(forward_batch, "decode_trtllm_mla_metadata", None)
or self.forward_metadata or self.forward_decode_metadata
) )
# Scale computation for TRTLLM MLA kernel BMM1 operation: # Scale computation for TRTLLM MLA kernel BMM1 operation:
...@@ -496,6 +527,55 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -496,6 +527,55 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim) output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
return output return output
def forward_extend(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if (
forward_batch.forward_mode.is_target_verify()
or forward_batch.forward_mode.is_draft_extend()
):
return super().forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
)
if not forward_batch.attn_attend_prefix_cache:
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
output = flashinfer.prefill.trtllm_ragged_attention_deepseek(
query=q,
key=k,
value=v,
workspace_buffer=self.workspace_buffer,
seq_lens=self.forward_prefill_metadata.seq_lens,
max_q_len=self.forward_prefill_metadata.max_seq_len,
max_kv_len=self.forward_prefill_metadata.max_seq_len,
bmm1_scale=layer.scaling,
bmm2_scale=1.0,
o_sf_scale=1.0,
batch_size=forward_batch.batch_size,
window_left=-1,
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
enable_pdl=False,
is_causal=True,
return_lse=forward_batch.mha_return_lse,
)
else:
# replace with trtllm ragged attention once accuracy is resolved.
output = super().forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
)
return output
class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend): class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
"""Multi-step draft backend for TRT-LLM MLA used by EAGLE.""" """Multi-step draft backend for TRT-LLM MLA used by EAGLE."""
......
...@@ -1050,7 +1050,6 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1050,7 +1050,6 @@ class DeepseekV2AttentionMLA(nn.Module):
attention_backend == "flashinfer" attention_backend == "flashinfer"
or attention_backend == "fa3" or attention_backend == "fa3"
or attention_backend == "flashmla" or attention_backend == "flashmla"
or attention_backend == "trtllm_mla"
or attention_backend == "cutlass_mla" or attention_backend == "cutlass_mla"
): ):
# Use MHA with chunked KV cache when prefilling on long sequences. # Use MHA with chunked KV cache when prefilling on long sequences.
...@@ -1079,6 +1078,15 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1079,6 +1078,15 @@ class DeepseekV2AttentionMLA(nn.Module):
return AttnForwardMethod.MHA_CHUNKED_KV return AttnForwardMethod.MHA_CHUNKED_KV
else: else:
return _dispatch_mla_subtype() return _dispatch_mla_subtype()
elif attention_backend == "trtllm_mla":
if (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
return AttnForwardMethod.MHA_CHUNKED_KV
else:
return _dispatch_mla_subtype()
elif attention_backend == "aiter": elif attention_backend == "aiter":
if ( if (
forward_batch.forward_mode.is_extend() forward_batch.forward_mode.is_extend()
......
...@@ -41,6 +41,10 @@ DEFAULT_CONFIG = { ...@@ -41,6 +41,10 @@ DEFAULT_CONFIG = {
"v_head_dim": 512, "v_head_dim": 512,
"num_kv_heads": 1, "num_kv_heads": 1,
"layer_id": 0, "layer_id": 0,
"tp_q_head_num": 128,
"tp_k_head_num": 128,
"prefill_head_dim": 192,
"prefill_v_head_dim": 128,
} }
ROPE_BASE = 10000 ROPE_BASE = 10000
...@@ -92,7 +96,7 @@ TEST_CASES = { ...@@ -92,7 +96,7 @@ TEST_CASES = {
"description": "Medium-scale batch", "description": "Medium-scale batch",
}, },
], ],
"decode_output_match": [ "output_match": [
{ {
"name": "single_fp16", "name": "single_fp16",
"batch_size": 1, "batch_size": 1,
...@@ -322,7 +326,7 @@ class TestTRTLLMMLA(CustomTestCase): ...@@ -322,7 +326,7 @@ class TestTRTLLMMLA(CustomTestCase):
config.update(test_case) config.update(test_case)
return config return config
def _create_model_components(self, config): def _create_model_components(self, config, is_prefill=False):
"""Create model runners, backends, and layer for testing.""" """Create model runners, backends, and layer for testing."""
# Create model runners # Create model runners
model_runner_trtllm = MockModelRunner(config) model_runner_trtllm = MockModelRunner(config)
...@@ -332,14 +336,23 @@ class TestTRTLLMMLA(CustomTestCase): ...@@ -332,14 +336,23 @@ class TestTRTLLMMLA(CustomTestCase):
trtllm_backend = TRTLLMMLABackend(model_runner_trtllm) trtllm_backend = TRTLLMMLABackend(model_runner_trtllm)
reference_backend = FlashInferMLAAttnBackend(model_runner_reference) reference_backend = FlashInferMLAAttnBackend(model_runner_reference)
head_dim = (
config["kv_lora_rank"] + config["qk_rope_head_dim"]
if not is_prefill
else config["prefill_head_dim"]
)
v_head_dim = (
config["v_head_dim"] if not is_prefill else config["prefill_v_head_dim"]
)
# Create RadixAttention layer # Create RadixAttention layer
layer = RadixAttention( layer = RadixAttention(
num_heads=config["num_attention_heads"], num_heads=config["num_attention_heads"],
head_dim=config["kv_lora_rank"] + config["qk_rope_head_dim"], head_dim=head_dim,
scaling=model_runner_trtllm.model_config.scaling, scaling=model_runner_trtllm.model_config.scaling,
num_kv_heads=config["num_kv_heads"], num_kv_heads=config["num_kv_heads"],
layer_id=config["layer_id"], layer_id=config["layer_id"],
v_head_dim=config["v_head_dim"], v_head_dim=v_head_dim,
prefix="attn_mqa", prefix="attn_mqa",
) )
...@@ -524,7 +537,7 @@ class TestTRTLLMMLA(CustomTestCase): ...@@ -524,7 +537,7 @@ class TestTRTLLMMLA(CustomTestCase):
"""Test that TRTLLM and FlashInfer MLA backends produce matching outputs.""" """Test that TRTLLM and FlashInfer MLA backends produce matching outputs."""
print(f"\nRunning decode output matching tests...") print(f"\nRunning decode output matching tests...")
for test_case in TEST_CASES["decode_output_match"]: for test_case in TEST_CASES["output_match"]:
with self.subTest(test_case=test_case["name"]): with self.subTest(test_case=test_case["name"]):
print(f" Testing {test_case['name']}: {test_case['description']}") print(f" Testing {test_case['name']}: {test_case['description']}")
...@@ -1099,6 +1112,157 @@ class TestTRTLLMMLA(CustomTestCase): ...@@ -1099,6 +1112,157 @@ class TestTRTLLMMLA(CustomTestCase):
self.assertIsNotNone(metadata_3.block_kv_indices) self.assertIsNotNone(metadata_3.block_kv_indices)
self.assertEqual(metadata_3.block_kv_indices.shape[0], config["batch_size"]) self.assertEqual(metadata_3.block_kv_indices.shape[0], config["batch_size"])
def test_prefill_output_match_self_attention(self):
"""Test prefill (forward) behavior of TRTLLM MLA backend vs reference."""
print(f"\nRunning prefill output tests...")
for test_case in TEST_CASES["output_match"][:2]: # Just a subset for speed
with self.subTest(test_case=test_case["name"]):
print(
f"Prefill Testing {test_case['name']}: {test_case['description']}"
)
config = self._merge_config(test_case)
batch_size = config["batch_size"]
max_seq_len = config["max_seq_len"]
# Create components
(
model_runner_trtllm,
model_runner_reference,
trtllm_backend,
reference_backend,
layer,
) = self._create_model_components(config, is_prefill=True)
# Prefill uses full sequences
seq_lens = torch.full(
(batch_size,), max_seq_len, device=config["device"]
)
def _create_forward_batch_prefill(
batch_size,
seq_lens,
extend_prefix_lens,
backend,
model_runner,
config,
):
"""Create a forward batch for the given backend."""
fb = ForwardBatch(
batch_size=batch_size,
input_ids=torch.randint(
0, 100, (batch_size, 1), device=config["device"]
),
out_cache_loc=torch.arange(batch_size, device=config["device"]),
seq_lens_sum=int(seq_lens.sum().item()),
extend_prefix_lens=extend_prefix_lens,
extend_prefix_lens_cpu=extend_prefix_lens.cpu().int().tolist(),
extend_seq_lens_cpu=(seq_lens - extend_prefix_lens)
.cpu()
.int()
.tolist(),
forward_mode=ForwardMode.EXTEND,
req_pool_indices=torch.arange(
batch_size, device=config["device"]
),
seq_lens=seq_lens,
seq_lens_cpu=seq_lens.cpu(),
attn_attend_prefix_cache=False,
mha_return_lse=False,
attn_backend=backend,
)
fb.req_to_token_pool = model_runner.req_to_token_pool
fb.token_to_kv_pool = model_runner.token_to_kv_pool
# Add position information for RoPE
fb.positions = torch.arange(batch_size, device=config["device"])
return fb
# Create forward batches
fb_trtllm = _create_forward_batch_prefill(
batch_size,
seq_lens.clone(),
torch.zeros(batch_size, device=config["device"], dtype=torch.int32),
trtllm_backend,
model_runner_trtllm,
config,
)
fb_reference = _create_forward_batch_prefill(
batch_size,
seq_lens.clone(),
torch.zeros(batch_size, device=config["device"], dtype=torch.int32),
reference_backend,
model_runner_reference,
config,
)
# Initialize metadata for both backends
trtllm_backend.init_forward_metadata(fb_trtllm)
reference_backend.init_forward_metadata(fb_reference)
# Create Q, K, V tensors for prefill
torch.manual_seed(config["seed_qkv"])
def _create_qkv_tensors_prefill(
batch_size, seq_len, config, dtype_override=None
):
"""Create Q, K, V tensors for prefill, using config for head_num and head_dim."""
device = config["device"]
dtype = dtype_override or config["dtype"]
total_tokens = batch_size * seq_len
tp_q_head_num = config["tp_q_head_num"]
tp_k_head_num = config["tp_k_head_num"]
head_dim = config["prefill_head_dim"]
v_head_dim = config["prefill_v_head_dim"]
q = torch.randn(
(total_tokens, tp_q_head_num * head_dim),
dtype=dtype,
device=device,
)
k = torch.randn(
(total_tokens, tp_k_head_num * head_dim),
dtype=dtype,
device=device,
)
v = torch.randn(
(total_tokens, tp_k_head_num * v_head_dim),
dtype=dtype,
device=device,
)
# Reshape as requested
q = q.view(-1, tp_q_head_num, head_dim)
k = k.view(-1, tp_k_head_num, head_dim)
v = v.view(-1, tp_k_head_num, v_head_dim)
return q, k, v
q, k, v = _create_qkv_tensors_prefill(batch_size, max_seq_len, config)
# Run prefill on both backends
out_trtllm = trtllm_backend.forward_extend(
q, k, v, layer, fb_trtllm, False
).view(-1, layer.tp_q_head_num * layer.v_head_dim)
out_reference = reference_backend.forward_extend(
q, k, v, layer, fb_reference, False
)
tolerance = config.get("tolerance", 1e-2)
comparison_passed = compare_outputs(
out_trtllm, out_reference, tolerance=tolerance
)
self.assertTrue(
comparison_passed,
f"TRTLLM and Reference prefill outputs differ beyond tolerance. "
f"Config: {test_case['name']}, "
f"Max diff: {(out_trtllm - out_reference).abs().max().item()}",
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment