Unverified Commit e2116d7b authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

chore: release v0.3.0

parents 6098c419 d94c2078
......@@ -9,7 +9,9 @@ image = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev
mask = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/mask.png")
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-fill-dev")
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-fill-dev/svdq-{precision}_r32-flux.1-fill-dev.safetensors"
)
pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
......
......@@ -9,7 +9,9 @@ precision = get_precision()
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Redux-dev", torch_dtype=torch.bfloat16
).to("cuda")
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev")
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
text_encoder=None,
......
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
from nunchaku.utils import get_precision
def main():
pipeline_init_kwargs = {}
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-schnell")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16, **pipeline_init_kwargs
).to("cuda")
image = pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
).images[0]
image.save(f"flux.1-schnell-qencoder-{precision}.png")
if __name__ == "__main__":
main()
......@@ -5,7 +5,9 @@ from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-schnell")
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{precision}_r32-flux.1-schnell.safetensors"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
......
......@@ -4,7 +4,9 @@ from diffusers import SanaPipeline
from nunchaku import NunchakuSanaTransformer2DModel
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
transformer = NunchakuSanaTransformer2DModel.from_pretrained(
"mit-han-lab/nunchaku-sana/svdq-int4_r32-sana1.6b.safetensors"
)
pipe = SanaPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
transformer=transformer,
......@@ -28,4 +30,4 @@ image = pipe(
generator=torch.Generator().manual_seed(42),
).images[0]
image.save("sana_1600m-int4.png")
image.save("sana1.6b-int4.png")
......@@ -3,7 +3,9 @@ from diffusers import SanaPipeline
from nunchaku import NunchakuSanaTransformer2DModel
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
transformer = NunchakuSanaTransformer2DModel.from_pretrained(
"mit-han-lab/nunchaku-sana/svdq-int4_r32-sana1.6b.safetensors"
)
pipe = SanaPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
transformer=transformer,
......@@ -23,4 +25,4 @@ image = pipe(
generator=torch.Generator().manual_seed(42),
).images[0]
image.save("sana_1600m-int4.png")
image.save("sana1.6b-int4.png")
......@@ -3,7 +3,9 @@ from diffusers import SanaPAGPipeline
from nunchaku import NunchakuSanaTransformer2DModel
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m", pag_layers=8)
transformer = NunchakuSanaTransformer2DModel.from_pretrained(
"mit-han-lab/nunchaku-sana/svdq-int4_r32-sana1.6b.safetensors", pag_layers=8
)
pipe = SanaPAGPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
transformer=transformer,
......@@ -24,4 +26,4 @@ image = pipe(
pag_scale=2.0,
num_inference_steps=20,
).images[0]
image.save("sana_1600m_pag-int4.png")
image.save("sana1.6b_pag-int4.png")
__version__ = "0.3.0dev"
__version__ = "0.3.0"
import argparse
import json
import os
from pathlib import Path
import torch
from huggingface_hub import constants, hf_hub_download
from safetensors.torch import save_file
from .utils import load_state_dict_in_safetensors
def merge_safetensors(
pretrained_model_name_or_path: str | os.PathLike[str], **kwargs
) -> tuple[dict[str, torch.Tensor], dict[str, str]]:
subfolder = kwargs.get("subfolder", None)
comfy_config_path = kwargs.get("comfy_config_path", None)
if isinstance(pretrained_model_name_or_path, str):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
if pretrained_model_name_or_path.exists():
dirpath = pretrained_model_name_or_path if subfolder is None else pretrained_model_name_or_path / subfolder
unquantized_part_path = dirpath / "unquantized_layers.safetensors"
transformer_block_path = dirpath / "transformer_blocks.safetensors"
config_path = dirpath / "config.json"
if comfy_config_path is None:
comfy_config_path = dirpath / "comfy_config.json"
else:
download_kwargs = {
"subfolder": subfolder,
"repo_type": "model",
"revision": kwargs.get("revision", None),
"cache_dir": kwargs.get("cache_dir", None),
"local_dir": kwargs.get("local_dir", None),
"user_agent": kwargs.get("user_agent", None),
"force_download": kwargs.get("force_download", False),
"proxies": kwargs.get("proxies", None),
"etag_timeout": kwargs.get("etag_timeout", constants.DEFAULT_ETAG_TIMEOUT),
"token": kwargs.get("token", None),
"local_files_only": kwargs.get("local_files_only", None),
"headers": kwargs.get("headers", None),
"endpoint": kwargs.get("endpoint", None),
"resume_download": kwargs.get("resume_download", None),
"force_filename": kwargs.get("force_filename", None),
"local_dir_use_symlinks": kwargs.get("local_dir_use_symlinks", "auto"),
}
unquantized_part_path = hf_hub_download(
repo_id=str(pretrained_model_name_or_path), filename="unquantized_layers.safetensors", **download_kwargs
)
transformer_block_path = hf_hub_download(
repo_id=str(pretrained_model_name_or_path), filename="transformer_blocks.safetensors", **download_kwargs
)
config_path = hf_hub_download(
repo_id=str(pretrained_model_name_or_path), filename="config.json", **download_kwargs
)
comfy_config_path = hf_hub_download(
repo_id=str(pretrained_model_name_or_path), filename="comfy_config.json", **download_kwargs
)
unquantized_part_sd = load_state_dict_in_safetensors(unquantized_part_path)
transformer_block_sd = load_state_dict_in_safetensors(transformer_block_path)
state_dict = unquantized_part_sd
state_dict.update(transformer_block_sd)
precision = "int4"
for v in state_dict.values():
assert isinstance(v, torch.Tensor)
if v.dtype in [
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e8m0fnu,
]:
precision = "fp4"
quantization_config = {
"method": "svdquant",
"weight": {
"dtype": "fp4_e2m1_all" if precision == "fp4" else "int4",
"scale_dtype": [None, "fp8_e4m3_nan"] if precision == "fp4" else None,
"group_size": 16 if precision == "fp4" else 64,
},
"activation": {
"dtype": "fp4_e2m1_all" if precision == "fp4" else "int4",
"scale_dtype": "fp8_e4m3_nan" if precision == "fp4" else None,
"group_size": 16 if precision == "fp4" else 64,
},
}
return state_dict, {
"config": Path(config_path).read_text(),
"comfy_config": Path(comfy_config_path).read_text(),
"model_class": "NunchakuFluxTransformer2dModel",
"quantization_config": json.dumps(quantization_config),
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--input-path",
type=Path,
required=True,
help="Path to model directory. It can also be a huggingface repo.",
)
parser.add_argument("-o", "--output-path", type=Path, required=True, help="Path to output path")
args = parser.parse_args()
state_dict, metadata = merge_safetensors(args.input_path)
output_path = Path(args.output_path)
dirpath = output_path.parent
dirpath.mkdir(parents=True, exist_ok=True)
save_file(state_dict, output_path, metadata=metadata)
import json
import logging
import os
from pathlib import Path
import torch
from accelerate import init_empty_weights
from huggingface_hub import constants, hf_hub_download
from safetensors.torch import load_file
from torch import nn
from transformers import PretrainedConfig, T5Config, T5EncoderModel
from transformers import T5Config, T5EncoderModel
from ...utils import load_state_dict_in_safetensors
from .linear import W4Linear
# Get log level from environment variable (default to INFO)
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
# Configure logging
logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
class NunchakuT5EncoderModel(T5EncoderModel):
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str | os.PathLike,
config: PretrainedConfig | str | os.PathLike | None = None,
cache_dir: str | os.PathLike | None = None,
force_download: bool = False,
local_files_only: bool = False,
token: str | bool | None = None,
revision: str = "main",
**kwargs,
):
subfolder = kwargs.get("subfolder", None)
if os.path.exists(pretrained_model_name_or_path):
dirname = (
pretrained_model_name_or_path
if subfolder is None
else os.path.join(pretrained_model_name_or_path, subfolder)
)
qmodel_path = os.path.join(dirname, "svdq-t5.safetensors")
config_path = os.path.join(dirname, "config.json")
else:
shared_kwargs = {
"repo_id": pretrained_model_name_or_path,
"subfolder": subfolder,
"repo_type": "model",
"revision": revision,
"library_name": kwargs.get("library_name"),
"library_version": kwargs.get("library_version"),
"cache_dir": cache_dir,
"local_dir": kwargs.get("local_dir"),
"user_agent": kwargs.get("user_agent"),
"force_download": force_download,
"proxies": kwargs.get("proxies"),
"etag_timeout": kwargs.get("etag_timeout", constants.DEFAULT_ETAG_TIMEOUT),
"token": token,
"local_files_only": local_files_only,
"headers": kwargs.get("headers"),
"endpoint": kwargs.get("endpoint"),
"resume_download": kwargs.get("resume_download"),
"force_filename": kwargs.get("force_filename"),
"local_dir_use_symlinks": kwargs.get("local_dir_use_symlinks", "auto"),
}
qmodel_path = hf_hub_download(filename="svdq-t5.safetensors", **shared_kwargs)
config_path = hf_hub_download(filename="config.json", **shared_kwargs)
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
state_dict, metadata = load_state_dict_in_safetensors(pretrained_model_name_or_path, return_metadata=True)
# Load the config file
config = T5Config.from_json_file(config_path)
config = json.loads(metadata["config"])
config = T5Config(**config)
# Initialize model on 'meta' device (no memory allocation for weights)
with init_empty_weights():
t5_encoder = T5EncoderModel(config).to(kwargs.get("torch_dtype", torch.bfloat16))
t5_encoder.eval()
# Load the model weights from the safetensors file
state_dict = load_file(qmodel_path)
# Load the model weights from the safetensors file
named_modules = {}
for name, module in t5_encoder.named_modules():
assert isinstance(name, str)
if isinstance(module, nn.Linear):
if f"{name}.qweight" in state_dict:
print(f"Switching {name} to W4Linear")
logger.debug(f"Switching {name} to W4Linear")
qmodule = W4Linear.from_linear(module, group_size=128, init_only=True)
# modeling_t5.py: T5DenseGatedActDense needs dtype of weight
qmodule.weight = torch.empty([1], dtype=module.weight.dtype, device=module.weight.device)
......
import json
import logging
import os
from pathlib import Path
from typing import Any, Dict, Optional, Union
import diffusers
......@@ -16,7 +18,7 @@ from ..._C import QuantizedFluxModel
from ..._C import utils as cutils
from ...lora.flux.nunchaku_converter import fuse_vectors, to_nunchaku
from ...lora.flux.utils import is_nunchaku_format
from ...utils import get_precision, load_state_dict_in_safetensors
from ...utils import check_hardware_compatibility, get_precision, load_state_dict_in_safetensors
from .utils import NunchakuModelLoaderMixin, pad_tensor
SVD_RANK = 32
......@@ -260,14 +262,21 @@ class EmbedND(nn.Module):
def load_quantized_module(
path: str, device: str | torch.device = "cuda", use_fp4: bool = False, offload: bool = False, bf16: bool = True
path_or_state_dict: str | os.PathLike[str] | dict[str, torch.Tensor],
device: str | torch.device = "cuda",
use_fp4: bool = False,
offload: bool = False,
bf16: bool = True,
) -> QuantizedFluxModel:
device = torch.device(device)
assert device.type == "cuda"
m = QuantizedFluxModel()
cutils.disable_memory_auto_release()
m.init(use_fp4, offload, bf16, 0 if device.index is None else device.index)
m.load(path)
if isinstance(path_or_state_dict, dict):
m.loadDict(path_or_state_dict, True)
else:
m.load(str(path_or_state_dict))
return m
......@@ -307,25 +316,45 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
self._quantized_part_vectors: dict[str, torch.Tensor] = {}
self._original_in_channels = in_channels
# Comfyui LoRA related
# ComfyUI LoRA related
self.comfy_lora_meta_list = []
self.comfy_lora_sd_list = []
@classmethod
@utils.validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs):
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
device = kwargs.get("device", "cuda")
if isinstance(device, str):
device = torch.device(device)
offload = kwargs.get("offload", False)
torch_dtype = kwargs.get("torch_dtype", torch.bfloat16)
precision = get_precision(kwargs.get("precision", "auto"), device, pretrained_model_name_or_path)
transformer, unquantized_part_path, transformer_block_path = cls._build_model(
pretrained_model_name_or_path, **kwargs
)
metadata = None
if isinstance(pretrained_model_name_or_path, str):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
if pretrained_model_name_or_path.is_file() or pretrained_model_name_or_path.name.endswith(
(".safetensors", ".sft")
):
transformer, model_state_dict, metadata = cls._build_model(pretrained_model_name_or_path, **kwargs)
quantized_part_sd = {}
unquantized_part_sd = {}
for k, v in model_state_dict.items():
if k.startswith(("transformer_blocks.", "single_transformer_blocks.")):
quantized_part_sd[k] = v
else:
unquantized_part_sd[k] = v
precision = get_precision(device=device)
quantization_config = json.loads(metadata["quantization_config"])
check_hardware_compatibility(quantization_config, device)
else:
transformer, unquantized_part_path, transformer_block_path = cls._build_model_legacy(
pretrained_model_name_or_path, **kwargs
)
# get the default LoRA branch and all the vectors
quantized_part_sd = load_file(transformer_block_path)
# get the default LoRA branch and all the vectors
quantized_part_sd = load_file(transformer_block_path)
unquantized_part_sd = load_file(unquantized_part_path)
new_quantized_part_sd = {}
for k, v in quantized_part_sd.items():
if v.ndim == 1:
......@@ -348,7 +377,7 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
new_quantized_part_sd[k] = v
transformer._quantized_part_sd = new_quantized_part_sd
m = load_quantized_module(
transformer_block_path,
quantized_part_sd,
device=device,
use_fp4=precision == "fp4",
offload=offload,
......@@ -357,11 +386,13 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
transformer.inject_quantized_module(m, device)
transformer.to_empty(device=device)
unquantized_part_sd = load_file(unquantized_part_path)
transformer.load_state_dict(unquantized_part_sd, strict=False)
transformer._unquantized_part_sd = unquantized_part_sd
return transformer
if kwargs.get("return_metadata", False):
return transformer, metadata
else:
return transformer
def inject_quantized_module(self, m: QuantizedFluxModel, device: str | torch.device = "cuda"):
print("Injecting quantized module")
......
import os
from pathlib import Path
from typing import Optional
import torch
......@@ -139,21 +140,48 @@ class NunchakuSanaTransformerBlocks(nn.Module):
class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoaderMixin):
@classmethod
@utils.validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs):
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
device = kwargs.get("device", "cuda")
if isinstance(device, str):
device = torch.device(device)
pag_layers = kwargs.get("pag_layers", [])
precision = get_precision(kwargs.get("precision", "auto"), device, pretrained_model_name_or_path)
transformer, unquantized_part_path, transformer_block_path = cls._build_model(
pretrained_model_name_or_path, **kwargs
)
m = load_quantized_module(
transformer, transformer_block_path, device=device, pag_layers=pag_layers, use_fp4=precision == "fp4"
)
transformer.inject_quantized_module(m, device)
transformer.to_empty(device=device)
unquantized_state_dict = load_file(unquantized_part_path)
transformer.load_state_dict(unquantized_state_dict, strict=False)
return transformer
metadata = None
if isinstance(pretrained_model_name_or_path, str):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
if pretrained_model_name_or_path.is_file() or pretrained_model_name_or_path.name.endswith(
(".safetensors", ".sft")
):
transformer, model_state_dict, metadata = cls._build_model(pretrained_model_name_or_path)
quantized_part_sd = {}
unquantized_part_sd = {}
for k, v in model_state_dict.items():
if k.startswith("transformer_blocks."):
quantized_part_sd[k] = v
else:
unquantized_part_sd[k] = v
m = load_quantized_module(
transformer, quantized_part_sd, device=device, pag_layers=pag_layers, use_fp4=precision == "fp4"
)
transformer.inject_quantized_module(m, device)
transformer.to_empty(device=device)
transformer.load_state_dict(unquantized_part_sd, strict=False)
else:
transformer, unquantized_part_path, transformer_block_path = cls._build_model_legacy(
pretrained_model_name_or_path, **kwargs
)
m = load_quantized_module(
transformer, transformer_block_path, device=device, pag_layers=pag_layers, use_fp4=precision == "fp4"
)
transformer.inject_quantized_module(m, device)
transformer.to_empty(device=device)
unquantized_state_dict = load_file(unquantized_part_path)
transformer.load_state_dict(unquantized_state_dict, strict=False)
if kwargs.get("return_metadata", False):
return transformer, metadata
else:
return transformer
def inject_quantized_module(self, m: QuantizedSanaModel, device: str | torch.device = "cuda"):
self.transformer_blocks = torch.nn.ModuleList([NunchakuSanaTransformerBlocks(m, self.dtype, device)])
......@@ -162,7 +190,7 @@ class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoader
def load_quantized_module(
net: SanaTransformer2DModel,
path: str,
path_or_state_dict: str | os.PathLike[str] | dict[str, torch.Tensor],
device: str | torch.device = "cuda",
pag_layers: int | list[int] | None = None,
use_fp4: bool = False,
......@@ -177,7 +205,10 @@ def load_quantized_module(
m = QuantizedSanaModel()
cutils.disable_memory_auto_release()
m.init(net.config, pag_layers, use_fp4, net.dtype == torch.bfloat16, 0 if device.index is None else device.index)
m.load(path)
if isinstance(path_or_state_dict, dict):
m.loadDict(path_or_state_dict, True)
else:
m.load(str(path_or_state_dict))
return m
......
import json
import logging
import os
from pathlib import Path
from typing import Any, Optional
import torch
......@@ -6,12 +9,44 @@ from diffusers import __version__
from huggingface_hub import constants, hf_hub_download
from torch import nn
from nunchaku.utils import ceil_divide
from nunchaku.utils import ceil_divide, load_state_dict_in_safetensors
# Get log level from environment variable (default to INFO)
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
# Configure logging
logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
class NunchakuModelLoaderMixin:
@classmethod
def _build_model(
cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs
) -> tuple[nn.Module, dict[str, torch.Tensor], dict[str, str]]:
if isinstance(pretrained_model_name_or_path, str):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
state_dict, metadata = load_state_dict_in_safetensors(pretrained_model_name_or_path, return_metadata=True)
# Load the config file
config = json.loads(metadata["config"])
with torch.device("meta"):
transformer = cls.from_config(config).to(kwargs.get("torch_dtype", torch.bfloat16))
return transformer, state_dict, metadata
@classmethod
def _build_model(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs) -> tuple[nn.Module, str, str]:
def _build_model_legacy(
cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs
) -> tuple[nn.Module, str, str]:
logger.warning(
"Loading models from a folder will be deprecated in v0.4. "
"Please download the latest safetensors model, or use one of the following tools to "
"merge your model into a single file: the CLI utility `python -m nunchaku.merge_safetensors` "
"or the ComfyUI workflow `merge_safetensors.json`."
)
subfolder = kwargs.get("subfolder", None)
if os.path.exists(pretrained_model_name_or_path):
dirname = (
......@@ -41,10 +76,10 @@ class NunchakuModelLoaderMixin:
"local_dir_use_symlinks": kwargs.get("local_dir_use_symlinks", "auto"),
}
unquantized_part_path = hf_hub_download(
repo_id=pretrained_model_name_or_path, filename="unquantized_layers.safetensors", **download_kwargs
repo_id=str(pretrained_model_name_or_path), filename="unquantized_layers.safetensors", **download_kwargs
)
transformer_block_path = hf_hub_download(
repo_id=pretrained_model_name_or_path, filename="transformer_blocks.safetensors", **download_kwargs
repo_id=str(pretrained_model_name_or_path), filename="transformer_blocks.safetensors", **download_kwargs
)
cache_dir = kwargs.pop("cache_dir", None)
......@@ -70,7 +105,6 @@ class NunchakuModelLoaderMixin:
with torch.device("meta"):
transformer = cls.from_config(config).to(kwargs.get("torch_dtype", torch.bfloat16))
return transformer, unquantized_part_path, transformer_block_path
......
import os
import warnings
from pathlib import Path
import safetensors
import torch
from huggingface_hub import hf_hub_download
def fetch_or_download(path: str, repo_type: str = "model") -> str:
if not os.path.exists(path):
hf_repo_id = os.path.dirname(path)
filename = os.path.basename(path)
path = hf_hub_download(repo_id=hf_repo_id, filename=filename, repo_type=repo_type)
return path
def fetch_or_download(path: str | Path, repo_type: str = "model") -> Path:
path = Path(path)
if path.exists():
return path
parts = path.parts
if len(parts) < 3:
raise ValueError(f"Path '{path}' is too short to extract repo_id and subfolder")
repo_id = "/".join(parts[:2])
sub_path = Path(*parts[2:])
filename = sub_path.name
subfolder = sub_path.parent if sub_path.parent != Path(".") else None
path = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder, repo_type=repo_type)
return Path(path)
def ceil_divide(x: int, divisor: int) -> int:
......@@ -31,29 +43,22 @@ def ceil_divide(x: int, divisor: int) -> int:
def load_state_dict_in_safetensors(
path: str, device: str | torch.device = "cpu", filter_prefix: str = ""
) -> dict[str, torch.Tensor]:
"""Load state dict in SafeTensors.
Args:
path (`str`):
file path.
device (`str` | `torch.device`, optional, defaults to `"cpu"`):
device.
filter_prefix (`str`, optional, defaults to `""`):
filter prefix.
Returns:
`dict`:
loaded SafeTensors.
"""
path: str | os.PathLike[str],
device: str | torch.device = "cpu",
filter_prefix: str = "",
return_metadata: bool = False,
) -> dict[str, torch.Tensor] | tuple[dict[str, torch.Tensor], dict[str, str]]:
state_dict = {}
with safetensors.safe_open(fetch_or_download(path), framework="pt", device=device) as f:
metadata = f.metadata()
for k in f.keys():
if filter_prefix and not k.startswith(filter_prefix):
continue
state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k)
return state_dict
if return_metadata:
return state_dict, metadata
else:
return state_dict
def filter_state_dict(state_dict: dict[str, torch.Tensor], filter_prefix: str = "") -> dict[str, torch.Tensor]:
......@@ -73,7 +78,9 @@ def filter_state_dict(state_dict: dict[str, torch.Tensor], filter_prefix: str =
def get_precision(
precision: str = "auto", device: str | torch.device = "cuda", pretrained_model_name_or_path: str | None = None
precision: str = "auto",
device: str | torch.device = "cuda",
pretrained_model_name_or_path: str | os.PathLike[str] | None = None,
) -> str:
assert precision in ("auto", "int4", "fp4")
if precision == "auto":
......@@ -84,10 +91,10 @@ def get_precision(
precision = "fp4" if sm == "120" else "int4"
if pretrained_model_name_or_path is not None:
if precision == "int4":
if "fp4" in pretrained_model_name_or_path:
if "fp4" in str(pretrained_model_name_or_path):
warnings.warn("The model may be quantized to fp4, but you are loading it with int4 precision.")
elif precision == "fp4":
if "int4" in pretrained_model_name_or_path:
if "int4" in str(pretrained_model_name_or_path):
warnings.warn("The model may be quantized to int4, but you are loading it with fp4 precision.")
return precision
......@@ -128,3 +135,21 @@ def get_gpu_memory(device: str | torch.device = "cuda", unit: str = "GiB") -> in
return memory // (1024**2)
else:
return memory
def check_hardware_compatibility(quantization_config: dict, device: str | torch.device = "cuda"):
if isinstance(device, str):
device = torch.device(device)
capability = torch.cuda.get_device_capability(0 if device.index is None else device.index)
sm = f"{capability[0]}{capability[1]}"
if sm == "120": # you can only use the fp4 models
if quantization_config["weight"]["dtype"] != "fp4_e2m1_all":
raise ValueError('Please use "fp4" quantization for Blackwell GPUs. ')
elif sm in ["75", "80", "86", "89"]:
if quantization_config["weight"]["dtype"] != "int4":
raise ValueError('Please use "int4" quantization for Turing, Ampere and Ada GPUs. ')
else:
raise ValueError(
f"Unsupported GPU architecture {sm} due to the lack of 4-bit tensorcores. "
"Please use a Turing, Ampere, Ada or Blackwell GPU for this quantization configuration."
)
# Nunchaku Tests
Nunchaku uses pytest as its testing framework.
## Setting Up Test Environments
After installing `nunchaku` as described in the [README](../README.md#installation), you can install the test dependencies with:
```shell
pip install -r tests/requirements.txt
```
## Running the Tests
```shell
HF_TOKEN=$YOUR_HF_TOKEN pytest -v tests/flux/test_flux_memory.py
HF_TOKEN=$YOUR_HF_TOKEN pytest -v tests/flux --ignore=tests/flux/test_flux_memory.py
......@@ -15,7 +19,7 @@ HF_TOKEN=$YOUR_HF_TOKEN pytest -v tests/sana
```
> **Note:** `$YOUR_HF_TOKEN` refers to your Hugging Face access token, required to download models and datasets. You can create one at [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens).
> If you've already logged in using `huggingface-cli login`, you can skip setting this environment variable.
> If you've already logged in using `huggingface-cli login`, you can skip setting this environment variable.
Some tests generate images using the original 16-bit models. You can cache these results to speed up future test runs by setting the environment variable `NUNCHAKU_TEST_CACHE_ROOT`. If not set, the images will be saved in `test_results/ref`.
......@@ -27,9 +31,9 @@ To test visual output correctness, you can:
1. **Generate reference images:** Use the original 16-bit model to produce a small number of reference images (e.g., 4).
2. **Generate comparison images:** Run your method using the **same inputs and seeds** to ensure deterministic outputs. You can control the seed by setting the `generator` parameter in the diffusers pipeline.
1. **Generate comparison images:** Run your method using the **same inputs and seeds** to ensure deterministic outputs. You can control the seed by setting the `generator` parameter in the diffusers pipeline.
3. **Compute similarity:** Evaluate the similarity between your outputs and the reference images using the [LPIPS](https://arxiv.org/abs/1801.03924) metric. Use the `compute_lpips` function provided in [`tests/flux/utils.py`](flux/utils.py):
1. **Compute similarity:** Evaluate the similarity between your outputs and the reference images using the [LPIPS](https://arxiv.org/abs/1801.03924) metric. Use the `compute_lpips` function provided in [`tests/flux/utils.py`](flux/utils.py):
```shell
lpips = compute_lpips(dir1, dir2)
......
......@@ -13,7 +13,9 @@ def test_device_id():
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
torch_dtype = torch.float16 if is_turing("cuda:1") else torch.bfloat16
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-schnell", torch_dtype=torch_dtype, device="cuda:1"
f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{precision}_r32-flux.1-schnell.safetensors",
torch_dtype=torch_dtype,
device="cuda:1",
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch_dtype
......
......@@ -16,7 +16,9 @@ from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_pulid():
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev")
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipeline = PuLIDFluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
......
......@@ -22,11 +22,13 @@ def test_flux_schnell_memory(use_qencoder: bool, cpu_offload: bool, memory_limit
precision = get_precision()
pipeline_init_kwargs = {
"transformer": NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-schnell", offload=cpu_offload
f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{precision}_r32-flux.1-schnell.safetensors", offload=cpu_offload
)
}
if use_qencoder:
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
)
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
......
......@@ -7,7 +7,7 @@ from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize(
"height,width,use_qencoder,expected_lpips", [(1024, 1024, True, 0.136 if get_precision() == "int4" else 0.145)]
"height,width,use_qencoder,expected_lpips", [(1024, 1024, True, 0.151 if get_precision() == "int4" else 0.145)]
)
def test_flux_schnell_qencoder(height: int, width: int, use_qencoder: bool, expected_lpips: float):
run_test(
......
......@@ -9,7 +9,7 @@ from .utils import run_test
@pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips",
[
(1024, 1024, "flashattn2", False, 0.126 if get_precision() == "int4" else 0.126),
(1024, 1024, "flashattn2", False, 0.141 if get_precision() == "int4" else 0.126),
(1024, 1024, "nunchaku-fp16", False, 0.139 if get_precision() == "int4" else 0.126),
(1920, 1080, "nunchaku-fp16", False, 0.190 if get_precision() == "int4" else 0.138),
(2048, 2048, "nunchaku-fp16", True, 0.166 if get_precision() == "int4" else 0.120),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment