"examples/git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "2c1cc0b5218ceea9890f2800db21744e9d157ddf"
Unverified Commit b1bb8e74 authored by pranavm-nvidia's avatar pranavm-nvidia Committed by GitHub
Browse files

Enables TRT-LLM backend to be used for target_verify (#10281)


Co-authored-by: default avatarPranav Marathe <pranavm@ipp1-3309.ipp1a1.colossus.nvidia.com>
Co-authored-by: default avatarfzyzcjy <ch271828n@outlook.com>
parent 38c00ed7
...@@ -127,6 +127,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -127,6 +127,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
"disable_chunked_prefix_cache" "disable_chunked_prefix_cache"
] ]
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
def _calc_padded_blocks(self, max_seq_len: int) -> int: def _calc_padded_blocks(self, max_seq_len: int) -> int:
""" """
Calculate padded block count that satisfies both TRT-LLM and Triton constraints. Calculate padded block count that satisfies both TRT-LLM and Triton constraints.
...@@ -217,7 +219,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -217,7 +219,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
"""Initialize metadata for CUDA graph capture.""" """Initialize metadata for CUDA graph capture."""
# Delegate to parent for non-decode modes. # Delegate to parent for non-decode modes.
if not forward_mode.is_decode_or_idle(): if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
return super().init_forward_metadata_capture_cuda_graph( return super().init_forward_metadata_capture_cuda_graph(
bs, bs,
num_tokens, num_tokens,
...@@ -228,6 +230,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -228,6 +230,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
spec_info, spec_info,
) )
if forward_mode.is_target_verify():
seq_lens = seq_lens + self.num_draft_tokens
# Custom fast-path for decode/idle. # Custom fast-path for decode/idle.
# Capture with full width so future longer sequences are safe during replay # Capture with full width so future longer sequences are safe during replay
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len) max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
...@@ -270,7 +275,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -270,7 +275,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
): ):
"""Replay CUDA graph with new inputs.""" """Replay CUDA graph with new inputs."""
# Delegate to parent for non-decode modes. # Delegate to parent for non-decode modes.
if not forward_mode.is_decode_or_idle(): if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
return super().init_forward_metadata_replay_cuda_graph( return super().init_forward_metadata_replay_cuda_graph(
bs, bs,
req_pool_indices, req_pool_indices,
...@@ -282,6 +287,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -282,6 +287,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
seq_lens_cpu, seq_lens_cpu,
) )
if forward_mode.is_target_verify():
seq_lens = seq_lens + self.num_draft_tokens
del seq_lens_sum # not handle "num_draft_tokens" but we do not need it
metadata = self.decode_cuda_graph_metadata[bs] metadata = self.decode_cuda_graph_metadata[bs]
# Update block indices for new sequences. # Update block indices for new sequences.
...@@ -332,7 +341,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -332,7 +341,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
cum_seq_lens_q, cum_seq_lens_q,
seq_lens, seq_lens,
) )
elif forward_batch.forward_mode.is_decode_or_idle(): elif (
forward_batch.forward_mode.is_decode_or_idle()
or forward_batch.forward_mode.is_target_verify()
):
bs = forward_batch.batch_size bs = forward_batch.batch_size
# Get maximum sequence length. # Get maximum sequence length.
...@@ -341,13 +353,19 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -341,13 +353,19 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
else: else:
max_seq = forward_batch.seq_lens.max().item() max_seq = forward_batch.seq_lens.max().item()
seq_lens = forward_batch.seq_lens
if forward_batch.forward_mode.is_target_verify():
max_seq = max_seq + self.num_draft_tokens
seq_lens = seq_lens + self.num_draft_tokens
max_seqlen_pad = self._calc_padded_blocks(max_seq) max_seqlen_pad = self._calc_padded_blocks(max_seq)
block_kv_indices = self._create_block_kv_indices( block_kv_indices = self._create_block_kv_indices(
bs, bs,
max_seqlen_pad, max_seqlen_pad,
forward_batch.req_pool_indices, forward_batch.req_pool_indices,
forward_batch.seq_lens, seq_lens,
forward_batch.seq_lens.device, seq_lens.device,
) )
max_seq_len_val = int(max_seq) max_seq_len_val = int(max_seq)
...@@ -553,84 +571,134 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -553,84 +571,134 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
save_kv_cache: bool = True, save_kv_cache: bool = True,
q_rope: Optional[torch.Tensor] = None, q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None,
): ) -> torch.Tensor:
if ( if forward_batch.forward_mode.is_draft_extend():
forward_batch.forward_mode.is_target_verify()
or forward_batch.forward_mode.is_draft_extend()
):
return super().forward_extend( return super().forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
) )
# chunked prefix cache is not enabled, use Flashinfer MLA prefill kernel
if forward_batch.attn_attend_prefix_cache is None: # Save KV cache if requested
return super().forward_extend( if save_kv_cache:
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope assert (
k is not None and k_rope is not None
), "For populating trtllm_mla kv cache, both k_nope and k_rope should be not None."
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer, forward_batch.out_cache_loc, k, k_rope
)
if q_rope is not None:
q = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
q = torch.cat([q, q_rope], dim=-1)
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
if k_rope is not None:
k = torch.cat([k, k_rope], dim=-1)
k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
if forward_batch.forward_mode.is_target_verify():
metadata = (
getattr(forward_batch, "decode_trtllm_mla_metadata", None)
or self.forward_decode_metadata
)
# Ensure query has shape [bs, num_draft_tokens, num_q_heads, head_dim]
bs = forward_batch.batch_size
q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1)
q_scale = 1.0
k_scale = (
layer.k_scale_float
if getattr(layer, "k_scale_float", None) is not None
else 1.0
)
bmm1_scale = q_scale * k_scale * layer.scaling
seq_lens = (
forward_batch.seq_lens.to(torch.int32)
+ forward_batch.spec_info.draft_token_num
)
max_seq_len = metadata.max_seq_len + forward_batch.spec_info.draft_token_num
# TODO may use `mla_rope_quantize_fp8` fusion
q = q.to(self.data_type)
assert kv_cache.dtype == self.data_type
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
query=q,
kv_cache=kv_cache,
workspace_buffer=self.workspace_buffer,
qk_nope_head_dim=self.qk_nope_head_dim,
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
block_tables=metadata.block_kv_indices,
seq_lens=seq_lens,
max_seq_len=max_seq_len,
bmm1_scale=bmm1_scale,
) )
if not forward_batch.attn_attend_prefix_cache: # Reshape output directly without slicing
q = q.view(-1, layer.tp_q_head_num, layer.head_dim) output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
k = k.view(-1, layer.tp_k_head_num, layer.head_dim) return output
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
output = flashinfer.prefill.trtllm_ragged_attention_deepseek( if forward_batch.attn_attend_prefix_cache:
# MHA for chunked prefix kv cache when running model with MLA
assert forward_batch.prefix_chunk_idx is not None
assert forward_batch.prefix_chunk_cu_seq_lens is not None
assert q_rope is None
assert k_rope is None
chunk_idx = forward_batch.prefix_chunk_idx
output_shape = (q.shape[0], layer.tp_q_head_num, layer.v_head_dim)
return flashinfer.prefill.trtllm_ragged_attention_deepseek(
query=q, query=q,
key=k, key=k,
value=v, value=v,
workspace_buffer=self.workspace_buffer, workspace_buffer=self.workspace_buffer,
seq_lens=self.forward_prefill_metadata.seq_lens, seq_lens=forward_batch.prefix_chunk_seq_lens[chunk_idx],
max_q_len=self.forward_prefill_metadata.max_seq_len, max_q_len=self.forward_prefill_metadata.max_seq_len,
max_kv_len=self.forward_prefill_metadata.max_seq_len, max_kv_len=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
bmm1_scale=layer.scaling, bmm1_scale=layer.scaling,
bmm2_scale=1.0, bmm2_scale=1.0,
o_sf_scale=1.0, o_sf_scale=-1.0,
batch_size=forward_batch.batch_size, batch_size=forward_batch.batch_size,
window_left=-1, window_left=-1,
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens, cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens, cum_seq_lens_kv=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
enable_pdl=False, enable_pdl=False,
is_causal=True, is_causal=False,
return_lse=forward_batch.mha_return_lse, return_lse=True,
out=torch.zeros(*output_shape, dtype=q.dtype, device=q.device),
) )
else:
if not ( return flashinfer.prefill.trtllm_ragged_attention_deepseek(
forward_batch.attn_attend_prefix_cache is not None query=q,
and forward_batch.mha_return_lse key=k,
): value=v,
output = super().forward_extend( workspace_buffer=self.workspace_buffer,
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope seq_lens=self.forward_prefill_metadata.seq_lens,
) max_q_len=self.forward_prefill_metadata.max_seq_len,
else: max_kv_len=self.forward_prefill_metadata.max_seq_len,
# MHA for chunked prefix kv cache when running model with MLA bmm1_scale=layer.scaling,
assert forward_batch.prefix_chunk_idx is not None bmm2_scale=1.0,
assert forward_batch.prefix_chunk_cu_seq_lens is not None o_sf_scale=1.0,
assert q_rope is None batch_size=forward_batch.batch_size,
assert k_rope is None window_left=-1,
chunk_idx = forward_batch.prefix_chunk_idx cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
q = q.view(-1, layer.tp_q_head_num, layer.head_dim) enable_pdl=False,
k = k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype) is_causal=True,
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype) return_lse=forward_batch.mha_return_lse,
output_shape = (q.shape[0], layer.tp_q_head_num, layer.v_head_dim) )
output = flashinfer.prefill.trtllm_ragged_attention_deepseek(
query=q,
key=k,
value=v,
workspace_buffer=self.workspace_buffer,
seq_lens=forward_batch.prefix_chunk_seq_lens[chunk_idx],
max_q_len=self.forward_prefill_metadata.max_seq_len,
max_kv_len=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
bmm1_scale=layer.scaling,
bmm2_scale=1.0,
o_sf_scale=-1.0,
batch_size=forward_batch.batch_size,
window_left=-1,
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
cum_seq_lens_kv=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
enable_pdl=False,
is_causal=False,
return_lse=True,
out=torch.zeros(*output_shape, dtype=q.dtype, device=q.device),
)
return output
class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend): class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment