Unverified Commit 63f233ad authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Fix] Fix vace and animate models config bug (#351)

parent 69c2f650
...@@ -34,6 +34,10 @@ class FlashAttn2Weight(AttnWeightTemplate): ...@@ -34,6 +34,10 @@ class FlashAttn2Weight(AttnWeightTemplate):
max_seqlen_kv=None, max_seqlen_kv=None,
model_cls=None, model_cls=None,
): ):
if len(q.shape) == 3:
bs = 1
elif len(q.shape) == 4:
bs = q.shape[0]
x = flash_attn_varlen_func( x = flash_attn_varlen_func(
q, q,
k, k,
...@@ -42,7 +46,7 @@ class FlashAttn2Weight(AttnWeightTemplate): ...@@ -42,7 +46,7 @@ class FlashAttn2Weight(AttnWeightTemplate):
cu_seqlens_kv, cu_seqlens_kv,
max_seqlen_q, max_seqlen_q,
max_seqlen_kv, max_seqlen_kv,
).reshape(max_seqlen_q, -1) ).reshape(bs * max_seqlen_q, -1)
return x return x
...@@ -63,23 +67,16 @@ class FlashAttn3Weight(AttnWeightTemplate): ...@@ -63,23 +67,16 @@ class FlashAttn3Weight(AttnWeightTemplate):
model_cls=None, model_cls=None,
): ):
if len(q.shape) == 3: if len(q.shape) == 3:
x = flash_attn_varlen_func_v3( bs = 1
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
).reshape(max_seqlen_q, -1)
elif len(q.shape) == 4: elif len(q.shape) == 4:
x = flash_attn_varlen_func_v3( bs = q.shape[0]
q, x = flash_attn_varlen_func_v3(
k, q,
v, k,
cu_seqlens_q, v,
cu_seqlens_kv, cu_seqlens_q,
max_seqlen_q, cu_seqlens_kv,
max_seqlen_kv, max_seqlen_q,
).reshape(q.shape[0] * max_seqlen_q, -1) max_seqlen_kv,
).reshape(bs * max_seqlen_q, -1)
return x return x
...@@ -36,38 +36,15 @@ class SageAttn2Weight(AttnWeightTemplate): ...@@ -36,38 +36,15 @@ class SageAttn2Weight(AttnWeightTemplate):
model_cls=None, model_cls=None,
): ):
q, k, v = q.contiguous(), k.contiguous(), v.contiguous() q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
if model_cls == "hunyuan": if len(q.shape) == 3:
x1 = sageattn( bs = 1
q[: cu_seqlens_q[1]].unsqueeze(0), q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
k[: cu_seqlens_kv[1]].unsqueeze(0), elif len(q.shape) == 4:
v[: cu_seqlens_kv[1]].unsqueeze(0), bs = q.shape[0]
tensor_layout="NHD", x = sageattn(
) q,
x2 = sageattn( k,
q[cu_seqlens_q[1] :].unsqueeze(0), v,
k[cu_seqlens_kv[1] :].unsqueeze(0), tensor_layout="NHD",
v[cu_seqlens_kv[1] :].unsqueeze(0), ).view(bs * max_seqlen_q, -1)
tensor_layout="NHD",
)
x = torch.cat((x1, x2), dim=1)
x = x.view(max_seqlen_q, -1)
elif model_cls in ["wan2.1", "wan2.1_distill", "wan2.1_causvid", "wan2.1_df", "seko_talk", "wan2.2", "wan2.1_vace", "wan2.2_moe", "wan2.2_animate", "wan2.2_moe_distill", "qwen_image"]:
if len(q.shape) == 3:
x = sageattn(
q.unsqueeze(0),
k.unsqueeze(0),
v.unsqueeze(0),
tensor_layout="NHD",
)
x = x.view(max_seqlen_q, -1)
elif len(q.shape) == 4:
x = sageattn(
q,
k,
v,
tensor_layout="NHD",
)
x = x.view(q.shape[0] * max_seqlen_q, -1)
else:
raise NotImplementedError(f"Model class '{model_cls}' is not implemented in this attention implementation")
return x return x
...@@ -20,8 +20,8 @@ class WanAnimateTransformerInfer(WanOffloadTransformerInfer): ...@@ -20,8 +20,8 @@ class WanAnimateTransformerInfer(WanOffloadTransformerInfer):
kv = phase.linear1_kv.apply(x_motion.view(-1, x_motion.shape[-1])) kv = phase.linear1_kv.apply(x_motion.view(-1, x_motion.shape[-1]))
kv = kv.view(T, -1, kv.shape[-1]) kv = kv.view(T, -1, kv.shape[-1])
q = phase.linear1_q.apply(x_feat) q = phase.linear1_q.apply(x_feat)
k, v = rearrange(kv, "L N (K H D) -> K L N H D", K=2, H=self.config.num_heads) k, v = rearrange(kv, "L N (K H D) -> K L N H D", K=2, H=self.config["num_heads"])
q = rearrange(q, "S (H D) -> S H D", H=self.config.num_heads) q = rearrange(q, "S (H D) -> S H D", H=self.config["num_heads"])
q = phase.q_norm.apply(q).view(T, q.shape[0] // T, q.shape[1], q.shape[2]) q = phase.q_norm.apply(q).view(T, q.shape[0] // T, q.shape[1], q.shape[2])
k = phase.k_norm.apply(k) k = phase.k_norm.apply(k)
......
...@@ -5,8 +5,8 @@ from lightx2v.utils.envs import * ...@@ -5,8 +5,8 @@ from lightx2v.utils.envs import *
class WanVaceTransformerInfer(WanOffloadTransformerInfer): class WanVaceTransformerInfer(WanOffloadTransformerInfer):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.vace_blocks_num = len(self.config.vace_layers) self.vace_blocks_num = len(self.config["vace_layers"])
self.vace_blocks_mapping = {orig_idx: seq_idx for seq_idx, orig_idx in enumerate(self.config.vace_layers)} self.vace_blocks_mapping = {orig_idx: seq_idx for seq_idx, orig_idx in enumerate(self.config["vace_layers"])}
def infer(self, weights, pre_infer_out): def infer(self, weights, pre_infer_out):
pre_infer_out.c = self.vace_pre_process(weights.vace_patch_embedding, pre_infer_out.vace_context) pre_infer_out.c = self.vace_pre_process(weights.vace_patch_embedding, pre_infer_out.vace_context)
......
...@@ -14,7 +14,7 @@ class WanVaceTransformerWeights(WanTransformerWeights): ...@@ -14,7 +14,7 @@ class WanVaceTransformerWeights(WanTransformerWeights):
super().__init__(config) super().__init__(config)
self.patch_size = (1, 2, 2) self.patch_size = (1, 2, 2)
self.vace_blocks = WeightModuleList( self.vace_blocks = WeightModuleList(
[WanVaceTransformerAttentionBlock(self.config.vace_layers[i], i, self.task, self.mm_type, self.config, "vace_blocks") for i in range(len(self.config.vace_layers))] [WanVaceTransformerAttentionBlock(self.config["vace_layers"][i], i, self.task, self.mm_type, self.config, "vace_blocks") for i in range(len(self.config["vace_layers"]))]
) )
self.add_module("vace_blocks", self.vace_blocks) self.add_module("vace_blocks", self.vace_blocks)
......
...@@ -214,7 +214,7 @@ class DefaultRunner(BaseRunner): ...@@ -214,7 +214,7 @@ class DefaultRunner(BaseRunner):
[src_video], [src_video],
[src_mask], [src_mask],
[None if src_ref_images is None else src_ref_images.split(",")], [None if src_ref_images is None else src_ref_images.split(",")],
(self.config.target_width, self.config.target_height), (self.config["target_width"], self.config["target_height"]),
) )
self.src_ref_images = src_ref_images self.src_ref_images = src_ref_images
......
...@@ -11,7 +11,7 @@ try: ...@@ -11,7 +11,7 @@ try:
from decord import VideoReader from decord import VideoReader
except ImportError: except ImportError:
VideoReader = None VideoReader = None
logger.info("If you need run animate model, please install decord.") logger.info("If you want to run animate model, please install decord.")
from lightx2v.models.input_encoders.hf.animate.face_encoder import FaceEncoder from lightx2v.models.input_encoders.hf.animate.face_encoder import FaceEncoder
...@@ -28,7 +28,7 @@ from lightx2v.utils.utils import load_weights, remove_substrings_from_keys ...@@ -28,7 +28,7 @@ from lightx2v.utils.utils import load_weights, remove_substrings_from_keys
class WanAnimateRunner(WanRunner): class WanAnimateRunner(WanRunner):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
assert self.config.task == "animate" assert self.config["task"] == "animate"
def inputs_padding(self, array, target_len): def inputs_padding(self, array, target_len):
idx = 0 idx = 0
...@@ -161,11 +161,11 @@ class WanAnimateRunner(WanRunner): ...@@ -161,11 +161,11 @@ class WanAnimateRunner(WanRunner):
pose_latents = self.vae_encoder.encode(conditioning_pixel_values.unsqueeze(0)) # c t h w pose_latents = self.vae_encoder.encode(conditioning_pixel_values.unsqueeze(0)) # c t h w
ref_latents = self.vae_encoder.encode(self.refer_pixel_values.unsqueeze(1).unsqueeze(0)) # c t h w ref_latents = self.vae_encoder.encode(self.refer_pixel_values.unsqueeze(1).unsqueeze(0)) # c t h w
mask_ref = self.get_i2v_mask(1, self.config.lat_h, self.config.lat_w, 1) mask_ref = self.get_i2v_mask(1, self.latent_h, self.latent_w, 1)
y_ref = torch.concat([mask_ref, ref_latents]) y_ref = torch.concat([mask_ref, ref_latents])
if self.mask_reft_len > 0: if self.mask_reft_len > 0:
if self.config.replace_flag: if self.config["replace_flag"]:
y_reft = self.vae_encoder.encode( y_reft = self.vae_encoder.encode(
torch.concat( torch.concat(
[ [
...@@ -183,9 +183,9 @@ class WanAnimateRunner(WanRunner): ...@@ -183,9 +183,9 @@ class WanAnimateRunner(WanRunner):
mask_pixel_values = mask_pixel_values[:, 0, :, :] mask_pixel_values = mask_pixel_values[:, 0, :, :]
msk_reft = self.get_i2v_mask( msk_reft = self.get_i2v_mask(
self.config.lat_t, self.latent_t,
self.config.lat_h, self.latent_h,
self.config.lat_w, self.latent_w,
self.mask_reft_len, self.mask_reft_len,
mask_pixel_values=mask_pixel_values.unsqueeze(0), mask_pixel_values=mask_pixel_values.unsqueeze(0),
) )
...@@ -198,31 +198,31 @@ class WanAnimateRunner(WanRunner): ...@@ -198,31 +198,31 @@ class WanAnimateRunner(WanRunner):
size=(H, W), size=(H, W),
mode="bicubic", mode="bicubic",
), ),
torch.zeros(3, self.config.target_video_length - self.mask_reft_len, H, W, dtype=GET_DTYPE()), torch.zeros(3, self.config["target_video_length"] - self.mask_reft_len, H, W, dtype=GET_DTYPE()),
], ],
dim=1, dim=1,
) )
.cuda() .cuda()
.unsqueeze(0) .unsqueeze(0)
) )
msk_reft = self.get_i2v_mask(self.config.lat_t, self.config.lat_h, self.config.lat_w, self.mask_reft_len) msk_reft = self.get_i2v_mask(self.latent_t, self.latent_h, self.latent_w, self.mask_reft_len)
else: else:
if self.config.replace_flag: if self.config["replace_flag"]:
mask_pixel_values = 1 - mask_pixel_values mask_pixel_values = 1 - mask_pixel_values
mask_pixel_values = mask_pixel_values.permute(1, 0, 2, 3) mask_pixel_values = mask_pixel_values.permute(1, 0, 2, 3)
mask_pixel_values = F.interpolate(mask_pixel_values, size=(H // 8, W // 8), mode="nearest") mask_pixel_values = F.interpolate(mask_pixel_values, size=(H // 8, W // 8), mode="nearest")
mask_pixel_values = mask_pixel_values[:, 0, :, :] mask_pixel_values = mask_pixel_values[:, 0, :, :]
y_reft = self.vae_encoder.encode(bg_pixel_values.unsqueeze(0)) y_reft = self.vae_encoder.encode(bg_pixel_values.unsqueeze(0))
msk_reft = self.get_i2v_mask( msk_reft = self.get_i2v_mask(
self.config.lat_t, self.latent_t,
self.config.lat_h, self.latent_h,
self.config.lat_w, self.latent_w,
self.mask_reft_len, self.mask_reft_len,
mask_pixel_values=mask_pixel_values.unsqueeze(0), mask_pixel_values=mask_pixel_values.unsqueeze(0),
) )
else: else:
y_reft = self.vae_encoder.encode(torch.zeros(1, 3, self.config.target_video_length - self.mask_reft_len, H, W, dtype=GET_DTYPE(), device="cuda")) y_reft = self.vae_encoder.encode(torch.zeros(1, 3, self.config["target_video_length"] - self.mask_reft_len, H, W, dtype=GET_DTYPE(), device="cuda"))
msk_reft = self.get_i2v_mask(self.config.lat_t, self.config.lat_h, self.config.lat_w, self.mask_reft_len) msk_reft = self.get_i2v_mask(self.latent_t, self.latent_h, self.latent_w, self.mask_reft_len)
y_reft = torch.concat([msk_reft, y_reft]) y_reft = torch.concat([msk_reft, y_reft])
y = torch.concat([y_ref, y_reft], dim=1) y = torch.concat([y_ref, y_reft], dim=1)
...@@ -230,35 +230,39 @@ class WanAnimateRunner(WanRunner): ...@@ -230,35 +230,39 @@ class WanAnimateRunner(WanRunner):
return y, pose_latents return y, pose_latents
def prepare_input(self): def prepare_input(self):
src_pose_path = self.config.get("src_pose_path", None) src_pose_path = self.config["src_pose_path"] if "src_pose_path" in self.config else None
src_face_path = self.config.get("src_face_path", None) src_face_path = self.config["src_face_path"] if "src_face_path" in self.config else None
src_ref_path = self.config.get("src_ref_images", None) src_ref_path = self.config["src_ref_images"] if "src_ref_images" in self.config else None
self.cond_images, self.face_images, self.refer_images = self.prepare_source(src_pose_path, src_face_path, src_ref_path) self.cond_images, self.face_images, self.refer_images = self.prepare_source(src_pose_path, src_face_path, src_ref_path)
self.refer_pixel_values = torch.tensor(self.refer_images / 127.5 - 1, dtype=GET_DTYPE(), device="cuda").permute(2, 0, 1) # chw self.refer_pixel_values = torch.tensor(self.refer_images / 127.5 - 1, dtype=GET_DTYPE(), device="cuda").permute(2, 0, 1) # chw
self.latent_t = self.config["target_video_length"] // self.config["vae_stride"][0] + 1
self.latent_h = self.refer_pixel_values.shape[-2] // self.config["vae_stride"][1]
self.latent_w = self.refer_pixel_values.shape[-1] // self.config["vae_stride"][2]
self.input_info.latent_shape = [self.config.get("num_channels_latents", 16), self.latent_t + 1, self.latent_h, self.latent_w]
self.real_frame_len = len(self.cond_images) self.real_frame_len = len(self.cond_images)
target_len = self.get_valid_len( target_len = self.get_valid_len(
self.real_frame_len, self.real_frame_len,
self.config.target_video_length, self.config["target_video_length"],
overlap=self.config.get("refert_num", 1), overlap=self.config["refert_num"] if "refert_num" in self.config else 1,
) )
logger.info("real frames: {} target frames: {}".format(self.real_frame_len, target_len)) logger.info("real frames: {} target frames: {}".format(self.real_frame_len, target_len))
self.cond_images = self.inputs_padding(self.cond_images, target_len) self.cond_images = self.inputs_padding(self.cond_images, target_len)
self.face_images = self.inputs_padding(self.face_images, target_len) self.face_images = self.inputs_padding(self.face_images, target_len)
if self.config.get("replace_flag", False): if self.config["replace_flag"] if "replace_flag" in self.config else False:
src_bg_path = self.config.get("src_bg_path") src_bg_path = self.config["src_bg_path"]
src_mask_path = self.config.get("src_mask_path") src_mask_path = self.config["src_mask_path"]
self.bg_images, self.mask_images = self.prepare_source_for_replace(src_bg_path, src_mask_path) self.bg_images, self.mask_images = self.prepare_source_for_replace(src_bg_path, src_mask_path)
self.bg_images = self.inputs_padding(self.bg_images, target_len) self.bg_images = self.inputs_padding(self.bg_images, target_len)
self.mask_images = self.inputs_padding(self.mask_images, target_len) self.mask_images = self.inputs_padding(self.mask_images, target_len)
def get_video_segment_num(self): def get_video_segment_num(self):
total_frames = len(self.cond_images) total_frames = len(self.cond_images)
self.move_frames = self.config.target_video_length - self.config.refert_num self.move_frames = self.config["target_video_length"] - self.config["refert_num"]
if total_frames <= self.config.target_video_length: if total_frames <= self.config["target_video_length"]:
self.video_segment_num = 1 self.video_segment_num = 1
else: else:
self.video_segment_num = 1 + (total_frames - self.config.target_video_length + self.move_frames - 1) // self.move_frames self.video_segment_num = 1 + (total_frames - self.config["target_video_length"] + self.move_frames - 1) // self.move_frames
def init_run(self): def init_run(self):
self.all_out_frames = [] self.all_out_frames = []
...@@ -267,10 +271,10 @@ class WanAnimateRunner(WanRunner): ...@@ -267,10 +271,10 @@ class WanAnimateRunner(WanRunner):
@ProfilingContext4DebugL1("Run VAE Decoder") @ProfilingContext4DebugL1("Run VAE Decoder")
def run_vae_decoder(self, latents): def run_vae_decoder(self, latents):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
self.vae_decoder = self.load_vae_decoder() self.vae_decoder = self.load_vae_decoder()
images = self.vae_decoder.decode(latents[:, 1:].to(GET_DTYPE())) images = self.vae_decoder.decode(latents[:, 1:].to(GET_DTYPE()))
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
del self.vae_decoder del self.vae_decoder
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
...@@ -278,11 +282,11 @@ class WanAnimateRunner(WanRunner): ...@@ -278,11 +282,11 @@ class WanAnimateRunner(WanRunner):
def init_run_segment(self, segment_idx): def init_run_segment(self, segment_idx):
start = segment_idx * self.move_frames start = segment_idx * self.move_frames
end = start + self.config.target_video_length end = start + self.config["target_video_length"]
if start == 0: if start == 0:
self.mask_reft_len = 0 self.mask_reft_len = 0
else: else:
self.mask_reft_len = self.config.refert_num self.mask_reft_len = self.config["refert_num"]
conditioning_pixel_values = torch.tensor( conditioning_pixel_values = torch.tensor(
np.stack(self.cond_images[start:end]) / 127.5 - 1, np.stack(self.cond_images[start:end]) / 127.5 - 1,
...@@ -300,17 +304,17 @@ class WanAnimateRunner(WanRunner): ...@@ -300,17 +304,17 @@ class WanAnimateRunner(WanRunner):
height, width = self.refer_images.shape[:2] height, width = self.refer_images.shape[:2]
refer_t_pixel_values = torch.zeros( refer_t_pixel_values = torch.zeros(
3, 3,
self.config.refert_num, self.config["refert_num"],
height, height,
width, width,
device="cuda", device="cuda",
dtype=GET_DTYPE(), dtype=GET_DTYPE(),
) # c t h w ) # c t h w
else: else:
refer_t_pixel_values = self.gen_video[0, :, -self.config.refert_num :].transpose(0, 1).clone().detach() # c t h w refer_t_pixel_values = self.gen_video[0, :, -self.config["refert_num"] :].transpose(0, 1).clone().detach() # c t h w
bg_pixel_values, mask_pixel_values = None, None bg_pixel_values, mask_pixel_values = None, None
if self.config.replace_flag: if self.config["replace_flag"] if "replace_flag" in self.config else False:
bg_pixel_values = torch.tensor( bg_pixel_values = torch.tensor(
np.stack(self.bg_images[start:end]) / 127.5 - 1, np.stack(self.bg_images[start:end]) / 127.5 - 1,
device="cuda", device="cuda",
...@@ -341,24 +345,17 @@ class WanAnimateRunner(WanRunner): ...@@ -341,24 +345,17 @@ class WanAnimateRunner(WanRunner):
self.gen_video = self.gen_video[:, :, self.config["refert_num"] :] self.gen_video = self.gen_video[:, :, self.config["refert_num"] :]
self.all_out_frames.append(self.gen_video.cpu()) self.all_out_frames.append(self.gen_video.cpu())
def process_images_after_vae_decoder(self, save_video=True): def process_images_after_vae_decoder(self):
self.gen_video = torch.cat(self.all_out_frames, dim=2)[:, :, : self.real_frame_len] self.gen_video_final = torch.cat(self.all_out_frames, dim=2)[:, :, : self.real_frame_len]
del self.all_out_frames del self.all_out_frames
gc.collect() gc.collect()
super().process_images_after_vae_decoder(save_video) super().process_images_after_vae_decoder()
def set_target_shape(self):
self.config.target_video_length = self.config.target_video_length
self.config.lat_h = self.refer_pixel_values.shape[-2] // 8
self.config.lat_w = self.refer_pixel_values.shape[-1] // 8
self.config.lat_t = self.config.target_video_length // 4 + 1
self.config.target_shape = [16, self.config.lat_t + 1, self.config.lat_h, self.config.lat_w]
def run_image_encoder(self, img): # CHW def run_image_encoder(self, img): # CHW
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
self.image_encoder = self.load_image_encoder() self.image_encoder = self.load_image_encoder()
clip_encoder_out = self.image_encoder.visual([img.unsqueeze(0)]).squeeze(0).to(GET_DTYPE()) clip_encoder_out = self.image_encoder.visual([img.unsqueeze(0)]).squeeze(0).to(GET_DTYPE())
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
del self.image_encoder del self.image_encoder
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
...@@ -366,7 +363,7 @@ class WanAnimateRunner(WanRunner): ...@@ -366,7 +363,7 @@ class WanAnimateRunner(WanRunner):
def load_transformer(self): def load_transformer(self):
model = WanAnimateModel( model = WanAnimateModel(
self.config.model_path, self.config["model_path"],
self.config, self.config,
self.init_device, self.init_device,
) )
......
...@@ -17,13 +17,13 @@ from lightx2v.utils.registry_factory import RUNNER_REGISTER ...@@ -17,13 +17,13 @@ from lightx2v.utils.registry_factory import RUNNER_REGISTER
class WanVaceRunner(WanRunner): class WanVaceRunner(WanRunner):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
assert self.config.task == "vace" assert self.config["task"] == "vace"
self.vid_proc = VaceVideoProcessor( self.vid_proc = VaceVideoProcessor(
downsample=tuple([x * y for x, y in zip(self.config.vae_stride, self.config.patch_size)]), downsample=tuple([x * y for x, y in zip(self.config["vae_stride"], self.config["patch_size"])]),
min_area=720 * 1280, min_area=720 * 1280,
max_area=720 * 1280, max_area=720 * 1280,
min_fps=self.config.get("fps", 16), min_fps=self.config["fps"] if "fps" in self.config else 16,
max_fps=self.config.get("fps", 16), max_fps=self.config["fps"] if "fps" in self.config else 16,
zero_start=True, zero_start=True,
seq_len=75600, seq_len=75600,
keep_last=True, keep_last=True,
...@@ -31,7 +31,7 @@ class WanVaceRunner(WanRunner): ...@@ -31,7 +31,7 @@ class WanVaceRunner(WanRunner):
def load_transformer(self): def load_transformer(self):
model = WanVaceModel( model = WanVaceModel(
self.config.model_path, self.config["model_path"],
self.config, self.config,
self.init_device, self.init_device,
) )
...@@ -57,7 +57,7 @@ class WanVaceRunner(WanRunner): ...@@ -57,7 +57,7 @@ class WanVaceRunner(WanRunner):
src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1) src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
image_sizes.append(src_video[i].shape[2:]) image_sizes.append(src_video[i].shape[2:])
elif sub_src_video is None: elif sub_src_video is None:
src_video[i] = torch.zeros((3, self.config.target_video_length, image_size[0], image_size[1]), device=device) src_video[i] = torch.zeros((3, self.config["target_video_length"], image_size[0], image_size[1]), device=device)
src_mask[i] = torch.ones_like(src_video[i], device=device) src_mask[i] = torch.ones_like(src_video[i], device=device)
image_sizes.append(image_size) image_sizes.append(image_size)
else: else:
...@@ -89,7 +89,7 @@ class WanVaceRunner(WanRunner): ...@@ -89,7 +89,7 @@ class WanVaceRunner(WanRunner):
return src_video, src_mask, src_ref_images return src_video, src_mask, src_ref_images
def run_vae_encoder(self, frames, ref_images, masks): def run_vae_encoder(self, frames, ref_images, masks):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
self.vae_encoder = self.load_vae_encoder() self.vae_encoder = self.load_vae_encoder()
if ref_images is None: if ref_images is None:
ref_images = [None] * len(frames) ref_images = [None] * len(frames)
...@@ -118,11 +118,11 @@ class WanVaceRunner(WanRunner): ...@@ -118,11 +118,11 @@ class WanVaceRunner(WanRunner):
latent = torch.cat([*ref_latent, latent], dim=1) latent = torch.cat([*ref_latent, latent], dim=1)
cat_latents.append(latent) cat_latents.append(latent)
self.latent_shape = list(cat_latents[0].shape) self.latent_shape = list(cat_latents[0].shape)
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
del self.vae_encoder del self.vae_encoder
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
return self.get_vae_encoder_output(cat_latents, masks, ref_images) return self.get_vae_encoder_output(cat_latents, masks, ref_images), self.set_input_info_latent_shape()
def get_vae_encoder_output(self, cat_latents, masks, ref_images): def get_vae_encoder_output(self, cat_latents, masks, ref_images):
if ref_images is None: if ref_images is None:
...@@ -133,15 +133,15 @@ class WanVaceRunner(WanRunner): ...@@ -133,15 +133,15 @@ class WanVaceRunner(WanRunner):
result_masks = [] result_masks = []
for mask, refs in zip(masks, ref_images): for mask, refs in zip(masks, ref_images):
c, depth, height, width = mask.shape c, depth, height, width = mask.shape
new_depth = int((depth + 3) // self.config.vae_stride[0]) new_depth = int((depth + 3) // self.config["vae_stride"][0])
height = 2 * (int(height) // (self.config.vae_stride[1] * 2)) height = 2 * (int(height) // (self.config["vae_stride"][1] * 2))
width = 2 * (int(width) // (self.config.vae_stride[2] * 2)) width = 2 * (int(width) // (self.config["vae_stride"][2] * 2))
# reshape # reshape
mask = mask[0, :, :, :] mask = mask[0, :, :, :]
mask = mask.view(depth, height, self.config.vae_stride[1], width, self.config.vae_stride[1]) # depth, height, 8, width, 8 mask = mask.view(depth, height, self.config["vae_stride"][1], width, self.config["vae_stride"][1]) # depth, height, 8, width, 8
mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width
mask = mask.reshape(self.config.vae_stride[1] * self.config.vae_stride[2], depth, height, width) # 8*8, depth, height, width mask = mask.reshape(self.config["vae_stride"][1] * self.config["vae_stride"][2], depth, height, width) # 8*8, depth, height, width
# interpolation # interpolation
mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode="nearest-exact").squeeze(0) mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode="nearest-exact").squeeze(0)
...@@ -161,7 +161,7 @@ class WanVaceRunner(WanRunner): ...@@ -161,7 +161,7 @@ class WanVaceRunner(WanRunner):
@ProfilingContext4DebugL1("Run VAE Decoder") @ProfilingContext4DebugL1("Run VAE Decoder")
def run_vae_decoder(self, latents): def run_vae_decoder(self, latents):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
self.vae_decoder = self.load_vae_decoder() self.vae_decoder = self.load_vae_decoder()
if self.src_ref_images is not None: if self.src_ref_images is not None:
...@@ -172,7 +172,7 @@ class WanVaceRunner(WanRunner): ...@@ -172,7 +172,7 @@ class WanVaceRunner(WanRunner):
images = self.vae_decoder.decode(latents.to(GET_DTYPE())) images = self.vae_decoder.decode(latents.to(GET_DTYPE()))
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
del self.vae_decoder del self.vae_decoder
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
......
...@@ -345,6 +345,7 @@ def quantize_model( ...@@ -345,6 +345,7 @@ def quantize_model(
weights, weights,
w_bit=8, w_bit=8,
target_keys=["attn", "ffn"], target_keys=["attn", "ffn"],
adapter_keys=None,
key_idx=2, key_idx=2,
ignore_key=None, ignore_key=None,
linear_dtype=torch.int8, linear_dtype=torch.int8,
...@@ -375,18 +376,21 @@ def quantize_model( ...@@ -375,18 +376,21 @@ def quantize_model(
tensor = weights[key] tensor = weights[key]
# Skip non-tensors, small tensors, and non-2D tensors # Skip non-tensors and non-2D tensors
if not isinstance(tensor, torch.Tensor) or tensor.dim() != 2: if not isinstance(tensor, torch.Tensor) or tensor.dim() != 2:
if tensor.dtype != non_linear_dtype: if tensor.dtype != non_linear_dtype:
weights[key] = tensor.to(non_linear_dtype) weights[key] = tensor.to(non_linear_dtype)
continue continue
# Check if key matches target modules # Check if key matches target modules
parts = key.split(".") parts = key.split(".")
if len(parts) < key_idx + 1 or parts[key_idx] not in target_keys: if len(parts) < key_idx + 1 or parts[key_idx] not in target_keys:
if tensor.dtype != non_linear_dtype: if adapter_keys is not None and not any(adapter_key in parts for adapter_key in adapter_keys):
weights[key] = tensor.to(non_linear_dtype) if tensor.dtype != non_linear_dtype:
continue weights[key] = tensor.to(non_linear_dtype)
continue
try: try:
# Quantize tensor and store results # Quantize tensor and store results
...@@ -511,6 +515,7 @@ def convert_weights(args): ...@@ -511,6 +515,7 @@ def convert_weights(args):
converted_weights, converted_weights,
w_bit=args.bits, w_bit=args.bits,
target_keys=args.target_keys, target_keys=args.target_keys,
adapter_keys=args.adapter_keys,
key_idx=args.key_idx, key_idx=args.key_idx,
ignore_key=args.ignore_key, ignore_key=args.ignore_key,
linear_dtype=args.linear_dtype, linear_dtype=args.linear_dtype,
...@@ -535,6 +540,8 @@ def convert_weights(args): ...@@ -535,6 +540,8 @@ def convert_weights(args):
match = block_pattern.search(key) match = block_pattern.search(key)
if match: if match:
block_idx = match.group(1) block_idx = match.group(1)
if args.model_type == "wan_animate_dit" and "face_adapter" in key:
block_idx = str(int(block_idx) * 5)
block_groups[block_idx][key] = tensor block_groups[block_idx][key] = tensor
else: else:
non_block_weights[key] = tensor non_block_weights[key] = tensor
...@@ -635,7 +642,7 @@ def main(): ...@@ -635,7 +642,7 @@ def main():
parser.add_argument( parser.add_argument(
"-t", "-t",
"--model_type", "--model_type",
choices=["wan_dit", "hunyuan_dit", "wan_t5", "wan_clip"], choices=["wan_dit", "hunyuan_dit", "wan_t5", "wan_clip", "wan_animate_dit"],
default="wan_dit", default="wan_dit",
help="Model type", help="Model type",
) )
...@@ -684,6 +691,7 @@ def main(): ...@@ -684,6 +691,7 @@ def main():
"target_keys": ["self_attn", "cross_attn", "ffn"], "target_keys": ["self_attn", "cross_attn", "ffn"],
"ignore_key": ["ca", "audio"], "ignore_key": ["ca", "audio"],
}, },
"wan_animate_dit": {"key_idx": 2, "target_keys": ["self_attn", "cross_attn", "ffn"], "adapter_keys": ["linear1_kv", "linear1_q", "linear2"], "ignore_key": None},
"hunyuan_dit": { "hunyuan_dit": {
"key_idx": 2, "key_idx": 2,
"target_keys": [ "target_keys": [
...@@ -710,6 +718,7 @@ def main(): ...@@ -710,6 +718,7 @@ def main():
} }
args.target_keys = model_type_keys_map[args.model_type]["target_keys"] args.target_keys = model_type_keys_map[args.model_type]["target_keys"]
args.adapter_keys = model_type_keys_map[args.model_type]["adapter_keys"] if "adapter_keys" in model_type_keys_map[args.model_type] else None
args.key_idx = model_type_keys_map[args.model_type]["key_idx"] args.key_idx = model_type_keys_map[args.model_type]["key_idx"]
args.ignore_key = model_type_keys_map[args.model_type]["ignore_key"] args.ignore_key = model_type_keys_map[args.model_type]["ignore_key"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment