"vscode:/vscode.git/clone" did not exist on "bedb57f4a761efa81b725db58f260a34f6bd5872"
Commit 4e895b2a authored by gushiqiao's avatar gushiqiao
Browse files

Fix causvid model load bug

parent 230e786a
import os
import torch
from safetensors import safe_open
from LightX2V.lightx2v.utils.utils import find_torch_model_path
from lightx2v.common.ops.attn.radial_attn import MaskMap
from lightx2v.models.networks.wan.infer.causvid.transformer_infer import (
......@@ -32,23 +32,12 @@ class WanCausVidModel(WanModel):
self.transformer_infer_class = WanTransformerInferCausVid
def _load_ckpt(self, unified_dtype, sensitive_layer):
ckpt_folder = "causvid_models"
safetensors_path = os.path.join(self.model_path, f"{ckpt_folder}/causal_model.safetensors")
if os.path.exists(safetensors_path):
with safe_open(safetensors_path, framework="pt") as f:
weight_dict = {
key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE()))
.pin_memory()
.to(self.device)
for key in f.keys()
}
return weight_dict
ckpt_path = os.path.join(self.model_path, f"{ckpt_folder}/causal_model.pt")
ckpt_path = find_torch_model_path(self.config, self.model_path, "causvid_model.pt")
if os.path.exists(ckpt_path):
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
weight_dict = {
key: (weight_dict[key].to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else weight_dict[key]).pin_memory().to(self.device) for key in weight_dict.keys()
key: (weight_dict[key].to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else weight_dict[key].to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device)
for key in weight_dict.keys()
}
return weight_dict
......
......@@ -28,14 +28,8 @@ class WanLoraWrapper:
return lora_name
def _load_lora_file(self, file_path):
use_bfloat16 = GET_DTYPE() == "BF16"
if self.model.config and hasattr(self.model.config, "get"):
use_bfloat16 = self.model.config.get("use_bfloat16", True)
with safe_open(file_path, framework="pt") as f:
if use_bfloat16:
tensor_dict = {key: f.get_tensor(key).to(GET_DTYPE()) for key in f.keys()}
else:
tensor_dict = {key: f.get_tensor(key) for key in f.keys()}
tensor_dict = {key: f.get_tensor(key).to(GET_DTYPE()) for key in f.keys()}
return tensor_dict
def apply_lora(self, lora_name, alpha=1.0):
......@@ -52,7 +46,7 @@ class WanLoraWrapper:
self.model._init_weights(weight_dict)
logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}")
del lora_weights # 删除节约显存
del lora_weights
return True
@torch.no_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment