indexer_bwd.py 9.76 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import torch.nn.functional as F
from einops import einsum, repeat

import tilelang as tl
import tilelang.language as T
from typing import Optional
from index import prepare_token_indices

from utils import get_abs_err, get_err_ratio

BF16 = "bfloat16"
FP32 = "float32"
INT32 = "int32"

pass_configs = {
    tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
    tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}


@tl.jit(pass_configs=pass_configs)
def tl_indexer_bwd_impl(
    heads: int,
    dim: int,
    topk: int,
    sm_scale: Optional[float] = None,
    block_I: int = 32,
    num_stages: int = 0,
    num_threads: int = 128,
):
    assert num_stages == 0
    assert topk == tl.math.next_power_of_2(topk)
    assert topk % block_I == 0
    assert heads <= 64 and heads % 8 == 0
    batch_plus_one = T.symbolic("batch_plus_one")
    seq_len = T.symbolic("seq_len")
    dtype: str = BF16
    accum_dtype: str = FP32
    index_q_shape = [seq_len, heads, dim]
    weights_shape = [seq_len, heads]
    index_k_shape = [seq_len, dim]
    shape_p = [seq_len, topk]
    topk_indices_shape = [seq_len, topk]
    offsets_shape = [batch_plus_one]
    token_indices_shape = [seq_len, 2]
    if sm_scale is None:
        sm_scale = dim**-0.5

    @T.prim_func
    def tl_indexer_bwd_kernel(
52
53
54
55
56
57
58
59
60
61
62
        IndexQ: T.Tensor(index_q_shape, dtype),
        Weights: T.Tensor(weights_shape, dtype),
        IndexK: T.Tensor(index_k_shape, dtype),
        dIndexQ: T.Tensor(index_q_shape, dtype),
        dWeights: T.Tensor(weights_shape, dtype),
        dIndexK: T.Tensor(index_k_shape, dtype),
        AttnScore: T.Tensor(shape_p, FP32),
        IndexScore: T.Tensor(shape_p, FP32),
        TopkIndices: T.Tensor(topk_indices_shape, INT32),
        Offsets: T.Tensor(offsets_shape, INT32),
        TokenIndices: T.Tensor(token_indices_shape, INT32),
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    ):
        with T.Kernel(seq_len, threads=num_threads) as (bx):
            i_b, i_t = TokenIndices[bx, 0], TokenIndices[bx, 1]
            bos = Offsets[i_b]
            num_blocks = T.ceildiv(topk, block_I)

            index_q_shared = T.alloc_shared([heads, dim], dtype=dtype)
            weights_shared = T.alloc_shared([heads], dtype=dtype)

            d_index_q_frag = T.alloc_fragment([heads, dim], dtype=accum_dtype)
            d_weights_frag = T.alloc_fragment([heads], dtype=accum_dtype)

            T.copy(IndexQ[bos + i_t, :, :], index_q_shared)
            T.copy(Weights[bos + i_t, :], weights_shared)
            T.fill(d_index_q_frag, 0)
            T.fill(d_weights_frag, 0)

            for i, j in T.Parallel(heads, dim):
                index_q_shared[i, j] = index_q_shared[i, j] * sm_scale

            for bi_i in T.Pipelined(num_blocks, num_stages=num_stages):
                i_st = bi_i * block_I
                i_ed = (bi_i + 1) * block_I

                indices_shared = T.alloc_shared([block_I], dtype=INT32)
                T.copy(TopkIndices[bos + i_t, i_st:i_ed], indices_shared)

                index_k_shared = T.alloc_shared([block_I, dim], dtype=dtype)
                for i, j in T.Parallel(block_I, dim):
                    pos = indices_shared[i]
93
                    index_k_shared[i, j] = T.if_then_else((pos > -1) & (pos <= i_t), IndexK[bos + pos, j], 0)
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

                attn_score_shared = T.alloc_shared([block_I], dtype=accum_dtype)
                index_score_shared = T.alloc_shared([block_I], dtype=accum_dtype)
                for i in T.Parallel(block_I):
                    attn_score_shared[i] = AttnScore[bos + i_t, i_st + i]
                    index_score_shared[i] = IndexScore[bos + i_t, i_st + i]

                logits = T.alloc_fragment((block_I, heads), accum_dtype)
                T.gemm(
                    index_k_shared,
                    index_q_shared,
                    logits,
                    transpose_A=False,
                    transpose_B=True,
                    clear_accum=True,
                )
                for i, j in T.Parallel(block_I, heads):
                    logits[i, j] = T.max(logits[i, j], 0)

                # dw
                d_weights_i = T.alloc_fragment((block_I, heads), accum_dtype)
                for i, j in T.Parallel(block_I, heads):
116
                    d_weights_i[i, j] = (index_score_shared[i] - attn_score_shared[i]) * logits[i, j]
117
118
119
120
121
122
123
124
125
126
127
128
                T.reduce_sum(d_weights_i, d_weights_frag, dim=0, clear=False)

                d_logits_qk = T.alloc_shared((block_I, heads), accum_dtype)
                d_logits_qk_cast1 = T.alloc_fragment((block_I, heads), dtype)
                d_logits_qk_cast2 = T.alloc_fragment((block_I, heads), dtype)

                for i, j in T.Parallel(block_I, heads):
                    d_relu = T.alloc_var(accum_dtype)
                    if logits[i, j] > 0:
                        d_relu = 1.0
                    else:
                        d_relu = 0.0
129
                    d_logits_qk[i, j] = (index_score_shared[i] - attn_score_shared[i]) * d_relu * weights_shared[j]
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155

                # dq
                T.copy(d_logits_qk, d_logits_qk_cast1)
                T.gemm(
                    d_logits_qk_cast1,  # [BS, HQ]
                    index_k_shared,  # [BS, K]
                    d_index_q_frag,  # [HQ, K]
                    transpose_A=True,
                    transpose_B=False,
                    clear_accum=False,
                )

                # dk
                T.copy(d_logits_qk, d_logits_qk_cast2)
                d_index_k_frag = T.alloc_fragment([block_I, dim], dtype=accum_dtype)
                T.gemm(
                    d_logits_qk_cast2,  # [BS, HQ]
                    index_q_shared,  # [HQ, K]
                    d_index_k_frag,  # [BS, K]
                    transpose_A=False,
                    transpose_B=False,
                    clear_accum=True,
                )

                for i, j in T.Parallel(block_I, dim):
                    pos = indices_shared[i]
156
                    if (pos > -1) & (pos <= i_t):
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
                        T.atomic_add(dIndexK[bos + pos, j], d_index_k_frag[i, j])

            for i, j in T.Parallel(heads, dim):
                d_index_q_frag[i, j] = d_index_q_frag[i, j] * sm_scale

            T.copy(d_index_q_frag, dIndexQ[bos + i_t, :, :])
            T.copy(d_weights_frag, dWeights[bos + i_t, :])

    return tl_indexer_bwd_kernel


def indexer_bwd_interface(
    q: torch.Tensor,
    weights: torch.Tensor,
    k: torch.Tensor,
    attn_score: torch.Tensor,
    index_score: torch.Tensor,
    topk_indices: torch.Tensor,
    offsets: torch.Tensor,
):
    _, heads, dim, topk = *q.shape, topk_indices.shape[-1]
    token_indices = prepare_token_indices(offsets)
    dq = torch.zeros_like(q)
    dweights = torch.zeros_like(weights)
    dk = torch.zeros_like(k)
    kernel = tl_indexer_bwd_impl(heads, dim, topk)
183
    kernel(q, weights, k, dq, dweights, dk, attn_score, index_score, topk_indices, offsets, token_indices)
184
185
186
    return dq, dweights, dk


187
188
189
def ref_indexer_bwd(
    Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, TopkIndices: torch.Tensor, AttnScore: torch.Tensor, offsets: torch.Tensor
) -> torch.Tensor:
190
191
192
    Q.requires_grad_(True)
    Weights.requires_grad_(True)
    K.requires_grad_(True)
193
    softmax_scale = Q.shape[-1] ** -0.5
194
195
196
197
    all_loss = []
    all_log_topk_prob = []
    for i in range(offsets.shape[0] - 1):
        assert (offsets[i + 1] - offsets[i]).item() >= TopkIndices.shape[-1]
198
199
200
201
202
        q = Q[offsets[i] : offsets[i + 1]]
        weights = Weights[offsets[i] : offsets[i + 1]]
        k = K[offsets[i] : offsets[i + 1]]
        topk_indices = TopkIndices[offsets[i] : offsets[i + 1]]
        attn_score = AttnScore[offsets[i] : offsets[i + 1]]
203
204
        s = q.shape[0]
        mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device)
205
        logits = einsum(q, k, "s1 h k, s2 k -> s1 h s2") * softmax_scale
206
207
        logits = F.relu(logits)
        score = (logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32)
208
        score = torch.where(mask, score, float("-inf"))
209
210
        topk_value = torch.gather(score, dim=-1, index=topk_indices.to(torch.int64))
        log_topk_prob = F.log_softmax(topk_value, dim=-1, dtype=torch.float32)
211
        loss = F.kl_div(log_topk_prob.clip(-100, 0), attn_score.log().clip(-100, 0), log_target=True, reduction="sum")
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
        all_loss.append(loss)
        all_log_topk_prob.append(log_topk_prob)
    loss = torch.stack(all_loss).sum()
    loss.backward()
    log_topk_prob = torch.cat(all_log_topk_prob, dim=0)
    return log_topk_prob.exp(), Q.grad, Weights.grad, K.grad


def test_kernel(
    B=1,
    S=2048,
    H=16,
    D=128,
    topk=64,
):
    torch.manual_seed(42)
    q = torch.randn((S, H, D)).cuda().bfloat16()
    w = torch.randn((S, H)).cuda().bfloat16()
    k = torch.randn((S, D)).cuda().bfloat16()
    offsets = torch.tensor([0, 1023, S], dtype=torch.int32).cuda()

    all_attn_score = []
    for i in range(offsets.shape[0] - 1):
        seq_len = (offsets[i + 1] - offsets[i]).item()
        mask = (torch.arange(seq_len)[:, None] >= torch.arange(topk)[None, :]).to(q.device)
        logits = torch.ones(seq_len, topk).cuda()
238
        logits = torch.where(mask, logits, float("-inf"))
239
240
241
242
        attn_score = F.softmax(logits, dim=-1, dtype=torch.float32)
        all_attn_score.append(attn_score)
    attn_score = torch.cat(all_attn_score, dim=0)

243
244
    topk_indices = repeat(torch.arange(topk, dtype=torch.int32).cuda(), "k -> s k", s=S).contiguous()
    index_score, ref_dq, ref_dw, ref_dk = ref_indexer_bwd(q, w, k, topk_indices, attn_score, offsets)
245
246
247
248
249
250
251
252

    dq, dw, dk = indexer_bwd_interface(q, w, k, attn_score, index_score, topk_indices, offsets)

    print(f"dq err: {get_abs_err(dq, ref_dq):.6f} ratio: {get_err_ratio(dq, ref_dq):.6f}")
    print(f"dq err: {get_abs_err(dw, ref_dw):.6f} ratio: {get_err_ratio(dw, ref_dw):.6f}")
    print(f"dq err: {get_abs_err(dk, ref_dk):.6f} ratio: {get_err_ratio(dk, ref_dk):.6f}")


253
if __name__ == "__main__":
254
    test_kernel()