Unverified Commit d9408ffb authored by Koushik Dutta's avatar Koushik Dutta Committed by GitHub
Browse files

Triton MLA perf fixes (#33529)


Signed-off-by: default avatarKoushik Dutta <koushd@gmail.com>
Co-authored-by: default avatarroot <root@ubuntu-nvidia.localdomain>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent 16a65e41
...@@ -14,6 +14,7 @@ from vllm.model_executor.layers.attention.mla_attention import ( ...@@ -14,6 +14,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
MLACommonMetadata, MLACommonMetadata,
) )
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.triton_utils import triton
from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionLayer, AttentionLayer,
...@@ -115,6 +116,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): ...@@ -115,6 +116,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
if is_quantized_kv_cache(self.kv_cache_dtype): if is_quantized_kv_cache(self.kv_cache_dtype):
self.supports_quant_query_input = False self.supports_quant_query_input = False
self._sm_count = torch.cuda.get_device_properties(0).multi_processor_count
def _flash_attn_varlen_diff_headdims( def _flash_attn_varlen_diff_headdims(
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
): ):
...@@ -149,7 +152,24 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): ...@@ -149,7 +152,24 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device) lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)
# For batch invariance, use only 1 split to ensure deterministic reduction # For batch invariance, use only 1 split to ensure deterministic reduction
num_kv_splits = 1 if envs.VLLM_BATCH_INVARIANT else 4 if envs.VLLM_BATCH_INVARIANT:
num_kv_splits = 1
else:
# Minimum work per split
# hardware dependent
min_work_per_split = 512
ideal_splits = max(1, attn_metadata.max_seq_len // min_work_per_split)
# use power of 2 to avoid excessive kernel instantiations
ideal_splits = triton.next_power_of_2(ideal_splits)
# Calculate SM-based maximum splits with occupancy multiplier
# 2-4x allows multiple blocks per SM for latency hiding
# hardware dependent
occupancy_multiplier = 2
max_splits = self._sm_count * occupancy_multiplier
num_kv_splits = min(ideal_splits, max_splits)
# TODO(lucas) Allocate ahead of time # TODO(lucas) Allocate ahead of time
attn_logits = torch.empty( attn_logits = torch.empty(
...@@ -186,6 +206,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): ...@@ -186,6 +206,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
PAGE_SIZE, PAGE_SIZE,
k_scale=layer._k_scale, k_scale=layer._k_scale,
v_scale=layer._k_scale, v_scale=layer._k_scale,
is_mla=True,
) )
return o, lse return o, lse
...@@ -291,6 +291,7 @@ def _fwd_grouped_kernel_stage1( ...@@ -291,6 +291,7 @@ def _fwd_grouped_kernel_stage1(
logit_cap: tl.constexpr, logit_cap: tl.constexpr,
Lk: tl.constexpr, Lk: tl.constexpr,
Lv: tl.constexpr, Lv: tl.constexpr,
IS_MLA: tl.constexpr = False,
): ):
cur_batch = tl.program_id(0) cur_batch = tl.program_id(0)
cur_head_id = tl.program_id(1) cur_head_id = tl.program_id(1)
...@@ -310,7 +311,12 @@ def _fwd_grouped_kernel_stage1( ...@@ -310,7 +311,12 @@ def _fwd_grouped_kernel_stage1(
cur_batch_req_idx = cur_batch cur_batch_req_idx = cur_batch
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) q = tl.load(
Q + offs_q,
mask=(mask_h[:, None]) & (mask_d[None, :]),
other=0.0,
cache_modifier=".ca",
)
if BLOCK_DPE > 0: if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
...@@ -319,7 +325,10 @@ def _fwd_grouped_kernel_stage1( ...@@ -319,7 +325,10 @@ def _fwd_grouped_kernel_stage1(
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
) )
qpe = tl.load( qpe = tl.load(
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0 Q + off_qpe,
mask=(mask_h[:, None]) & (mask_dpe[None, :]),
other=0.0,
cache_modifier=".ca",
) )
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
...@@ -331,9 +340,14 @@ def _fwd_grouped_kernel_stage1( ...@@ -331,9 +340,14 @@ def _fwd_grouped_kernel_stage1(
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
if split_kv_end > split_kv_start: if split_kv_end > split_kv_start:
base_offs_k = cur_kv_head * stride_buf_kh + offs_d[:, None]
base_offs_v = cur_kv_head * stride_buf_vh + offs_dv[None, :]
if BLOCK_DPE > 0:
base_offs_kpe = cur_kv_head * stride_buf_kh + offs_dpe[:, None]
ks = tl.load(k_scale) ks = tl.load(k_scale)
vs = tl.load(v_scale) vs = tl.load(v_scale)
for start_n in range(split_kv_start, split_kv_end, BLOCK_N): for start_n in tl.range(split_kv_start, split_kv_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N) offs_n = start_n + tl.arange(0, BLOCK_N)
kv_page_number = tl.load( kv_page_number = tl.load(
Req_to_tokens Req_to_tokens
...@@ -341,31 +355,29 @@ def _fwd_grouped_kernel_stage1( ...@@ -341,31 +355,29 @@ def _fwd_grouped_kernel_stage1(
+ offs_n // PAGE_SIZE, + offs_n // PAGE_SIZE,
mask=offs_n < split_kv_end, mask=offs_n < split_kv_end,
other=0, other=0,
cache_modifier=".ca",
) )
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
offs_buf_k = (
kv_loc[None, :] * stride_buf_kbs # explicitly facilitate overlapping load/compute
+ cur_kv_head * stride_buf_kh offs_buf_k = kv_loc[None, :] * stride_buf_kbs + base_offs_k
+ offs_d[:, None]
)
k = tl.load( k = tl.load(
K_Buffer + offs_buf_k, K_Buffer + offs_buf_k,
mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),
other=0.0, other=0.0,
cache_modifier=".cg",
) )
if k.dtype.is_fp8(): if k.dtype.is_fp8():
k = (k.to(tl.float32) * ks).to(q.dtype) k = (k.to(tl.float32) * ks).to(q.dtype)
qk = tl.dot(q, k.to(q.dtype)) qk = tl.dot(q, k.to(q.dtype))
if BLOCK_DPE > 0: if BLOCK_DPE > 0:
offs_buf_kpe = ( offs_buf_kpe = kv_loc[None, :] * stride_buf_kbs + base_offs_kpe
kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_dpe[:, None]
)
kpe = tl.load( kpe = tl.load(
K_Buffer + offs_buf_kpe, K_Buffer + offs_buf_kpe,
mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]), mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]),
other=0.0, other=0.0,
cache_modifier=".cg",
) )
if kpe.dtype.is_fp8(): if kpe.dtype.is_fp8():
kpe = (kpe.to(tl.float32) * ks).to(qpe.dtype) kpe = (kpe.to(tl.float32) * ks).to(qpe.dtype)
...@@ -379,18 +391,20 @@ def _fwd_grouped_kernel_stage1( ...@@ -379,18 +391,20 @@ def _fwd_grouped_kernel_stage1(
mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf") mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")
) )
offs_buf_v = ( if not IS_MLA:
kv_loc[:, None] * stride_buf_vbs offs_buf_v = kv_loc[:, None] * stride_buf_vbs + base_offs_v
+ cur_kv_head * stride_buf_vh v = tl.load(
+ offs_dv[None, :] V_Buffer + offs_buf_v,
) mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
v = tl.load( other=0.0,
V_Buffer + offs_buf_v, )
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), if v.dtype.is_fp8():
other=0.0, v = (v.to(tl.float32) * vs).to(q.dtype)
) else:
if v.dtype.is_fp8(): # MLA uses a single c_kv.
v = (v.to(tl.float32) * vs).to(q.dtype) # loading the same c_kv to interpret it as v is not necessary.
# transpose the existing c_kv (aka k) for the dot product.
v = tl.trans(k)
n_e_max = tl.maximum(tl.max(qk, 1), e_max) n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max) re_scale = tl.exp(e_max - n_e_max)
...@@ -441,7 +455,10 @@ def _decode_grouped_att_m_fwd( ...@@ -441,7 +455,10 @@ def _decode_grouped_att_m_fwd(
logit_cap, logit_cap,
k_scale, k_scale,
v_scale, v_scale,
is_mla=False,
): ):
# with is_mla there is only a single c_kv in smem.
# could increase BLOCK or num_stages.
BLOCK = 32 BLOCK = 32
Lk = k_buffer.shape[-1] Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1] Lv = v_buffer.shape[-1]
...@@ -514,6 +531,7 @@ def _decode_grouped_att_m_fwd( ...@@ -514,6 +531,7 @@ def _decode_grouped_att_m_fwd(
num_stages=num_stages, num_stages=num_stages,
Lk=Lk, Lk=Lk,
Lv=Lv, Lv=Lv,
IS_MLA=is_mla,
**extra_kargs, **extra_kargs,
) )
...@@ -673,6 +691,7 @@ def decode_attention_fwd_grouped( ...@@ -673,6 +691,7 @@ def decode_attention_fwd_grouped(
logit_cap=0.0, logit_cap=0.0,
k_scale=None, k_scale=None,
v_scale=None, v_scale=None,
is_mla=False,
): ):
_decode_grouped_att_m_fwd( _decode_grouped_att_m_fwd(
q, q,
...@@ -687,6 +706,7 @@ def decode_attention_fwd_grouped( ...@@ -687,6 +706,7 @@ def decode_attention_fwd_grouped(
logit_cap, logit_cap,
k_scale, k_scale,
v_scale, v_scale,
is_mla=is_mla,
) )
_decode_softmax_reducev_fwd( _decode_softmax_reducev_fwd(
attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits
...@@ -708,6 +728,7 @@ def decode_attention_fwd( ...@@ -708,6 +728,7 @@ def decode_attention_fwd(
logit_cap=0.0, logit_cap=0.0,
k_scale=None, k_scale=None,
v_scale=None, v_scale=None,
is_mla=False,
): ):
assert num_kv_splits == attn_logits.shape[2] assert num_kv_splits == attn_logits.shape[2]
...@@ -753,4 +774,5 @@ def decode_attention_fwd( ...@@ -753,4 +774,5 @@ def decode_attention_fwd(
logit_cap, logit_cap,
k_scale, k_scale,
v_scale, v_scale,
is_mla=is_mla,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment