Unverified Commit f37ce0b2 authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files

Support nbhd attention (#427)

parent 6062ef24
{
"infer_steps": 40,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "nbhd_attn",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"sample_guide_scale": 5,
"sample_shift": 3,
"enable_cfg": true,
"cpu_offload": false
}
from .flash_attn import FlashAttn2Weight, FlashAttn3Weight
from .nbhd_attn import NbhdAttnWeight
from .radial_attn import RadialAttnWeight
from .ring_attn import RingAttnWeight
from .sage_attn import SageAttn2Weight, SageAttn3Weight
......
import torch
from loguru import logger
try:
from magi_attention.functional import flex_flash_attn_func as magi_ffa_func
except ImportError:
magi_ffa_func = None
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
def generate_nbhd_mask(a, block_num, num_frame, device="cpu"):
"""
a : block num per frame
block_num : block num per col/row
num_frame : total frame num
"""
i_indices = torch.arange(block_num, device=device).unsqueeze(1) # [block_num, 1]
j_indices = torch.arange(block_num, device=device).unsqueeze(0) # [1, block_num]
# 1. attention sink frame: j <= a
mask_sink = j_indices <= a
# 2. self-attention within the frame
n = i_indices // a
mask_self = (j_indices >= n * a) & (j_indices < (n + 1) * a)
# 3. cross-frame attention
mask_cross = torch.zeros((block_num, block_num), dtype=torch.bool, device=device)
for n in range(1, num_frame):
if n == 1:
width = 1 / 2 * a
elif n >= 2:
width = 1 / 8 * a
mask_1 = (i_indices - j_indices + (n * a + width) >= 0) & (i_indices - j_indices + (n * a - width) < 0)
mask_2 = (i_indices - j_indices - (n * a - width) > 0) & (i_indices - j_indices - (n * a + width) <= 0)
mask_cross = mask_cross | mask_1 | mask_2
# 合并所有mask
mask = mask_sink | mask_self | mask_cross
return mask
def generate_qk_ranges(mask, block_size, seqlen):
indices = torch.nonzero(mask, as_tuple=False) # shape: [N, 2]
i_indices = indices[:, 0] # [N]
j_indices = indices[:, 1] # [N]
q_start = i_indices * block_size # [N]
q_end = torch.clamp((i_indices + 1) * block_size, max=seqlen) # [N]
k_start = j_indices * block_size # [N]
k_end = torch.clamp((j_indices + 1) * block_size, max=seqlen) # [N]
q_ranges = torch.stack([q_start, q_end], dim=1) # [N, 2]
k_ranges = torch.stack([k_start, k_end], dim=1) # [N, 2]
return q_ranges, k_ranges
@ATTN_WEIGHT_REGISTER("nbhd_attn")
class NbhdAttnWeight(AttnWeightTemplate):
block_size = 128
seqlen = None
num_frame = None
q_ranges = None
k_ranges = None
attn_type_map = None
def __init__(self):
self.config = {}
@classmethod
def prepare_mask(cls, seqlen, num_frame):
if seqlen == cls.seqlen and num_frame == cls.num_frame:
return
block_num = (seqlen + cls.block_size - 1) // cls.block_size
block_num_per_frame = (seqlen // num_frame + cls.block_size - 1) // cls.block_size
mask = generate_nbhd_mask(block_num_per_frame, block_num, num_frame, device="cpu")
q_ranges, k_ranges = generate_qk_ranges(mask, cls.block_size, seqlen)
attn_type_map = torch.zeros(len(q_ranges), dtype=torch.int32, device="cuda")
q_ranges = q_ranges.to(torch.int32).to("cuda")
k_ranges = k_ranges.to(torch.int32).to("cuda")
cls.seqlen = seqlen
cls.num_frame = num_frame
cls.q_ranges = q_ranges
cls.k_ranges = k_ranges
cls.attn_type_map = attn_type_map
logger.info(f"NbhdAttnWeight Update: seqlen={seqlen}, num_frame={num_frame}")
sparsity = 1 - mask.sum().item() / mask.numel()
logger.info(f"Attention sparsity: {sparsity}")
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
):
"""
q: [seqlen, head_num, head_dim]
k: [seqlen, head_num, head_dim]
v: [seqlen, head_num, head_dim]
"""
num_frame = 21
self.prepare_mask(seqlen=q.shape[0], num_frame=num_frame)
out = magi_ffa_func(
q,
k,
v,
q_ranges=self.q_ranges,
k_ranges=self.k_ranges,
attn_type_map=self.attn_type_map,
auto_range_merge=True,
)[0]
return out.reshape(out.shape[0], -1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment