"test/vscode:/vscode.git/clone" did not exist on "2b80978859794eb9dcf0156066a4e7c7b7abc713"
Commit 774ccfe7 authored by gushiqiao's avatar gushiqiao
Browse files

Fix

parent a92ea6e8
...@@ -20,17 +20,13 @@ class WanDistillModel(WanModel): ...@@ -20,17 +20,13 @@ class WanDistillModel(WanModel):
def __init__(self, model_path, config, device): def __init__(self, model_path, config, device):
super().__init__(model_path, config, device) super().__init__(model_path, config, device)
def _load_ckpt(self): def _load_ckpt(self, use_bf16, skip_bf16):
use_bfloat16 = GET_DTYPE() == "BF16"
ckpt_path = os.path.join(self.model_path, "distill_model.pt") ckpt_path = os.path.join(self.model_path, "distill_model.pt")
if not os.path.exists(ckpt_path): if not os.path.exists(ckpt_path):
# 文件不存在,调用父类的 _load_ckpt 方法 # 文件不存在,调用父类的 _load_ckpt 方法
return super()._load_ckpt() return super()._load_ckpt()
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
weight_dict = {key: (weight_dict[key].to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else weight_dict[key]).pin_memory().to(self.device) for key in weight_dict.keys()}
dtype = torch.bfloat16 if use_bfloat16 else None
for key, value in weight_dict.items():
weight_dict[key] = value.to(device=self.device, dtype=dtype)
return weight_dict return weight_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment