from typing import Callable import torch import pytest from aiter.ops.tilelang import mhc_post_fwd as mhc_post def generate_mhc_post_test_data( num_tokens: int, h: int, mhc_mult: int, device: str = 'cuda', ) -> dict[str, torch.Tensor]: x = torch.randn((num_tokens, h), dtype=torch.bfloat16, device=device) residual = torch.randn((num_tokens, mhc_mult, h), dtype=torch.bfloat16, device=device) post_layer_mix = torch.randn((num_tokens, mhc_mult), dtype=torch.float32, device=device) comb_res_mix = torch.randn((num_tokens, mhc_mult, mhc_mult), dtype=torch.float32, device=device) o_grad = torch.randn((num_tokens, mhc_mult, h), dtype=torch.bfloat16, device=device) return { 'x': x, 'residual': residual, 'post_layer_mix': post_layer_mix, 'comb_res_mix': comb_res_mix, 'o_grad': o_grad, } def _tester( impl: Callable[[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], test_data: dict[str, torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: x_ = test_data['x'].clone().requires_grad_() residual_ = test_data['residual'].clone().requires_grad_() post_layer_mix_ = test_data['post_layer_mix'].clone().requires_grad_() comb_res_mix_ = test_data['comb_res_mix'].clone().requires_grad_() out_ = impl(x_, residual_, post_layer_mix_, comb_res_mix_) torch.autograd.backward([out_], [test_data['o_grad']]) return out_, x_.grad, residual_.grad, post_layer_mix_.grad, comb_res_mix_.grad def _estimate_fwd_io_bytes(num_tokens: int, h: int, mhc_mult: int) -> int: n = num_tokens read_bytes = ( n * h * 2 + n * mhc_mult * h * 2 + n * mhc_mult * 4 + n * mhc_mult * mhc_mult * 4 ) write_bytes = n * mhc_mult * h * 2 return read_bytes + write_bytes def _estimate_bwd_io_bytes(num_tokens: int, h: int, mhc_mult: int) -> int: n = num_tokens read_bytes = ( n * mhc_mult * h * 2 + n * h * 2 + n * mhc_mult * h * 2 + n * mhc_mult * 4 + n * mhc_mult * mhc_mult * 4 ) write_bytes = n * h * 2 + n * mhc_mult * h * 2 + n * mhc_mult * 4 + n * mhc_mult * mhc_mult * 4 return read_bytes + write_bytes @pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA is required') @pytest.mark.benchmark @pytest.mark.parametrize( 'num_tokens,h,mhc_mult', [ # (1, 4096, 1280, 4), # (1, 4096, 2560, 4), # (1, 4096, 7168, 4), # (2, 4096, 2560, 4), # h = 7168 (1, 7168, 4), (32, 7168, 4), (64, 7168, 4), (128, 7168, 4), (256, 7168, 4), (512, 7168, 4), (1024, 7168, 4), (2048, 7168, 4), (8192, 7168, 4), (65536, 7168, 4), # h = 4096 (1, 4096, 4), (32, 4096, 4), (64, 4096, 4), (128, 4096, 4), (256, 4096, 4), (512, 4096, 4), (1024, 4096, 4), (2048, 4096, 4), (8192, 4096, 4), (65536, 4096, 4), ], ) def test_mhc_post_fwd_benchmark( num_tokens: int, h: int, mhc_mult: int, benchmark_timer, benchmark_record, ) -> None: test_data = generate_mhc_post_test_data(num_tokens=num_tokens, h=h, mhc_mult=mhc_mult) x = test_data['x'] residual = test_data['residual'] post_layer_mix = test_data['post_layer_mix'] comb_res_mix = test_data['comb_res_mix'] def fn_fwd() -> torch.Tensor: return mhc_post(x, residual, post_layer_mix, comb_res_mix) fn_fwd() t_tl_us = benchmark_timer(fn_fwd) io_bytes = _estimate_fwd_io_bytes(num_tokens, h, mhc_mult) bw_tl_gbs = io_bytes / t_tl_us / 1e3 benchmark_record( kernel='mhc_post', operation='fwd', params={'num_tokens': num_tokens, 'h': h, 'mhc_mult': mhc_mult}, time_us=t_tl_us, bandwidth_gbs=bw_tl_gbs, extras={'num_tokens': num_tokens, 'io_bytes': io_bytes}, )