Unverified Commit e2116d7b authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

chore: release v0.3.0

parents 6098c419 d94c2078
...@@ -9,7 +9,9 @@ image = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev ...@@ -9,7 +9,9 @@ image = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev
mask = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/mask.png") mask = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/mask.png")
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-fill-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-fill-dev/svdq-{precision}_r32-flux.1-fill-dev.safetensors"
)
pipe = FluxFillPipeline.from_pretrained( pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-Fill-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda") ).to("cuda")
......
...@@ -9,7 +9,9 @@ precision = get_precision() ...@@ -9,7 +9,9 @@ precision = get_precision()
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained( pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Redux-dev", torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-Redux-dev", torch_dtype=torch.bfloat16
).to("cuda") ).to("cuda")
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipe = FluxPipeline.from_pretrained( pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-dev",
text_encoder=None, text_encoder=None,
......
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
from nunchaku.utils import get_precision
def main():
pipeline_init_kwargs = {}
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-schnell")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16, **pipeline_init_kwargs
).to("cuda")
image = pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
).images[0]
image.save(f"flux.1-schnell-qencoder-{precision}.png")
if __name__ == "__main__":
main()
...@@ -5,7 +5,9 @@ from nunchaku import NunchakuFluxTransformer2dModel ...@@ -5,7 +5,9 @@ from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-schnell") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{precision}_r32-flux.1-schnell.safetensors"
)
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda") ).to("cuda")
......
...@@ -4,7 +4,9 @@ from diffusers import SanaPipeline ...@@ -4,7 +4,9 @@ from diffusers import SanaPipeline
from nunchaku import NunchakuSanaTransformer2DModel from nunchaku import NunchakuSanaTransformer2DModel
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m") transformer = NunchakuSanaTransformer2DModel.from_pretrained(
"mit-han-lab/nunchaku-sana/svdq-int4_r32-sana1.6b.safetensors"
)
pipe = SanaPipeline.from_pretrained( pipe = SanaPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
transformer=transformer, transformer=transformer,
...@@ -28,4 +30,4 @@ image = pipe( ...@@ -28,4 +30,4 @@ image = pipe(
generator=torch.Generator().manual_seed(42), generator=torch.Generator().manual_seed(42),
).images[0] ).images[0]
image.save("sana_1600m-int4.png") image.save("sana1.6b-int4.png")
...@@ -3,7 +3,9 @@ from diffusers import SanaPipeline ...@@ -3,7 +3,9 @@ from diffusers import SanaPipeline
from nunchaku import NunchakuSanaTransformer2DModel from nunchaku import NunchakuSanaTransformer2DModel
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m") transformer = NunchakuSanaTransformer2DModel.from_pretrained(
"mit-han-lab/nunchaku-sana/svdq-int4_r32-sana1.6b.safetensors"
)
pipe = SanaPipeline.from_pretrained( pipe = SanaPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
transformer=transformer, transformer=transformer,
...@@ -23,4 +25,4 @@ image = pipe( ...@@ -23,4 +25,4 @@ image = pipe(
generator=torch.Generator().manual_seed(42), generator=torch.Generator().manual_seed(42),
).images[0] ).images[0]
image.save("sana_1600m-int4.png") image.save("sana1.6b-int4.png")
...@@ -3,7 +3,9 @@ from diffusers import SanaPAGPipeline ...@@ -3,7 +3,9 @@ from diffusers import SanaPAGPipeline
from nunchaku import NunchakuSanaTransformer2DModel from nunchaku import NunchakuSanaTransformer2DModel
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m", pag_layers=8) transformer = NunchakuSanaTransformer2DModel.from_pretrained(
"mit-han-lab/nunchaku-sana/svdq-int4_r32-sana1.6b.safetensors", pag_layers=8
)
pipe = SanaPAGPipeline.from_pretrained( pipe = SanaPAGPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
transformer=transformer, transformer=transformer,
...@@ -24,4 +26,4 @@ image = pipe( ...@@ -24,4 +26,4 @@ image = pipe(
pag_scale=2.0, pag_scale=2.0,
num_inference_steps=20, num_inference_steps=20,
).images[0] ).images[0]
image.save("sana_1600m_pag-int4.png") image.save("sana1.6b_pag-int4.png")
__version__ = "0.3.0dev" __version__ = "0.3.0"
import argparse
import json
import os
from pathlib import Path
import torch
from huggingface_hub import constants, hf_hub_download
from safetensors.torch import save_file
from .utils import load_state_dict_in_safetensors
def merge_safetensors(
pretrained_model_name_or_path: str | os.PathLike[str], **kwargs
) -> tuple[dict[str, torch.Tensor], dict[str, str]]:
subfolder = kwargs.get("subfolder", None)
comfy_config_path = kwargs.get("comfy_config_path", None)
if isinstance(pretrained_model_name_or_path, str):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
if pretrained_model_name_or_path.exists():
dirpath = pretrained_model_name_or_path if subfolder is None else pretrained_model_name_or_path / subfolder
unquantized_part_path = dirpath / "unquantized_layers.safetensors"
transformer_block_path = dirpath / "transformer_blocks.safetensors"
config_path = dirpath / "config.json"
if comfy_config_path is None:
comfy_config_path = dirpath / "comfy_config.json"
else:
download_kwargs = {
"subfolder": subfolder,
"repo_type": "model",
"revision": kwargs.get("revision", None),
"cache_dir": kwargs.get("cache_dir", None),
"local_dir": kwargs.get("local_dir", None),
"user_agent": kwargs.get("user_agent", None),
"force_download": kwargs.get("force_download", False),
"proxies": kwargs.get("proxies", None),
"etag_timeout": kwargs.get("etag_timeout", constants.DEFAULT_ETAG_TIMEOUT),
"token": kwargs.get("token", None),
"local_files_only": kwargs.get("local_files_only", None),
"headers": kwargs.get("headers", None),
"endpoint": kwargs.get("endpoint", None),
"resume_download": kwargs.get("resume_download", None),
"force_filename": kwargs.get("force_filename", None),
"local_dir_use_symlinks": kwargs.get("local_dir_use_symlinks", "auto"),
}
unquantized_part_path = hf_hub_download(
repo_id=str(pretrained_model_name_or_path), filename="unquantized_layers.safetensors", **download_kwargs
)
transformer_block_path = hf_hub_download(
repo_id=str(pretrained_model_name_or_path), filename="transformer_blocks.safetensors", **download_kwargs
)
config_path = hf_hub_download(
repo_id=str(pretrained_model_name_or_path), filename="config.json", **download_kwargs
)
comfy_config_path = hf_hub_download(
repo_id=str(pretrained_model_name_or_path), filename="comfy_config.json", **download_kwargs
)
unquantized_part_sd = load_state_dict_in_safetensors(unquantized_part_path)
transformer_block_sd = load_state_dict_in_safetensors(transformer_block_path)
state_dict = unquantized_part_sd
state_dict.update(transformer_block_sd)
precision = "int4"
for v in state_dict.values():
assert isinstance(v, torch.Tensor)
if v.dtype in [
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e8m0fnu,
]:
precision = "fp4"
quantization_config = {
"method": "svdquant",
"weight": {
"dtype": "fp4_e2m1_all" if precision == "fp4" else "int4",
"scale_dtype": [None, "fp8_e4m3_nan"] if precision == "fp4" else None,
"group_size": 16 if precision == "fp4" else 64,
},
"activation": {
"dtype": "fp4_e2m1_all" if precision == "fp4" else "int4",
"scale_dtype": "fp8_e4m3_nan" if precision == "fp4" else None,
"group_size": 16 if precision == "fp4" else 64,
},
}
return state_dict, {
"config": Path(config_path).read_text(),
"comfy_config": Path(comfy_config_path).read_text(),
"model_class": "NunchakuFluxTransformer2dModel",
"quantization_config": json.dumps(quantization_config),
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--input-path",
type=Path,
required=True,
help="Path to model directory. It can also be a huggingface repo.",
)
parser.add_argument("-o", "--output-path", type=Path, required=True, help="Path to output path")
args = parser.parse_args()
state_dict, metadata = merge_safetensors(args.input_path)
output_path = Path(args.output_path)
dirpath = output_path.parent
dirpath.mkdir(parents=True, exist_ok=True)
save_file(state_dict, output_path, metadata=metadata)
import json
import logging
import os import os
from pathlib import Path
import torch import torch
from accelerate import init_empty_weights from accelerate import init_empty_weights
from huggingface_hub import constants, hf_hub_download
from safetensors.torch import load_file
from torch import nn from torch import nn
from transformers import PretrainedConfig, T5Config, T5EncoderModel from transformers import T5Config, T5EncoderModel
from ...utils import load_state_dict_in_safetensors
from .linear import W4Linear from .linear import W4Linear
# Get log level from environment variable (default to INFO)
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
# Configure logging
logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
class NunchakuT5EncoderModel(T5EncoderModel): class NunchakuT5EncoderModel(T5EncoderModel):
@classmethod @classmethod
def from_pretrained( def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
cls, pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
pretrained_model_name_or_path: str | os.PathLike, state_dict, metadata = load_state_dict_in_safetensors(pretrained_model_name_or_path, return_metadata=True)
config: PretrainedConfig | str | os.PathLike | None = None,
cache_dir: str | os.PathLike | None = None,
force_download: bool = False,
local_files_only: bool = False,
token: str | bool | None = None,
revision: str = "main",
**kwargs,
):
subfolder = kwargs.get("subfolder", None)
if os.path.exists(pretrained_model_name_or_path):
dirname = (
pretrained_model_name_or_path
if subfolder is None
else os.path.join(pretrained_model_name_or_path, subfolder)
)
qmodel_path = os.path.join(dirname, "svdq-t5.safetensors")
config_path = os.path.join(dirname, "config.json")
else:
shared_kwargs = {
"repo_id": pretrained_model_name_or_path,
"subfolder": subfolder,
"repo_type": "model",
"revision": revision,
"library_name": kwargs.get("library_name"),
"library_version": kwargs.get("library_version"),
"cache_dir": cache_dir,
"local_dir": kwargs.get("local_dir"),
"user_agent": kwargs.get("user_agent"),
"force_download": force_download,
"proxies": kwargs.get("proxies"),
"etag_timeout": kwargs.get("etag_timeout", constants.DEFAULT_ETAG_TIMEOUT),
"token": token,
"local_files_only": local_files_only,
"headers": kwargs.get("headers"),
"endpoint": kwargs.get("endpoint"),
"resume_download": kwargs.get("resume_download"),
"force_filename": kwargs.get("force_filename"),
"local_dir_use_symlinks": kwargs.get("local_dir_use_symlinks", "auto"),
}
qmodel_path = hf_hub_download(filename="svdq-t5.safetensors", **shared_kwargs)
config_path = hf_hub_download(filename="config.json", **shared_kwargs)
# Load the config file # Load the config file
config = T5Config.from_json_file(config_path) config = json.loads(metadata["config"])
config = T5Config(**config)
# Initialize model on 'meta' device (no memory allocation for weights) # Initialize model on 'meta' device (no memory allocation for weights)
with init_empty_weights(): with init_empty_weights():
t5_encoder = T5EncoderModel(config).to(kwargs.get("torch_dtype", torch.bfloat16)) t5_encoder = T5EncoderModel(config).to(kwargs.get("torch_dtype", torch.bfloat16))
t5_encoder.eval() t5_encoder.eval()
# Load the model weights from the safetensors file
state_dict = load_file(qmodel_path)
# Load the model weights from the safetensors file
named_modules = {} named_modules = {}
for name, module in t5_encoder.named_modules(): for name, module in t5_encoder.named_modules():
assert isinstance(name, str) assert isinstance(name, str)
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
if f"{name}.qweight" in state_dict: if f"{name}.qweight" in state_dict:
print(f"Switching {name} to W4Linear") logger.debug(f"Switching {name} to W4Linear")
qmodule = W4Linear.from_linear(module, group_size=128, init_only=True) qmodule = W4Linear.from_linear(module, group_size=128, init_only=True)
# modeling_t5.py: T5DenseGatedActDense needs dtype of weight # modeling_t5.py: T5DenseGatedActDense needs dtype of weight
qmodule.weight = torch.empty([1], dtype=module.weight.dtype, device=module.weight.device) qmodule.weight = torch.empty([1], dtype=module.weight.dtype, device=module.weight.device)
......
import json
import logging import logging
import os import os
from pathlib import Path
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
import diffusers import diffusers
...@@ -16,7 +18,7 @@ from ..._C import QuantizedFluxModel ...@@ -16,7 +18,7 @@ from ..._C import QuantizedFluxModel
from ..._C import utils as cutils from ..._C import utils as cutils
from ...lora.flux.nunchaku_converter import fuse_vectors, to_nunchaku from ...lora.flux.nunchaku_converter import fuse_vectors, to_nunchaku
from ...lora.flux.utils import is_nunchaku_format from ...lora.flux.utils import is_nunchaku_format
from ...utils import get_precision, load_state_dict_in_safetensors from ...utils import check_hardware_compatibility, get_precision, load_state_dict_in_safetensors
from .utils import NunchakuModelLoaderMixin, pad_tensor from .utils import NunchakuModelLoaderMixin, pad_tensor
SVD_RANK = 32 SVD_RANK = 32
...@@ -260,14 +262,21 @@ class EmbedND(nn.Module): ...@@ -260,14 +262,21 @@ class EmbedND(nn.Module):
def load_quantized_module( def load_quantized_module(
path: str, device: str | torch.device = "cuda", use_fp4: bool = False, offload: bool = False, bf16: bool = True path_or_state_dict: str | os.PathLike[str] | dict[str, torch.Tensor],
device: str | torch.device = "cuda",
use_fp4: bool = False,
offload: bool = False,
bf16: bool = True,
) -> QuantizedFluxModel: ) -> QuantizedFluxModel:
device = torch.device(device) device = torch.device(device)
assert device.type == "cuda" assert device.type == "cuda"
m = QuantizedFluxModel() m = QuantizedFluxModel()
cutils.disable_memory_auto_release() cutils.disable_memory_auto_release()
m.init(use_fp4, offload, bf16, 0 if device.index is None else device.index) m.init(use_fp4, offload, bf16, 0 if device.index is None else device.index)
m.load(path) if isinstance(path_or_state_dict, dict):
m.loadDict(path_or_state_dict, True)
else:
m.load(str(path_or_state_dict))
return m return m
...@@ -307,25 +316,45 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader ...@@ -307,25 +316,45 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
self._quantized_part_vectors: dict[str, torch.Tensor] = {} self._quantized_part_vectors: dict[str, torch.Tensor] = {}
self._original_in_channels = in_channels self._original_in_channels = in_channels
# Comfyui LoRA related # ComfyUI LoRA related
self.comfy_lora_meta_list = [] self.comfy_lora_meta_list = []
self.comfy_lora_sd_list = [] self.comfy_lora_sd_list = []
@classmethod @classmethod
@utils.validate_hf_hub_args @utils.validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
device = kwargs.get("device", "cuda") device = kwargs.get("device", "cuda")
if isinstance(device, str): if isinstance(device, str):
device = torch.device(device) device = torch.device(device)
offload = kwargs.get("offload", False) offload = kwargs.get("offload", False)
torch_dtype = kwargs.get("torch_dtype", torch.bfloat16) torch_dtype = kwargs.get("torch_dtype", torch.bfloat16)
precision = get_precision(kwargs.get("precision", "auto"), device, pretrained_model_name_or_path) precision = get_precision(kwargs.get("precision", "auto"), device, pretrained_model_name_or_path)
transformer, unquantized_part_path, transformer_block_path = cls._build_model( metadata = None
pretrained_model_name_or_path, **kwargs
) if isinstance(pretrained_model_name_or_path, str):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
if pretrained_model_name_or_path.is_file() or pretrained_model_name_or_path.name.endswith(
(".safetensors", ".sft")
):
transformer, model_state_dict, metadata = cls._build_model(pretrained_model_name_or_path, **kwargs)
quantized_part_sd = {}
unquantized_part_sd = {}
for k, v in model_state_dict.items():
if k.startswith(("transformer_blocks.", "single_transformer_blocks.")):
quantized_part_sd[k] = v
else:
unquantized_part_sd[k] = v
precision = get_precision(device=device)
quantization_config = json.loads(metadata["quantization_config"])
check_hardware_compatibility(quantization_config, device)
else:
transformer, unquantized_part_path, transformer_block_path = cls._build_model_legacy(
pretrained_model_name_or_path, **kwargs
)
# get the default LoRA branch and all the vectors # get the default LoRA branch and all the vectors
quantized_part_sd = load_file(transformer_block_path) quantized_part_sd = load_file(transformer_block_path)
unquantized_part_sd = load_file(unquantized_part_path)
new_quantized_part_sd = {} new_quantized_part_sd = {}
for k, v in quantized_part_sd.items(): for k, v in quantized_part_sd.items():
if v.ndim == 1: if v.ndim == 1:
...@@ -348,7 +377,7 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader ...@@ -348,7 +377,7 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
new_quantized_part_sd[k] = v new_quantized_part_sd[k] = v
transformer._quantized_part_sd = new_quantized_part_sd transformer._quantized_part_sd = new_quantized_part_sd
m = load_quantized_module( m = load_quantized_module(
transformer_block_path, quantized_part_sd,
device=device, device=device,
use_fp4=precision == "fp4", use_fp4=precision == "fp4",
offload=offload, offload=offload,
...@@ -357,11 +386,13 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader ...@@ -357,11 +386,13 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
transformer.inject_quantized_module(m, device) transformer.inject_quantized_module(m, device)
transformer.to_empty(device=device) transformer.to_empty(device=device)
unquantized_part_sd = load_file(unquantized_part_path)
transformer.load_state_dict(unquantized_part_sd, strict=False) transformer.load_state_dict(unquantized_part_sd, strict=False)
transformer._unquantized_part_sd = unquantized_part_sd transformer._unquantized_part_sd = unquantized_part_sd
return transformer if kwargs.get("return_metadata", False):
return transformer, metadata
else:
return transformer
def inject_quantized_module(self, m: QuantizedFluxModel, device: str | torch.device = "cuda"): def inject_quantized_module(self, m: QuantizedFluxModel, device: str | torch.device = "cuda"):
print("Injecting quantized module") print("Injecting quantized module")
......
import os import os
from pathlib import Path
from typing import Optional from typing import Optional
import torch import torch
...@@ -139,21 +140,48 @@ class NunchakuSanaTransformerBlocks(nn.Module): ...@@ -139,21 +140,48 @@ class NunchakuSanaTransformerBlocks(nn.Module):
class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoaderMixin): class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoaderMixin):
@classmethod @classmethod
@utils.validate_hf_hub_args @utils.validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
device = kwargs.get("device", "cuda") device = kwargs.get("device", "cuda")
if isinstance(device, str):
device = torch.device(device)
pag_layers = kwargs.get("pag_layers", []) pag_layers = kwargs.get("pag_layers", [])
precision = get_precision(kwargs.get("precision", "auto"), device, pretrained_model_name_or_path) precision = get_precision(kwargs.get("precision", "auto"), device, pretrained_model_name_or_path)
transformer, unquantized_part_path, transformer_block_path = cls._build_model( metadata = None
pretrained_model_name_or_path, **kwargs
) if isinstance(pretrained_model_name_or_path, str):
m = load_quantized_module( pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
transformer, transformer_block_path, device=device, pag_layers=pag_layers, use_fp4=precision == "fp4" if pretrained_model_name_or_path.is_file() or pretrained_model_name_or_path.name.endswith(
) (".safetensors", ".sft")
transformer.inject_quantized_module(m, device) ):
transformer.to_empty(device=device) transformer, model_state_dict, metadata = cls._build_model(pretrained_model_name_or_path)
unquantized_state_dict = load_file(unquantized_part_path) quantized_part_sd = {}
transformer.load_state_dict(unquantized_state_dict, strict=False) unquantized_part_sd = {}
return transformer for k, v in model_state_dict.items():
if k.startswith("transformer_blocks."):
quantized_part_sd[k] = v
else:
unquantized_part_sd[k] = v
m = load_quantized_module(
transformer, quantized_part_sd, device=device, pag_layers=pag_layers, use_fp4=precision == "fp4"
)
transformer.inject_quantized_module(m, device)
transformer.to_empty(device=device)
transformer.load_state_dict(unquantized_part_sd, strict=False)
else:
transformer, unquantized_part_path, transformer_block_path = cls._build_model_legacy(
pretrained_model_name_or_path, **kwargs
)
m = load_quantized_module(
transformer, transformer_block_path, device=device, pag_layers=pag_layers, use_fp4=precision == "fp4"
)
transformer.inject_quantized_module(m, device)
transformer.to_empty(device=device)
unquantized_state_dict = load_file(unquantized_part_path)
transformer.load_state_dict(unquantized_state_dict, strict=False)
if kwargs.get("return_metadata", False):
return transformer, metadata
else:
return transformer
def inject_quantized_module(self, m: QuantizedSanaModel, device: str | torch.device = "cuda"): def inject_quantized_module(self, m: QuantizedSanaModel, device: str | torch.device = "cuda"):
self.transformer_blocks = torch.nn.ModuleList([NunchakuSanaTransformerBlocks(m, self.dtype, device)]) self.transformer_blocks = torch.nn.ModuleList([NunchakuSanaTransformerBlocks(m, self.dtype, device)])
...@@ -162,7 +190,7 @@ class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoader ...@@ -162,7 +190,7 @@ class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoader
def load_quantized_module( def load_quantized_module(
net: SanaTransformer2DModel, net: SanaTransformer2DModel,
path: str, path_or_state_dict: str | os.PathLike[str] | dict[str, torch.Tensor],
device: str | torch.device = "cuda", device: str | torch.device = "cuda",
pag_layers: int | list[int] | None = None, pag_layers: int | list[int] | None = None,
use_fp4: bool = False, use_fp4: bool = False,
...@@ -177,7 +205,10 @@ def load_quantized_module( ...@@ -177,7 +205,10 @@ def load_quantized_module(
m = QuantizedSanaModel() m = QuantizedSanaModel()
cutils.disable_memory_auto_release() cutils.disable_memory_auto_release()
m.init(net.config, pag_layers, use_fp4, net.dtype == torch.bfloat16, 0 if device.index is None else device.index) m.init(net.config, pag_layers, use_fp4, net.dtype == torch.bfloat16, 0 if device.index is None else device.index)
m.load(path) if isinstance(path_or_state_dict, dict):
m.loadDict(path_or_state_dict, True)
else:
m.load(str(path_or_state_dict))
return m return m
......
import json
import logging
import os import os
from pathlib import Path
from typing import Any, Optional from typing import Any, Optional
import torch import torch
...@@ -6,12 +9,44 @@ from diffusers import __version__ ...@@ -6,12 +9,44 @@ from diffusers import __version__
from huggingface_hub import constants, hf_hub_download from huggingface_hub import constants, hf_hub_download
from torch import nn from torch import nn
from nunchaku.utils import ceil_divide from nunchaku.utils import ceil_divide, load_state_dict_in_safetensors
# Get log level from environment variable (default to INFO)
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
# Configure logging
logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
class NunchakuModelLoaderMixin: class NunchakuModelLoaderMixin:
@classmethod
def _build_model(
cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs
) -> tuple[nn.Module, dict[str, torch.Tensor], dict[str, str]]:
if isinstance(pretrained_model_name_or_path, str):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
state_dict, metadata = load_state_dict_in_safetensors(pretrained_model_name_or_path, return_metadata=True)
# Load the config file
config = json.loads(metadata["config"])
with torch.device("meta"):
transformer = cls.from_config(config).to(kwargs.get("torch_dtype", torch.bfloat16))
return transformer, state_dict, metadata
@classmethod @classmethod
def _build_model(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs) -> tuple[nn.Module, str, str]: def _build_model_legacy(
cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs
) -> tuple[nn.Module, str, str]:
logger.warning(
"Loading models from a folder will be deprecated in v0.4. "
"Please download the latest safetensors model, or use one of the following tools to "
"merge your model into a single file: the CLI utility `python -m nunchaku.merge_safetensors` "
"or the ComfyUI workflow `merge_safetensors.json`."
)
subfolder = kwargs.get("subfolder", None) subfolder = kwargs.get("subfolder", None)
if os.path.exists(pretrained_model_name_or_path): if os.path.exists(pretrained_model_name_or_path):
dirname = ( dirname = (
...@@ -41,10 +76,10 @@ class NunchakuModelLoaderMixin: ...@@ -41,10 +76,10 @@ class NunchakuModelLoaderMixin:
"local_dir_use_symlinks": kwargs.get("local_dir_use_symlinks", "auto"), "local_dir_use_symlinks": kwargs.get("local_dir_use_symlinks", "auto"),
} }
unquantized_part_path = hf_hub_download( unquantized_part_path = hf_hub_download(
repo_id=pretrained_model_name_or_path, filename="unquantized_layers.safetensors", **download_kwargs repo_id=str(pretrained_model_name_or_path), filename="unquantized_layers.safetensors", **download_kwargs
) )
transformer_block_path = hf_hub_download( transformer_block_path = hf_hub_download(
repo_id=pretrained_model_name_or_path, filename="transformer_blocks.safetensors", **download_kwargs repo_id=str(pretrained_model_name_or_path), filename="transformer_blocks.safetensors", **download_kwargs
) )
cache_dir = kwargs.pop("cache_dir", None) cache_dir = kwargs.pop("cache_dir", None)
...@@ -70,7 +105,6 @@ class NunchakuModelLoaderMixin: ...@@ -70,7 +105,6 @@ class NunchakuModelLoaderMixin:
with torch.device("meta"): with torch.device("meta"):
transformer = cls.from_config(config).to(kwargs.get("torch_dtype", torch.bfloat16)) transformer = cls.from_config(config).to(kwargs.get("torch_dtype", torch.bfloat16))
return transformer, unquantized_part_path, transformer_block_path return transformer, unquantized_part_path, transformer_block_path
......
import os import os
import warnings import warnings
from pathlib import Path
import safetensors import safetensors
import torch import torch
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
def fetch_or_download(path: str, repo_type: str = "model") -> str: def fetch_or_download(path: str | Path, repo_type: str = "model") -> Path:
if not os.path.exists(path): path = Path(path)
hf_repo_id = os.path.dirname(path)
filename = os.path.basename(path) if path.exists():
path = hf_hub_download(repo_id=hf_repo_id, filename=filename, repo_type=repo_type) return path
return path
parts = path.parts
if len(parts) < 3:
raise ValueError(f"Path '{path}' is too short to extract repo_id and subfolder")
repo_id = "/".join(parts[:2])
sub_path = Path(*parts[2:])
filename = sub_path.name
subfolder = sub_path.parent if sub_path.parent != Path(".") else None
path = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder, repo_type=repo_type)
return Path(path)
def ceil_divide(x: int, divisor: int) -> int: def ceil_divide(x: int, divisor: int) -> int:
...@@ -31,29 +43,22 @@ def ceil_divide(x: int, divisor: int) -> int: ...@@ -31,29 +43,22 @@ def ceil_divide(x: int, divisor: int) -> int:
def load_state_dict_in_safetensors( def load_state_dict_in_safetensors(
path: str, device: str | torch.device = "cpu", filter_prefix: str = "" path: str | os.PathLike[str],
) -> dict[str, torch.Tensor]: device: str | torch.device = "cpu",
"""Load state dict in SafeTensors. filter_prefix: str = "",
return_metadata: bool = False,
Args: ) -> dict[str, torch.Tensor] | tuple[dict[str, torch.Tensor], dict[str, str]]:
path (`str`):
file path.
device (`str` | `torch.device`, optional, defaults to `"cpu"`):
device.
filter_prefix (`str`, optional, defaults to `""`):
filter prefix.
Returns:
`dict`:
loaded SafeTensors.
"""
state_dict = {} state_dict = {}
with safetensors.safe_open(fetch_or_download(path), framework="pt", device=device) as f: with safetensors.safe_open(fetch_or_download(path), framework="pt", device=device) as f:
metadata = f.metadata()
for k in f.keys(): for k in f.keys():
if filter_prefix and not k.startswith(filter_prefix): if filter_prefix and not k.startswith(filter_prefix):
continue continue
state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k) state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k)
return state_dict if return_metadata:
return state_dict, metadata
else:
return state_dict
def filter_state_dict(state_dict: dict[str, torch.Tensor], filter_prefix: str = "") -> dict[str, torch.Tensor]: def filter_state_dict(state_dict: dict[str, torch.Tensor], filter_prefix: str = "") -> dict[str, torch.Tensor]:
...@@ -73,7 +78,9 @@ def filter_state_dict(state_dict: dict[str, torch.Tensor], filter_prefix: str = ...@@ -73,7 +78,9 @@ def filter_state_dict(state_dict: dict[str, torch.Tensor], filter_prefix: str =
def get_precision( def get_precision(
precision: str = "auto", device: str | torch.device = "cuda", pretrained_model_name_or_path: str | None = None precision: str = "auto",
device: str | torch.device = "cuda",
pretrained_model_name_or_path: str | os.PathLike[str] | None = None,
) -> str: ) -> str:
assert precision in ("auto", "int4", "fp4") assert precision in ("auto", "int4", "fp4")
if precision == "auto": if precision == "auto":
...@@ -84,10 +91,10 @@ def get_precision( ...@@ -84,10 +91,10 @@ def get_precision(
precision = "fp4" if sm == "120" else "int4" precision = "fp4" if sm == "120" else "int4"
if pretrained_model_name_or_path is not None: if pretrained_model_name_or_path is not None:
if precision == "int4": if precision == "int4":
if "fp4" in pretrained_model_name_or_path: if "fp4" in str(pretrained_model_name_or_path):
warnings.warn("The model may be quantized to fp4, but you are loading it with int4 precision.") warnings.warn("The model may be quantized to fp4, but you are loading it with int4 precision.")
elif precision == "fp4": elif precision == "fp4":
if "int4" in pretrained_model_name_or_path: if "int4" in str(pretrained_model_name_or_path):
warnings.warn("The model may be quantized to int4, but you are loading it with fp4 precision.") warnings.warn("The model may be quantized to int4, but you are loading it with fp4 precision.")
return precision return precision
...@@ -128,3 +135,21 @@ def get_gpu_memory(device: str | torch.device = "cuda", unit: str = "GiB") -> in ...@@ -128,3 +135,21 @@ def get_gpu_memory(device: str | torch.device = "cuda", unit: str = "GiB") -> in
return memory // (1024**2) return memory // (1024**2)
else: else:
return memory return memory
def check_hardware_compatibility(quantization_config: dict, device: str | torch.device = "cuda"):
if isinstance(device, str):
device = torch.device(device)
capability = torch.cuda.get_device_capability(0 if device.index is None else device.index)
sm = f"{capability[0]}{capability[1]}"
if sm == "120": # you can only use the fp4 models
if quantization_config["weight"]["dtype"] != "fp4_e2m1_all":
raise ValueError('Please use "fp4" quantization for Blackwell GPUs. ')
elif sm in ["75", "80", "86", "89"]:
if quantization_config["weight"]["dtype"] != "int4":
raise ValueError('Please use "int4" quantization for Turing, Ampere and Ada GPUs. ')
else:
raise ValueError(
f"Unsupported GPU architecture {sm} due to the lack of 4-bit tensorcores. "
"Please use a Turing, Ampere, Ada or Blackwell GPU for this quantization configuration."
)
# Nunchaku Tests # Nunchaku Tests
Nunchaku uses pytest as its testing framework. Nunchaku uses pytest as its testing framework.
## Setting Up Test Environments ## Setting Up Test Environments
After installing `nunchaku` as described in the [README](../README.md#installation), you can install the test dependencies with: After installing `nunchaku` as described in the [README](../README.md#installation), you can install the test dependencies with:
```shell ```shell
pip install -r tests/requirements.txt pip install -r tests/requirements.txt
``` ```
## Running the Tests ## Running the Tests
```shell ```shell
HF_TOKEN=$YOUR_HF_TOKEN pytest -v tests/flux/test_flux_memory.py HF_TOKEN=$YOUR_HF_TOKEN pytest -v tests/flux/test_flux_memory.py
HF_TOKEN=$YOUR_HF_TOKEN pytest -v tests/flux --ignore=tests/flux/test_flux_memory.py HF_TOKEN=$YOUR_HF_TOKEN pytest -v tests/flux --ignore=tests/flux/test_flux_memory.py
...@@ -15,7 +19,7 @@ HF_TOKEN=$YOUR_HF_TOKEN pytest -v tests/sana ...@@ -15,7 +19,7 @@ HF_TOKEN=$YOUR_HF_TOKEN pytest -v tests/sana
``` ```
> **Note:** `$YOUR_HF_TOKEN` refers to your Hugging Face access token, required to download models and datasets. You can create one at [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens). > **Note:** `$YOUR_HF_TOKEN` refers to your Hugging Face access token, required to download models and datasets. You can create one at [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens).
> If you've already logged in using `huggingface-cli login`, you can skip setting this environment variable. > If you've already logged in using `huggingface-cli login`, you can skip setting this environment variable.
Some tests generate images using the original 16-bit models. You can cache these results to speed up future test runs by setting the environment variable `NUNCHAKU_TEST_CACHE_ROOT`. If not set, the images will be saved in `test_results/ref`. Some tests generate images using the original 16-bit models. You can cache these results to speed up future test runs by setting the environment variable `NUNCHAKU_TEST_CACHE_ROOT`. If not set, the images will be saved in `test_results/ref`.
...@@ -27,9 +31,9 @@ To test visual output correctness, you can: ...@@ -27,9 +31,9 @@ To test visual output correctness, you can:
1. **Generate reference images:** Use the original 16-bit model to produce a small number of reference images (e.g., 4). 1. **Generate reference images:** Use the original 16-bit model to produce a small number of reference images (e.g., 4).
2. **Generate comparison images:** Run your method using the **same inputs and seeds** to ensure deterministic outputs. You can control the seed by setting the `generator` parameter in the diffusers pipeline. 1. **Generate comparison images:** Run your method using the **same inputs and seeds** to ensure deterministic outputs. You can control the seed by setting the `generator` parameter in the diffusers pipeline.
3. **Compute similarity:** Evaluate the similarity between your outputs and the reference images using the [LPIPS](https://arxiv.org/abs/1801.03924) metric. Use the `compute_lpips` function provided in [`tests/flux/utils.py`](flux/utils.py): 1. **Compute similarity:** Evaluate the similarity between your outputs and the reference images using the [LPIPS](https://arxiv.org/abs/1801.03924) metric. Use the `compute_lpips` function provided in [`tests/flux/utils.py`](flux/utils.py):
```shell ```shell
lpips = compute_lpips(dir1, dir2) lpips = compute_lpips(dir1, dir2)
......
...@@ -13,7 +13,9 @@ def test_device_id(): ...@@ -13,7 +13,9 @@ def test_device_id():
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
torch_dtype = torch.float16 if is_turing("cuda:1") else torch.bfloat16 torch_dtype = torch.float16 if is_turing("cuda:1") else torch.bfloat16
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-schnell", torch_dtype=torch_dtype, device="cuda:1" f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{precision}_r32-flux.1-schnell.safetensors",
torch_dtype=torch_dtype,
device="cuda:1",
) )
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch_dtype "black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch_dtype
......
...@@ -16,7 +16,9 @@ from nunchaku.utils import get_precision, is_turing ...@@ -16,7 +16,9 @@ from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_pulid(): def test_flux_dev_pulid():
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev") transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipeline = PuLIDFluxPipeline.from_pretrained( pipeline = PuLIDFluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-dev",
......
...@@ -22,11 +22,13 @@ def test_flux_schnell_memory(use_qencoder: bool, cpu_offload: bool, memory_limit ...@@ -22,11 +22,13 @@ def test_flux_schnell_memory(use_qencoder: bool, cpu_offload: bool, memory_limit
precision = get_precision() precision = get_precision()
pipeline_init_kwargs = { pipeline_init_kwargs = {
"transformer": NunchakuFluxTransformer2dModel.from_pretrained( "transformer": NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-schnell", offload=cpu_offload f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{precision}_r32-flux.1-schnell.safetensors", offload=cpu_offload
) )
} }
if use_qencoder: if use_qencoder:
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5") text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
)
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2 pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
......
...@@ -7,7 +7,7 @@ from .utils import run_test ...@@ -7,7 +7,7 @@ from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"height,width,use_qencoder,expected_lpips", [(1024, 1024, True, 0.136 if get_precision() == "int4" else 0.145)] "height,width,use_qencoder,expected_lpips", [(1024, 1024, True, 0.151 if get_precision() == "int4" else 0.145)]
) )
def test_flux_schnell_qencoder(height: int, width: int, use_qencoder: bool, expected_lpips: float): def test_flux_schnell_qencoder(height: int, width: int, use_qencoder: bool, expected_lpips: float):
run_test( run_test(
......
...@@ -9,7 +9,7 @@ from .utils import run_test ...@@ -9,7 +9,7 @@ from .utils import run_test
@pytest.mark.parametrize( @pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips", "height,width,attention_impl,cpu_offload,expected_lpips",
[ [
(1024, 1024, "flashattn2", False, 0.126 if get_precision() == "int4" else 0.126), (1024, 1024, "flashattn2", False, 0.141 if get_precision() == "int4" else 0.126),
(1024, 1024, "nunchaku-fp16", False, 0.139 if get_precision() == "int4" else 0.126), (1024, 1024, "nunchaku-fp16", False, 0.139 if get_precision() == "int4" else 0.126),
(1920, 1080, "nunchaku-fp16", False, 0.190 if get_precision() == "int4" else 0.138), (1920, 1080, "nunchaku-fp16", False, 0.190 if get_precision() == "int4" else 0.138),
(2048, 2048, "nunchaku-fp16", True, 0.166 if get_precision() == "int4" else 0.120), (2048, 2048, "nunchaku-fp16", True, 0.166 if get_precision() == "int4" else 0.120),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment