Unverified Commit 311de47b authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

[2/2] Speed up trtllm_mla attention backend (#10474)

parent 373080ea
...@@ -22,7 +22,7 @@ from sglang.srt.layers.attention.utils import ( ...@@ -22,7 +22,7 @@ from sglang.srt.layers.attention.utils import (
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_flashinfer_available from sglang.srt.utils import is_cuda, is_flashinfer_available
if is_flashinfer_available(): if is_flashinfer_available():
import flashinfer import flashinfer
...@@ -32,6 +32,11 @@ if TYPE_CHECKING: ...@@ -32,6 +32,11 @@ if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo from sglang.srt.speculative.spec_info import SpecInfo
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import concat_mla_absorb_q
# Constants # Constants
DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
...@@ -482,6 +487,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -482,6 +487,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
q_rope_reshaped = q_rope.view( q_rope_reshaped = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
) )
if _is_cuda and q_nope.shape[-1] == 512 and q_rope_reshaped.shape[-1] == 64:
query = concat_mla_absorb_q(q_nope, q_rope_reshaped)
else:
query = torch.cat([q_nope, q_rope_reshaped], dim=-1) query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
else: else:
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function # For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment