Commit 5b56dc56 authored by Dongz's avatar Dongz Committed by GitHub
Browse files

[major]: deprecated attention functions (#35)

parent ad0237f9
import torch import torch
import numpy as np import numpy as np
from einops import rearrange from einops import rearrange
from lightx2v.attentions import attention
from .utils import taylor_cache_init, derivative_approximation, taylor_formula from .utils import taylor_cache_init, derivative_approximation, taylor_formula
from ..utils_bf16 import apply_rotary_emb from ..utils_bf16 import apply_rotary_emb
from ..transformer_infer import HunyuanTransformerInfer from ..transformer_infer import HunyuanTransformerInfer
...@@ -118,8 +117,7 @@ class HunyuanTransformerInferTaylorCaching(HunyuanTransformerInfer): ...@@ -118,8 +117,7 @@ class HunyuanTransformerInferTaylorCaching(HunyuanTransformerInfer):
v = torch.cat((img_v, txt_v), dim=0) v = torch.cat((img_v, txt_v), dim=0)
if not self.parallel_attention: if not self.parallel_attention:
attn = attention( attn = weights.double_attn.apply(
attention_type=self.attention_type,
q=q, q=q,
k=k, k=k,
v=v, v=v,
...@@ -284,8 +282,7 @@ class HunyuanTransformerInferTaylorCaching(HunyuanTransformerInfer): ...@@ -284,8 +282,7 @@ class HunyuanTransformerInferTaylorCaching(HunyuanTransformerInfer):
k = torch.cat((img_k, txt_k), dim=0) k = torch.cat((img_k, txt_k), dim=0)
if not self.parallel_attention: if not self.parallel_attention:
attn = attention( attn = weights.single_attn.apply(
attention_type=self.attention_type,
q=q, q=q,
k=k, k=k,
v=v, v=v,
......
import torch import torch
import math import math
from einops import rearrange from einops import rearrange
from lightx2v.attentions import attention
class HunyuanPreInfer: class HunyuanPreInfer:
...@@ -107,7 +106,7 @@ class HunyuanPreInfer: ...@@ -107,7 +106,7 @@ class HunyuanPreInfer:
normx = weights.txt_in_individual_token_refiner_blocks_0_norm1.apply(txt_in_input_embed) normx = weights.txt_in_individual_token_refiner_blocks_0_norm1.apply(txt_in_input_embed)
qkv = weights.txt_in_individual_token_refiner_blocks_0_self_attn_qkv.apply(normx) qkv = weights.txt_in_individual_token_refiner_blocks_0_self_attn_qkv.apply(normx)
q, k, v = rearrange(qkv.unsqueeze(0), "B L (K H D) -> K B L H D", K=3, H=self.heads_num) q, k, v = rearrange(qkv.unsqueeze(0), "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
attn = attention(attention_type="torch_sdpa", q=q, k=k, v=v, attn_mask=self_attn_mask)[0] attn = weights.txt_in_attn_1.apply(q=q, k=k, v=v, attn_mask=self_attn_mask)[0]
out = weights.txt_in_individual_token_refiner_blocks_0_self_attn_proj.apply(attn) out = weights.txt_in_individual_token_refiner_blocks_0_self_attn_proj.apply(attn)
out_1 = txt_in_input_embed + out * gate_msa out_1 = txt_in_input_embed + out * gate_msa
out = weights.txt_in_individual_token_refiner_blocks_0_norm2.apply(out_1) out = weights.txt_in_individual_token_refiner_blocks_0_norm2.apply(out_1)
...@@ -126,7 +125,7 @@ class HunyuanPreInfer: ...@@ -126,7 +125,7 @@ class HunyuanPreInfer:
q, k, v = rearrange(qkv.unsqueeze(0), "B L (K H D) -> K B L H D", K=3, H=self.heads_num) q, k, v = rearrange(qkv.unsqueeze(0), "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
attn = attention(attention_type="torch_sdpa", q=q, k=k, v=v, attn_mask=self_attn_mask)[0] attn = weights.txt_in_attn_1.apply(q=q, k=k, v=v, attn_mask=self_attn_mask)[0]
out = weights.txt_in_individual_token_refiner_blocks_1_self_attn_proj.apply(attn) out = weights.txt_in_individual_token_refiner_blocks_1_self_attn_proj.apply(attn)
out_1 = txt_in_input_embed + out * gate_msa out_1 = txt_in_input_embed + out * gate_msa
......
import torch import torch
from einops import rearrange from einops import rearrange
from lightx2v.attentions import attention
from .utils_bf16 import apply_rotary_emb from .utils_bf16 import apply_rotary_emb
from lightx2v.common.offload.manager import WeightStreamManager from lightx2v.common.offload.manager import WeightStreamManager
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
...@@ -120,8 +119,7 @@ class HunyuanTransformerInfer: ...@@ -120,8 +119,7 @@ class HunyuanTransformerInfer:
v = torch.cat((img_v, txt_v), dim=0) v = torch.cat((img_v, txt_v), dim=0)
if not self.parallel_attention: if not self.parallel_attention:
attn = attention( attn = weights.double_attn.apply(
attention_type=self.attention_type,
q=q, q=q,
k=k, k=k,
v=v, v=v,
...@@ -263,8 +261,7 @@ class HunyuanTransformerInfer: ...@@ -263,8 +261,7 @@ class HunyuanTransformerInfer:
k = torch.cat((img_k, txt_k), dim=0) k = torch.cat((img_k, txt_k), dim=0)
if not self.parallel_attention: if not self.parallel_attention:
attn = attention( attn = weights.single_attn.apply(
attention_type=self.attention_type,
q=q, q=q,
k=k, k=k,
v=v, v=v,
......
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, CONV3D_WEIGHT_REGISTER from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, CONV3D_WEIGHT_REGISTER, ATTN_WEIGHT_REGISTER
from lightx2v.common.modules.weight_module import WeightModule from lightx2v.common.modules.weight_module import WeightModule
...@@ -79,3 +79,6 @@ class HunyuanPreWeights(WeightModule): ...@@ -79,3 +79,6 @@ class HunyuanPreWeights(WeightModule):
self.add_module("vector_in_out_layer", MM_WEIGHT_REGISTER["Default"]("vector_in.out_layer.weight", "vector_in.out_layer.bias")) self.add_module("vector_in_out_layer", MM_WEIGHT_REGISTER["Default"]("vector_in.out_layer.weight", "vector_in.out_layer.bias"))
self.add_module("guidance_in_mlp_0", MM_WEIGHT_REGISTER["Default"]("guidance_in.mlp.0.weight", "guidance_in.mlp.0.bias")) self.add_module("guidance_in_mlp_0", MM_WEIGHT_REGISTER["Default"]("guidance_in.mlp.0.weight", "guidance_in.mlp.0.bias"))
self.add_module("guidance_in_mlp_2", MM_WEIGHT_REGISTER["Default"]("guidance_in.mlp.2.weight", "guidance_in.mlp.2.bias")) self.add_module("guidance_in_mlp_2", MM_WEIGHT_REGISTER["Default"]("guidance_in.mlp.2.weight", "guidance_in.mlp.2.bias"))
# attention weights section
self.add_module("txt_in_attn_1", ATTN_WEIGHT_REGISTER["torch_sdpa"]())
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, RMS_WEIGHT_REGISTER from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, RMS_WEIGHT_REGISTER, ATTN_WEIGHT_REGISTER
from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList
...@@ -40,12 +40,16 @@ class HunyuanTransformerDoubleBlock(WeightModule): ...@@ -40,12 +40,16 @@ class HunyuanTransformerDoubleBlock(WeightModule):
self.add_module("txt_mlp_fc1", MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.txt_mlp.fc1.weight", f"double_blocks.{self.block_index}.txt_mlp.fc1.bias")) self.add_module("txt_mlp_fc1", MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.txt_mlp.fc1.weight", f"double_blocks.{self.block_index}.txt_mlp.fc1.bias"))
self.add_module("txt_mlp_fc2", MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.txt_mlp.fc2.weight", f"double_blocks.{self.block_index}.txt_mlp.fc2.bias")) self.add_module("txt_mlp_fc2", MM_WEIGHT_REGISTER[mm_type](f"double_blocks.{self.block_index}.txt_mlp.fc2.weight", f"double_blocks.{self.block_index}.txt_mlp.fc2.bias"))
# attention weights section
self.add_module("double_attn", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
class HunyuanTransformerSingleBlock(WeightModule): class HunyuanTransformerSingleBlock(WeightModule):
def __init__(self, block_index, config): def __init__(self, block_index, config):
super().__init__() super().__init__()
self.block_index = block_index self.block_index = block_index
self.config = config self.config = config
self.sparge = config.get("sparge", False)
if self.config["do_mm_calib"]: if self.config["do_mm_calib"]:
mm_type = "Calib" mm_type = "Calib"
...@@ -57,3 +61,11 @@ class HunyuanTransformerSingleBlock(WeightModule): ...@@ -57,3 +61,11 @@ class HunyuanTransformerSingleBlock(WeightModule):
self.add_module("q_norm", RMS_WEIGHT_REGISTER["sgl-kernel"](f"single_blocks.{self.block_index}.q_norm.weight", eps=1e-6)) self.add_module("q_norm", RMS_WEIGHT_REGISTER["sgl-kernel"](f"single_blocks.{self.block_index}.q_norm.weight", eps=1e-6))
self.add_module("k_norm", RMS_WEIGHT_REGISTER["sgl-kernel"](f"single_blocks.{self.block_index}.k_norm.weight", eps=1e-6)) self.add_module("k_norm", RMS_WEIGHT_REGISTER["sgl-kernel"](f"single_blocks.{self.block_index}.k_norm.weight", eps=1e-6))
self.add_module("modulation", MM_WEIGHT_REGISTER[mm_type](f"single_blocks.{self.block_index}.modulation.linear.weight", f"single_blocks.{self.block_index}.modulation.linear.bias")) self.add_module("modulation", MM_WEIGHT_REGISTER[mm_type](f"single_blocks.{self.block_index}.modulation.linear.weight", f"single_blocks.{self.block_index}.modulation.linear.bias"))
# attention weights section
if self.sparge:
# load sparge attention weights
#! todo
pass
else:
self.add_module("single_attn", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
import torch import torch
import math import math
from ..utils import compute_freqs, compute_freqs_causvid, compute_freqs_dist, apply_rotary_emb from ..utils import compute_freqs, compute_freqs_causvid, compute_freqs_dist, apply_rotary_emb
from lightx2v.attentions import attention
from lightx2v.common.offload.manager import WeightStreamManager from lightx2v.common.offload.manager import WeightStreamManager
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from ..transformer_infer import WanTransformerInfer from ..transformer_infer import WanTransformerInfer
...@@ -125,8 +124,7 @@ class WanTransformerInferCausVid(WanTransformerInfer): ...@@ -125,8 +124,7 @@ class WanTransformerInferCausVid(WanTransformerInfer):
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q=q, k=self.kv_cache[block_idx]["k"][:kv_end], k_lens=torch.tensor([kv_end], dtype=torch.int32, device=k.device)) cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q=q, k=self.kv_cache[block_idx]["k"][:kv_end], k_lens=torch.tensor([kv_end], dtype=torch.int32, device=k.device))
if not self.parallel_attention: if not self.parallel_attention:
attn_out = attention( attn_out = weights.self_attn_1.apply(
attention_type=self.attention_type,
q=q, q=q,
k=self.kv_cache[block_idx]["k"][:kv_end], k=self.kv_cache[block_idx]["k"][:kv_end],
v=self.kv_cache[block_idx]["v"][:kv_end], v=self.kv_cache[block_idx]["v"][:kv_end],
...@@ -164,8 +162,15 @@ class WanTransformerInferCausVid(WanTransformerInfer): ...@@ -164,8 +162,15 @@ class WanTransformerInferCausVid(WanTransformerInfer):
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k, k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device)) cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k, k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device))
attn_out = attention( attn_out = weights.cross_attn_1.apply(
attention_type=self.attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=lq, max_seqlen_kv=lk, model_cls=self.config["model_cls"] q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=lq,
max_seqlen_kv=lk,
model_cls=self.config["model_cls"],
) )
# TODO: Implement I2V inference for causvid model # TODO: Implement I2V inference for causvid model
......
import torch import torch
from .utils import compute_freqs, compute_freqs_dist, apply_rotary_emb from .utils import compute_freqs, compute_freqs_dist, apply_rotary_emb
from lightx2v.attentions import attention
from lightx2v.common.offload.manager import WeightStreamManager from lightx2v.common.offload.manager import WeightStreamManager
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
......
...@@ -25,6 +25,7 @@ class WanTransformerAttentionBlock(WeightModule): ...@@ -25,6 +25,7 @@ class WanTransformerAttentionBlock(WeightModule):
self.task = task self.task = task
self.config = config self.config = config
self.quant_method = config["mm_config"].get("quant_method", None) self.quant_method = config["mm_config"].get("quant_method", None)
self.sparge = config.get("sparge", False)
self.add_module("self_attn_q", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.q.weight", f"blocks.{self.block_index}.self_attn.q.bias")) self.add_module("self_attn_q", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.q.weight", f"blocks.{self.block_index}.self_attn.q.bias"))
self.add_module("self_attn_k", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.k.weight", f"blocks.{self.block_index}.self_attn.k.bias")) self.add_module("self_attn_k", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.k.weight", f"blocks.{self.block_index}.self_attn.k.bias"))
...@@ -44,8 +45,8 @@ class WanTransformerAttentionBlock(WeightModule): ...@@ -44,8 +45,8 @@ class WanTransformerAttentionBlock(WeightModule):
self.add_module("ffn_0", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.ffn.0.weight", f"blocks.{self.block_index}.ffn.0.bias")) self.add_module("ffn_0", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.ffn.0.weight", f"blocks.{self.block_index}.ffn.0.bias"))
self.add_module("ffn_2", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.ffn.2.weight", f"blocks.{self.block_index}.ffn.2.bias")) self.add_module("ffn_2", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.ffn.2.weight", f"blocks.{self.block_index}.ffn.2.bias"))
# attention weights # attention weights section
if self.config["sparge"]: if self.sparge:
assert self.config["sparge_ckpt"], "sparge_ckpt must be set when sparge is True" assert self.config["sparge_ckpt"], "sparge_ckpt must be set when sparge is True"
self.add_module("self_attn_1", ATTN_WEIGHT_REGISTER["Sparge"](f"blocks.{self.block_index}")) self.add_module("self_attn_1", ATTN_WEIGHT_REGISTER["Sparge"](f"blocks.{self.block_index}"))
self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]()) self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
...@@ -61,7 +62,7 @@ class WanTransformerAttentionBlock(WeightModule): ...@@ -61,7 +62,7 @@ class WanTransformerAttentionBlock(WeightModule):
self.add_module("cross_attn_2", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]()) self.add_module("cross_attn_2", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
# load attn weights # load attn weights
if self.config["sparge"]: if self.sparge:
assert self.config["sparge_ckpt"], "sparge_ckpt must be set when sparge is True" assert self.config["sparge_ckpt"], "sparge_ckpt must be set when sparge is True"
sparge_ckpt = torch.load(self.config["sparge_ckpt"]) sparge_ckpt = torch.load(self.config["sparge_ckpt"])
self.self_attn_1.load(sparge_ckpt) self.self_attn_1.load(sparge_ckpt)
......
...@@ -20,7 +20,6 @@ def get_default_config(): ...@@ -20,7 +20,6 @@ def get_default_config():
"strength_model": 1.0, "strength_model": 1.0,
"mm_config": {}, "mm_config": {},
"use_prompt_enhancer": False, "use_prompt_enhancer": False,
"sparge": False,
} }
return default_config return default_config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment