Commit f62e3109 authored by gushiqiao's avatar gushiqiao
Browse files

Fix

parent adf8df9d
......@@ -36,7 +36,6 @@ class WanDistillModel(WanModel):
if os.path.exists(ckpt_path):
logger.info(f"Loading weights from {ckpt_path}")
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
print(weight_dict.keys())
weight_dict = {
key: (weight_dict[key].to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else weight_dict[key]).pin_memory().to(self.device) for key in weight_dict.keys()
}
......
......@@ -69,7 +69,6 @@ class WanRunner(DefaultRunner):
clip_quantized_ckpt = None
clip_quant_scheme = None
print(clip_quant_scheme)
image_encoder = CLIPModel(
dtype=torch.float16,
device=self.init_device,
......@@ -107,7 +106,7 @@ class WanRunner(DefaultRunner):
else:
t5_quant_scheme = None
t5_quantized_ckpt = None
print(t5_quant_scheme)
text_encoder = T5EncoderModel(
text_len=self.config["text_len"],
dtype=torch.bfloat16,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment