Commit 26105692 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Support loading bf16 weights and converting them to fp32

Support loading bf16 weights and converting them to fp32
parents e7e74dcd 94bd4599
...@@ -37,6 +37,11 @@ try: ...@@ -37,6 +37,11 @@ try:
except ImportError: except ImportError:
gguf = None gguf = None
try:
import marlin_cuda_quant
except ModuleNotFoundError:
marlin_cuda_quant = None
class MMWeightTemplate(metaclass=ABCMeta): class MMWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None): def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
...@@ -684,6 +689,38 @@ class MMWeightGGUFQ4K(MMWeightGGUFTemplate): ...@@ -684,6 +689,38 @@ class MMWeightGGUFQ4K(MMWeightGGUFTemplate):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file) super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
@MM_WEIGHT_REGISTER("W-int4-group128-sym-Marlin")
class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
"""
Name: "W-int4-group128-sym-Marlin
Quant int4 x FP16:
Weight: int4 pergroup sym
Kernel: Marlin
"""
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
self.load_func = self.load_quantized
def load(self, weight_dict):
assert not self.lazy_load
self.load_func(weight_dict)
self.workspace = weight_dict[f"{self.weight_name}_workspace"]
if self.bias_name is not None:
self.bias = weight_dict[self.bias_name]
self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
else:
self.bias = None
def apply(self, input_tensor):
output_tensor = torch.empty(input_tensor.shape[:-1] + (self.weight_scale.shape[1],), dtype=input_tensor.dtype, device=input_tensor.device)
marlin_cuda_quant.mul(input_tensor, self.weight, output_tensor, self.weight_scale.half(), self.workspace, -1, -1, -1, -1)
if hasattr(self, "bias") and self.bias is not None:
output_tensor.add_(self.bias)
return output_tensor
if __name__ == "__main__": if __name__ == "__main__":
weight_dict = { weight_dict = {
"xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn), "xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn),
......
...@@ -37,7 +37,10 @@ class WanCausVidModel(WanModel): ...@@ -37,7 +37,10 @@ class WanCausVidModel(WanModel):
if os.path.exists(safetensors_path): if os.path.exists(safetensors_path):
with safe_open(safetensors_path, framework="pt") as f: with safe_open(safetensors_path, framework="pt") as f:
weight_dict = { weight_dict = {
key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys() key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE()))
.pin_memory()
.to(self.device)
for key in f.keys()
} }
return weight_dict return weight_dict
......
import json
import os import os
import torch import torch
...@@ -103,7 +104,10 @@ class WanModel: ...@@ -103,7 +104,10 @@ class WanModel:
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer): def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
with safe_open(file_path, framework="pt") as f: with safe_open(file_path, framework="pt") as f:
return {key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key)).pin_memory().to(self.device) for key in f.keys()} return {
key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device)
for key in f.keys()
}
def _load_ckpt(self, unified_dtype, sensitive_layer): def _load_ckpt(self, unified_dtype, sensitive_layer):
safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_original_ckpt", subdir="original") safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_original_ckpt", subdir="original")
...@@ -134,11 +138,11 @@ class WanModel: ...@@ -134,11 +138,11 @@ class WanModel:
with safe_open(safetensor_path, framework="pt") as f: with safe_open(safetensor_path, framework="pt") as f:
logger.info(f"Loading weights from {safetensor_path}") logger.info(f"Loading weights from {safetensor_path}")
for k in f.keys(): for k in f.keys():
if f.get_tensor(k).dtype == torch.float: if f.get_tensor(k).dtype in [torch.float16, torch.bfloat16, torch.float]:
if unified_dtype or all(s not in k for s in sensitive_layer): if unified_dtype or all(s not in k for s in sensitive_layer):
weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_DTYPE()).to(self.device) weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_DTYPE()).to(self.device)
else: else:
weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device) weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_SENSITIVE_DTYPE()).to(self.device)
else: else:
weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device) weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
...@@ -152,11 +156,11 @@ class WanModel: ...@@ -152,11 +156,11 @@ class WanModel:
safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors") safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors")
with safe_open(safetensor_path, framework="pt", device="cpu") as f: with safe_open(safetensor_path, framework="pt", device="cpu") as f:
for k in f.keys(): for k in f.keys():
if f.get_tensor(k).dtype == torch.float: if f.get_tensor(k).dtype in [torch.float16, torch.bfloat16, torch.float]:
if unified_dtype or all(s not in k for s in sensitive_layer): if unified_dtype or all(s not in k for s in sensitive_layer):
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_DTYPE()).to(self.device) pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_DTYPE()).to(self.device)
else: else:
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device) pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_SENSITIVE_DTYPE()).to(self.device)
else: else:
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device) pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment