example_linear_attn_fwd.py 5.8 KB
Newer Older
1
import torch
2
import tilelang
3
4
5
6
import tilelang.language as T
from tilelang.profiler import do_bench
import argparse
from fla.ops.linear_attn import fused_chunk_linear_attn  # We compare with FLA
7
8
9
from fla.modules.l2norm import l2norm_fwd
from einops import rearrange
from typing import Optional, Tuple
10
11


12
13
@tilelang.jit(
    out_idx=[4],
14
    pass_configs={
15
16
        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
17
    })
18
def tl_fused_chunk_fwd_kernel(
19
20
21
22
23
24
25
26
27
28
29
30
31
32
    B,
    S,
    H,
    DK,
    DV,
    dtype: str = 'float16',
    scale: float = None,
) -> torch.Tensor:

    if scale is None:
        scale = DK**-0.5
    accum_dtype = 'float'

    chunk_size = 64
33
    BK = BV = 64  # Set to 128 can be faster, but has some numerical differences with FLA
34
    assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0
35
36
37
    NK = tilelang.cdiv(DK, BK)
    NV = tilelang.cdiv(DV, BV)
    NT = tilelang.cdiv(S, chunk_size)
38
39

    @T.prim_func
40
    def fused_chunk_linear_attn_fwd(
41
42
43
            Q: T.Tensor([B, S, H, DK], dtype),  # type: ignore
            K: T.Tensor([B, S, H, DK], dtype),  # type: ignore
            V: T.Tensor([B, S, H, DV], dtype),  # type: ignore
44
            O: T.Tensor([B, S, H, DV], accum_dtype),  # type: ignore
45
            final_state: T.Tensor([B, H, DK, DV], accum_dtype)):  # type: ignore
46
47
48
49
50
51
52
53
54
55
56
57
        with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh):
            i_b = i_bh // H
            i_h = i_bh % H

            q = T.alloc_shared([chunk_size, BK], dtype)
            k = T.alloc_shared([chunk_size, BK], dtype)
            v = T.alloc_shared([chunk_size, BV], dtype)
            h = T.alloc_fragment([BK, BV], accum_dtype)
            h_shared = T.alloc_shared([BK, BV], dtype)
            s = T.alloc_fragment([chunk_size, chunk_size], accum_dtype)
            s_shared = T.alloc_shared([chunk_size, chunk_size], dtype)
            o = T.alloc_fragment([chunk_size, BV], accum_dtype)
58
            o_shared = T.alloc_shared([chunk_size, BV], accum_dtype)
59

60
            T.annotate_layout({o_shared: tilelang.layout.make_swizzled_layout(o_shared)})
61
            T.use_swizzle(10)
62

63
64
65
            T.clear(h)

            for i in T.Pipelined(0, NT):
66
67
68
69
70
71
72
73
74
75
76
77
                for row, col in T.Parallel(chunk_size, BK):
                    q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale
                T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k)
                T.copy(V[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v)

                T.gemm(q, k, s, clear_accum=True, transpose_B=True)
                for row, col in T.Parallel(chunk_size, chunk_size):
                    s_shared[row, col] = T.if_then_else(row >= col, s[row, col], 0)

                T.gemm(s_shared, v, o, clear_accum=True)
                T.copy(h, h_shared)
                T.gemm(k, v, h, transpose_A=True)
78
                T.gemm(q, h_shared, o)
79
80
81
82
                T.copy(o, o_shared)
                T.atomic_add(
                    O[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV],
                    o_shared)
83
84
85
86

            # Output final state
            T.copy(h, final_state[i_b, i_h, i_k * BK:(i_k + 1) * BK, i_v * BV:(i_v + 1) * BV])

87
    return fused_chunk_linear_attn_fwd
88
89


90
91
92
def tl_fused_chunk_fwd(q, k, v):
    B, S, H, D = q.shape
    kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D)
93
    print(kernel.get_kernel_source())
94
95
    o = torch.zeros((B, S, H, D), device='cuda', dtype=torch.float32)
    h = kernel(q, k, v, o)
96
97
98
    return o, h


99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def ref_program(q: torch.Tensor,
                k: torch.Tensor,
                v: torch.Tensor,
                scale: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]:
    q, k, v = q.float(), k.float(), v.float()
    if scale is None:
        scale = q.shape[-1]**-0.5
    chunk_size = 64
    q = rearrange(q, 'b (n c) h d -> b h n c d', c=chunk_size) * scale
    k = rearrange(k, 'b (n c) h d -> b h n c d', c=chunk_size)
    v = rearrange(v, 'b (n c) h d -> b h n c d', c=chunk_size)
    kv = k.transpose(-1, -2) @ v
    kv = kv.cumsum(2)
    h = kv[:, :, -1, :, :]
    kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
    inter = q @ kv
    intra = ((q @ k.transpose(-1, -2)).masked_fill_(
        torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1),
        0)) @ v
    o = inter + intra
    return rearrange(o, 'b h n c d -> b (n c) h d'), h


def main(B=1, S=512, H=16, D=128):
123
124
125
126
    q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16)
    k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16)
    v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16)

127
128
129
    # qk norm is necessary for linear attn
    q, _ = l2norm_fwd(q)
    k, _ = l2norm_fwd(k)
130

131
132
133
134
135
136
    o, h = tl_fused_chunk_fwd(q, k, v)
    o_ref, h_ref = ref_program(q, k, v)

    assert torch.allclose(o, o_ref, atol=1e-2, rtol=1e-2), f'o max err: {(o - o_ref).abs().max()}'
    assert torch.allclose(h, h_ref, atol=1e-2, rtol=1e-2), f'h max err: {(h - h_ref).abs().max()}'
    print('Passed all tests!✅')
137
138

    t1 = do_bench(
139
140
141
        lambda: fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False),
        backend='cupti')
    t2 = do_bench(lambda: tl_fused_chunk_fwd(q, k, v), backend='cupti')
142
143
144
145
146
147
    print(f'Triton latency: {t1:.3f} ms')
    print(f'TileLang latency: {t2:.3f} ms')
    print(f'Speedup: {t1/t2:.3f}x')


if __name__ == '__main__':
148
149
150
151
152
153
154
155
    parser = argparse.ArgumentParser()
    parser.add_argument('--B', type=int, default=8, help='Batch size')
    parser.add_argument('--S', type=int, default=1024, help='Seq len')
    parser.add_argument('--H', type=int, default=32, help='Num heads')
    parser.add_argument('--D', type=int, default=128, help='Head dim')
    args = parser.parse_args()

    main(args.B, args.S, args.H, args.D)