Unverified Commit 5182f8f8 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

feat: single-file model loading (#413)

* add a script to merge models

* finished

* try to merge t5

* merge the config into meta files

* rewrite the t5 model loading

* consider the case of subfolder

* merged the qencoder files

* make the linter happy and fix the tests

* pass tests

* add deprecation messages

* add a script to merge models

* schnell script runnable

* update sana

* modify the model paths

* fix the model paths

* style: make the linter happy

* remove the debugging assertion

* chore: fix the qencoder lpips

* fix the lpips
parent 8401d290
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
from nunchaku.utils import get_precision
def main():
pipeline_init_kwargs = {}
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-schnell")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16, **pipeline_init_kwargs
).to("cuda")
image = pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
).images[0]
image.save(f"flux.1-schnell-qencoder-{precision}.png")
if __name__ == "__main__":
main()
......@@ -5,7 +5,9 @@ from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-schnell")
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{precision}_r32-flux.1-schnell.safetensors"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
......
......@@ -4,7 +4,9 @@ from diffusers import SanaPipeline
from nunchaku import NunchakuSanaTransformer2DModel
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
transformer = NunchakuSanaTransformer2DModel.from_pretrained(
"mit-han-lab/nunchaku-sana/svdq-int4_r32-sana1.6b.safetensors"
)
pipe = SanaPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
transformer=transformer,
......@@ -28,4 +30,4 @@ image = pipe(
generator=torch.Generator().manual_seed(42),
).images[0]
image.save("sana_1600m-int4.png")
image.save("sana1.6b-int4.png")
......@@ -3,7 +3,9 @@ from diffusers import SanaPipeline
from nunchaku import NunchakuSanaTransformer2DModel
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
transformer = NunchakuSanaTransformer2DModel.from_pretrained(
"mit-han-lab/nunchaku-sana/svdq-int4_r32-sana1.6b.safetensors"
)
pipe = SanaPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
transformer=transformer,
......@@ -23,4 +25,4 @@ image = pipe(
generator=torch.Generator().manual_seed(42),
).images[0]
image.save("sana_1600m-int4.png")
image.save("sana1.6b-int4.png")
......@@ -3,7 +3,9 @@ from diffusers import SanaPAGPipeline
from nunchaku import NunchakuSanaTransformer2DModel
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m", pag_layers=8)
transformer = NunchakuSanaTransformer2DModel.from_pretrained(
"mit-han-lab/nunchaku-sana/svdq-int4_r32-sana1.6b.safetensors", pag_layers=8
)
pipe = SanaPAGPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
transformer=transformer,
......@@ -24,4 +26,4 @@ image = pipe(
pag_scale=2.0,
num_inference_steps=20,
).images[0]
image.save("sana_1600m_pag-int4.png")
image.save("sana1.6b_pag-int4.png")
import argparse
import os
from pathlib import Path
import torch
from huggingface_hub import constants, hf_hub_download
from safetensors.torch import save_file
from .utils import load_state_dict_in_safetensors
def merge_models_into_a_single_file(
pretrained_model_name_or_path: str | os.PathLike[str], **kwargs
) -> tuple[dict[str, torch.Tensor], dict[str, str]]:
subfolder = kwargs.get("subfolder", None)
if isinstance(pretrained_model_name_or_path, str):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
if pretrained_model_name_or_path.exists():
dirpath = pretrained_model_name_or_path if subfolder is None else pretrained_model_name_or_path / subfolder
unquantized_part_path = dirpath / "unquantized_layers.safetensors"
transformer_block_path = dirpath / "transformer_blocks.safetensors"
config_path = dirpath / "config.json"
comfy_config_path = dirpath / "comfy_config.json"
else:
download_kwargs = {
"subfolder": subfolder,
"repo_type": "model",
"revision": kwargs.get("revision", None),
"cache_dir": kwargs.get("cache_dir", None),
"local_dir": kwargs.get("local_dir", None),
"user_agent": kwargs.get("user_agent", None),
"force_download": kwargs.get("force_download", False),
"proxies": kwargs.get("proxies", None),
"etag_timeout": kwargs.get("etag_timeout", constants.DEFAULT_ETAG_TIMEOUT),
"token": kwargs.get("token", None),
"local_files_only": kwargs.get("local_files_only", None),
"headers": kwargs.get("headers", None),
"endpoint": kwargs.get("endpoint", None),
"resume_download": kwargs.get("resume_download", None),
"force_filename": kwargs.get("force_filename", None),
"local_dir_use_symlinks": kwargs.get("local_dir_use_symlinks", "auto"),
}
unquantized_part_path = hf_hub_download(
repo_id=str(pretrained_model_name_or_path), filename="unquantized_layers.safetensors", **download_kwargs
)
transformer_block_path = hf_hub_download(
repo_id=str(pretrained_model_name_or_path), filename="transformer_blocks.safetensors", **download_kwargs
)
config_path = hf_hub_download(
repo_id=str(pretrained_model_name_or_path), filename="config.json", **download_kwargs
)
comfy_config_path = hf_hub_download(
repo_id=str(pretrained_model_name_or_path), filename="comfy_config.json", **download_kwargs
)
unquantized_part_sd = load_state_dict_in_safetensors(unquantized_part_path)
transformer_block_sd = load_state_dict_in_safetensors(transformer_block_path)
state_dict = unquantized_part_sd
state_dict.update(transformer_block_sd)
return state_dict, {
"config": Path(config_path).read_text(),
"comfy_config": Path(comfy_config_path).read_text(),
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--input-path",
type=Path,
required=True,
help="Path to model directory. It can also be a huggingface repo.",
)
parser.add_argument("-o", "--output-path", type=Path, required=True, help="Path to output path")
args = parser.parse_args()
state_dict, metadata = merge_models_into_a_single_file(args.input_path)
output_path = Path(args.output_path)
dirpath = output_path.parent
dirpath.mkdir(parents=True, exist_ok=True)
save_file(state_dict, output_path, metadata=metadata)
import argparse
import os
from pathlib import Path
import torch
from huggingface_hub import constants, hf_hub_download
from safetensors.torch import save_file
from .utils import load_state_dict_in_safetensors
def merge_config_into_model(
pretrained_model_name_or_path: str | os.PathLike[str], **kwargs
) -> tuple[dict[str, torch.Tensor], dict[str, str]]:
subfolder = kwargs.get("subfolder", None)
if isinstance(pretrained_model_name_or_path, str):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
if pretrained_model_name_or_path.exists():
dirpath = pretrained_model_name_or_path if subfolder is None else pretrained_model_name_or_path / subfolder
model_path = dirpath / "awq-int4-flux.1-t5xxl.safetensors"
config_path = dirpath / "config.json"
else:
download_kwargs = {
"subfolder": subfolder,
"repo_type": "model",
"revision": kwargs.get("revision", None),
"cache_dir": kwargs.get("cache_dir", None),
"local_dir": kwargs.get("local_dir", None),
"user_agent": kwargs.get("user_agent", None),
"force_download": kwargs.get("force_download", False),
"proxies": kwargs.get("proxies", None),
"etag_timeout": kwargs.get("etag_timeout", constants.DEFAULT_ETAG_TIMEOUT),
"token": kwargs.get("token", None),
"local_files_only": kwargs.get("local_files_only", None),
"headers": kwargs.get("headers", None),
"endpoint": kwargs.get("endpoint", None),
"resume_download": kwargs.get("resume_download", None),
"force_filename": kwargs.get("force_filename", None),
"local_dir_use_symlinks": kwargs.get("local_dir_use_symlinks", "auto"),
}
model_path = hf_hub_download(
repo_id=str(pretrained_model_name_or_path), filename="awq-int4-flux.1-t5xxl.safetensors", **download_kwargs
)
config_path = hf_hub_download(
repo_id=str(pretrained_model_name_or_path), filename="config.json", **download_kwargs
)
model_path = Path(model_path)
config_path = Path(config_path)
state_dict = load_state_dict_in_safetensors(model_path)
metadata = {"config": config_path.read_text()}
return state_dict, metadata
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--input-path",
type=Path,
default="mit-han-lab/nunchaku-t5",
help="Path to model directory. It can also be a huggingface repo.",
)
parser.add_argument("-o", "--output-path", type=Path, required=True, help="Path to output path")
args = parser.parse_args()
state_dict, metadata = merge_config_into_model(args.input_path)
output_path = Path(args.output_path)
dirpath = output_path.parent
dirpath.mkdir(parents=True, exist_ok=True)
save_file(state_dict, output_path, metadata=metadata)
import json
import logging
import os
from pathlib import Path
import torch
from accelerate import init_empty_weights
from huggingface_hub import constants, hf_hub_download
from safetensors.torch import load_file
from torch import nn
from transformers import PretrainedConfig, T5Config, T5EncoderModel
from transformers import T5Config, T5EncoderModel
from ...utils import load_state_dict_in_safetensors
from .linear import W4Linear
# Get log level from environment variable (default to INFO)
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
# Configure logging
logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
class NunchakuT5EncoderModel(T5EncoderModel):
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str | os.PathLike,
config: PretrainedConfig | str | os.PathLike | None = None,
cache_dir: str | os.PathLike | None = None,
force_download: bool = False,
local_files_only: bool = False,
token: str | bool | None = None,
revision: str = "main",
**kwargs,
):
subfolder = kwargs.get("subfolder", None)
if os.path.exists(pretrained_model_name_or_path):
dirname = (
pretrained_model_name_or_path
if subfolder is None
else os.path.join(pretrained_model_name_or_path, subfolder)
)
qmodel_path = os.path.join(dirname, "svdq-t5.safetensors")
config_path = os.path.join(dirname, "config.json")
else:
shared_kwargs = {
"repo_id": pretrained_model_name_or_path,
"subfolder": subfolder,
"repo_type": "model",
"revision": revision,
"library_name": kwargs.get("library_name"),
"library_version": kwargs.get("library_version"),
"cache_dir": cache_dir,
"local_dir": kwargs.get("local_dir"),
"user_agent": kwargs.get("user_agent"),
"force_download": force_download,
"proxies": kwargs.get("proxies"),
"etag_timeout": kwargs.get("etag_timeout", constants.DEFAULT_ETAG_TIMEOUT),
"token": token,
"local_files_only": local_files_only,
"headers": kwargs.get("headers"),
"endpoint": kwargs.get("endpoint"),
"resume_download": kwargs.get("resume_download"),
"force_filename": kwargs.get("force_filename"),
"local_dir_use_symlinks": kwargs.get("local_dir_use_symlinks", "auto"),
}
qmodel_path = hf_hub_download(filename="svdq-t5.safetensors", **shared_kwargs)
config_path = hf_hub_download(filename="config.json", **shared_kwargs)
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
state_dict = load_state_dict_in_safetensors(pretrained_model_name_or_path, return_metadata=True)
# Load the config file
config = T5Config.from_json_file(config_path)
metadata = state_dict.pop("__metadata__", {})
config = json.loads(metadata["config"])
config = T5Config(**config)
# Initialize model on 'meta' device (no memory allocation for weights)
with init_empty_weights():
t5_encoder = T5EncoderModel(config).to(kwargs.get("torch_dtype", torch.bfloat16))
t5_encoder.eval()
# Load the model weights from the safetensors file
state_dict = load_file(qmodel_path)
# Load the model weights from the safetensors file
named_modules = {}
for name, module in t5_encoder.named_modules():
assert isinstance(name, str)
if isinstance(module, nn.Linear):
if f"{name}.qweight" in state_dict:
print(f"Switching {name} to W4Linear")
logger.debug(f"Switching {name} to W4Linear")
qmodule = W4Linear.from_linear(module, group_size=128, init_only=True)
# modeling_t5.py: T5DenseGatedActDense needs dtype of weight
qmodule.weight = torch.empty([1], dtype=module.weight.dtype, device=module.weight.device)
......
import logging
import os
from pathlib import Path
from typing import Any, Dict, Optional, Union
import diffusers
......@@ -260,14 +261,21 @@ class EmbedND(nn.Module):
def load_quantized_module(
path: str, device: str | torch.device = "cuda", use_fp4: bool = False, offload: bool = False, bf16: bool = True
path_or_state_dict: str | os.PathLike[str] | dict[str, torch.Tensor],
device: str | torch.device = "cuda",
use_fp4: bool = False,
offload: bool = False,
bf16: bool = True,
) -> QuantizedFluxModel:
device = torch.device(device)
assert device.type == "cuda"
m = QuantizedFluxModel()
cutils.disable_memory_auto_release()
m.init(use_fp4, offload, bf16, 0 if device.index is None else device.index)
m.load(path)
if isinstance(path_or_state_dict, dict):
m.loadDict(path_or_state_dict, True)
else:
m.load(str(path_or_state_dict))
return m
......@@ -313,19 +321,35 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
@classmethod
@utils.validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs):
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
device = kwargs.get("device", "cuda")
if isinstance(device, str):
device = torch.device(device)
offload = kwargs.get("offload", False)
torch_dtype = kwargs.get("torch_dtype", torch.bfloat16)
precision = get_precision(kwargs.get("precision", "auto"), device, pretrained_model_name_or_path)
transformer, unquantized_part_path, transformer_block_path = cls._build_model(
pretrained_model_name_or_path, **kwargs
)
# get the default LoRA branch and all the vectors
quantized_part_sd = load_file(transformer_block_path)
if isinstance(pretrained_model_name_or_path, str):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
if pretrained_model_name_or_path.is_file() or pretrained_model_name_or_path.name.endswith(
(".safetensors", ".sft")
):
transformer, model_state_dict = cls._build_model(pretrained_model_name_or_path, **kwargs)
quantized_part_sd = {}
unquantized_part_sd = {}
for k, v in model_state_dict.items():
if k.startswith(("transformer_blocks.", "single_transformer_blocks.")):
quantized_part_sd[k] = v
else:
unquantized_part_sd[k] = v
else:
transformer, unquantized_part_path, transformer_block_path = cls._build_model_legacy(
pretrained_model_name_or_path, **kwargs
)
# get the default LoRA branch and all the vectors
quantized_part_sd = load_file(transformer_block_path)
unquantized_part_sd = load_file(unquantized_part_path)
new_quantized_part_sd = {}
for k, v in quantized_part_sd.items():
if v.ndim == 1:
......@@ -348,7 +372,7 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
new_quantized_part_sd[k] = v
transformer._quantized_part_sd = new_quantized_part_sd
m = load_quantized_module(
transformer_block_path,
quantized_part_sd,
device=device,
use_fp4=precision == "fp4",
offload=offload,
......@@ -357,7 +381,6 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
transformer.inject_quantized_module(m, device)
transformer.to_empty(device=device)
unquantized_part_sd = load_file(unquantized_part_path)
transformer.load_state_dict(unquantized_part_sd, strict=False)
transformer._unquantized_part_sd = unquantized_part_sd
......
import os
from pathlib import Path
from typing import Optional
import torch
......@@ -139,20 +140,43 @@ class NunchakuSanaTransformerBlocks(nn.Module):
class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoaderMixin):
@classmethod
@utils.validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs):
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
device = kwargs.get("device", "cuda")
if isinstance(device, str):
device = torch.device(device)
pag_layers = kwargs.get("pag_layers", [])
precision = get_precision(kwargs.get("precision", "auto"), device, pretrained_model_name_or_path)
transformer, unquantized_part_path, transformer_block_path = cls._build_model(
pretrained_model_name_or_path, **kwargs
)
m = load_quantized_module(
transformer, transformer_block_path, device=device, pag_layers=pag_layers, use_fp4=precision == "fp4"
)
transformer.inject_quantized_module(m, device)
transformer.to_empty(device=device)
unquantized_state_dict = load_file(unquantized_part_path)
transformer.load_state_dict(unquantized_state_dict, strict=False)
if isinstance(pretrained_model_name_or_path, str):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
if pretrained_model_name_or_path.is_file() or pretrained_model_name_or_path.name.endswith(
(".safetensors", ".sft")
):
transformer, model_state_dict = cls._build_model(pretrained_model_name_or_path)
quantized_part_sd = {}
unquantized_part_sd = {}
for k, v in model_state_dict.items():
if k.startswith("transformer_blocks."):
quantized_part_sd[k] = v
else:
unquantized_part_sd[k] = v
m = load_quantized_module(
transformer, quantized_part_sd, device=device, pag_layers=pag_layers, use_fp4=precision == "fp4"
)
transformer.inject_quantized_module(m, device)
transformer.to_empty(device=device)
transformer.load_state_dict(unquantized_part_sd, strict=False)
else:
transformer, unquantized_part_path, transformer_block_path = cls._build_model_legacy(
pretrained_model_name_or_path, **kwargs
)
m = load_quantized_module(
transformer, transformer_block_path, device=device, pag_layers=pag_layers, use_fp4=precision == "fp4"
)
transformer.inject_quantized_module(m, device)
transformer.to_empty(device=device)
unquantized_state_dict = load_file(unquantized_part_path)
transformer.load_state_dict(unquantized_state_dict, strict=False)
return transformer
def inject_quantized_module(self, m: QuantizedSanaModel, device: str | torch.device = "cuda"):
......@@ -162,7 +186,7 @@ class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoader
def load_quantized_module(
net: SanaTransformer2DModel,
path: str,
path_or_state_dict: str | os.PathLike[str] | dict[str, torch.Tensor],
device: str | torch.device = "cuda",
pag_layers: int | list[int] | None = None,
use_fp4: bool = False,
......@@ -177,7 +201,10 @@ def load_quantized_module(
m = QuantizedSanaModel()
cutils.disable_memory_auto_release()
m.init(net.config, pag_layers, use_fp4, net.dtype == torch.bfloat16, 0 if device.index is None else device.index)
m.load(path)
if isinstance(path_or_state_dict, dict):
m.loadDict(path_or_state_dict, True)
else:
m.load(str(path_or_state_dict))
return m
......
import json
import logging
import os
from pathlib import Path
from typing import Any, Optional
import torch
......@@ -6,12 +9,45 @@ from diffusers import __version__
from huggingface_hub import constants, hf_hub_download
from torch import nn
from nunchaku.utils import ceil_divide
from nunchaku.utils import ceil_divide, load_state_dict_in_safetensors
# Get log level from environment variable (default to INFO)
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
# Configure logging
logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
class NunchakuModelLoaderMixin:
@classmethod
def _build_model(
cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs
) -> tuple[nn.Module, dict[str, torch.Tensor]]:
if isinstance(pretrained_model_name_or_path, str):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
state_dict = load_state_dict_in_safetensors(pretrained_model_name_or_path, return_metadata=True)
# Load the config file
metadata = state_dict.pop("__metadata__", {})
config = json.loads(metadata["config"])
with torch.device("meta"):
transformer = cls.from_config(config).to(kwargs.get("torch_dtype", torch.bfloat16))
return transformer, state_dict
@classmethod
def _build_model(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs) -> tuple[nn.Module, str, str]:
def _build_model_legacy(
cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs
) -> tuple[nn.Module, str, str]:
logger.warning(
"Loading models from a folder will be deprecated in v0.4. "
"Please download the latest safetensors model, or use one of the following tools to "
"merge your model into a single file: the CLI utility `python -m nunchaku.merge_models` "
"or the ComfyUI node `MergeFolderIntoSingleFile`."
)
subfolder = kwargs.get("subfolder", None)
if os.path.exists(pretrained_model_name_or_path):
dirname = (
......@@ -41,10 +77,10 @@ class NunchakuModelLoaderMixin:
"local_dir_use_symlinks": kwargs.get("local_dir_use_symlinks", "auto"),
}
unquantized_part_path = hf_hub_download(
repo_id=pretrained_model_name_or_path, filename="unquantized_layers.safetensors", **download_kwargs
repo_id=str(pretrained_model_name_or_path), filename="unquantized_layers.safetensors", **download_kwargs
)
transformer_block_path = hf_hub_download(
repo_id=pretrained_model_name_or_path, filename="transformer_blocks.safetensors", **download_kwargs
repo_id=str(pretrained_model_name_or_path), filename="transformer_blocks.safetensors", **download_kwargs
)
cache_dir = kwargs.pop("cache_dir", None)
......@@ -70,7 +106,6 @@ class NunchakuModelLoaderMixin:
with torch.device("meta"):
transformer = cls.from_config(config).to(kwargs.get("torch_dtype", torch.bfloat16))
return transformer, unquantized_part_path, transformer_block_path
......
import os
import warnings
from os import PathLike
from pathlib import Path
import safetensors
import torch
from huggingface_hub import hf_hub_download
def fetch_or_download(path: str, repo_type: str = "model") -> str:
if not os.path.exists(path):
hf_repo_id = os.path.dirname(path)
filename = os.path.basename(path)
path = hf_hub_download(repo_id=hf_repo_id, filename=filename, repo_type=repo_type)
return path
def fetch_or_download(path: str | Path, repo_type: str = "model") -> Path:
path = Path(path)
if path.exists():
return path
parts = path.parts
if len(parts) < 3:
raise ValueError(f"Path '{path}' is too short to extract repo_id and subfolder")
repo_id = "/".join(parts[:2])
sub_path = Path(*parts[2:])
filename = sub_path.name
subfolder = sub_path.parent if sub_path.parent != Path(".") else None
path = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder, repo_type=repo_type)
return Path(path)
def ceil_divide(x: int, divisor: int) -> int:
......@@ -31,7 +43,10 @@ def ceil_divide(x: int, divisor: int) -> int:
def load_state_dict_in_safetensors(
path: str, device: str | torch.device = "cpu", filter_prefix: str = ""
path: str | PathLike[str],
device: str | torch.device = "cpu",
filter_prefix: str = "",
return_metadata: bool = False,
) -> dict[str, torch.Tensor]:
"""Load state dict in SafeTensors.
......@@ -49,6 +64,9 @@ def load_state_dict_in_safetensors(
"""
state_dict = {}
with safetensors.safe_open(fetch_or_download(path), framework="pt", device=device) as f:
metadata = f.metadata()
if return_metadata:
state_dict["__metadata__"] = metadata
for k in f.keys():
if filter_prefix and not k.startswith(filter_prefix):
continue
......
......@@ -13,7 +13,9 @@ def test_device_id():
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
torch_dtype = torch.float16 if is_turing("cuda:1") else torch.bfloat16
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-schnell", torch_dtype=torch_dtype, device="cuda:1"
f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{precision}_r32-flux.1-schnell.safetensors",
torch_dtype=torch_dtype,
device="cuda:1",
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch_dtype
......
......@@ -16,7 +16,9 @@ from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_pulid():
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev")
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipeline = PuLIDFluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
......
......@@ -22,11 +22,13 @@ def test_flux_schnell_memory(use_qencoder: bool, cpu_offload: bool, memory_limit
precision = get_precision()
pipeline_init_kwargs = {
"transformer": NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-schnell", offload=cpu_offload
f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{precision}_r32-flux.1-schnell.safetensors", offload=cpu_offload
)
}
if use_qencoder:
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
)
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
......
......@@ -7,7 +7,7 @@ from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize(
"height,width,use_qencoder,expected_lpips", [(1024, 1024, True, 0.136 if get_precision() == "int4" else 0.145)]
"height,width,use_qencoder,expected_lpips", [(1024, 1024, True, 0.151 if get_precision() == "int4" else 0.145)]
)
def test_flux_schnell_qencoder(height: int, width: int, use_qencoder: bool, expected_lpips: float):
run_test(
......
......@@ -112,7 +112,9 @@ def test_flux_teacache(
# Then, generate results with the 4-bit model
if not already_generate(results_dir_4_bit, 1):
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev")
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
......
......@@ -63,27 +63,6 @@ def test_flux_fill_dev():
)
# @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
# def test_flux_dev_canny_lora():
# run_test(
# precision=get_precision(),
# model_name="flux.1-dev",
# dataset_name="MJHQ-control",
# task="canny",
# dtype=torch.bfloat16,
# height=1024,
# width=1024,
# num_inference_steps=30,
# guidance_scale=30,
# attention_impl="nunchaku-fp16",
# cpu_offload=False,
# lora_names="canny",
# lora_strengths=0.85,
# cache_threshold=0,
# expected_lpips=0.081,
# )
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_depth_lora():
run_test(
......
......@@ -12,7 +12,7 @@ from ..utils import compute_lpips
def test_lora_reset():
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-dev", offload=True
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors", offload=True
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
......@@ -44,4 +44,4 @@ def test_lora_reset():
lpips = compute_lpips(os.path.join(save_dir, "before.png"), os.path.join(save_dir, "after.png"))
print(f"LPIPS: {lpips}")
assert lpips < 0.179 * 1.1
assert lpips < 0.232 * 1.1
......@@ -28,12 +28,12 @@ ORIGINAL_REPO_MAP = {
}
NUNCHAKU_REPO_PATTERN_MAP = {
"flux.1-schnell": "mit-han-lab/svdq-{precision}-flux.1-schnell",
"flux.1-dev": "mit-han-lab/svdq-{precision}-flux.1-dev",
"shuttle-jaguar": "mit-han-lab/svdq-{precision}-shuttle-jaguar",
"flux.1-canny-dev": "mit-han-lab/svdq-{precision}-flux.1-canny-dev",
"flux.1-depth-dev": "mit-han-lab/svdq-{precision}-flux.1-depth-dev",
"flux.1-fill-dev": "mit-han-lab/svdq-{precision}-flux.1-fill-dev",
"flux.1-schnell": "mit-han-lab/nunchaku-flux.1-schnell/svdq-{precision}_r32-flux.1-schnell.safetensors",
"flux.1-dev": "mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors",
"shuttle-jaguar": "mit-han-lab/nunchaku-shuttle-jaguar/svdq-{precision}_r32-shuttle-jaguar.safetensors",
"flux.1-canny-dev": "mit-han-lab/nunchaku-flux.1-canny-dev/svdq-{precision}_r32-flux.1-canny-dev.safetensors",
"flux.1-depth-dev": "mit-han-lab/nunchaku-flux.1-depth-dev/svdq-{precision}_r32-flux.1-depth-dev.safetensors",
"flux.1-fill-dev": "mit-han-lab/nunchaku-flux.1-fill-dev/svdq-{precision}_r32-flux.1-fill-dev.safetensors",
}
LORA_PATH_MAP = {
......@@ -285,7 +285,9 @@ def run_test(
if task == "redux":
pipeline_init_kwargs.update({"text_encoder": None, "text_encoder_2": None})
elif use_qencoder:
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
)
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = pipeline_cls.from_pretrained(model_id_16bit, torch_dtype=dtype, **pipeline_init_kwargs)
if cpu_offload:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment