Commit 57f5e32b authored by GoatWu's avatar GoatWu
Browse files

bug fixed for causvid

parent 34df26f6
......@@ -12,6 +12,7 @@ from lightx2v.models.networks.wan.infer.causvid.transformer_infer import (
WanTransformerInferCausVid,
)
from lightx2v.utils.envs import *
from safetensors import safe_open
class WanCausVidModel(WanModel):
......@@ -28,18 +29,22 @@ class WanCausVidModel(WanModel):
self.transformer_infer_class = WanTransformerInferCausVid
def _load_ckpt(self, use_bf16, skip_bf16):
use_bfloat16 = GET_DTYPE() == "BF16"
ckpt_path = os.path.join(self.model_path, "causal_model.pt")
if not os.path.exists(ckpt_path):
return super()._load_ckpt(use_bf16, skip_bf16)
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
dtype = torch.bfloat16 if use_bfloat16 else None
for key, value in weight_dict.items():
weight_dict[key] = value.to(device=self.device, dtype=dtype)
return weight_dict
ckpt_folder = "causvid_models"
safetensors_path = os.path.join(self.model_path, f"{ckpt_folder}/causal_model.safetensors")
if os.path.exists(safetensors_path):
with safe_open(safetensors_path, framework="pt") as f:
weight_dict = {key: (f.get_tensor(key).to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()}
return weight_dict
ckpt_path = os.path.join(self.model_path, f"{ckpt_folder}/causal_model.pt")
if os.path.exists(ckpt_path):
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
weight_dict = {
key: (weight_dict[key].to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else weight_dict[key]).pin_memory().to(self.device) for key in weight_dict.keys()
}
return weight_dict
return super()._load_ckpt(use_bf16, skip_bf16)
@torch.no_grad()
def infer(self, inputs, kv_start, kv_end):
......
......@@ -24,11 +24,11 @@ import torch.distributed as dist
class WanCausVidRunner(WanRunner):
def __init__(self, config):
super().__init__(config)
self.num_frame_per_block = self.model.config.num_frame_per_block
self.num_frames = self.model.config.num_frames
self.frame_seq_length = self.model.config.frame_seq_length
self.infer_blocks = self.model.config.num_blocks
self.num_fragments = self.model.config.num_fragments
self.num_frame_per_block = self.config.num_frame_per_block
self.num_frames = self.config.num_frames
self.frame_seq_length = self.config.frame_seq_length
self.infer_blocks = self.config.num_blocks
self.num_fragments = self.config.num_fragments
def load_transformer(self):
if self.config.lora_path:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment