"vscode:/vscode.git/clone" did not exist on "fcff6f6593bf9a9b0cd190fe04cc552d2fcb14c6"
Unverified Commit 837feba7 authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files

Update sparse attention (#432)

parent fc231d3d
...@@ -11,11 +11,11 @@ from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER ...@@ -11,11 +11,11 @@ from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate from .template import AttnWeightTemplate
def generate_nbhd_mask(a, block_num, num_frame, device="cpu"): def generate_nbhd_mask(a, block_num, attnmap_frame_num, device="cpu"):
""" """
a : block num per frame a : block num per frame
block_num : block num per col/row block_num : block num per col/row
num_frame : total frame num attnmap_frame_num : total frame num
""" """
i_indices = torch.arange(block_num, device=device).unsqueeze(1) # [block_num, 1] i_indices = torch.arange(block_num, device=device).unsqueeze(1) # [block_num, 1]
j_indices = torch.arange(block_num, device=device).unsqueeze(0) # [1, block_num] j_indices = torch.arange(block_num, device=device).unsqueeze(0) # [1, block_num]
...@@ -29,7 +29,7 @@ def generate_nbhd_mask(a, block_num, num_frame, device="cpu"): ...@@ -29,7 +29,7 @@ def generate_nbhd_mask(a, block_num, num_frame, device="cpu"):
# 3. cross-frame attention # 3. cross-frame attention
mask_cross = torch.zeros((block_num, block_num), dtype=torch.bool, device=device) mask_cross = torch.zeros((block_num, block_num), dtype=torch.bool, device=device)
for n in range(1, num_frame): for n in range(1, attnmap_frame_num):
if n == 1: if n == 1:
width = 1 / 2 * a width = 1 / 2 * a
elif n >= 2: elif n >= 2:
...@@ -67,7 +67,7 @@ def generate_qk_ranges(mask, block_size, seqlen): ...@@ -67,7 +67,7 @@ def generate_qk_ranges(mask, block_size, seqlen):
class NbhdAttnWeight(AttnWeightTemplate): class NbhdAttnWeight(AttnWeightTemplate):
block_size = 128 block_size = 128
seqlen = None seqlen = None
num_frame = None attnmap_frame_num = None
q_ranges = None q_ranges = None
k_ranges = None k_ranges = None
attn_type_map = None attn_type_map = None
...@@ -76,22 +76,21 @@ class NbhdAttnWeight(AttnWeightTemplate): ...@@ -76,22 +76,21 @@ class NbhdAttnWeight(AttnWeightTemplate):
self.config = {} self.config = {}
@classmethod @classmethod
def prepare_mask(cls, seqlen, num_frame): def prepare_mask(cls, seqlen):
if seqlen == cls.seqlen and num_frame == cls.num_frame: if seqlen == cls.seqlen:
return return
block_num = (seqlen + cls.block_size - 1) // cls.block_size block_num = (seqlen + cls.block_size - 1) // cls.block_size
block_num_per_frame = (seqlen // num_frame + cls.block_size - 1) // cls.block_size block_num_per_frame = (seqlen // cls.attnmap_frame_num + cls.block_size - 1) // cls.block_size
mask = generate_nbhd_mask(block_num_per_frame, block_num, num_frame, device="cpu") mask = generate_nbhd_mask(block_num_per_frame, block_num, cls.attnmap_frame_num, device="cpu")
q_ranges, k_ranges = generate_qk_ranges(mask, cls.block_size, seqlen) q_ranges, k_ranges = generate_qk_ranges(mask, cls.block_size, seqlen)
attn_type_map = torch.zeros(len(q_ranges), dtype=torch.int32, device="cuda") attn_type_map = torch.zeros(len(q_ranges), dtype=torch.int32, device="cuda")
q_ranges = q_ranges.to(torch.int32).to("cuda") q_ranges = q_ranges.to(torch.int32).to("cuda")
k_ranges = k_ranges.to(torch.int32).to("cuda") k_ranges = k_ranges.to(torch.int32).to("cuda")
cls.seqlen = seqlen cls.seqlen = seqlen
cls.num_frame = num_frame
cls.q_ranges = q_ranges cls.q_ranges = q_ranges
cls.k_ranges = k_ranges cls.k_ranges = k_ranges
cls.attn_type_map = attn_type_map cls.attn_type_map = attn_type_map
logger.info(f"NbhdAttnWeight Update: seqlen={seqlen}, num_frame={num_frame}") logger.info(f"NbhdAttnWeight Update: seqlen={seqlen}")
sparsity = 1 - mask.sum().item() / mask.numel() sparsity = 1 - mask.sum().item() / mask.numel()
logger.info(f"Attention sparsity: {sparsity}") logger.info(f"Attention sparsity: {sparsity}")
...@@ -111,8 +110,7 @@ class NbhdAttnWeight(AttnWeightTemplate): ...@@ -111,8 +110,7 @@ class NbhdAttnWeight(AttnWeightTemplate):
k: [seqlen, head_num, head_dim] k: [seqlen, head_num, head_dim]
v: [seqlen, head_num, head_dim] v: [seqlen, head_num, head_dim]
""" """
num_frame = 21 self.prepare_mask(seqlen=q.shape[0])
self.prepare_mask(seqlen=q.shape[0], num_frame=num_frame)
out = magi_ffa_func( out = magi_ffa_func(
q, q,
k, k,
......
...@@ -299,8 +299,8 @@ class SvgAttnWeight(AttnWeightTemplate): ...@@ -299,8 +299,8 @@ class SvgAttnWeight(AttnWeightTemplate):
sample_mse_max_row = None sample_mse_max_row = None
num_sampled_rows = None num_sampled_rows = None
context_length = None context_length = None
num_frame = None attnmap_frame_num = None
frame_size = None seqlen = None
sparsity = None sparsity = None
mask_name_list = ["spatial", "temporal"] mask_name_list = ["spatial", "temporal"]
attention_masks = None attention_masks = None
...@@ -325,18 +325,18 @@ class SvgAttnWeight(AttnWeightTemplate): ...@@ -325,18 +325,18 @@ class SvgAttnWeight(AttnWeightTemplate):
self.sparse_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs") self.sparse_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs")
@classmethod @classmethod
def prepare_mask(cls, num_frame, frame_size): def prepare_mask(cls, seqlen):
# Use class attributes so updates affect all instances of this class # Use class attributes so updates affect all instances of this class
if num_frame == cls.num_frame and frame_size == cls.frame_size: if seqlen == cls.seqlen:
return return
cls.num_frame = num_frame frame_size = seqlen // cls.attnmap_frame_num
cls.frame_size = frame_size cls.attention_masks = [get_attention_mask(mask_name, cls.sample_mse_max_row, cls.context_length, cls.attnmap_frame_num, frame_size) for mask_name in cls.mask_name_list]
cls.attention_masks = [get_attention_mask(mask_name, cls.sample_mse_max_row, cls.context_length, num_frame, frame_size) for mask_name in cls.mask_name_list] multiplier = diag_width = sparsity_to_width(cls.sparsity, cls.context_length, cls.attnmap_frame_num, frame_size)
multiplier = diag_width = sparsity_to_width(cls.sparsity, cls.context_length, num_frame, frame_size)
cls.block_mask = prepare_flexattention( cls.block_mask = prepare_flexattention(
1, cls.head_num, cls.head_dim, torch.bfloat16, "cuda", cls.context_length, cls.context_length, num_frame, frame_size, diag_width=diag_width, multiplier=multiplier 1, cls.head_num, cls.head_dim, torch.bfloat16, "cuda", cls.context_length, cls.context_length, cls.attnmap_frame_num, frame_size, diag_width=diag_width, multiplier=multiplier
) )
logger.info(f"SvgAttnWeight Update: num_frame={num_frame}, frame_size={frame_size}") cls.seqlen = seqlen
logger.info(f"SvgAttnWeight Update: seqlen={seqlen}")
def apply( def apply(
self, self,
...@@ -354,18 +354,19 @@ class SvgAttnWeight(AttnWeightTemplate): ...@@ -354,18 +354,19 @@ class SvgAttnWeight(AttnWeightTemplate):
v = v.unsqueeze(0).transpose(1, 2) v = v.unsqueeze(0).transpose(1, 2)
bs, num_heads, seq_len, dim = q.size() bs, num_heads, seq_len, dim = q.size()
num_frame = 21 self.prepare_mask(seq_len)
self.prepare_mask(num_frame=num_frame, frame_size=seq_len // num_frame)
sampled_mses = self.sample_mse(q, k, v) sampled_mses = self.sample_mse(q, k, v)
best_mask_idx = torch.argmin(sampled_mses, dim=0) best_mask_idx = torch.argmin(sampled_mses, dim=0)
output_hidden_states = torch.zeros_like(q) output_hidden_states = torch.zeros_like(q)
query_out, key_out, value_out = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v) query_out, key_out, value_out = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v)
query_out, key_out, value_out = self.fast_sparse_head_placement(q, k, v, query_out, key_out, value_out, best_mask_idx, self.context_length, self.num_frame, self.frame_size) query_out, key_out, value_out = self.fast_sparse_head_placement(
q, k, v, query_out, key_out, value_out, best_mask_idx, self.context_length, self.attnmap_frame_num, seq_len // self.attnmap_frame_num
)
hidden_states = self.sparse_attention(query_out, key_out, value_out) hidden_states = self.sparse_attention(query_out, key_out, value_out)
wan_hidden_states_placement(hidden_states, output_hidden_states, best_mask_idx, self.context_length, self.num_frame, self.frame_size) wan_hidden_states_placement(hidden_states, output_hidden_states, best_mask_idx, self.context_length, self.attnmap_frame_num, seq_len // self.attnmap_frame_num)
return output_hidden_states.reshape(bs, num_heads, seq_len, dim).transpose(1, 2).reshape(bs * seq_len, -1) return output_hidden_states.reshape(bs, num_heads, seq_len, dim).transpose(1, 2).reshape(bs * seq_len, -1)
......
...@@ -192,6 +192,8 @@ class WanSelfAttention(WeightModule): ...@@ -192,6 +192,8 @@ class WanSelfAttention(WeightModule):
context_length=self.config.get("svg_context_length", 0), context_length=self.config.get("svg_context_length", 0),
sparsity=self.config.get("svg_sparsity", 0.25), sparsity=self.config.get("svg_sparsity", 0.25),
) )
if self.config["self_attn_1_type"] in ["svg_attn", "nbhd_attn"]:
attention_weights_cls.attnmap_frame_num = self.config["attnmap_frame_num"]
self.add_module("self_attn_1", attention_weights_cls()) self.add_module("self_attn_1", attention_weights_cls())
if self.config["seq_parallel"]: if self.config["seq_parallel"]:
......
...@@ -71,6 +71,8 @@ def set_config(args): ...@@ -71,6 +71,8 @@ def set_config(args):
logger.warning(f"`num_frames - 1` has to be divisible by {config['vae_stride'][0]}. Rounding to the nearest number.") logger.warning(f"`num_frames - 1` has to be divisible by {config['vae_stride'][0]}. Rounding to the nearest number.")
config["target_video_length"] = config["target_video_length"] // config["vae_stride"][0] * config["vae_stride"][0] + 1 config["target_video_length"] = config["target_video_length"] // config["vae_stride"][0] * config["vae_stride"][0] + 1
config["attnmap_frame_num"] = ((config["target_video_length"] - 1) // config["vae_stride"][0] + 1) // config["patch_size"][0]
return config return config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment