example_linear_attn_fwd.py 5.7 KB
Newer Older
1
import torch
2
import tilelang
3
4
5
6
import tilelang.language as T
from tilelang.profiler import do_bench
import argparse
from fla.ops.linear_attn import fused_chunk_linear_attn  # We compare with FLA
7
8
9
from fla.modules.l2norm import l2norm_fwd
from einops import rearrange
from typing import Optional, Tuple
10
11


12
13
@tilelang.jit(
    out_idx=[4],
14
    pass_configs={
15
16
        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
17
18
    },
)
19
def tl_fused_chunk_fwd_kernel(
20
21
22
23
24
    B,
    S,
    H,
    DK,
    DV,
25
    dtype: str = "float16",
26
27
28
29
    scale: float = None,
) -> torch.Tensor:
    if scale is None:
        scale = DK**-0.5
30
    accum_dtype = "float"
31
32

    chunk_size = 64
33
    BK = BV = 64  # Set to 128 can be faster, but has some numerical differences with FLA
34
    assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0
35
36
37
    NK = tilelang.cdiv(DK, BK)
    NV = tilelang.cdiv(DV, BV)
    NT = tilelang.cdiv(S, chunk_size)
38
39

    @T.prim_func
40
    def fused_chunk_linear_attn_fwd(
41
42
43
44
45
46
        Q: T.Tensor([B, S, H, DK], dtype),  # type: ignore
        K: T.Tensor([B, S, H, DK], dtype),  # type: ignore
        V: T.Tensor([B, S, H, DV], dtype),  # type: ignore
        O: T.Tensor([B, S, H, DV], accum_dtype),  # type: ignore
        final_state: T.Tensor([B, H, DK, DV], accum_dtype),
    ):  # type: ignore
47
48
49
50
51
52
53
54
55
56
57
58
        with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh):
            i_b = i_bh // H
            i_h = i_bh % H

            q = T.alloc_shared([chunk_size, BK], dtype)
            k = T.alloc_shared([chunk_size, BK], dtype)
            v = T.alloc_shared([chunk_size, BV], dtype)
            h = T.alloc_fragment([BK, BV], accum_dtype)
            h_shared = T.alloc_shared([BK, BV], dtype)
            s = T.alloc_fragment([chunk_size, chunk_size], accum_dtype)
            s_shared = T.alloc_shared([chunk_size, chunk_size], dtype)
            o = T.alloc_fragment([chunk_size, BV], accum_dtype)
59
            o_shared = T.alloc_shared([chunk_size, BV], accum_dtype)
60

61
            T.annotate_layout({o_shared: tilelang.layout.make_swizzled_layout(o_shared)})
62
            T.use_swizzle(10)
63

64
65
66
            T.clear(h)

            for i in T.Pipelined(0, NT):
67
68
                for row, col in T.Parallel(chunk_size, BK):
                    q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale
69
70
                T.copy(K[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_k * BK : (i_k + 1) * BK], k)
                T.copy(V[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], v)
71
72
73
74
75
76
77
78

                T.gemm(q, k, s, clear_accum=True, transpose_B=True)
                for row, col in T.Parallel(chunk_size, chunk_size):
                    s_shared[row, col] = T.if_then_else(row >= col, s[row, col], 0)

                T.gemm(s_shared, v, o, clear_accum=True)
                T.copy(h, h_shared)
                T.gemm(k, v, h, transpose_A=True)
79
                T.gemm(q, h_shared, o)
80
                T.copy(o, o_shared)
81
                T.atomic_add(O[i_b, i * chunk_size : (i + 1) * chunk_size, i_h, i_v * BV : (i_v + 1) * BV], o_shared)
82
83

            # Output final state
84
            T.copy(h, final_state[i_b, i_h, i_k * BK : (i_k + 1) * BK, i_v * BV : (i_v + 1) * BV])
85

86
    return fused_chunk_linear_attn_fwd
87
88


89
90
91
def tl_fused_chunk_fwd(q, k, v):
    B, S, H, D = q.shape
    kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D)
92
    print(kernel.get_kernel_source())
93
    o = torch.zeros((B, S, H, D), device="cuda", dtype=torch.float32)
94
    h = kernel(q, k, v, o)
95
96
97
    return o, h


98
def ref_program(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]:
99
100
    q, k, v = q.float(), k.float(), v.float()
    if scale is None:
101
        scale = q.shape[-1] ** -0.5
102
    chunk_size = 64
103
104
105
    q = rearrange(q, "b (n c) h d -> b h n c d", c=chunk_size) * scale
    k = rearrange(k, "b (n c) h d -> b h n c d", c=chunk_size)
    v = rearrange(v, "b (n c) h d -> b h n c d", c=chunk_size)
106
107
108
109
110
    kv = k.transpose(-1, -2) @ v
    kv = kv.cumsum(2)
    h = kv[:, :, -1, :, :]
    kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
    inter = q @ kv
111
112
113
    intra = (
        (q @ k.transpose(-1, -2)).masked_fill_(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0)
    ) @ v
114
    o = inter + intra
115
    return rearrange(o, "b h n c d -> b (n c) h d"), h
116
117
118


def main(B=1, S=512, H=16, D=128):
119
120
121
    q = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16)
    k = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16)
    v = torch.randn((B, S, H, D), device="cuda", dtype=torch.float16)
122

123
124
125
    # qk norm is necessary for linear attn
    q, _ = l2norm_fwd(q)
    k, _ = l2norm_fwd(k)
126

127
128
129
    o, h = tl_fused_chunk_fwd(q, k, v)
    o_ref, h_ref = ref_program(q, k, v)

130
131
132
    assert torch.allclose(o, o_ref, atol=1e-2, rtol=1e-2), f"o max err: {(o - o_ref).abs().max()}"
    assert torch.allclose(h, h_ref, atol=1e-2, rtol=1e-2), f"h max err: {(h - h_ref).abs().max()}"
    print("Passed all tests!✅")
133

134
135
136
137
138
    t1 = do_bench(lambda: fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False), backend="cupti")
    t2 = do_bench(lambda: tl_fused_chunk_fwd(q, k, v), backend="cupti")
    print(f"Triton latency: {t1:.3f} ms")
    print(f"TileLang latency: {t2:.3f} ms")
    print(f"Speedup: {t1 / t2:.3f}x")
139
140


141
if __name__ == "__main__":
142
    parser = argparse.ArgumentParser()
143
144
145
146
    parser.add_argument("--B", type=int, default=8, help="Batch size")
    parser.add_argument("--S", type=int, default=1024, help="Seq len")
    parser.add_argument("--H", type=int, default=32, help="Num heads")
    parser.add_argument("--D", type=int, default=128, help="Head dim")
147
148
149
    args = parser.parse_args()

    main(args.B, args.S, args.H, args.D)