Unverified Commit 94e593ce authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Recon] remove mm_config and support load single safetensors file (#375)

parent 954df466
......@@ -16,14 +16,10 @@ class WeightModule:
def load(self, weight_dict):
for _, module in self._modules.items():
if hasattr(module, "set_config"):
module.set_config(self.config["mm_config"])
if hasattr(module, "load"):
module.load(weight_dict)
for _, parameter in self._parameters.items():
if hasattr(parameter, "set_config"):
parameter.set_config(self.config["mm_config"])
if hasattr(parameter, "load"):
parameter.load(weight_dict)
......
from .mm_weight import *
from .mm_weight_calib import *
from abc import ABCMeta, abstractmethod
import torch
from loguru import logger
from lightx2v.utils.envs import *
from lightx2v.utils.quant_utils import FloatQuantizer, IntegerQuantizer
......@@ -362,7 +361,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
return destination
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm")
@MM_WEIGHT_REGISTER("fp8-vllm")
class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightQuantTemplate):
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm
......@@ -397,7 +396,7 @@ class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightQuantTemplate):
return output_tensor
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm")
@MM_WEIGHT_REGISTER("int8-vllm")
class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm
......@@ -432,7 +431,7 @@ class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate):
return output_tensor
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F")
@MM_WEIGHT_REGISTER("fp8-q8f")
class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate):
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F
......@@ -462,7 +461,7 @@ class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate):
return output_tensor.squeeze(0)
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F")
@MM_WEIGHT_REGISTER("int8-q8f")
class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F
......@@ -493,52 +492,7 @@ class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate):
return output_tensor.squeeze(0)
@MM_WEIGHT_REGISTER("W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm")
class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemm(MMWeightQuantTemplate):
"""
Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm
Quant MM:
Weight: fp8 perblock 128x128 sym
Act: fp8 perchannel-pergroup group=128 dynamic sym
Kernel: Deepgemm
Reference: https://github.com/deepseek-ai/DeepGEMM
Example:
Act(1024, 2048) x Weight(2048, 4096) = Out(1024, 4096)
Act : torch.Size([1024, 2048]), torch.float8_e4m3fn
Act Scale: torch.Size([1024, 16]), torch.float32
Weight : torch.Size([4096, 2048]), torch.float8_e4m3fn
Weight Scale: torch.Size([32, 16]), torch.float32
Out : torch.Size([1024, 4096]), self.infer_dtype
"""
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
self.load_func = self.load_fp8_perblock128_sym
self.weight_need_transpose = False
self.act_quant_func = self.act_quant_fp8_perchannelgroup128_sym_deepgemm
def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[0])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
deep_gemm.gemm_fp8_fp8_bf16_nt(
(input_tensor_quant, input_tensor_scale),
(self.weight, self.weight_scale),
output_tensor,
)
if hasattr(self, "bias") and self.bias is not None:
output_tensor.add_(self.bias)
return output_tensor
@MM_WEIGHT_REGISTER("W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl")
@MM_WEIGHT_REGISTER("fp8-b128-deepgemm")
class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemmActSgl(MMWeightQuantTemplate):
"""
Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl
......@@ -572,72 +526,7 @@ class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemmActSgl(MMWeightQuant
return output_tensor
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm-ActSgl")
class MMWeightWfp8channelAfp8channeldynamicVllmActSgl(MMWeightQuantTemplate):
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm-ActSgl
Quant MM:
Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym
Kernel: quant-mm using vllm, act dynamic quant using Sgl-kernel
"""
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
self.load_func = self.load_fp8_perchannel_sym
self.weight_need_transpose = True
self.act_quant_func = self.act_quant_fp8_perchannel_sym_sgl
def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[1])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
torch.ops._C.cutlass_scaled_mm(
output_tensor,
input_tensor_quant,
self.weight,
input_tensor_scale,
self.weight_scale,
self.bias,
)
return output_tensor
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl-ActVllm")
class MMWeightWfp8channelAfp8channeldynamicSglActVllm(MMWeightQuantTemplate):
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl-ActVllm
Quant MM:
Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym
Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
"""
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
self.load_func = self.load_fp8_perchannel_sym
self.weight_need_transpose = True
self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = sgl_kernel.fp8_scaled_mm(
input_tensor_quant,
self.weight,
input_tensor_scale,
self.weight_scale,
self.infer_dtype,
bias=self.bias,
)
return output_tensor
@MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl")
@MM_WEIGHT_REGISTER("fp8-sgl")
class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl
......@@ -667,7 +556,7 @@ class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
return output_tensor
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm")
@MM_WEIGHT_REGISTER("int8-sgl")
class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm
......@@ -702,7 +591,7 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
return output_tensor
@MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao")
@MM_WEIGHT_REGISTER("int8-torchao")
class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao
......@@ -746,7 +635,7 @@ class MMWeightGGUFQ4K(MMWeightGGUFTemplate):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
@MM_WEIGHT_REGISTER("W-int4-group128-sym-Marlin")
@MM_WEIGHT_REGISTER("int4-g128-marlin")
class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
"""
Name: "W-int4-group128-sym-Marlin
......@@ -779,42 +668,3 @@ class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
if hasattr(self, "bias") and self.bias is not None:
output_tensor.add_(self.bias)
return output_tensor
if __name__ == "__main__":
weight_dict = {
"xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn),
"xx.bias": torch.randn(8192).to(torch.bfloat16),
"xx.weight_scale": torch.randn(8192, 1).to(torch.float32),
}
mm_weight = MM_WEIGHT_REGISTER["W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
mm_weight.set_config({"weight_auto_quant": False})
mm_weight.load(weight_dict)
input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
output_tensor = mm_weight.apply(input_tensor)
logger.info(output_tensor.shape)
weight_dict = {
"xx.weight": torch.randn(8192, 4096),
"xx.bias": torch.randn(8192).to(torch.bfloat16),
}
mm_weight = MM_WEIGHT_REGISTER["W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
mm_weight.set_config({"weight_auto_quant": True})
mm_weight.load(weight_dict)
input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
output_tensor = mm_weight.apply(input_tensor)
logger.info(output_tensor.shape)
weight_dict = {
"xx.weight": torch.randn(8192, 4096),
"xx.bias": torch.randn(8192).to(torch.bfloat16),
}
mm_weight = MM_WEIGHT_REGISTER["W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
mm_weight.set_config({"weight_auto_quant": True})
mm_weight.load(weight_dict)
input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
output_tensor = mm_weight.apply(input_tensor)
logger.info(output_tensor.shape)
import torch
from lightx2v.utils.quant_utils import FloatQuantizer
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from .mm_weight import MMWeight
@MM_WEIGHT_REGISTER("Calib")
class MMWeightCalib(MMWeight):
def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name)
def load(self, weight_dict):
assert self.config and self.config.get("mm_type", "Default") != "Default"
self.weight = weight_dict[self.weight_name]
self.get_quantizer()
shape_and_dtype = self.get_quant_shape_and_dtype(self.weight.shape)
self.realq_weight, self.scales, self.zeros = self.w_quantizer.real_quant_tensor(self.weight)
self.realq_weight = self.realq_weight.view(shape_and_dtype["tensor"][0]).contiguous().to(shape_and_dtype["tensor"][1])
self.scales = self.scales.view(shape_and_dtype["scales"][0]).contiguous().to(shape_and_dtype["scales"][1])
if self.zeros is not None:
self.zeros = self.zeros.view(shape_and_dtype["zeros"][0]).contiguous().to(shape_and_dtype["zeros"][1])
def apply(self, input_tensor):
return super().apply(input_tensor)
def get_quantizer(self):
if self.config["mm_type"] == "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm":
self.w_setting = {"bit": "e4m3", "symmetric": True, "granularity": "per_channel"}
self.a_setting = {"bit": "e4m3", "symmetric": True, "granularity": "per_channel"}
self.w_quantizer = FloatQuantizer(**self.w_setting)
self.a_quantizer = FloatQuantizer(**self.a_setting)
self.act_dynamic_quant = True
else:
raise NotImplementedError(f"Unsupported mm_type: {self.config['mm_type']}")
def get_quant_shape_and_dtype(self, shape):
if self.config["mm_type"] == "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm":
return {
"tensor": (shape, torch.float8_e5m2),
"scales": ((shape[0], 1), torch.float32),
"zeros": None,
}
else:
raise NotImplementedError(f"Unsupported mm_type: {self.config['mm_type']}")
......@@ -33,8 +33,7 @@ class QwenImageTransformerModel:
self.in_channels = transformer_config["in_channels"]
self.attention_kwargs = {}
self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default"
self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False)
self.dit_quantized = self.config["dit_quantized"]
self._init_infer_class()
self._init_weights()
......@@ -62,7 +61,7 @@ class QwenImageTransformerModel:
if weight_dict is None:
is_weight_loader = self._should_load_weights()
if is_weight_loader:
if not self.dit_quantized or self.weight_auto_quant:
if not self.dit_quantized:
# Load original weights
weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
else:
......
......@@ -10,11 +10,9 @@ class QwenImagePostWeights(WeightModule):
super().__init__()
self.task = config["task"]
self.config = config
if config["do_mm_calib"]:
self.mm_type = "Calib"
else:
self.mm_type = config["mm_config"].get("mm_type", "Default") if config["mm_config"] else "Default"
self.mm_type = config.get("dit_quant_scheme", "Default")
if self.mm_type != "Default":
assert config.get("dit_quantized") is True
self.lazy_load = self.config.get("lazy_load", False)
if self.lazy_load:
assert NotImplementedError
......
......@@ -12,11 +12,9 @@ class QwenImageTransformerWeights(WeightModule):
self.blocks_num = config["num_layers"]
self.task = config["task"]
self.config = config
if config["do_mm_calib"]:
self.mm_type = "Calib"
else:
self.mm_type = config["mm_config"].get("mm_type", "Default") if config["mm_config"] else "Default"
self.mm_type = config.get("dit_quant_scheme", "Default")
if self.mm_type != "Default":
assert config.get("dit_quantized") is True
blocks = WeightModuleList(QwenImageTransformerAttentionBlock(i, self.task, self.mm_type, self.config, "transformer_blocks") for i in range(self.blocks_num))
self.add_module("blocks", blocks)
......
import glob
import os
import torch
......@@ -19,8 +18,7 @@ class WanDistillModel(WanModel):
post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device, ckpt_config_key="dit_distill_ckpt"):
self.ckpt_config_key = ckpt_config_key
def __init__(self, model_path, config, device):
super().__init__(model_path, config, device)
def _load_ckpt(self, unified_dtype, sensitive_layer):
......@@ -34,21 +32,4 @@ class WanDistillModel(WanModel):
for key in weight_dict.keys()
}
return weight_dict
if self.config.get("enable_dynamic_cfg", False):
safetensors_path = find_hf_model_path(self.config, self.model_path, self.ckpt_config_key, subdir="distill_cfg_models")
else:
safetensors_path = find_hf_model_path(self.config, self.model_path, self.ckpt_config_key, subdir="distill_models")
if os.path.isfile(safetensors_path):
logger.info(f"loading checkpoint from {safetensors_path} ...")
safetensors_files = glob.glob(safetensors_path)
else:
logger.info(f"loading checkpoint from {safetensors_path} ...")
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
weight_dict = {}
for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
weight_dict.update(file_weights)
return weight_dict
return super()._load_ckpt(unified_dtype, sensitive_layer)
import gc
import json
import os
import torch
......@@ -62,35 +61,9 @@ class WanModel(CompiledMethodsMixin):
self.init_empty_model = False
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.dit_quantized = self.config["mm_config"].get("mm_type", "Default") != "Default"
if self.dit_quantized:
dit_quant_scheme = self.config["mm_config"].get("mm_type").split("-")[1]
if self.config["model_cls"] == "wan2.1_distill":
dit_quant_scheme = "distill_" + dit_quant_scheme
if dit_quant_scheme == "gguf":
self.dit_quantized_ckpt = find_gguf_model_path(config, "dit_quantized_ckpt", subdir=dit_quant_scheme)
self.config["use_gguf"] = True
else:
self.dit_quantized_ckpt = find_hf_model_path(
config,
self.model_path,
"dit_quantized_ckpt",
subdir=dit_quant_scheme,
)
quant_config_path = os.path.join(self.dit_quantized_ckpt, "config.json")
if os.path.exists(quant_config_path):
with open(quant_config_path, "r") as f:
quant_model_config = json.load(f)
self.config.update(quant_model_config)
else:
self.dit_quantized_ckpt = None
assert not self.config.get("lazy_load", False)
self.weight_auto_quant = self.config["mm_config"].get("weight_auto_quant", False)
self.dit_quantized = self.config.get("dit_quantized", False)
if self.dit_quantized:
assert self.weight_auto_quant or self.dit_quantized_ckpt is not None
assert self.config.get("dit_quant_scheme", "Default") in ["Default-Force-FP32", "fp8-vllm", "int8-vllm", "fp8-q8f", "int8-q8f", "fp8-b128-deepgemm", "fp8-sgl", "int8-sgl", "int8-torchao"]
self.device = device
self._init_infer_class()
self._init_weights()
......@@ -152,40 +125,50 @@ class WanModel(CompiledMethodsMixin):
}
def _load_ckpt(self, unified_dtype, sensitive_layer):
safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_original_ckpt", subdir="original")
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
if self.config.get("dit_original_ckpt", None):
safetensors_path = self.config["dit_original_ckpt"]
else:
safetensors_path = self.model_path
if os.path.isdir(safetensors_path):
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
else:
safetensors_files = [safetensors_path]
weight_dict = {}
for file_path in safetensors_files:
if self.config.get("adapter_model_path", None) is not None:
if self.config["adapter_model_path"] == file_path:
continue
logger.info(f"Loading weights from {file_path}")
file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
weight_dict.update(file_weights)
return weight_dict
def _load_quant_ckpt(self, unified_dtype, sensitive_layer):
remove_keys = self.remove_keys if hasattr(self, "remove_keys") else []
ckpt_path = self.dit_quantized_ckpt
index_files = [f for f in os.listdir(ckpt_path) if f.endswith(".index.json")]
if not index_files:
raise FileNotFoundError(f"No *.index.json found in {ckpt_path}")
index_path = os.path.join(ckpt_path, index_files[0])
logger.info(f" Using safetensors index: {index_path}")
if self.config.get("dit_quantized_ckpt", None):
safetensors_path = self.config["dit_quantized_ckpt"]
else:
safetensors_path = self.model_path
with open(index_path, "r") as f:
index_data = json.load(f)
if os.path.isdir(safetensors_path):
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
else:
safetensors_files = [safetensors_path]
weight_dict = {}
for filename in set(index_data["weight_map"].values()):
safetensor_path = os.path.join(ckpt_path, filename)
for safetensor_path in safetensors_files:
if self.config.get("adapter_model_path", None) is not None:
if self.config["adapter_model_path"] == safetensor_path:
continue
with safe_open(safetensor_path, framework="pt") as f:
logger.info(f"Loading weights from {safetensor_path}")
for k in f.keys():
if any(remove_key in k for remove_key in remove_keys):
continue
if f.get_tensor(k).dtype in [
torch.float16,
torch.bfloat16,
......@@ -200,7 +183,7 @@ class WanModel(CompiledMethodsMixin):
return weight_dict
def _load_quant_split_ckpt(self, unified_dtype, sensitive_layer):
def _load_quant_split_ckpt(self, unified_dtype, sensitive_layer): # Need rewrite
lazy_load_model_path = self.dit_quantized_ckpt
logger.info(f"Loading splited quant model from {lazy_load_model_path}")
pre_post_weight_dict = {}
......@@ -247,7 +230,7 @@ class WanModel(CompiledMethodsMixin):
if weight_dict is None:
is_weight_loader = self._should_load_weights()
if is_weight_loader:
if not self.dit_quantized or self.weight_auto_quant:
if not self.dit_quantized:
# Load original weights
weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
else:
......
......@@ -18,12 +18,10 @@ class WanTransformerWeights(WeightModule):
self.blocks_num = config["num_layers"]
self.task = config["task"]
self.config = config
if config["do_mm_calib"]:
self.mm_type = "Calib"
else:
self.mm_type = config["mm_config"].get("mm_type", "Default") if config["mm_config"] else "Default"
self.mm_type = config.get("dit_quant_scheme", "Default")
if self.mm_type != "Default":
assert config.get("dit_quantized") is True
self.blocks = WeightModuleList([WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config) for i in range(self.blocks_num)])
self.add_module("blocks", self.blocks)
# non blocks weights
......
......@@ -783,7 +783,7 @@ class WanAudioRunner(WanRunner): # type:ignore
"""Load transformer with LoRA support"""
base_model = WanAudioModel(self.config["model_path"], self.config, self.init_device)
if self.config.get("lora_configs") and self.config["lora_configs"]:
assert not self.config.get("dit_quantized", False) or self.config["mm_config"].get("weight_auto_quant", False)
assert not self.config.get("dit_quantized", False)
lora_wrapper = WanLoraWrapper(base_model)
for lora_config in self.config["lora_configs"]:
lora_path = lora_config["path"]
......
......@@ -78,6 +78,21 @@ class MultiDistillModelStruct(MultiModelStruct):
class Wan22MoeDistillRunner(WanDistillRunner):
def __init__(self, config):
super().__init__(config)
self.high_noise_model_path = os.path.join(self.config["model_path"], "high_noise_model")
if not os.path.isdir(self.high_noise_model_path):
self.high_noise_model_path = os.path.join(self.config["model_path"], "distill_models", "high_noise_model")
if self.config.get("dit_quantized", False) and self.config.get("high_noise_quantized_ckpt", None):
self.high_noise_model_path = self.config["high_noise_quantized_ckpt"]
elif self.config.get("high_noise_original_ckpt", None):
self.high_noise_model_path = self.config["high_noise_original_ckpt"]
self.low_noise_model_path = os.path.join(self.config["model_path"], "low_noise_model")
if not os.path.isdir(self.low_noise_model_path):
self.low_noise_model_path = os.path.join(self.config["model_path"], "distill_models", "low_noise_model")
if self.config.get("dit_quantized", False) and self.config.get("low_noise_quantized_ckpt", None):
self.low_noise_model_path = self.config["low_noise_quantized_ckpt"]
elif not self.config.get("dit_quantized", False) and self.config.get("low_noise_original_ckpt", None):
self.low_noise_model_path = self.config["low_noise_original_ckpt"]
def load_transformer(self):
use_high_lora, use_low_lora = False, False
......@@ -90,7 +105,7 @@ class Wan22MoeDistillRunner(WanDistillRunner):
if use_high_lora:
high_noise_model = WanModel(
os.path.join(self.config["model_path"], "high_noise_model"),
self.high_noise_model_path,
self.config,
self.init_device,
)
......@@ -104,15 +119,14 @@ class Wan22MoeDistillRunner(WanDistillRunner):
logger.info(f"High noise model loaded LoRA: {lora_name} with strength: {strength}")
else:
high_noise_model = WanDistillModel(
os.path.join(self.config["model_path"], "distill_models", "high_noise_model"),
self.high_noise_model_path,
self.config,
self.init_device,
ckpt_config_key="dit_distill_ckpt_high",
)
if use_low_lora:
low_noise_model = WanModel(
os.path.join(self.config["model_path"], "low_noise_model"),
self.low_noise_model_path,
self.config,
self.init_device,
)
......@@ -126,10 +140,9 @@ class Wan22MoeDistillRunner(WanDistillRunner):
logger.info(f"Low noise model loaded LoRA: {lora_name} with strength: {strength}")
else:
low_noise_model = WanDistillModel(
os.path.join(self.config["model_path"], "distill_models", "low_noise_model"),
self.low_noise_model_path,
self.config,
self.init_device,
ckpt_config_key="dit_distill_ckpt_low",
)
return MultiDistillModelStruct([high_noise_model, low_noise_model], self.config, self.config["boundary_step_index"])
......
......@@ -29,7 +29,6 @@ from lightx2v.utils.envs import *
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import *
from lightx2v.utils.utils import best_output_size
@RUNNER_REGISTER("wan2.1")
......@@ -48,7 +47,7 @@ class WanRunner(DefaultRunner):
self.init_device,
)
if self.config.get("lora_configs") and self.config.lora_configs:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
assert not self.config.get("dit_quantized", False)
lora_wrapper = WanLoraWrapper(model)
for lora_config in self.config.lora_configs:
lora_path = lora_config["path"]
......@@ -445,22 +444,37 @@ class MultiModelStruct:
class Wan22MoeRunner(WanRunner):
def __init__(self, config):
super().__init__(config)
self.high_noise_model_path = os.path.join(self.config["model_path"], "high_noise_model")
if not os.path.isdir(self.high_noise_model_path):
self.high_noise_model_path = os.path.join(self.config["model_path"], "distill_models", "high_noise_model")
if self.config.get("dit_quantized", False) and self.config.get("high_noise_quantized_ckpt", None):
self.high_noise_model_path = self.config["high_noise_quantized_ckpt"]
elif self.config.get("high_noise_original_ckpt", None):
self.high_noise_model_path = self.config["high_noise_original_ckpt"]
self.low_noise_model_path = os.path.join(self.config["model_path"], "low_noise_model")
if not os.path.isdir(self.low_noise_model_path):
self.low_noise_model_path = os.path.join(self.config["model_path"], "distill_models", "low_noise_model")
if self.config.get("dit_quantized", False) and self.config.get("low_noise_quantized_ckpt", None):
self.low_noise_model_path = self.config["low_noise_quantized_ckpt"]
elif not self.config.get("dit_quantized", False) and self.config.get("low_noise_original_ckpt", None):
self.low_noise_model_path = self.config["low_noise_original_ckpt"]
def load_transformer(self):
# encoder -> high_noise_model -> low_noise_model -> vae -> video_output
high_noise_model = WanModel(
os.path.join(self.config["model_path"], "high_noise_model"),
self.high_noise_model_path,
self.config,
self.init_device,
)
low_noise_model = WanModel(
os.path.join(self.config["model_path"], "low_noise_model"),
self.low_noise_model_path,
self.config,
self.init_device,
)
if self.config.get("lora_configs") and self.config["lora_configs"]:
assert not self.config.get("dit_quantized", False) or self.config["mm_config"].get("weight_auto_quant", False)
assert not self.config.get("dit_quantized", False)
for lora_config in self.config["lora_configs"]:
lora_path = lora_config["path"]
......
......@@ -27,7 +27,7 @@ class WanSFRunner(WanRunner):
self.init_device,
)
if self.config.get("lora_configs") and self.config.lora_configs:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
assert not self.config.get("dit_quantized", False)
lora_wrapper = WanLoraWrapper(model)
for lora_config in self.config.lora_configs:
lora_path = lora_config["path"]
......
......@@ -21,7 +21,6 @@ def get_default_config():
"use_ret_steps": False,
"use_bfloat16": True,
"lora_configs": None, # List of dicts with 'path' and 'strength' keys
"mm_config": {},
"use_prompt_enhancer": False,
"parallel": False,
"seq_parallel": False,
......
import glob
import os
import random
import subprocess
......@@ -318,52 +317,6 @@ def find_torch_model_path(config, ckpt_config_key=None, filename=None, subdir=["
raise FileNotFoundError(f"PyTorch model file '{filename}' not found.\nPlease download the model from https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
def find_hf_model_path(config, model_path, ckpt_config_key=None, subdir=["original", "fp8", "int8", "distill_models", "distill_fp8", "distill_int8"]):
if ckpt_config_key and config.get(ckpt_config_key, None) is not None:
return config.get(ckpt_config_key)
paths_to_check = [model_path]
if isinstance(subdir, list):
for sub in subdir:
paths_to_check.insert(0, os.path.join(model_path, sub))
else:
paths_to_check.insert(0, os.path.join(model_path, subdir))
for path in paths_to_check:
safetensors_pattern = os.path.join(path, "*.safetensors")
safetensors_files = glob.glob(safetensors_pattern)
if safetensors_files:
return path
raise FileNotFoundError(f"No Hugging Face model files (.safetensors) found.\nPlease download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
def find_gguf_model_path(config, ckpt_config_key=None, subdir=None):
gguf_path = config.get(ckpt_config_key, None)
if gguf_path is None:
raise ValueError(f"GGUF path not found in config with key '{ckpt_config_key}'")
if not isinstance(gguf_path, str) or not gguf_path.endswith(".gguf"):
raise ValueError(f"GGUF path must be a string ending with '.gguf', got: {gguf_path}")
if os.sep in gguf_path or (os.altsep and os.altsep in gguf_path):
if os.path.exists(gguf_path):
logger.info(f"Found GGUF model file in: {gguf_path}")
return os.path.abspath(gguf_path)
else:
raise FileNotFoundError(f"GGUF file not found at path: {gguf_path}")
else:
# It's just a filename, search in predefined paths
paths_to_check = [config.model_path]
if subdir:
paths_to_check.append(os.path.join(config.model_path, subdir))
for path in paths_to_check:
gguf_file_path = os.path.join(path, gguf_path)
gguf_file = glob.glob(gguf_file_path)
if gguf_file:
logger.info(f"Found GGUF model file in: {gguf_file_path}")
return gguf_file_path
raise FileNotFoundError(f"No GGUF model files (.gguf) found.\nPlease download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
def load_safetensors(in_path, remove_key=None, include_keys=None):
"""加载safetensors文件或目录,支持按key包含筛选或排除"""
include_keys = include_keys or []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment