Commit 94bd4599 authored by gushiqiao's avatar gushiqiao
Browse files

Support loading bf16 weights and converting them to fp32

parent 6b32b743
...@@ -42,6 +42,7 @@ try: ...@@ -42,6 +42,7 @@ try:
except ModuleNotFoundError: except ModuleNotFoundError:
marlin_cuda_quant = None marlin_cuda_quant = None
class MMWeightTemplate(metaclass=ABCMeta): class MMWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None): def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
self.weight_name = weight_name self.weight_name = weight_name
...@@ -687,6 +688,7 @@ class MMWeightGGUFQ4K(MMWeightGGUFTemplate): ...@@ -687,6 +688,7 @@ class MMWeightGGUFQ4K(MMWeightGGUFTemplate):
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None): def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file) super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
@MM_WEIGHT_REGISTER("W-int4-group128-sym-Marlin") @MM_WEIGHT_REGISTER("W-int4-group128-sym-Marlin")
class MMWeightWint4group128Marlin(MMWeightQuantTemplate): class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
""" """
...@@ -710,14 +712,15 @@ class MMWeightWint4group128Marlin(MMWeightQuantTemplate): ...@@ -710,14 +712,15 @@ class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
else: else:
self.bias = None self.bias = None
def apply(self, input_tensor): def apply(self, input_tensor):
output_tensor = torch.empty(input_tensor.shape[:-1] + (self.weight_scale.shape[1],), dtype=input_tensor.dtype, device=input_tensor.device) output_tensor = torch.empty(input_tensor.shape[:-1] + (self.weight_scale.shape[1],), dtype=input_tensor.dtype, device=input_tensor.device)
marlin_cuda_quant.mul(input_tensor, self.weight, output_tensor, self.weight_scale.half(), self.workspace, -1, -1, -1, -1) marlin_cuda_quant.mul(input_tensor, self.weight, output_tensor, self.weight_scale.half(), self.workspace, -1, -1, -1, -1)
if hasattr(self, "bias") and self.bias is not None: if hasattr(self, "bias") and self.bias is not None:
output_tensor.add_(self.bias) output_tensor.add_(self.bias)
return output_tensor return output_tensor
if __name__ == "__main__": if __name__ == "__main__":
weight_dict = { weight_dict = {
"xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn), "xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn),
......
...@@ -37,7 +37,10 @@ class WanCausVidModel(WanModel): ...@@ -37,7 +37,10 @@ class WanCausVidModel(WanModel):
if os.path.exists(safetensors_path): if os.path.exists(safetensors_path):
with safe_open(safetensors_path, framework="pt") as f: with safe_open(safetensors_path, framework="pt") as f:
weight_dict = { weight_dict = {
key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device) for key in f.keys() key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE()))
.pin_memory()
.to(self.device)
for key in f.keys()
} }
return weight_dict return weight_dict
......
import os
import json import json
import os
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from loguru import logger from loguru import logger
...@@ -103,7 +104,10 @@ class WanModel: ...@@ -103,7 +104,10 @@ class WanModel:
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer): def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
with safe_open(file_path, framework="pt") as f: with safe_open(file_path, framework="pt") as f:
return {key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device) for key in f.keys()} return {
key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device)
for key in f.keys()
}
def _load_ckpt(self, unified_dtype, sensitive_layer): def _load_ckpt(self, unified_dtype, sensitive_layer):
safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_original_ckpt", subdir="original") safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_original_ckpt", subdir="original")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment