Commit 2ede5f01 authored by Muyang Li's avatar Muyang Li Committed by Zhekai Zhang
Browse files

Clean some codes and refract the tests

parent 83b7542d
...@@ -4,7 +4,6 @@ from diffusers import SanaPipeline ...@@ -4,7 +4,6 @@ from diffusers import SanaPipeline
from nunchaku import NunchakuSanaTransformer2DModel from nunchaku import NunchakuSanaTransformer2DModel
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m") transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
pipe = SanaPipeline.from_pretrained( pipe = SanaPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
...@@ -29,4 +28,4 @@ image = pipe( ...@@ -29,4 +28,4 @@ image = pipe(
generator=torch.Generator().manual_seed(42), generator=torch.Generator().manual_seed(42),
).images[0] ).images[0]
image.save("sana_1600m.png") image.save("sana_1600m-int4.png")
...@@ -23,4 +23,4 @@ image = pipe( ...@@ -23,4 +23,4 @@ image = pipe(
generator=torch.Generator().manual_seed(42), generator=torch.Generator().manual_seed(42),
).images[0] ).images[0]
image.save("sana_1600m.png") image.save("sana_1600m-int4.png")
...@@ -24,4 +24,4 @@ image = pipe( ...@@ -24,4 +24,4 @@ image = pipe(
pag_scale=2.0, pag_scale=2.0,
num_inference_steps=20, num_inference_steps=20,
).images[0] ).images[0]
image.save("sana_1600m_pag.png") image.save("sana_1600m_pag-int4.png")
import functools import functools
import unittest import unittest
import torch
from diffusers import DiffusionPipeline, FluxTransformer2DModel from diffusers import DiffusionPipeline, FluxTransformer2DModel
from torch import nn
from ...caching import utils from ...caching import utils
...@@ -11,7 +11,7 @@ def apply_cache_on_transformer(transformer: FluxTransformer2DModel, *, residual_ ...@@ -11,7 +11,7 @@ def apply_cache_on_transformer(transformer: FluxTransformer2DModel, *, residual_
if getattr(transformer, "_is_cached", False): if getattr(transformer, "_is_cached", False):
return transformer return transformer
cached_transformer_blocks = torch.nn.ModuleList( cached_transformer_blocks = nn.ModuleList(
[ [
utils.FluxCachedTransformerBlocks( utils.FluxCachedTransformerBlocks(
transformer=transformer, transformer=transformer,
...@@ -20,7 +20,7 @@ def apply_cache_on_transformer(transformer: FluxTransformer2DModel, *, residual_ ...@@ -20,7 +20,7 @@ def apply_cache_on_transformer(transformer: FluxTransformer2DModel, *, residual_
) )
] ]
) )
dummy_single_transformer_blocks = torch.nn.ModuleList() dummy_single_transformer_blocks = nn.ModuleList()
original_forward = transformer.forward original_forward = transformer.forward
......
...@@ -94,7 +94,6 @@ def apply_prev_hidden_states_residual( ...@@ -94,7 +94,6 @@ def apply_prev_hidden_states_residual(
encoder_hidden_states = encoder_hidden_states_residual + encoder_hidden_states encoder_hidden_states = encoder_hidden_states_residual + encoder_hidden_states
encoder_hidden_states = encoder_hidden_states.contiguous() encoder_hidden_states = encoder_hidden_states.contiguous()
return hidden_states, encoder_hidden_states return hidden_states, encoder_hidden_states
...@@ -109,6 +108,7 @@ def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=Fals ...@@ -109,6 +108,7 @@ def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=Fals
) )
return can_use_cache return can_use_cache
class SanaCachedTransformerBlocks(nn.Module): class SanaCachedTransformerBlocks(nn.Module):
def __init__( def __init__(
self, self,
...@@ -123,7 +123,8 @@ class SanaCachedTransformerBlocks(nn.Module): ...@@ -123,7 +123,8 @@ class SanaCachedTransformerBlocks(nn.Module):
self.residual_diff_threshold = residual_diff_threshold self.residual_diff_threshold = residual_diff_threshold
self.verbose = verbose self.verbose = verbose
def forward(self, def forward(
self,
hidden_states, hidden_states,
attention_mask, attention_mask,
encoder_hidden_states, encoder_hidden_states,
...@@ -135,8 +136,7 @@ class SanaCachedTransformerBlocks(nn.Module): ...@@ -135,8 +136,7 @@ class SanaCachedTransformerBlocks(nn.Module):
batch_size = hidden_states.shape[0] batch_size = hidden_states.shape[0]
if self.residual_diff_threshold <= 0.0 or batch_size > 2: if self.residual_diff_threshold <= 0.0 or batch_size > 2:
if batch_size > 2: if batch_size > 2:
print("Batch size > 2 (for SANA CFG)" print("Batch size > 2 (for SANA CFG)" " currently not supported")
" currently not supported")
first_transformer_block = self.transformer_blocks[0] first_transformer_block = self.transformer_blocks[0]
hidden_states = first_transformer_block( hidden_states = first_transformer_block(
...@@ -199,15 +199,15 @@ class SanaCachedTransformerBlocks(nn.Module): ...@@ -199,15 +199,15 @@ class SanaCachedTransformerBlocks(nn.Module):
return hidden_states return hidden_states
def call_remaining_transformer_blocks(
def call_remaining_transformer_blocks(self, self,
hidden_states, hidden_states,
attention_mask, attention_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask=None, encoder_attention_mask=None,
timestep=None, timestep=None,
post_patch_height=None, post_patch_height=None,
post_patch_width=None post_patch_width=None,
): ):
first_transformer_block = self.transformer_blocks[0] first_transformer_block = self.transformer_blocks[0]
original_hidden_states = hidden_states original_hidden_states = hidden_states
...@@ -219,7 +219,7 @@ class SanaCachedTransformerBlocks(nn.Module): ...@@ -219,7 +219,7 @@ class SanaCachedTransformerBlocks(nn.Module):
timestep=timestep, timestep=timestep,
height=post_patch_height, height=post_patch_height,
width=post_patch_width, width=post_patch_width,
skip_first_layer=True skip_first_layer=True,
) )
hidden_states_residual = hidden_states - original_hidden_states hidden_states_residual = hidden_states - original_hidden_states
......
...@@ -13,11 +13,11 @@ from packaging.version import Version ...@@ -13,11 +13,11 @@ from packaging.version import Version
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from torch import nn from torch import nn
from .utils import get_precision, NunchakuModelLoaderMixin, pad_tensor from .utils import NunchakuModelLoaderMixin, pad_tensor
from ..._C import QuantizedFluxModel, utils as cutils from ..._C import QuantizedFluxModel, utils as cutils
from ...lora.flux.nunchaku_converter import fuse_vectors, to_nunchaku from ...lora.flux.nunchaku_converter import fuse_vectors, to_nunchaku
from ...lora.flux.utils import is_nunchaku_format from ...lora.flux.utils import is_nunchaku_format
from ...utils import load_state_dict_in_safetensors from ...utils import get_precision, load_state_dict_in_safetensors
SVD_RANK = 32 SVD_RANK = 32
...@@ -127,7 +127,7 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -127,7 +127,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
image_rotary_emb: torch.Tensor, image_rotary_emb: torch.Tensor,
joint_attention_kwargs=None, joint_attention_kwargs=None,
controlnet_block_samples=None, controlnet_block_samples=None,
controlnet_single_block_samples=None controlnet_single_block_samples=None,
): ):
batch_size = hidden_states.shape[0] batch_size = hidden_states.shape[0]
txt_tokens = encoder_hidden_states.shape[1] txt_tokens = encoder_hidden_states.shape[1]
...@@ -159,8 +159,14 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -159,8 +159,14 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_img = self.pack_rotemb(pad_tensor(rotary_emb_img, 256, 1)) rotary_emb_img = self.pack_rotemb(pad_tensor(rotary_emb_img, 256, 1))
hidden_states, encoder_hidden_states = self.m.forward_layer( hidden_states, encoder_hidden_states = self.m.forward_layer(
idx, hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_txt, idx,
controlnet_block_samples, controlnet_single_block_samples hidden_states,
encoder_hidden_states,
temb,
rotary_emb_img,
rotary_emb_txt,
controlnet_block_samples,
controlnet_single_block_samples,
) )
hidden_states = hidden_states.to(original_dtype).to(original_device) hidden_states = hidden_states.to(original_dtype).to(original_device)
...@@ -578,7 +584,7 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader ...@@ -578,7 +584,7 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs, joint_attention_kwargs=joint_attention_kwargs,
controlnet_block_samples=controlnet_block_samples, controlnet_block_samples=controlnet_block_samples,
controlnet_single_block_samples=controlnet_single_block_samples controlnet_single_block_samples=controlnet_single_block_samples,
) )
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
......
...@@ -8,7 +8,8 @@ from safetensors.torch import load_file ...@@ -8,7 +8,8 @@ from safetensors.torch import load_file
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from .utils import get_precision, NunchakuModelLoaderMixin from .utils import NunchakuModelLoaderMixin
from ...utils import get_precision
from ..._C import QuantizedSanaModel, utils as cutils from ..._C import QuantizedSanaModel, utils as cutils
SVD_RANK = 32 SVD_RANK = 32
......
import os import os
import warnings
from typing import Any, Optional from typing import Any, Optional
import torch import torch
...@@ -82,21 +81,3 @@ def pad_tensor(tensor: Optional[torch.Tensor], multiples: int, dim: int, fill: A ...@@ -82,21 +81,3 @@ def pad_tensor(tensor: Optional[torch.Tensor], multiples: int, dim: int, fill: A
result.fill_(fill) result.fill_(fill)
result[[slice(0, extent) for extent in tensor.shape]] = tensor result[[slice(0, extent) for extent in tensor.shape]] = tensor
return result return result
def get_precision(precision: str, device: str | torch.device, pretrained_model_name_or_path: str | None = None) -> str:
assert precision in ("auto", "int4", "fp4")
if precision == "auto":
if isinstance(device, str):
device = torch.device(device)
capability = torch.cuda.get_device_capability(0 if device.index is None else device.index)
sm = f"{capability[0]}{capability[1]}"
precision = "fp4" if sm == "120" else "int4"
if pretrained_model_name_or_path is not None:
if precision == "int4":
if "fp4" in pretrained_model_name_or_path:
warnings.warn("The model may be quantized to fp4, but you are loading it with int4 precision.")
elif precision == "fp4":
if "int4" in pretrained_model_name_or_path:
warnings.warn("The model may be quantized to int4, but you are loading it with fp4 precision.")
return precision
import os import os
import warnings
import safetensors import safetensors
import torch import torch
...@@ -69,3 +70,38 @@ def filter_state_dict(state_dict: dict[str, torch.Tensor], filter_prefix: str = ...@@ -69,3 +70,38 @@ def filter_state_dict(state_dict: dict[str, torch.Tensor], filter_prefix: str =
filtered state dict. filtered state dict.
""" """
return {k.removeprefix(filter_prefix): v for k, v in state_dict.items() if k.startswith(filter_prefix)} return {k.removeprefix(filter_prefix): v for k, v in state_dict.items() if k.startswith(filter_prefix)}
def get_precision(
precision: str = "auto", device: str | torch.device = "cuda", pretrained_model_name_or_path: str | None = None
) -> str:
assert precision in ("auto", "int4", "fp4")
if precision == "auto":
if isinstance(device, str):
device = torch.device(device)
capability = torch.cuda.get_device_capability(0 if device.index is None else device.index)
sm = f"{capability[0]}{capability[1]}"
precision = "fp4" if sm == "120" else "int4"
if pretrained_model_name_or_path is not None:
if precision == "int4":
if "fp4" in pretrained_model_name_or_path:
warnings.warn("The model may be quantized to fp4, but you are loading it with int4 precision.")
elif precision == "fp4":
if "int4" in pretrained_model_name_or_path:
warnings.warn("The model may be quantized to int4, but you are loading it with fp4 precision.")
return precision
def is_turing(device: str | torch.device = "cuda") -> bool:
"""Check if the current GPU is a Turing GPU.
Returns:
`bool`:
True if the current GPU is a Turing GPU, False otherwise.
"""
if isinstance(device, str):
device = torch.device(device)
device_id = 0 if device.index is None else device.index
capability = torch.cuda.get_device_capability(device_id)
sm = f"{capability[0]}{capability[1]}"
return sm == "75"
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
import random import random
import datasets import datasets
import yaml
from PIL import Image from PIL import Image
_CITATION = """\ _CITATION = """\
...@@ -32,6 +33,8 @@ IMAGE_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/ ...@@ -32,6 +33,8 @@ IMAGE_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/
META_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/meta_data.json" META_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/meta_data.json"
CONTROL_URL = "https://huggingface.co/datasets/mit-han-lab/svdquant-datasets/resolve/main/MJHQ-5000.zip"
class MJHQConfig(datasets.BuilderConfig): class MJHQConfig(datasets.BuilderConfig):
def __init__(self, max_dataset_size: int = -1, return_gt: bool = False, **kwargs): def __init__(self, max_dataset_size: int = -1, return_gt: bool = False, **kwargs):
...@@ -46,11 +49,14 @@ class MJHQConfig(datasets.BuilderConfig): ...@@ -46,11 +49,14 @@ class MJHQConfig(datasets.BuilderConfig):
self.return_gt = return_gt self.return_gt = return_gt
class DCI(datasets.GeneratorBasedBuilder): class MJHQ(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0") VERSION = datasets.Version("0.0.0")
BUILDER_CONFIG_CLASS = MJHQConfig BUILDER_CONFIG_CLASS = MJHQConfig
BUILDER_CONFIGS = [MJHQConfig(name="MJHQ", version=VERSION, description="MJHQ-30K full dataset")] BUILDER_CONFIGS = [
MJHQConfig(name="MJHQ", version=VERSION, description="MJHQ-30K full dataset"),
MJHQConfig(name="MJHQ-control", version=VERSION, description="MJHQ-5K with controls"),
]
DEFAULT_CONFIG_NAME = "MJHQ" DEFAULT_CONFIG_NAME = "MJHQ"
def _info(self): def _info(self):
...@@ -64,6 +70,10 @@ class DCI(datasets.GeneratorBasedBuilder): ...@@ -64,6 +70,10 @@ class DCI(datasets.GeneratorBasedBuilder):
"image_root": datasets.Value("string"), "image_root": datasets.Value("string"),
"image_path": datasets.Value("string"), "image_path": datasets.Value("string"),
"split": datasets.Value("string"), "split": datasets.Value("string"),
"canny_image_path": datasets.Value("string"),
"cropped_image_path": datasets.Value("string"),
"depth_image_path": datasets.Value("string"),
"mask_image_path": datasets.Value("string"),
} }
) )
return datasets.DatasetInfo( return datasets.DatasetInfo(
...@@ -71,6 +81,7 @@ class DCI(datasets.GeneratorBasedBuilder): ...@@ -71,6 +81,7 @@ class DCI(datasets.GeneratorBasedBuilder):
) )
def _split_generators(self, dl_manager: datasets.download.DownloadManager): def _split_generators(self, dl_manager: datasets.download.DownloadManager):
if self.config.name == "MJHQ":
meta_path = dl_manager.download(META_URL) meta_path = dl_manager.download(META_URL)
image_root = dl_manager.download_and_extract(IMAGE_URL) image_root = dl_manager.download_and_extract(IMAGE_URL)
return [ return [
...@@ -78,9 +89,19 @@ class DCI(datasets.GeneratorBasedBuilder): ...@@ -78,9 +89,19 @@ class DCI(datasets.GeneratorBasedBuilder):
name=datasets.Split.TRAIN, gen_kwargs={"meta_path": meta_path, "image_root": image_root} name=datasets.Split.TRAIN, gen_kwargs={"meta_path": meta_path, "image_root": image_root}
), ),
] ]
else:
assert self.config.name == "MJHQ-control"
control_root = dl_manager.download_and_extract(CONTROL_URL)
control_root = os.path.join(control_root, "MJHQ-5000")
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={"meta_path": os.path.join(control_root, "prompts.yaml"), "image_root": control_root},
),
]
def _generate_examples(self, meta_path: str, image_root: str): def _generate_examples(self, meta_path: str, image_root: str):
if self.config.name == "MJHQ":
with open(meta_path, "r") as f: with open(meta_path, "r") as f:
meta = json.load(f) meta = json.load(f)
...@@ -103,4 +124,32 @@ class DCI(datasets.GeneratorBasedBuilder): ...@@ -103,4 +124,32 @@ class DCI(datasets.GeneratorBasedBuilder):
"image_root": image_root, "image_root": image_root,
"image_path": image_path, "image_path": image_path,
"split": self.config.name, "split": self.config.name,
"canny_image_path": None,
"cropped_image_path": None,
"depth_image_path": None,
"mask_image_path": None,
}
else:
assert self.config.name == "MJHQ-control"
meta = yaml.safe_load(open(meta_path, "r"))
names = list(meta.keys())
if self.config.max_dataset_size > 0:
random.Random(0).shuffle(names)
names = names[: self.config.max_dataset_size]
names = sorted(names)
for i, name in enumerate(names):
prompt = meta[name]
yield i, {
"filename": name,
"category": None,
"image": None,
"prompt": prompt,
"meta_path": meta_path,
"image_root": image_root,
"image_path": os.path.join(image_root, "images", f"{name}.png"),
"split": self.config.name,
"canny_image_path": os.path.join(image_root, "canny_images", f"{name}.png"),
"cropped_image_path": os.path.join(image_root, "cropped_images", f"{name}.png"),
"depth_image_path": os.path.join(image_root, "depth_images", f"{name}.png"),
"mask_image_path": os.path.join(image_root, "mask_images", f"{name}.png"),
} }
...@@ -3,9 +3,16 @@ import random ...@@ -3,9 +3,16 @@ import random
import datasets import datasets
import yaml import yaml
from huggingface_hub import snapshot_download
from nunchaku.utils import fetch_or_download from nunchaku.utils import fetch_or_download
__all__ = ["get_dataset"] __all__ = ["get_dataset", "load_dataset_yaml", "download_hf_dataset"]
def download_hf_dataset(repo_id: str = "mit-han-lab/nunchaku-test", local_dir: str | None = None) -> str:
path = snapshot_download(repo_id=repo_id, repo_type="dataset", local_dir=local_dir)
return path
def load_dataset_yaml(meta_path: str, max_dataset_size: int = -1, repeat: int = 4) -> dict: def load_dataset_yaml(meta_path: str, max_dataset_size: int = -1, repeat: int = 4) -> dict:
...@@ -46,10 +53,13 @@ def get_dataset( ...@@ -46,10 +53,13 @@ def get_dataset(
path = os.path.join(prefix, f"{name}") path = os.path.join(prefix, f"{name}")
if name == "MJHQ": if name == "MJHQ":
dataset = datasets.load_dataset(path, return_gt=return_gt, **kwargs) dataset = datasets.load_dataset(path, return_gt=return_gt, **kwargs)
elif name == "MJHQ-control":
kwargs["name"] = "MJHQ-control"
dataset = datasets.load_dataset(os.path.join(prefix, "MJHQ"), return_gt=return_gt, **kwargs)
else: else:
dataset = datasets.Dataset.from_dict( dataset = datasets.Dataset.from_dict(
load_dataset_yaml( load_dataset_yaml(
fetch_or_download(f"mit-han-lab/nunchaku-test/{name}.yaml", repo_type="dataset"), fetch_or_download(f"mit-han-lab/svdquant-datasets/{name}.yaml", repo_type="dataset"),
max_dataset_size=max_dataset_size, max_dataset_size=max_dataset_size,
repeat=1, repeat=1,
), ),
......
import pytest import pytest
from .test_flux_dev import run_test_flux_dev from nunchaku.utils import get_precision, is_turing
from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests for Turing GPUs")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"height,width,num_inference_steps,cache_threshold,lora_name,use_qencoder,cpu_offload,expected_lpips", "cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips",
[ [
# (1024, 1024, 50, 0, None, False, False, 0.5), # 13min20s 5min55s 0.19539418816566467 (0.12, 1024, 1024, 30, None, 1, 0.26),
# (1024, 1024, 50, 0.05, None, False, True, 0.5), # 7min11s 0.21917256712913513 (0.12, 512, 2048, 30, "anime", 1, 0.4),
# (1024, 1024, 50, 0.12, None, False, True, 0.5), # 2min58s, 0.24101486802101135
# (1024, 1024, 50, 0.2, None, False, True, 0.5), # 2min23s, 0.3101634383201599
# (1024, 1024, 50, 0.5, None, False, True, 0.5), # 1min44s 0.6543852090835571
# (1024, 1024, 30, 0, None, False, False, 0.5), # 8min2s 3min40s 0.2141970843076706
# (1024, 1024, 30, 0.05, None, False, True, 0.5), # 4min57 0.21297718584537506
# (1024, 1024, 30, 0.12, None, False, True, 0.5), # 2min34 0.25963714718818665
# (1024, 1024, 30, 0.2, None, False, True, 0.5), # 1min51 0.31409069895744324
# (1024, 1024, 20, 0, None, False, False, 0.5), # 5min25 2min29 0.18987375497817993
# (1024, 1024, 20, 0.05, None, False, True, 0.5), # 3min3 0.17194810509681702
# (1024, 1024, 20, 0.12, None, False, True, 0.5), # 2min15 0.19407868385314941
# (1024, 1024, 20, 0.2, None, False, True, 0.5), # 1min48 0.2832985818386078
(1024, 1024, 30, 0.12, None, False, False, 0.26),
(512, 2048, 30, 0.12, "anime", True, False, 0.4),
], ],
) )
def test_flux_dev_base( def test_flux_dev_loras(
cache_threshold: float,
height: int, height: int,
width: int, width: int,
num_inference_steps: int, num_inference_steps: int,
cache_threshold: float, lora_name: str,
lora_name: str | None, lora_strength: float,
use_qencoder: bool,
cpu_offload: bool,
expected_lpips: float, expected_lpips: float,
): ):
run_test_flux_dev( run_test(
precision="int4", precision=get_precision(),
model_name="flux.1-dev",
dataset_name="MJHQ" if lora_name is None else lora_name,
height=height, height=height,
width=width, width=width,
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
guidance_scale=3.5, guidance_scale=3.5,
use_qencoder=use_qencoder, use_qencoder=False,
cpu_offload=cpu_offload, cpu_offload=False,
lora_name=lora_name, lora_names=lora_name,
lora_scale=1, lora_strengths=lora_strength,
cache_threshold=cache_threshold, cache_threshold=cache_threshold,
max_dataset_size=16,
expected_lpips=expected_lpips, expected_lpips=expected_lpips,
) )
import os
import pytest import pytest
import torch
from diffusers import FluxPipeline
from peft.tuners import lora
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
from nunchaku.lora.flux import convert_to_nunchaku_flux_lowrank_dict, is_nunchaku_format, to_diffusers
from .utils import run_pipeline
from ..data import get_dataset
from ..utils import already_generate, compute_lpips
LORA_PATH_MAP = { from nunchaku.utils import get_precision, is_turing
"hypersd8": "ByteDance/Hyper-SD/Hyper-FLUX.1-dev-8steps-lora.safetensors", from .utils import run_test
"realism": "XLabs-AI/flux-RealismLora/lora.safetensors",
"ghibsky": "aleksa-codes/flux-ghibsky-illustration/lora.safetensors",
"anime": "alvdansen/sonny-anime-fixed/araminta_k_sonnyanime_fluxd_fixed.safetensors",
"sketch": "Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch/FLUX-dev-lora-children-simple-sketch.safetensors",
"yarn": "linoyts/yarn_art_Flux_LoRA/pytorch_lora_weights.safetensors",
"haunted_linework": "alvdansen/haunted_linework_flux/hauntedlinework_flux_araminta_k.safetensors",
}
def run_test_flux_dev( @pytest.mark.skipif(is_turing(), reason="Skip tests for Turing GPUs")
precision: str, @pytest.mark.parametrize(
height: int, "height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips",
width: int, [
num_inference_steps: int, (1024, 1024, 50, "flashattn2", False, 0.226),
guidance_scale: float, (2048, 512, 25, "nunchaku-fp16", False, 0.243),
use_qencoder: bool, ],
cpu_offload: bool, )
lora_name: str | None, def test_flux_dev(
lora_scale: float, height: int, width: int, num_inference_steps: int, attention_impl: str, cpu_offload: bool, expected_lpips: float
cache_threshold: float,
max_dataset_size: int,
expected_lpips: float,
): ):
save_root = os.path.join( run_test(
"results", precision=get_precision(),
"dev", model_name="flux.1-dev",
f"w{width}h{height}t{num_inference_steps}g{guidance_scale}" height=height,
+ (f"-{lora_name}_{lora_scale:.1f}" if lora_name else ""), width=width,
) num_inference_steps=num_inference_steps,
dataset = get_dataset( attention_impl=attention_impl,
name="MJHQ" if lora_name in [None, "hypersd8"] else lora_name, max_dataset_size=max_dataset_size
)
save_dir_16bit = os.path.join(save_root, "bf16")
if not already_generate(save_dir_16bit, max_dataset_size):
pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipeline = pipeline.to("cuda")
if lora_name is not None:
pipeline.load_lora_weights(
os.path.dirname(LORA_PATH_MAP[lora_name]),
weight_name=os.path.basename(LORA_PATH_MAP[lora_name]),
adapter_name="lora",
)
for n, m in pipeline.transformer.named_modules():
if isinstance(m, lora.LoraLayer):
for name in m.scaling.keys():
m.scaling[name] = lora_scale
run_pipeline(
dataset,
pipeline,
save_dir=save_dir_16bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
name = precision
name += "-qencoder" if use_qencoder else ""
name += "-offload" if cpu_offload else ""
name += f"-cache{cache_threshold:.2f}" if cache_threshold > 0 else ""
save_dir_4bit = os.path.join(save_root, name)
if not already_generate(save_dir_4bit, max_dataset_size):
pipeline_init_kwargs = {}
if precision == "int4":
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-dev", offload=cpu_offload
)
else:
assert precision == "fp4"
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-fp4-flux.1-dev", precision="fp4", offload=cpu_offload
)
if lora_name is not None:
lora_path = LORA_PATH_MAP[lora_name]
transformer.update_lora_params(lora_path)
transformer.set_lora_strength(lora_scale)
pipeline_init_kwargs["transformer"] = transformer
if use_qencoder:
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
pipeline = pipeline.to("cuda")
if cpu_offload:
pipeline.enable_sequential_cpu_offload()
if cache_threshold > 0:
apply_cache_on_pipe(pipeline, residual_diff_threshold=cache_threshold)
run_pipeline(
dataset,
pipeline,
save_dir=save_dir_4bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
lpips = compute_lpips(save_dir_16bit, save_dir_4bit)
print(f"lpips: {lpips}")
assert lpips < expected_lpips * 1.05
@pytest.mark.parametrize("cpu_offload", [False, True])
def test_flux_dev_base(cpu_offload: bool):
run_test_flux_dev(
precision="int4",
height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
lora_name=None, expected_lpips=expected_lpips,
lora_scale=0,
cache_threshold=0,
max_dataset_size=8,
expected_lpips=0.16,
)
def test_flux_dev_qencoder_800x600():
run_test_flux_dev(
precision="int4",
height=800,
width=600,
num_inference_steps=50,
guidance_scale=3.5,
use_qencoder=True,
cpu_offload=False,
lora_name=None,
lora_scale=0,
cache_threshold=0,
max_dataset_size=8,
expected_lpips=0.36,
) )
import pytest import pytest
from tests.flux.test_flux_dev import run_test_flux_dev from nunchaku.utils import get_precision, is_turing
from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_inference_steps,lora_name,lora_scale,cpu_offload,expected_lpips", "num_inference_steps,lora_name,lora_strength,cpu_offload,expected_lpips",
[ [
(25, "realism", 0.9, False, 0.17), (25, "realism", 0.9, True, 0.178),
(25, "ghibsky", 1, False, 0.16), (25, "ghibsky", 1, False, 0.164),
(28, "anime", 1, False, 0.27), (28, "anime", 1, False, 0.284),
(24, "sketch", 1, False, 0.35), (24, "sketch", 1, True, 0.223),
(28, "yarn", 1, False, 0.22), (28, "yarn", 1, False, 0.211),
(25, "haunted_linework", 1, False, 0.34), (25, "haunted_linework", 1, True, 0.317),
], ],
) )
def test_flux_dev_loras(num_inference_steps, lora_name, lora_scale, cpu_offload, expected_lpips): def test_flux_dev_loras(num_inference_steps, lora_name, lora_strength, cpu_offload, expected_lpips):
run_test_flux_dev( run_test(
precision="int4", precision=get_precision(),
model_name="flux.1-dev",
dataset_name=lora_name,
height=1024, height=1024,
width=1024, width=1024,
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
guidance_scale=3.5, guidance_scale=3.5,
use_qencoder=False, use_qencoder=False,
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
lora_name=lora_name, lora_names=lora_name,
lora_scale=lora_scale, lora_strengths=lora_strength,
cache_threshold=0, cache_threshold=0,
max_dataset_size=8,
expected_lpips=expected_lpips, expected_lpips=expected_lpips,
) )
def test_flux_dev_hypersd8_1080x1920(): @pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
run_test_flux_dev( def test_flux_dev_hypersd8_1536x2048():
precision="int4", run_test(
height=1080, precision=get_precision(),
width=1920, model_name="flux.1-dev",
dataset_name="MJHQ",
height=1536,
width=2048,
num_inference_steps=8, num_inference_steps=8,
guidance_scale=3.5, guidance_scale=3.5,
use_qencoder=False, use_qencoder=False,
cpu_offload=False, attention_impl="nunchaku-fp16",
lora_name="hypersd8", cpu_offload=True,
lora_scale=0.125, lora_names="hypersd8",
lora_strengths=0.125,
cache_threshold=0,
expected_lpips=0.291,
)
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_turbo8_2048x2048():
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="MJHQ",
height=2048,
width=2048,
num_inference_steps=8,
guidance_scale=3.5,
use_qencoder=False,
attention_impl="nunchaku-fp16",
cpu_offload=True,
lora_names="turbo8",
lora_strengths=1,
cache_threshold=0,
expected_lpips=0.189,
)
# lora composition
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_turbo8_yarn_2048x1024():
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="yarn",
height=2048,
width=1024,
num_inference_steps=8,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=True,
lora_names=["turbo8", "yarn"],
lora_strengths=[1, 1],
cache_threshold=0,
expected_lpips=0.252,
)
# large rank loras
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_turbo8_yarn_1024x1024():
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="ghibsky",
height=1024,
width=1024,
num_inference_steps=8,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=True,
lora_names=["realism", "ghibsky", "anime", "sketch", "yarn", "haunted_linework", "turbo8"],
lora_strengths=[0, 1, 0, 0, 0, 0, 1],
cache_threshold=0, cache_threshold=0,
max_dataset_size=8,
expected_lpips=0.44, expected_lpips=0.44,
) )
...@@ -3,8 +3,10 @@ import torch ...@@ -3,8 +3,10 @@ import torch
from diffusers import FluxPipeline from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"use_qencoder,cpu_offload,memory_limit", "use_qencoder,cpu_offload,memory_limit",
[ [
...@@ -15,10 +17,12 @@ from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel ...@@ -15,10 +17,12 @@ from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
], ],
) )
def test_flux_schnell_memory(use_qencoder: bool, cpu_offload: bool, memory_limit: float): def test_flux_schnell_memory(use_qencoder: bool, cpu_offload: bool, memory_limit: float):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
precision = get_precision()
pipeline_init_kwargs = { pipeline_init_kwargs = {
"transformer": NunchakuFluxTransformer2dModel.from_pretrained( "transformer": NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-schnell", offload=cpu_offload f"mit-han-lab/svdq-{precision}-flux.1-schnell", offload=cpu_offload
) )
} }
if use_qencoder: if use_qencoder:
...@@ -26,10 +30,12 @@ def test_flux_schnell_memory(use_qencoder: bool, cpu_offload: bool, memory_limit ...@@ -26,10 +30,12 @@ def test_flux_schnell_memory(use_qencoder: bool, cpu_offload: bool, memory_limit
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2 pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
).to("cuda") )
if cpu_offload: if cpu_offload:
pipeline.enable_sequential_cpu_offload() pipeline.enable_sequential_cpu_offload()
else:
pipeline = pipeline.to("cuda")
pipeline( pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=50, guidance_scale=0 "A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=50, guidance_scale=0
......
import os
import pytest import pytest
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel from nunchaku.utils import get_precision, is_turing
from tests.data import get_dataset from .utils import run_test
from tests.flux.utils import run_pipeline
from tests.utils import already_generate, compute_lpips
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"precision,height,width,num_inference_steps,guidance_scale,use_qencoder,cpu_offload,max_dataset_size,expected_lpips", "height,width,attention_impl,cpu_offload,expected_lpips",
[ [
("int4", 1024, 1024, 4, 0, False, False, 16, 0.258), (1024, 1024, "flashattn2", False, 0.250),
("int4", 1024, 1024, 4, 0, True, False, 16, 0.41), (1024, 1024, "nunchaku-fp16", False, 0.255),
("int4", 1024, 1024, 4, 0, True, False, 16, 0.41), (1024, 1024, "flashattn2", True, 0.250),
("int4", 1920, 1080, 4, 0, False, False, 16, 0.258), (1920, 1080, "nunchaku-fp16", False, 0.253),
("int4", 600, 800, 4, 0, False, False, 16, 0.29), (2048, 2048, "flashattn2", True, 0.274),
], ],
) )
def test_flux_schnell( def test_int4_schnell(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float):
precision: str, run_test(
height: int, precision=get_precision(),
width: int, height=height,
num_inference_steps: int, width=width,
guidance_scale: float, attention_impl=attention_impl,
use_qencoder: bool, cpu_offload=cpu_offload,
cpu_offload: bool, expected_lpips=expected_lpips,
max_dataset_size: int,
expected_lpips: float,
):
dataset = get_dataset(name="MJHQ", max_dataset_size=max_dataset_size)
save_root = os.path.join("results", "schnell", f"w{width}h{height}t{num_inference_steps}g{guidance_scale}")
save_dir_16bit = os.path.join(save_root, "bf16")
if not already_generate(save_dir_16bit, max_dataset_size):
pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipeline = pipeline.to("cuda")
run_pipeline(
dataset,
pipeline,
save_dir=save_dir_16bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
save_dir_4bit = os.path.join(
save_root, f"{precision}-qencoder" if use_qencoder else f"{precision}" + ("-cpuoffload" if cpu_offload else "")
)
if not already_generate(save_dir_4bit, max_dataset_size):
pipeline_init_kwargs = {}
if precision == "int4":
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-schnell", offload=cpu_offload
)
else:
assert precision == "fp4"
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-fp4-flux.1-schnell", precision="fp4", offload=cpu_offload
)
pipeline_init_kwargs["transformer"] = transformer
if use_qencoder:
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
pipeline = pipeline.to("cuda")
if cpu_offload:
pipeline.enable_sequential_cpu_offload()
run_pipeline(
dataset,
pipeline,
save_dir=save_dir_4bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
) )
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
lpips = compute_lpips(save_dir_16bit, save_dir_4bit)
print(f"lpips: {lpips}")
assert lpips < expected_lpips * 1.05
import pytest
import torch import torch
from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline, FluxFillPipeline, FluxPipeline, FluxPriorReduxPipeline
from diffusers.utils import load_image
from image_gen_aux import DepthPreprocessor
from nunchaku import NunchakuFluxTransformer2dModel from nunchaku.utils import get_precision, is_turing
from .utils import run_test
def test_flux_dev_canny(): @pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-canny-dev") def test_flux_canny_dev():
pipe = FluxControlPipeline.from_pretrained( run_test(
"black-forest-labs/FLUX.1-Canny-dev", transformer=transformer, torch_dtype=torch.bfloat16 precision=get_precision(),
).to("cuda") model_name="flux.1-canny-dev",
dataset_name="MJHQ-control",
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts." # noqa: E501 task="canny",
control_image = load_image( dtype=torch.bfloat16,
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png" height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=30,
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
expected_lpips=0.103 if get_precision() == "int4" else 0.164,
) )
processor = CannyDetector()
control_image = processor(
control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024
)
image = pipe( @pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
prompt=prompt, control_image=control_image, height=1024, width=1024, num_inference_steps=50, guidance_scale=30.0 def test_flux_depth_dev():
).images[0] run_test(
image.save("flux.1-canny-dev.png") precision=get_precision(),
model_name="flux.1-depth-dev",
dataset_name="MJHQ-control",
task="depth",
dtype=torch.bfloat16,
height=1024,
width=1024,
num_inference_steps=30,
guidance_scale=10,
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
expected_lpips=0.103 if get_precision() == "int4" else 0.120,
)
def test_flux_dev_depth(): @pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-depth-dev") def test_flux_fill_dev():
run_test(
precision=get_precision(),
model_name="flux.1-fill-dev",
dataset_name="MJHQ-control",
task="fill",
dtype=torch.bfloat16,
height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=30,
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
expected_lpips=0.045,
)
pipe = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Depth-dev",
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts." # noqa: E501 @pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
control_image = load_image( def test_flux_dev_canny_lora():
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png" run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="MJHQ-control",
task="canny",
dtype=torch.bfloat16,
height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=30,
attention_impl="nunchaku-fp16",
cpu_offload=False,
lora_names="canny",
lora_strengths=0.85,
cache_threshold=0,
expected_lpips=0.103,
) )
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
control_image = processor(control_image)[0].convert("RGB")
image = pipe(
prompt=prompt, control_image=control_image, height=1024, width=1024, num_inference_steps=30, guidance_scale=10.0
).images[0]
image.save("flux.1-depth-dev.png")
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_depth_lora():
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="MJHQ-control",
task="depth",
dtype=torch.bfloat16,
height=1024,
width=1024,
num_inference_steps=30,
guidance_scale=10,
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
lora_names="depth",
lora_strengths=0.85,
expected_lpips=0.163,
)
def test_flux_dev_fill():
image = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/example.png")
mask = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/mask.png")
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-fill-dev") @pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
pipe = FluxFillPipeline.from_pretrained( def test_flux_fill_dev_turbo():
"black-forest-labs/FLUX.1-Fill-dev", transformer=transformer, torch_dtype=torch.bfloat16 run_test(
).to("cuda") precision=get_precision(),
image = pipe( model_name="flux.1-fill-dev",
prompt="A wooden basket of a cat.", dataset_name="MJHQ-control",
image=image, task="fill",
mask_image=mask, dtype=torch.bfloat16,
height=1024, height=1024,
width=1024, width=1024,
num_inference_steps=8,
guidance_scale=30, guidance_scale=30,
num_inference_steps=50, attention_impl="nunchaku-fp16",
max_sequence_length=512, cpu_offload=False,
).images[0] cache_threshold=0,
image.save("flux.1-fill-dev.png") lora_names="turbo8",
lora_strengths=1,
expected_lpips=0.048,
)
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_redux(): def test_flux_dev_redux():
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained( run_test(
"black-forest-labs/FLUX.1-Redux-dev", torch_dtype=torch.bfloat16 precision=get_precision(),
).to("cuda") model_name="flux.1-dev",
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev") dataset_name="MJHQ-control",
pipe = FluxPipeline.from_pretrained( task="redux",
"black-forest-labs/FLUX.1-dev", dtype=torch.bfloat16,
text_encoder=None, height=1024,
text_encoder_2=None, width=1024,
transformer=transformer, num_inference_steps=50,
torch_dtype=torch.bfloat16, guidance_scale=2.5,
).to("cuda") attention_impl="nunchaku-fp16",
cpu_offload=False,
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png") cache_threshold=0,
pipe_prior_output = pipe_prior_redux(image) expected_lpips=0.187 if get_precision() == "int4" else 0.55, # redux seems to generate different images on 5090
images = pipe(guidance_scale=2.5, num_inference_steps=50, **pipe_prior_output).images )
images[0].save("flux.1-redux-dev.png")
import os
import pytest import pytest
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel from .utils import run_test
from tests.data import get_dataset from nunchaku.utils import get_precision, is_turing
from tests.flux.utils import run_pipeline
from tests.utils import already_generate, compute_lpips
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"precision,height,width,num_inference_steps,guidance_scale,use_qencoder,cpu_offload,max_dataset_size,expected_lpips", "height,width,attention_impl,cpu_offload,expected_lpips",
[ [(1024, 1024, "flashattn2", False, 0.25), (2048, 512, "nunchaku-fp16", False, 0.25)],
("int4", 1024, 1024, 4, 3.5, False, False, 16, 0.25),
("int4", 2048, 512, 4, 3.5, False, False, 16, 0.21),
],
) )
def test_shuttle_jaguar( def test_shuttle_jaguar(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float):
precision: str, run_test(
height: int, precision=get_precision(),
width: int, model_name="shuttle-jaguar",
num_inference_steps: int, height=height,
guidance_scale: float, width=width,
use_qencoder: bool, attention_impl=attention_impl,
cpu_offload: bool, cpu_offload=cpu_offload,
max_dataset_size: int, expected_lpips=expected_lpips,
expected_lpips: float,
):
dataset = get_dataset(name="MJHQ", max_dataset_size=max_dataset_size)
save_root = os.path.join("results", "shuttle-jaguar", f"w{width}h{height}t{num_inference_steps}g{guidance_scale}")
save_dir_16bit = os.path.join(save_root, "bf16")
if not already_generate(save_dir_16bit, max_dataset_size):
pipeline = FluxPipeline.from_pretrained("shuttleai/shuttle-jaguar", torch_dtype=torch.bfloat16)
pipeline = pipeline.to("cuda")
run_pipeline(
dataset,
pipeline,
save_dir=save_dir_16bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
save_dir_4bit = os.path.join(
save_root, f"{precision}-qencoder" if use_qencoder else f"{precision}" + ("-cpuoffload" if cpu_offload else "")
)
if not already_generate(save_dir_4bit, max_dataset_size):
pipeline_init_kwargs = {}
if precision == "int4":
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-shuttle-jaguar", offload=cpu_offload
)
else:
assert precision == "fp4"
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-fp4-shuttle-jaguar", precision="fp4", offload=cpu_offload
)
pipeline_init_kwargs["transformer"] = transformer
if use_qencoder:
raise NotImplementedError
# text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
# pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained(
"shuttleai/shuttle-jaguar", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
pipeline = pipeline.to("cuda")
if cpu_offload:
pipeline.enable_sequential_cpu_offload()
run_pipeline(
dataset,
pipeline,
save_dir=save_dir_4bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
) )
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
lpips = compute_lpips(save_dir_16bit, save_dir_4bit)
print(f"lpips: {lpips}")
assert lpips < expected_lpips * 1.05
import pytest
from nunchaku.utils import get_precision
from .utils import run_test
@pytest.mark.skipif(get_precision() == "fp4", reason="Blackwell GPUs. Skip tests for Turing.")
@pytest.mark.parametrize(
"height,width,num_inference_steps,cpu_offload,i2f_mode,expected_lpips",
[
(1024, 1024, 50, True, None, 0.253),
(1024, 1024, 50, True, "enabled", 0.258),
(1024, 1024, 50, True, "always", 0.257),
],
)
def test_flux_dev(
height: int, width: int, num_inference_steps: int, cpu_offload: bool, i2f_mode: str | None, expected_lpips: float
):
run_test(
precision=get_precision(),
dtype="fp16",
model_name="flux.1-dev",
height=height,
width=width,
num_inference_steps=num_inference_steps,
attention_impl="nunchaku-fp16",
cpu_offload=cpu_offload,
i2f_mode=i2f_mode,
expected_lpips=expected_lpips,
)
import os import os
import torch import torch
from diffusers import FluxPipeline from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline, FluxFillPipeline, FluxPipeline, FluxPriorReduxPipeline
from diffusers.utils import load_image
from image_gen_aux import DepthPreprocessor
from tqdm import tqdm from tqdm import tqdm
from ..utils import hash_str_to_int import nunchaku
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
from nunchaku.lora.flux.compose import compose_lora
from ..data import download_hf_dataset, get_dataset
from ..utils import already_generate, compute_lpips, hash_str_to_int
ORIGINAL_REPO_MAP = {
"flux.1-schnell": "black-forest-labs/FLUX.1-schnell",
"flux.1-dev": "black-forest-labs/FLUX.1-dev",
"shuttle-jaguar": "shuttleai/shuttle-jaguar",
"flux.1-canny-dev": "black-forest-labs/FLUX.1-Canny-dev",
"flux.1-depth-dev": "black-forest-labs/FLUX.1-Depth-dev",
"flux.1-fill-dev": "black-forest-labs/FLUX.1-Fill-dev",
}
def run_pipeline(dataset, pipeline: FluxPipeline, save_dir: str, forward_kwargs: dict = {}): NUNCHAKU_REPO_PATTERN_MAP = {
"flux.1-schnell": "mit-han-lab/svdq-{precision}-flux.1-schnell",
"flux.1-dev": "mit-han-lab/svdq-{precision}-flux.1-dev",
"shuttle-jaguar": "mit-han-lab/svdq-{precision}-shuttle-jaguar",
"flux.1-canny-dev": "mit-han-lab/svdq-{precision}-flux.1-canny-dev",
"flux.1-depth-dev": "mit-han-lab/svdq-{precision}-flux.1-depth-dev",
"flux.1-fill-dev": "mit-han-lab/svdq-{precision}-flux.1-fill-dev",
}
LORA_PATH_MAP = {
"hypersd8": "ByteDance/Hyper-SD/Hyper-FLUX.1-dev-8steps-lora.safetensors",
"turbo8": "alimama-creative/FLUX.1-Turbo-Alpha/diffusion_pytorch_model.safetensors",
"realism": "XLabs-AI/flux-RealismLora/lora.safetensors",
"ghibsky": "aleksa-codes/flux-ghibsky-illustration/lora.safetensors",
"anime": "alvdansen/sonny-anime-fixed/araminta_k_sonnyanime_fluxd_fixed.safetensors",
"sketch": "Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch/FLUX-dev-lora-children-simple-sketch.safetensors",
"yarn": "linoyts/yarn_art_Flux_LoRA/pytorch_lora_weights.safetensors",
"haunted_linework": "alvdansen/haunted_linework_flux/hauntedlinework_flux_araminta_k.safetensors",
"canny": "black-forest-labs/FLUX.1-Canny-dev-lora/flux1-canny-dev-lora.safetensors",
"depth": "black-forest-labs/FLUX.1-Depth-dev-lora/flux1-depth-dev-lora.safetensors",
}
def run_pipeline(dataset, task: str, pipeline: FluxPipeline, save_dir: str, forward_kwargs: dict = {}):
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
pipeline.set_progress_bar_config(desc="Sampling", leave=False, dynamic_ncols=True, position=1) pipeline.set_progress_bar_config(desc="Sampling", leave=False, dynamic_ncols=True, position=1)
if task == "canny":
processor = CannyDetector()
elif task == "depth":
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
elif task == "redux":
processor = FluxPriorReduxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Redux-dev", torch_dtype=torch.bfloat16
).to("cuda")
else:
assert task in ["t2i", "fill"]
processor = None
for row in tqdm(dataset): for row in tqdm(dataset):
filename = row["filename"] filename = row["filename"]
prompt = row["prompt"] prompt = row["prompt"]
_forward_kwargs = {k: v for k, v in forward_kwargs.items()}
if task == "canny":
assert forward_kwargs.get("height", 1024) == 1024
assert forward_kwargs.get("width", 1024) == 1024
control_image = load_image(row["canny_image_path"])
control_image = processor(
control_image,
low_threshold=50,
high_threshold=200,
detect_resolution=1024,
image_resolution=1024,
)
_forward_kwargs["control_image"] = control_image
elif task == "depth":
control_image = load_image(row["depth_image_path"])
control_image = processor(control_image)[0].convert("RGB")
_forward_kwargs["control_image"] = control_image
elif task == "fill":
image = load_image(row["image_path"])
mask_image = load_image(row["mask_image_path"])
_forward_kwargs["image"] = image
_forward_kwargs["mask_image"] = mask_image
elif task == "redux":
image = load_image(row["image_path"])
_forward_kwargs.update(processor(image))
seed = hash_str_to_int(filename) seed = hash_str_to_int(filename)
image = pipeline(prompt, generator=torch.Generator().manual_seed(seed), **forward_kwargs).images[0] if task == "redux":
image = pipeline(generator=torch.Generator().manual_seed(seed), **_forward_kwargs).images[0]
else:
image = pipeline(prompt, generator=torch.Generator().manual_seed(seed), **_forward_kwargs).images[0]
image.save(os.path.join(save_dir, f"{filename}.png")) image.save(os.path.join(save_dir, f"{filename}.png"))
torch.cuda.empty_cache()
def run_test(
precision: str = "int4",
model_name: str = "flux.1-schnell",
dataset_name: str = "MJHQ",
task: str = "t2i",
dtype: str | torch.dtype = torch.bfloat16, # the full precision dtype
height: int = 1024,
width: int = 1024,
num_inference_steps: int = 4,
guidance_scale: float = 3.5,
use_qencoder: bool = False,
attention_impl: str = "flashattn2", # "flashattn2" or "nunchaku-fp16"
cpu_offload: bool = False,
cache_threshold: float = 0,
lora_names: str | list[str] | None = None,
lora_strengths: float | list[float] = 1.0,
max_dataset_size: int = 20,
i2f_mode: str | None = None,
expected_lpips: float = 0.5,
):
if isinstance(dtype, str):
dtype_str = dtype
if dtype == "bf16":
dtype = torch.bfloat16
else:
assert dtype == "fp16"
dtype = torch.float16
else:
if dtype == torch.bfloat16:
dtype_str = "bf16"
else:
assert dtype == torch.float16
dtype_str = "fp16"
dataset = get_dataset(name=dataset_name, max_dataset_size=max_dataset_size)
model_id_16bit = ORIGINAL_REPO_MAP[model_name]
folder_name = f"w{width}h{height}t{num_inference_steps}g{guidance_scale}"
if lora_names is None:
lora_names = []
elif isinstance(lora_names, str):
lora_names = [lora_names]
if len(lora_names) > 0:
if isinstance(lora_strengths, (int, float)):
lora_strengths = [lora_strengths]
assert len(lora_names) == len(lora_strengths)
for lora_name, lora_strength in zip(lora_names, lora_strengths):
folder_name += f"-{lora_name}_{lora_strength}"
if not os.path.exists(os.path.join("test_results", "ref")):
ref_root = download_hf_dataset(local_dir=os.path.join("test_results", "ref"))
else:
ref_root = os.path.join("test_results", "ref")
save_dir_16bit = os.path.join(ref_root, dtype_str, model_name, folder_name)
if task in ["t2i", "redux"]:
pipeline_cls = FluxPipeline
elif task in ["canny", "depth"]:
pipeline_cls = FluxControlPipeline
elif task == "fill":
pipeline_cls = FluxFillPipeline
else:
raise NotImplementedError(f"Unknown task {task}!")
if not already_generate(save_dir_16bit, max_dataset_size):
pipeline_init_kwargs = {"text_encoder": None, "text_encoder2": None} if task == "redux" else {}
pipeline = pipeline_cls.from_pretrained(model_id_16bit, torch_dtype=dtype, **pipeline_init_kwargs)
pipeline = pipeline.to("cuda")
if len(lora_names) > 0:
for i, (lora_name, lora_strength) in enumerate(zip(lora_names, lora_strengths)):
lora_path = LORA_PATH_MAP[lora_name]
pipeline.load_lora_weights(
os.path.dirname(lora_path), weight_name=os.path.basename(lora_path), adapter_name=f"lora_{i}"
)
pipeline.set_adapters([f"lora_{i}" for i in range(len(lora_names))], lora_strengths)
run_pipeline(
dataset=dataset,
task=task,
pipeline=pipeline,
save_dir=save_dir_16bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
precision_str = precision
if use_qencoder:
precision_str += "-qe"
if attention_impl == "flashattn2":
precision_str += "-fa2"
else:
assert attention_impl == "nunchaku-fp16"
precision_str += "-nfp16"
if cpu_offload:
precision_str += "-co"
if cache_threshold > 0:
precision_str += f"-cache{cache_threshold}"
if i2f_mode is not None:
precision_str += f"-i2f{i2f_mode}"
save_dir_4bit = os.path.join("test_results", dtype_str, precision_str, model_name, folder_name)
if not already_generate(save_dir_4bit, max_dataset_size):
pipeline_init_kwargs = {}
model_id_4bit = NUNCHAKU_REPO_PATTERN_MAP[model_name].format(precision=precision)
if i2f_mode is not None:
nunchaku._C.utils.set_faster_i2f_mode(i2f_mode)
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
model_id_4bit, offload=cpu_offload, torch_dtype=dtype
)
transformer.set_attention_impl(attention_impl)
if len(lora_names) > 0:
if len(lora_names) == 1: # directly load the lora
lora_path = LORA_PATH_MAP[lora_names[0]]
lora_strength = lora_strengths[0]
transformer.update_lora_params(lora_path)
transformer.set_lora_strength(lora_strength)
else:
composed_lora = compose_lora(
[
(LORA_PATH_MAP[lora_name], lora_strength)
for lora_name, lora_strength in zip(lora_names, lora_strengths)
]
)
transformer.update_lora_params(composed_lora)
pipeline_init_kwargs["transformer"] = transformer
if task == "redux":
pipeline_init_kwargs.update({"text_encoder": None, "text_encoder_2": None})
elif use_qencoder:
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = pipeline_cls.from_pretrained(model_id_16bit, torch_dtype=dtype, **pipeline_init_kwargs)
if cpu_offload:
pipeline.enable_sequential_cpu_offload()
else:
pipeline = pipeline.to("cuda")
run_pipeline(
dataset=dataset,
task=task,
pipeline=pipeline,
save_dir=save_dir_4bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del transformer
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
lpips = compute_lpips(save_dir_16bit, save_dir_4bit)
print(f"lpips: {lpips}")
assert lpips < expected_lpips * 1.05
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment