Unverified Commit d9795972 authored by Zhuguanyu Wu's avatar Zhuguanyu Wu Committed by GitHub
Browse files

update checkpoint loading functions for wan22_moe_distill (#364)

主要用于步数蒸馏实验
- 支持在 highnoise / lownoise 文件夹下指定特定的checkpoint用于加载
- 删除冗余的 Wan22MoeDistill 类
- 修改同时发起多个请求时,请求中没有 seed 参数导致的 bug
parent cd777631
...@@ -24,5 +24,7 @@ ...@@ -24,5 +24,7 @@
750, 750,
500, 500,
250 250
] ],
"dit_distill_ckpt_high": "Wan2.2-I2V-A14B/distill_models/high_noise_model/distill_model.safetensors",
"dit_distill_ckpt_low": "Wan2.2-I2V-A14B/distill_models/low_noise_model/distill_model.safetensors"
} }
...@@ -19,7 +19,8 @@ class WanDistillModel(WanModel): ...@@ -19,7 +19,8 @@ class WanDistillModel(WanModel):
post_weight_class = WanPostWeights post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device): def __init__(self, model_path, config, device, ckpt_config_key="dit_distill_ckpt"):
self.ckpt_config_key = ckpt_config_key
super().__init__(model_path, config, device) super().__init__(model_path, config, device)
def _load_ckpt(self, unified_dtype, sensitive_layer): def _load_ckpt(self, unified_dtype, sensitive_layer):
...@@ -35,23 +36,19 @@ class WanDistillModel(WanModel): ...@@ -35,23 +36,19 @@ class WanDistillModel(WanModel):
return weight_dict return weight_dict
if self.config.get("enable_dynamic_cfg", False): if self.config.get("enable_dynamic_cfg", False):
safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_distill_ckpt", subdir="distill_cfg_models") safetensors_path = find_hf_model_path(self.config, self.model_path, self.ckpt_config_key, subdir="distill_cfg_models")
else: else:
safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_distill_ckpt", subdir="distill_models") safetensors_path = find_hf_model_path(self.config, self.model_path, self.ckpt_config_key, subdir="distill_models")
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors")) if os.path.isfile(safetensors_path):
logger.info(f"loading checkpoint from {safetensors_path} ...")
safetensors_files = glob.glob(safetensors_path)
else:
logger.info(f"loading checkpoint from {safetensors_path} ...")
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
weight_dict = {} weight_dict = {}
for file_path in safetensors_files: for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer) file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
weight_dict.update(file_weights) weight_dict.update(file_weights)
return weight_dict return weight_dict
class Wan22MoeDistillModel(WanDistillModel, WanModel):
def __init__(self, model_path, config, device):
WanDistillModel.__init__(self, model_path, config, device)
@torch.no_grad()
def infer(self, inputs):
return WanModel.infer(self, inputs)
...@@ -2,7 +2,7 @@ import os ...@@ -2,7 +2,7 @@ import os
from loguru import logger from loguru import logger
from lightx2v.models.networks.wan.distill_model import Wan22MoeDistillModel, WanDistillModel from lightx2v.models.networks.wan.distill_model import WanDistillModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.networks.wan.model import WanModel from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.runners.wan.wan_runner import MultiModelStruct, WanRunner from lightx2v.models.runners.wan.wan_runner import MultiModelStruct, WanRunner
...@@ -103,10 +103,11 @@ class Wan22MoeDistillRunner(WanDistillRunner): ...@@ -103,10 +103,11 @@ class Wan22MoeDistillRunner(WanDistillRunner):
high_lora_wrapper.apply_lora(lora_name, strength) high_lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"High noise model loaded LoRA: {lora_name} with strength: {strength}") logger.info(f"High noise model loaded LoRA: {lora_name} with strength: {strength}")
else: else:
high_noise_model = Wan22MoeDistillModel( high_noise_model = WanDistillModel(
os.path.join(self.config["model_path"], "distill_models", "high_noise_model"), os.path.join(self.config["model_path"], "distill_models", "high_noise_model"),
self.config, self.config,
self.init_device, self.init_device,
ckpt_config_key="dit_distill_ckpt_high",
) )
if use_low_lora: if use_low_lora:
...@@ -124,10 +125,11 @@ class Wan22MoeDistillRunner(WanDistillRunner): ...@@ -124,10 +125,11 @@ class Wan22MoeDistillRunner(WanDistillRunner):
low_lora_wrapper.apply_lora(lora_name, strength) low_lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Low noise model loaded LoRA: {lora_name} with strength: {strength}") logger.info(f"Low noise model loaded LoRA: {lora_name} with strength: {strength}")
else: else:
low_noise_model = Wan22MoeDistillModel( low_noise_model = WanDistillModel(
os.path.join(self.config["model_path"], "distill_models", "low_noise_model"), os.path.join(self.config["model_path"], "distill_models", "low_noise_model"),
self.config, self.config,
self.init_device, self.init_device,
ckpt_config_key="dit_distill_ckpt_low",
) )
return MultiDistillModelStruct([high_noise_model, low_noise_model], self.config, self.config["boundary_step_index"]) return MultiDistillModelStruct([high_noise_model, low_noise_model], self.config, self.config["boundary_step_index"])
......
...@@ -21,7 +21,7 @@ if __name__ == "__main__": ...@@ -21,7 +21,7 @@ if __name__ == "__main__":
messages = [] messages = []
for i, (image_path, prompt) in enumerate(img_prompts.items()): for i, (image_path, prompt) in enumerate(img_prompts.items()):
messages.append({"prompt": prompt, "negative_prompt": negative_prompt, "image_path": image_to_base64(image_path), "save_result_path": f"./output_lightx2v_wan_i2v_{i + 1}.mp4"}) messages.append({"seed": 42, "prompt": prompt, "negative_prompt": negative_prompt, "image_path": image_to_base64(image_path), "save_result_path": f"./output_lightx2v_wan_i2v_{i + 1}.mp4"})
logger.info(f"urls: {urls}") logger.info(f"urls: {urls}")
......
...@@ -15,7 +15,7 @@ if __name__ == "__main__": ...@@ -15,7 +15,7 @@ if __name__ == "__main__":
messages = [] messages = []
for i, prompt in enumerate(prompts): for i, prompt in enumerate(prompts):
messages.append({"prompt": prompt, "negative_prompt": negative_prompt, "image_path": "", "save_result_path": f"./output_lightx2v_wan_t2v_{i + 1}.mp4"}) messages.append({"seed": 42, "prompt": prompt, "negative_prompt": negative_prompt, "image_path": "", "save_result_path": f"./output_lightx2v_wan_t2v_{i + 1}.mp4"})
logger.info(f"urls: {urls}") logger.info(f"urls: {urls}")
......
...@@ -17,6 +17,7 @@ def create_i2v_messages(img_files, output_path): ...@@ -17,6 +17,7 @@ def create_i2v_messages(img_files, output_path):
save_result_path = os.path.join(output_path, f"{prompt}.mp4") save_result_path = os.path.join(output_path, f"{prompt}.mp4")
message = { message = {
"seed": 42,
"prompt": prompt, "prompt": prompt,
"negative_prompt": negative_prompt, "negative_prompt": negative_prompt,
"image_path": img_path, "image_path": img_path,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment