Commit db6296f2 authored by helloyongyang's avatar helloyongyang
Browse files

Support flashinfer for nbhd attn

parent adc66e8d
from .flash_attn import FlashAttn2Weight, FlashAttn3Weight
from .nbhd_attn import NbhdAttnWeight
from .nbhd_attn import NbhdAttnWeight, NbhdAttnWeightFlashInfer
from .radial_attn import RadialAttnWeight
from .ring_attn import RingAttnWeight
from .sage_attn import SageAttn2Weight, SageAttn3Weight
......
......@@ -6,6 +6,11 @@ try:
except ImportError:
magi_ffa_func = None
try:
import flashinfer
except ImportError:
flashinfer = None
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
......@@ -124,3 +129,68 @@ class NbhdAttnWeight(AttnWeightTemplate):
auto_range_merge=True,
)[0]
return out.reshape(out.shape[0], -1)
@ATTN_WEIGHT_REGISTER("nbhd_attn_flashinfer")
class NbhdAttnWeightFlashInfer(AttnWeightTemplate):
block_size = 128
seqlen = None
attnmap_frame_num = None
coefficient = [1.0, 0.5, 0.056]
min_width = 1.0
sparse_wrapper = None
def __init__(self):
self.config = {}
@classmethod
@torch.compiler.disable
def prepare_mask(cls, seqlen, head_num, head_dim):
if seqlen == cls.seqlen:
return
block_num = (seqlen + cls.block_size - 1) // cls.block_size
block_num_per_frame = seqlen / cls.attnmap_frame_num / cls.block_size
mask = generate_nbhd_mask(block_num_per_frame, block_num, cls.attnmap_frame_num, coefficient=cls.coefficient, min_width=cls.min_width, device="cpu")
mask = mask.unsqueeze(0).repeat(head_num, 1, 1)
block_rowcol_size = torch.ones(block_num, dtype=torch.int32) * cls.block_size
block_rowcol_size[-1] = seqlen - cls.block_size * (block_num - 1)
block_rowcol_size = block_rowcol_size.unsqueeze(0).repeat(head_num, 1)
float_workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
cls.sparse_wrapper = flashinfer.sparse.VariableBlockSparseAttentionWrapper(float_workspace_buffer, backend="fa2")
cls.sparse_wrapper.plan(
block_mask_map=mask,
block_row_sz=block_rowcol_size,
block_col_sz=block_rowcol_size,
num_qo_heads=head_num,
num_kv_heads=head_num,
head_dim=head_dim,
q_data_type=torch.bfloat16,
)
cls.seqlen = seqlen
logger.info(f"NbhdAttnWeight Update: seqlen={seqlen}")
sparsity = 1 - mask.sum().item() / mask.numel()
logger.info(f"Attention sparsity: {sparsity}")
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
):
"""
q: [seqlen, head_num, head_dim]
k: [seqlen, head_num, head_dim]
v: [seqlen, head_num, head_dim]
"""
self.prepare_mask(seqlen=q.shape[0], head_num=q.shape[1], head_dim=q.shape[2])
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
out = self.sparse_wrapper.run(q, k, v)
out = out.transpose(0, 1)
return out.reshape(out.shape[0], -1)
......@@ -192,10 +192,10 @@ class WanSelfAttention(WeightModule):
context_length=self.config.get("svg_context_length", 0),
sparsity=self.config.get("svg_sparsity", 0.25),
)
if self.config["self_attn_1_type"] in ["svg_attn", "radial_attn", "nbhd_attn"]:
if self.config["self_attn_1_type"] in ["svg_attn", "radial_attn", "nbhd_attn", "nbhd_attn_flashinfer"]:
attention_weights_cls.attnmap_frame_num = self.config["attnmap_frame_num"]
# nbhd_attn setting
if self.config["self_attn_1_type"] == "nbhd_attn":
if self.config["self_attn_1_type"] in ["nbhd_attn", "nbhd_attn_flashinfer"]:
if "nbhd_attn_setting" in self.config:
if "coefficient" in self.config["nbhd_attn_setting"]:
attention_weights_cls.coefficient = self.config["nbhd_attn_setting"]["coefficient"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment