Commit 6b32b743 authored by gushiqiao's avatar gushiqiao
Browse files

Support loading bf16 weights and converting them to fp32

parent 978e3b32
......@@ -37,6 +37,10 @@ try:
except ImportError:
gguf = None
try:
import marlin_cuda_quant
except ModuleNotFoundError:
marlin_cuda_quant = None
class MMWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
......@@ -683,7 +687,37 @@ class MMWeightGGUFQ4K(MMWeightGGUFTemplate):
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
@MM_WEIGHT_REGISTER("W-int4-group128-sym-Marlin")
class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
"""
Name: "W-int4-group128-sym-Marlin
Quant int4 x FP16:
Weight: int4 pergroup sym
Kernel: Marlin
"""
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
self.load_func = self.load_quantized
def load(self, weight_dict):
assert not self.lazy_load
self.load_func(weight_dict)
self.workspace = weight_dict[f"{self.weight_name}_workspace"]
if self.bias_name is not None:
self.bias = weight_dict[self.bias_name]
self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
else:
self.bias = None
def apply(self, input_tensor):
output_tensor = torch.empty(input_tensor.shape[:-1] + (self.weight_scale.shape[1],), dtype=input_tensor.dtype, device=input_tensor.device)
marlin_cuda_quant.mul(input_tensor, self.weight, output_tensor, self.weight_scale.half(), self.workspace, -1, -1, -1, -1)
if hasattr(self, "bias") and self.bias is not None:
output_tensor.add_(self.bias)
return output_tensor
if __name__ == "__main__":
weight_dict = {
"xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn),
......
......@@ -37,7 +37,7 @@ class WanCausVidModel(WanModel):
if os.path.exists(safetensors_path):
with safe_open(safetensors_path, framework="pt") as f:
weight_dict = {
key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()
key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device) for key in f.keys()
}
return weight_dict
......
import os
import json
import torch
import torch.distributed as dist
from loguru import logger
......@@ -103,7 +103,7 @@ class WanModel:
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
with safe_open(file_path, framework="pt") as f:
return {key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()}
return {key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device) for key in f.keys()}
def _load_ckpt(self, unified_dtype, sensitive_layer):
safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_original_ckpt", subdir="original")
......@@ -134,11 +134,11 @@ class WanModel:
with safe_open(safetensor_path, framework="pt") as f:
logger.info(f"Loading weights from {safetensor_path}")
for k in f.keys():
if f.get_tensor(k).dtype == torch.float:
if f.get_tensor(k).dtype in [torch.float16, torch.bfloat16, torch.float]:
if unified_dtype or all(s not in k for s in sensitive_layer):
weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_DTYPE()).to(self.device)
else:
weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_SENSITIVE_DTYPE()).to(self.device)
else:
weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
......@@ -152,11 +152,11 @@ class WanModel:
safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors")
with safe_open(safetensor_path, framework="pt", device="cpu") as f:
for k in f.keys():
if f.get_tensor(k).dtype == torch.float:
if f.get_tensor(k).dtype in [torch.float16, torch.bfloat16, torch.float]:
if unified_dtype or all(s not in k for s in sensitive_layer):
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_DTYPE()).to(self.device)
else:
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_SENSITIVE_DTYPE()).to(self.device)
else:
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment