dsa.py 8.59 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from typing import Optional
import torch
import torch.nn.functional as F
from indexer_topk_reducesum import indexer_topk_reducesum_interface
from indexer_bwd import indexer_bwd_interface
from sparse_mla_fwd import sparse_mla_fwd_interface
from sparse_mla_bwd import sparse_mla_bwd
from sparse_mla_topk_reducesum import sparse_mla_topk_reducesum_interface
from einops import einsum, repeat
from utils import get_abs_err, get_err_ratio


class RegsiterLossFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, loss):
        ctx.save_for_backward(loss)
        return x

    @staticmethod
    def backward(ctx, grad):
        loss = ctx.saved_tensors
        return grad, torch.ones(1, dtype=loss[0].dtype, device=loss[0].device)


register_loss = RegsiterLossFunction.apply


def ref_deepseek_sparse_attention_innner(
    q: torch.Tensor,
    kv: torch.Tensor,
    index_q: torch.Tensor,
    index_k: torch.Tensor,
    weights: torch.Tensor,
    topk: int,
    dim_v: int,
    sm_scale: Optional[float] = None,
    index_sm_scale: Optional[float] = None,
):
    dtype = q.dtype
40
    q, kv, index_q, index_k, weights = map(lambda x: x.to(torch.float32), (q, kv, index_q, index_k, weights))
41

42
    index_sm_scale = index_q.shape[-1] ** -0.5
43
44
45
46
47
48
    b, s = index_q.shape[:2]

    # tl_topk_indices = tl_topk_indices.to(torch.int64)
    # tl_topk_indices[tl_topk_indices == -1] = s

    casual_mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device)
49
    index_logits = einsum(index_q, index_k, "b s1 h k, b s2 k -> b s1 h s2")
50
    index_logits = F.relu(index_logits)
51
52
    index_logits = (index_logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) * index_sm_scale
    index_logits = torch.where(casual_mask, index_logits, float("-inf"))
53
    topk_indices = torch.topk(index_logits, k=topk, dim=-1).indices
54
    topk_logits = torch.gather(F.pad(index_logits, (0, 1), value=float("-inf")), dim=-1, index=topk_indices)
55
56
57
58
    topk_score = F.log_softmax(topk_logits, dim=-1, dtype=torch.float32)
    index_topk_score = topk_score

    if sm_scale is None:
59
        sm_scale = kv.shape[-1] ** -0.5
60
61

    h = q.shape[-2]
62
63
64
65
    index_mask = torch.zeros((b, s, s + 1), dtype=torch.bool, device="cuda").scatter_(
        dim=-1, index=topk_indices, src=torch.ones_like(topk_indices, dtype=torch.bool)
    )[:, :, :-1]
    mask = repeat(casual_mask & index_mask, "b s1 s2 -> b s1 h s2", h=h)
66
    k, v = kv, kv[..., :dim_v]
67
68
    logits = einsum(q, k, "b s1 h d, b s2 d -> b s1 h s2") * sm_scale
    logits = torch.where(mask, logits, float("-inf"))
69
    attn_score = F.softmax(logits, dim=-1, dtype=torch.float32)
70
    o = einsum(attn_score, v, "b s1 h s2, b s2 d -> b s1 h d")
71
72
73
74
75

    attn_score = attn_score.sum(dim=-2)  # [b, s1, s2]
    attn_topk_score = torch.gather(F.pad(attn_score, (0, 1)), dim=-1, index=topk_indices)
    attn_topk_score = attn_topk_score / attn_topk_score.sum(dim=-1, keepdim=True)

76
    loss = F.kl_div(index_topk_score.clip(-100, 0), attn_topk_score.detach().log().clip(-100, 0), log_target=True, reduction="sum")
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    o = register_loss(o, loss)

    return o.to(dtype), topk_indices


def ref_deepseek_sparse_attention(
    q: torch.Tensor,
    kv: torch.Tensor,
    index_q: torch.Tensor,
    index_k: torch.Tensor,
    weights: torch.Tensor,
    offsets: torch.Tensor,
    topk: int,
    dim_v: int,
    sm_scale: Optional[float] = None,
    index_sm_scale: Optional[float] = None,
):
    all_o, all_topk_indices = [], []
    for i in range(offsets.shape[0] - 1):
        o, topk_indices = ref_deepseek_sparse_attention_innner(
97
98
99
100
101
            q[None, offsets[i] : offsets[i + 1]],
            kv[None, offsets[i] : offsets[i + 1]],
            index_q[None, offsets[i] : offsets[i + 1]],
            index_k[None, offsets[i] : offsets[i + 1]],
            weights[None, offsets[i] : offsets[i + 1]],
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
            topk,
            dim_v,
            sm_scale,
            index_sm_scale,
        )
        all_o.append(o.squeeze(0))
        all_topk_indices.append(topk_indices.squeeze(0))
    o = torch.cat(all_o, dim=0)
    topk_indices = torch.cat(all_topk_indices, dim=0)
    return o, topk_indices


class DSAFunction(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        q: torch.Tensor,
        kv: torch.Tensor,
        index_q: torch.Tensor,
        index_k: torch.Tensor,
        weights: torch.Tensor,
        offsets: torch.Tensor,
        topk: int,
        dim_v: int,
        sm_scale: Optional[float] = None,
    ):
        # topk_indices, index_score = ref_index_score(index_q, weights, index_k, topk)
129
130
131
        topk_indices, index_score = indexer_topk_reducesum_interface(index_q, weights, index_k, topk, offsets)
        o, lse = sparse_mla_fwd_interface(q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), offsets, sm_scale=sm_scale, d_v=dim_v)
        ctx.save_for_backward(q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, offsets)
132
133
134
135
136
137
138
139
140
141
142
143
144
        ctx.topk = topk
        ctx.dim_v = dim_v
        ctx.sm_scale = sm_scale
        return o, topk_indices

    @staticmethod
    def backward(
        ctx,
        do: torch.Tensor,
        _1: torch.Tensor,
    ):
        q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, offsets = ctx.saved_tensors
        attn_score = sparse_mla_topk_reducesum_interface(
145
146
147
148
            q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), lse, offsets, dim_v=ctx.dim_v
        ).squeeze(-2)
        dq, dkv = sparse_mla_bwd(q, kv.unsqueeze(-2), o, do, topk_indices.unsqueeze(-2), lse, offsets, sm_scale=ctx.sm_scale)
        dindex_q, dweights, dindex_k = indexer_bwd_interface(index_q, weights, index_k, attn_score, index_score, topk_indices, offsets)
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        return dq, dkv.squeeze(-2), dindex_q, dindex_k, dweights, None, None, None, None


def deepseek_sparse_attention(
    q: torch.Tensor,
    kv: torch.Tensor,
    index_q: torch.Tensor,
    index_k: torch.Tensor,
    weights: torch.Tensor,
    offsets: torch.Tensor,
    topk: int,
    dim_v: int,
    sm_scale: Optional[float] = None,
):
    return DSAFunction.apply(q, kv, index_q, index_k, weights, offsets, topk, dim_v, sm_scale)


def test_kernel(
    B=1,
    S=2048,
    H=16,
    D=512,
    tail_D=64,
    index_D=128,
    topk=64,
):
    torch.manual_seed(42)
    q = torch.randn((S, H, D + tail_D)).cuda().bfloat16().requires_grad_()
    kv = torch.randn((S, D + tail_D)).cuda().bfloat16().requires_grad_()
    index_q = torch.randn((S, H, index_D)).cuda().bfloat16().requires_grad_()
    weights = torch.randn((S, H)).cuda().bfloat16().requires_grad_()
    index_k = torch.randn((S, index_D)).cuda().bfloat16().requires_grad_()
    do = torch.randn((S, H, D)).cuda().bfloat16().requires_grad_()
    offsets = torch.tensor([0, S // 2, S], dtype=torch.int32).cuda()

    o, topk_indices = deepseek_sparse_attention(q, kv, index_q, index_k, weights, offsets, topk, D)
    o.backward(do)
    q_grad, q.grad = q.grad, None
    kv_grad, kv.grad = kv.grad, None
    index_q_grad, index_q.grad = index_q.grad, None
    index_k_grad, index_k.grad = index_k.grad, None
    weights_grad, weights.grad = weights.grad, None

192
    ref_o, ref_topk_indices = ref_deepseek_sparse_attention(q, kv, index_q, index_k, weights, offsets, topk, D)
193
194
195
196
197
198
199
200
    ref_o.backward(do)
    ref_q_grad, q.grad = q.grad, None
    ref_kv_grad, kv.grad = kv.grad, None
    ref_index_q_grad, index_q.grad = index_q.grad, None
    ref_index_k_grad, index_k.grad = index_k.grad, None
    ref_weights_grad, weights.grad = weights.grad, None

    print(f"o err: {get_abs_err(o, ref_o):.6f} ratio: {get_err_ratio(o, ref_o):.6f}")
201
202
    print(f"q.grad err: {get_abs_err(q_grad, ref_q_grad):.6f} ratio: {get_err_ratio(q_grad, ref_q_grad):.6f}")
    print(f"kv.grad err: {get_abs_err(kv_grad, ref_kv_grad):.6f} ratio: {get_err_ratio(kv_grad, ref_kv_grad):.6f}")
203
204
205
    print(
        f"index_q.grad err: {get_abs_err(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f} ratio: {get_err_ratio(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f}"
    )
206
207
    print(f"index_k.grad err: {get_abs_err(index_k_grad, ref_index_k_grad):.6f} ratio: {get_err_ratio(index_k_grad, ref_index_k_grad):.6f}")
    print(f"weights.grad err: {get_abs_err(weights_grad, ref_weights_grad):.6f} ratio: {get_err_ratio(weights_grad, ref_weights_grad):.6f}")
208
209
210
211
212
213

    intersections = []
    for j in range(S):
        ref_np = ref_topk_indices[j].cpu().to(torch.int32).numpy()
        trt_np = topk_indices[j].cpu().to(torch.int32).numpy()

214
        mask = trt_np != -1
215
216
217
218
219
220
221
222
223

        set_ref = set(ref_np[mask])
        set_trt = set(trt_np[mask])
        intersection = set_ref & set_trt
        intersections.append(len(intersection) / len(set_ref))
    print("average intersections: {:.4f}".format(sum(intersections) / len(intersections)))


test_kernel()