Commit 397ce244 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Merge pull request #150 from ModelTC/dev_quant

Fix
parents cc2a283a f62e3109
...@@ -36,7 +36,6 @@ class WanDistillModel(WanModel): ...@@ -36,7 +36,6 @@ class WanDistillModel(WanModel):
if os.path.exists(ckpt_path): if os.path.exists(ckpt_path):
logger.info(f"Loading weights from {ckpt_path}") logger.info(f"Loading weights from {ckpt_path}")
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
print(weight_dict.keys())
weight_dict = { weight_dict = {
key: (weight_dict[key].to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else weight_dict[key]).pin_memory().to(self.device) for key in weight_dict.keys() key: (weight_dict[key].to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else weight_dict[key]).pin_memory().to(self.device) for key in weight_dict.keys()
} }
......
...@@ -69,7 +69,6 @@ class WanRunner(DefaultRunner): ...@@ -69,7 +69,6 @@ class WanRunner(DefaultRunner):
clip_quantized_ckpt = None clip_quantized_ckpt = None
clip_quant_scheme = None clip_quant_scheme = None
print(clip_quant_scheme)
image_encoder = CLIPModel( image_encoder = CLIPModel(
dtype=torch.float16, dtype=torch.float16,
device=self.init_device, device=self.init_device,
...@@ -107,7 +106,7 @@ class WanRunner(DefaultRunner): ...@@ -107,7 +106,7 @@ class WanRunner(DefaultRunner):
else: else:
t5_quant_scheme = None t5_quant_scheme = None
t5_quantized_ckpt = None t5_quantized_ckpt = None
print(t5_quant_scheme)
text_encoder = T5EncoderModel( text_encoder = T5EncoderModel(
text_len=self.config["text_len"], text_len=self.config["text_len"],
dtype=torch.bfloat16, dtype=torch.bfloat16,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment