Commit d7c99b0c authored by Yang Yong(雍洋)'s avatar Yang Yong(雍洋) Committed by GitHub
Browse files

refactor seko model adapter (#298)



## Contributing Guidelines

We have prepared a `pre-commit` hook to enforce consistent code
formatting across the project. If your code complies with the standards,
you should not see any errors, you can clean up your code following the
steps below:

1. Install the required dependencies:

```shell
    pip install ruff pre-commit
```

2. Then, run the following command before commit:

```shell
    pre-commit run --all-files
```

3. Finally, please double-check your code to ensure it complies with the
following additional specifications as much as possible:
- Avoid hard-coding local paths: Make sure your submissions do not
include hard-coded local paths, as these paths are specific to
individual development environments and can cause compatibility issues.
Use relative paths or configuration files instead.
- Clear error handling: Implement clear error-handling mechanisms in
your code so that error messages can accurately indicate the location of
the problem, possible causes, and suggested solutions, facilitating
quick debugging.
- Detailed comments and documentation: Add comments to complex code
sections and provide comprehensive documentation to explain the
functionality of the code, input-output requirements, and potential
error scenarios.

Thank you for your contributions!

---------
Co-authored-by: default avatargushiqiao <975033167@qq.com>
parent 091a2a85
......@@ -175,6 +175,10 @@ class WeightModuleList(WeightModule):
def __getitem__(self, idx):
return self._list[idx]
def __setitem__(self, idx, module):
self._list[idx] = module
self.add_module(str(idx), module)
def __len__(self):
return len(self._list)
......
......@@ -2,7 +2,6 @@ try:
import flash_attn
except ModuleNotFoundError:
flash_attn = None
import math
import torch
import torch.nn as nn
......@@ -10,8 +9,6 @@ import torch.nn.functional as F
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from einops import rearrange
from lightx2v.models.input_encoders.hf.q_linear import SglQuantLinearFp8
def linear_interpolation(features, output_len: int):
features = features.transpose(1, 2)
......@@ -84,16 +81,17 @@ def calculate_n_query_tokens(hidden_states, sp_rank, sp_size, n_tokens_per_rank,
n_query_tokens = n_tokens_per_rank
if n_query_tokens > 0:
hidden_states_aligned = hidden_states[:, :n_query_tokens]
hidden_states_tail = hidden_states[:, n_query_tokens:]
hidden_states_aligned = hidden_states[:n_query_tokens]
hidden_states_tail = hidden_states[n_query_tokens:]
else:
# for ranks that should be excluded from cross-attn, fake cross-attn will be applied so that FSDP works.
hidden_states_aligned = hidden_states[:, :1]
hidden_states_tail = hidden_states[:, 1:]
hidden_states_aligned = hidden_states[:1]
hidden_states_tail = hidden_states[1:]
return n_query_tokens, hidden_states_aligned, hidden_states_tail
'''
class PerceiverAttentionCA(nn.Module):
def __init__(self, dim_head=128, heads=16, kv_dim=2048, adaLN: bool = False, quantized=False, quant_scheme=None):
super().__init__()
......@@ -156,6 +154,7 @@ class PerceiverAttentionCA(nn.Module):
)
out = rearrange(out, "(B L) H C -> B L (H C)", B=batchsize)
return self.to_out(out) * gate
'''
class AudioProjection(nn.Module):
......@@ -258,9 +257,10 @@ class AudioAdapter(nn.Module):
# self.num_tokens = num_tokens * 4
self.num_tokens_x4 = num_tokens * 4
self.audio_pe = nn.Parameter(torch.randn(self.num_tokens_x4, mlp_dims[-1] // num_tokens) * 0.02)
ca_num = math.ceil(base_num_layers / interval)
# ca_num = math.ceil(base_num_layers / interval)
self.base_num_layers = base_num_layers
self.interval = interval
"""
self.ca = nn.ModuleList(
[
PerceiverAttentionCA(
......@@ -274,6 +274,7 @@ class AudioAdapter(nn.Module):
for _ in range(ca_num)
]
)
"""
self.dim = attention_head_dim * num_attention_heads
if time_freq_dim > 0:
self.time_embedding = TimeEmbedding(self.dim, time_freq_dim, self.dim * 3)
......
import os
import torch.distributed as dist
from lightx2v.models.networks.wan.infer.audio.post_infer import WanAudioPostInfer
from lightx2v.models.networks.wan.infer.audio.pre_infer import WanAudioPreInfer
from lightx2v.models.networks.wan.infer.audio.transformer_infer import WanAudioTransformerInfer
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.audio.transformer_weights import WanAudioTransformerWeights
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights,
)
from lightx2v.utils.utils import load_weights
class WanAudioModel(WanModel):
pre_weight_class = WanPreWeights
post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights
transformer_weight_class = WanAudioTransformerWeights
def __init__(self, model_path, config, device):
self.config = config
self._load_adapter_ckpt()
super().__init__(model_path, config, device)
def _load_adapter_ckpt(self):
if self.config.get("adapter_model_path", None) is None:
if self.config.get("adapter_quantized", False):
if self.config.get("adapter_quant_scheme", None) in ["fp8", "fp8-q8f"]:
adapter_model_name = "audio_adapter_model_fp8.safetensors"
elif self.config.get("adapter_quant_scheme", None) == "int8":
adapter_model_name = "audio_adapter_model_int8.safetensors"
else:
raise ValueError(f"Unsupported quant_scheme: {self.config.get('adapter_quant_scheme', None)}")
else:
adapter_model_name = "audio_adapter_model.safetensors"
self.config.adapter_model_path = os.path.join(self.config.model_path, adapter_model_name)
adapter_offload = self.config.get("cpu_offload", False)
self.adapter_weights_dict = load_weights(self.config.adapter_model_path, cpu_offload=adapter_offload, remove_key="audio")
if not adapter_offload and not dist.is_initialized():
for key, value in self.adapter_weights_dict.items():
self.adapter_weights_dict[key] = value.cuda()
def _init_infer_class(self):
super()._init_infer_class()
self.pre_infer_class = WanAudioPreInfer
self.post_infer_class = WanAudioPostInfer
self.transformer_infer_class = WanAudioTransformerInfer
def set_audio_adapter(self, audio_adapter):
self.audio_adapter = audio_adapter
self.transformer_infer.set_audio_adapter(self.audio_adapter)
import torch
import torch.distributed as dist
from loguru import logger
try:
import flash_attn # noqa: F401
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
logger.info("flash_attn_varlen_func not found, please install flash_attn2 first")
flash_attn_varlen_func = None
from lightx2v.models.input_encoders.hf.seko_audio.audio_adapter import calculate_n_query_tokens, get_qk_lens_audio_range
from lightx2v.models.networks.wan.infer.offload.transformer_infer import WanOffloadTransformerInfer
......@@ -8,69 +16,70 @@ from lightx2v.models.networks.wan.infer.offload.transformer_infer import WanOffl
class WanAudioTransformerInfer(WanOffloadTransformerInfer):
def __init__(self, config):
super().__init__(config)
self.num_tokens = 32
self.num_tokens_x4 = self.num_tokens * 4
def set_audio_adapter(self, audio_adapter):
self.audio_adapter = audio_adapter
@torch.no_grad()
def post_process(self, x, y, c_gate_msa, pre_infer_out):
x = super().post_process(x, y, c_gate_msa, pre_infer_out)
x = self.modify_hidden_states(
hidden_states=x.to(self.infer_dtype),
grid_sizes=pre_infer_out.grid_sizes.tensor,
ca_block=self.audio_adapter.ca[self.block_idx],
audio_encoder_output=pre_infer_out.adapter_output["audio_encoder_output"],
t_emb=self.scheduler.audio_adapter_t_emb,
weight=1.0,
seq_p_group=self.seq_p_group,
)
return x
self.has_post_adapter = True
self.phases_num = 4
@torch.no_grad()
def modify_hidden_states(self, hidden_states, grid_sizes, ca_block, audio_encoder_output, t_emb, weight, seq_p_group):
"""thw specify the latent_frame, latent_height, latenf_width after
hidden_states is patchified.
latent_frame does not include the reference images so that the
audios and hidden_states are strictly aligned
"""
if len(hidden_states.shape) == 2: # 扩展batchsize dim
hidden_states = hidden_states.unsqueeze(0) # bs = 1
def infer_post_adapter(self, phase, x, pre_infer_out):
grid_sizes = pre_infer_out.grid_sizes.tensor
audio_encoder_output = pre_infer_out.adapter_output["audio_encoder_output"]
total_tokens = grid_sizes[0].prod()
pre_frame_tokens = grid_sizes[0][1:].prod()
n_tokens = total_tokens - pre_frame_tokens # 去掉ref image的token数
ori_dtype = hidden_states.dtype
device = hidden_states.device
n_tokens_per_rank = torch.tensor(hidden_states.size(1), dtype=torch.int32, device=device)
ori_dtype = x.dtype
device = x.device
n_tokens_per_rank = torch.tensor(x.size(0), dtype=torch.int32, device=device)
if seq_p_group is not None:
sp_size = dist.get_world_size(seq_p_group)
sp_rank = dist.get_rank(seq_p_group)
if self.seq_p_group is not None:
sp_size = dist.get_world_size(self.seq_p_group)
sp_rank = dist.get_rank(self.seq_p_group)
else:
sp_size = 1
sp_rank = 0
n_query_tokens, hidden_states_aligned, hidden_states_tail = calculate_n_query_tokens(hidden_states, sp_rank, sp_size, n_tokens_per_rank, n_tokens)
n_query_tokens, hidden_states_aligned, hidden_states_tail = calculate_n_query_tokens(x, sp_rank, sp_size, n_tokens_per_rank, n_tokens)
q_lens, k_lens, max_seqlen_q, max_seqlen_k, t0, t1 = get_qk_lens_audio_range(
n_tokens_per_rank=n_tokens_per_rank, n_query_tokens=n_query_tokens, n_tokens_per_frame=pre_frame_tokens, sp_rank=sp_rank, num_tokens_x4=self.num_tokens_x4
n_tokens_per_rank=n_tokens_per_rank, n_query_tokens=n_query_tokens, n_tokens_per_frame=pre_frame_tokens, sp_rank=sp_rank, num_tokens_x4=128
)
# ca_block:CrossAttention函数
if self.audio_adapter.cpu_offload:
ca_block.to("cuda")
residual = ca_block(audio_encoder_output[:, t0:t1], hidden_states_aligned, t_emb, q_lens, k_lens, max_seqlen_q, max_seqlen_k) * weight
if self.audio_adapter.cpu_offload:
ca_block.to("cpu")
audio_encoder_output = audio_encoder_output[:, t0:t1].reshape(-1, audio_encoder_output.size(-1))
residual = self.perceiver_attention_ca(phase, audio_encoder_output, hidden_states_aligned, self.scheduler.audio_adapter_t_emb, q_lens, k_lens, max_seqlen_q, max_seqlen_k)
residual = residual.to(ori_dtype) # audio做了CrossAttention之后以Residual的方式注入
if n_query_tokens == 0:
residual = residual * 0.0
hidden_states = torch.cat([hidden_states_aligned + residual, hidden_states_tail], dim=1)
x = torch.cat([hidden_states_aligned + residual, hidden_states_tail], dim=0)
return x
if len(hidden_states.shape) == 3: #
hidden_states = hidden_states.squeeze(0) # bs = 1
return hidden_states
@torch.no_grad()
def perceiver_attention_ca(self, phase, audio_encoder_output, latents, t_emb, q_lens, k_lens, max_seqlen_q, max_seqlen_k):
audio_encoder_output = phase.norm_kv.apply(audio_encoder_output)
shift, scale, gate = (t_emb + phase.shift_scale_gate.tensor)[0].chunk(3, dim=0)
norm_q = phase.norm_q.apply(latents)
latents = norm_q * (1 + scale) + shift
q = phase.to_q.apply(latents)
k, v = phase.to_kv.apply(audio_encoder_output).chunk(2, dim=-1)
q = q.view(q.size(0), self.num_heads, self.head_dim)
k = k.view(k.size(0), self.num_heads, self.head_dim)
v = v.view(v.size(0), self.num_heads, self.head_dim)
out = flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True),
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
)
out = out.view(-1, self.num_heads * self.head_dim)
return phase.to_out.apply(out) * gate
......@@ -215,11 +215,11 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
x,
self.phase_params["y"],
self.phase_params["c_gate_msa"],
pre_infer_out,
)
if hasattr(cur_phase, "after_proj"):
pre_infer_out.adapter_output["hints"].append(cur_phase.after_proj.apply(x))
elif cur_phase_idx == 3:
x = self.infer_post_adapter(cur_phase, x, pre_infer_out)
return x
def clear_offload_params(self, pre_infer_out):
......
......@@ -15,6 +15,7 @@ class WanTransformerInfer(BaseTransformerInfer):
self.attention_type = config.get("attention_type", "flash_attn2")
self.blocks_num = config.num_layers
self.phases_num = 3
self.has_post_adapter = False
self.num_heads = config.num_heads
self.head_dim = config.dim // config.num_heads
self.window_size = config.get("window_size", (-1, -1))
......@@ -106,9 +107,12 @@ class WanTransformerInfer(BaseTransformerInfer):
x, attn_out = self.infer_cross_attn(block.compute_phases[1], x, pre_infer_out.context, y_out, gate_msa)
y = self.infer_ffn(block.compute_phases[2], x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa, pre_infer_out)
if hasattr(block.compute_phases[2], "after_proj"):
pre_infer_out.adapter_output["hints"].append(block.compute_phases[2].after_proj.apply(x))
if self.has_post_adapter:
x = self.infer_post_adapter(block.compute_phases[3], x, pre_infer_out)
return x
def pre_process(self, modulation, embed0):
......@@ -294,7 +298,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return y
def post_process(self, x, y, c_gate_msa, pre_infer_out):
def post_process(self, x, y, c_gate_msa, pre_infer_out=None):
if self.sensitive_layer_dtype != self.infer_dtype:
x = x.to(self.sensitive_layer_dtype) + y.to(self.sensitive_layer_dtype) * c_gate_msa.squeeze()
else:
......
......@@ -136,8 +136,12 @@ class WanModel:
def _load_ckpt(self, unified_dtype, sensitive_layer):
safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_original_ckpt", subdir="original")
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
weight_dict = {}
for file_path in safetensors_files:
if self.config.get("adapter_model_path", None) is not None:
if self.config.adapter_model_path == file_path:
continue
file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
weight_dict.update(file_weights)
return weight_dict
......@@ -236,6 +240,9 @@ class WanModel:
if self.config.get("device_mesh") is not None:
weight_dict = self._load_weights_distribute(weight_dict, is_weight_loader)
if hasattr(self, "adapter_weights_dict"):
weight_dict.update(self.adapter_weights_dict)
self.original_weight_dict = weight_dict
else:
self.original_weight_dict = weight_dict
......
from lightx2v.common.modules.weight_module import WeightModule
from lightx2v.models.networks.wan.weights.transformer_weights import WanTransformerWeights
from lightx2v.utils.registry_factory import (
LN_WEIGHT_REGISTER,
MM_WEIGHT_REGISTER,
TENSOR_REGISTER,
)
class WanAudioTransformerWeights(WanTransformerWeights):
def __init__(self, config):
super().__init__(config)
for i in range(self.blocks_num):
self.blocks[i].compute_phases.append(
WanAudioAdapterCA(
i,
f"ca",
self.task,
self.mm_type,
self.config,
self.blocks[i].lazy_load,
self.blocks[i].lazy_load_file,
)
)
class WanAudioAdapterCA(WeightModule):
def __init__(self, block_index, block_prefix, task, mm_type, config, lazy_load, lazy_load_file):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
self.task = task
self.config = config
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.add_module(
"to_q",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{block_index}.to_q.weight",
f"{block_prefix}.{block_index}.to_q.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"to_kv",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{block_index}.to_kv.weight",
f"{block_prefix}.{block_index}.to_kv.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"to_out",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{block_index}.to_out.weight",
f"{block_prefix}.{block_index}.to_out.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"norm_kv",
LN_WEIGHT_REGISTER["Default"](
f"{block_prefix}.{block_index}.norm_kv.weight",
f"{block_prefix}.{block_index}.norm_kv.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"norm_q",
LN_WEIGHT_REGISTER["Default"](),
)
self.add_module(
"shift_scale_gate",
TENSOR_REGISTER["Default"](
f"{block_prefix}.{block_index}.shift_scale_gate",
self.lazy_load,
self.lazy_load_file,
),
)
......@@ -734,18 +734,9 @@ class WanAudioRunner(WanRunner): # type:ignore
quant_scheme=self.config.get("adapter_quant_scheme", None),
cpu_offload=audio_adapter_offload,
)
audio_adapter.to(device)
if self.config.get("adapter_quantized", False):
if self.config.get("adapter_quant_scheme", None) in ["fp8", "fp8-q8f"]:
model_name = "audio_adapter_model_fp8.safetensors"
elif self.config.get("adapter_quant_scheme", None) == "int8":
model_name = "audio_adapter_model_int8.safetensors"
else:
raise ValueError(f"Unsupported quant_scheme: {self.config.get('adapter_quant_scheme', None)}")
else:
model_name = "audio_adapter_model.safetensors"
weights_dict = load_weights(os.path.join(self.config["model_path"], model_name), cpu_offload=audio_adapter_offload)
audio_adapter.to(device)
weights_dict = load_weights(self.config.adapter_model_path, cpu_offload=audio_adapter_offload, remove_key="ca")
audio_adapter.load_state_dict(weights_dict, strict=False)
return audio_adapter.to(dtype=GET_DTYPE())
......@@ -754,7 +745,6 @@ class WanAudioRunner(WanRunner): # type:ignore
with ProfilingContext4DebugL2("Load audio encoder and adapter"):
self.audio_encoder = self.load_audio_encoder()
self.audio_adapter = self.load_audio_adapter()
self.model.set_audio_adapter(self.audio_adapter)
def set_target_shape(self):
"""Set target shape for generation"""
......
......@@ -369,7 +369,7 @@ def quantize_model(
for key in pbar:
pbar.set_postfix(current_key=key, refresh=False)
if ignore_key is not None and ignore_key in key:
if ignore_key is not None and any(ig_key in key for ig_key in ignore_key):
del weights[key]
continue
......@@ -682,7 +682,7 @@ def main():
"wan_dit": {
"key_idx": 2,
"target_keys": ["self_attn", "cross_attn", "ffn"],
"ignore_key": None,
"ignore_key": ["ca", "audio"],
},
"hunyuan_dit": {
"key_idx": 2,
......
......@@ -4,7 +4,7 @@ from safetensors.torch import save_file
from lightx2v.utils.quant_utils import FloatQuantizer
model_path = "/data/nvme0/models/Wan2.1-R2V721-Audio-14B-720P/audio_adapter_model.safetensors"
model_path = "/data/nvme0/gushiqiao/models/Lightx2v_models/SekoTalk-Distill/audio_adapter_model.safetensors"
state_dict = {}
with safetensors.safe_open(model_path, framework="pt", device="cpu") as f:
......@@ -13,10 +13,10 @@ with safetensors.safe_open(model_path, framework="pt", device="cpu") as f:
new_state_dict = {}
new_model_path = "/data/nvme0/models/Wan2.1-R2V721-Audio-14B-720P/audio_adapter_model_fp8.safetensors"
new_model_path = "/data/nvme0/gushiqiao/models/Lightx2v_models/seko-new/SekoTalk-Distill-fp8/audio_adapter_model_fp8.safetensors"
for key in state_dict.keys():
if key.startswith("ca") and ".to" in key and "weight" in key and "to_kv" not in key:
if key.startswith("ca") and ".to" in key and "weight" in key:
print(key, state_dict[key].dtype)
weight = state_dict[key].to(torch.float32).cuda()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment