# ruff: noqa import pytest import torch from aiter.ops.tilelang import mhc_post_fwd def mhc_post_ref( x: torch.Tensor, residual: torch.Tensor, post_layer_mix: torch.Tensor, comb_res_mix: torch.Tensor, ) -> torch.Tensor: term2 = torch.einsum("tmn,tmc->tnc", comb_res_mix, residual.float()) return (x.float().unsqueeze(-2) * post_layer_mix.unsqueeze(-1) + term2).bfloat16() def generate_mhc_post_test_data( num_tokens: int, h: int, mhc_mult: int, device: str = "cuda", ) -> dict[str, torch.Tensor]: x = torch.randn((num_tokens, h), dtype=torch.bfloat16, device=device) residual = torch.randn((num_tokens, mhc_mult, h), dtype=torch.bfloat16, device=device) post_layer_mix = torch.randn((num_tokens, mhc_mult), dtype=torch.float32, device=device) comb_res_mix = torch.randn((num_tokens, mhc_mult, mhc_mult), dtype=torch.float32, device=device) return { "x": x, "residual": residual, "post_layer_mix": post_layer_mix, "comb_res_mix": comb_res_mix, } @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") @pytest.mark.parametrize("num_tokens", [1, 16, 32, 4096]) @pytest.mark.parametrize("h", [1280, 2560, 4096, 7168]) def test_mhc_post_fwd_correctness(num_tokens: int, h: int) -> None: td = generate_mhc_post_test_data(num_tokens=num_tokens, h=h, mhc_mult=4) out_tl = mhc_post_fwd( td["x"].contiguous(), td["residual"].contiguous(), td["post_layer_mix"].contiguous(), td["comb_res_mix"].contiguous(), ) out_ref = mhc_post_ref( td["x"], td["residual"], td["post_layer_mix"], td["comb_res_mix"], ) torch.testing.assert_close(out_tl, out_ref, atol=1e-2, rtol=1e-2)