Commit d7c99b0c authored by Yang Yong(雍洋)'s avatar Yang Yong(雍洋) Committed by GitHub
Browse files

refactor seko model adapter (#298)



## Contributing Guidelines

We have prepared a `pre-commit` hook to enforce consistent code
formatting across the project. If your code complies with the standards,
you should not see any errors, you can clean up your code following the
steps below:

1. Install the required dependencies:

```shell
    pip install ruff pre-commit
```

2. Then, run the following command before commit:

```shell
    pre-commit run --all-files
```

3. Finally, please double-check your code to ensure it complies with the
following additional specifications as much as possible:
- Avoid hard-coding local paths: Make sure your submissions do not
include hard-coded local paths, as these paths are specific to
individual development environments and can cause compatibility issues.
Use relative paths or configuration files instead.
- Clear error handling: Implement clear error-handling mechanisms in
your code so that error messages can accurately indicate the location of
the problem, possible causes, and suggested solutions, facilitating
quick debugging.
- Detailed comments and documentation: Add comments to complex code
sections and provide comprehensive documentation to explain the
functionality of the code, input-output requirements, and potential
error scenarios.

Thank you for your contributions!

---------
Co-authored-by: default avatargushiqiao <975033167@qq.com>
parent 091a2a85
...@@ -175,6 +175,10 @@ class WeightModuleList(WeightModule): ...@@ -175,6 +175,10 @@ class WeightModuleList(WeightModule):
def __getitem__(self, idx): def __getitem__(self, idx):
return self._list[idx] return self._list[idx]
def __setitem__(self, idx, module):
self._list[idx] = module
self.add_module(str(idx), module)
def __len__(self): def __len__(self):
return len(self._list) return len(self._list)
......
...@@ -2,7 +2,6 @@ try: ...@@ -2,7 +2,6 @@ try:
import flash_attn import flash_attn
except ModuleNotFoundError: except ModuleNotFoundError:
flash_attn = None flash_attn = None
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -10,8 +9,6 @@ import torch.nn.functional as F ...@@ -10,8 +9,6 @@ import torch.nn.functional as F
from diffusers.models.embeddings import TimestepEmbedding, Timesteps from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from einops import rearrange from einops import rearrange
from lightx2v.models.input_encoders.hf.q_linear import SglQuantLinearFp8
def linear_interpolation(features, output_len: int): def linear_interpolation(features, output_len: int):
features = features.transpose(1, 2) features = features.transpose(1, 2)
...@@ -84,16 +81,17 @@ def calculate_n_query_tokens(hidden_states, sp_rank, sp_size, n_tokens_per_rank, ...@@ -84,16 +81,17 @@ def calculate_n_query_tokens(hidden_states, sp_rank, sp_size, n_tokens_per_rank,
n_query_tokens = n_tokens_per_rank n_query_tokens = n_tokens_per_rank
if n_query_tokens > 0: if n_query_tokens > 0:
hidden_states_aligned = hidden_states[:, :n_query_tokens] hidden_states_aligned = hidden_states[:n_query_tokens]
hidden_states_tail = hidden_states[:, n_query_tokens:] hidden_states_tail = hidden_states[n_query_tokens:]
else: else:
# for ranks that should be excluded from cross-attn, fake cross-attn will be applied so that FSDP works. # for ranks that should be excluded from cross-attn, fake cross-attn will be applied so that FSDP works.
hidden_states_aligned = hidden_states[:, :1] hidden_states_aligned = hidden_states[:1]
hidden_states_tail = hidden_states[:, 1:] hidden_states_tail = hidden_states[1:]
return n_query_tokens, hidden_states_aligned, hidden_states_tail return n_query_tokens, hidden_states_aligned, hidden_states_tail
'''
class PerceiverAttentionCA(nn.Module): class PerceiverAttentionCA(nn.Module):
def __init__(self, dim_head=128, heads=16, kv_dim=2048, adaLN: bool = False, quantized=False, quant_scheme=None): def __init__(self, dim_head=128, heads=16, kv_dim=2048, adaLN: bool = False, quantized=False, quant_scheme=None):
super().__init__() super().__init__()
...@@ -156,6 +154,7 @@ class PerceiverAttentionCA(nn.Module): ...@@ -156,6 +154,7 @@ class PerceiverAttentionCA(nn.Module):
) )
out = rearrange(out, "(B L) H C -> B L (H C)", B=batchsize) out = rearrange(out, "(B L) H C -> B L (H C)", B=batchsize)
return self.to_out(out) * gate return self.to_out(out) * gate
'''
class AudioProjection(nn.Module): class AudioProjection(nn.Module):
...@@ -258,9 +257,10 @@ class AudioAdapter(nn.Module): ...@@ -258,9 +257,10 @@ class AudioAdapter(nn.Module):
# self.num_tokens = num_tokens * 4 # self.num_tokens = num_tokens * 4
self.num_tokens_x4 = num_tokens * 4 self.num_tokens_x4 = num_tokens * 4
self.audio_pe = nn.Parameter(torch.randn(self.num_tokens_x4, mlp_dims[-1] // num_tokens) * 0.02) self.audio_pe = nn.Parameter(torch.randn(self.num_tokens_x4, mlp_dims[-1] // num_tokens) * 0.02)
ca_num = math.ceil(base_num_layers / interval) # ca_num = math.ceil(base_num_layers / interval)
self.base_num_layers = base_num_layers self.base_num_layers = base_num_layers
self.interval = interval self.interval = interval
"""
self.ca = nn.ModuleList( self.ca = nn.ModuleList(
[ [
PerceiverAttentionCA( PerceiverAttentionCA(
...@@ -274,6 +274,7 @@ class AudioAdapter(nn.Module): ...@@ -274,6 +274,7 @@ class AudioAdapter(nn.Module):
for _ in range(ca_num) for _ in range(ca_num)
] ]
) )
"""
self.dim = attention_head_dim * num_attention_heads self.dim = attention_head_dim * num_attention_heads
if time_freq_dim > 0: if time_freq_dim > 0:
self.time_embedding = TimeEmbedding(self.dim, time_freq_dim, self.dim * 3) self.time_embedding = TimeEmbedding(self.dim, time_freq_dim, self.dim * 3)
......
import os
import torch.distributed as dist
from lightx2v.models.networks.wan.infer.audio.post_infer import WanAudioPostInfer from lightx2v.models.networks.wan.infer.audio.post_infer import WanAudioPostInfer
from lightx2v.models.networks.wan.infer.audio.pre_infer import WanAudioPreInfer from lightx2v.models.networks.wan.infer.audio.pre_infer import WanAudioPreInfer
from lightx2v.models.networks.wan.infer.audio.transformer_infer import WanAudioTransformerInfer from lightx2v.models.networks.wan.infer.audio.transformer_infer import WanAudioTransformerInfer
from lightx2v.models.networks.wan.model import WanModel from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.audio.transformer_weights import WanAudioTransformerWeights
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.transformer_weights import ( from lightx2v.utils.utils import load_weights
WanTransformerWeights,
)
class WanAudioModel(WanModel): class WanAudioModel(WanModel):
pre_weight_class = WanPreWeights pre_weight_class = WanPreWeights
post_weight_class = WanPostWeights post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights transformer_weight_class = WanAudioTransformerWeights
def __init__(self, model_path, config, device): def __init__(self, model_path, config, device):
self.config = config
self._load_adapter_ckpt()
super().__init__(model_path, config, device) super().__init__(model_path, config, device)
def _load_adapter_ckpt(self):
if self.config.get("adapter_model_path", None) is None:
if self.config.get("adapter_quantized", False):
if self.config.get("adapter_quant_scheme", None) in ["fp8", "fp8-q8f"]:
adapter_model_name = "audio_adapter_model_fp8.safetensors"
elif self.config.get("adapter_quant_scheme", None) == "int8":
adapter_model_name = "audio_adapter_model_int8.safetensors"
else:
raise ValueError(f"Unsupported quant_scheme: {self.config.get('adapter_quant_scheme', None)}")
else:
adapter_model_name = "audio_adapter_model.safetensors"
self.config.adapter_model_path = os.path.join(self.config.model_path, adapter_model_name)
adapter_offload = self.config.get("cpu_offload", False)
self.adapter_weights_dict = load_weights(self.config.adapter_model_path, cpu_offload=adapter_offload, remove_key="audio")
if not adapter_offload and not dist.is_initialized():
for key, value in self.adapter_weights_dict.items():
self.adapter_weights_dict[key] = value.cuda()
def _init_infer_class(self): def _init_infer_class(self):
super()._init_infer_class() super()._init_infer_class()
self.pre_infer_class = WanAudioPreInfer self.pre_infer_class = WanAudioPreInfer
self.post_infer_class = WanAudioPostInfer self.post_infer_class = WanAudioPostInfer
self.transformer_infer_class = WanAudioTransformerInfer self.transformer_infer_class = WanAudioTransformerInfer
def set_audio_adapter(self, audio_adapter):
self.audio_adapter = audio_adapter
self.transformer_infer.set_audio_adapter(self.audio_adapter)
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from loguru import logger
try:
import flash_attn # noqa: F401
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
logger.info("flash_attn_varlen_func not found, please install flash_attn2 first")
flash_attn_varlen_func = None
from lightx2v.models.input_encoders.hf.seko_audio.audio_adapter import calculate_n_query_tokens, get_qk_lens_audio_range from lightx2v.models.input_encoders.hf.seko_audio.audio_adapter import calculate_n_query_tokens, get_qk_lens_audio_range
from lightx2v.models.networks.wan.infer.offload.transformer_infer import WanOffloadTransformerInfer from lightx2v.models.networks.wan.infer.offload.transformer_infer import WanOffloadTransformerInfer
...@@ -8,69 +16,70 @@ from lightx2v.models.networks.wan.infer.offload.transformer_infer import WanOffl ...@@ -8,69 +16,70 @@ from lightx2v.models.networks.wan.infer.offload.transformer_infer import WanOffl
class WanAudioTransformerInfer(WanOffloadTransformerInfer): class WanAudioTransformerInfer(WanOffloadTransformerInfer):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_tokens = 32 self.has_post_adapter = True
self.num_tokens_x4 = self.num_tokens * 4 self.phases_num = 4
def set_audio_adapter(self, audio_adapter):
self.audio_adapter = audio_adapter
@torch.no_grad()
def post_process(self, x, y, c_gate_msa, pre_infer_out):
x = super().post_process(x, y, c_gate_msa, pre_infer_out)
x = self.modify_hidden_states(
hidden_states=x.to(self.infer_dtype),
grid_sizes=pre_infer_out.grid_sizes.tensor,
ca_block=self.audio_adapter.ca[self.block_idx],
audio_encoder_output=pre_infer_out.adapter_output["audio_encoder_output"],
t_emb=self.scheduler.audio_adapter_t_emb,
weight=1.0,
seq_p_group=self.seq_p_group,
)
return x
@torch.no_grad() @torch.no_grad()
def modify_hidden_states(self, hidden_states, grid_sizes, ca_block, audio_encoder_output, t_emb, weight, seq_p_group): def infer_post_adapter(self, phase, x, pre_infer_out):
"""thw specify the latent_frame, latent_height, latenf_width after grid_sizes = pre_infer_out.grid_sizes.tensor
hidden_states is patchified. audio_encoder_output = pre_infer_out.adapter_output["audio_encoder_output"]
latent_frame does not include the reference images so that the
audios and hidden_states are strictly aligned
"""
if len(hidden_states.shape) == 2: # 扩展batchsize dim
hidden_states = hidden_states.unsqueeze(0) # bs = 1
total_tokens = grid_sizes[0].prod() total_tokens = grid_sizes[0].prod()
pre_frame_tokens = grid_sizes[0][1:].prod() pre_frame_tokens = grid_sizes[0][1:].prod()
n_tokens = total_tokens - pre_frame_tokens # 去掉ref image的token数 n_tokens = total_tokens - pre_frame_tokens # 去掉ref image的token数
ori_dtype = hidden_states.dtype ori_dtype = x.dtype
device = hidden_states.device device = x.device
n_tokens_per_rank = torch.tensor(hidden_states.size(1), dtype=torch.int32, device=device) n_tokens_per_rank = torch.tensor(x.size(0), dtype=torch.int32, device=device)
if seq_p_group is not None: if self.seq_p_group is not None:
sp_size = dist.get_world_size(seq_p_group) sp_size = dist.get_world_size(self.seq_p_group)
sp_rank = dist.get_rank(seq_p_group) sp_rank = dist.get_rank(self.seq_p_group)
else: else:
sp_size = 1 sp_size = 1
sp_rank = 0 sp_rank = 0
n_query_tokens, hidden_states_aligned, hidden_states_tail = calculate_n_query_tokens(hidden_states, sp_rank, sp_size, n_tokens_per_rank, n_tokens) n_query_tokens, hidden_states_aligned, hidden_states_tail = calculate_n_query_tokens(x, sp_rank, sp_size, n_tokens_per_rank, n_tokens)
q_lens, k_lens, max_seqlen_q, max_seqlen_k, t0, t1 = get_qk_lens_audio_range( q_lens, k_lens, max_seqlen_q, max_seqlen_k, t0, t1 = get_qk_lens_audio_range(
n_tokens_per_rank=n_tokens_per_rank, n_query_tokens=n_query_tokens, n_tokens_per_frame=pre_frame_tokens, sp_rank=sp_rank, num_tokens_x4=self.num_tokens_x4 n_tokens_per_rank=n_tokens_per_rank, n_query_tokens=n_query_tokens, n_tokens_per_frame=pre_frame_tokens, sp_rank=sp_rank, num_tokens_x4=128
) )
# ca_block:CrossAttention函数
if self.audio_adapter.cpu_offload: audio_encoder_output = audio_encoder_output[:, t0:t1].reshape(-1, audio_encoder_output.size(-1))
ca_block.to("cuda") residual = self.perceiver_attention_ca(phase, audio_encoder_output, hidden_states_aligned, self.scheduler.audio_adapter_t_emb, q_lens, k_lens, max_seqlen_q, max_seqlen_k)
residual = ca_block(audio_encoder_output[:, t0:t1], hidden_states_aligned, t_emb, q_lens, k_lens, max_seqlen_q, max_seqlen_k) * weight
if self.audio_adapter.cpu_offload:
ca_block.to("cpu")
residual = residual.to(ori_dtype) # audio做了CrossAttention之后以Residual的方式注入 residual = residual.to(ori_dtype) # audio做了CrossAttention之后以Residual的方式注入
if n_query_tokens == 0: if n_query_tokens == 0:
residual = residual * 0.0 residual = residual * 0.0
hidden_states = torch.cat([hidden_states_aligned + residual, hidden_states_tail], dim=1) x = torch.cat([hidden_states_aligned + residual, hidden_states_tail], dim=0)
return x
if len(hidden_states.shape) == 3: # @torch.no_grad()
hidden_states = hidden_states.squeeze(0) # bs = 1 def perceiver_attention_ca(self, phase, audio_encoder_output, latents, t_emb, q_lens, k_lens, max_seqlen_q, max_seqlen_k):
return hidden_states audio_encoder_output = phase.norm_kv.apply(audio_encoder_output)
shift, scale, gate = (t_emb + phase.shift_scale_gate.tensor)[0].chunk(3, dim=0)
norm_q = phase.norm_q.apply(latents)
latents = norm_q * (1 + scale) + shift
q = phase.to_q.apply(latents)
k, v = phase.to_kv.apply(audio_encoder_output).chunk(2, dim=-1)
q = q.view(q.size(0), self.num_heads, self.head_dim)
k = k.view(k.size(0), self.num_heads, self.head_dim)
v = v.view(v.size(0), self.num_heads, self.head_dim)
out = flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True),
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
)
out = out.view(-1, self.num_heads * self.head_dim)
return phase.to_out.apply(out) * gate
...@@ -215,11 +215,11 @@ class WanOffloadTransformerInfer(WanTransformerInfer): ...@@ -215,11 +215,11 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
x, x,
self.phase_params["y"], self.phase_params["y"],
self.phase_params["c_gate_msa"], self.phase_params["c_gate_msa"],
pre_infer_out,
) )
if hasattr(cur_phase, "after_proj"): if hasattr(cur_phase, "after_proj"):
pre_infer_out.adapter_output["hints"].append(cur_phase.after_proj.apply(x)) pre_infer_out.adapter_output["hints"].append(cur_phase.after_proj.apply(x))
elif cur_phase_idx == 3:
x = self.infer_post_adapter(cur_phase, x, pre_infer_out)
return x return x
def clear_offload_params(self, pre_infer_out): def clear_offload_params(self, pre_infer_out):
......
...@@ -15,6 +15,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -15,6 +15,7 @@ class WanTransformerInfer(BaseTransformerInfer):
self.attention_type = config.get("attention_type", "flash_attn2") self.attention_type = config.get("attention_type", "flash_attn2")
self.blocks_num = config.num_layers self.blocks_num = config.num_layers
self.phases_num = 3 self.phases_num = 3
self.has_post_adapter = False
self.num_heads = config.num_heads self.num_heads = config.num_heads
self.head_dim = config.dim // config.num_heads self.head_dim = config.dim // config.num_heads
self.window_size = config.get("window_size", (-1, -1)) self.window_size = config.get("window_size", (-1, -1))
...@@ -106,9 +107,12 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -106,9 +107,12 @@ class WanTransformerInfer(BaseTransformerInfer):
x, attn_out = self.infer_cross_attn(block.compute_phases[1], x, pre_infer_out.context, y_out, gate_msa) x, attn_out = self.infer_cross_attn(block.compute_phases[1], x, pre_infer_out.context, y_out, gate_msa)
y = self.infer_ffn(block.compute_phases[2], x, attn_out, c_shift_msa, c_scale_msa) y = self.infer_ffn(block.compute_phases[2], x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa, pre_infer_out) x = self.post_process(x, y, c_gate_msa, pre_infer_out)
if hasattr(block.compute_phases[2], "after_proj"): if hasattr(block.compute_phases[2], "after_proj"):
pre_infer_out.adapter_output["hints"].append(block.compute_phases[2].after_proj.apply(x)) pre_infer_out.adapter_output["hints"].append(block.compute_phases[2].after_proj.apply(x))
if self.has_post_adapter:
x = self.infer_post_adapter(block.compute_phases[3], x, pre_infer_out)
return x return x
def pre_process(self, modulation, embed0): def pre_process(self, modulation, embed0):
...@@ -294,7 +298,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -294,7 +298,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return y return y
def post_process(self, x, y, c_gate_msa, pre_infer_out): def post_process(self, x, y, c_gate_msa, pre_infer_out=None):
if self.sensitive_layer_dtype != self.infer_dtype: if self.sensitive_layer_dtype != self.infer_dtype:
x = x.to(self.sensitive_layer_dtype) + y.to(self.sensitive_layer_dtype) * c_gate_msa.squeeze() x = x.to(self.sensitive_layer_dtype) + y.to(self.sensitive_layer_dtype) * c_gate_msa.squeeze()
else: else:
......
...@@ -136,8 +136,12 @@ class WanModel: ...@@ -136,8 +136,12 @@ class WanModel:
def _load_ckpt(self, unified_dtype, sensitive_layer): def _load_ckpt(self, unified_dtype, sensitive_layer):
safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_original_ckpt", subdir="original") safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_original_ckpt", subdir="original")
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors")) safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
weight_dict = {} weight_dict = {}
for file_path in safetensors_files: for file_path in safetensors_files:
if self.config.get("adapter_model_path", None) is not None:
if self.config.adapter_model_path == file_path:
continue
file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer) file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
weight_dict.update(file_weights) weight_dict.update(file_weights)
return weight_dict return weight_dict
...@@ -236,6 +240,9 @@ class WanModel: ...@@ -236,6 +240,9 @@ class WanModel:
if self.config.get("device_mesh") is not None: if self.config.get("device_mesh") is not None:
weight_dict = self._load_weights_distribute(weight_dict, is_weight_loader) weight_dict = self._load_weights_distribute(weight_dict, is_weight_loader)
if hasattr(self, "adapter_weights_dict"):
weight_dict.update(self.adapter_weights_dict)
self.original_weight_dict = weight_dict self.original_weight_dict = weight_dict
else: else:
self.original_weight_dict = weight_dict self.original_weight_dict = weight_dict
......
from lightx2v.common.modules.weight_module import WeightModule
from lightx2v.models.networks.wan.weights.transformer_weights import WanTransformerWeights
from lightx2v.utils.registry_factory import (
LN_WEIGHT_REGISTER,
MM_WEIGHT_REGISTER,
TENSOR_REGISTER,
)
class WanAudioTransformerWeights(WanTransformerWeights):
def __init__(self, config):
super().__init__(config)
for i in range(self.blocks_num):
self.blocks[i].compute_phases.append(
WanAudioAdapterCA(
i,
f"ca",
self.task,
self.mm_type,
self.config,
self.blocks[i].lazy_load,
self.blocks[i].lazy_load_file,
)
)
class WanAudioAdapterCA(WeightModule):
def __init__(self, block_index, block_prefix, task, mm_type, config, lazy_load, lazy_load_file):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
self.task = task
self.config = config
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.add_module(
"to_q",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{block_index}.to_q.weight",
f"{block_prefix}.{block_index}.to_q.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"to_kv",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{block_index}.to_kv.weight",
f"{block_prefix}.{block_index}.to_kv.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"to_out",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{block_index}.to_out.weight",
f"{block_prefix}.{block_index}.to_out.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"norm_kv",
LN_WEIGHT_REGISTER["Default"](
f"{block_prefix}.{block_index}.norm_kv.weight",
f"{block_prefix}.{block_index}.norm_kv.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"norm_q",
LN_WEIGHT_REGISTER["Default"](),
)
self.add_module(
"shift_scale_gate",
TENSOR_REGISTER["Default"](
f"{block_prefix}.{block_index}.shift_scale_gate",
self.lazy_load,
self.lazy_load_file,
),
)
...@@ -734,18 +734,9 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -734,18 +734,9 @@ class WanAudioRunner(WanRunner): # type:ignore
quant_scheme=self.config.get("adapter_quant_scheme", None), quant_scheme=self.config.get("adapter_quant_scheme", None),
cpu_offload=audio_adapter_offload, cpu_offload=audio_adapter_offload,
) )
audio_adapter.to(device)
if self.config.get("adapter_quantized", False):
if self.config.get("adapter_quant_scheme", None) in ["fp8", "fp8-q8f"]:
model_name = "audio_adapter_model_fp8.safetensors"
elif self.config.get("adapter_quant_scheme", None) == "int8":
model_name = "audio_adapter_model_int8.safetensors"
else:
raise ValueError(f"Unsupported quant_scheme: {self.config.get('adapter_quant_scheme', None)}")
else:
model_name = "audio_adapter_model.safetensors"
weights_dict = load_weights(os.path.join(self.config["model_path"], model_name), cpu_offload=audio_adapter_offload) audio_adapter.to(device)
weights_dict = load_weights(self.config.adapter_model_path, cpu_offload=audio_adapter_offload, remove_key="ca")
audio_adapter.load_state_dict(weights_dict, strict=False) audio_adapter.load_state_dict(weights_dict, strict=False)
return audio_adapter.to(dtype=GET_DTYPE()) return audio_adapter.to(dtype=GET_DTYPE())
...@@ -754,7 +745,6 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -754,7 +745,6 @@ class WanAudioRunner(WanRunner): # type:ignore
with ProfilingContext4DebugL2("Load audio encoder and adapter"): with ProfilingContext4DebugL2("Load audio encoder and adapter"):
self.audio_encoder = self.load_audio_encoder() self.audio_encoder = self.load_audio_encoder()
self.audio_adapter = self.load_audio_adapter() self.audio_adapter = self.load_audio_adapter()
self.model.set_audio_adapter(self.audio_adapter)
def set_target_shape(self): def set_target_shape(self):
"""Set target shape for generation""" """Set target shape for generation"""
......
...@@ -369,7 +369,7 @@ def quantize_model( ...@@ -369,7 +369,7 @@ def quantize_model(
for key in pbar: for key in pbar:
pbar.set_postfix(current_key=key, refresh=False) pbar.set_postfix(current_key=key, refresh=False)
if ignore_key is not None and ignore_key in key: if ignore_key is not None and any(ig_key in key for ig_key in ignore_key):
del weights[key] del weights[key]
continue continue
...@@ -682,7 +682,7 @@ def main(): ...@@ -682,7 +682,7 @@ def main():
"wan_dit": { "wan_dit": {
"key_idx": 2, "key_idx": 2,
"target_keys": ["self_attn", "cross_attn", "ffn"], "target_keys": ["self_attn", "cross_attn", "ffn"],
"ignore_key": None, "ignore_key": ["ca", "audio"],
}, },
"hunyuan_dit": { "hunyuan_dit": {
"key_idx": 2, "key_idx": 2,
......
...@@ -4,7 +4,7 @@ from safetensors.torch import save_file ...@@ -4,7 +4,7 @@ from safetensors.torch import save_file
from lightx2v.utils.quant_utils import FloatQuantizer from lightx2v.utils.quant_utils import FloatQuantizer
model_path = "/data/nvme0/models/Wan2.1-R2V721-Audio-14B-720P/audio_adapter_model.safetensors" model_path = "/data/nvme0/gushiqiao/models/Lightx2v_models/SekoTalk-Distill/audio_adapter_model.safetensors"
state_dict = {} state_dict = {}
with safetensors.safe_open(model_path, framework="pt", device="cpu") as f: with safetensors.safe_open(model_path, framework="pt", device="cpu") as f:
...@@ -13,10 +13,10 @@ with safetensors.safe_open(model_path, framework="pt", device="cpu") as f: ...@@ -13,10 +13,10 @@ with safetensors.safe_open(model_path, framework="pt", device="cpu") as f:
new_state_dict = {} new_state_dict = {}
new_model_path = "/data/nvme0/models/Wan2.1-R2V721-Audio-14B-720P/audio_adapter_model_fp8.safetensors" new_model_path = "/data/nvme0/gushiqiao/models/Lightx2v_models/seko-new/SekoTalk-Distill-fp8/audio_adapter_model_fp8.safetensors"
for key in state_dict.keys(): for key in state_dict.keys():
if key.startswith("ca") and ".to" in key and "weight" in key and "to_kv" not in key: if key.startswith("ca") and ".to" in key and "weight" in key:
print(key, state_dict[key].dtype) print(key, state_dict[key].dtype)
weight = state_dict[key].to(torch.float32).cuda() weight = state_dict[key].to(torch.float32).cuda()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment