svg_attn.py 16.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
import math
from functools import lru_cache
from math import ceil

import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from loguru import logger
from torch.nn.attention.flex_attention import create_block_mask, flex_attention

from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER

from .template import AttnWeightTemplate


@triton.jit
def wan_hidden_states_placement_kernel(
    hidden_states_ptr,  # [cfg, num_heads, seq_len, head_dim] seq_len = context_length + num_frame * frame_size
    hidden_states_out_ptr,  # [cfg, num_heads, seq_len, head_dim]
    best_mask_idx_ptr,  # [cfg, num_heads]
    hidden_states_stride_b,
    hidden_states_stride_h,
    hidden_states_stride_s,
    hidden_states_stride_d,
    mask_idx_stride_b,
    mask_idx_stride_h,
    seq_len: tl.constexpr,
    head_dim: tl.constexpr,
    context_length: tl.constexpr,
    num_frame: tl.constexpr,
    frame_size: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    # Copy hidden_states to output
    # range: [b, h, block_id * block_size: block_id * block_size + block_size, :]
    cfg = tl.program_id(0)
    head = tl.program_id(1)
    block_id = tl.program_id(2)

    start_id = block_id * BLOCK_SIZE
    end_id = start_id + BLOCK_SIZE
    end_id = tl.where(end_id > seq_len, seq_len, end_id)

    # Load best mask idx (0 is spatial, 1 is temporal)
    is_temporal = tl.load(best_mask_idx_ptr + cfg * mask_idx_stride_b + head * mask_idx_stride_h)

    offset_token = tl.arange(0, BLOCK_SIZE) + start_id
    offset_mask = offset_token < seq_len
    offset_d = tl.arange(0, head_dim)

    if is_temporal:
        patch_id = offset_token // num_frame
        frame_id = offset_token - patch_id * num_frame
        offset_store_token = tl.where(offset_token >= seq_len - context_length, offset_token, frame_id * frame_size + patch_id)

        offset_load = (cfg * hidden_states_stride_b + head * hidden_states_stride_h + offset_token[:, None] * hidden_states_stride_s) + offset_d[None, :] * hidden_states_stride_d
        offset_hidden_states = hidden_states_ptr + offset_load

        offset_store = (cfg * hidden_states_stride_b + head * hidden_states_stride_h + offset_store_token[:, None] * hidden_states_stride_s) + offset_d[None, :] * hidden_states_stride_d
        offset_hidden_states_out = hidden_states_out_ptr + offset_store

        # Maybe tune the pipeline here
        hidden_states = tl.load(offset_hidden_states, mask=offset_mask[:, None])
        tl.store(offset_hidden_states_out, hidden_states, mask=offset_mask[:, None])
    else:
        offset_load = (cfg * hidden_states_stride_b + head * hidden_states_stride_h + offset_token[:, None] * hidden_states_stride_s) + offset_d[None, :] * hidden_states_stride_d
        offset_hidden_states = hidden_states_ptr + offset_load

        offset_store = offset_load
        offset_hidden_states_out = hidden_states_out_ptr + offset_store

        # Maybe tune the pipeline here
        hidden_states = tl.load(offset_hidden_states, mask=offset_mask[:, None])
        tl.store(offset_hidden_states_out, hidden_states, mask=offset_mask[:, None])


def wan_hidden_states_placement(hidden_states, hidden_states_out, best_mask_idx, context_length, num_frame, frame_size):
    cfg, num_heads, seq_len, head_dim = hidden_states.shape
    BLOCK_SIZE = 128
    assert seq_len == context_length + num_frame * frame_size

    grid = (cfg, num_heads, (seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE)

    wan_hidden_states_placement_kernel[grid](
        hidden_states,
        hidden_states_out,
        best_mask_idx,
        hidden_states.stride(0),
        hidden_states.stride(1),
        hidden_states.stride(2),
        hidden_states.stride(3),
        best_mask_idx.stride(0),
        best_mask_idx.stride(1),
        seq_len,
        head_dim,
        context_length,
        num_frame,
        frame_size,
        BLOCK_SIZE,
    )

    return hidden_states_out


@triton.jit
def wan_sparse_head_placement_kernel(
    query_ptr,
    key_ptr,
    value_ptr,  # [cfg, num_heads, seq_len, head_dim] seq_len = context_length + num_frame * frame_size
    query_out_ptr,
    key_out_ptr,
    value_out_ptr,  # [cfg, num_heads, seq_len, head_dim]
    best_mask_idx_ptr,  # [cfg, num_heads]
    query_stride_b,
    query_stride_h,
    query_stride_s,
    query_stride_d,
    mask_idx_stride_b,
    mask_idx_stride_h,
    seq_len: tl.constexpr,
    head_dim: tl.constexpr,
    context_length: tl.constexpr,
    num_frame: tl.constexpr,
    frame_size: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    # Copy query, key, value to output
    # range: [b, h, block_id * block_size: block_id * block_size + block_size, :]
    cfg = tl.program_id(0)
    head = tl.program_id(1)
    block_id = tl.program_id(2)

    start_id = block_id * BLOCK_SIZE
    end_id = start_id + BLOCK_SIZE
    end_id = tl.where(end_id > seq_len, seq_len, end_id)

    # Load best mask idx (0 is spatial, 1 is temporal)
    is_temporal = tl.load(best_mask_idx_ptr + cfg * mask_idx_stride_b + head * mask_idx_stride_h)

    offset_token = tl.arange(0, BLOCK_SIZE) + start_id
    offset_mask = offset_token < seq_len
    offset_d = tl.arange(0, head_dim)

    if is_temporal:
        frame_id = offset_token // frame_size
        patch_id = offset_token - frame_id * frame_size
        offset_store_token = tl.where(offset_token >= seq_len - context_length, offset_token, patch_id * num_frame + frame_id)

        offset_load = (cfg * query_stride_b + head * query_stride_h + offset_token[:, None] * query_stride_s) + offset_d[None, :] * query_stride_d
        offset_query = query_ptr + offset_load
        offset_key = key_ptr + offset_load
        offset_value = value_ptr + offset_load

        offset_store = (cfg * query_stride_b + head * query_stride_h + offset_store_token[:, None] * query_stride_s) + offset_d[None, :] * query_stride_d
        offset_query_out = query_out_ptr + offset_store
        offset_key_out = key_out_ptr + offset_store
        offset_value_out = value_out_ptr + offset_store

        # Maybe tune the pipeline here
        query = tl.load(offset_query, mask=offset_mask[:, None])
        tl.store(offset_query_out, query, mask=offset_mask[:, None])
        key = tl.load(offset_key, mask=offset_mask[:, None])
        tl.store(offset_key_out, key, mask=offset_mask[:, None])
        value = tl.load(offset_value, mask=offset_mask[:, None])
        tl.store(offset_value_out, value, mask=offset_mask[:, None])

    else:
        offset_load = (cfg * query_stride_b + head * query_stride_h + offset_token[:, None] * query_stride_s) + offset_d[None, :] * query_stride_d
        offset_query = query_ptr + offset_load
        offset_key = key_ptr + offset_load
        offset_value = value_ptr + offset_load

        offset_store = offset_load
        offset_query_out = query_out_ptr + offset_store
        offset_key_out = key_out_ptr + offset_store
        offset_value_out = value_out_ptr + offset_store

        # Maybe tune the pipeline here
        query = tl.load(offset_query, mask=offset_mask[:, None])
        tl.store(offset_query_out, query, mask=offset_mask[:, None])
        key = tl.load(offset_key, mask=offset_mask[:, None])
        tl.store(offset_key_out, key, mask=offset_mask[:, None])
        value = tl.load(offset_value, mask=offset_mask[:, None])
        tl.store(offset_value_out, value, mask=offset_mask[:, None])


def wan_sparse_head_placement(query, key, value, query_out, key_out, value_out, best_mask_idx, context_length, num_frame, frame_size):
    cfg, num_heads, seq_len, head_dim = query.shape
    BLOCK_SIZE = 128
    assert seq_len == context_length + num_frame * frame_size

    grid = (cfg, num_heads, (seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE)

    wan_sparse_head_placement_kernel[grid](
        query,
        key,
        value,
        query_out,
        key_out,
        value_out,
        best_mask_idx,
        query.stride(0),
        query.stride(1),
        query.stride(2),
        query.stride(3),
        best_mask_idx.stride(0),
        best_mask_idx.stride(1),
        seq_len,
        head_dim,
        context_length,
        num_frame,
        frame_size,
        BLOCK_SIZE,
    )


def generate_temporal_head_mask_mod(context_length: int = 226, prompt_length: int = 226, num_frames: int = 13, token_per_frame: int = 1350, mul: int = 2):
    def round_to_multiple(idx):
        return ceil(idx / 128) * 128

    def temporal_mask_mod(b, h, q_idx, kv_idx):
        two_frame = round_to_multiple(mul * token_per_frame)
        temporal_head_mask = torch.abs(q_idx - kv_idx) <= two_frame

        # return temporal_head_mask
        first_frame_mask = kv_idx < token_per_frame
        video_mask = first_frame_mask | temporal_head_mask
        return video_mask

    return temporal_mask_mod


@lru_cache
def create_block_mask_cached(score_mod, B, H, M, N, device="cuda", _compile=False):
    block_mask = create_block_mask(score_mod, B, H, M, N, device=device, _compile=_compile)
    return block_mask


def prepare_flexattention(cfg_size, num_head, head_dim, dtype, device, context_length, prompt_length, num_frame, frame_size, diag_width=1, multiplier=2):
    assert diag_width == multiplier, f"{diag_width} is not equivalent to {multiplier}"
    seq_len = context_length + num_frame * frame_size
    mask_mod = generate_temporal_head_mask_mod(context_length, prompt_length, num_frame, frame_size, mul=multiplier)
    block_mask = create_block_mask_cached(mask_mod, None, None, seq_len, seq_len, device=device, _compile=True)
    return block_mask


def sparsity_to_width(sparsity, context_length, num_frame, frame_size):
    seq_len = context_length + num_frame * frame_size
    total_elements = seq_len**2

    sparsity = (sparsity * total_elements - 2 * seq_len * context_length) / total_elements

    width = seq_len * (1 - math.sqrt(1 - sparsity))
    width_frame = width / frame_size

    return width_frame


def get_attention_mask(mask_name, sample_mse_max_row, context_length, num_frame, frame_size):
    attention_mask = torch.zeros((context_length + num_frame * frame_size, context_length + num_frame * frame_size), device="cpu")

    # TODO: fix hard coded mask
    if mask_name == "spatial":
        pixel_attn_mask = torch.zeros_like(attention_mask, dtype=torch.bool, device="cpu")

        pixel_attn_mask[:, :frame_size] = 1  # First Frame Sink

        block_size, block_thres = 128, frame_size * 2
        num_block = math.ceil(num_frame * frame_size / block_size)
        for i in range(num_block):
            for j in range(num_block):
                if abs(i - j) < block_thres // block_size:
                    pixel_attn_mask[i * block_size : (i + 1) * block_size, j * block_size : (j + 1) * block_size] = 1
        attention_mask = pixel_attn_mask
    else:
        pixel_attn_mask = torch.zeros_like(attention_mask, dtype=torch.bool, device="cpu")

        pixel_attn_mask[:, :frame_size] = 1  # First Frame Sink

        block_size, block_thres = 128, frame_size * 2
        num_block = math.ceil(num_frame * frame_size / block_size)
        for i in range(num_block):
            for j in range(num_block):
                if abs(i - j) < block_thres // block_size:
                    pixel_attn_mask[i * block_size : (i + 1) * block_size, j * block_size : (j + 1) * block_size] = 1

        pixel_attn_mask = pixel_attn_mask.reshape(frame_size, num_frame, frame_size, num_frame).permute(1, 0, 3, 2).reshape(frame_size * num_frame, frame_size * num_frame)
        attention_mask = pixel_attn_mask

    attention_mask = attention_mask[:sample_mse_max_row].cuda()
    return attention_mask


@ATTN_WEIGHT_REGISTER("svg_attn")
class SvgAttnWeight(AttnWeightTemplate):
    head_num = None
    head_dim = None
    sample_mse_max_row = None
    num_sampled_rows = None
    context_length = None
302
303
    attnmap_frame_num = None
    seqlen = None
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
    sparsity = None
    mask_name_list = ["spatial", "temporal"]
    attention_masks = None
    block_mask = None

    @classmethod
    def prepare(cls, head_num, head_dim, sample_mse_max_row, num_sampled_rows, context_length, sparsity):
        cls.head_num = head_num
        cls.head_dim = head_dim
        cls.sample_mse_max_row = sample_mse_max_row
        cls.num_sampled_rows = num_sampled_rows
        cls.context_length = context_length
        cls.sparsity = sparsity
        torch._dynamo.config.cache_size_limit = 192 * 3
        torch._dynamo.config.accumulated_cache_size_limit = 192 * 3
        logger.info(
            f"SvgAttnWeight Prepare: head_num={head_num}, head_dim={head_dim}, sample_mse_max_row={sample_mse_max_row}, num_sampled_rows={num_sampled_rows}, context_length={context_length}, sparsity={sparsity}"
        )

    def __init__(self):
        self.config = {}
        self.sparse_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs")

    @classmethod
328
    def prepare_mask(cls, seqlen):
329
        # Use class attributes so updates affect all instances of this class
330
        if seqlen == cls.seqlen:
331
            return
332
333
334
        frame_size = seqlen // cls.attnmap_frame_num
        cls.attention_masks = [get_attention_mask(mask_name, cls.sample_mse_max_row, cls.context_length, cls.attnmap_frame_num, frame_size) for mask_name in cls.mask_name_list]
        multiplier = diag_width = sparsity_to_width(cls.sparsity, cls.context_length, cls.attnmap_frame_num, frame_size)
335
        cls.block_mask = prepare_flexattention(
336
            1, cls.head_num, cls.head_dim, torch.bfloat16, "cuda", cls.context_length, cls.context_length, cls.attnmap_frame_num, frame_size, diag_width=diag_width, multiplier=multiplier
337
        )
338
339
        cls.seqlen = seqlen
        logger.info(f"SvgAttnWeight Update: seqlen={seqlen}")
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356

    def apply(
        self,
        q,
        k,
        v,
        cu_seqlens_q=None,
        cu_seqlens_kv=None,
        max_seqlen_q=None,
        max_seqlen_kv=None,
        model_cls=None,
    ):
        q = q.unsqueeze(0).transpose(1, 2)
        k = k.unsqueeze(0).transpose(1, 2)
        v = v.unsqueeze(0).transpose(1, 2)
        bs, num_heads, seq_len, dim = q.size()

357
        self.prepare_mask(seq_len)
358
359
360
361
362
363
        sampled_mses = self.sample_mse(q, k, v)
        best_mask_idx = torch.argmin(sampled_mses, dim=0)

        output_hidden_states = torch.zeros_like(q)
        query_out, key_out, value_out = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v)

364
365
366
        query_out, key_out, value_out = self.fast_sparse_head_placement(
            q, k, v, query_out, key_out, value_out, best_mask_idx, self.context_length, self.attnmap_frame_num, seq_len // self.attnmap_frame_num
        )
367
368

        hidden_states = self.sparse_attention(query_out, key_out, value_out)
369
        wan_hidden_states_placement(hidden_states, output_hidden_states, best_mask_idx, self.context_length, self.attnmap_frame_num, seq_len // self.attnmap_frame_num)
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409

        return output_hidden_states.reshape(bs, num_heads, seq_len, dim).transpose(1, 2).reshape(bs * seq_len, -1)

    def fast_sparse_head_placement(self, query, key, value, query_out, key_out, value_out, best_mask_idx, context_length, num_frame, frame_size):
        wan_sparse_head_placement(query, key, value, query_out, key_out, value_out, best_mask_idx, context_length, num_frame, frame_size)
        return query_out, key_out, value_out

    def sample_mse(self, query, key, value):
        cfg, num_heads, seq_len, dim = query.size()
        num_sampled_rows = min(self.num_sampled_rows, seq_len)
        sampled_rows = torch.randint(low=0, high=self.sample_mse_max_row, size=(num_sampled_rows,))
        sampled_q = query[:, :, sampled_rows, :]
        sampled_qk_scores = torch.matmul(sampled_q, key.transpose(-2, -1)) / (dim**0.5)

        sampled_attn_weights = F.softmax(sampled_qk_scores, dim=-1)
        sampled_golden_hidden_states = torch.matmul(sampled_attn_weights, value)  # (1, seq_len, dim)

        sampled_mses = torch.zeros(len(self.attention_masks), cfg, num_heads, device=query.device, dtype=query.dtype)

        # Only have Tri-diagonal and Striped
        for mask_idx, attn_mask in enumerate(self.attention_masks):
            sampled_attention_mask = attn_mask[sampled_rows, :]
            sampled_attention_scores = sampled_qk_scores.masked_fill(sampled_attention_mask == 0, float("-inf"))
            sampled_attn_weights = F.softmax(sampled_attention_scores, dim=-1)
            sampled_hidden_states = torch.matmul(sampled_attn_weights, value)
            mse = torch.mean((sampled_hidden_states - sampled_golden_hidden_states) ** 2, dim=(2, 3))
            sampled_mses[mask_idx] = mse

        return sampled_mses


if __name__ == "__main__":
    q, k, v = torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda(), torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda(), torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda()

    SvgAttnWeight.prepare(head_num=40, head_dim=128, sample_mse_max_row=10000, num_sampled_rows=64, context_length=0, sparsity=0.25)
    svg_attn = SvgAttnWeight()
    print("SvgAttnWeight initialized.")

    out = svg_attn.apply(q, k, v)
    print(f"out: {out.shape}, {out.dtype}, {out.device}")