"vscode:/vscode.git/clone" did not exist on "9332b71ff9ca6d844176a65fd78992a7c0307277"
Commit 4e895b2a authored by gushiqiao's avatar gushiqiao
Browse files

Fix causvid model load bug

parent 230e786a
import os import os
import torch import torch
from safetensors import safe_open from LightX2V.lightx2v.utils.utils import find_torch_model_path
from lightx2v.common.ops.attn.radial_attn import MaskMap from lightx2v.common.ops.attn.radial_attn import MaskMap
from lightx2v.models.networks.wan.infer.causvid.transformer_infer import ( from lightx2v.models.networks.wan.infer.causvid.transformer_infer import (
...@@ -32,23 +32,12 @@ class WanCausVidModel(WanModel): ...@@ -32,23 +32,12 @@ class WanCausVidModel(WanModel):
self.transformer_infer_class = WanTransformerInferCausVid self.transformer_infer_class = WanTransformerInferCausVid
def _load_ckpt(self, unified_dtype, sensitive_layer): def _load_ckpt(self, unified_dtype, sensitive_layer):
ckpt_folder = "causvid_models" ckpt_path = find_torch_model_path(self.config, self.model_path, "causvid_model.pt")
safetensors_path = os.path.join(self.model_path, f"{ckpt_folder}/causal_model.safetensors")
if os.path.exists(safetensors_path):
with safe_open(safetensors_path, framework="pt") as f:
weight_dict = {
key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE()))
.pin_memory()
.to(self.device)
for key in f.keys()
}
return weight_dict
ckpt_path = os.path.join(self.model_path, f"{ckpt_folder}/causal_model.pt")
if os.path.exists(ckpt_path): if os.path.exists(ckpt_path):
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
weight_dict = { weight_dict = {
key: (weight_dict[key].to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else weight_dict[key]).pin_memory().to(self.device) for key in weight_dict.keys() key: (weight_dict[key].to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else weight_dict[key].to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device)
for key in weight_dict.keys()
} }
return weight_dict return weight_dict
......
...@@ -28,14 +28,8 @@ class WanLoraWrapper: ...@@ -28,14 +28,8 @@ class WanLoraWrapper:
return lora_name return lora_name
def _load_lora_file(self, file_path): def _load_lora_file(self, file_path):
use_bfloat16 = GET_DTYPE() == "BF16"
if self.model.config and hasattr(self.model.config, "get"):
use_bfloat16 = self.model.config.get("use_bfloat16", True)
with safe_open(file_path, framework="pt") as f: with safe_open(file_path, framework="pt") as f:
if use_bfloat16:
tensor_dict = {key: f.get_tensor(key).to(GET_DTYPE()) for key in f.keys()} tensor_dict = {key: f.get_tensor(key).to(GET_DTYPE()) for key in f.keys()}
else:
tensor_dict = {key: f.get_tensor(key) for key in f.keys()}
return tensor_dict return tensor_dict
def apply_lora(self, lora_name, alpha=1.0): def apply_lora(self, lora_name, alpha=1.0):
...@@ -52,7 +46,7 @@ class WanLoraWrapper: ...@@ -52,7 +46,7 @@ class WanLoraWrapper:
self.model._init_weights(weight_dict) self.model._init_weights(weight_dict)
logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}") logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}")
del lora_weights # 删除节约显存 del lora_weights
return True return True
@torch.no_grad() @torch.no_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment