Commit a5535e49 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Merge pull request #63 from ModelTC/dev_fix

Fix
parents 2ef8e74e 6c5817f8
......@@ -66,9 +66,9 @@ class WanModel:
use_bfloat16 = self.config.get("use_bfloat16", True)
with safe_open(file_path, framework="pt") as f:
if use_bfloat16:
tensor_dict = {key: f.get_tensor(key).to(torch.bfloat16).to(self.device) for key in f.keys()}
tensor_dict = {key: f.get_tensor(key).pin_memory().to(torch.bfloat16).to(self.device) for key in f.keys()}
else:
tensor_dict = {key: f.get_tensor(key).to(self.device) for key in f.keys()}
tensor_dict = {key: f.get_tensor(key).pin_memory().to(self.device) for key in f.keys()}
return tensor_dict
def _load_ckpt(self):
......@@ -107,9 +107,9 @@ class WanModel:
with safe_open(safetensor_path, framework="pt", device=str(self.device)) as f:
logger.info(f"Loading weights from {safetensor_path}")
for k in f.keys():
weight_dict[k] = f.get_tensor(k)
weight_dict[k] = f.get_tensor(k).pin_memory()
if weight_dict[k].dtype == torch.float:
weight_dict[k] = weight_dict[k].to(torch.bfloat16)
weight_dict[k] = weight_dict[k].pin_memory().to(torch.bfloat16)
return weight_dict
......@@ -121,9 +121,9 @@ class WanModel:
safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors")
with safe_open(safetensor_path, framework="pt", device=str(self.device)) as f:
for k in f.keys():
pre_post_weight_dict[k] = f.get_tensor(k)
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory()
if pre_post_weight_dict[k].dtype == torch.float:
pre_post_weight_dict[k] = pre_post_weight_dict[k].to(torch.bfloat16)
pre_post_weight_dict[k] = pre_post_weight_dict[k].pin_memory().to(torch.bfloat16)
safetensors_pattern = os.path.join(lazy_load_model_path, "block_*.safetensors")
safetensors_files = glob.glob(safetensors_pattern)
......@@ -134,9 +134,9 @@ class WanModel:
with safe_open(file_path, framework="pt") as f:
for k in f.keys():
if "modulation" in k:
transformer_weight_dict[k] = f.get_tensor(k)
transformer_weight_dict[k] = f.get_tensor(k).pin_memory()
if transformer_weight_dict[k].dtype == torch.float:
transformer_weight_dict[k] = transformer_weight_dict[k].to(torch.bfloat16)
transformer_weight_dict[k] = transformer_weight_dict[k].pin_memory().to(torch.bfloat16)
return pre_post_weight_dict, transformer_weight_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment