Commit 6b32b743 authored by gushiqiao's avatar gushiqiao
Browse files

Support loading bf16 weights and converting them to fp32

parent 978e3b32
...@@ -37,6 +37,10 @@ try: ...@@ -37,6 +37,10 @@ try:
except ImportError: except ImportError:
gguf = None gguf = None
try:
import marlin_cuda_quant
except ModuleNotFoundError:
marlin_cuda_quant = None
class MMWeightTemplate(metaclass=ABCMeta): class MMWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None): def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
...@@ -683,6 +687,36 @@ class MMWeightGGUFQ4K(MMWeightGGUFTemplate): ...@@ -683,6 +687,36 @@ class MMWeightGGUFQ4K(MMWeightGGUFTemplate):
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None): def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file) super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
@MM_WEIGHT_REGISTER("W-int4-group128-sym-Marlin")
class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
"""
Name: "W-int4-group128-sym-Marlin
Quant int4 x FP16:
Weight: int4 pergroup sym
Kernel: Marlin
"""
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
self.load_func = self.load_quantized
def load(self, weight_dict):
assert not self.lazy_load
self.load_func(weight_dict)
self.workspace = weight_dict[f"{self.weight_name}_workspace"]
if self.bias_name is not None:
self.bias = weight_dict[self.bias_name]
self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
else:
self.bias = None
def apply(self, input_tensor):
output_tensor = torch.empty(input_tensor.shape[:-1] + (self.weight_scale.shape[1],), dtype=input_tensor.dtype, device=input_tensor.device)
marlin_cuda_quant.mul(input_tensor, self.weight, output_tensor, self.weight_scale.half(), self.workspace, -1, -1, -1, -1)
if hasattr(self, "bias") and self.bias is not None:
output_tensor.add_(self.bias)
return output_tensor
if __name__ == "__main__": if __name__ == "__main__":
weight_dict = { weight_dict = {
......
...@@ -37,7 +37,7 @@ class WanCausVidModel(WanModel): ...@@ -37,7 +37,7 @@ class WanCausVidModel(WanModel):
if os.path.exists(safetensors_path): if os.path.exists(safetensors_path):
with safe_open(safetensors_path, framework="pt") as f: with safe_open(safetensors_path, framework="pt") as f:
weight_dict = { weight_dict = {
key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys() key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device) for key in f.keys()
} }
return weight_dict return weight_dict
......
import os import os
import json
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from loguru import logger from loguru import logger
...@@ -103,7 +103,7 @@ class WanModel: ...@@ -103,7 +103,7 @@ class WanModel:
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer): def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
with safe_open(file_path, framework="pt") as f: with safe_open(file_path, framework="pt") as f:
return {key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()} return {key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device) for key in f.keys()}
def _load_ckpt(self, unified_dtype, sensitive_layer): def _load_ckpt(self, unified_dtype, sensitive_layer):
safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_original_ckpt", subdir="original") safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_original_ckpt", subdir="original")
...@@ -134,11 +134,11 @@ class WanModel: ...@@ -134,11 +134,11 @@ class WanModel:
with safe_open(safetensor_path, framework="pt") as f: with safe_open(safetensor_path, framework="pt") as f:
logger.info(f"Loading weights from {safetensor_path}") logger.info(f"Loading weights from {safetensor_path}")
for k in f.keys(): for k in f.keys():
if f.get_tensor(k).dtype == torch.float: if f.get_tensor(k).dtype in [torch.float16, torch.bfloat16, torch.float]:
if unified_dtype or all(s not in k for s in sensitive_layer): if unified_dtype or all(s not in k for s in sensitive_layer):
weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_DTYPE()).to(self.device) weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_DTYPE()).to(self.device)
else: else:
weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device) weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_SENSITIVE_DTYPE()).to(self.device)
else: else:
weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device) weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
...@@ -152,11 +152,11 @@ class WanModel: ...@@ -152,11 +152,11 @@ class WanModel:
safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors") safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors")
with safe_open(safetensor_path, framework="pt", device="cpu") as f: with safe_open(safetensor_path, framework="pt", device="cpu") as f:
for k in f.keys(): for k in f.keys():
if f.get_tensor(k).dtype == torch.float: if f.get_tensor(k).dtype in [torch.float16, torch.bfloat16, torch.float]:
if unified_dtype or all(s not in k for s in sensitive_layer): if unified_dtype or all(s not in k for s in sensitive_layer):
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_DTYPE()).to(self.device) pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_DTYPE()).to(self.device)
else: else:
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device) pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_SENSITIVE_DTYPE()).to(self.device)
else: else:
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device) pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment