Commit 94bd4599 authored by gushiqiao's avatar gushiqiao
Browse files

Support loading bf16 weights and converting them to fp32

parent 6b32b743
......@@ -42,6 +42,7 @@ try:
except ModuleNotFoundError:
marlin_cuda_quant = None
class MMWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
self.weight_name = weight_name
......@@ -687,6 +688,7 @@ class MMWeightGGUFQ4K(MMWeightGGUFTemplate):
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
@MM_WEIGHT_REGISTER("W-int4-group128-sym-Marlin")
class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
"""
......@@ -710,14 +712,15 @@ class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
else:
self.bias = None
def apply(self, input_tensor):
output_tensor = torch.empty(input_tensor.shape[:-1] + (self.weight_scale.shape[1],), dtype=input_tensor.dtype, device=input_tensor.device)
marlin_cuda_quant.mul(input_tensor, self.weight, output_tensor, self.weight_scale.half(), self.workspace, -1, -1, -1, -1)
if hasattr(self, "bias") and self.bias is not None:
output_tensor.add_(self.bias)
return output_tensor
if __name__ == "__main__":
weight_dict = {
"xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn),
......
......@@ -37,7 +37,10 @@ class WanCausVidModel(WanModel):
if os.path.exists(safetensors_path):
with safe_open(safetensors_path, framework="pt") as f:
weight_dict = {
key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device) for key in f.keys()
key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE()))
.pin_memory()
.to(self.device)
for key in f.keys()
}
return weight_dict
......
import os
import json
import os
import torch
import torch.distributed as dist
from loguru import logger
......@@ -103,7 +104,10 @@ class WanModel:
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
with safe_open(file_path, framework="pt") as f:
return {key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device) for key in f.keys()}
return {
key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device)
for key in f.keys()
}
def _load_ckpt(self, unified_dtype, sensitive_layer):
safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_original_ckpt", subdir="original")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment