# SPDX-License-Identifier: MIT from aiter.test_common import ( checkAllclose, benchmark, run_perftest, ) import math import torch import aiter from aiter import dtypes import argparse import os from pathlib import Path import sys from typing import Callable import pandas as pd def _truthy_env(name: str) -> bool: v = os.environ.get(name, "").strip().lower() return v in ("1", "true", "yes", "on") # Set True after argparse when running this script as main (TileKernels / TileLang optional). _compare_tilekernels = False _enable_breakdown = False torch.set_default_device("cuda") # torch.cuda.manual_seed_all(0) # torch.set_printoptions(precision=3, linewidth=200, sci_mode=False) _tilekernels_root = Path(__file__).resolve().parents[2] / "TileKernels" if _tilekernels_root.exists(): sys.path.insert(0, str(_tilekernels_root)) try: from tile_kernels.modeling.mhc.ops import mhc_pre_big_fuse as mhc_pre_tile from tile_kernels.modeling.mhc.ops import mhc_post as mhc_post_tile from tile_kernels.mhc.norm_fn_kernel import _mhc_pre_norm_fn_fwd_mul, round_to_tf32 from tile_kernels.mhc.pre_norm_fn_splitk_kernel import ( mhc_pre_gemm_sqrsum_splitk_kernel, ) from tile_kernels.mhc.pre_big_fuse_kernel import _mhc_pre_big_fuse except Exception: mhc_pre_tile = None mhc_post_tile = None _mhc_pre_norm_fn_fwd_mul = None round_to_tf32 = None mhc_pre_gemm_sqrsum_splitk_kernel = None _mhc_pre_big_fuse = None # copy from tilelang/examples/deepseek_mhc/example_mhc_pre.py def mhc_pre_ref( residual: torch.Tensor, fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, rms_eps: float, hc_pre_eps: float, hc_sinkhorn_eps: float, hc_post_mult_value: float, sinkhorn_repeat: int, test_hc_head: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: hc_mult = residual.shape[-2] residual_flat = residual.flatten(-2, -1).float() sqrsum = residual_flat.square().sum(-1) out = residual_flat @ fn.T mixes = out * (sqrsum.unsqueeze(-1) / fn.shape[-1] + rms_eps).rsqrt() if not test_hc_head: hc_scale = torch.cat( [ hc_scale[0].expand(hc_mult), hc_scale[1].expand(hc_mult), hc_scale[2].expand(hc_mult * hc_mult), ], ) mixes = mixes * hc_scale + hc_base pre_mix = mixes[:, :hc_mult].sigmoid().unsqueeze(-1) + hc_pre_eps post_mix = ( mixes[:, hc_mult : 2 * hc_mult].sigmoid() * hc_post_mult_value ).unsqueeze(-1) res_mix = mixes[:, 2 * hc_mult :].view(-1, hc_mult, hc_mult) def sinkhorn_normalize_ref( x: torch.Tensor, repeat: int, eps: float ) -> torch.Tensor: x = x.softmax(-1) + eps x = x / (x.sum(-2, keepdim=True) + eps) for _ in range(repeat - 1): x = x / (x.sum(-1, keepdim=True) + eps) x = x / (x.sum(-2, keepdim=True) + eps) return x res_mix = sinkhorn_normalize_ref( res_mix, repeat=sinkhorn_repeat, eps=hc_sinkhorn_eps ) else: hc_scale = hc_scale[0].expand(hc_mult) mixes = mixes * hc_scale + hc_base pre_mix = mixes[:, :hc_mult].sigmoid().unsqueeze(-1) + hc_pre_eps post_mix = None res_mix = None layer_input = (residual * pre_mix).sum(-2).bfloat16() return post_mix, res_mix, layer_input def mhc_pre_hip( residual: torch.Tensor, fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, rms_eps: float, hc_pre_eps: float, hc_sinkhorn_eps: float, hc_post_mult_value: float, sinkhorn_repeat: int, use_tf32: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return aiter.mhc_pre( residual, fn, hc_scale, hc_base, rms_eps, hc_pre_eps, hc_sinkhorn_eps, hc_post_mult_value, sinkhorn_repeat, use_tf32, ) def mhc_pre_tilekernels( residual: torch.Tensor, fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, rms_eps: float, hc_pre_eps: float, hc_sinkhorn_eps: float, hc_post_mult_value: float, sinkhorn_repeat: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if mhc_pre_tile is None: raise RuntimeError("TileKernels mhc_pre_big_fuse is unavailable") return mhc_pre_tile( residual, fn, hc_scale, hc_base, rms_eps, hc_pre_eps, hc_sinkhorn_eps, hc_post_mult_value, sinkhorn_repeat, ) """ Disabled MHC pre breakdown helpers. These helpers were used during MHC kernel tuning to separately time aiter stage1/stage2/reduce and TileKernels stage1/stage2. Keep them together here so future debugging can re-enable the whole block, instead of scattering breakdown logic through the benchmark. def _select_mhc_pre_launch( m: int, hc_mult: int, hc_mult3: int, hc_hidden_size: int, sinkhorn_repeat: int, ) -> tuple[int, int]: prefetch_stages = 2 stage1_variant = os.environ.get("AITER_MHC_PRE_STAGE1", "auto").strip().lower() hidden_size = hc_hidden_size // hc_mult use_stage1_m128_auto = ( sinkhorn_repeat > 0 and hc_mult3 == hc_mult * (2 + hc_mult) and not (hidden_size in (1280, 2560) and m <= 512) ) if stage1_variant in ("", "auto"): use_stage1_m128 = use_stage1_m128_auto elif stage1_variant in ("aiter", "legacy"): use_stage1_m128 = False elif stage1_variant in ("m128", "tlstyle"): use_stage1_m128 = True else: raise ValueError("AITER_MHC_PRE_STAGE1 must be 'auto' or 'm128' ('tlstyle' is accepted as an alias)") tile_m = 128 if use_stage1_m128 else 16 * 4 tile_k_tg_dict = {128: 2} if use_stage1_m128 else {128: 2, 64: 4} num_cu = torch.cuda.get_device_properties("cuda").multi_processor_count selected_splitk = 1 selected_tile_k = 128 if use_stage1_m128 else 64 num_tg_m = (m + tile_m - 1) // tile_m if num_tg_m >= num_cu: min_splitk = 2 max_splitk = 2 else: min_splitk = 1 max_splitk = 32 selected_score = num_tg_m / (num_cu * tile_k_tg_dict[selected_tile_k]) selected_score = selected_score / math.ceil(selected_score) for tile_k, tg_per_cu in tile_k_tg_dict.items(): if (hc_hidden_size % tile_k) != 0: continue meanwhile_tg = num_cu * tg_per_cu for splitk in range(min_splitk, max_splitk + 1): if hc_hidden_size % (splitk * tile_k) != 0 or (hc_hidden_size // splitk) < ( tile_k * prefetch_stages ): continue num_tg = num_tg_m * splitk score = num_tg / meanwhile_tg score = score / math.ceil(score) if selected_score < score: selected_splitk = splitk selected_tile_k = tile_k selected_score = score if num_tg > meanwhile_tg * 4: break # Keep TileLang-style M128 split-k aligned with aiter.ops.mhc.mhc_pre. if use_stage1_m128 and hc_hidden_size in (4 * 4096, 4 * 7168): if num_tg_m >= num_cu: candidate_splitk = 2 elif m >= 2048: candidate_splitk = 8 else: candidate_splitk = 32 if ( hc_hidden_size % (candidate_splitk * selected_tile_k) == 0 and (hc_hidden_size // candidate_splitk) >= selected_tile_k * prefetch_stages ): selected_splitk = candidate_splitk # Keep work-bound tile_k override in sync with aiter.ops.mhc.mhc_pre. if not use_stage1_m128 and num_tg_m >= num_cu and selected_tile_k == 128: candidate_tile_k = 64 candidate_splitk = 2 if ( hc_hidden_size % (candidate_splitk * candidate_tile_k) == 0 and (hc_hidden_size // candidate_splitk) >= candidate_tile_k * prefetch_stages ): selected_tile_k = candidate_tile_k selected_splitk = candidate_splitk # Keep small/medium DeepSeek stage1 override in sync with aiter.ops.mhc.mhc_pre. candidate_tile_k = 64 candidate_splitk = 32 if ( not use_stage1_m128 and hc_hidden_size in (4 * 4096, 4 * 7168) and (m <= 1024 or (m == 2048 and hc_hidden_size == 4 * 7168)) and hc_hidden_size % (candidate_splitk * candidate_tile_k) == 0 and (hc_hidden_size // candidate_splitk) >= candidate_tile_k * prefetch_stages ): selected_tile_k = candidate_tile_k selected_splitk = candidate_splitk # Keep breakdown stage selection aligned with aiter.ops.mhc.mhc_pre env overrides. env_tile_k = os.environ.get("AITER_MHC_PRE_TILE_K", "").strip() if env_tile_k: forced_tile_k = int(env_tile_k) if forced_tile_k not in tile_k_tg_dict: msg = "AITER_MHC_PRE_TILE_K must be 128 when AITER_MHC_PRE_STAGE1=m128" if not use_stage1_m128: msg = "AITER_MHC_PRE_TILE_K must be 64 or 128" raise ValueError(msg) if (hc_hidden_size % forced_tile_k) != 0: raise ValueError( f"AITER_MHC_PRE_TILE_K={forced_tile_k} is incompatible with hc_hidden_size={hc_hidden_size}" ) selected_tile_k = forced_tile_k env_splitk = os.environ.get("AITER_MHC_PRE_SPLITK", "").strip() if env_splitk: forced_splitk = int(env_splitk) if forced_splitk < 1: raise ValueError("AITER_MHC_PRE_SPLITK must be >= 1") if hc_hidden_size % (forced_splitk * selected_tile_k) != 0: raise ValueError( "AITER_MHC_PRE_SPLITK is incompatible with selected tile_k/hc_hidden_size" ) if (hc_hidden_size // forced_splitk) < (selected_tile_k * prefetch_stages): raise ValueError( "AITER_MHC_PRE_SPLITK violates prefetch stage constraint for selected tile_k" ) selected_splitk = forced_splitk return selected_splitk, selected_tile_k def _hip_stage1( out: torch.Tensor, sqrsum: torch.Tensor, residual: torch.Tensor, fn: torch.Tensor, tile_k: int, ) -> None: stage1_variant = os.environ.get("AITER_MHC_PRE_STAGE1", "auto").strip().lower() m = residual.size(0) hc_mult = residual.size(1) hidden_size = residual.size(2) hc_mult3 = fn.size(0) use_stage1_m128_auto = ( hc_mult3 == hc_mult * (2 + hc_mult) and not (hidden_size in (1280, 2560) and m <= 512) ) if stage1_variant in ("m128", "tlstyle") or (stage1_variant in ("", "auto") and use_stage1_m128_auto): aiter.mhc_pre_gemm_sqrsum_stage1_m128(out, sqrsum, residual, fn) else: aiter.mhc_pre_gemm_sqrsum(out, sqrsum, residual, fn, tile_k) def _use_mhc_pre_tlstyle( m: int, hidden_size: int, hc_mult: int, hc_mult3: int, sinkhorn_repeat: int, ) -> bool: env_kernel = os.environ.get("AITER_MHC_PRE_KERNEL", "auto").strip().lower() use_tlstyle_auto = ( sinkhorn_repeat > 0 and hc_mult3 == hc_mult * (2 + hc_mult) and m > 128 and not (hidden_size in (1280, 2560) and m <= 512) ) if env_kernel in ("aiter", "legacy"): return False if env_kernel == "tlstyle": return True return use_tlstyle_auto def _hip_stage2( post_mix: torch.Tensor, comb_mix: torch.Tensor, layer_input: torch.Tensor, out: torch.Tensor, sqrsum: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, residual: torch.Tensor, rms_eps: float, hc_pre_eps: float, hc_sinkhorn_eps: float, hc_post_mult_value: float, sinkhorn_repeat: int, ) -> None: m = residual.shape[0] hc_mult = residual.shape[1] hidden_size = residual.shape[2] hc_mult3 = out.shape[2] big_fuse = ( aiter.mhc_pre_big_fuse_tlstyle if _use_mhc_pre_tlstyle(m, hidden_size, hc_mult, hc_mult3, sinkhorn_repeat) else aiter.mhc_pre_big_fuse ) big_fuse( post_mix, comb_mix, layer_input, out, sqrsum, hc_scale, hc_base, residual, rms_eps, hc_pre_eps, hc_sinkhorn_eps, hc_post_mult_value, sinkhorn_repeat, ) def _hip_reduce_splitk( out: torch.Tensor, sqrsum: torch.Tensor, hc_mult3: int, ) -> tuple[torch.Tensor, torch.Tensor]: m = out.shape[1] out_red_pad = torch.empty( 1, m, (hc_mult3 + 31) // 32 * 32, dtype=out.dtype, device=out.device, ) out_red = out_red_pad[:, :, :hc_mult3] sqrsum_red = torch.empty(1, m, dtype=sqrsum.dtype, device=sqrsum.device) aiter.mhc_pre_reduce_splitk(out_red, sqrsum_red, out, sqrsum) return out_red, sqrsum_red def _tile_stage1_prepare( residual: torch.Tensor, fn: torch.Tensor, ) -> tuple[Callable[[], None], torch.Tensor, torch.Tensor, int]: if ( _mhc_pre_norm_fn_fwd_mul is None or round_to_tf32 is None or mhc_pre_gemm_sqrsum_splitk_kernel is None ): raise RuntimeError("TileKernels stage1 internals are unavailable") mhc_mult = residual.shape[-2] hidden_size = residual.shape[-1] mhc_mult2 = mhc_mult * mhc_mult mhc_mult3 = mhc_mult * 2 + mhc_mult2 mhc_hidden_size = mhc_mult * hidden_size residual_flat = residual.view(-1, mhc_mult, hidden_size) num_tokens = residual_flat.shape[0] token_block = 128 hidden_block = 128 hidden_loop = mhc_hidden_size // hidden_block token_loop = (num_tokens + token_block - 1) // token_block cu_count = torch.cuda.get_device_properties("cuda").multi_processor_count if token_loop <= 2: if num_tokens > 128: n_splits_pre = 64 if hidden_loop % n_splits_pre != 0: token_block = 64 n_splits_pre = 32 elif num_tokens > 64: token_block = 64 n_splits_pre = 64 if hidden_loop % n_splits_pre != 0: token_block = 32 n_splits_pre = 32 elif num_tokens > 32: token_block = 32 n_splits_pre = 64 if hidden_loop % n_splits_pre != 0: n_splits_pre = 32 else: token_block = 32 n_splits_pre = 64 if hidden_loop % n_splits_pre != 0: n_splits_pre = 32 elif token_loop <= 4: n_splits_pre = 32 elif token_loop <= cu_count // 8: n_splits_pre = 16 elif token_loop <= cu_count // 4: n_splits_pre = 8 elif token_loop <= cu_count * 0.75: n_splits_pre = 8 elif token_loop < cu_count * 2: n_splits_pre = 4 else: n_splits_pre = 1 fn_tf32 = round_to_tf32(fn) use_small_token_splitk = ( n_splits_pre > 1 and num_tokens < token_block * cu_count * 2 and hidden_loop > 0 and hidden_loop % n_splits_pre == 0 ) if use_small_token_splitk: kernel_0, kernel_1 = mhc_pre_gemm_sqrsum_splitk_kernel( mhc_mult3, mhc_hidden_size, split_k=n_splits_pre, token_block=token_block, hidden_block=hidden_block, ) partial_out = torch.empty( n_splits_pre, num_tokens, mhc_mult3, dtype=torch.float32, device=residual.device ) partial_sqrsum = torch.empty( n_splits_pre, num_tokens, dtype=torch.float32, device=residual.device ) gemm_out_mul = torch.empty( 1, num_tokens, mhc_mult3, dtype=torch.float32, device=residual.device ) gemm_out_sqrsum = torch.empty( 1, num_tokens, dtype=torch.float32, device=residual.device ) def _tile_stage1() -> None: kernel_0( residual_flat.view(-1, mhc_hidden_size), fn_tf32, partial_out, partial_sqrsum, ) kernel_1( partial_out, partial_sqrsum, gemm_out_mul.squeeze(0), gemm_out_sqrsum.squeeze(0), ) return _tile_stage1, gemm_out_mul, gemm_out_sqrsum, 1 gemm_out_mul = torch.empty( 1, num_tokens, mhc_mult3, dtype=torch.float32, device=residual.device ) gemm_out_sqrsum = torch.empty(1, num_tokens, dtype=torch.float32, device=residual.device) fwd_mul_kernel = _mhc_pre_norm_fn_fwd_mul( mhc_mult3, 1, mhc_hidden_size, token_block=128, hidden_block=128 ) def _tile_stage1() -> None: fwd_mul_kernel( residual_flat.view(-1, mhc_hidden_size), fn_tf32, gemm_out_mul.view(-1, 1, mhc_mult3), gemm_out_sqrsum.view(-1, 1), ) return _tile_stage1, gemm_out_mul, gemm_out_sqrsum, 1 def _tile_stage2( residual: torch.Tensor, gemm_out_mul: torch.Tensor, gemm_out_sqrsum: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, rms_eps: float, hc_pre_eps: float, hc_sinkhorn_eps: float, hc_post_mult_value: float, sinkhorn_repeat: int, n_splits: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if _mhc_pre_big_fuse is None: raise RuntimeError("TileKernels stage2 internals are unavailable") mhc_mult = residual.shape[-2] hidden_size = residual.shape[-1] mhc_mult2 = mhc_mult * mhc_mult residual_flat = residual.view(-1, mhc_mult, hidden_size) num_tokens = residual_flat.shape[0] post_mix = torch.empty(num_tokens, mhc_mult, dtype=torch.float32, device=residual.device) comb_mix = torch.empty(num_tokens, mhc_mult2, dtype=torch.float32, device=residual.device) layer_input = torch.empty(num_tokens, hidden_size, dtype=torch.bfloat16, device=residual.device) _mhc_pre_big_fuse( hidden_size, rms_eps, hc_pre_eps, hc_sinkhorn_eps, hc_post_mult_value, sinkhorn_repeat, n_splits=n_splits, mhc_mult=mhc_mult, )( gemm_out_mul, gemm_out_sqrsum, hc_scale, hc_base, residual_flat, post_mix, comb_mix, layer_input, ) post_mix = post_mix.view(*residual.shape[:-2], mhc_mult, 1) comb_mix = comb_mix.view(*residual.shape[:-2], mhc_mult, mhc_mult) layer_input = layer_input.view(*residual.shape[:-2], hidden_size) return post_mix, comb_mix, layer_input """ def _round_to_tf32_like_tilekernels(x: torch.Tensor) -> torch.Tensor: return (x.view(torch.int32) + 0x1000).view(torch.float32) @benchmark() def test_mhc_pre(m, hidden_size, hc_mult, test_hc_head=False, use_tf32=False): hc_mult2 = hc_mult * hc_mult hc_mult3 = hc_mult * 2 + hc_mult2 if not test_hc_head else hc_mult hc_hidden_size = hc_mult * hidden_size residual = torch.randn(m, hc_mult, hidden_size, dtype=dtypes.bf16) fn = torch.randn(hc_mult3, hc_hidden_size, dtype=dtypes.fp32) hc_scale = torch.randn((3,), dtype=dtypes.fp32) * 0.1 hc_base = torch.randn((hc_mult3,), dtype=dtypes.fp32) * 0.1 extra_args = { "rms_eps": 1e-6, "hc_pre_eps": 1e-6, "hc_sinkhorn_eps": 1e-6, "hc_post_mult_value": 1.0, "sinkhorn_repeat": 20 if not test_hc_head else 0, } fn_ref = _round_to_tf32_like_tilekernels(fn) if use_tf32 else fn post_mix_ref, comb_mix_ref, layer_input_ref = mhc_pre_ref( residual, fn_ref, hc_scale, hc_base, **extra_args, test_hc_head=test_hc_head, ) (post_mix_hip, comb_mix_hip, layer_input_hip), hip_us = run_perftest( mhc_pre_hip, residual, fn, hc_scale, hc_base, **extra_args, use_tf32=use_tf32, ) if not test_hc_head: checkAllclose(post_mix_ref, post_mix_hip, msg="post_mix") checkAllclose(comb_mix_ref, comb_mix_hip, msg="comb_mix") hip_err = checkAllclose(layer_input_ref, layer_input_hip, msg="layer_input") ret = {} ret["hip_err"] = hip_err ret["hip_us"] = hip_us ret["use_tf32"] = use_tf32 # Breakdown timing is disabled. See the contiguous disabled helper block # above if stage1/stage2 analysis is needed again. # ret["TFLOPS * us"] = 2.0 * m * hidden_size * hc_mult * hc_mult3 / 1e6 # ret["GB"] = (m * hc_mult3 * dtypes.fp32.itemsize + (m * hc_mult + m) * hidden_size * dtypes.bf16.itemsize) / 1e6 if _compare_tilekernels: try: if test_hc_head: raise RuntimeError( "TileKernels mhc_pre_big_fuse does not support hc_head-only mode" ) (post_mix_tile, comb_mix_tile, layer_input_tile), tile_us = run_perftest( mhc_pre_tilekernels, residual, fn, hc_scale, hc_base, **extra_args, ) fn_tile_ref = _round_to_tf32_like_tilekernels(fn) if use_tf32 else fn post_mix_tile_ref, comb_mix_tile_ref, layer_input_tile_ref = mhc_pre_ref( residual, fn_tile_ref, hc_scale, hc_base, **extra_args, test_hc_head=test_hc_head, ) checkAllclose(post_mix_tile_ref, post_mix_tile, msg="tile_post_mix") tile_err = checkAllclose(comb_mix_tile_ref, comb_mix_tile, msg="tile_comb_mix") checkAllclose(layer_input_tile_ref, layer_input_tile, msg="tile_layer_input") ret["tile_err"] = tile_err ret["tile_us"] = tile_us if tile_us and hip_us: ret["tile/hip_us"] = tile_us / hip_us # TileKernels stage breakdown is disabled with the aiter breakdown block. except Exception as e: tile_err = str(e) print(f"tilekernels mhc_pre error: {tile_err}") ret["tile_err"] = tile_err ret["tile_us"] = None return ret def mhc_post_hip( x: torch.Tensor, residual: torch.Tensor, post_layer_mix: torch.Tensor, comb_res_mix: torch.Tensor, ) -> torch.Tensor: out = torch.empty_like(residual) aiter.mhc_post( out, x, residual, post_layer_mix, comb_res_mix, ) return out def mhc_post_tilekernels( x: torch.Tensor, residual: torch.Tensor, post_layer_mix: torch.Tensor, comb_res_mix: torch.Tensor, ) -> torch.Tensor: if mhc_post_tile is None: raise RuntimeError("TileKernels mhc_post is unavailable") # TileKernels expects (num_seqs, num_tokens, ...) like tests/mhc/test_post.py; # aiter op_tests use flat batch (m, ...). Insert seq dim when needed. if residual.ndim == 3: x_tl = x.unsqueeze(0) residual_tl = residual.unsqueeze(0) post_layer_mix_tl = post_layer_mix.unsqueeze(0) comb_res_mix_tl = comb_res_mix.unsqueeze(0) out_tl = mhc_post_tile( x_tl, residual_tl, post_layer_mix_tl, comb_res_mix_tl, ) return out_tl.squeeze(0) return mhc_post_tile( x, residual, post_layer_mix, comb_res_mix, ) # copy from tilelang/examples/deepseek_mhc/example_mhc_post.py def mhc_post_ref( x: torch.Tensor, residual: torch.Tensor, post_layer_mix: torch.Tensor, comb_res_mix: torch.Tensor, ) -> torch.Tensor: term2 = torch.bmm(comb_res_mix.mT, residual.float()) return (x.float().unsqueeze(-2) * post_layer_mix + term2).bfloat16() @benchmark() def test_mhc_post(m, hidden_size, hc_mult): x = torch.randn(m, hidden_size, dtype=dtypes.bf16) residual = torch.randn(m, hc_mult, hidden_size, dtype=dtypes.bf16) post_layer_mix = torch.randn(m, hc_mult, 1, dtype=dtypes.fp32) comb_res_mix = torch.randn(m, hc_mult, hc_mult, dtype=dtypes.fp32) out_ref = mhc_post_ref(x, residual, post_layer_mix, comb_res_mix) out_hip, hip_us = run_perftest( mhc_post_hip, x, residual, post_layer_mix, comb_res_mix, ) hip_err = checkAllclose(out_ref, out_hip, msg="out") ret = {} ret["hip_err"] = hip_err ret["hip_us"] = hip_us ret["TB/s"] = ( ( out_ref.numel() * out_ref.dtype.itemsize + x.numel() * x.dtype.itemsize + residual.numel() * residual.dtype.itemsize + post_layer_mix.numel() * post_layer_mix.dtype.itemsize + comb_res_mix.numel() * comb_res_mix.dtype.itemsize ) / 1e6 / hip_us ) if _compare_tilekernels: try: out_tile, tile_us = run_perftest( mhc_post_tilekernels, x, residual, post_layer_mix, comb_res_mix, ) tile_err = checkAllclose(out_ref, out_tile, msg="tile_out") ret["tile_err"] = tile_err ret["tile_us"] = tile_us if tile_us and hip_us: ret["tile/hip_us"] = tile_us / hip_us except Exception as e: tile_err = str(e) print(f"tilekernels mhc_post error: {tile_err}") ret["tile_err"] = tile_err ret["tile_us"] = None return ret def check_mhc_post_dispatch_regression(df: pd.DataFrame) -> None: """Optional guard for auto-dispatch correctness and its key win anchors.""" if df.empty: return max_hip_err = df["hip_err"].max() assert max_hip_err == 0, f"mhc_post hip_err regression: max hip_err={max_hip_err}" post_kernel = os.environ.get("AITER_MHC_POST_KERNEL", "").strip().lower() auto_dispatch = post_kernel in ("", "auto") if not auto_dispatch or "tile/hip_us" not in df: return win_anchors = [ (128, 1280), (256, 1280), (512, 1280), ] for m, hidden_size in win_anchors: row = df[(df["m"] == m) & (df["hidden_size"] == hidden_size)] if row.empty: continue ratio = row.iloc[0]["tile/hip_us"] assert ratio > 1.0, ( "mhc_post auto dispatch regression: expected aiter to beat TileKernels " f"at m={m}, hidden_size={hidden_size}, got tile/hip_us={ratio}" ) parser = argparse.ArgumentParser( formatter_class=argparse.RawTextHelpFormatter, description="config input of test", ) parser.add_argument( "-d", "--dtype", type=dtypes.str2Dtype, choices=[dtypes.d_dtypes["fp16"], dtypes.d_dtypes["bf16"]], nargs="*", metavar="{fp16, bf16}", default=["bf16"], help="""Data type. e.g.: -d bf16""", ) parser.add_argument( "-m", type=int, nargs="*", choices=[1, 32, 64, 128, 256, 512, 1024, 2048, 8192, 65536], default=[1, 32, 64, 128, 256, 512, 1024, 2048, 8192, 65536], help="""M. e.g.: -m 32""", ) parser.add_argument( "-n", "--hidden_size", type=int, nargs="*", choices=[1280, 2560, 4096, 7168], default=[1280, 2560, 4096, 7168], help="""hidden_size. e.g.: -hidden_size 1024""", ) parser.add_argument( "--hc_head", action="store_true", help="""Test mhc_pre for hc_head only.""", ) parser.add_argument( "--compare-tilekernels", action="store_true", help="""Also compare against TileKernels MHC (TileLang-backed) implementations. Default off so environments without TileLang still pass. Equivalent: set env AITER_MHC_COMPARE_TILEKERNELS=1 (or true/yes/on).""", ) # Breakdown analysis is intentionally disabled after the MHC tuning pass. # Re-enable the contiguous disabled helper block above before restoring this arg. # parser.add_argument( # "--breakdown", # action="store_true", # help="""Collect mhc_pre stage1/stage2 timing ratios for both aiter and TileKernels.""", # ) parser.add_argument( "--post-only", action="store_true", help="""Only run mhc_post tests. Useful for post dispatch sweeps/regression.""", ) parser.add_argument( "--post-dispatch-regression", action="store_true", help="""Assert mhc_post auto-dispatch correctness and key performance anchors.""", ) parser.add_argument( "--use-tf32", action="store_true", help="""Run mhc_pre stage1 GEMM with the optional HCU TF32 MMAC path.""", ) args = parser.parse_args() _compare_tilekernels = bool(args.compare_tilekernels) or _truthy_env( "AITER_MHC_COMPARE_TILEKERNELS" ) _enable_breakdown = False if _compare_tilekernels: aiter.logger.info( "TileKernels compare enabled (--compare-tilekernels or AITER_MHC_COMPARE_TILEKERNELS=1)" ) if not args.post_only: df = [] for dtype in args.dtype: for m in args.m: for hidden_size in args.hidden_size: for hc_mult in [4]: ret = test_mhc_pre( m=m, hidden_size=hidden_size, hc_mult=hc_mult, test_hc_head=args.hc_head, use_tf32=args.use_tf32, ) df.append(ret) df = pd.DataFrame(df) df_md = df.to_markdown(index=False) aiter.logger.info("mhc_pre summary (markdown):\n%s", df_md) if not args.hc_head: df = [] for dtype in args.dtype: for hidden_size in args.hidden_size: for m in args.m: for hc_mult in [4]: ret = test_mhc_post(m=m, hidden_size=hidden_size, hc_mult=hc_mult) df.append(ret) df = pd.DataFrame(df) df_md = df.to_markdown(index=False) aiter.logger.info("mhc_post summary (markdown):\n%s", df_md) if args.post_dispatch_regression: check_mhc_post_dispatch_regression(df)