Unverified Commit 5182f8f8 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

feat: single-file model loading (#413)

* add a script to merge models

* finished

* try to merge t5

* merge the config into meta files

* rewrite the t5 model loading

* consider the case of subfolder

* merged the qencoder files

* make the linter happy and fix the tests

* pass tests

* add deprecation messages

* add a script to merge models

* schnell script runnable

* update sana

* modify the model paths

* fix the model paths

* style: make the linter happy

* remove the debugging assertion

* chore: fix the qencoder lpips

* fix the lpips
parent 8401d290
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
from nunchaku.utils import get_precision
def main():
pipeline_init_kwargs = {}
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-schnell")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16, **pipeline_init_kwargs
).to("cuda")
image = pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
).images[0]
image.save(f"flux.1-schnell-qencoder-{precision}.png")
if __name__ == "__main__":
main()
...@@ -5,7 +5,9 @@ from nunchaku import NunchakuFluxTransformer2dModel ...@@ -5,7 +5,9 @@ from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-schnell") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{precision}_r32-flux.1-schnell.safetensors"
)
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda") ).to("cuda")
......
...@@ -4,7 +4,9 @@ from diffusers import SanaPipeline ...@@ -4,7 +4,9 @@ from diffusers import SanaPipeline
from nunchaku import NunchakuSanaTransformer2DModel from nunchaku import NunchakuSanaTransformer2DModel
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m") transformer = NunchakuSanaTransformer2DModel.from_pretrained(
"mit-han-lab/nunchaku-sana/svdq-int4_r32-sana1.6b.safetensors"
)
pipe = SanaPipeline.from_pretrained( pipe = SanaPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
transformer=transformer, transformer=transformer,
...@@ -28,4 +30,4 @@ image = pipe( ...@@ -28,4 +30,4 @@ image = pipe(
generator=torch.Generator().manual_seed(42), generator=torch.Generator().manual_seed(42),
).images[0] ).images[0]
image.save("sana_1600m-int4.png") image.save("sana1.6b-int4.png")
...@@ -3,7 +3,9 @@ from diffusers import SanaPipeline ...@@ -3,7 +3,9 @@ from diffusers import SanaPipeline
from nunchaku import NunchakuSanaTransformer2DModel from nunchaku import NunchakuSanaTransformer2DModel
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m") transformer = NunchakuSanaTransformer2DModel.from_pretrained(
"mit-han-lab/nunchaku-sana/svdq-int4_r32-sana1.6b.safetensors"
)
pipe = SanaPipeline.from_pretrained( pipe = SanaPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
transformer=transformer, transformer=transformer,
...@@ -23,4 +25,4 @@ image = pipe( ...@@ -23,4 +25,4 @@ image = pipe(
generator=torch.Generator().manual_seed(42), generator=torch.Generator().manual_seed(42),
).images[0] ).images[0]
image.save("sana_1600m-int4.png") image.save("sana1.6b-int4.png")
...@@ -3,7 +3,9 @@ from diffusers import SanaPAGPipeline ...@@ -3,7 +3,9 @@ from diffusers import SanaPAGPipeline
from nunchaku import NunchakuSanaTransformer2DModel from nunchaku import NunchakuSanaTransformer2DModel
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m", pag_layers=8) transformer = NunchakuSanaTransformer2DModel.from_pretrained(
"mit-han-lab/nunchaku-sana/svdq-int4_r32-sana1.6b.safetensors", pag_layers=8
)
pipe = SanaPAGPipeline.from_pretrained( pipe = SanaPAGPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
transformer=transformer, transformer=transformer,
...@@ -24,4 +26,4 @@ image = pipe( ...@@ -24,4 +26,4 @@ image = pipe(
pag_scale=2.0, pag_scale=2.0,
num_inference_steps=20, num_inference_steps=20,
).images[0] ).images[0]
image.save("sana_1600m_pag-int4.png") image.save("sana1.6b_pag-int4.png")
import argparse
import os
from pathlib import Path
import torch
from huggingface_hub import constants, hf_hub_download
from safetensors.torch import save_file
from .utils import load_state_dict_in_safetensors
def merge_models_into_a_single_file(
pretrained_model_name_or_path: str | os.PathLike[str], **kwargs
) -> tuple[dict[str, torch.Tensor], dict[str, str]]:
subfolder = kwargs.get("subfolder", None)
if isinstance(pretrained_model_name_or_path, str):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
if pretrained_model_name_or_path.exists():
dirpath = pretrained_model_name_or_path if subfolder is None else pretrained_model_name_or_path / subfolder
unquantized_part_path = dirpath / "unquantized_layers.safetensors"
transformer_block_path = dirpath / "transformer_blocks.safetensors"
config_path = dirpath / "config.json"
comfy_config_path = dirpath / "comfy_config.json"
else:
download_kwargs = {
"subfolder": subfolder,
"repo_type": "model",
"revision": kwargs.get("revision", None),
"cache_dir": kwargs.get("cache_dir", None),
"local_dir": kwargs.get("local_dir", None),
"user_agent": kwargs.get("user_agent", None),
"force_download": kwargs.get("force_download", False),
"proxies": kwargs.get("proxies", None),
"etag_timeout": kwargs.get("etag_timeout", constants.DEFAULT_ETAG_TIMEOUT),
"token": kwargs.get("token", None),
"local_files_only": kwargs.get("local_files_only", None),
"headers": kwargs.get("headers", None),
"endpoint": kwargs.get("endpoint", None),
"resume_download": kwargs.get("resume_download", None),
"force_filename": kwargs.get("force_filename", None),
"local_dir_use_symlinks": kwargs.get("local_dir_use_symlinks", "auto"),
}
unquantized_part_path = hf_hub_download(
repo_id=str(pretrained_model_name_or_path), filename="unquantized_layers.safetensors", **download_kwargs
)
transformer_block_path = hf_hub_download(
repo_id=str(pretrained_model_name_or_path), filename="transformer_blocks.safetensors", **download_kwargs
)
config_path = hf_hub_download(
repo_id=str(pretrained_model_name_or_path), filename="config.json", **download_kwargs
)
comfy_config_path = hf_hub_download(
repo_id=str(pretrained_model_name_or_path), filename="comfy_config.json", **download_kwargs
)
unquantized_part_sd = load_state_dict_in_safetensors(unquantized_part_path)
transformer_block_sd = load_state_dict_in_safetensors(transformer_block_path)
state_dict = unquantized_part_sd
state_dict.update(transformer_block_sd)
return state_dict, {
"config": Path(config_path).read_text(),
"comfy_config": Path(comfy_config_path).read_text(),
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--input-path",
type=Path,
required=True,
help="Path to model directory. It can also be a huggingface repo.",
)
parser.add_argument("-o", "--output-path", type=Path, required=True, help="Path to output path")
args = parser.parse_args()
state_dict, metadata = merge_models_into_a_single_file(args.input_path)
output_path = Path(args.output_path)
dirpath = output_path.parent
dirpath.mkdir(parents=True, exist_ok=True)
save_file(state_dict, output_path, metadata=metadata)
import argparse
import os
from pathlib import Path
import torch
from huggingface_hub import constants, hf_hub_download
from safetensors.torch import save_file
from .utils import load_state_dict_in_safetensors
def merge_config_into_model(
pretrained_model_name_or_path: str | os.PathLike[str], **kwargs
) -> tuple[dict[str, torch.Tensor], dict[str, str]]:
subfolder = kwargs.get("subfolder", None)
if isinstance(pretrained_model_name_or_path, str):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
if pretrained_model_name_or_path.exists():
dirpath = pretrained_model_name_or_path if subfolder is None else pretrained_model_name_or_path / subfolder
model_path = dirpath / "awq-int4-flux.1-t5xxl.safetensors"
config_path = dirpath / "config.json"
else:
download_kwargs = {
"subfolder": subfolder,
"repo_type": "model",
"revision": kwargs.get("revision", None),
"cache_dir": kwargs.get("cache_dir", None),
"local_dir": kwargs.get("local_dir", None),
"user_agent": kwargs.get("user_agent", None),
"force_download": kwargs.get("force_download", False),
"proxies": kwargs.get("proxies", None),
"etag_timeout": kwargs.get("etag_timeout", constants.DEFAULT_ETAG_TIMEOUT),
"token": kwargs.get("token", None),
"local_files_only": kwargs.get("local_files_only", None),
"headers": kwargs.get("headers", None),
"endpoint": kwargs.get("endpoint", None),
"resume_download": kwargs.get("resume_download", None),
"force_filename": kwargs.get("force_filename", None),
"local_dir_use_symlinks": kwargs.get("local_dir_use_symlinks", "auto"),
}
model_path = hf_hub_download(
repo_id=str(pretrained_model_name_or_path), filename="awq-int4-flux.1-t5xxl.safetensors", **download_kwargs
)
config_path = hf_hub_download(
repo_id=str(pretrained_model_name_or_path), filename="config.json", **download_kwargs
)
model_path = Path(model_path)
config_path = Path(config_path)
state_dict = load_state_dict_in_safetensors(model_path)
metadata = {"config": config_path.read_text()}
return state_dict, metadata
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--input-path",
type=Path,
default="mit-han-lab/nunchaku-t5",
help="Path to model directory. It can also be a huggingface repo.",
)
parser.add_argument("-o", "--output-path", type=Path, required=True, help="Path to output path")
args = parser.parse_args()
state_dict, metadata = merge_config_into_model(args.input_path)
output_path = Path(args.output_path)
dirpath = output_path.parent
dirpath.mkdir(parents=True, exist_ok=True)
save_file(state_dict, output_path, metadata=metadata)
import json
import logging
import os import os
from pathlib import Path
import torch import torch
from accelerate import init_empty_weights from accelerate import init_empty_weights
from huggingface_hub import constants, hf_hub_download
from safetensors.torch import load_file
from torch import nn from torch import nn
from transformers import PretrainedConfig, T5Config, T5EncoderModel from transformers import T5Config, T5EncoderModel
from ...utils import load_state_dict_in_safetensors
from .linear import W4Linear from .linear import W4Linear
# Get log level from environment variable (default to INFO)
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
# Configure logging
logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
class NunchakuT5EncoderModel(T5EncoderModel): class NunchakuT5EncoderModel(T5EncoderModel):
@classmethod @classmethod
def from_pretrained( def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
cls, pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
pretrained_model_name_or_path: str | os.PathLike, state_dict = load_state_dict_in_safetensors(pretrained_model_name_or_path, return_metadata=True)
config: PretrainedConfig | str | os.PathLike | None = None,
cache_dir: str | os.PathLike | None = None,
force_download: bool = False,
local_files_only: bool = False,
token: str | bool | None = None,
revision: str = "main",
**kwargs,
):
subfolder = kwargs.get("subfolder", None)
if os.path.exists(pretrained_model_name_or_path):
dirname = (
pretrained_model_name_or_path
if subfolder is None
else os.path.join(pretrained_model_name_or_path, subfolder)
)
qmodel_path = os.path.join(dirname, "svdq-t5.safetensors")
config_path = os.path.join(dirname, "config.json")
else:
shared_kwargs = {
"repo_id": pretrained_model_name_or_path,
"subfolder": subfolder,
"repo_type": "model",
"revision": revision,
"library_name": kwargs.get("library_name"),
"library_version": kwargs.get("library_version"),
"cache_dir": cache_dir,
"local_dir": kwargs.get("local_dir"),
"user_agent": kwargs.get("user_agent"),
"force_download": force_download,
"proxies": kwargs.get("proxies"),
"etag_timeout": kwargs.get("etag_timeout", constants.DEFAULT_ETAG_TIMEOUT),
"token": token,
"local_files_only": local_files_only,
"headers": kwargs.get("headers"),
"endpoint": kwargs.get("endpoint"),
"resume_download": kwargs.get("resume_download"),
"force_filename": kwargs.get("force_filename"),
"local_dir_use_symlinks": kwargs.get("local_dir_use_symlinks", "auto"),
}
qmodel_path = hf_hub_download(filename="svdq-t5.safetensors", **shared_kwargs)
config_path = hf_hub_download(filename="config.json", **shared_kwargs)
# Load the config file # Load the config file
config = T5Config.from_json_file(config_path) metadata = state_dict.pop("__metadata__", {})
config = json.loads(metadata["config"])
config = T5Config(**config)
# Initialize model on 'meta' device (no memory allocation for weights) # Initialize model on 'meta' device (no memory allocation for weights)
with init_empty_weights(): with init_empty_weights():
t5_encoder = T5EncoderModel(config).to(kwargs.get("torch_dtype", torch.bfloat16)) t5_encoder = T5EncoderModel(config).to(kwargs.get("torch_dtype", torch.bfloat16))
t5_encoder.eval() t5_encoder.eval()
# Load the model weights from the safetensors file
state_dict = load_file(qmodel_path)
# Load the model weights from the safetensors file
named_modules = {} named_modules = {}
for name, module in t5_encoder.named_modules(): for name, module in t5_encoder.named_modules():
assert isinstance(name, str) assert isinstance(name, str)
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
if f"{name}.qweight" in state_dict: if f"{name}.qweight" in state_dict:
print(f"Switching {name} to W4Linear") logger.debug(f"Switching {name} to W4Linear")
qmodule = W4Linear.from_linear(module, group_size=128, init_only=True) qmodule = W4Linear.from_linear(module, group_size=128, init_only=True)
# modeling_t5.py: T5DenseGatedActDense needs dtype of weight # modeling_t5.py: T5DenseGatedActDense needs dtype of weight
qmodule.weight = torch.empty([1], dtype=module.weight.dtype, device=module.weight.device) qmodule.weight = torch.empty([1], dtype=module.weight.dtype, device=module.weight.device)
......
import logging import logging
import os import os
from pathlib import Path
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
import diffusers import diffusers
...@@ -260,14 +261,21 @@ class EmbedND(nn.Module): ...@@ -260,14 +261,21 @@ class EmbedND(nn.Module):
def load_quantized_module( def load_quantized_module(
path: str, device: str | torch.device = "cuda", use_fp4: bool = False, offload: bool = False, bf16: bool = True path_or_state_dict: str | os.PathLike[str] | dict[str, torch.Tensor],
device: str | torch.device = "cuda",
use_fp4: bool = False,
offload: bool = False,
bf16: bool = True,
) -> QuantizedFluxModel: ) -> QuantizedFluxModel:
device = torch.device(device) device = torch.device(device)
assert device.type == "cuda" assert device.type == "cuda"
m = QuantizedFluxModel() m = QuantizedFluxModel()
cutils.disable_memory_auto_release() cutils.disable_memory_auto_release()
m.init(use_fp4, offload, bf16, 0 if device.index is None else device.index) m.init(use_fp4, offload, bf16, 0 if device.index is None else device.index)
m.load(path) if isinstance(path_or_state_dict, dict):
m.loadDict(path_or_state_dict, True)
else:
m.load(str(path_or_state_dict))
return m return m
...@@ -313,19 +321,35 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader ...@@ -313,19 +321,35 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
@classmethod @classmethod
@utils.validate_hf_hub_args @utils.validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
device = kwargs.get("device", "cuda") device = kwargs.get("device", "cuda")
if isinstance(device, str): if isinstance(device, str):
device = torch.device(device) device = torch.device(device)
offload = kwargs.get("offload", False) offload = kwargs.get("offload", False)
torch_dtype = kwargs.get("torch_dtype", torch.bfloat16) torch_dtype = kwargs.get("torch_dtype", torch.bfloat16)
precision = get_precision(kwargs.get("precision", "auto"), device, pretrained_model_name_or_path) precision = get_precision(kwargs.get("precision", "auto"), device, pretrained_model_name_or_path)
transformer, unquantized_part_path, transformer_block_path = cls._build_model(
pretrained_model_name_or_path, **kwargs
)
# get the default LoRA branch and all the vectors if isinstance(pretrained_model_name_or_path, str):
quantized_part_sd = load_file(transformer_block_path) pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
if pretrained_model_name_or_path.is_file() or pretrained_model_name_or_path.name.endswith(
(".safetensors", ".sft")
):
transformer, model_state_dict = cls._build_model(pretrained_model_name_or_path, **kwargs)
quantized_part_sd = {}
unquantized_part_sd = {}
for k, v in model_state_dict.items():
if k.startswith(("transformer_blocks.", "single_transformer_blocks.")):
quantized_part_sd[k] = v
else:
unquantized_part_sd[k] = v
else:
transformer, unquantized_part_path, transformer_block_path = cls._build_model_legacy(
pretrained_model_name_or_path, **kwargs
)
# get the default LoRA branch and all the vectors
quantized_part_sd = load_file(transformer_block_path)
unquantized_part_sd = load_file(unquantized_part_path)
new_quantized_part_sd = {} new_quantized_part_sd = {}
for k, v in quantized_part_sd.items(): for k, v in quantized_part_sd.items():
if v.ndim == 1: if v.ndim == 1:
...@@ -348,7 +372,7 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader ...@@ -348,7 +372,7 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
new_quantized_part_sd[k] = v new_quantized_part_sd[k] = v
transformer._quantized_part_sd = new_quantized_part_sd transformer._quantized_part_sd = new_quantized_part_sd
m = load_quantized_module( m = load_quantized_module(
transformer_block_path, quantized_part_sd,
device=device, device=device,
use_fp4=precision == "fp4", use_fp4=precision == "fp4",
offload=offload, offload=offload,
...@@ -357,7 +381,6 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader ...@@ -357,7 +381,6 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
transformer.inject_quantized_module(m, device) transformer.inject_quantized_module(m, device)
transformer.to_empty(device=device) transformer.to_empty(device=device)
unquantized_part_sd = load_file(unquantized_part_path)
transformer.load_state_dict(unquantized_part_sd, strict=False) transformer.load_state_dict(unquantized_part_sd, strict=False)
transformer._unquantized_part_sd = unquantized_part_sd transformer._unquantized_part_sd = unquantized_part_sd
......
import os import os
from pathlib import Path
from typing import Optional from typing import Optional
import torch import torch
...@@ -139,20 +140,43 @@ class NunchakuSanaTransformerBlocks(nn.Module): ...@@ -139,20 +140,43 @@ class NunchakuSanaTransformerBlocks(nn.Module):
class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoaderMixin): class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoaderMixin):
@classmethod @classmethod
@utils.validate_hf_hub_args @utils.validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
device = kwargs.get("device", "cuda") device = kwargs.get("device", "cuda")
if isinstance(device, str):
device = torch.device(device)
pag_layers = kwargs.get("pag_layers", []) pag_layers = kwargs.get("pag_layers", [])
precision = get_precision(kwargs.get("precision", "auto"), device, pretrained_model_name_or_path) precision = get_precision(kwargs.get("precision", "auto"), device, pretrained_model_name_or_path)
transformer, unquantized_part_path, transformer_block_path = cls._build_model(
pretrained_model_name_or_path, **kwargs if isinstance(pretrained_model_name_or_path, str):
) pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
m = load_quantized_module( if pretrained_model_name_or_path.is_file() or pretrained_model_name_or_path.name.endswith(
transformer, transformer_block_path, device=device, pag_layers=pag_layers, use_fp4=precision == "fp4" (".safetensors", ".sft")
) ):
transformer.inject_quantized_module(m, device) transformer, model_state_dict = cls._build_model(pretrained_model_name_or_path)
transformer.to_empty(device=device) quantized_part_sd = {}
unquantized_state_dict = load_file(unquantized_part_path) unquantized_part_sd = {}
transformer.load_state_dict(unquantized_state_dict, strict=False) for k, v in model_state_dict.items():
if k.startswith("transformer_blocks."):
quantized_part_sd[k] = v
else:
unquantized_part_sd[k] = v
m = load_quantized_module(
transformer, quantized_part_sd, device=device, pag_layers=pag_layers, use_fp4=precision == "fp4"
)
transformer.inject_quantized_module(m, device)
transformer.to_empty(device=device)
transformer.load_state_dict(unquantized_part_sd, strict=False)
else:
transformer, unquantized_part_path, transformer_block_path = cls._build_model_legacy(
pretrained_model_name_or_path, **kwargs
)
m = load_quantized_module(
transformer, transformer_block_path, device=device, pag_layers=pag_layers, use_fp4=precision == "fp4"
)
transformer.inject_quantized_module(m, device)
transformer.to_empty(device=device)
unquantized_state_dict = load_file(unquantized_part_path)
transformer.load_state_dict(unquantized_state_dict, strict=False)
return transformer return transformer
def inject_quantized_module(self, m: QuantizedSanaModel, device: str | torch.device = "cuda"): def inject_quantized_module(self, m: QuantizedSanaModel, device: str | torch.device = "cuda"):
...@@ -162,7 +186,7 @@ class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoader ...@@ -162,7 +186,7 @@ class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoader
def load_quantized_module( def load_quantized_module(
net: SanaTransformer2DModel, net: SanaTransformer2DModel,
path: str, path_or_state_dict: str | os.PathLike[str] | dict[str, torch.Tensor],
device: str | torch.device = "cuda", device: str | torch.device = "cuda",
pag_layers: int | list[int] | None = None, pag_layers: int | list[int] | None = None,
use_fp4: bool = False, use_fp4: bool = False,
...@@ -177,7 +201,10 @@ def load_quantized_module( ...@@ -177,7 +201,10 @@ def load_quantized_module(
m = QuantizedSanaModel() m = QuantizedSanaModel()
cutils.disable_memory_auto_release() cutils.disable_memory_auto_release()
m.init(net.config, pag_layers, use_fp4, net.dtype == torch.bfloat16, 0 if device.index is None else device.index) m.init(net.config, pag_layers, use_fp4, net.dtype == torch.bfloat16, 0 if device.index is None else device.index)
m.load(path) if isinstance(path_or_state_dict, dict):
m.loadDict(path_or_state_dict, True)
else:
m.load(str(path_or_state_dict))
return m return m
......
import json
import logging
import os import os
from pathlib import Path
from typing import Any, Optional from typing import Any, Optional
import torch import torch
...@@ -6,12 +9,45 @@ from diffusers import __version__ ...@@ -6,12 +9,45 @@ from diffusers import __version__
from huggingface_hub import constants, hf_hub_download from huggingface_hub import constants, hf_hub_download
from torch import nn from torch import nn
from nunchaku.utils import ceil_divide from nunchaku.utils import ceil_divide, load_state_dict_in_safetensors
# Get log level from environment variable (default to INFO)
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
# Configure logging
logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
class NunchakuModelLoaderMixin: class NunchakuModelLoaderMixin:
@classmethod
def _build_model(
cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs
) -> tuple[nn.Module, dict[str, torch.Tensor]]:
if isinstance(pretrained_model_name_or_path, str):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
state_dict = load_state_dict_in_safetensors(pretrained_model_name_or_path, return_metadata=True)
# Load the config file
metadata = state_dict.pop("__metadata__", {})
config = json.loads(metadata["config"])
with torch.device("meta"):
transformer = cls.from_config(config).to(kwargs.get("torch_dtype", torch.bfloat16))
return transformer, state_dict
@classmethod @classmethod
def _build_model(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs) -> tuple[nn.Module, str, str]: def _build_model_legacy(
cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs
) -> tuple[nn.Module, str, str]:
logger.warning(
"Loading models from a folder will be deprecated in v0.4. "
"Please download the latest safetensors model, or use one of the following tools to "
"merge your model into a single file: the CLI utility `python -m nunchaku.merge_models` "
"or the ComfyUI node `MergeFolderIntoSingleFile`."
)
subfolder = kwargs.get("subfolder", None) subfolder = kwargs.get("subfolder", None)
if os.path.exists(pretrained_model_name_or_path): if os.path.exists(pretrained_model_name_or_path):
dirname = ( dirname = (
...@@ -41,10 +77,10 @@ class NunchakuModelLoaderMixin: ...@@ -41,10 +77,10 @@ class NunchakuModelLoaderMixin:
"local_dir_use_symlinks": kwargs.get("local_dir_use_symlinks", "auto"), "local_dir_use_symlinks": kwargs.get("local_dir_use_symlinks", "auto"),
} }
unquantized_part_path = hf_hub_download( unquantized_part_path = hf_hub_download(
repo_id=pretrained_model_name_or_path, filename="unquantized_layers.safetensors", **download_kwargs repo_id=str(pretrained_model_name_or_path), filename="unquantized_layers.safetensors", **download_kwargs
) )
transformer_block_path = hf_hub_download( transformer_block_path = hf_hub_download(
repo_id=pretrained_model_name_or_path, filename="transformer_blocks.safetensors", **download_kwargs repo_id=str(pretrained_model_name_or_path), filename="transformer_blocks.safetensors", **download_kwargs
) )
cache_dir = kwargs.pop("cache_dir", None) cache_dir = kwargs.pop("cache_dir", None)
...@@ -70,7 +106,6 @@ class NunchakuModelLoaderMixin: ...@@ -70,7 +106,6 @@ class NunchakuModelLoaderMixin:
with torch.device("meta"): with torch.device("meta"):
transformer = cls.from_config(config).to(kwargs.get("torch_dtype", torch.bfloat16)) transformer = cls.from_config(config).to(kwargs.get("torch_dtype", torch.bfloat16))
return transformer, unquantized_part_path, transformer_block_path return transformer, unquantized_part_path, transformer_block_path
......
import os
import warnings import warnings
from os import PathLike
from pathlib import Path
import safetensors import safetensors
import torch import torch
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
def fetch_or_download(path: str, repo_type: str = "model") -> str: def fetch_or_download(path: str | Path, repo_type: str = "model") -> Path:
if not os.path.exists(path): path = Path(path)
hf_repo_id = os.path.dirname(path)
filename = os.path.basename(path) if path.exists():
path = hf_hub_download(repo_id=hf_repo_id, filename=filename, repo_type=repo_type) return path
return path
parts = path.parts
if len(parts) < 3:
raise ValueError(f"Path '{path}' is too short to extract repo_id and subfolder")
repo_id = "/".join(parts[:2])
sub_path = Path(*parts[2:])
filename = sub_path.name
subfolder = sub_path.parent if sub_path.parent != Path(".") else None
path = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder, repo_type=repo_type)
return Path(path)
def ceil_divide(x: int, divisor: int) -> int: def ceil_divide(x: int, divisor: int) -> int:
...@@ -31,7 +43,10 @@ def ceil_divide(x: int, divisor: int) -> int: ...@@ -31,7 +43,10 @@ def ceil_divide(x: int, divisor: int) -> int:
def load_state_dict_in_safetensors( def load_state_dict_in_safetensors(
path: str, device: str | torch.device = "cpu", filter_prefix: str = "" path: str | PathLike[str],
device: str | torch.device = "cpu",
filter_prefix: str = "",
return_metadata: bool = False,
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
"""Load state dict in SafeTensors. """Load state dict in SafeTensors.
...@@ -49,6 +64,9 @@ def load_state_dict_in_safetensors( ...@@ -49,6 +64,9 @@ def load_state_dict_in_safetensors(
""" """
state_dict = {} state_dict = {}
with safetensors.safe_open(fetch_or_download(path), framework="pt", device=device) as f: with safetensors.safe_open(fetch_or_download(path), framework="pt", device=device) as f:
metadata = f.metadata()
if return_metadata:
state_dict["__metadata__"] = metadata
for k in f.keys(): for k in f.keys():
if filter_prefix and not k.startswith(filter_prefix): if filter_prefix and not k.startswith(filter_prefix):
continue continue
......
...@@ -13,7 +13,9 @@ def test_device_id(): ...@@ -13,7 +13,9 @@ def test_device_id():
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
torch_dtype = torch.float16 if is_turing("cuda:1") else torch.bfloat16 torch_dtype = torch.float16 if is_turing("cuda:1") else torch.bfloat16
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-schnell", torch_dtype=torch_dtype, device="cuda:1" f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{precision}_r32-flux.1-schnell.safetensors",
torch_dtype=torch_dtype,
device="cuda:1",
) )
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch_dtype "black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch_dtype
......
...@@ -16,7 +16,9 @@ from nunchaku.utils import get_precision, is_turing ...@@ -16,7 +16,9 @@ from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_pulid(): def test_flux_dev_pulid():
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipeline = PuLIDFluxPipeline.from_pretrained( pipeline = PuLIDFluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-dev",
......
...@@ -22,11 +22,13 @@ def test_flux_schnell_memory(use_qencoder: bool, cpu_offload: bool, memory_limit ...@@ -22,11 +22,13 @@ def test_flux_schnell_memory(use_qencoder: bool, cpu_offload: bool, memory_limit
precision = get_precision() precision = get_precision()
pipeline_init_kwargs = { pipeline_init_kwargs = {
"transformer": NunchakuFluxTransformer2dModel.from_pretrained( "transformer": NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-schnell", offload=cpu_offload f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{precision}_r32-flux.1-schnell.safetensors", offload=cpu_offload
) )
} }
if use_qencoder: if use_qencoder:
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5") text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
)
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2 pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
......
...@@ -7,7 +7,7 @@ from .utils import run_test ...@@ -7,7 +7,7 @@ from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"height,width,use_qencoder,expected_lpips", [(1024, 1024, True, 0.136 if get_precision() == "int4" else 0.145)] "height,width,use_qencoder,expected_lpips", [(1024, 1024, True, 0.151 if get_precision() == "int4" else 0.145)]
) )
def test_flux_schnell_qencoder(height: int, width: int, use_qencoder: bool, expected_lpips: float): def test_flux_schnell_qencoder(height: int, width: int, use_qencoder: bool, expected_lpips: float):
run_test( run_test(
......
...@@ -112,7 +112,9 @@ def test_flux_teacache( ...@@ -112,7 +112,9 @@ def test_flux_teacache(
# Then, generate results with the 4-bit model # Then, generate results with the 4-bit model
if not already_generate(results_dir_4_bit, 1): if not already_generate(results_dir_4_bit, 1):
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda") ).to("cuda")
......
...@@ -63,27 +63,6 @@ def test_flux_fill_dev(): ...@@ -63,27 +63,6 @@ def test_flux_fill_dev():
) )
# @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
# def test_flux_dev_canny_lora():
# run_test(
# precision=get_precision(),
# model_name="flux.1-dev",
# dataset_name="MJHQ-control",
# task="canny",
# dtype=torch.bfloat16,
# height=1024,
# width=1024,
# num_inference_steps=30,
# guidance_scale=30,
# attention_impl="nunchaku-fp16",
# cpu_offload=False,
# lora_names="canny",
# lora_strengths=0.85,
# cache_threshold=0,
# expected_lpips=0.081,
# )
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_depth_lora(): def test_flux_dev_depth_lora():
run_test( run_test(
......
...@@ -12,7 +12,7 @@ from ..utils import compute_lpips ...@@ -12,7 +12,7 @@ from ..utils import compute_lpips
def test_lora_reset(): def test_lora_reset():
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-dev", offload=True f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors", offload=True
) )
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
...@@ -44,4 +44,4 @@ def test_lora_reset(): ...@@ -44,4 +44,4 @@ def test_lora_reset():
lpips = compute_lpips(os.path.join(save_dir, "before.png"), os.path.join(save_dir, "after.png")) lpips = compute_lpips(os.path.join(save_dir, "before.png"), os.path.join(save_dir, "after.png"))
print(f"LPIPS: {lpips}") print(f"LPIPS: {lpips}")
assert lpips < 0.179 * 1.1 assert lpips < 0.232 * 1.1
...@@ -28,12 +28,12 @@ ORIGINAL_REPO_MAP = { ...@@ -28,12 +28,12 @@ ORIGINAL_REPO_MAP = {
} }
NUNCHAKU_REPO_PATTERN_MAP = { NUNCHAKU_REPO_PATTERN_MAP = {
"flux.1-schnell": "mit-han-lab/svdq-{precision}-flux.1-schnell", "flux.1-schnell": "mit-han-lab/nunchaku-flux.1-schnell/svdq-{precision}_r32-flux.1-schnell.safetensors",
"flux.1-dev": "mit-han-lab/svdq-{precision}-flux.1-dev", "flux.1-dev": "mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors",
"shuttle-jaguar": "mit-han-lab/svdq-{precision}-shuttle-jaguar", "shuttle-jaguar": "mit-han-lab/nunchaku-shuttle-jaguar/svdq-{precision}_r32-shuttle-jaguar.safetensors",
"flux.1-canny-dev": "mit-han-lab/svdq-{precision}-flux.1-canny-dev", "flux.1-canny-dev": "mit-han-lab/nunchaku-flux.1-canny-dev/svdq-{precision}_r32-flux.1-canny-dev.safetensors",
"flux.1-depth-dev": "mit-han-lab/svdq-{precision}-flux.1-depth-dev", "flux.1-depth-dev": "mit-han-lab/nunchaku-flux.1-depth-dev/svdq-{precision}_r32-flux.1-depth-dev.safetensors",
"flux.1-fill-dev": "mit-han-lab/svdq-{precision}-flux.1-fill-dev", "flux.1-fill-dev": "mit-han-lab/nunchaku-flux.1-fill-dev/svdq-{precision}_r32-flux.1-fill-dev.safetensors",
} }
LORA_PATH_MAP = { LORA_PATH_MAP = {
...@@ -285,7 +285,9 @@ def run_test( ...@@ -285,7 +285,9 @@ def run_test(
if task == "redux": if task == "redux":
pipeline_init_kwargs.update({"text_encoder": None, "text_encoder_2": None}) pipeline_init_kwargs.update({"text_encoder": None, "text_encoder_2": None})
elif use_qencoder: elif use_qencoder:
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5") text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
)
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2 pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = pipeline_cls.from_pretrained(model_id_16bit, torch_dtype=dtype, **pipeline_init_kwargs) pipeline = pipeline_cls.from_pretrained(model_id_16bit, torch_dtype=dtype, **pipeline_init_kwargs)
if cpu_offload: if cpu_offload:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment