example_triton_nsa_fwd.py 7.59 KB
Newer Older
1
2
3
4
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa
import torch
5
from typing import Optional, Union
6
7
8
9
10
11
12
13
14

import torch
import triton
import triton.language as tl
from einops import rearrange

from fla.ops.common.utils import (prepare_chunk_indices, prepare_lens, prepare_token_indices)
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous

15
from reference import naive_nsa, naive_nsa_simple
16
17


18
19
20
21
@triton.heuristics({
    'USE_OFFSETS': lambda args: args['offsets'] is not None,
    'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor),
})
22
@triton.autotune(
23
    configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1]],
24
25
26
    key=['BS', 'BK', 'BV'],
)
@triton.jit
27
28
29
30
31
32
33
def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices,
                            block_counts, offsets, token_indices, T, H: tl.constexpr,
                            HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr,
                            S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr,
                            BV: tl.constexpr, USE_OFFSETS: tl.constexpr,
                            USE_BLOCK_COUNTS: tl.constexpr):
    i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
34
35
36
37
38
39
40
41
    i_b, i_h = i_bh // H, i_bh % H

    bos, eos = i_b * T, i_b * T + T

    k += (bos * H + i_h) * K
    v += (bos * H + i_h) * V
    block_indices += (bos + i_t) * H * S + i_h * S

42
43
44
    # if USE_BLOCK_COUNTS:
    #     NS = tl.load(block_counts + (bos + i_t) * H + i_h)
    # else:
45
46
47
48
49
50
51
52
    NS = S

    p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK),
                            (1, 0))
    # the Q block is kept in the shared memory throughout the whole kernel
    # [G, BK]
    b_q = tl.load(p_q, boundary_check=(0, 1))
    b_q = (b_q * scale).to(b_q.dtype)
53
54
55
56

    p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV),
                                (G, BV), (1, 0))
    p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
57
    # [G, BV]
58
    b_o_slc = tl.zeros([G, BV], dtype=tl.float32)
59

60
61
    b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32)
    b_acc_slc = tl.zeros([G], dtype=tl.float32)
62
63
    for i in range(NS):
        i_s = tl.load(block_indices + i).to(tl.int32) * BS
64
65
66
        if i_s <= i_t and i_s >= 0:
            p_k_slc = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1))
            p_v_slc = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
67
            # [BK, BS]
68
            b_k_slc = tl.load(p_k_slc, boundary_check=(0, 1))
69
            # [BS, BV]
70
            b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1))
71
            # [G, BS]
72
73
74
75
            b_s_slc = tl.dot(b_q, b_k_slc)
            if i_t == 6:
                print("b_s_slc", b_s_slc)
            b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf'))
76
77

            # [G]
78
79
            b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc
            b_r_slc = tl.exp(b_mp_slc - b_m_slc)
80
            # [G, BS]
81
            b_p_slc = tl.exp(b_s_slc - b_m_slc[:, None])
82
            # [G]
83
            b_acc_slc = b_acc_slc * b_r_slc + tl.sum(b_p_slc, 1)
84
            # [G, BV]
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
            b_o_slc = b_o_slc * b_r_slc[:, None] + tl.dot(b_p_slc.to(b_q.dtype), b_v_slc)

            b_mp_slc = b_m_slc
    b_o_slc = b_o_slc / b_acc_slc[:, None]
    b_m_slc += tl.log(b_acc_slc)
    tl.store(p_o_slc, b_o_slc.to(p_o_slc.dtype.element_ty), boundary_check=(0, 1))
    tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty))


class ParallelNSAFunction(torch.autograd.Function):

    @staticmethod
    @contiguous
    @autocast_custom_fwd
    def forward(ctx, q, k, v, block_indices, block_size, scale, offsets):
        ctx.dtype = q.dtype

        # 2-d sequence indices denoting the offsets of tokens in each sequence
        # for example, if the passed `offsets` is [0, 2, 6],
        # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
        # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
        token_indices = prepare_token_indices(offsets) if offsets is not None else None
107

108
109
110
111
112
113
114
        o, lse = parallel_nsa_fwd(
            q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale)
        ctx.save_for_backward(q, k, v, o, lse)
        ctx.block_indices = block_indices
        ctx.block_size = block_size
        ctx.scale = scale
        return o.to(q.dtype)
115
116
117
118
119
120


def parallel_nsa_fwd(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
121
122
    block_indices: torch.LongTensor,
    block_counts: Union[torch.LongTensor, int],
123
    block_size: int,
124
125
126
    window_size: int,
    offsets: Optional[torch.LongTensor] = None,
    token_indices: Optional[torch.LongTensor] = None,
127
):
128
    import math
129
130
    B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
    HQ = q.shape[2]
131
    scale = 1.0 / math.sqrt(K)
132
133
    G = HQ // H
    BS = block_size
134
    WS = window_size
135
136
137
138
139
140
141
142
143
144
    if torch.cuda.get_device_capability()[0] >= 9:
        BK = min(256, triton.next_power_of_2(K))
        BV = min(256, triton.next_power_of_2(V))
    else:
        BK = min(128, triton.next_power_of_2(K))
        BV = min(128, triton.next_power_of_2(V))
    NK = triton.cdiv(K, BK)
    NV = triton.cdiv(V, BV)
    assert NK == 1, "The key dimension can not be larger than 256"

145
146
147
148
149
150
    grid = (T, NV, B * H)
    o_slc = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
    o_swa = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) if window_size > 0 else None
    lse_slc = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
    lse_swa = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) if window_size > 0 else None

151
152
153
154
    parallel_nsa_fwd_kernel[grid](
        q=q,
        k=k,
        v=v,
155
156
157
158
        o_slc=o_slc,
        o_swa=o_swa,
        lse_slc=lse_slc,
        lse_swa=lse_swa,
159
160
        scale=scale,
        block_indices=block_indices,
161
162
163
164
        block_counts=block_counts,
        offsets=offsets,
        token_indices=token_indices,
        T=T,
165
166
167
168
169
170
171
        H=H,
        HQ=HQ,
        G=G,
        K=K,
        V=V,
        S=S,
        BS=BS,
172
        WS=WS,
173
174
175
        BK=BK,
        BV=BV,
    )
176
    return o_slc, lse_slc, o_swa, lse_swa
177
178
179


if __name__ == "__main__":
180
181
    B, T, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1
    torch.random.manual_seed(0)
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    q = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
    k = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True)
    v = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True)
    do = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda')

    block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device='cuda')
    for b in range(B):
        for t in range(T):
            for h in range(H):
                i_i = torch.randperm(max(1, (t // block_size)))[:S]
                block_indices[b, t, h, :len(i_i)] = i_i
    block_indices = block_indices.sort(-1)[0]

    block_counts = torch.randint(1, S + 1, (B, T, H), device='cuda')

197
    ref = naive_nsa_simple(
198
199
200
201
202
203
        q=q,
        k=k,
        v=v,
        block_indices=block_indices,
        block_counts=block_counts,
        block_size=block_size,
204
    )
205

206
207
208
209
210
211
212
213
    tri, _, _, _ = parallel_nsa_fwd(
        q=q,
        k=k,
        v=v,
        block_indices=block_indices,
        block_size=block_size,
        window_size=0,
        block_counts=block_counts)
214

215
216
    print("tri", tri)
    print("ref", ref)
217
218

    torch.testing.assert_close(ref, tri, atol=1e-2, rtol=1e-2)