Unverified Commit 63f233ad authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Fix] Fix vace and animate models config bug (#351)

parent 69c2f650
......@@ -34,6 +34,10 @@ class FlashAttn2Weight(AttnWeightTemplate):
max_seqlen_kv=None,
model_cls=None,
):
if len(q.shape) == 3:
bs = 1
elif len(q.shape) == 4:
bs = q.shape[0]
x = flash_attn_varlen_func(
q,
k,
......@@ -42,7 +46,7 @@ class FlashAttn2Weight(AttnWeightTemplate):
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
).reshape(max_seqlen_q, -1)
).reshape(bs * max_seqlen_q, -1)
return x
......@@ -63,23 +67,16 @@ class FlashAttn3Weight(AttnWeightTemplate):
model_cls=None,
):
if len(q.shape) == 3:
x = flash_attn_varlen_func_v3(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
).reshape(max_seqlen_q, -1)
bs = 1
elif len(q.shape) == 4:
x = flash_attn_varlen_func_v3(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
).reshape(q.shape[0] * max_seqlen_q, -1)
bs = q.shape[0]
x = flash_attn_varlen_func_v3(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
).reshape(bs * max_seqlen_q, -1)
return x
......@@ -36,38 +36,15 @@ class SageAttn2Weight(AttnWeightTemplate):
model_cls=None,
):
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
if model_cls == "hunyuan":
x1 = sageattn(
q[: cu_seqlens_q[1]].unsqueeze(0),
k[: cu_seqlens_kv[1]].unsqueeze(0),
v[: cu_seqlens_kv[1]].unsqueeze(0),
tensor_layout="NHD",
)
x2 = sageattn(
q[cu_seqlens_q[1] :].unsqueeze(0),
k[cu_seqlens_kv[1] :].unsqueeze(0),
v[cu_seqlens_kv[1] :].unsqueeze(0),
tensor_layout="NHD",
)
x = torch.cat((x1, x2), dim=1)
x = x.view(max_seqlen_q, -1)
elif model_cls in ["wan2.1", "wan2.1_distill", "wan2.1_causvid", "wan2.1_df", "seko_talk", "wan2.2", "wan2.1_vace", "wan2.2_moe", "wan2.2_animate", "wan2.2_moe_distill", "qwen_image"]:
if len(q.shape) == 3:
x = sageattn(
q.unsqueeze(0),
k.unsqueeze(0),
v.unsqueeze(0),
tensor_layout="NHD",
)
x = x.view(max_seqlen_q, -1)
elif len(q.shape) == 4:
x = sageattn(
q,
k,
v,
tensor_layout="NHD",
)
x = x.view(q.shape[0] * max_seqlen_q, -1)
else:
raise NotImplementedError(f"Model class '{model_cls}' is not implemented in this attention implementation")
if len(q.shape) == 3:
bs = 1
q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
elif len(q.shape) == 4:
bs = q.shape[0]
x = sageattn(
q,
k,
v,
tensor_layout="NHD",
).view(bs * max_seqlen_q, -1)
return x
......@@ -20,8 +20,8 @@ class WanAnimateTransformerInfer(WanOffloadTransformerInfer):
kv = phase.linear1_kv.apply(x_motion.view(-1, x_motion.shape[-1]))
kv = kv.view(T, -1, kv.shape[-1])
q = phase.linear1_q.apply(x_feat)
k, v = rearrange(kv, "L N (K H D) -> K L N H D", K=2, H=self.config.num_heads)
q = rearrange(q, "S (H D) -> S H D", H=self.config.num_heads)
k, v = rearrange(kv, "L N (K H D) -> K L N H D", K=2, H=self.config["num_heads"])
q = rearrange(q, "S (H D) -> S H D", H=self.config["num_heads"])
q = phase.q_norm.apply(q).view(T, q.shape[0] // T, q.shape[1], q.shape[2])
k = phase.k_norm.apply(k)
......
......@@ -5,8 +5,8 @@ from lightx2v.utils.envs import *
class WanVaceTransformerInfer(WanOffloadTransformerInfer):
def __init__(self, config):
super().__init__(config)
self.vace_blocks_num = len(self.config.vace_layers)
self.vace_blocks_mapping = {orig_idx: seq_idx for seq_idx, orig_idx in enumerate(self.config.vace_layers)}
self.vace_blocks_num = len(self.config["vace_layers"])
self.vace_blocks_mapping = {orig_idx: seq_idx for seq_idx, orig_idx in enumerate(self.config["vace_layers"])}
def infer(self, weights, pre_infer_out):
pre_infer_out.c = self.vace_pre_process(weights.vace_patch_embedding, pre_infer_out.vace_context)
......
......@@ -14,7 +14,7 @@ class WanVaceTransformerWeights(WanTransformerWeights):
super().__init__(config)
self.patch_size = (1, 2, 2)
self.vace_blocks = WeightModuleList(
[WanVaceTransformerAttentionBlock(self.config.vace_layers[i], i, self.task, self.mm_type, self.config, "vace_blocks") for i in range(len(self.config.vace_layers))]
[WanVaceTransformerAttentionBlock(self.config["vace_layers"][i], i, self.task, self.mm_type, self.config, "vace_blocks") for i in range(len(self.config["vace_layers"]))]
)
self.add_module("vace_blocks", self.vace_blocks)
......
......@@ -214,7 +214,7 @@ class DefaultRunner(BaseRunner):
[src_video],
[src_mask],
[None if src_ref_images is None else src_ref_images.split(",")],
(self.config.target_width, self.config.target_height),
(self.config["target_width"], self.config["target_height"]),
)
self.src_ref_images = src_ref_images
......
......@@ -11,7 +11,7 @@ try:
from decord import VideoReader
except ImportError:
VideoReader = None
logger.info("If you need run animate model, please install decord.")
logger.info("If you want to run animate model, please install decord.")
from lightx2v.models.input_encoders.hf.animate.face_encoder import FaceEncoder
......@@ -28,7 +28,7 @@ from lightx2v.utils.utils import load_weights, remove_substrings_from_keys
class WanAnimateRunner(WanRunner):
def __init__(self, config):
super().__init__(config)
assert self.config.task == "animate"
assert self.config["task"] == "animate"
def inputs_padding(self, array, target_len):
idx = 0
......@@ -161,11 +161,11 @@ class WanAnimateRunner(WanRunner):
pose_latents = self.vae_encoder.encode(conditioning_pixel_values.unsqueeze(0)) # c t h w
ref_latents = self.vae_encoder.encode(self.refer_pixel_values.unsqueeze(1).unsqueeze(0)) # c t h w
mask_ref = self.get_i2v_mask(1, self.config.lat_h, self.config.lat_w, 1)
mask_ref = self.get_i2v_mask(1, self.latent_h, self.latent_w, 1)
y_ref = torch.concat([mask_ref, ref_latents])
if self.mask_reft_len > 0:
if self.config.replace_flag:
if self.config["replace_flag"]:
y_reft = self.vae_encoder.encode(
torch.concat(
[
......@@ -183,9 +183,9 @@ class WanAnimateRunner(WanRunner):
mask_pixel_values = mask_pixel_values[:, 0, :, :]
msk_reft = self.get_i2v_mask(
self.config.lat_t,
self.config.lat_h,
self.config.lat_w,
self.latent_t,
self.latent_h,
self.latent_w,
self.mask_reft_len,
mask_pixel_values=mask_pixel_values.unsqueeze(0),
)
......@@ -198,31 +198,31 @@ class WanAnimateRunner(WanRunner):
size=(H, W),
mode="bicubic",
),
torch.zeros(3, self.config.target_video_length - self.mask_reft_len, H, W, dtype=GET_DTYPE()),
torch.zeros(3, self.config["target_video_length"] - self.mask_reft_len, H, W, dtype=GET_DTYPE()),
],
dim=1,
)
.cuda()
.unsqueeze(0)
)
msk_reft = self.get_i2v_mask(self.config.lat_t, self.config.lat_h, self.config.lat_w, self.mask_reft_len)
msk_reft = self.get_i2v_mask(self.latent_t, self.latent_h, self.latent_w, self.mask_reft_len)
else:
if self.config.replace_flag:
if self.config["replace_flag"]:
mask_pixel_values = 1 - mask_pixel_values
mask_pixel_values = mask_pixel_values.permute(1, 0, 2, 3)
mask_pixel_values = F.interpolate(mask_pixel_values, size=(H // 8, W // 8), mode="nearest")
mask_pixel_values = mask_pixel_values[:, 0, :, :]
y_reft = self.vae_encoder.encode(bg_pixel_values.unsqueeze(0))
msk_reft = self.get_i2v_mask(
self.config.lat_t,
self.config.lat_h,
self.config.lat_w,
self.latent_t,
self.latent_h,
self.latent_w,
self.mask_reft_len,
mask_pixel_values=mask_pixel_values.unsqueeze(0),
)
else:
y_reft = self.vae_encoder.encode(torch.zeros(1, 3, self.config.target_video_length - self.mask_reft_len, H, W, dtype=GET_DTYPE(), device="cuda"))
msk_reft = self.get_i2v_mask(self.config.lat_t, self.config.lat_h, self.config.lat_w, self.mask_reft_len)
y_reft = self.vae_encoder.encode(torch.zeros(1, 3, self.config["target_video_length"] - self.mask_reft_len, H, W, dtype=GET_DTYPE(), device="cuda"))
msk_reft = self.get_i2v_mask(self.latent_t, self.latent_h, self.latent_w, self.mask_reft_len)
y_reft = torch.concat([msk_reft, y_reft])
y = torch.concat([y_ref, y_reft], dim=1)
......@@ -230,35 +230,39 @@ class WanAnimateRunner(WanRunner):
return y, pose_latents
def prepare_input(self):
src_pose_path = self.config.get("src_pose_path", None)
src_face_path = self.config.get("src_face_path", None)
src_ref_path = self.config.get("src_ref_images", None)
src_pose_path = self.config["src_pose_path"] if "src_pose_path" in self.config else None
src_face_path = self.config["src_face_path"] if "src_face_path" in self.config else None
src_ref_path = self.config["src_ref_images"] if "src_ref_images" in self.config else None
self.cond_images, self.face_images, self.refer_images = self.prepare_source(src_pose_path, src_face_path, src_ref_path)
self.refer_pixel_values = torch.tensor(self.refer_images / 127.5 - 1, dtype=GET_DTYPE(), device="cuda").permute(2, 0, 1) # chw
self.latent_t = self.config["target_video_length"] // self.config["vae_stride"][0] + 1
self.latent_h = self.refer_pixel_values.shape[-2] // self.config["vae_stride"][1]
self.latent_w = self.refer_pixel_values.shape[-1] // self.config["vae_stride"][2]
self.input_info.latent_shape = [self.config.get("num_channels_latents", 16), self.latent_t + 1, self.latent_h, self.latent_w]
self.real_frame_len = len(self.cond_images)
target_len = self.get_valid_len(
self.real_frame_len,
self.config.target_video_length,
overlap=self.config.get("refert_num", 1),
self.config["target_video_length"],
overlap=self.config["refert_num"] if "refert_num" in self.config else 1,
)
logger.info("real frames: {} target frames: {}".format(self.real_frame_len, target_len))
self.cond_images = self.inputs_padding(self.cond_images, target_len)
self.face_images = self.inputs_padding(self.face_images, target_len)
if self.config.get("replace_flag", False):
src_bg_path = self.config.get("src_bg_path")
src_mask_path = self.config.get("src_mask_path")
if self.config["replace_flag"] if "replace_flag" in self.config else False:
src_bg_path = self.config["src_bg_path"]
src_mask_path = self.config["src_mask_path"]
self.bg_images, self.mask_images = self.prepare_source_for_replace(src_bg_path, src_mask_path)
self.bg_images = self.inputs_padding(self.bg_images, target_len)
self.mask_images = self.inputs_padding(self.mask_images, target_len)
def get_video_segment_num(self):
total_frames = len(self.cond_images)
self.move_frames = self.config.target_video_length - self.config.refert_num
if total_frames <= self.config.target_video_length:
self.move_frames = self.config["target_video_length"] - self.config["refert_num"]
if total_frames <= self.config["target_video_length"]:
self.video_segment_num = 1
else:
self.video_segment_num = 1 + (total_frames - self.config.target_video_length + self.move_frames - 1) // self.move_frames
self.video_segment_num = 1 + (total_frames - self.config["target_video_length"] + self.move_frames - 1) // self.move_frames
def init_run(self):
self.all_out_frames = []
......@@ -267,10 +271,10 @@ class WanAnimateRunner(WanRunner):
@ProfilingContext4DebugL1("Run VAE Decoder")
def run_vae_decoder(self, latents):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
self.vae_decoder = self.load_vae_decoder()
images = self.vae_decoder.decode(latents[:, 1:].to(GET_DTYPE()))
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
del self.vae_decoder
torch.cuda.empty_cache()
gc.collect()
......@@ -278,11 +282,11 @@ class WanAnimateRunner(WanRunner):
def init_run_segment(self, segment_idx):
start = segment_idx * self.move_frames
end = start + self.config.target_video_length
end = start + self.config["target_video_length"]
if start == 0:
self.mask_reft_len = 0
else:
self.mask_reft_len = self.config.refert_num
self.mask_reft_len = self.config["refert_num"]
conditioning_pixel_values = torch.tensor(
np.stack(self.cond_images[start:end]) / 127.5 - 1,
......@@ -300,17 +304,17 @@ class WanAnimateRunner(WanRunner):
height, width = self.refer_images.shape[:2]
refer_t_pixel_values = torch.zeros(
3,
self.config.refert_num,
self.config["refert_num"],
height,
width,
device="cuda",
dtype=GET_DTYPE(),
) # c t h w
else:
refer_t_pixel_values = self.gen_video[0, :, -self.config.refert_num :].transpose(0, 1).clone().detach() # c t h w
refer_t_pixel_values = self.gen_video[0, :, -self.config["refert_num"] :].transpose(0, 1).clone().detach() # c t h w
bg_pixel_values, mask_pixel_values = None, None
if self.config.replace_flag:
if self.config["replace_flag"] if "replace_flag" in self.config else False:
bg_pixel_values = torch.tensor(
np.stack(self.bg_images[start:end]) / 127.5 - 1,
device="cuda",
......@@ -341,24 +345,17 @@ class WanAnimateRunner(WanRunner):
self.gen_video = self.gen_video[:, :, self.config["refert_num"] :]
self.all_out_frames.append(self.gen_video.cpu())
def process_images_after_vae_decoder(self, save_video=True):
self.gen_video = torch.cat(self.all_out_frames, dim=2)[:, :, : self.real_frame_len]
def process_images_after_vae_decoder(self):
self.gen_video_final = torch.cat(self.all_out_frames, dim=2)[:, :, : self.real_frame_len]
del self.all_out_frames
gc.collect()
super().process_images_after_vae_decoder(save_video)
def set_target_shape(self):
self.config.target_video_length = self.config.target_video_length
self.config.lat_h = self.refer_pixel_values.shape[-2] // 8
self.config.lat_w = self.refer_pixel_values.shape[-1] // 8
self.config.lat_t = self.config.target_video_length // 4 + 1
self.config.target_shape = [16, self.config.lat_t + 1, self.config.lat_h, self.config.lat_w]
super().process_images_after_vae_decoder()
def run_image_encoder(self, img): # CHW
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
self.image_encoder = self.load_image_encoder()
clip_encoder_out = self.image_encoder.visual([img.unsqueeze(0)]).squeeze(0).to(GET_DTYPE())
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
del self.image_encoder
torch.cuda.empty_cache()
gc.collect()
......@@ -366,7 +363,7 @@ class WanAnimateRunner(WanRunner):
def load_transformer(self):
model = WanAnimateModel(
self.config.model_path,
self.config["model_path"],
self.config,
self.init_device,
)
......
......@@ -17,13 +17,13 @@ from lightx2v.utils.registry_factory import RUNNER_REGISTER
class WanVaceRunner(WanRunner):
def __init__(self, config):
super().__init__(config)
assert self.config.task == "vace"
assert self.config["task"] == "vace"
self.vid_proc = VaceVideoProcessor(
downsample=tuple([x * y for x, y in zip(self.config.vae_stride, self.config.patch_size)]),
downsample=tuple([x * y for x, y in zip(self.config["vae_stride"], self.config["patch_size"])]),
min_area=720 * 1280,
max_area=720 * 1280,
min_fps=self.config.get("fps", 16),
max_fps=self.config.get("fps", 16),
min_fps=self.config["fps"] if "fps" in self.config else 16,
max_fps=self.config["fps"] if "fps" in self.config else 16,
zero_start=True,
seq_len=75600,
keep_last=True,
......@@ -31,7 +31,7 @@ class WanVaceRunner(WanRunner):
def load_transformer(self):
model = WanVaceModel(
self.config.model_path,
self.config["model_path"],
self.config,
self.init_device,
)
......@@ -57,7 +57,7 @@ class WanVaceRunner(WanRunner):
src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
image_sizes.append(src_video[i].shape[2:])
elif sub_src_video is None:
src_video[i] = torch.zeros((3, self.config.target_video_length, image_size[0], image_size[1]), device=device)
src_video[i] = torch.zeros((3, self.config["target_video_length"], image_size[0], image_size[1]), device=device)
src_mask[i] = torch.ones_like(src_video[i], device=device)
image_sizes.append(image_size)
else:
......@@ -89,7 +89,7 @@ class WanVaceRunner(WanRunner):
return src_video, src_mask, src_ref_images
def run_vae_encoder(self, frames, ref_images, masks):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
self.vae_encoder = self.load_vae_encoder()
if ref_images is None:
ref_images = [None] * len(frames)
......@@ -118,11 +118,11 @@ class WanVaceRunner(WanRunner):
latent = torch.cat([*ref_latent, latent], dim=1)
cat_latents.append(latent)
self.latent_shape = list(cat_latents[0].shape)
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
del self.vae_encoder
torch.cuda.empty_cache()
gc.collect()
return self.get_vae_encoder_output(cat_latents, masks, ref_images)
return self.get_vae_encoder_output(cat_latents, masks, ref_images), self.set_input_info_latent_shape()
def get_vae_encoder_output(self, cat_latents, masks, ref_images):
if ref_images is None:
......@@ -133,15 +133,15 @@ class WanVaceRunner(WanRunner):
result_masks = []
for mask, refs in zip(masks, ref_images):
c, depth, height, width = mask.shape
new_depth = int((depth + 3) // self.config.vae_stride[0])
height = 2 * (int(height) // (self.config.vae_stride[1] * 2))
width = 2 * (int(width) // (self.config.vae_stride[2] * 2))
new_depth = int((depth + 3) // self.config["vae_stride"][0])
height = 2 * (int(height) // (self.config["vae_stride"][1] * 2))
width = 2 * (int(width) // (self.config["vae_stride"][2] * 2))
# reshape
mask = mask[0, :, :, :]
mask = mask.view(depth, height, self.config.vae_stride[1], width, self.config.vae_stride[1]) # depth, height, 8, width, 8
mask = mask.view(depth, height, self.config["vae_stride"][1], width, self.config["vae_stride"][1]) # depth, height, 8, width, 8
mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width
mask = mask.reshape(self.config.vae_stride[1] * self.config.vae_stride[2], depth, height, width) # 8*8, depth, height, width
mask = mask.reshape(self.config["vae_stride"][1] * self.config["vae_stride"][2], depth, height, width) # 8*8, depth, height, width
# interpolation
mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode="nearest-exact").squeeze(0)
......@@ -161,7 +161,7 @@ class WanVaceRunner(WanRunner):
@ProfilingContext4DebugL1("Run VAE Decoder")
def run_vae_decoder(self, latents):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
self.vae_decoder = self.load_vae_decoder()
if self.src_ref_images is not None:
......@@ -172,7 +172,7 @@ class WanVaceRunner(WanRunner):
images = self.vae_decoder.decode(latents.to(GET_DTYPE()))
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
if (self.config["lazy_load"] if "lazy_load" in self.config else False) or (self.config["unload_modules"] if "unload_modules" in self.config else False):
del self.vae_decoder
torch.cuda.empty_cache()
gc.collect()
......
......@@ -345,6 +345,7 @@ def quantize_model(
weights,
w_bit=8,
target_keys=["attn", "ffn"],
adapter_keys=None,
key_idx=2,
ignore_key=None,
linear_dtype=torch.int8,
......@@ -375,18 +376,21 @@ def quantize_model(
tensor = weights[key]
# Skip non-tensors, small tensors, and non-2D tensors
# Skip non-tensors and non-2D tensors
if not isinstance(tensor, torch.Tensor) or tensor.dim() != 2:
if tensor.dtype != non_linear_dtype:
weights[key] = tensor.to(non_linear_dtype)
continue
# Check if key matches target modules
parts = key.split(".")
if len(parts) < key_idx + 1 or parts[key_idx] not in target_keys:
if tensor.dtype != non_linear_dtype:
weights[key] = tensor.to(non_linear_dtype)
continue
if adapter_keys is not None and not any(adapter_key in parts for adapter_key in adapter_keys):
if tensor.dtype != non_linear_dtype:
weights[key] = tensor.to(non_linear_dtype)
continue
try:
# Quantize tensor and store results
......@@ -511,6 +515,7 @@ def convert_weights(args):
converted_weights,
w_bit=args.bits,
target_keys=args.target_keys,
adapter_keys=args.adapter_keys,
key_idx=args.key_idx,
ignore_key=args.ignore_key,
linear_dtype=args.linear_dtype,
......@@ -535,6 +540,8 @@ def convert_weights(args):
match = block_pattern.search(key)
if match:
block_idx = match.group(1)
if args.model_type == "wan_animate_dit" and "face_adapter" in key:
block_idx = str(int(block_idx) * 5)
block_groups[block_idx][key] = tensor
else:
non_block_weights[key] = tensor
......@@ -635,7 +642,7 @@ def main():
parser.add_argument(
"-t",
"--model_type",
choices=["wan_dit", "hunyuan_dit", "wan_t5", "wan_clip"],
choices=["wan_dit", "hunyuan_dit", "wan_t5", "wan_clip", "wan_animate_dit"],
default="wan_dit",
help="Model type",
)
......@@ -684,6 +691,7 @@ def main():
"target_keys": ["self_attn", "cross_attn", "ffn"],
"ignore_key": ["ca", "audio"],
},
"wan_animate_dit": {"key_idx": 2, "target_keys": ["self_attn", "cross_attn", "ffn"], "adapter_keys": ["linear1_kv", "linear1_q", "linear2"], "ignore_key": None},
"hunyuan_dit": {
"key_idx": 2,
"target_keys": [
......@@ -710,6 +718,7 @@ def main():
}
args.target_keys = model_type_keys_map[args.model_type]["target_keys"]
args.adapter_keys = model_type_keys_map[args.model_type]["adapter_keys"] if "adapter_keys" in model_type_keys_map[args.model_type] else None
args.key_idx = model_type_keys_map[args.model_type]["key_idx"]
args.ignore_key = model_type_keys_map[args.model_type]["ignore_key"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment