Unverified Commit 9c6d30b5 authored by Zhuguanyu Wu's avatar Zhuguanyu Wu Committed by GitHub
Browse files

bugs fixed for no apply_weight in some conditions (#392)

parent bf19c132
...@@ -18,8 +18,8 @@ class WanDistillModel(WanModel): ...@@ -18,8 +18,8 @@ class WanDistillModel(WanModel):
post_weight_class = WanPostWeights post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device): def __init__(self, model_path, config, device, model_type="wan2.1"):
super().__init__(model_path, config, device) super().__init__(model_path, config, device, model_type)
def _load_ckpt(self, unified_dtype, sensitive_layer): def _load_ckpt(self, unified_dtype, sensitive_layer):
# For the old t2v distill model: https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill # For the old t2v distill model: https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill
......
...@@ -44,23 +44,19 @@ class WanModel(CompiledMethodsMixin): ...@@ -44,23 +44,19 @@ class WanModel(CompiledMethodsMixin):
pre_weight_class = WanPreWeights pre_weight_class = WanPreWeights
transformer_weight_class = WanTransformerWeights transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device): def __init__(self, model_path, config, device, model_type="wan2.1"):
super().__init__() super().__init__()
self.model_path = model_path self.model_path = model_path
self.config = config self.config = config
self.cpu_offload = self.config.get("cpu_offload", False) self.cpu_offload = self.config.get("cpu_offload", False)
self.offload_granularity = self.config.get("offload_granularity", "block") self.offload_granularity = self.config.get("offload_granularity", "block")
self.model_type = model_type
if self.config["seq_parallel"]: if self.config["seq_parallel"]:
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
else: else:
self.seq_p_group = None self.seq_p_group = None
if self.config.get("lora_configs") and self.config.lora_configs:
self.init_empty_model = True
else:
self.init_empty_model = False
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False) self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.dit_quantized = self.config.get("dit_quantized", False) self.dit_quantized = self.config.get("dit_quantized", False)
if self.dit_quantized: if self.dit_quantized:
...@@ -110,6 +106,20 @@ class WanModel(CompiledMethodsMixin): ...@@ -110,6 +106,20 @@ class WanModel(CompiledMethodsMixin):
return True return True
return False return False
def _should_init_empty_model(self):
if self.config.get("lora_configs") and self.config.lora_configs:
if self.model_type in ["wan2.1"]:
return True
if self.model_type in ["wan2.2_moe_high_noise"]:
for lora_config in self.config["lora_configs"]:
if lora_config["name"] == "high_noise_model":
return True
if self.model_type in ["wan2.2_moe_low_noise"]:
for lora_config in self.config["lora_configs"]:
if lora_config["name"] == "low_noise_model":
return True
return False
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer): def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
remove_keys = self.remove_keys if hasattr(self, "remove_keys") else [] remove_keys = self.remove_keys if hasattr(self, "remove_keys") else []
...@@ -254,7 +264,7 @@ class WanModel(CompiledMethodsMixin): ...@@ -254,7 +264,7 @@ class WanModel(CompiledMethodsMixin):
# Initialize weight containers # Initialize weight containers
self.pre_weight = self.pre_weight_class(self.config) self.pre_weight = self.pre_weight_class(self.config)
self.transformer_weights = self.transformer_weight_class(self.config) self.transformer_weights = self.transformer_weight_class(self.config)
if not self.init_empty_model: if not self._should_init_empty_model():
self._apply_weights() self._apply_weights()
def _apply_weights(self, weight_dict=None): def _apply_weights(self, weight_dict=None):
......
...@@ -108,6 +108,7 @@ class Wan22MoeDistillRunner(WanDistillRunner): ...@@ -108,6 +108,7 @@ class Wan22MoeDistillRunner(WanDistillRunner):
self.high_noise_model_path, self.high_noise_model_path,
self.config, self.config,
self.init_device, self.init_device,
model_type="wan2.2_moe_high_noise",
) )
high_lora_wrapper = WanLoraWrapper(high_noise_model) high_lora_wrapper = WanLoraWrapper(high_noise_model)
for lora_config in self.config["lora_configs"]: for lora_config in self.config["lora_configs"]:
...@@ -122,6 +123,7 @@ class Wan22MoeDistillRunner(WanDistillRunner): ...@@ -122,6 +123,7 @@ class Wan22MoeDistillRunner(WanDistillRunner):
self.high_noise_model_path, self.high_noise_model_path,
self.config, self.config,
self.init_device, self.init_device,
model_type="wan2.2_moe_high_noise",
) )
if use_low_lora: if use_low_lora:
...@@ -129,6 +131,7 @@ class Wan22MoeDistillRunner(WanDistillRunner): ...@@ -129,6 +131,7 @@ class Wan22MoeDistillRunner(WanDistillRunner):
self.low_noise_model_path, self.low_noise_model_path,
self.config, self.config,
self.init_device, self.init_device,
model_type="wan2.2_moe_low_noise",
) )
low_lora_wrapper = WanLoraWrapper(low_noise_model) low_lora_wrapper = WanLoraWrapper(low_noise_model)
for lora_config in self.config["lora_configs"]: for lora_config in self.config["lora_configs"]:
...@@ -143,6 +146,7 @@ class Wan22MoeDistillRunner(WanDistillRunner): ...@@ -143,6 +146,7 @@ class Wan22MoeDistillRunner(WanDistillRunner):
self.low_noise_model_path, self.low_noise_model_path,
self.config, self.config,
self.init_device, self.init_device,
model_type="wan2.2_moe_low_noise",
) )
return MultiDistillModelStruct([high_noise_model, low_noise_model], self.config, self.config["boundary_step_index"]) return MultiDistillModelStruct([high_noise_model, low_noise_model], self.config, self.config["boundary_step_index"])
......
...@@ -466,11 +466,13 @@ class Wan22MoeRunner(WanRunner): ...@@ -466,11 +466,13 @@ class Wan22MoeRunner(WanRunner):
self.high_noise_model_path, self.high_noise_model_path,
self.config, self.config,
self.init_device, self.init_device,
model_type="wan2.2_moe_high_noise",
) )
low_noise_model = WanModel( low_noise_model = WanModel(
self.low_noise_model_path, self.low_noise_model_path,
self.config, self.config,
self.init_device, self.init_device,
model_type="wan2.2_moe_low_noise",
) )
if self.config.get("lora_configs") and self.config["lora_configs"]: if self.config.get("lora_configs") and self.config["lora_configs"]:
......
...@@ -127,6 +127,14 @@ class LockableDict(dict): ...@@ -127,6 +127,14 @@ class LockableDict(dict):
self.update(other) self.update(other)
return self return self
# ========== Attribute-style access (EasyDict-like behavior) ==========
def __getattr__(self, key: str):
"""Allow attribute-style access: d.key instead of d['key']"""
try:
return self[key]
except KeyError:
raise AttributeError(f"'LockableDict' object has no attribute '{key}'")
# ========== Internal utilities ========== # ========== Internal utilities ==========
def _ensure_unlocked(self) -> None: def _ensure_unlocked(self) -> None:
if self._locked: if self._locked:
......
...@@ -21,7 +21,7 @@ if __name__ == "__main__": ...@@ -21,7 +21,7 @@ if __name__ == "__main__":
messages = [] messages = []
for i, (image_path, prompt) in enumerate(img_prompts.items()): for i, (image_path, prompt) in enumerate(img_prompts.items()):
messages.append({"seed": 42, "prompt": prompt, "negative_prompt": negative_prompt, "image_path": image_to_base64(image_path), "save_result_path": f"./output_lightx2v_wan_i2v_{i + 1}.mp4"}) messages.append({"seed": 42, "prompt": prompt, "negative_prompt": negative_prompt, "image_path": image_path, "save_result_path": f"./output_lightx2v_wan_i2v_{i + 1}.mp4"})
logger.info(f"urls: {urls}") logger.info(f"urls: {urls}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment