svg2_attn.py 12.8 KB
Newer Older
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
from typing import Optional

# Please reinstall flashinfer by referring to https://github.com/svg-project/Sparse-VideoGen
try:
    import flashinfer
except ImportError:
    flashinfer = None

import torch
import triton
import triton.language as tl

from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER

from .svg2_attn_utils import (
    batch_kmeans_Euclid,
    identify_dynamic_map,
)
from .template import AttnWeightTemplate


@triton.jit
def _permute_kernel(
    X_ptr,
    IDX_ptr,
    Y_ptr,
    S: tl.constexpr,
    D: tl.constexpr,
    BLOCK_S: tl.constexpr,
):
    """Each program permutes BLOCK_S tokens *all* hidden features (D). No inner python loop."""

    pid_bh = tl.program_id(0)
    tile_s = tl.program_id(1)

    # Offsets along sequence
    s_offsets = tile_s * BLOCK_S + tl.arange(0, BLOCK_S)
    token_mask = s_offsets < S

    # Gather source indices for these tokens
    idx_ptrs = IDX_ptr + pid_bh * S + s_offsets
    src_row_idx = tl.load(idx_ptrs, mask=token_mask, other=0).to(tl.int32)

    # Broadcast to create 2-D pointer matrix (BLOCK_S, D)
    d_offsets = tl.arange(0, D)

    src_ptrs = X_ptr + (pid_bh * S + src_row_idx[:, None]) * D + d_offsets[None, :]
    dst_ptrs = Y_ptr + (pid_bh * S + s_offsets[:, None]) * D + d_offsets[None, :]

    full_mask = token_mask[:, None]

    values = tl.load(src_ptrs, mask=full_mask, other=0.0)
    tl.store(dst_ptrs, values, mask=full_mask)


def permute_tensor_by_labels_triton(
    tensor: torch.Tensor,
    labels: Optional[torch.Tensor],
    dim: int,
    *,
    sorted_indices: Optional[torch.Tensor] = None,
):
    """
    Permute `tensor` along `dim` according to ascending order of `labels`.

    This is a Triton-accelerated replacement for the original implementation.
    It currently supports 4-D tensors of shape [B, H, S, D] and `dim == 2`.
    If these conditions are not met or the tensors reside on CPU, we fall back
    to the reference PyTorch implementation.
    """

    # Assertions – we only support the optimized CUDA path.
    assert dim == 2, "permute_tensor_by_labels currently only supports dim==2 (sequence dimension)"
    assert tensor.dim() == 4, "Expected tensor shape [B,H,S,D]"
    assert tensor.is_cuda, "permute_tensor_by_labels requires CUDA tensors"

    B, H, S, D = tensor.shape
    BH = B * H

    # Determine sorted indices
    if sorted_indices is not None:
        sorted_indices = sorted_indices.to(torch.int32).contiguous()
    else:
        assert labels is not None, "Either `labels` or `sorted_indices` must be provided."
        labels = labels.to(tensor.device)
        sorted_indices = torch.argsort(labels, dim=-1).to(torch.int32).contiguous()

    # Flatten tensor and allocate output
    inp_flat = tensor.reshape(BH, S, D).contiguous()
    out_flat = torch.empty_like(inp_flat)

    # Triton kernel tile size
    BLOCK_S = 64  # number of tokens per program, tunable

    n_s_tiles = triton.cdiv(S, BLOCK_S)
    grid = (BH, n_s_tiles)

    _permute_kernel[grid](inp_flat, sorted_indices, out_flat, S, D, BLOCK_S, num_warps=4)

    permuted_tensor = out_flat.reshape(B, H, S, D)
    return permuted_tensor, sorted_indices


@triton.jit
def _inverse_permute_kernel(
    X_ptr,
    IDX_ptr,
    Y_ptr,
    S: tl.constexpr,
    D: tl.constexpr,
    BLOCK_S: tl.constexpr,
):
    """Inverse permutation: scatter BLOCK_S tokens back in one shot."""

    pid_bh = tl.program_id(0)
    tile_s = tl.program_id(1)

    s_offsets = tile_s * BLOCK_S + tl.arange(0, BLOCK_S)
    token_mask = s_offsets < S

    idx_ptrs = IDX_ptr + pid_bh * S + s_offsets
    src_pos_idx = s_offsets.to(tl.int32)
    dst_pos_idx = tl.load(idx_ptrs, mask=token_mask, other=0).to(tl.int32)

    d_offsets = tl.arange(0, D)

    src_ptrs = X_ptr + (pid_bh * S + src_pos_idx[:, None]) * D + d_offsets[None, :]
    dst_ptrs = Y_ptr + (pid_bh * S + dst_pos_idx[:, None]) * D + d_offsets[None, :]

    full_mask = token_mask[:, None]

    values = tl.load(src_ptrs, mask=full_mask, other=0.0)
    tl.store(dst_ptrs, values, mask=full_mask)


def apply_inverse_permutation_triton(
    permuted_tensor: torch.Tensor,
    sorted_indices: torch.Tensor,
    dim: int,
):
    """
    Triton implementation of inverse permutation. Inverse the permutation applied by `permute_tensor_by_labels`.

    Args:
        permuted_tensor: (B, H, S, D).
        sorted_indices: (B, H, S).
        dim: Dimension along which to apply inverse permutation. Typically 2, meaning the sequence lengthdimension.

    Returns:
        Tensor of shape (B, H, S, D).
    """

    assert dim == 2, "apply_inverse_permutation currently only supports dim==2"
    assert permuted_tensor.dim() == 4, "Expected tensor shape [B,H,S,D]"
    assert permuted_tensor.is_cuda, "apply_inverse_permutation requires CUDA tensors"

    B, H, S, D = permuted_tensor.shape
    BH = B * H

    # Ensure index dtype
    sorted_indices = sorted_indices.to(torch.int32).contiguous()

    # Flatten inputs
    inp_flat = permuted_tensor.reshape(BH, S, D).contiguous()
    out_flat = torch.empty_like(inp_flat)

    BLOCK_S = 64
    n_s_tiles = triton.cdiv(S, BLOCK_S)
    grid = (BH, n_s_tiles)

    _inverse_permute_kernel[grid](inp_flat, sorted_indices, out_flat, S, D, BLOCK_S, num_warps=4)

    original_tensor = out_flat.reshape(B, H, S, D)
    return original_tensor


@ATTN_WEIGHT_REGISTER("svg2_attn")
class Svg2AttnWeight(AttnWeightTemplate):
    centroids_init = False
    num_q_centroids = 300
    num_k_centroids = 1000
    kmeans_iter_init = 50
    top_p_kmeans = 0.9
    min_kc_ratio = 0.10
    kmeans_iter_step = 2

    def __init__(self):
        self.config = {}

    def apply(
        self,
        q,
        k,
        v,
        cu_seqlens_q=None,
        cu_seqlens_kv=None,
        max_seqlen_q=None,
        max_seqlen_kv=None,
        model_cls=None,
    ):
        q = q.unsqueeze(0).transpose(1, 2)
        k = k.unsqueeze(0).transpose(1, 2)
        v = v.unsqueeze(0).transpose(1, 2)
        bs, num_heads, seq_len, dim = q.size()
        q_perm, k_perm, v_perm, dyn_map, qc_sz_s, kc_sz_s, q_sorted_indices = self.semantic_aware_permutation(q, k, v)

        output_permuted = self.dynamic_block_sparse_fwd_flashinfer(q_perm, k_perm, v_perm, dyn_map, qc_sz_s, kc_sz_s, is_cpu=False)

        attn_output = apply_inverse_permutation_triton(output_permuted, q_sorted_indices, dim=2)

        return attn_output.reshape(bs, num_heads, seq_len, dim).transpose(1, 2).reshape(bs * seq_len, -1)

    def dynamic_block_sparse_fwd_flashinfer(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        block_mask_map: torch.Tensor,
        block_row_sz: torch.Tensor,
        block_col_sz: torch.Tensor,
        is_cpu: bool = True,
    ):
        """
        Launcher for the Flashinfer dynamic block sparse attention kernel.

        Args:
            q (torch.Tensor): Query tensor, shape [B, H, S, D].
            k (torch.Tensor): Key tensor, shape [B, H, S, D].
            v (torch.Tensor): Value tensor, shape [B, H, S, D].
            block_mask_map (torch.Tensor): Boolean mask, shape [B, H, qc_num, kc_num]. Currently must on CPU.
            block_row_sz (torch.Tensor): Query block sizes, shape [B, H, qc_num]. Currently must on CPU.
            block_col_sz (torch.Tensor): Key block sizes, shape [B, H, kc_num]. Currently must on CPU.
            is_cpu (bool): Whether to run on CPU. Flashinfer default is to run on CPU. We switch to GPU for faster planning. Default is True.
        """
        # Input shape check
        B, H, S, D = q.shape
        qc_num = block_row_sz.shape[-1]
        kc_num = block_col_sz.shape[-1]
        assert block_mask_map.shape == (B, H, qc_num, kc_num)

        assert all(t.device == torch.device("cpu") for t in [block_mask_map, block_row_sz, block_col_sz]) if is_cpu else True

        # Check if block_col_sz and block_row_sz are the same for each head
        assert torch.all(block_col_sz.sum(dim=2) == block_col_sz.sum(dim=2)[0, 0])
        assert torch.all(block_row_sz.sum(dim=2) == block_row_sz.sum(dim=2)[0, 0])

        # Prepare flashinfer wrapper
        float_workspace_buffer = torch.empty(128 * 1024 * 1024, device=q.device)
        vector_sparse_indices_buffer = torch.empty(1024 * 1024 * 1024, device=q.device)
        wrapper = flashinfer.sparse.VariableBlockSparseAttentionWrapper(float_workspace_buffer, backend="auto")
        wrapper.reset_workspace_buffer(
            float_workspace_buffer=wrapper._float_workspace_buffer,
            int_workspace_buffer=wrapper._int_workspace_buffer,
            vector_sparse_indices_buffer=vector_sparse_indices_buffer,  # Only reset this buffer size
            vector_sparse_indptr_buffer=wrapper._vector_sparse_indptr_buffer,
        )

        block_mask_map = block_mask_map.reshape(B * H, qc_num, kc_num)
        block_row_sz = block_row_sz.reshape(B * H, qc_num)
        block_col_sz = block_col_sz.reshape(B * H, kc_num)

        wrapper.plan(
            block_mask_map=block_mask_map,
            block_row_sz=block_row_sz,
            block_col_sz=block_col_sz,
            num_qo_heads=B * H,
            num_kv_heads=B * H,
            head_dim=D,
            q_data_type=q.dtype,
            kv_data_type=k.dtype,
        )

        # print_memory_usage("After plan")

        q = q.reshape(B * H, S, D)
        k = k.reshape(B * H, S, D)
        v = v.reshape(B * H, S, D)
        o = wrapper.run(q, k, v)  # [num_qo_heads, qo_len, head_dim]
        o = o.reshape(B, H, S, D)
        return o

    def semantic_aware_permutation(self, query, key, value):
        cfg, num_heads, seq_len, dim = query.size()

        # 1. Kmeans clustering
        qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter = self.kmeans_clustering(query, key)

        # 2. Identify dynamic map
        q_cluster_sizes = qcluster_sizes.view(cfg, num_heads, self.num_q_centroids)
        k_cluster_sizes = kcluster_sizes.view(cfg, num_heads, self.num_k_centroids)

        dynamic_map = identify_dynamic_map(
            qcentroids.view(cfg, num_heads, self.num_q_centroids, dim),
            kcentroids.view(cfg, num_heads, self.num_k_centroids, dim),
            q_cluster_sizes,
            k_cluster_sizes,
            self.top_p_kmeans,
            self.min_kc_ratio,
        )

        # 3. Permute the query, key, value
        q_permuted, q_sorted_indices = permute_tensor_by_labels_triton(query, qlabels, dim=2)
        k_permuted, k_sorted_indices = permute_tensor_by_labels_triton(key, klabels, dim=2)
        v_permuted, v_sorted_indices = permute_tensor_by_labels_triton(value, klabels, dim=2, sorted_indices=k_sorted_indices)

        return q_permuted, k_permuted, v_permuted, dynamic_map, q_cluster_sizes, k_cluster_sizes, q_sorted_indices

    def kmeans_clustering(self, query, key):
        if not self.centroids_init:
            qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter = self.kmeans_init(query, key)
            self.centroids_init = True
        else:
            qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter = self.kmeans_step(query, key)

        return qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter

    def kmeans_init(self, query, key):
        cfg, num_heads, seq_len, dim = query.size()
        qlabels, qcentroids, qcluster_sizes, qiter = batch_kmeans_Euclid(query.view(cfg * num_heads, seq_len, dim), n_clusters=self.num_q_centroids, max_iters=self.kmeans_iter_init)
        klabels, kcentroids, kcluster_sizes, kiter = batch_kmeans_Euclid(key.view(cfg * num_heads, seq_len, dim), n_clusters=self.num_k_centroids, max_iters=self.kmeans_iter_init)

        self.q_centroids = qcentroids
        self.k_centroids = kcentroids

        return qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter

    def kmeans_step(self, query, key):
        cfg, num_heads, seq_len, dim = query.size()
        qlabels, qcentroids, qcluster_sizes, qiter = batch_kmeans_Euclid(
            query.view(cfg * num_heads, seq_len, dim),
            n_clusters=self.num_q_centroids,
            max_iters=self.kmeans_iter_step,
            init_centroids=self.q_centroids,
        )
        klabels, kcentroids, kcluster_sizes, kiter = batch_kmeans_Euclid(
            key.view(cfg * num_heads, seq_len, dim),
            n_clusters=self.num_k_centroids,
            max_iters=self.kmeans_iter_step,
            init_centroids=self.k_centroids,
        )

        self.q_centroids = qcentroids
        self.k_centroids = kcentroids

        return qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter


if __name__ == "__main__":
    q, k, v = torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda(), torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda(), torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda()

    svg2_attn = Svg2AttnWeight()
    print("Svg2AttnWeight initialized.")

    out = svg2_attn.apply(q, k, v)
    print(f"out: {out.shape}, {out.dtype}, {out.device}")