import os import sys os.environ["AMDGCN_USE_BUFFER_OPS"] = "1" # :class:`Hcutuner` reads ``TRITON_HCUTUNE_PERF_MODE`` in ``__init__``. Module-level # ``fn`` / ``fn_v2`` / ``fn_v2_decode`` = ``triton.utils.hcutune(...)`` run at import time, *before* # ``if __name__ == "__main__"``, so ``--perf`` must be applied here (or set in the shell). if "--perf" in sys.argv: os.environ["TRITON_HCUTUNE_PERF_MODE"] = "1" # GPU / ROCm tuning: run **inside** the ``zww_tl_1`` container (not the bare host), e.g.: # docker exec zww_tl_1 bash -lc 'cd /data/zhouweiwang/aiter/op_tests/triton_autotune && python tune_extend_attention.py --perf' # ``do_bench`` timing during tuning (smaller => faster iteration; raise for stable prod numbers). TUNE_DO_BENCH_WARMUP = 5 TUNE_DO_BENCH_REP = 20 import json import torch import triton import random import itertools import argparse from aiter.ops.triton.extend_attention import ( _fwd_kernel, _fwd_kernel_v2, _fwd_kernel_v2_decode, ) _is_hip = True # hcutune key for :func:`_fwd_kernel_v2`. JSON block-size lookup uses :func:`_get_config_v2` # ``want7`` only; these names add kernel constexprs (``SKIP_PREFIX_CUSTOM_MASK``, # ``xai_temperature_len``) used for autotune but not in the V2 JSON key. # Log alignment (e.g. ``fp8_dp2_tp8_415_triton_rocm_nomtp.log`` ~2291): ``kv_group_num = q.size(-2)//k.size(-2)``, # ``Lq``/``Lv`` last dims, ``USE_CUSTOM_MASK = custom_mask is not None``, ``HAS_SINK = sinks is not None``. HCUTUNE_KEY_V2 = [ "kv_group_num", "Lq", "Lv", "USE_CUSTOM_MASK", "IS_CAUSAL", "SKIP_PREFIX_CUSTOM_MASK", "HAS_SINK", "SLIDING_WINDOW_SIZE", "xai_temperature_len", ] version = triton.__version__.split(".") major_version, minor_version = eval(version[0]), eval(version[1]) def input_helper( B, H, prefix_length, extend_length, kv_lora_rank, qk_rope_head_dim, v_head_dim, dtype, device, attn_impl="normal", equal_seqlens=False, requires_grad=False, kv_num_heads: int = 1, ): torch.manual_seed(0) if not equal_seqlens: max_extend_length = extend_length max_prefix_length = prefix_length seqlens_extend = torch.randint( 1, max_extend_length + 1, (B,), dtype=torch.int32 ) if prefix_length == 0: seqlens_prefix = torch.full((B,), prefix_length, dtype=torch.int32) else: seqlens_prefix = torch.randint( 1, max_prefix_length + 1, (B,), dtype=torch.int32 ) else: seqlens_extend = torch.full((B,), extend_length, dtype=torch.int32) seqlens_prefix = torch.full((B,), prefix_length, dtype=torch.int32) cu_seqlens_extend = torch.cat( [ torch.tensor([0], dtype=torch.int32), seqlens_extend.cumsum(dim=0, dtype=torch.int32), ] ) cu_seqlens_prefix = torch.cat( [ torch.tensor([0], dtype=torch.int32), seqlens_prefix.cumsum(dim=0, dtype=torch.int32), ] ) cu_seqlens_extend = cu_seqlens_extend.to(device="cuda") cu_seqlens_prefix = cu_seqlens_prefix.to(device="cuda") total_extend = cu_seqlens_extend[-1].item() total_prefix = cu_seqlens_prefix[-1].item() if attn_impl == "absorb": Lq = kv_lora_rank + qk_rope_head_dim Lk = kv_lora_rank + qk_rope_head_dim Lv = kv_lora_rank else: Lq = v_head_dim + qk_rope_head_dim Lk = v_head_dim + qk_rope_head_dim Lv = v_head_dim q_extend = torch.randn( total_extend, H, Lq, dtype=dtype, device=device ).requires_grad_(requires_grad) # extend parts (``kv_num_heads`` for GQA: e.g. 2 when q has 16 heads and kv_group_num is 8) k_extend = torch.randn( total_extend, kv_num_heads, Lk, dtype=dtype, device=device ).requires_grad_(requires_grad) v_extend = k_extend[..., :Lv] # extend indexing qo_indptr = cu_seqlens_extend # prefix parts k_buffer = torch.randn( total_prefix, kv_num_heads, Lk, dtype=dtype, device=device ).requires_grad_(requires_grad) v_buffer = k_buffer[..., :Lv] if attn_impl != "absorb": # simulate v = kv_latent * w_vc which changes the values compared to k v_extend = torch.randn_like(v_extend, dtype=v_extend.dtype) v_buffer = torch.randn_like(v_buffer, dtype=v_buffer.dtype) # prefix indexing kv_indptr = cu_seqlens_prefix kv_indices = torch.arange(total_prefix, device=device, dtype=torch.int32) custom_mask = None mask_indptr = None max_len_extend = extend_length return ( q_extend, k_extend, v_extend, k_buffer, v_buffer, kv_indptr, kv_indices, qo_indptr, custom_mask, mask_indptr, max_len_extend, ) def _v2_flat_causal_custom_mask(B, prefix_len, extend_len, device): """Row-major causal mask per batch: shape [extend_len, prefix_len + extend_len], k_global <= q_global.""" total = prefix_len + extend_len q_row = torch.arange(extend_len, device=device) + prefix_len k_col = torch.arange(total, device=device) m = k_col.unsqueeze(0) <= q_row.unsqueeze(1) one_batch = m.reshape(-1).contiguous() custom_mask = one_batch.repeat(B) seg = extend_len * total mask_indptr = torch.arange( 0, (B + 1) * seg, seg, dtype=torch.int32, device=device ) return custom_mask, mask_indptr def get_bench_inputs(): names = ["B", "H", "prefix", "extend", "kv_lora_rank", "qk_rope_head_dim", "v_head_dim", "causal", "custom_mask"] vals = [] shapes = [ (2, 4, 0, 512, 32, 16, 32), (3, 5, 0, 333, 18, 13, 17), (3, 5, 512, 333, 18, 0, 17), (3, 5, 110, 333, 18, 0, 19), (8, 16, 0, 1024, 128, 0, 128), # this one passes # (8, 16, 0, 16324, 128, 0, 128), # this one fails, numeric precision is likely the issue (2, 1, 64, 32, 128, 64, 128), (2, 1, 64, 32, 128, 64, 128), (4, 16, 64, 96, 128, 64, 128), (1, 16, 0, 7, 512, 64, 512), (1, 16, 7, 4, 512, 64, 512), (1, 16, 32, 4, 512, 64, 512), (1, 16, 64, 3, 512, 64, 512), (1, 16, 127, 4, 512, 64, 512), (1, 16, 255, 15, 512, 64, 512), (3, 16, 452, 16, 512, 64, 512), (4, 16, 512, 14, 512, 64, 512), (4, 16, 1024, 16, 512, 64, 512), (4, 16, 2048, 13, 512, 64, 512), ] for is_causal in [True, False]: for use_custom_mask in [True, False]: for s in shapes: vals.append((*s, is_causal, use_custom_mask)) return names, vals def generate_configs(config): keys = list(config.keys()) values = list(config.values()) configs_list = [] for combination in itertools.product(*values): cfg = dict(zip(keys, combination)) configs_list.append(cfg) return configs_list # def get_triton_configs(): # config = { # "BLOCK_M": [16, 32, 64], # "BLOCK_N": [16, 32, 64], # "waves_per_eu": [1], # "num_warps": [4, 8, 16], # # "instruction_sched_variant": ["none", "llvm-iglp-1", "llvm-iglp-8", "local-prefetch"], # # "schedule_hint": ["none", "llvm-iglp-8", "local-prefetch"], # "matrix_instr_nonkdim": [16], # "num_stages": [1, 2, 3], # "sched_latency": ["none", "mmac5-ds10"], # "kpack": [1, 2], # } # tt_configs = [] # for c in generate_configs(config): # num_warps = c['num_warps'] # num_stages = c['num_stages'] # del c['num_warps'] # del c['num_stages'] # tt_configs.append(triton.Config(c, num_warps=num_warps, num_stages=num_stages)) # return tt_configs def get_triton_configs(): config = { "BLOCK_M": [16, 32, 64], "BLOCK_N": [16, 32, 64], "waves_per_eu": [1], "num_warps": [4, 8], # "instruction_sched_variant": ["none", "llvm-iglp-1", "llvm-iglp-8", "local-prefetch"], # "schedule_hint": ["none", "llvm-iglp-8", "local-prefetch"], "matrix_instr_nonkdim": [16], "num_stages": [1, 2], "sched_latency": ["none", "mmac5-ds10"], "kpack": [1], } tt_configs = [] for c in generate_configs(config): num_warps = c['num_warps'] num_stages = c['num_stages'] del c['num_warps'] del c['num_stages'] tt_configs.append(triton.Config(c, num_warps=num_warps, num_stages=num_stages)) return tt_configs def prune_configs(configs, nargs, **kwargs): def _prune(config): c = config.all_kwargs() if c['BLOCK_M'] == 64 and c['num_stages'] >= 2: return True elif c['BLOCK_M'] >= 64 and c['num_warps'] >= 8: return True res = [c for c in configs if not _prune(c)] return res key = [ 'kv_group_num', 'SLIDING_WINDOW_SIZE', 'Lq', 'Lv', 'USE_CUSTOM_MASK', 'IS_CAUSAL', 'SKIP_PREFIX_CUSTOM_MASK', 'STORE_TRANSPOSE', ] fn = triton.utils.hcutune( configs=get_triton_configs(), key=key, perf_debug=True, prune_configs_by={"early_config_prune": prune_configs}, warmup=TUNE_DO_BENCH_WARMUP, rep=TUNE_DO_BENCH_REP, )(_fwd_kernel) fn_v2 = triton.utils.hcutune( configs=get_triton_configs(), key=HCUTUNE_KEY_V2, perf_debug=True, prune_configs_by={"early_config_prune": prune_configs}, warmup=TUNE_DO_BENCH_WARMUP, rep=TUNE_DO_BENCH_REP, )(_fwd_kernel_v2) fn_v2_decode = triton.utils.hcutune( configs=get_triton_configs(), key=HCUTUNE_KEY_V2, perf_debug=True, prune_configs_by={"early_config_prune": prune_configs}, warmup=TUNE_DO_BENCH_WARMUP, rep=TUNE_DO_BENCH_REP, )(_fwd_kernel_v2_decode) def extend_attention_fwd( q_extend, k_extend, v_extend, o_extend, k_buffer, v_buffer, qo_indptr, kv_indptr, kv_indices, custom_mask, is_causal, mask_indptr, max_len_extend, sm_scale=None, logit_cap=0.0, skip_prefix_custom_mask=True, # config: Optional[dict[str, any]] = None, ): """ q_extend, k_extend, v_extend, o_extend: contiguous tensors k_buffer, v_buffer: (prefix + extend) tensors in mem_manager """ Lq, Lv = ( q_extend.shape[-1], v_extend.shape[-1], ) if Lq == 576: BLOCK_DMODEL = 512 BLOCK_DPE = 64 elif Lq == 288: BLOCK_DMODEL = 256 BLOCK_DPE = 32 elif Lq == 192: BLOCK_DMODEL = 128 BLOCK_DPE = 64 else: BLOCK_DMODEL = triton.next_power_of_2(Lq) BLOCK_DPE = 0 BLOCK_DV = triton.next_power_of_2(Lv) # BLOCK_M, BLOCK_N = (64, 64) # num_warps = 4 sm_scale = sm_scale or 1.0 / (Lq**0.5) batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1] kv_group_num = q_extend.shape[1] // k_extend.shape[1] USE_CUSTOM_MASK = custom_mask is not None # Skip custom mask for prefix part SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask if not USE_CUSTOM_MASK: custom_mask = torch.tensor([0], dtype=torch.bool, device=q_extend.device) mask_indptr = torch.tensor([0], dtype=torch.int32, device=q_extend.device) grid = lambda META: (batch_size, head_num, triton.cdiv(max_len_extend, META["BLOCK_M"])) # num_stages = 1 # extra_kargs = {} # extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2} fn[grid]( q_extend, k_extend, v_extend, o_extend, k_buffer, v_buffer, qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr, sm_scale, kv_group_num, q_extend.stride(0), q_extend.stride(1), k_extend.stride(0), k_extend.stride(1), v_extend.stride(0), v_extend.stride(1), o_extend.stride(0), o_extend.stride(1), k_buffer.stride(0), k_buffer.stride(1), v_buffer.stride(0), v_buffer.stride(1), logit_cap=logit_cap, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DPE=BLOCK_DPE, BLOCK_DV=BLOCK_DV, # BLOCK_M=BLOCK_M, # BLOCK_N=BLOCK_N, Lq=Lq, Lv=Lv, USE_CUSTOM_MASK=USE_CUSTOM_MASK, IS_CAUSAL=is_causal, SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK, STORE_TRANSPOSE=True, # num_warps=num_warps, # num_stages=num_stages, # **config, ) def extend_attention_fwd_v2( q_extend, k_extend, v_extend, o_extend, k_buffer, v_buffer, qo_indptr, kv_indptr, kv_indices, custom_mask, is_causal, mask_indptr, max_len_extend, sm_scale, k_scale, v_scale, sliding_window_size, sinks, window_kv_offsets, xai_temperature_len, skip_prefix_custom_mask, logit_cap, ): """Launch :data:`fn_v2` (hcutune-wrapped :func:`_fwd_kernel_v2`); kwargs align with ``extend_attention_fwd`` v2 path.""" Lq, Lv = q_extend.shape[-1], v_extend.shape[-1] if Lq == 576: BLOCK_DMODEL, BLOCK_DPE = 512, 64 elif Lq == 288: BLOCK_DMODEL, BLOCK_DPE = 256, 32 elif Lq == 192: BLOCK_DMODEL, BLOCK_DPE = 128, 64 else: BLOCK_DMODEL = triton.next_power_of_2(Lq) BLOCK_DPE = 0 BLOCK_DV = triton.next_power_of_2(Lv) sm_scale = sm_scale or 1.0 / (Lq**0.5) batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1] kv_group_num = q_extend.shape[1] // k_extend.shape[1] USE_CUSTOM_MASK = custom_mask is not None SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask HAS_SINK = sinks is not None if not USE_CUSTOM_MASK: custom_mask = torch.tensor([0], dtype=torch.bool, device=q_extend.device) mask_indptr = torch.tensor([0], dtype=torch.int32, device=q_extend.device) elif sliding_window_size > 0 and window_kv_offsets is None: window_kv_offsets = torch.zeros( batch_size, dtype=torch.int32, device=q_extend.device ) grid = lambda META: ( batch_size, head_num, triton.cdiv(max_len_extend, META["BLOCK_M"]), ) stride_args = ( q_extend.stride(0), q_extend.stride(1), k_extend.stride(0), k_extend.stride(1), v_extend.stride(0), v_extend.stride(1), o_extend.stride(0), o_extend.stride(1), k_buffer.stride(0), k_buffer.stride(1), v_buffer.stride(0), v_buffer.stride(1), ) fn_v2[grid]( q_extend, k_extend, v_extend, o_extend, k_buffer, v_buffer, qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr, sinks, window_kv_offsets, sm_scale, k_scale, v_scale, kv_group_num, *stride_args, SLIDING_WINDOW_SIZE=sliding_window_size, logit_cap=logit_cap, xai_temperature_len=xai_temperature_len, HAS_SINK=HAS_SINK, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DPE=BLOCK_DPE, BLOCK_DV=BLOCK_DV, Lq=Lq, Lv=Lv, USE_CUSTOM_MASK=USE_CUSTOM_MASK, IS_CAUSAL=is_causal, SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK, STORE_TRANSPOSE=True, ) def extend_attention_fwd_v2_decode( q_extend, k_extend, v_extend, o_extend, k_buffer, v_buffer, qo_indptr, kv_indptr, kv_indices, custom_mask, is_causal, mask_indptr, max_len_extend, sm_scale, k_scale, v_scale, sliding_window_size, sinks, window_kv_offsets, xai_temperature_len, skip_prefix_custom_mask, logit_cap, ): """Launch :data:`fn_v2_decode` (hcutune-wrapped :func:`_fwd_kernel_v2_decode`); grid-2 is KV head.""" Lq, Lv = q_extend.shape[-1], v_extend.shape[-1] if Lq == 576: BLOCK_DMODEL, BLOCK_DPE = 512, 64 elif Lq == 288: BLOCK_DMODEL, BLOCK_DPE = 256, 32 elif Lq == 192: BLOCK_DMODEL, BLOCK_DPE = 128, 64 else: BLOCK_DMODEL = triton.next_power_of_2(Lq) BLOCK_DPE = 0 BLOCK_DV = triton.next_power_of_2(Lv) sm_scale = sm_scale or 1.0 / (Lq**0.5) batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1] kv_head_num = k_extend.shape[1] kv_group_num = q_extend.shape[1] // kv_head_num USE_CUSTOM_MASK = custom_mask is not None SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask HAS_SINK = sinks is not None if not USE_CUSTOM_MASK: custom_mask = torch.tensor([0], dtype=torch.bool, device=q_extend.device) mask_indptr = torch.tensor([0], dtype=torch.int32, device=q_extend.device) elif sliding_window_size > 0 and window_kv_offsets is None: window_kv_offsets = torch.zeros( batch_size, dtype=torch.int32, device=q_extend.device ) def grid(META): q_seq = max(META["BLOCK_M"] // kv_group_num, 1) return ( batch_size, kv_head_num, triton.cdiv(max_len_extend, q_seq), ) stride_args = ( q_extend.stride(0), q_extend.stride(1), k_extend.stride(0), k_extend.stride(1), v_extend.stride(0), v_extend.stride(1), o_extend.stride(0), o_extend.stride(1), k_buffer.stride(0), k_buffer.stride(1), v_buffer.stride(0), v_buffer.stride(1), ) fn_v2_decode[grid]( q_extend, k_extend, v_extend, o_extend, k_buffer, v_buffer, qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr, sinks, window_kv_offsets, sm_scale, k_scale, v_scale, *stride_args, SLIDING_WINDOW_SIZE=sliding_window_size, logit_cap=logit_cap, xai_temperature_len=xai_temperature_len, HAS_SINK=HAS_SINK, kv_group_num=kv_group_num, num_query_heads=head_num, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DPE=BLOCK_DPE, BLOCK_DV=BLOCK_DV, Lq=Lq, Lv=Lv, USE_CUSTOM_MASK=USE_CUSTOM_MASK, IS_CAUSAL=is_causal, SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK, STORE_TRANSPOSE=True, ) def get_bench_inputs_v2(): """Cases for :func:`_fwd_kernel_v2` / ``_get_config_v2`` keys. 对 ``fp8_dp2_tp8_415_triton_rocm_nomtp.log`` 逐条核对后,与 extend 相关的**键控组合**只有两类: - **MHA**(``k_extend[...,1,...]``):``sliding_window_size=-1``,``sinks=None``(如 log 中 ``k_extend [1152,1,192]`` / ``[3952,1,192]`` 段)。 - **GQA**(``k_extend[...,2,...]``):``sliding_window_size=128``,``sinks shape [16]``(全文件未出现 GQA 与 ``-1``/无 sinks 的同框记录)。 该 log 中 **未出现** ``sliding_window_size=64``,也未出现「GQA + 无 SWA + 无 sinks」;若需覆盖其它模型再单独加行并注明来源。 额外两行覆盖 **causal=True + custom_mask=True**(flatten 因果 mask,与 kernel 中 ``USE_CUSTOM_MASK`` 分支一致;want7 第 4 维为 True)。 """ names = [ "B", "H", "prefix", "extend", "kv_lora_rank", "qk_rope_head_dim", "v_head_dim", "causal", "custom_mask", "sliding_window_size", "has_sink", "kv_num_heads", ] vals = [ # (prefix, extend) 只影响访存/grid;want7 由 head/dim/SWA/sinks 决定。prefix=8192、extend=1024 与长 KV bench 习惯一致。 # (1) GQA Q16/KV2:与 log 中 ``[3952,2,192]`` + ``sliding_window_size=128`` + ``sinks [16]`` 一致;want7 (8,192,128,F,T,T,True) (4, 16, 8192, 1024, 128, 64, 128, True, False, 128, True, 2), # (2) MHA Q16/KV1:与 log 中 ``[...,1,192]`` + ``sliding_window_size=-1`` + ``sinks=None`` 一致;want7 (16,192,128,F,T,F,False) (4, 16, 8192, 1024, 128, 64, 128, True, False, -1, False, 1), # (3) GQA + custom causal mask;want7 (8,192,128,T,T,T,True) (4, 16, 8192, 1024, 128, 64, 128, True, True, 128, True, 2), # (4) MHA + custom causal mask;want7 (16,192,128,T,T,F,False) (4, 16, 8192, 1024, 128, 64, 128, True, True, -1, False, 1), ] return names, vals def get_bench_inputs_v2_decode(): """Cases for :func:`_fwd_kernel_v2_decode` (short ``extend`` / decode grid). Fixed ``B=64``, ``prefix=256``, ``extend=4``, ``equal_seqlens``; causal / custom_mask / sliding_window / sinks / GQA|MHA 与 :func:`get_bench_inputs_v2` 四行一一对应。 """ names = [ "B", "H", "prefix", "extend", "kv_lora_rank", "qk_rope_head_dim", "v_head_dim", "causal", "custom_mask", "sliding_window_size", "has_sink", "kv_num_heads", ] vals = [ (64, 16, 256, 4, 128, 64, 128, True, False, 128, True, 2), (64, 16, 256, 4, 128, 64, 128, True, False, -1, False, 1), (64, 16, 256, 4, 128, 64, 128, True, True, 128, True, 2), (64, 16, 256, 4, 128, 64, 128, True, True, -1, False, 1), ] return names, vals x_names, x_vals = get_bench_inputs() configs = [ triton.testing.Benchmark( x_names=x_names, x_vals=x_vals, line_arg="provider", line_vals=["triton"], line_names=["triton"], styles=[("red", "-")], ylabel="ms", plot_name="extend_attention", args={'dtype': torch.float16}, ) ] @triton.utils.dist_perf_report(configs) def bench_extend_attention(B, H, prefix, extend, kv_lora_rank, qk_rope_head_dim, v_head_dim, causal, custom_mask, provider, dtype): torch.manual_seed(42) device = 'cpu' if os.getenv("TRITON_HCUTUNE_COMPILE_ONLY", "") == "1" else "cuda" ref_attn_impl = "normal" sm_scale = 1.0 logit_cap = 0.0 ( q_extend, k_extend, v_extend, k_buffer, v_buffer, kv_indptr, kv_indices, qo_indptr, custom_mask, mask_indptr, max_len_extend, ) = input_helper( B, H, prefix, extend, kv_lora_rank, qk_rope_head_dim, v_head_dim, dtype, device, ref_attn_impl, ) tri_out = torch.empty( (*q_extend.shape[:-1], v_extend.shape[-1]), dtype=q_extend.dtype, device=q_extend.device, ) fn = lambda: extend_attention_fwd( q_extend, k_extend, v_extend, tri_out, k_buffer, v_buffer, qo_indptr, kv_indptr, kv_indices, custom_mask, causal, mask_indptr, max_len_extend, sm_scale=sm_scale, logit_cap=logit_cap, ) return triton.testing.do_bench_cudagraph(fn) x_names_v2, x_vals_v2 = get_bench_inputs_v2() configs_v2 = [ triton.testing.Benchmark( x_names=x_names_v2, x_vals=x_vals_v2, line_arg="provider", line_vals=["triton_v2"], line_names=["triton_v2"], styles=[("blue", "-")], ylabel="ms", plot_name="extend_attention_v2_hcutune", args={"dtype": torch.bfloat16}, ) ] @triton.utils.dist_perf_report(configs_v2) def bench_extend_attention_v2( B, H, prefix, extend, kv_lora_rank, qk_rope_head_dim, v_head_dim, causal, custom_mask, sliding_window_size, has_sink, kv_num_heads, provider, dtype, ): torch.manual_seed(42) device = "cpu" if os.getenv("TRITON_HCUTUNE_COMPILE_ONLY", "") == "1" else "cuda" ref_attn_impl = "normal" logit_cap = 0.0 k_scale = 1.0 v_scale = 1.0 xai_temperature_len = -1 skip_prefix_custom_mask = True ( q_extend, k_extend, v_extend, k_buffer, v_buffer, kv_indptr, kv_indices, qo_indptr, custom_mask_t, mask_indptr, max_len_extend, ) = input_helper( B, H, prefix, extend, kv_lora_rank, qk_rope_head_dim, v_head_dim, dtype, device, ref_attn_impl, equal_seqlens=True, kv_num_heads=kv_num_heads, ) if custom_mask: custom_mask_t, mask_indptr = _v2_flat_causal_custom_mask( B, prefix, extend, device ) sm_scale = float(1.0 / (q_extend.shape[-1] ** 0.5)) sinks = ( torch.randn(H, dtype=q_extend.dtype, device=device) if has_sink else None ) window_kv_offsets = None if custom_mask and sliding_window_size > 0: window_kv_offsets = torch.zeros(B, dtype=torch.int32, device=device) tri_out = torch.empty( (*q_extend.shape[:-1], v_extend.shape[-1]), dtype=q_extend.dtype, device=q_extend.device, ) def run_once(): extend_attention_fwd_v2( q_extend, k_extend, v_extend, tri_out, k_buffer, v_buffer, qo_indptr, kv_indptr, kv_indices, custom_mask_t, causal, mask_indptr, max_len_extend, sm_scale, k_scale, v_scale, sliding_window_size, sinks, window_kv_offsets, xai_temperature_len, skip_prefix_custom_mask, logit_cap, ) return triton.testing.do_bench_cudagraph(run_once) x_names_v2_decode, x_vals_v2_decode = get_bench_inputs_v2_decode() configs_v2_decode = [ triton.testing.Benchmark( x_names=x_names_v2_decode, x_vals=x_vals_v2_decode, line_arg="provider", line_vals=["triton_v2_decode"], line_names=["triton_v2_decode"], styles=[("green", "-")], ylabel="ms", plot_name="extend_attention_v2_decode_hcutune", args={"dtype": torch.bfloat16}, ) ] @triton.utils.dist_perf_report(configs_v2_decode) def bench_extend_attention_v2_decode( B, H, prefix, extend, kv_lora_rank, qk_rope_head_dim, v_head_dim, causal, custom_mask, sliding_window_size, has_sink, kv_num_heads, provider, dtype, ): torch.manual_seed(42) device = "cpu" if os.getenv("TRITON_HCUTUNE_COMPILE_ONLY", "") == "1" else "cuda" ref_attn_impl = "normal" logit_cap = 0.0 k_scale = 1.0 v_scale = 1.0 xai_temperature_len = -1 skip_prefix_custom_mask = True ( q_extend, k_extend, v_extend, k_buffer, v_buffer, kv_indptr, kv_indices, qo_indptr, custom_mask_t, mask_indptr, max_len_extend, ) = input_helper( B, H, prefix, extend, kv_lora_rank, qk_rope_head_dim, v_head_dim, dtype, device, ref_attn_impl, equal_seqlens=True, kv_num_heads=kv_num_heads, ) if custom_mask: custom_mask_t, mask_indptr = _v2_flat_causal_custom_mask( B, prefix, extend, device ) sm_scale = float(1.0 / (q_extend.shape[-1] ** 0.5)) sinks = ( torch.randn(H, dtype=q_extend.dtype, device=device) if has_sink else None ) window_kv_offsets = None if custom_mask and sliding_window_size > 0: window_kv_offsets = torch.zeros(B, dtype=torch.int32, device=device) tri_out = torch.empty( (*q_extend.shape[:-1], v_extend.shape[-1]), dtype=q_extend.dtype, device=q_extend.device, ) def run_once(): extend_attention_fwd_v2_decode( q_extend, k_extend, v_extend, tri_out, k_buffer, v_buffer, qo_indptr, kv_indptr, kv_indices, custom_mask_t, causal, mask_indptr, max_len_extend, sm_scale, k_scale, v_scale, sliding_window_size, sinks, window_kv_offsets, xai_temperature_len, skip_prefix_custom_mask, logit_cap, ) return triton.testing.do_bench_cudagraph(run_once) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--perf", action="store_true", default=False, help="benchmark with hcutuner perf mode") parser.add_argument( "--v1", action="store_true", default=False, help="Tune v1 ``_fwd_kernel`` only.", ) parser.add_argument( "--v2-decode", action="store_true", default=False, help="Tune v2 decode ``_fwd_kernel_v2_decode`` (B=64, prefix=256, extend=4). Default: v2 prefill ``_fwd_kernel_v2``.", ) return parser.parse_args() if __name__ == "__main__": args = parse_args() if args.perf: os.environ["TRITON_HCUTUNE_PERF_MODE"] = "1" # idempotent; real set is at top for Hcutuner init if args.v1: bench_extend_attention.run(print_data=True, save_path="./tune_extend_attention_out") elif args.v2_decode: bench_extend_attention_v2_decode.run( print_data=True, save_path="./tune_extend_attention_v2_decode_out" ) else: bench_extend_attention_v2.run( print_data=True, save_path="./tune_extend_attention_v2_out" )