"src/targets/gpu/vscode:/vscode.git/clone" did not exist on "6e2df9de64c5c581a0e6f8b0b60b72685f86a7cc"
Unverified Commit 94e593ce authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Recon] remove mm_config and support load single safetensors file (#375)

parent 954df466
...@@ -16,14 +16,10 @@ class WeightModule: ...@@ -16,14 +16,10 @@ class WeightModule:
def load(self, weight_dict): def load(self, weight_dict):
for _, module in self._modules.items(): for _, module in self._modules.items():
if hasattr(module, "set_config"):
module.set_config(self.config["mm_config"])
if hasattr(module, "load"): if hasattr(module, "load"):
module.load(weight_dict) module.load(weight_dict)
for _, parameter in self._parameters.items(): for _, parameter in self._parameters.items():
if hasattr(parameter, "set_config"):
parameter.set_config(self.config["mm_config"])
if hasattr(parameter, "load"): if hasattr(parameter, "load"):
parameter.load(weight_dict) parameter.load(weight_dict)
......
from .mm_weight import * from .mm_weight import *
from .mm_weight_calib import *
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import torch import torch
from loguru import logger
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.quant_utils import FloatQuantizer, IntegerQuantizer from lightx2v.utils.quant_utils import FloatQuantizer, IntegerQuantizer
...@@ -362,7 +361,7 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -362,7 +361,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
return destination return destination
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm") @MM_WEIGHT_REGISTER("fp8-vllm")
class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightQuantTemplate): class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightQuantTemplate):
""" """
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm
...@@ -397,7 +396,7 @@ class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightQuantTemplate): ...@@ -397,7 +396,7 @@ class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightQuantTemplate):
return output_tensor return output_tensor
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm") @MM_WEIGHT_REGISTER("int8-vllm")
class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate): class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate):
""" """
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm
...@@ -432,7 +431,7 @@ class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate): ...@@ -432,7 +431,7 @@ class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate):
return output_tensor return output_tensor
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F") @MM_WEIGHT_REGISTER("fp8-q8f")
class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate): class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate):
""" """
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F
...@@ -462,7 +461,7 @@ class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate): ...@@ -462,7 +461,7 @@ class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate):
return output_tensor.squeeze(0) return output_tensor.squeeze(0)
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F") @MM_WEIGHT_REGISTER("int8-q8f")
class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate): class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate):
""" """
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F
...@@ -493,52 +492,7 @@ class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate): ...@@ -493,52 +492,7 @@ class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate):
return output_tensor.squeeze(0) return output_tensor.squeeze(0)
@MM_WEIGHT_REGISTER("W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm") @MM_WEIGHT_REGISTER("fp8-b128-deepgemm")
class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemm(MMWeightQuantTemplate):
"""
Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm
Quant MM:
Weight: fp8 perblock 128x128 sym
Act: fp8 perchannel-pergroup group=128 dynamic sym
Kernel: Deepgemm
Reference: https://github.com/deepseek-ai/DeepGEMM
Example:
Act(1024, 2048) x Weight(2048, 4096) = Out(1024, 4096)
Act : torch.Size([1024, 2048]), torch.float8_e4m3fn
Act Scale: torch.Size([1024, 16]), torch.float32
Weight : torch.Size([4096, 2048]), torch.float8_e4m3fn
Weight Scale: torch.Size([32, 16]), torch.float32
Out : torch.Size([1024, 4096]), self.infer_dtype
"""
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
self.load_func = self.load_fp8_perblock128_sym
self.weight_need_transpose = False
self.act_quant_func = self.act_quant_fp8_perchannelgroup128_sym_deepgemm
def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[0])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
deep_gemm.gemm_fp8_fp8_bf16_nt(
(input_tensor_quant, input_tensor_scale),
(self.weight, self.weight_scale),
output_tensor,
)
if hasattr(self, "bias") and self.bias is not None:
output_tensor.add_(self.bias)
return output_tensor
@MM_WEIGHT_REGISTER("W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl")
class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemmActSgl(MMWeightQuantTemplate): class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemmActSgl(MMWeightQuantTemplate):
""" """
Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl
...@@ -572,72 +526,7 @@ class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemmActSgl(MMWeightQuant ...@@ -572,72 +526,7 @@ class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemmActSgl(MMWeightQuant
return output_tensor return output_tensor
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm-ActSgl") @MM_WEIGHT_REGISTER("fp8-sgl")
class MMWeightWfp8channelAfp8channeldynamicVllmActSgl(MMWeightQuantTemplate):
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm-ActSgl
Quant MM:
Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym
Kernel: quant-mm using vllm, act dynamic quant using Sgl-kernel
"""
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
self.load_func = self.load_fp8_perchannel_sym
self.weight_need_transpose = True
self.act_quant_func = self.act_quant_fp8_perchannel_sym_sgl
def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[1])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
torch.ops._C.cutlass_scaled_mm(
output_tensor,
input_tensor_quant,
self.weight,
input_tensor_scale,
self.weight_scale,
self.bias,
)
return output_tensor
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl-ActVllm")
class MMWeightWfp8channelAfp8channeldynamicSglActVllm(MMWeightQuantTemplate):
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl-ActVllm
Quant MM:
Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym
Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
"""
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
self.load_func = self.load_fp8_perchannel_sym
self.weight_need_transpose = True
self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = sgl_kernel.fp8_scaled_mm(
input_tensor_quant,
self.weight,
input_tensor_scale,
self.weight_scale,
self.infer_dtype,
bias=self.bias,
)
return output_tensor
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl")
class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate): class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
""" """
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl
...@@ -667,7 +556,7 @@ class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate): ...@@ -667,7 +556,7 @@ class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
return output_tensor return output_tensor
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm") @MM_WEIGHT_REGISTER("int8-sgl")
class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate): class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
""" """
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm
...@@ -702,7 +591,7 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate): ...@@ -702,7 +591,7 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
return output_tensor return output_tensor
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao") @MM_WEIGHT_REGISTER("int8-torchao")
class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate): class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
""" """
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao
...@@ -746,7 +635,7 @@ class MMWeightGGUFQ4K(MMWeightGGUFTemplate): ...@@ -746,7 +635,7 @@ class MMWeightGGUFQ4K(MMWeightGGUFTemplate):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file) super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
@MM_WEIGHT_REGISTER("W-int4-group128-sym-Marlin") @MM_WEIGHT_REGISTER("int4-g128-marlin")
class MMWeightWint4group128Marlin(MMWeightQuantTemplate): class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
""" """
Name: "W-int4-group128-sym-Marlin Name: "W-int4-group128-sym-Marlin
...@@ -779,42 +668,3 @@ class MMWeightWint4group128Marlin(MMWeightQuantTemplate): ...@@ -779,42 +668,3 @@ class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
if hasattr(self, "bias") and self.bias is not None: if hasattr(self, "bias") and self.bias is not None:
output_tensor.add_(self.bias) output_tensor.add_(self.bias)
return output_tensor return output_tensor
if __name__ == "__main__":
weight_dict = {
"xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn),
"xx.bias": torch.randn(8192).to(torch.bfloat16),
"xx.weight_scale": torch.randn(8192, 1).to(torch.float32),
}
mm_weight = MM_WEIGHT_REGISTER["W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
mm_weight.set_config({"weight_auto_quant": False})
mm_weight.load(weight_dict)
input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
output_tensor = mm_weight.apply(input_tensor)
logger.info(output_tensor.shape)
weight_dict = {
"xx.weight": torch.randn(8192, 4096),
"xx.bias": torch.randn(8192).to(torch.bfloat16),
}
mm_weight = MM_WEIGHT_REGISTER["W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
mm_weight.set_config({"weight_auto_quant": True})
mm_weight.load(weight_dict)
input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
output_tensor = mm_weight.apply(input_tensor)
logger.info(output_tensor.shape)
weight_dict = {
"xx.weight": torch.randn(8192, 4096),
"xx.bias": torch.randn(8192).to(torch.bfloat16),
}
mm_weight = MM_WEIGHT_REGISTER["W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
mm_weight.set_config({"weight_auto_quant": True})
mm_weight.load(weight_dict)
input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
output_tensor = mm_weight.apply(input_tensor)
logger.info(output_tensor.shape)
import torch
from lightx2v.utils.quant_utils import FloatQuantizer
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from .mm_weight import MMWeight
@MM_WEIGHT_REGISTER("Calib")
class MMWeightCalib(MMWeight):
def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name)
def load(self, weight_dict):
assert self.config and self.config.get("mm_type", "Default") != "Default"
self.weight = weight_dict[self.weight_name]
self.get_quantizer()
shape_and_dtype = self.get_quant_shape_and_dtype(self.weight.shape)
self.realq_weight, self.scales, self.zeros = self.w_quantizer.real_quant_tensor(self.weight)
self.realq_weight = self.realq_weight.view(shape_and_dtype["tensor"][0]).contiguous().to(shape_and_dtype["tensor"][1])
self.scales = self.scales.view(shape_and_dtype["scales"][0]).contiguous().to(shape_and_dtype["scales"][1])
if self.zeros is not None:
self.zeros = self.zeros.view(shape_and_dtype["zeros"][0]).contiguous().to(shape_and_dtype["zeros"][1])
def apply(self, input_tensor):
return super().apply(input_tensor)
def get_quantizer(self):
if self.config["mm_type"] == "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm":
self.w_setting = {"bit": "e4m3", "symmetric": True, "granularity": "per_channel"}
self.a_setting = {"bit": "e4m3", "symmetric": True, "granularity": "per_channel"}
self.w_quantizer = FloatQuantizer(**self.w_setting)
self.a_quantizer = FloatQuantizer(**self.a_setting)
self.act_dynamic_quant = True
else:
raise NotImplementedError(f"Unsupported mm_type: {self.config['mm_type']}")
def get_quant_shape_and_dtype(self, shape):
if self.config["mm_type"] == "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm":
return {
"tensor": (shape, torch.float8_e5m2),
"scales": ((shape[0], 1), torch.float32),
"zeros": None,
}
else:
raise NotImplementedError(f"Unsupported mm_type: {self.config['mm_type']}")
...@@ -33,8 +33,7 @@ class QwenImageTransformerModel: ...@@ -33,8 +33,7 @@ class QwenImageTransformerModel:
self.in_channels = transformer_config["in_channels"] self.in_channels = transformer_config["in_channels"]
self.attention_kwargs = {} self.attention_kwargs = {}
self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default" self.dit_quantized = self.config["dit_quantized"]
self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False)
self._init_infer_class() self._init_infer_class()
self._init_weights() self._init_weights()
...@@ -62,7 +61,7 @@ class QwenImageTransformerModel: ...@@ -62,7 +61,7 @@ class QwenImageTransformerModel:
if weight_dict is None: if weight_dict is None:
is_weight_loader = self._should_load_weights() is_weight_loader = self._should_load_weights()
if is_weight_loader: if is_weight_loader:
if not self.dit_quantized or self.weight_auto_quant: if not self.dit_quantized:
# Load original weights # Load original weights
weight_dict = self._load_ckpt(unified_dtype, sensitive_layer) weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
else: else:
......
...@@ -10,11 +10,9 @@ class QwenImagePostWeights(WeightModule): ...@@ -10,11 +10,9 @@ class QwenImagePostWeights(WeightModule):
super().__init__() super().__init__()
self.task = config["task"] self.task = config["task"]
self.config = config self.config = config
if config["do_mm_calib"]: self.mm_type = config.get("dit_quant_scheme", "Default")
self.mm_type = "Calib" if self.mm_type != "Default":
else: assert config.get("dit_quantized") is True
self.mm_type = config["mm_config"].get("mm_type", "Default") if config["mm_config"] else "Default"
self.lazy_load = self.config.get("lazy_load", False) self.lazy_load = self.config.get("lazy_load", False)
if self.lazy_load: if self.lazy_load:
assert NotImplementedError assert NotImplementedError
......
...@@ -12,11 +12,9 @@ class QwenImageTransformerWeights(WeightModule): ...@@ -12,11 +12,9 @@ class QwenImageTransformerWeights(WeightModule):
self.blocks_num = config["num_layers"] self.blocks_num = config["num_layers"]
self.task = config["task"] self.task = config["task"]
self.config = config self.config = config
if config["do_mm_calib"]: self.mm_type = config.get("dit_quant_scheme", "Default")
self.mm_type = "Calib" if self.mm_type != "Default":
else: assert config.get("dit_quantized") is True
self.mm_type = config["mm_config"].get("mm_type", "Default") if config["mm_config"] else "Default"
blocks = WeightModuleList(QwenImageTransformerAttentionBlock(i, self.task, self.mm_type, self.config, "transformer_blocks") for i in range(self.blocks_num)) blocks = WeightModuleList(QwenImageTransformerAttentionBlock(i, self.task, self.mm_type, self.config, "transformer_blocks") for i in range(self.blocks_num))
self.add_module("blocks", blocks) self.add_module("blocks", blocks)
......
import glob
import os import os
import torch import torch
...@@ -19,8 +18,7 @@ class WanDistillModel(WanModel): ...@@ -19,8 +18,7 @@ class WanDistillModel(WanModel):
post_weight_class = WanPostWeights post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device, ckpt_config_key="dit_distill_ckpt"): def __init__(self, model_path, config, device):
self.ckpt_config_key = ckpt_config_key
super().__init__(model_path, config, device) super().__init__(model_path, config, device)
def _load_ckpt(self, unified_dtype, sensitive_layer): def _load_ckpt(self, unified_dtype, sensitive_layer):
...@@ -34,21 +32,4 @@ class WanDistillModel(WanModel): ...@@ -34,21 +32,4 @@ class WanDistillModel(WanModel):
for key in weight_dict.keys() for key in weight_dict.keys()
} }
return weight_dict return weight_dict
return super()._load_ckpt(unified_dtype, sensitive_layer)
if self.config.get("enable_dynamic_cfg", False):
safetensors_path = find_hf_model_path(self.config, self.model_path, self.ckpt_config_key, subdir="distill_cfg_models")
else:
safetensors_path = find_hf_model_path(self.config, self.model_path, self.ckpt_config_key, subdir="distill_models")
if os.path.isfile(safetensors_path):
logger.info(f"loading checkpoint from {safetensors_path} ...")
safetensors_files = glob.glob(safetensors_path)
else:
logger.info(f"loading checkpoint from {safetensors_path} ...")
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
weight_dict = {}
for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
weight_dict.update(file_weights)
return weight_dict
import gc import gc
import json
import os import os
import torch import torch
...@@ -62,35 +61,9 @@ class WanModel(CompiledMethodsMixin): ...@@ -62,35 +61,9 @@ class WanModel(CompiledMethodsMixin):
self.init_empty_model = False self.init_empty_model = False
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False) self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.dit_quantized = self.config["mm_config"].get("mm_type", "Default") != "Default" self.dit_quantized = self.config.get("dit_quantized", False)
if self.dit_quantized:
dit_quant_scheme = self.config["mm_config"].get("mm_type").split("-")[1]
if self.config["model_cls"] == "wan2.1_distill":
dit_quant_scheme = "distill_" + dit_quant_scheme
if dit_quant_scheme == "gguf":
self.dit_quantized_ckpt = find_gguf_model_path(config, "dit_quantized_ckpt", subdir=dit_quant_scheme)
self.config["use_gguf"] = True
else:
self.dit_quantized_ckpt = find_hf_model_path(
config,
self.model_path,
"dit_quantized_ckpt",
subdir=dit_quant_scheme,
)
quant_config_path = os.path.join(self.dit_quantized_ckpt, "config.json")
if os.path.exists(quant_config_path):
with open(quant_config_path, "r") as f:
quant_model_config = json.load(f)
self.config.update(quant_model_config)
else:
self.dit_quantized_ckpt = None
assert not self.config.get("lazy_load", False)
self.weight_auto_quant = self.config["mm_config"].get("weight_auto_quant", False)
if self.dit_quantized: if self.dit_quantized:
assert self.weight_auto_quant or self.dit_quantized_ckpt is not None assert self.config.get("dit_quant_scheme", "Default") in ["Default-Force-FP32", "fp8-vllm", "int8-vllm", "fp8-q8f", "int8-q8f", "fp8-b128-deepgemm", "fp8-sgl", "int8-sgl", "int8-torchao"]
self.device = device self.device = device
self._init_infer_class() self._init_infer_class()
self._init_weights() self._init_weights()
...@@ -152,40 +125,50 @@ class WanModel(CompiledMethodsMixin): ...@@ -152,40 +125,50 @@ class WanModel(CompiledMethodsMixin):
} }
def _load_ckpt(self, unified_dtype, sensitive_layer): def _load_ckpt(self, unified_dtype, sensitive_layer):
safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_original_ckpt", subdir="original") if self.config.get("dit_original_ckpt", None):
safetensors_path = self.config["dit_original_ckpt"]
else:
safetensors_path = self.model_path
if os.path.isdir(safetensors_path):
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors")) safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
else:
safetensors_files = [safetensors_path]
weight_dict = {} weight_dict = {}
for file_path in safetensors_files: for file_path in safetensors_files:
if self.config.get("adapter_model_path", None) is not None: if self.config.get("adapter_model_path", None) is not None:
if self.config["adapter_model_path"] == file_path: if self.config["adapter_model_path"] == file_path:
continue continue
logger.info(f"Loading weights from {file_path}")
file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer) file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
weight_dict.update(file_weights) weight_dict.update(file_weights)
return weight_dict return weight_dict
def _load_quant_ckpt(self, unified_dtype, sensitive_layer): def _load_quant_ckpt(self, unified_dtype, sensitive_layer):
remove_keys = self.remove_keys if hasattr(self, "remove_keys") else [] remove_keys = self.remove_keys if hasattr(self, "remove_keys") else []
ckpt_path = self.dit_quantized_ckpt
index_files = [f for f in os.listdir(ckpt_path) if f.endswith(".index.json")]
if not index_files:
raise FileNotFoundError(f"No *.index.json found in {ckpt_path}")
index_path = os.path.join(ckpt_path, index_files[0]) if self.config.get("dit_quantized_ckpt", None):
logger.info(f" Using safetensors index: {index_path}") safetensors_path = self.config["dit_quantized_ckpt"]
else:
safetensors_path = self.model_path
with open(index_path, "r") as f: if os.path.isdir(safetensors_path):
index_data = json.load(f) safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
else:
safetensors_files = [safetensors_path]
weight_dict = {} weight_dict = {}
for filename in set(index_data["weight_map"].values()): for safetensor_path in safetensors_files:
safetensor_path = os.path.join(ckpt_path, filename) if self.config.get("adapter_model_path", None) is not None:
if self.config["adapter_model_path"] == safetensor_path:
continue
with safe_open(safetensor_path, framework="pt") as f: with safe_open(safetensor_path, framework="pt") as f:
logger.info(f"Loading weights from {safetensor_path}") logger.info(f"Loading weights from {safetensor_path}")
for k in f.keys(): for k in f.keys():
if any(remove_key in k for remove_key in remove_keys): if any(remove_key in k for remove_key in remove_keys):
continue continue
if f.get_tensor(k).dtype in [ if f.get_tensor(k).dtype in [
torch.float16, torch.float16,
torch.bfloat16, torch.bfloat16,
...@@ -200,7 +183,7 @@ class WanModel(CompiledMethodsMixin): ...@@ -200,7 +183,7 @@ class WanModel(CompiledMethodsMixin):
return weight_dict return weight_dict
def _load_quant_split_ckpt(self, unified_dtype, sensitive_layer): def _load_quant_split_ckpt(self, unified_dtype, sensitive_layer): # Need rewrite
lazy_load_model_path = self.dit_quantized_ckpt lazy_load_model_path = self.dit_quantized_ckpt
logger.info(f"Loading splited quant model from {lazy_load_model_path}") logger.info(f"Loading splited quant model from {lazy_load_model_path}")
pre_post_weight_dict = {} pre_post_weight_dict = {}
...@@ -247,7 +230,7 @@ class WanModel(CompiledMethodsMixin): ...@@ -247,7 +230,7 @@ class WanModel(CompiledMethodsMixin):
if weight_dict is None: if weight_dict is None:
is_weight_loader = self._should_load_weights() is_weight_loader = self._should_load_weights()
if is_weight_loader: if is_weight_loader:
if not self.dit_quantized or self.weight_auto_quant: if not self.dit_quantized:
# Load original weights # Load original weights
weight_dict = self._load_ckpt(unified_dtype, sensitive_layer) weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
else: else:
......
...@@ -18,12 +18,10 @@ class WanTransformerWeights(WeightModule): ...@@ -18,12 +18,10 @@ class WanTransformerWeights(WeightModule):
self.blocks_num = config["num_layers"] self.blocks_num = config["num_layers"]
self.task = config["task"] self.task = config["task"]
self.config = config self.config = config
if config["do_mm_calib"]: self.mm_type = config.get("dit_quant_scheme", "Default")
self.mm_type = "Calib" if self.mm_type != "Default":
else: assert config.get("dit_quantized") is True
self.mm_type = config["mm_config"].get("mm_type", "Default") if config["mm_config"] else "Default"
self.blocks = WeightModuleList([WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config) for i in range(self.blocks_num)]) self.blocks = WeightModuleList([WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config) for i in range(self.blocks_num)])
self.add_module("blocks", self.blocks) self.add_module("blocks", self.blocks)
# non blocks weights # non blocks weights
......
...@@ -783,7 +783,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -783,7 +783,7 @@ class WanAudioRunner(WanRunner): # type:ignore
"""Load transformer with LoRA support""" """Load transformer with LoRA support"""
base_model = WanAudioModel(self.config["model_path"], self.config, self.init_device) base_model = WanAudioModel(self.config["model_path"], self.config, self.init_device)
if self.config.get("lora_configs") and self.config["lora_configs"]: if self.config.get("lora_configs") and self.config["lora_configs"]:
assert not self.config.get("dit_quantized", False) or self.config["mm_config"].get("weight_auto_quant", False) assert not self.config.get("dit_quantized", False)
lora_wrapper = WanLoraWrapper(base_model) lora_wrapper = WanLoraWrapper(base_model)
for lora_config in self.config["lora_configs"]: for lora_config in self.config["lora_configs"]:
lora_path = lora_config["path"] lora_path = lora_config["path"]
......
...@@ -78,6 +78,21 @@ class MultiDistillModelStruct(MultiModelStruct): ...@@ -78,6 +78,21 @@ class MultiDistillModelStruct(MultiModelStruct):
class Wan22MoeDistillRunner(WanDistillRunner): class Wan22MoeDistillRunner(WanDistillRunner):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.high_noise_model_path = os.path.join(self.config["model_path"], "high_noise_model")
if not os.path.isdir(self.high_noise_model_path):
self.high_noise_model_path = os.path.join(self.config["model_path"], "distill_models", "high_noise_model")
if self.config.get("dit_quantized", False) and self.config.get("high_noise_quantized_ckpt", None):
self.high_noise_model_path = self.config["high_noise_quantized_ckpt"]
elif self.config.get("high_noise_original_ckpt", None):
self.high_noise_model_path = self.config["high_noise_original_ckpt"]
self.low_noise_model_path = os.path.join(self.config["model_path"], "low_noise_model")
if not os.path.isdir(self.low_noise_model_path):
self.low_noise_model_path = os.path.join(self.config["model_path"], "distill_models", "low_noise_model")
if self.config.get("dit_quantized", False) and self.config.get("low_noise_quantized_ckpt", None):
self.low_noise_model_path = self.config["low_noise_quantized_ckpt"]
elif not self.config.get("dit_quantized", False) and self.config.get("low_noise_original_ckpt", None):
self.low_noise_model_path = self.config["low_noise_original_ckpt"]
def load_transformer(self): def load_transformer(self):
use_high_lora, use_low_lora = False, False use_high_lora, use_low_lora = False, False
...@@ -90,7 +105,7 @@ class Wan22MoeDistillRunner(WanDistillRunner): ...@@ -90,7 +105,7 @@ class Wan22MoeDistillRunner(WanDistillRunner):
if use_high_lora: if use_high_lora:
high_noise_model = WanModel( high_noise_model = WanModel(
os.path.join(self.config["model_path"], "high_noise_model"), self.high_noise_model_path,
self.config, self.config,
self.init_device, self.init_device,
) )
...@@ -104,15 +119,14 @@ class Wan22MoeDistillRunner(WanDistillRunner): ...@@ -104,15 +119,14 @@ class Wan22MoeDistillRunner(WanDistillRunner):
logger.info(f"High noise model loaded LoRA: {lora_name} with strength: {strength}") logger.info(f"High noise model loaded LoRA: {lora_name} with strength: {strength}")
else: else:
high_noise_model = WanDistillModel( high_noise_model = WanDistillModel(
os.path.join(self.config["model_path"], "distill_models", "high_noise_model"), self.high_noise_model_path,
self.config, self.config,
self.init_device, self.init_device,
ckpt_config_key="dit_distill_ckpt_high",
) )
if use_low_lora: if use_low_lora:
low_noise_model = WanModel( low_noise_model = WanModel(
os.path.join(self.config["model_path"], "low_noise_model"), self.low_noise_model_path,
self.config, self.config,
self.init_device, self.init_device,
) )
...@@ -126,10 +140,9 @@ class Wan22MoeDistillRunner(WanDistillRunner): ...@@ -126,10 +140,9 @@ class Wan22MoeDistillRunner(WanDistillRunner):
logger.info(f"Low noise model loaded LoRA: {lora_name} with strength: {strength}") logger.info(f"Low noise model loaded LoRA: {lora_name} with strength: {strength}")
else: else:
low_noise_model = WanDistillModel( low_noise_model = WanDistillModel(
os.path.join(self.config["model_path"], "distill_models", "low_noise_model"), self.low_noise_model_path,
self.config, self.config,
self.init_device, self.init_device,
ckpt_config_key="dit_distill_ckpt_low",
) )
return MultiDistillModelStruct([high_noise_model, low_noise_model], self.config, self.config["boundary_step_index"]) return MultiDistillModelStruct([high_noise_model, low_noise_model], self.config, self.config["boundary_step_index"])
......
...@@ -29,7 +29,6 @@ from lightx2v.utils.envs import * ...@@ -29,7 +29,6 @@ from lightx2v.utils.envs import *
from lightx2v.utils.profiler import * from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import * from lightx2v.utils.utils import *
from lightx2v.utils.utils import best_output_size
@RUNNER_REGISTER("wan2.1") @RUNNER_REGISTER("wan2.1")
...@@ -48,7 +47,7 @@ class WanRunner(DefaultRunner): ...@@ -48,7 +47,7 @@ class WanRunner(DefaultRunner):
self.init_device, self.init_device,
) )
if self.config.get("lora_configs") and self.config.lora_configs: if self.config.get("lora_configs") and self.config.lora_configs:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False) assert not self.config.get("dit_quantized", False)
lora_wrapper = WanLoraWrapper(model) lora_wrapper = WanLoraWrapper(model)
for lora_config in self.config.lora_configs: for lora_config in self.config.lora_configs:
lora_path = lora_config["path"] lora_path = lora_config["path"]
...@@ -445,22 +444,37 @@ class MultiModelStruct: ...@@ -445,22 +444,37 @@ class MultiModelStruct:
class Wan22MoeRunner(WanRunner): class Wan22MoeRunner(WanRunner):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.high_noise_model_path = os.path.join(self.config["model_path"], "high_noise_model")
if not os.path.isdir(self.high_noise_model_path):
self.high_noise_model_path = os.path.join(self.config["model_path"], "distill_models", "high_noise_model")
if self.config.get("dit_quantized", False) and self.config.get("high_noise_quantized_ckpt", None):
self.high_noise_model_path = self.config["high_noise_quantized_ckpt"]
elif self.config.get("high_noise_original_ckpt", None):
self.high_noise_model_path = self.config["high_noise_original_ckpt"]
self.low_noise_model_path = os.path.join(self.config["model_path"], "low_noise_model")
if not os.path.isdir(self.low_noise_model_path):
self.low_noise_model_path = os.path.join(self.config["model_path"], "distill_models", "low_noise_model")
if self.config.get("dit_quantized", False) and self.config.get("low_noise_quantized_ckpt", None):
self.low_noise_model_path = self.config["low_noise_quantized_ckpt"]
elif not self.config.get("dit_quantized", False) and self.config.get("low_noise_original_ckpt", None):
self.low_noise_model_path = self.config["low_noise_original_ckpt"]
def load_transformer(self): def load_transformer(self):
# encoder -> high_noise_model -> low_noise_model -> vae -> video_output # encoder -> high_noise_model -> low_noise_model -> vae -> video_output
high_noise_model = WanModel( high_noise_model = WanModel(
os.path.join(self.config["model_path"], "high_noise_model"), self.high_noise_model_path,
self.config, self.config,
self.init_device, self.init_device,
) )
low_noise_model = WanModel( low_noise_model = WanModel(
os.path.join(self.config["model_path"], "low_noise_model"), self.low_noise_model_path,
self.config, self.config,
self.init_device, self.init_device,
) )
if self.config.get("lora_configs") and self.config["lora_configs"]: if self.config.get("lora_configs") and self.config["lora_configs"]:
assert not self.config.get("dit_quantized", False) or self.config["mm_config"].get("weight_auto_quant", False) assert not self.config.get("dit_quantized", False)
for lora_config in self.config["lora_configs"]: for lora_config in self.config["lora_configs"]:
lora_path = lora_config["path"] lora_path = lora_config["path"]
......
...@@ -27,7 +27,7 @@ class WanSFRunner(WanRunner): ...@@ -27,7 +27,7 @@ class WanSFRunner(WanRunner):
self.init_device, self.init_device,
) )
if self.config.get("lora_configs") and self.config.lora_configs: if self.config.get("lora_configs") and self.config.lora_configs:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False) assert not self.config.get("dit_quantized", False)
lora_wrapper = WanLoraWrapper(model) lora_wrapper = WanLoraWrapper(model)
for lora_config in self.config.lora_configs: for lora_config in self.config.lora_configs:
lora_path = lora_config["path"] lora_path = lora_config["path"]
......
...@@ -21,7 +21,6 @@ def get_default_config(): ...@@ -21,7 +21,6 @@ def get_default_config():
"use_ret_steps": False, "use_ret_steps": False,
"use_bfloat16": True, "use_bfloat16": True,
"lora_configs": None, # List of dicts with 'path' and 'strength' keys "lora_configs": None, # List of dicts with 'path' and 'strength' keys
"mm_config": {},
"use_prompt_enhancer": False, "use_prompt_enhancer": False,
"parallel": False, "parallel": False,
"seq_parallel": False, "seq_parallel": False,
......
import glob
import os import os
import random import random
import subprocess import subprocess
...@@ -318,52 +317,6 @@ def find_torch_model_path(config, ckpt_config_key=None, filename=None, subdir=[" ...@@ -318,52 +317,6 @@ def find_torch_model_path(config, ckpt_config_key=None, filename=None, subdir=["
raise FileNotFoundError(f"PyTorch model file '{filename}' not found.\nPlease download the model from https://huggingface.co/lightx2v/ or specify the model path in the configuration file.") raise FileNotFoundError(f"PyTorch model file '{filename}' not found.\nPlease download the model from https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
def find_hf_model_path(config, model_path, ckpt_config_key=None, subdir=["original", "fp8", "int8", "distill_models", "distill_fp8", "distill_int8"]):
if ckpt_config_key and config.get(ckpt_config_key, None) is not None:
return config.get(ckpt_config_key)
paths_to_check = [model_path]
if isinstance(subdir, list):
for sub in subdir:
paths_to_check.insert(0, os.path.join(model_path, sub))
else:
paths_to_check.insert(0, os.path.join(model_path, subdir))
for path in paths_to_check:
safetensors_pattern = os.path.join(path, "*.safetensors")
safetensors_files = glob.glob(safetensors_pattern)
if safetensors_files:
return path
raise FileNotFoundError(f"No Hugging Face model files (.safetensors) found.\nPlease download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
def find_gguf_model_path(config, ckpt_config_key=None, subdir=None):
gguf_path = config.get(ckpt_config_key, None)
if gguf_path is None:
raise ValueError(f"GGUF path not found in config with key '{ckpt_config_key}'")
if not isinstance(gguf_path, str) or not gguf_path.endswith(".gguf"):
raise ValueError(f"GGUF path must be a string ending with '.gguf', got: {gguf_path}")
if os.sep in gguf_path or (os.altsep and os.altsep in gguf_path):
if os.path.exists(gguf_path):
logger.info(f"Found GGUF model file in: {gguf_path}")
return os.path.abspath(gguf_path)
else:
raise FileNotFoundError(f"GGUF file not found at path: {gguf_path}")
else:
# It's just a filename, search in predefined paths
paths_to_check = [config.model_path]
if subdir:
paths_to_check.append(os.path.join(config.model_path, subdir))
for path in paths_to_check:
gguf_file_path = os.path.join(path, gguf_path)
gguf_file = glob.glob(gguf_file_path)
if gguf_file:
logger.info(f"Found GGUF model file in: {gguf_file_path}")
return gguf_file_path
raise FileNotFoundError(f"No GGUF model files (.gguf) found.\nPlease download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
def load_safetensors(in_path, remove_key=None, include_keys=None): def load_safetensors(in_path, remove_key=None, include_keys=None):
"""加载safetensors文件或目录,支持按key包含筛选或排除""" """加载safetensors文件或目录,支持按key包含筛选或排除"""
include_keys = include_keys or [] include_keys = include_keys or []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment