Commit 774ccfe7 authored by gushiqiao's avatar gushiqiao
Browse files

Fix

parent a92ea6e8
......@@ -20,17 +20,13 @@ class WanDistillModel(WanModel):
def __init__(self, model_path, config, device):
super().__init__(model_path, config, device)
def _load_ckpt(self):
use_bfloat16 = GET_DTYPE() == "BF16"
def _load_ckpt(self, use_bf16, skip_bf16):
ckpt_path = os.path.join(self.model_path, "distill_model.pt")
if not os.path.exists(ckpt_path):
# 文件不存在,调用父类的 _load_ckpt 方法
return super()._load_ckpt()
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
dtype = torch.bfloat16 if use_bfloat16 else None
for key, value in weight_dict.items():
weight_dict[key] = value.to(device=self.device, dtype=dtype)
weight_dict = {key: (weight_dict[key].to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else weight_dict[key]).pin_memory().to(self.device) for key in weight_dict.keys()}
return weight_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment