test_pre_big_fuse.py 5.05 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# ruff: noqa
import pytest
import torch

from aiter.ops.tilelang import mhc_pre_big_fuse


def _sinkhorn_matrix_torch(comb: torch.Tensor, sinkhorn_iters: int, eps: float) -> torch.Tensor:
    row_max = comb.max(dim=2, keepdim=True).values
    comb = torch.exp(comb - row_max)
    comb = comb / comb.sum(dim=2, keepdim=True) + eps
    comb = comb / (comb.sum(dim=1, keepdim=True) + eps)

    for _ in range(sinkhorn_iters - 1):
        comb = comb / (comb.sum(dim=2, keepdim=True) + eps)
        comb = comb / (comb.sum(dim=1, keepdim=True) + eps)
    return comb


def generate_big_fuse_test_data(
    n1: int,
    mhc_mult: int,
    hidden_size: int,
    rms_eps: float = 1e-6,
    mhc_pre_eps: float = 1e-6,
    mhc_sinkhorn_eps: float = 1e-6,
    mhc_post_mult_value: float = 1.0,
    sinkhorn_repeat: int = 10,
    n_splits: int = 16,
) -> dict[str, torch.Tensor | float]:
    n0 = 1
    mhc_mult2 = mhc_mult * mhc_mult
    mhc_mult3 = mhc_mult * 2 + mhc_mult2
    device = "cuda"

    residual = (
        torch.randn((n0, n1, mhc_mult, hidden_size), dtype=torch.float, device=device)
        .mul(1 + torch.arange(mhc_mult, device=device).mul(0.01).view(1, 1, -1, 1))
        .bfloat16()
    )

    fn = (
        torch.randn((mhc_mult3, mhc_mult, hidden_size), dtype=torch.float, device=device)
        * 1e-4
        * (1 + torch.arange(mhc_mult, device=device).mul(0.01).view(1, -1, 1))
    ).flatten(1, 2)

    mhc_scale = torch.randn((3,), dtype=torch.float, device=device) * 0.1
    mhc_base = torch.randn((mhc_mult3,), dtype=torch.float, device=device) * 0.1

    return {
        "residual": residual,
        "fn": fn,
        "mhc_scale": mhc_scale,
        "mhc_base": mhc_base,
        "rms_eps": rms_eps,
        "mhc_pre_eps": mhc_pre_eps,
        "mhc_sinkhorn_eps": mhc_sinkhorn_eps,
        "mhc_post_mult_value": mhc_post_mult_value,
        "sinkhorn_repeat": sinkhorn_repeat,
        "n_splits": n_splits,
    }


def big_fuse_reference(
    residual: torch.Tensor,
    fn: torch.Tensor,
    mhc_scale: torch.Tensor,
    mhc_base: torch.Tensor,
    rms_eps: float,
    mhc_pre_eps: float,
    mhc_sinkhorn_eps: float,
    mhc_post_mult_value: float,
    sinkhorn_repeat: int,
    n_splits: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    del n_splits

    mhc_mult = residual.shape[-2]
    hidden_size = residual.shape[-1]
    mhc_mult3 = mhc_mult * (2 + mhc_mult)

    residual_flat = residual.view(-1, mhc_mult, hidden_size)
    num_tokens = residual_flat.shape[0]

    x_flat = residual_flat.view(num_tokens, mhc_mult * hidden_size).float()
    rsqrt = torch.rsqrt(x_flat.square().mean(-1, keepdim=True) + rms_eps)
    mixes = torch.matmul(x_flat, fn.t()) * rsqrt

    pre_mix = torch.sigmoid(mixes[:, :mhc_mult] * mhc_scale[0] + mhc_base[:mhc_mult]) + mhc_pre_eps
    post_mix = (
        torch.sigmoid(mixes[:, mhc_mult : 2 * mhc_mult] * mhc_scale[1] + mhc_base[mhc_mult : 2 * mhc_mult])
        * mhc_post_mult_value
    )
    comb_mix = mixes[:, 2 * mhc_mult : mhc_mult3].view(num_tokens, mhc_mult, mhc_mult) * mhc_scale[2]
    comb_mix = comb_mix + mhc_base[2 * mhc_mult : mhc_mult3].view(1, mhc_mult, mhc_mult)
    comb_mix = _sinkhorn_matrix_torch(comb_mix, sinkhorn_repeat, mhc_sinkhorn_eps)

    layer_input = torch.einsum("nh,nhd->nd", pre_mix, residual_flat.float()).to(torch.bfloat16)

    outer_shape = residual.shape[:-2]
    post_mix = post_mix.view(*outer_shape, mhc_mult, 1)
    comb_mix = comb_mix.view(*outer_shape, mhc_mult, mhc_mult)
    layer_input = layer_input.view(*outer_shape, hidden_size)

    return post_mix, comb_mix, layer_input


@pytest.mark.parametrize("n1", [1, 34, 65, 128, 133, 288, 577, 1010, 2722, 4572, 8192, 9217, 21111])
@pytest.mark.parametrize("hidden_size", [1280, 2560, 4096, 7168])
@pytest.mark.parametrize("mhc_mult", [4])
def test_correctness(
    n1: int,
    hidden_size: int,
    mhc_mult: int,
) -> None:
    test_data = generate_big_fuse_test_data(
        n1=n1,
        mhc_mult=mhc_mult,
        hidden_size=hidden_size,
    )

    post_mix_fused, comb_mix_fused, layer_input_fused = mhc_pre_big_fuse(
        test_data["residual"],
        test_data["fn"],
        test_data["mhc_scale"],
        test_data["mhc_base"],
        rms_eps=test_data["rms_eps"],
        mhc_pre_eps=test_data["mhc_pre_eps"],
        mhc_sinkhorn_eps=test_data["mhc_sinkhorn_eps"],
        mhc_post_mult_value=test_data["mhc_post_mult_value"],
        sinkhorn_repeat=test_data["sinkhorn_repeat"],
        n_splits=test_data["n_splits"],
    )

    post_mix_ref, comb_mix_ref, layer_input_ref = big_fuse_reference(
        test_data["residual"],
        test_data["fn"],
        test_data["mhc_scale"],
        test_data["mhc_base"],
        test_data["rms_eps"],
        test_data["mhc_pre_eps"],
        test_data["mhc_sinkhorn_eps"],
        test_data["mhc_post_mult_value"],
        test_data["sinkhorn_repeat"],
        test_data["n_splits"],
    )

    assert torch.allclose(post_mix_fused, post_mix_ref, rtol=1e-5, atol=1e-6)
    assert torch.allclose(comb_mix_fused, comb_mix_ref, rtol=1e-5, atol=1e-6)
    assert torch.allclose(layer_input_fused, layer_input_ref, rtol=1e-2, atol=2e-3)