Unverified Commit f35def86 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fuse quantize and rope in trtllm_mla MTP (#10779)

parent d61615fe
...@@ -568,12 +568,35 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -568,12 +568,35 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
save_kv_cache: bool = True, save_kv_cache: bool = True,
q_rope: Optional[torch.Tensor] = None, q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None,
cos_sin_cache: Optional[torch.Tensor] = None,
is_neox: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
if forward_batch.forward_mode.is_draft_extend(): if forward_batch.forward_mode.is_draft_extend():
return super().forward_extend( return super().forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
) )
# TODO refactor to avoid code duplication
merge_query = q_rope is not None
if (
self.data_type == torch.float8_e4m3fn
) and forward_batch.forward_mode.is_target_verify():
# For FP8 path, we quantize the query and rope parts and merge them into a single tensor
# Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend
assert all(
x is not None for x in [q_rope, k_rope, cos_sin_cache]
), "For FP8 path and using flashinfer.rope.mla_rope_quantize we need all of q_rope, k_rope and cos_sin_cache to be not None."
q, k, k_rope = self.quantize_and_rope_for_fp8(
q,
q_rope,
k.squeeze(1),
k_rope.squeeze(1),
forward_batch,
cos_sin_cache,
is_neox,
)
merge_query = False
# Save KV cache if requested # Save KV cache if requested
if save_kv_cache: if save_kv_cache:
assert ( assert (
...@@ -583,12 +606,18 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -583,12 +606,18 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
layer, forward_batch.out_cache_loc, k, k_rope layer, forward_batch.out_cache_loc, k, k_rope
) )
if q_rope is not None: # TODO refactor to avoid code duplication
q = q.view(-1, layer.tp_q_head_num, layer.v_head_dim) # Prepare query tensor inline
q_rope = q_rope.view( if merge_query:
# For FP16 path, we merge the query and rope parts into a single tensor
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope_reshaped = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
) )
q = _concat_mla_absorb_q_general(q, q_rope) q = _concat_mla_absorb_q_general(q_nope, q_rope_reshaped)
else:
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
q = q.view(-1, layer.tp_q_head_num, layer.head_dim) q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
......
...@@ -1399,7 +1399,10 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1399,7 +1399,10 @@ class DeepseekV2AttentionMLA(nn.Module):
""" """
return ( return (
self.current_attention_backend == "trtllm_mla" self.current_attention_backend == "trtllm_mla"
and forward_batch.forward_mode.is_decode_or_idle() and (
forward_batch.forward_mode.is_decode_or_idle()
or forward_batch.forward_mode.is_target_verify()
)
and forward_batch.attn_backend.data_type == torch.float8_e4m3fn and forward_batch.attn_backend.data_type == torch.float8_e4m3fn
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment