Unverified Commit 837feba7 authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files

Update sparse attention (#432)

parent fc231d3d
......@@ -11,11 +11,11 @@ from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
def generate_nbhd_mask(a, block_num, num_frame, device="cpu"):
def generate_nbhd_mask(a, block_num, attnmap_frame_num, device="cpu"):
"""
a : block num per frame
block_num : block num per col/row
num_frame : total frame num
attnmap_frame_num : total frame num
"""
i_indices = torch.arange(block_num, device=device).unsqueeze(1) # [block_num, 1]
j_indices = torch.arange(block_num, device=device).unsqueeze(0) # [1, block_num]
......@@ -29,7 +29,7 @@ def generate_nbhd_mask(a, block_num, num_frame, device="cpu"):
# 3. cross-frame attention
mask_cross = torch.zeros((block_num, block_num), dtype=torch.bool, device=device)
for n in range(1, num_frame):
for n in range(1, attnmap_frame_num):
if n == 1:
width = 1 / 2 * a
elif n >= 2:
......@@ -67,7 +67,7 @@ def generate_qk_ranges(mask, block_size, seqlen):
class NbhdAttnWeight(AttnWeightTemplate):
block_size = 128
seqlen = None
num_frame = None
attnmap_frame_num = None
q_ranges = None
k_ranges = None
attn_type_map = None
......@@ -76,22 +76,21 @@ class NbhdAttnWeight(AttnWeightTemplate):
self.config = {}
@classmethod
def prepare_mask(cls, seqlen, num_frame):
if seqlen == cls.seqlen and num_frame == cls.num_frame:
def prepare_mask(cls, seqlen):
if seqlen == cls.seqlen:
return
block_num = (seqlen + cls.block_size - 1) // cls.block_size
block_num_per_frame = (seqlen // num_frame + cls.block_size - 1) // cls.block_size
mask = generate_nbhd_mask(block_num_per_frame, block_num, num_frame, device="cpu")
block_num_per_frame = (seqlen // cls.attnmap_frame_num + cls.block_size - 1) // cls.block_size
mask = generate_nbhd_mask(block_num_per_frame, block_num, cls.attnmap_frame_num, device="cpu")
q_ranges, k_ranges = generate_qk_ranges(mask, cls.block_size, seqlen)
attn_type_map = torch.zeros(len(q_ranges), dtype=torch.int32, device="cuda")
q_ranges = q_ranges.to(torch.int32).to("cuda")
k_ranges = k_ranges.to(torch.int32).to("cuda")
cls.seqlen = seqlen
cls.num_frame = num_frame
cls.q_ranges = q_ranges
cls.k_ranges = k_ranges
cls.attn_type_map = attn_type_map
logger.info(f"NbhdAttnWeight Update: seqlen={seqlen}, num_frame={num_frame}")
logger.info(f"NbhdAttnWeight Update: seqlen={seqlen}")
sparsity = 1 - mask.sum().item() / mask.numel()
logger.info(f"Attention sparsity: {sparsity}")
......@@ -111,8 +110,7 @@ class NbhdAttnWeight(AttnWeightTemplate):
k: [seqlen, head_num, head_dim]
v: [seqlen, head_num, head_dim]
"""
num_frame = 21
self.prepare_mask(seqlen=q.shape[0], num_frame=num_frame)
self.prepare_mask(seqlen=q.shape[0])
out = magi_ffa_func(
q,
k,
......
......@@ -299,8 +299,8 @@ class SvgAttnWeight(AttnWeightTemplate):
sample_mse_max_row = None
num_sampled_rows = None
context_length = None
num_frame = None
frame_size = None
attnmap_frame_num = None
seqlen = None
sparsity = None
mask_name_list = ["spatial", "temporal"]
attention_masks = None
......@@ -325,18 +325,18 @@ class SvgAttnWeight(AttnWeightTemplate):
self.sparse_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs")
@classmethod
def prepare_mask(cls, num_frame, frame_size):
def prepare_mask(cls, seqlen):
# Use class attributes so updates affect all instances of this class
if num_frame == cls.num_frame and frame_size == cls.frame_size:
if seqlen == cls.seqlen:
return
cls.num_frame = num_frame
cls.frame_size = frame_size
cls.attention_masks = [get_attention_mask(mask_name, cls.sample_mse_max_row, cls.context_length, num_frame, frame_size) for mask_name in cls.mask_name_list]
multiplier = diag_width = sparsity_to_width(cls.sparsity, cls.context_length, num_frame, frame_size)
frame_size = seqlen // cls.attnmap_frame_num
cls.attention_masks = [get_attention_mask(mask_name, cls.sample_mse_max_row, cls.context_length, cls.attnmap_frame_num, frame_size) for mask_name in cls.mask_name_list]
multiplier = diag_width = sparsity_to_width(cls.sparsity, cls.context_length, cls.attnmap_frame_num, frame_size)
cls.block_mask = prepare_flexattention(
1, cls.head_num, cls.head_dim, torch.bfloat16, "cuda", cls.context_length, cls.context_length, num_frame, frame_size, diag_width=diag_width, multiplier=multiplier
1, cls.head_num, cls.head_dim, torch.bfloat16, "cuda", cls.context_length, cls.context_length, cls.attnmap_frame_num, frame_size, diag_width=diag_width, multiplier=multiplier
)
logger.info(f"SvgAttnWeight Update: num_frame={num_frame}, frame_size={frame_size}")
cls.seqlen = seqlen
logger.info(f"SvgAttnWeight Update: seqlen={seqlen}")
def apply(
self,
......@@ -354,18 +354,19 @@ class SvgAttnWeight(AttnWeightTemplate):
v = v.unsqueeze(0).transpose(1, 2)
bs, num_heads, seq_len, dim = q.size()
num_frame = 21
self.prepare_mask(num_frame=num_frame, frame_size=seq_len // num_frame)
self.prepare_mask(seq_len)
sampled_mses = self.sample_mse(q, k, v)
best_mask_idx = torch.argmin(sampled_mses, dim=0)
output_hidden_states = torch.zeros_like(q)
query_out, key_out, value_out = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v)
query_out, key_out, value_out = self.fast_sparse_head_placement(q, k, v, query_out, key_out, value_out, best_mask_idx, self.context_length, self.num_frame, self.frame_size)
query_out, key_out, value_out = self.fast_sparse_head_placement(
q, k, v, query_out, key_out, value_out, best_mask_idx, self.context_length, self.attnmap_frame_num, seq_len // self.attnmap_frame_num
)
hidden_states = self.sparse_attention(query_out, key_out, value_out)
wan_hidden_states_placement(hidden_states, output_hidden_states, best_mask_idx, self.context_length, self.num_frame, self.frame_size)
wan_hidden_states_placement(hidden_states, output_hidden_states, best_mask_idx, self.context_length, self.attnmap_frame_num, seq_len // self.attnmap_frame_num)
return output_hidden_states.reshape(bs, num_heads, seq_len, dim).transpose(1, 2).reshape(bs * seq_len, -1)
......
......@@ -192,6 +192,8 @@ class WanSelfAttention(WeightModule):
context_length=self.config.get("svg_context_length", 0),
sparsity=self.config.get("svg_sparsity", 0.25),
)
if self.config["self_attn_1_type"] in ["svg_attn", "nbhd_attn"]:
attention_weights_cls.attnmap_frame_num = self.config["attnmap_frame_num"]
self.add_module("self_attn_1", attention_weights_cls())
if self.config["seq_parallel"]:
......
......@@ -71,6 +71,8 @@ def set_config(args):
logger.warning(f"`num_frames - 1` has to be divisible by {config['vae_stride'][0]}. Rounding to the nearest number.")
config["target_video_length"] = config["target_video_length"] // config["vae_stride"][0] * config["vae_stride"][0] + 1
config["attnmap_frame_num"] = ((config["target_video_length"] - 1) // config["vae_stride"][0] + 1) // config["patch_size"][0]
return config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment