import os import sys os.environ["AMDGCN_USE_BUFFER_OPS"] = "1" # :class:`Hcutuner` reads ``TRITON_HCUTUNE_PERF_MODE`` in ``__init__``. Module-level # ``fn`` / ``fn_v2 = triton.utils.hcutune(...)`` runs at import time, *before* # ``if __name__ == "__main__"``, so ``--perf`` must be applied here (or set in the shell). if "--perf" in sys.argv: os.environ["TRITON_HCUTUNE_PERF_MODE"] = "1" # GPU / ROCm tuning: run **inside** the ``zww_tl_1`` container (not the bare host), e.g.: # docker exec zww_tl_1 bash -lc 'cd /data/zhouweiwang/aiter/op_tests/triton_autotune && python tune_extend_attention.py --perf' # ``do_bench`` timing during tuning (smaller => faster iteration; raise for stable prod numbers). TUNE_DO_BENCH_WARMUP = 5 TUNE_DO_BENCH_REP = 20 import json import torch import triton import random import itertools import argparse from aiter.ops.triton.extend_attention import _fwd_kernel, _fwd_kernel_v2 _is_hip = True # hcutune key for :func:`_fwd_kernel_v2`. JSON block-size lookup uses :func:`_get_config_v2` # ``want7`` only; these names add kernel constexprs (``SKIP_PREFIX_CUSTOM_MASK``, # ``xai_temperature_len``) used for autotune but not in the V2 JSON key. # Log alignment (e.g. ``fp8_dp2_tp8_415_triton_rocm_nomtp.log`` ~2291): ``kv_group_num = q.size(-2)//k.size(-2)``, # ``Lq``/``Lv`` last dims, ``USE_CUSTOM_MASK = custom_mask is not None``, ``HAS_SINK = sinks is not None``. HCUTUNE_KEY_V2 = [ "kv_group_num", "Lq", "Lv", "USE_CUSTOM_MASK", "IS_CAUSAL", "SKIP_PREFIX_CUSTOM_MASK", "HAS_SINK", "SLIDING_WINDOW_SIZE", "xai_temperature_len", ] version = triton.__version__.split(".") major_version, minor_version = eval(version[0]), eval(version[1]) def input_helper( B, H, prefix_length, extend_length, kv_lora_rank, qk_rope_head_dim, v_head_dim, dtype, device, attn_impl="normal", equal_seqlens=False, requires_grad=False, kv_num_heads: int = 1, ): torch.manual_seed(0) if not equal_seqlens: max_extend_length = extend_length max_prefix_length = prefix_length seqlens_extend = torch.randint( 1, max_extend_length + 1, (B,), dtype=torch.int32 ) if prefix_length == 0: seqlens_prefix = torch.full((B,), prefix_length, dtype=torch.int32) else: seqlens_prefix = torch.randint( 1, max_prefix_length + 1, (B,), dtype=torch.int32 ) else: seqlens_extend = torch.full((B,), extend_length, dtype=torch.int32) seqlens_prefix = torch.full((B,), prefix_length, dtype=torch.int32) cu_seqlens_extend = torch.cat( [ torch.tensor([0], dtype=torch.int32), seqlens_extend.cumsum(dim=0, dtype=torch.int32), ] ) cu_seqlens_prefix = torch.cat( [ torch.tensor([0], dtype=torch.int32), seqlens_prefix.cumsum(dim=0, dtype=torch.int32), ] ) cu_seqlens_extend = cu_seqlens_extend.to(device="cuda") cu_seqlens_prefix = cu_seqlens_prefix.to(device="cuda") total_extend = cu_seqlens_extend[-1].item() total_prefix = cu_seqlens_prefix[-1].item() if attn_impl == "absorb": Lq = kv_lora_rank + qk_rope_head_dim Lk = kv_lora_rank + qk_rope_head_dim Lv = kv_lora_rank else: Lq = v_head_dim + qk_rope_head_dim Lk = v_head_dim + qk_rope_head_dim Lv = v_head_dim q_extend = torch.randn( total_extend, H, Lq, dtype=dtype, device=device ).requires_grad_(requires_grad) # extend parts (``kv_num_heads`` for GQA: e.g. 2 when q has 16 heads and kv_group_num is 8) k_extend = torch.randn( total_extend, kv_num_heads, Lk, dtype=dtype, device=device ).requires_grad_(requires_grad) v_extend = k_extend[..., :Lv] # extend indexing qo_indptr = cu_seqlens_extend # prefix parts k_buffer = torch.randn( total_prefix, kv_num_heads, Lk, dtype=dtype, device=device ).requires_grad_(requires_grad) v_buffer = k_buffer[..., :Lv] if attn_impl != "absorb": # simulate v = kv_latent * w_vc which changes the values compared to k v_extend = torch.randn_like(v_extend, dtype=v_extend.dtype) v_buffer = torch.randn_like(v_buffer, dtype=v_buffer.dtype) # prefix indexing kv_indptr = cu_seqlens_prefix kv_indices = torch.arange(total_prefix, device=device, dtype=torch.int32) custom_mask = None mask_indptr = None max_len_extend = extend_length return ( q_extend, k_extend, v_extend, k_buffer, v_buffer, kv_indptr, kv_indices, qo_indptr, custom_mask, mask_indptr, max_len_extend, ) def get_bench_inputs(): names = ["B", "H", "prefix", "extend", "kv_lora_rank", "qk_rope_head_dim", "v_head_dim", "causal", "custom_mask"] vals = [] shapes = [ (2, 4, 0, 512, 32, 16, 32), (3, 5, 0, 333, 18, 13, 17), (3, 5, 512, 333, 18, 0, 17), (3, 5, 110, 333, 18, 0, 19), (8, 16, 0, 1024, 128, 0, 128), # this one passes # (8, 16, 0, 16324, 128, 0, 128), # this one fails, numeric precision is likely the issue (2, 1, 64, 32, 128, 64, 128), (2, 1, 64, 32, 128, 64, 128), (4, 16, 64, 96, 128, 64, 128), (1, 16, 0, 7, 512, 64, 512), (1, 16, 7, 4, 512, 64, 512), (1, 16, 32, 4, 512, 64, 512), (1, 16, 64, 3, 512, 64, 512), (1, 16, 127, 4, 512, 64, 512), (1, 16, 255, 15, 512, 64, 512), (3, 16, 452, 16, 512, 64, 512), (4, 16, 512, 14, 512, 64, 512), (4, 16, 1024, 16, 512, 64, 512), (4, 16, 2048, 13, 512, 64, 512), ] for is_causal in [True, False]: for use_custom_mask in [True, False]: for s in shapes: vals.append((*s, is_causal, use_custom_mask)) return names, vals def generate_configs(config): keys = list(config.keys()) values = list(config.values()) configs_list = [] for combination in itertools.product(*values): cfg = dict(zip(keys, combination)) configs_list.append(cfg) return configs_list # def get_triton_configs(): # config = { # "BLOCK_M": [16, 32, 64], # "BLOCK_N": [16, 32, 64], # "waves_per_eu": [1], # "num_warps": [4, 8, 16], # # "instruction_sched_variant": ["none", "llvm-iglp-1", "llvm-iglp-8", "local-prefetch"], # # "schedule_hint": ["none", "llvm-iglp-8", "local-prefetch"], # "matrix_instr_nonkdim": [16], # "num_stages": [1, 2, 3], # "sched_latency": ["none", "mmac5-ds10"], # "kpack": [1, 2], # } # tt_configs = [] # for c in generate_configs(config): # num_warps = c['num_warps'] # num_stages = c['num_stages'] # del c['num_warps'] # del c['num_stages'] # tt_configs.append(triton.Config(c, num_warps=num_warps, num_stages=num_stages)) # return tt_configs def get_triton_configs(): config = { "BLOCK_M": [16, 32, 64], "BLOCK_N": [16, 32, 64], "waves_per_eu": [1], "num_warps": [4, 8], # "instruction_sched_variant": ["none", "llvm-iglp-1", "llvm-iglp-8", "local-prefetch"], # "schedule_hint": ["none", "llvm-iglp-8", "local-prefetch"], "matrix_instr_nonkdim": [16], "num_stages": [1, 2], "sched_latency": ["none", "mmac5-ds10"], "kpack": [1], } tt_configs = [] for c in generate_configs(config): num_warps = c['num_warps'] num_stages = c['num_stages'] del c['num_warps'] del c['num_stages'] tt_configs.append(triton.Config(c, num_warps=num_warps, num_stages=num_stages)) return tt_configs def prune_configs(configs, nargs, **kwargs): def _prune(config): c = config.all_kwargs() if c['BLOCK_M'] == 64 and c['num_stages'] >= 2: return True elif c['BLOCK_M'] >= 64 and c['num_warps'] >= 8: return True res = [c for c in configs if not _prune(c)] return res key = [ 'kv_group_num', 'SLIDING_WINDOW_SIZE', 'Lq', 'Lv', 'USE_CUSTOM_MASK', 'IS_CAUSAL', 'SKIP_PREFIX_CUSTOM_MASK', 'STORE_TRANSPOSE', ] fn = triton.utils.hcutune( configs=get_triton_configs(), key=key, perf_debug=True, prune_configs_by={"early_config_prune": prune_configs}, warmup=TUNE_DO_BENCH_WARMUP, rep=TUNE_DO_BENCH_REP, )(_fwd_kernel) fn_v2 = triton.utils.hcutune( configs=get_triton_configs(), key=HCUTUNE_KEY_V2, perf_debug=True, prune_configs_by={"early_config_prune": prune_configs}, warmup=TUNE_DO_BENCH_WARMUP, rep=TUNE_DO_BENCH_REP, )(_fwd_kernel_v2) def extend_attention_fwd( q_extend, k_extend, v_extend, o_extend, k_buffer, v_buffer, qo_indptr, kv_indptr, kv_indices, custom_mask, is_causal, mask_indptr, max_len_extend, sm_scale=None, logit_cap=0.0, skip_prefix_custom_mask=True, # config: Optional[dict[str, any]] = None, ): """ q_extend, k_extend, v_extend, o_extend: contiguous tensors k_buffer, v_buffer: (prefix + extend) tensors in mem_manager """ Lq, Lv = ( q_extend.shape[-1], v_extend.shape[-1], ) if Lq == 576: BLOCK_DMODEL = 512 BLOCK_DPE = 64 elif Lq == 288: BLOCK_DMODEL = 256 BLOCK_DPE = 32 elif Lq == 192: BLOCK_DMODEL = 128 BLOCK_DPE = 64 else: BLOCK_DMODEL = triton.next_power_of_2(Lq) BLOCK_DPE = 0 BLOCK_DV = triton.next_power_of_2(Lv) # BLOCK_M, BLOCK_N = (64, 64) # num_warps = 4 sm_scale = sm_scale or 1.0 / (Lq**0.5) batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1] kv_group_num = q_extend.shape[1] // k_extend.shape[1] USE_CUSTOM_MASK = custom_mask is not None # Skip custom mask for prefix part SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask if not USE_CUSTOM_MASK: custom_mask = torch.tensor([0], dtype=torch.bool, device=q_extend.device) mask_indptr = torch.tensor([0], dtype=torch.int32, device=q_extend.device) grid = lambda META: (batch_size, head_num, triton.cdiv(max_len_extend, META["BLOCK_M"])) # num_stages = 1 # extra_kargs = {} # extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2} fn[grid]( q_extend, k_extend, v_extend, o_extend, k_buffer, v_buffer, qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr, sm_scale, kv_group_num, q_extend.stride(0), q_extend.stride(1), k_extend.stride(0), k_extend.stride(1), v_extend.stride(0), v_extend.stride(1), o_extend.stride(0), o_extend.stride(1), k_buffer.stride(0), k_buffer.stride(1), v_buffer.stride(0), v_buffer.stride(1), logit_cap=logit_cap, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DPE=BLOCK_DPE, BLOCK_DV=BLOCK_DV, # BLOCK_M=BLOCK_M, # BLOCK_N=BLOCK_N, Lq=Lq, Lv=Lv, USE_CUSTOM_MASK=USE_CUSTOM_MASK, IS_CAUSAL=is_causal, SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK, STORE_TRANSPOSE=True, # num_warps=num_warps, # num_stages=num_stages, # **config, ) def extend_attention_fwd_v2( q_extend, k_extend, v_extend, o_extend, k_buffer, v_buffer, qo_indptr, kv_indptr, kv_indices, custom_mask, is_causal, mask_indptr, max_len_extend, sm_scale, k_scale, v_scale, sliding_window_size, sinks, window_kv_offsets, xai_temperature_len, skip_prefix_custom_mask, logit_cap, ): """Launch :data:`fn_v2` (hcutune-wrapped :func:`_fwd_kernel_v2`); kwargs align with ``extend_attention_fwd`` v2 path.""" Lq, Lv = q_extend.shape[-1], v_extend.shape[-1] if Lq == 576: BLOCK_DMODEL, BLOCK_DPE = 512, 64 elif Lq == 288: BLOCK_DMODEL, BLOCK_DPE = 256, 32 elif Lq == 192: BLOCK_DMODEL, BLOCK_DPE = 128, 64 else: BLOCK_DMODEL = triton.next_power_of_2(Lq) BLOCK_DPE = 0 BLOCK_DV = triton.next_power_of_2(Lv) sm_scale = sm_scale or 1.0 / (Lq**0.5) batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1] kv_group_num = q_extend.shape[1] // k_extend.shape[1] USE_CUSTOM_MASK = custom_mask is not None SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask HAS_SINK = sinks is not None if not USE_CUSTOM_MASK: custom_mask = torch.tensor([0], dtype=torch.bool, device=q_extend.device) mask_indptr = torch.tensor([0], dtype=torch.int32, device=q_extend.device) grid = lambda META: ( batch_size, head_num, triton.cdiv(max_len_extend, META["BLOCK_M"]), ) stride_args = ( q_extend.stride(0), q_extend.stride(1), k_extend.stride(0), k_extend.stride(1), v_extend.stride(0), v_extend.stride(1), o_extend.stride(0), o_extend.stride(1), k_buffer.stride(0), k_buffer.stride(1), v_buffer.stride(0), v_buffer.stride(1), ) fn_v2[grid]( q_extend, k_extend, v_extend, o_extend, k_buffer, v_buffer, qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr, sinks, window_kv_offsets, sm_scale, k_scale, v_scale, kv_group_num, *stride_args, SLIDING_WINDOW_SIZE=sliding_window_size, logit_cap=logit_cap, xai_temperature_len=xai_temperature_len, HAS_SINK=HAS_SINK, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DPE=BLOCK_DPE, BLOCK_DV=BLOCK_DV, Lq=Lq, Lv=Lv, USE_CUSTOM_MASK=USE_CUSTOM_MASK, IS_CAUSAL=is_causal, SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK, STORE_TRANSPOSE=True, ) def get_bench_inputs_v2(): """Cases for :func:`_fwd_kernel_v2` / ``_get_config_v2`` keys. 对 ``fp8_dp2_tp8_415_triton_rocm_nomtp.log`` 逐条核对后,与 extend 相关的**键控组合**只有两类: - **MHA**(``k_extend[...,1,...]``):``sliding_window_size=-1``,``sinks=None``(如 log 中 ``k_extend [1152,1,192]`` / ``[3952,1,192]`` 段)。 - **GQA**(``k_extend[...,2,...]``):``sliding_window_size=128``,``sinks shape [16]``(全文件未出现 GQA 与 ``-1``/无 sinks 的同框记录)。 该 log 中 **未出现** ``sliding_window_size=64``,也未出现「GQA + 无 SWA + 无 sinks」;若需覆盖其它模型再单独加行并注明来源。 """ names = [ "B", "H", "prefix", "extend", "kv_lora_rank", "qk_rope_head_dim", "v_head_dim", "causal", "custom_mask", "sliding_window_size", "has_sink", "kv_num_heads", ] vals = [ # (prefix, extend) 只影响访存/grid;want7 由 head/dim/SWA/sinks 决定。prefix=8192、extend=1024 与长 KV bench 习惯一致。 # (1) GQA Q16/KV2:与 log 中 ``[3952,2,192]`` + ``sliding_window_size=128`` + ``sinks [16]`` 一致;want7 (8,192,128,F,T,T,128) (4, 16, 8192, 1024, 128, 64, 128, True, False, 128, True, 2), # (2) MHA Q16/KV1:与 log 中 ``[...,1,192]`` + ``sliding_window_size=-1`` + ``sinks=None`` 一致;want7 (16,192,128,F,T,F,-1) (4, 16, 8192, 1024, 128, 64, 128, True, False, -1, False, 1), ] return names, vals x_names, x_vals = get_bench_inputs() configs = [ triton.testing.Benchmark( x_names=x_names, x_vals=x_vals, line_arg="provider", line_vals=["triton"], line_names=["triton"], styles=[("red", "-")], ylabel="ms", plot_name="extend_attention", args={'dtype': torch.float16}, ) ] @triton.utils.dist_perf_report(configs) def bench_extend_attention(B, H, prefix, extend, kv_lora_rank, qk_rope_head_dim, v_head_dim, causal, custom_mask, provider, dtype): torch.manual_seed(42) device = 'cpu' if os.getenv("TRITON_HCUTUNE_COMPILE_ONLY", "") == "1" else "cuda" ref_attn_impl = "normal" sm_scale = 1.0 logit_cap = 0.0 ( q_extend, k_extend, v_extend, k_buffer, v_buffer, kv_indptr, kv_indices, qo_indptr, custom_mask, mask_indptr, max_len_extend, ) = input_helper( B, H, prefix, extend, kv_lora_rank, qk_rope_head_dim, v_head_dim, dtype, device, ref_attn_impl, ) tri_out = torch.empty( (*q_extend.shape[:-1], v_extend.shape[-1]), dtype=q_extend.dtype, device=q_extend.device, ) fn = lambda: extend_attention_fwd( q_extend, k_extend, v_extend, tri_out, k_buffer, v_buffer, qo_indptr, kv_indptr, kv_indices, custom_mask, causal, mask_indptr, max_len_extend, sm_scale=sm_scale, logit_cap=logit_cap, ) return triton.testing.do_bench_cudagraph(fn) x_names_v2, x_vals_v2 = get_bench_inputs_v2() configs_v2 = [ triton.testing.Benchmark( x_names=x_names_v2, x_vals=x_vals_v2, line_arg="provider", line_vals=["triton_v2"], line_names=["triton_v2"], styles=[("blue", "-")], ylabel="ms", plot_name="extend_attention_v2_hcutune", args={"dtype": torch.bfloat16}, ) ] @triton.utils.dist_perf_report(configs_v2) def bench_extend_attention_v2( B, H, prefix, extend, kv_lora_rank, qk_rope_head_dim, v_head_dim, causal, custom_mask, sliding_window_size, has_sink, kv_num_heads, provider, dtype, ): torch.manual_seed(42) device = "cpu" if os.getenv("TRITON_HCUTUNE_COMPILE_ONLY", "") == "1" else "cuda" ref_attn_impl = "normal" logit_cap = 0.0 k_scale = 1.0 v_scale = 1.0 xai_temperature_len = -1 skip_prefix_custom_mask = True ( q_extend, k_extend, v_extend, k_buffer, v_buffer, kv_indptr, kv_indices, qo_indptr, custom_mask_t, mask_indptr, max_len_extend, ) = input_helper( B, H, prefix, extend, kv_lora_rank, qk_rope_head_dim, v_head_dim, dtype, device, ref_attn_impl, equal_seqlens=True, kv_num_heads=kv_num_heads, ) if custom_mask: raise NotImplementedError( "tune v2 with custom_mask requires mask tensors; use custom_mask=False for hcutune key matching" ) sm_scale = float(1.0 / (q_extend.shape[-1] ** 0.5)) sinks = ( torch.randn(H, dtype=q_extend.dtype, device=device) if has_sink else None ) window_kv_offsets = None tri_out = torch.empty( (*q_extend.shape[:-1], v_extend.shape[-1]), dtype=q_extend.dtype, device=q_extend.device, ) def run_once(): extend_attention_fwd_v2( q_extend, k_extend, v_extend, tri_out, k_buffer, v_buffer, qo_indptr, kv_indptr, kv_indices, custom_mask_t, causal, mask_indptr, max_len_extend, sm_scale, k_scale, v_scale, sliding_window_size, sinks, window_kv_offsets, xai_temperature_len, skip_prefix_custom_mask, logit_cap, ) return triton.testing.do_bench_cudagraph(run_once) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--perf", action="store_true", default=False, help="benchmark with hcutuner perf mode") parser.add_argument( "--v1", action="store_true", default=False, help="Tune v1 ``_fwd_kernel`` only. Default: v2 ``_fwd_kernel_v2`` (JSON: ``_get_config_v2`` / EXTEND_ATTENTION-V2-FP16).", ) return parser.parse_args() if __name__ == "__main__": args = parse_args() if args.perf: os.environ["TRITON_HCUTUNE_PERF_MODE"] = "1" # idempotent; real set is at top for Hcutuner init if args.v1: bench_extend_attention.run(print_data=True, save_path="./tune_extend_attention_out") else: bench_extend_attention_v2.run( print_data=True, save_path="./tune_extend_attention_v2_out" )