Unverified Commit 9c6d30b5 authored by Zhuguanyu Wu's avatar Zhuguanyu Wu Committed by GitHub
Browse files

bugs fixed for no apply_weight in some conditions (#392)

parent bf19c132
......@@ -18,8 +18,8 @@ class WanDistillModel(WanModel):
post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device):
super().__init__(model_path, config, device)
def __init__(self, model_path, config, device, model_type="wan2.1"):
super().__init__(model_path, config, device, model_type)
def _load_ckpt(self, unified_dtype, sensitive_layer):
# For the old t2v distill model: https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill
......
......@@ -44,23 +44,19 @@ class WanModel(CompiledMethodsMixin):
pre_weight_class = WanPreWeights
transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device):
def __init__(self, model_path, config, device, model_type="wan2.1"):
super().__init__()
self.model_path = model_path
self.config = config
self.cpu_offload = self.config.get("cpu_offload", False)
self.offload_granularity = self.config.get("offload_granularity", "block")
self.model_type = model_type
if self.config["seq_parallel"]:
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
else:
self.seq_p_group = None
if self.config.get("lora_configs") and self.config.lora_configs:
self.init_empty_model = True
else:
self.init_empty_model = False
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.dit_quantized = self.config.get("dit_quantized", False)
if self.dit_quantized:
......@@ -110,6 +106,20 @@ class WanModel(CompiledMethodsMixin):
return True
return False
def _should_init_empty_model(self):
if self.config.get("lora_configs") and self.config.lora_configs:
if self.model_type in ["wan2.1"]:
return True
if self.model_type in ["wan2.2_moe_high_noise"]:
for lora_config in self.config["lora_configs"]:
if lora_config["name"] == "high_noise_model":
return True
if self.model_type in ["wan2.2_moe_low_noise"]:
for lora_config in self.config["lora_configs"]:
if lora_config["name"] == "low_noise_model":
return True
return False
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
remove_keys = self.remove_keys if hasattr(self, "remove_keys") else []
......@@ -254,7 +264,7 @@ class WanModel(CompiledMethodsMixin):
# Initialize weight containers
self.pre_weight = self.pre_weight_class(self.config)
self.transformer_weights = self.transformer_weight_class(self.config)
if not self.init_empty_model:
if not self._should_init_empty_model():
self._apply_weights()
def _apply_weights(self, weight_dict=None):
......
......@@ -108,6 +108,7 @@ class Wan22MoeDistillRunner(WanDistillRunner):
self.high_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_high_noise",
)
high_lora_wrapper = WanLoraWrapper(high_noise_model)
for lora_config in self.config["lora_configs"]:
......@@ -122,6 +123,7 @@ class Wan22MoeDistillRunner(WanDistillRunner):
self.high_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_high_noise",
)
if use_low_lora:
......@@ -129,6 +131,7 @@ class Wan22MoeDistillRunner(WanDistillRunner):
self.low_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_low_noise",
)
low_lora_wrapper = WanLoraWrapper(low_noise_model)
for lora_config in self.config["lora_configs"]:
......@@ -143,6 +146,7 @@ class Wan22MoeDistillRunner(WanDistillRunner):
self.low_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_low_noise",
)
return MultiDistillModelStruct([high_noise_model, low_noise_model], self.config, self.config["boundary_step_index"])
......
......@@ -466,11 +466,13 @@ class Wan22MoeRunner(WanRunner):
self.high_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_high_noise",
)
low_noise_model = WanModel(
self.low_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_low_noise",
)
if self.config.get("lora_configs") and self.config["lora_configs"]:
......
......@@ -127,6 +127,14 @@ class LockableDict(dict):
self.update(other)
return self
# ========== Attribute-style access (EasyDict-like behavior) ==========
def __getattr__(self, key: str):
"""Allow attribute-style access: d.key instead of d['key']"""
try:
return self[key]
except KeyError:
raise AttributeError(f"'LockableDict' object has no attribute '{key}'")
# ========== Internal utilities ==========
def _ensure_unlocked(self) -> None:
if self._locked:
......
......@@ -21,7 +21,7 @@ if __name__ == "__main__":
messages = []
for i, (image_path, prompt) in enumerate(img_prompts.items()):
messages.append({"seed": 42, "prompt": prompt, "negative_prompt": negative_prompt, "image_path": image_to_base64(image_path), "save_result_path": f"./output_lightx2v_wan_i2v_{i + 1}.mp4"})
messages.append({"seed": 42, "prompt": prompt, "negative_prompt": negative_prompt, "image_path": image_path, "save_result_path": f"./output_lightx2v_wan_i2v_{i + 1}.mp4"})
logger.info(f"urls: {urls}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment