# SPDX-License-Identifier: Apache-2.0 # Adapted from # https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py # which was originally adapted from # https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py # https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py # Changes: # - Add support for page size >= 1. # Copyright 2025 vLLM Team # Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ Memory-efficient attention for decoding. It supports page size >= 1. """ import os import logging import torch import triton import triton.language as tl from vllm.platforms import current_platform is_hip_ = current_platform.is_rocm() os.environ["TRITON_HIP_USE_NEW_STREAM_PIPELINE"] = f"0" logger = logging.getLogger(__name__) # TODO: Remove this when triton>=3.2.0. This issue will not affect performance # and accuracy. logger.warning( "The following error message 'operation scheduled before its operands' " "can be ignored.") @triton.jit def tanh(x): # Tanh is just a scaled sigmoid return 2 * tl.sigmoid(2 * x) - 1 @triton.jit def _fwd_kernel_stage1( Q, K_Buffer, V_Buffer, sm_scale, Req_to_tokens, B_Seqlen, Att_Out, stride_req_to_tokens_b, stride_qbs, stride_qh, stride_buf_kbs, stride_buf_kh, stride_buf_vbs, stride_buf_vh, stride_mid_ob, stride_mid_oh, stride_mid_os, kv_group_num: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_DV: tl.constexpr, BLOCK_N: tl.constexpr, NUM_KV_SPLITS: tl.constexpr, PAGE_SIZE: tl.constexpr, logit_cap: tl.constexpr, Lk: tl.constexpr, Lv: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) split_kv_id = tl.program_id(2) cur_kv_head = cur_head // kv_group_num offs_d = tl.arange(0, BLOCK_DMODEL) offs_dv = tl.arange(0, BLOCK_DV) mask_d = offs_d < Lk mask_dv = offs_dv < Lv cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_req_idx = cur_batch off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d q = tl.load(Q + off_q, mask=mask_d, other=0.0) kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) split_kv_start = kv_len_per_split * split_kv_id split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) e_max = -float("inf") e_sum = 0.0 acc = tl.zeros([BLOCK_DV], dtype=tl.float32) if split_kv_end > split_kv_start: for start_n in range(split_kv_start, split_kv_end, BLOCK_N): offs_n = start_n + tl.arange(0, BLOCK_N) kv_page_number = tl.load( Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n // PAGE_SIZE, mask=offs_n < split_kv_end, other=0, ) kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE offs_buf_k = (kv_loc[:, None] * stride_buf_kbs + cur_kv_head * stride_buf_kh + offs_d[None, :]) k = tl.load( K_Buffer + offs_buf_k, mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]), other=0.0, ) qk = tl.sum(q[None, :] * k, 1) qk *= sm_scale if logit_cap > 0: qk = logit_cap * tanh(qk / logit_cap) qk = tl.where(offs_n < split_kv_end, qk, float("-inf")) offs_buf_v = (kv_loc[:, None] * stride_buf_vbs + cur_kv_head * stride_buf_vh + offs_dv[None, :]) v = tl.load( V_Buffer + offs_buf_v, mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), other=0.0, ) n_e_max = tl.maximum(tl.max(qk, 0), e_max) re_scale = tl.exp(e_max - n_e_max) p = tl.exp(qk - n_e_max) acc *= re_scale acc += tl.sum(p[:, None] * v, 0) e_sum = e_sum * re_scale + tl.sum(p, 0) e_max = n_e_max offs_mid_o = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + split_kv_id * stride_mid_os + offs_dv) tl.store( Att_Out + offs_mid_o, acc / e_sum, mask=(mask_dv), ) offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + split_kv_id * stride_mid_os + Lv) tl.store( Att_Out + offs_mid_o_1, e_max + tl.log(e_sum), ) def _decode_att_m_fwd( q, k_buffer, v_buffer, att_out, Req_to_tokens, B_Seqlen, num_kv_splits, sm_scale, page_size, logit_cap, ): BLOCK = 64 NUM_KV_SPLITS = num_kv_splits Lk = k_buffer.shape[-1] Lv = v_buffer.shape[-1] batch, head_num = q.shape[0], q.shape[1] grid = (batch, head_num, NUM_KV_SPLITS) kv_group_num = q.shape[1] // k_buffer.shape[-2] num_warps = 4 if kv_group_num == 1 else 2 BLOCK_DMODEL = triton.next_power_of_2(Lk) BLOCK_DV = triton.next_power_of_2(Lv) _fwd_kernel_stage1[grid]( q, k_buffer, v_buffer, sm_scale, Req_to_tokens, B_Seqlen, att_out, Req_to_tokens.stride(0), q.stride(0), q.stride(1), k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) att_out.stride(0), att_out.stride(1), att_out.stride(2), kv_group_num=kv_group_num, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DV=BLOCK_DV, BLOCK_N=BLOCK, NUM_KV_SPLITS=NUM_KV_SPLITS, PAGE_SIZE=page_size, logit_cap=logit_cap, num_warps=num_warps, Lk=Lk, Lv=Lv, ) @triton.jit def _fwd_kernel_stage2( Mid_O, o, B_Seqlen, stride_mid_ob, stride_mid_oh, stride_mid_os, stride_obs, stride_oh, NUM_KV_SPLITS: tl.constexpr, BLOCK_DV: tl.constexpr, Lv: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) offs_d = tl.arange(0, BLOCK_DV) mask_d = offs_d < Lv e_sum = 0.0 e_max = -float("inf") acc = tl.zeros([BLOCK_DV], dtype=tl.float32) offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv for split_kv_id in range(0, NUM_KV_SPLITS): kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) split_kv_start = kv_len_per_split * split_kv_id split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) if split_kv_end > split_kv_start: tv = tl.load(Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0) tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os) n_e_max = tl.maximum(tlogic, e_max) old_scale = tl.exp(e_max - n_e_max) acc *= old_scale exp_logic = tl.exp(tlogic - n_e_max) acc += exp_logic * tv e_sum = e_sum * old_scale + exp_logic e_max = n_e_max tl.store( o + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / e_sum, mask=mask_d, ) def _decode_softmax_reducev_fwd( logits, q, o, v_buffer, b_seq_len, num_kv_splits, ): batch, head_num = q.shape[0], q.shape[1] Lv = v_buffer.shape[-1] BLOCK_DV = triton.next_power_of_2(Lv) NUM_KV_SPLITS = num_kv_splits extra_kargs = {} if is_hip_: # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py extra_kargs = { "waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 2 } grid = (batch, head_num) _fwd_kernel_stage2[grid]( logits, o, b_seq_len, logits.stride(0), logits.stride(1), logits.stride(2), o.stride(0), o.stride(1), NUM_KV_SPLITS=NUM_KV_SPLITS, BLOCK_DV=BLOCK_DV, Lv=Lv, num_warps=4, **extra_kargs, ) def decode_attention_fwd_normal( q, k_buffer, v_buffer, o, req_to_token, b_seq_len, attn_logits, num_kv_splits, sm_scale, page_size, logit_cap=0.0, ): _decode_att_m_fwd( q, k_buffer, v_buffer, attn_logits, req_to_token, b_seq_len, num_kv_splits, sm_scale, page_size, logit_cap, ) _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits) # opt @triton.autotune( configs=[ triton.Config({"BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1), ], key=["B_Seqlen","stride_qbs","stride_buf_kbs","stride_buf_kh"] ) @triton.jit def _decode_v1_kernel_stage1_use_tc( Q, K_Buffer, sm_scale, Req_to_tokens, #B_req_idx, B_Start_Loc, B_Seqlen, Att_Out, stride_req_to_tokens_b, stride_qbs, stride_qh, stride_buf_kbs, stride_buf_kh, att_stride_h, kv_group_num: tl.constexpr, q_head_num: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_DPE: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_H: tl.constexpr, SPLIT_K: tl.constexpr, PAGE_SIZE: tl.constexpr, logit_cap: tl.constexpr, Lk: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head_id = tl.program_id(1) cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) split_k_id = tl.program_id(2) reduce_dtype = Att_Out.dtype.element_ty if BLOCK_H < kv_group_num: VALID_BLOCK_H: tl.constexpr = BLOCK_H else: VALID_BLOCK_H: tl.constexpr = kv_group_num cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H) mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H mask_h = mask_h & (cur_head < q_head_num) offs_d = tl.arange(0, BLOCK_DMODEL) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) # cur_batch_req_idx = tl.load(B_req_idx + cur_batch) cur_batch_req_idx = cur_batch offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] q = tl.load( Q + offs_q, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk), other=0.0 ).to(reduce_dtype) if BLOCK_DPE > 0: offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) off_qpe = ( cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] ) qpe = tl.load(Q + off_qpe, mask=mask_h[:, None], other=0.0).to(reduce_dtype) kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K) split_k_start = kv_len_per_split * split_k_id split_k_end = tl.minimum(split_k_start + kv_len_per_split, cur_batch_seq_len) for start_n in range(split_k_start, split_k_end, BLOCK_N): offs_n = start_n + tl.arange(0, BLOCK_N) kv_page_number = tl.load( Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n // PAGE_SIZE, mask=offs_n < split_k_end, other=0, ) k_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE offs_buf_k = ( k_loc[None, :] * stride_buf_kbs + cur_kv_head * stride_buf_kh + offs_d[:, None] ) k = tl.load( K_Buffer + offs_buf_k, mask=(offs_n[None, :] < split_k_end) & (offs_d[:, None] < Lk), other=0.0, ).to(reduce_dtype) qk = tl.dot(q, k) if BLOCK_DPE > 0: offs_buf_kpe = ( k_loc[None, :] * stride_buf_kbs + cur_kv_head * stride_buf_kh + offs_dpe[:, None] ) kpe = tl.load( K_Buffer + offs_buf_kpe, mask=offs_n[None, :] < split_k_end, other=0.0, ).to(reduce_dtype) qk += tl.dot(qpe, kpe) qk *= sm_scale if logit_cap > 0: qk = logit_cap * tanh(qk / logit_cap) offs_o = cur_head[:, None] * att_stride_h + ( cur_batch_in_all_start_index + offs_n[None, :] ) tl.store( Att_Out + offs_o, qk, mask=mask_h[:, None] & (offs_n[None, :] < split_k_end), ) @triton.autotune( configs=[ triton.Config({"BLOCK_N": 32}, num_warps=1, num_ldmatrixes=1, num_stages=1), triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=1, num_stages=1), triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=1, num_stages=1), triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=1, num_stages=1), triton.Config({"BLOCK_N": 64}, num_warps=1, num_ldmatrixes=1, num_stages=1), triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=1, num_stages=1), triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=1, num_stages=1), triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=1, num_stages=1), triton.Config({"BLOCK_N": 8}, num_warps=1, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 8}, num_warps=2, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 8}, num_warps=4, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 8}, num_warps=8, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 16}, num_warps=1, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 32}, num_warps=1, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 64}, num_warps=1, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 128}, num_warps=1, num_ldmatrixes=1, num_stages=1), triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=1, num_stages=1), triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=1, num_stages=1), triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=1, num_stages=1), triton.Config({"BLOCK_N": 256}, num_warps=1, num_ldmatrixes=1, num_stages=1), triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=1, num_stages=1), triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=1, num_stages=1), triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=1, num_stages=1), triton.Config({"BLOCK_N": 512}, num_warps=1, num_ldmatrixes=1, num_stages=1), triton.Config({"BLOCK_N": 512}, num_warps=2, num_ldmatrixes=1, num_stages=1), triton.Config({"BLOCK_N": 512}, num_warps=4, num_ldmatrixes=1, num_stages=1), triton.Config({"BLOCK_N": 512}, num_warps=8, num_ldmatrixes=1, num_stages=1), triton.Config({"BLOCK_N": 128}, num_warps=1, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 256}, num_warps=1, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 512}, num_warps=1, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 512}, num_warps=2, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 512}, num_warps=4, num_ldmatrixes=0, num_stages=1), triton.Config({"BLOCK_N": 512}, num_warps=8, num_ldmatrixes=0, num_stages=1), ], key=["B_Seqlen","stride_logic_h","stride_buf_vbs","stride_buf_vh"] ) @triton.jit def _decode_v1_kernel_stage2_use_tc( logits, V_Buffer, Out, Req_to_tokens, #B_req_idx, B_Start_Loc, B_Seqlen, stride_logic_h, stride_buf_vbs, stride_buf_vh, stride_obs, stride_oh, stride_req_to_token_b, kv_group_num: tl.constexpr, q_head_num: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_H: tl.constexpr, PAGE_SIZE: tl.constexpr, Lv: tl.constexpr, BLOCK_N: tl.constexpr, ): cur_batch = tl.program_id(0) cur_kv_head = tl.program_id(1) cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H) mask_h = cur_head < (cur_kv_head + 1) * kv_group_num mask_h = mask_h & (cur_head < q_head_num) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch) cur_batch_req_idx = cur_batch #tl.load(B_req_idx + cur_batch) offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) offs_buf_v = cur_kv_head * stride_buf_vh + offs_d[None, :] v_ptrs = V_Buffer + offs_buf_v e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) acc = tl.zeros([BLOCK_H, BLOCK_DMODEL], dtype=tl.float32) for start_n in range(0, cur_batch_seq_len, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) v_page_number = tl.load( Req_to_tokens + cur_batch_req_idx * stride_req_to_token_b + (start_n + offs_n) // PAGE_SIZE, mask=(start_n + offs_n) < cur_batch_seq_len, other=0, ) v_loc = v_page_number * PAGE_SIZE + (start_n + offs_n) % PAGE_SIZE offs_qk = cur_head[:, None] * stride_logic_h + ( cur_batch_start_loc + start_n + offs_n[None, :] ) qk = tl.load( logits + offs_qk, mask=mask_h[:, None] & (start_n + offs_n[None, :] < cur_batch_seq_len), other=float("-inf"), ) #[head, block_n] n_e_max = tl.maximum(tl.max(qk, 1), e_max) old_scale = tl.exp(e_max - n_e_max) p = tl.exp(qk - n_e_max[:, None]) e_sum = e_sum * old_scale + tl.sum(p, 1) v = tl.load( v_ptrs + v_loc[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv) ) #[block_n,head_dim] p = p.to(v.dtype) acc = acc * old_scale[:, None] + tl.dot(p, v) e_max = n_e_max acc = acc / e_sum[:, None] off_o = cur_batch * stride_obs + cur_head[:, None] * stride_oh + offs_d[None, :] out_ptrs = Out + off_o tl.store(out_ptrs, acc, mask=(mask_h[:, None]) & (offs_d[None, :] < Lv)) def _decode_v1_stage1_use_tc( q, k_buffer, att_out, Req_to_tokens, #B_req_idx, B_Start_Loc, B_Seqlen, sm_scale, page_size, num_kv_splits, logit_cap, ): Lk = k_buffer.shape[-1] if Lk == 576: BLOCK_DMODEL = 512 BLOCK_DPE = 64 elif Lk == 288: BLOCK_DMODEL = 256 BLOCK_DPE = 32 else: BLOCK_DMODEL = triton.next_power_of_2(Lk) BLOCK_DPE = 0 # batch, head_num = B_req_idx.shape[0], q.shape[1] batch, head_num = q.shape[0], q.shape[1] kv_group_num = q.shape[1] // k_buffer.shape[-2] SPLIT_K = num_kv_splits BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num))) grid = lambda META: ( batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), SPLIT_K, ) _decode_v1_kernel_stage1_use_tc[grid]( q, k_buffer, sm_scale, Req_to_tokens, #B_req_idx, B_Start_Loc, B_Seqlen, att_out, Req_to_tokens.stride(0), q.stride(0), q.stride(1), k_buffer.stride(-3), k_buffer.stride(-2), att_out.stride(0), kv_group_num=kv_group_num, q_head_num=head_num, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DPE=BLOCK_DPE, BLOCK_H=BLOCK_H, SPLIT_K=SPLIT_K, PAGE_SIZE=page_size, logit_cap=logit_cap, Lk=Lk, kpack=2, ) return _decode_v1_kernel_stage1_use_tc.best_config def _decode_v1_stage2_use_tc( logits, v_buffer, o, req_to_tokens, #b_req_idx, b_start_loc, b_seq_len, page_size, ): batch, head_num = b_seq_len.shape[0], logits.shape[0] kv_group_num = logits.shape[0] // v_buffer.shape[-2] BLOCK_H = max(16, triton.next_power_of_2(kv_group_num)) grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1) Lv = v_buffer.shape[-1] BLOCK_DMODEL = triton.next_power_of_2(Lv) _decode_v1_kernel_stage2_use_tc[grid]( logits, v_buffer, o, req_to_tokens, #b_req_idx, b_start_loc, b_seq_len, logits.stride(0), v_buffer.stride(-3), v_buffer.stride(-2), o.stride(0), o.stride(1), req_to_tokens.stride(0), kv_group_num=kv_group_num, q_head_num=head_num, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_H=BLOCK_H, PAGE_SIZE=page_size, Lv=Lv, ) return _decode_v1_kernel_stage2_use_tc.best_config def decode_attention_v1( q, k_buffer, v_buffer, o, req_to_token, #b_req_idx, b_start_loc, b_seq_len, attn_logits, num_kv_splits, sm_scale, page_size, logit_cap=0.0, ): # GQA/MQA/MLA _decode_v1_stage1_best_config = _decode_v1_stage1_use_tc( q, k_buffer, attn_logits, req_to_token, #b_req_idx, b_start_loc, b_seq_len, sm_scale, page_size, num_kv_splits, logit_cap, ) _decode_v1_stage2_best_config = _decode_v1_stage2_use_tc( attn_logits, v_buffer, o, req_to_token, #b_req_idx, b_start_loc, b_seq_len, page_size, ) return _decode_v1_stage1_best_config, _decode_v1_stage2_best_config @triton.autotune( configs=[ triton.Config({"BLOCK_N": 16}, num_warps=2, num_stages=1), triton.Config({"BLOCK_N": 16}, num_warps=4, num_stages=1), triton.Config({"BLOCK_N": 16}, num_warps=8, num_stages=1), triton.Config({"BLOCK_N": 32}, num_warps=2, num_stages=1), triton.Config({"BLOCK_N": 32}, num_warps=4, num_stages=1), triton.Config({"BLOCK_N": 32}, num_warps=8, num_stages=1), triton.Config({"BLOCK_N": 64}, num_warps=2, num_stages=1), triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=1), triton.Config({"BLOCK_N": 64}, num_warps=8, num_stages=1), triton.Config({"BLOCK_N": 128}, num_warps=2, num_stages=1), triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=1), triton.Config({"BLOCK_N": 128}, num_warps=8, num_stages=1), triton.Config({"BLOCK_N": 256}, num_warps=2, num_stages=1), triton.Config({"BLOCK_N": 256}, num_warps=4, num_stages=1), triton.Config({"BLOCK_N": 256}, num_warps=8, num_stages=1), ], key=["B_Seqlen","stride_qbs","stride_buf_kbs","stride_buf_kh", "stride_buf_vbs", "stride_buf_vh"] ) @triton.jit def _decode_v2_kernel_stage1_use_tc( Q, K_Buffer, V_Buffer, sm_scale, Req_to_tokens, # B_req_idx, B_Seqlen, Att_Out, stride_req_to_tokens_b, stride_qbs, stride_qh, stride_buf_kbs, stride_buf_kh, stride_buf_vbs, stride_buf_vh, stride_mid_ob, stride_mid_oh, stride_mid_os, kv_group_num: tl.constexpr, q_head_num: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_DPE: tl.constexpr, BLOCK_DV: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_H: tl.constexpr, NUM_KV_SPLITS: tl.constexpr, PAGE_SIZE: tl.constexpr, logit_cap: tl.constexpr, Lk: tl.constexpr, Lv: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head_id = tl.program_id(1) cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) split_kv_id = tl.program_id(2) if BLOCK_H < kv_group_num: VALID_BLOCK_H: tl.constexpr = BLOCK_H else: VALID_BLOCK_H: tl.constexpr = kv_group_num cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H) mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H mask_h = mask_h & (cur_head < q_head_num) offs_d = tl.arange(0, BLOCK_DMODEL) offs_dv = tl.arange(0, BLOCK_DV) mask_d = offs_d < Lk mask_dv = offs_dv < Lv cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) # cur_batch_req_idx = tl.load(B_req_idx + cur_batch) cur_batch_req_idx = cur_batch offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) if BLOCK_DPE > 0: offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) mask_dpe = offs_dpe < Lk off_qpe = ( cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] ) qpe = tl.load( Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0 ) kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) split_kv_start = kv_len_per_split * split_kv_id split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) if split_kv_end > split_kv_start: for start_n in range(split_kv_start, split_kv_end, BLOCK_N): offs_n = start_n + tl.arange(0, BLOCK_N) kv_page_number = tl.load( Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n // PAGE_SIZE, mask=offs_n < split_kv_end, other=0, ) kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE offs_buf_k = ( kv_loc[None, :] * stride_buf_kbs + cur_kv_head * stride_buf_kh + offs_d[:, None] ) k = tl.load( K_Buffer + offs_buf_k, mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), other=0.0, ) qk = tl.dot(q, k.to(q.dtype)) if BLOCK_DPE > 0: offs_buf_kpe = ( kv_loc[None, :] * stride_buf_kbs + cur_kv_head * stride_buf_kh + offs_dpe[:, None] ) kpe = tl.load( K_Buffer + offs_buf_kpe, mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]), other=0.0, ) qk += tl.dot(qpe, kpe.to(qpe.dtype)) qk *= sm_scale if logit_cap > 0: qk = logit_cap * tanh(qk / logit_cap) qk = tl.where( mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf") ) offs_buf_v = ( kv_loc[:, None] * stride_buf_vbs + cur_kv_head * stride_buf_vh + offs_dv[None, :] ) v = tl.load( V_Buffer + offs_buf_v, mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), other=0.0, ) n_e_max = tl.maximum(tl.max(qk, 1), e_max) re_scale = tl.exp(e_max - n_e_max) p = tl.exp(qk - n_e_max[:, None]) acc *= re_scale[:, None] acc += tl.dot(p.to(v.dtype), v) e_sum = e_sum * re_scale + tl.sum(p, 1) e_max = n_e_max offs_mid_o = ( cur_batch * stride_mid_ob + cur_head[:, None] * stride_mid_oh + split_kv_id * stride_mid_os + offs_dv[None, :] ) tl.store( Att_Out + offs_mid_o, acc / e_sum[:, None], mask=(mask_h[:, None]) & (mask_dv[None, :]), ) offs_mid_o_1 = ( cur_batch * stride_mid_ob + cur_head * stride_mid_oh + split_kv_id * stride_mid_os + Lv ) tl.store( Att_Out + offs_mid_o_1, e_max + tl.log(e_sum), mask=mask_h, ) def _decode_v2_stage1_use_tc( q, k_buffer, v_buffer, att_out, Req_to_tokens, # B_req_idx, B_Seqlen, num_kv_splits, sm_scale, page_size, logit_cap, ): Lk = k_buffer.shape[-1] Lv = v_buffer.shape[-1] if Lk == 576: BLOCK_DMODEL = 512 BLOCK_DPE = 64 elif Lk == 288: BLOCK_DMODEL = 256 BLOCK_DPE = 32 else: BLOCK_DMODEL = triton.next_power_of_2(Lk) BLOCK_DPE = 0 BLOCK_DV = triton.next_power_of_2(Lv) # batch, head_num = B_req_idx.shape[0], q.shape[1] batch, head_num = q.shape[0], q.shape[1] kv_group_num = q.shape[1] // k_buffer.shape[-2] BLOCK_H = 16 NUM_KV_SPLITS = num_kv_splits grid = lambda META: ( batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), NUM_KV_SPLITS, ) _decode_v2_kernel_stage1_use_tc[grid]( q, k_buffer, v_buffer, sm_scale, Req_to_tokens, # B_req_idx, B_Seqlen, att_out, Req_to_tokens.stride(0), q.stride(0), q.stride(1), k_buffer.stride(-3), k_buffer.stride(-2), v_buffer.stride(-3), v_buffer.stride(-2), att_out.stride(0), att_out.stride(1), att_out.stride(2), kv_group_num=kv_group_num, q_head_num=head_num, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DPE=BLOCK_DPE, BLOCK_DV=BLOCK_DV, BLOCK_H=BLOCK_H, NUM_KV_SPLITS=NUM_KV_SPLITS, PAGE_SIZE=page_size, logit_cap=logit_cap, Lk=Lk, Lv=Lv, kpack=2, ) return _decode_v2_kernel_stage1_use_tc.best_config @triton.autotune( configs=[ triton.Config({}, num_warps=1, num_stages=1), triton.Config({}, num_warps=2, num_stages=1), triton.Config({}, num_warps=4, num_stages=1), triton.Config({}, num_warps=8, num_stages=1), ], key=["B_Seqlen", "stride_mid_ob", "stride_mid_oh", "stride_mid_os"] ) @triton.jit def _decode_v2_kernel_stage2( Mid_O, O, B_Seqlen, stride_mid_ob, stride_mid_oh, stride_mid_os, stride_obs, stride_oh, NUM_KV_SPLITS: tl.constexpr, BLOCK_DV: tl.constexpr, Lv: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) offs_d = tl.arange(0, BLOCK_DV) mask_d = offs_d < Lv e_sum = 0.0 e_max = -float("inf") acc = tl.zeros([BLOCK_DV], dtype=tl.float32) offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv for split_kv_id in range(0, NUM_KV_SPLITS): kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) split_kv_start = kv_len_per_split * split_kv_id split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) if split_kv_end > split_kv_start: tv = tl.load( Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0 ) tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os) n_e_max = tl.maximum(tlogic, e_max) old_scale = tl.exp(e_max - n_e_max) acc *= old_scale exp_logic = tl.exp(tlogic - n_e_max) acc += exp_logic * tv e_sum = e_sum * old_scale + exp_logic e_max = n_e_max tl.store( O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / e_sum, mask=mask_d, ) def _decode_v2_stage2_use_tc( logits, q, o, v_buffer, b_seq_len, num_kv_splits, ): batch, head_num = q.shape[0], q.shape[1] Lv = v_buffer.shape[-1] BLOCK_DV = triton.next_power_of_2(Lv) NUM_KV_SPLITS = num_kv_splits grid = (batch, head_num) _decode_v2_kernel_stage2[grid]( logits, o, b_seq_len, logits.stride(0), logits.stride(1), logits.stride(2), o.stride(0), o.stride(1), NUM_KV_SPLITS=NUM_KV_SPLITS, BLOCK_DV=BLOCK_DV, Lv=Lv, ) return _decode_v2_kernel_stage2.best_config def decode_attention_v2( q, k_buffer, v_buffer, o, req_to_token, # b_req_idx, b_seq_len, attn_logits, num_kv_splits, sm_scale, page_size, logit_cap=0.0, ): _decode_v2_stage1_best_config = _decode_v2_stage1_use_tc( q, k_buffer, v_buffer, attn_logits, req_to_token, # b_req_idx, b_seq_len, num_kv_splits, sm_scale, page_size, logit_cap, ) _decode_v2_stage2_best_config = _decode_v2_stage2_use_tc(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits) return _decode_v2_stage1_best_config, _decode_v2_stage2_best_config def decode_attention_fwd( q, k_buffer, v_buffer, o, req_to_token, b_seq_len, attn_logits, num_kv_splits, sm_scale, page_size=1, logit_cap=0.0, ): assert num_kv_splits == attn_logits.shape[2] kv_group_num = q.shape[1] // v_buffer.shape[-2] b_start_loc = torch.arange(0, k_buffer.shape[0] * page_size, k_buffer.shape[0] * page_size // q.shape[0], device="cuda").to(torch.int32) if kv_group_num == 1: # MHA decode_attention_fwd_normal( q, k_buffer, v_buffer, o, req_to_token, b_seq_len, attn_logits, num_kv_splits, sm_scale, page_size, logit_cap, ) else: # GQA/MQA/MLA decode_attention_v2( q, k_buffer, v_buffer, o, req_to_token, b_seq_len, attn_logits, num_kv_splits, sm_scale, page_size, logit_cap, ) # attn_logits_v1 = torch.empty( # (q.shape[1],k_buffer.shape[0]*page_size), # dtype=torch.float16, # device="cuda") # decode_attention_v1( # q, # k_buffer, # v_buffer, # o, # req_to_token, # b_start_loc, # b_seq_len, # attn_logits_v1, # num_kv_splits, # sm_scale, # page_size, # logit_cap, # )