Commit 2ede5f01 authored by Muyang Li's avatar Muyang Li Committed by Zhekai Zhang
Browse files

Clean some codes and refract the tests

parent 83b7542d
......@@ -4,7 +4,6 @@ from diffusers import SanaPipeline
from nunchaku import NunchakuSanaTransformer2DModel
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
pipe = SanaPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
......@@ -29,4 +28,4 @@ image = pipe(
generator=torch.Generator().manual_seed(42),
).images[0]
image.save("sana_1600m.png")
image.save("sana_1600m-int4.png")
......@@ -23,4 +23,4 @@ image = pipe(
generator=torch.Generator().manual_seed(42),
).images[0]
image.save("sana_1600m.png")
image.save("sana_1600m-int4.png")
......@@ -24,4 +24,4 @@ image = pipe(
pag_scale=2.0,
num_inference_steps=20,
).images[0]
image.save("sana_1600m_pag.png")
image.save("sana_1600m_pag-int4.png")
import functools
import unittest
import torch
from diffusers import DiffusionPipeline, FluxTransformer2DModel
from torch import nn
from ...caching import utils
......@@ -11,7 +11,7 @@ def apply_cache_on_transformer(transformer: FluxTransformer2DModel, *, residual_
if getattr(transformer, "_is_cached", False):
return transformer
cached_transformer_blocks = torch.nn.ModuleList(
cached_transformer_blocks = nn.ModuleList(
[
utils.FluxCachedTransformerBlocks(
transformer=transformer,
......@@ -20,7 +20,7 @@ def apply_cache_on_transformer(transformer: FluxTransformer2DModel, *, residual_
)
]
)
dummy_single_transformer_blocks = torch.nn.ModuleList()
dummy_single_transformer_blocks = nn.ModuleList()
original_forward = transformer.forward
......
......@@ -94,7 +94,6 @@ def apply_prev_hidden_states_residual(
encoder_hidden_states = encoder_hidden_states_residual + encoder_hidden_states
encoder_hidden_states = encoder_hidden_states.contiguous()
return hidden_states, encoder_hidden_states
......@@ -109,6 +108,7 @@ def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=Fals
)
return can_use_cache
class SanaCachedTransformerBlocks(nn.Module):
def __init__(
self,
......@@ -123,20 +123,20 @@ class SanaCachedTransformerBlocks(nn.Module):
self.residual_diff_threshold = residual_diff_threshold
self.verbose = verbose
def forward(self,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask=None,
timestep=None,
post_patch_height=None,
post_patch_width=None,
):
def forward(
self,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask=None,
timestep=None,
post_patch_height=None,
post_patch_width=None,
):
batch_size = hidden_states.shape[0]
if self.residual_diff_threshold <= 0.0 or batch_size > 2:
if batch_size > 2:
print("Batch size > 2 (for SANA CFG)"
" currently not supported")
print("Batch size > 2 (for SANA CFG)" " currently not supported")
first_transformer_block = self.transformer_blocks[0]
hidden_states = first_transformer_block(
......@@ -199,15 +199,15 @@ class SanaCachedTransformerBlocks(nn.Module):
return hidden_states
def call_remaining_transformer_blocks(self,
def call_remaining_transformer_blocks(
self,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask=None,
timestep=None,
post_patch_height=None,
post_patch_width=None
post_patch_width=None,
):
first_transformer_block = self.transformer_blocks[0]
original_hidden_states = hidden_states
......@@ -219,7 +219,7 @@ class SanaCachedTransformerBlocks(nn.Module):
timestep=timestep,
height=post_patch_height,
width=post_patch_width,
skip_first_layer=True
skip_first_layer=True,
)
hidden_states_residual = hidden_states - original_hidden_states
......
......@@ -13,11 +13,11 @@ from packaging.version import Version
from safetensors.torch import load_file, save_file
from torch import nn
from .utils import get_precision, NunchakuModelLoaderMixin, pad_tensor
from .utils import NunchakuModelLoaderMixin, pad_tensor
from ..._C import QuantizedFluxModel, utils as cutils
from ...lora.flux.nunchaku_converter import fuse_vectors, to_nunchaku
from ...lora.flux.utils import is_nunchaku_format
from ...utils import load_state_dict_in_safetensors
from ...utils import get_precision, load_state_dict_in_safetensors
SVD_RANK = 32
......@@ -127,7 +127,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
image_rotary_emb: torch.Tensor,
joint_attention_kwargs=None,
controlnet_block_samples=None,
controlnet_single_block_samples=None
controlnet_single_block_samples=None,
):
batch_size = hidden_states.shape[0]
txt_tokens = encoder_hidden_states.shape[1]
......@@ -159,8 +159,14 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_img = self.pack_rotemb(pad_tensor(rotary_emb_img, 256, 1))
hidden_states, encoder_hidden_states = self.m.forward_layer(
idx, hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_txt,
controlnet_block_samples, controlnet_single_block_samples
idx,
hidden_states,
encoder_hidden_states,
temb,
rotary_emb_img,
rotary_emb_txt,
controlnet_block_samples,
controlnet_single_block_samples,
)
hidden_states = hidden_states.to(original_dtype).to(original_device)
......@@ -578,7 +584,7 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
controlnet_block_samples=controlnet_block_samples,
controlnet_single_block_samples=controlnet_single_block_samples
controlnet_single_block_samples=controlnet_single_block_samples,
)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
......
......@@ -8,7 +8,8 @@ from safetensors.torch import load_file
from torch import nn
from torch.nn import functional as F
from .utils import get_precision, NunchakuModelLoaderMixin
from .utils import NunchakuModelLoaderMixin
from ...utils import get_precision
from ..._C import QuantizedSanaModel, utils as cutils
SVD_RANK = 32
......
import os
import warnings
from typing import Any, Optional
import torch
......@@ -82,21 +81,3 @@ def pad_tensor(tensor: Optional[torch.Tensor], multiples: int, dim: int, fill: A
result.fill_(fill)
result[[slice(0, extent) for extent in tensor.shape]] = tensor
return result
def get_precision(precision: str, device: str | torch.device, pretrained_model_name_or_path: str | None = None) -> str:
assert precision in ("auto", "int4", "fp4")
if precision == "auto":
if isinstance(device, str):
device = torch.device(device)
capability = torch.cuda.get_device_capability(0 if device.index is None else device.index)
sm = f"{capability[0]}{capability[1]}"
precision = "fp4" if sm == "120" else "int4"
if pretrained_model_name_or_path is not None:
if precision == "int4":
if "fp4" in pretrained_model_name_or_path:
warnings.warn("The model may be quantized to fp4, but you are loading it with int4 precision.")
elif precision == "fp4":
if "int4" in pretrained_model_name_or_path:
warnings.warn("The model may be quantized to int4, but you are loading it with fp4 precision.")
return precision
import os
import warnings
import safetensors
import torch
......@@ -69,3 +70,38 @@ def filter_state_dict(state_dict: dict[str, torch.Tensor], filter_prefix: str =
filtered state dict.
"""
return {k.removeprefix(filter_prefix): v for k, v in state_dict.items() if k.startswith(filter_prefix)}
def get_precision(
precision: str = "auto", device: str | torch.device = "cuda", pretrained_model_name_or_path: str | None = None
) -> str:
assert precision in ("auto", "int4", "fp4")
if precision == "auto":
if isinstance(device, str):
device = torch.device(device)
capability = torch.cuda.get_device_capability(0 if device.index is None else device.index)
sm = f"{capability[0]}{capability[1]}"
precision = "fp4" if sm == "120" else "int4"
if pretrained_model_name_or_path is not None:
if precision == "int4":
if "fp4" in pretrained_model_name_or_path:
warnings.warn("The model may be quantized to fp4, but you are loading it with int4 precision.")
elif precision == "fp4":
if "int4" in pretrained_model_name_or_path:
warnings.warn("The model may be quantized to int4, but you are loading it with fp4 precision.")
return precision
def is_turing(device: str | torch.device = "cuda") -> bool:
"""Check if the current GPU is a Turing GPU.
Returns:
`bool`:
True if the current GPU is a Turing GPU, False otherwise.
"""
if isinstance(device, str):
device = torch.device(device)
device_id = 0 if device.index is None else device.index
capability = torch.cuda.get_device_capability(device_id)
sm = f"{capability[0]}{capability[1]}"
return sm == "75"
......@@ -3,6 +3,7 @@ import os
import random
import datasets
import yaml
from PIL import Image
_CITATION = """\
......@@ -32,6 +33,8 @@ IMAGE_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/
META_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/meta_data.json"
CONTROL_URL = "https://huggingface.co/datasets/mit-han-lab/svdquant-datasets/resolve/main/MJHQ-5000.zip"
class MJHQConfig(datasets.BuilderConfig):
def __init__(self, max_dataset_size: int = -1, return_gt: bool = False, **kwargs):
......@@ -46,11 +49,14 @@ class MJHQConfig(datasets.BuilderConfig):
self.return_gt = return_gt
class DCI(datasets.GeneratorBasedBuilder):
class MJHQ(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
BUILDER_CONFIG_CLASS = MJHQConfig
BUILDER_CONFIGS = [MJHQConfig(name="MJHQ", version=VERSION, description="MJHQ-30K full dataset")]
BUILDER_CONFIGS = [
MJHQConfig(name="MJHQ", version=VERSION, description="MJHQ-30K full dataset"),
MJHQConfig(name="MJHQ-control", version=VERSION, description="MJHQ-5K with controls"),
]
DEFAULT_CONFIG_NAME = "MJHQ"
def _info(self):
......@@ -64,6 +70,10 @@ class DCI(datasets.GeneratorBasedBuilder):
"image_root": datasets.Value("string"),
"image_path": datasets.Value("string"),
"split": datasets.Value("string"),
"canny_image_path": datasets.Value("string"),
"cropped_image_path": datasets.Value("string"),
"depth_image_path": datasets.Value("string"),
"mask_image_path": datasets.Value("string"),
}
)
return datasets.DatasetInfo(
......@@ -71,36 +81,75 @@ class DCI(datasets.GeneratorBasedBuilder):
)
def _split_generators(self, dl_manager: datasets.download.DownloadManager):
meta_path = dl_manager.download(META_URL)
image_root = dl_manager.download_and_extract(IMAGE_URL)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN, gen_kwargs={"meta_path": meta_path, "image_root": image_root}
),
]
if self.config.name == "MJHQ":
meta_path = dl_manager.download(META_URL)
image_root = dl_manager.download_and_extract(IMAGE_URL)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN, gen_kwargs={"meta_path": meta_path, "image_root": image_root}
),
]
else:
assert self.config.name == "MJHQ-control"
control_root = dl_manager.download_and_extract(CONTROL_URL)
control_root = os.path.join(control_root, "MJHQ-5000")
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={"meta_path": os.path.join(control_root, "prompts.yaml"), "image_root": control_root},
),
]
def _generate_examples(self, meta_path: str, image_root: str):
with open(meta_path, "r") as f:
meta = json.load(f)
names = list(meta.keys())
if self.config.max_dataset_size > 0:
random.Random(0).shuffle(names)
names = names[: self.config.max_dataset_size]
names = sorted(names)
for i, name in enumerate(names):
category = meta[name]["category"]
prompt = meta[name]["prompt"]
image_path = os.path.join(image_root, category, f"{name}.jpg")
yield i, {
"filename": name,
"category": category,
"image": Image.open(image_path) if self.config.return_gt else None,
"prompt": prompt,
"meta_path": meta_path,
"image_root": image_root,
"image_path": image_path,
"split": self.config.name,
}
if self.config.name == "MJHQ":
with open(meta_path, "r") as f:
meta = json.load(f)
names = list(meta.keys())
if self.config.max_dataset_size > 0:
random.Random(0).shuffle(names)
names = names[: self.config.max_dataset_size]
names = sorted(names)
for i, name in enumerate(names):
category = meta[name]["category"]
prompt = meta[name]["prompt"]
image_path = os.path.join(image_root, category, f"{name}.jpg")
yield i, {
"filename": name,
"category": category,
"image": Image.open(image_path) if self.config.return_gt else None,
"prompt": prompt,
"meta_path": meta_path,
"image_root": image_root,
"image_path": image_path,
"split": self.config.name,
"canny_image_path": None,
"cropped_image_path": None,
"depth_image_path": None,
"mask_image_path": None,
}
else:
assert self.config.name == "MJHQ-control"
meta = yaml.safe_load(open(meta_path, "r"))
names = list(meta.keys())
if self.config.max_dataset_size > 0:
random.Random(0).shuffle(names)
names = names[: self.config.max_dataset_size]
names = sorted(names)
for i, name in enumerate(names):
prompt = meta[name]
yield i, {
"filename": name,
"category": None,
"image": None,
"prompt": prompt,
"meta_path": meta_path,
"image_root": image_root,
"image_path": os.path.join(image_root, "images", f"{name}.png"),
"split": self.config.name,
"canny_image_path": os.path.join(image_root, "canny_images", f"{name}.png"),
"cropped_image_path": os.path.join(image_root, "cropped_images", f"{name}.png"),
"depth_image_path": os.path.join(image_root, "depth_images", f"{name}.png"),
"mask_image_path": os.path.join(image_root, "mask_images", f"{name}.png"),
}
......@@ -3,9 +3,16 @@ import random
import datasets
import yaml
from huggingface_hub import snapshot_download
from nunchaku.utils import fetch_or_download
__all__ = ["get_dataset"]
__all__ = ["get_dataset", "load_dataset_yaml", "download_hf_dataset"]
def download_hf_dataset(repo_id: str = "mit-han-lab/nunchaku-test", local_dir: str | None = None) -> str:
path = snapshot_download(repo_id=repo_id, repo_type="dataset", local_dir=local_dir)
return path
def load_dataset_yaml(meta_path: str, max_dataset_size: int = -1, repeat: int = 4) -> dict:
......@@ -46,10 +53,13 @@ def get_dataset(
path = os.path.join(prefix, f"{name}")
if name == "MJHQ":
dataset = datasets.load_dataset(path, return_gt=return_gt, **kwargs)
elif name == "MJHQ-control":
kwargs["name"] = "MJHQ-control"
dataset = datasets.load_dataset(os.path.join(prefix, "MJHQ"), return_gt=return_gt, **kwargs)
else:
dataset = datasets.Dataset.from_dict(
load_dataset_yaml(
fetch_or_download(f"mit-han-lab/nunchaku-test/{name}.yaml", repo_type="dataset"),
fetch_or_download(f"mit-han-lab/svdquant-datasets/{name}.yaml", repo_type="dataset"),
max_dataset_size=max_dataset_size,
repeat=1,
),
......
import pytest
from .test_flux_dev import run_test_flux_dev
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests for Turing GPUs")
@pytest.mark.parametrize(
"height,width,num_inference_steps,cache_threshold,lora_name,use_qencoder,cpu_offload,expected_lpips",
"cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips",
[
# (1024, 1024, 50, 0, None, False, False, 0.5), # 13min20s 5min55s 0.19539418816566467
# (1024, 1024, 50, 0.05, None, False, True, 0.5), # 7min11s 0.21917256712913513
# (1024, 1024, 50, 0.12, None, False, True, 0.5), # 2min58s, 0.24101486802101135
# (1024, 1024, 50, 0.2, None, False, True, 0.5), # 2min23s, 0.3101634383201599
# (1024, 1024, 50, 0.5, None, False, True, 0.5), # 1min44s 0.6543852090835571
# (1024, 1024, 30, 0, None, False, False, 0.5), # 8min2s 3min40s 0.2141970843076706
# (1024, 1024, 30, 0.05, None, False, True, 0.5), # 4min57 0.21297718584537506
# (1024, 1024, 30, 0.12, None, False, True, 0.5), # 2min34 0.25963714718818665
# (1024, 1024, 30, 0.2, None, False, True, 0.5), # 1min51 0.31409069895744324
# (1024, 1024, 20, 0, None, False, False, 0.5), # 5min25 2min29 0.18987375497817993
# (1024, 1024, 20, 0.05, None, False, True, 0.5), # 3min3 0.17194810509681702
# (1024, 1024, 20, 0.12, None, False, True, 0.5), # 2min15 0.19407868385314941
# (1024, 1024, 20, 0.2, None, False, True, 0.5), # 1min48 0.2832985818386078
(1024, 1024, 30, 0.12, None, False, False, 0.26),
(512, 2048, 30, 0.12, "anime", True, False, 0.4),
(0.12, 1024, 1024, 30, None, 1, 0.26),
(0.12, 512, 2048, 30, "anime", 1, 0.4),
],
)
def test_flux_dev_base(
def test_flux_dev_loras(
cache_threshold: float,
height: int,
width: int,
num_inference_steps: int,
cache_threshold: float,
lora_name: str | None,
use_qencoder: bool,
cpu_offload: bool,
lora_name: str,
lora_strength: float,
expected_lpips: float,
):
run_test_flux_dev(
precision="int4",
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="MJHQ" if lora_name is None else lora_name,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=3.5,
use_qencoder=use_qencoder,
cpu_offload=cpu_offload,
lora_name=lora_name,
lora_scale=1,
use_qencoder=False,
cpu_offload=False,
lora_names=lora_name,
lora_strengths=lora_strength,
cache_threshold=cache_threshold,
max_dataset_size=16,
expected_lpips=expected_lpips,
)
import os
import pytest
import torch
from diffusers import FluxPipeline
from peft.tuners import lora
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
from nunchaku.lora.flux import convert_to_nunchaku_flux_lowrank_dict, is_nunchaku_format, to_diffusers
from .utils import run_pipeline
from ..data import get_dataset
from ..utils import already_generate, compute_lpips
LORA_PATH_MAP = {
"hypersd8": "ByteDance/Hyper-SD/Hyper-FLUX.1-dev-8steps-lora.safetensors",
"realism": "XLabs-AI/flux-RealismLora/lora.safetensors",
"ghibsky": "aleksa-codes/flux-ghibsky-illustration/lora.safetensors",
"anime": "alvdansen/sonny-anime-fixed/araminta_k_sonnyanime_fluxd_fixed.safetensors",
"sketch": "Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch/FLUX-dev-lora-children-simple-sketch.safetensors",
"yarn": "linoyts/yarn_art_Flux_LoRA/pytorch_lora_weights.safetensors",
"haunted_linework": "alvdansen/haunted_linework_flux/hauntedlinework_flux_araminta_k.safetensors",
}
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
def run_test_flux_dev(
precision: str,
height: int,
width: int,
num_inference_steps: int,
guidance_scale: float,
use_qencoder: bool,
cpu_offload: bool,
lora_name: str | None,
lora_scale: float,
cache_threshold: float,
max_dataset_size: int,
expected_lpips: float,
@pytest.mark.skipif(is_turing(), reason="Skip tests for Turing GPUs")
@pytest.mark.parametrize(
"height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips",
[
(1024, 1024, 50, "flashattn2", False, 0.226),
(2048, 512, 25, "nunchaku-fp16", False, 0.243),
],
)
def test_flux_dev(
height: int, width: int, num_inference_steps: int, attention_impl: str, cpu_offload: bool, expected_lpips: float
):
save_root = os.path.join(
"results",
"dev",
f"w{width}h{height}t{num_inference_steps}g{guidance_scale}"
+ (f"-{lora_name}_{lora_scale:.1f}" if lora_name else ""),
)
dataset = get_dataset(
name="MJHQ" if lora_name in [None, "hypersd8"] else lora_name, max_dataset_size=max_dataset_size
)
save_dir_16bit = os.path.join(save_root, "bf16")
if not already_generate(save_dir_16bit, max_dataset_size):
pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipeline = pipeline.to("cuda")
if lora_name is not None:
pipeline.load_lora_weights(
os.path.dirname(LORA_PATH_MAP[lora_name]),
weight_name=os.path.basename(LORA_PATH_MAP[lora_name]),
adapter_name="lora",
)
for n, m in pipeline.transformer.named_modules():
if isinstance(m, lora.LoraLayer):
for name in m.scaling.keys():
m.scaling[name] = lora_scale
run_pipeline(
dataset,
pipeline,
save_dir=save_dir_16bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
name = precision
name += "-qencoder" if use_qencoder else ""
name += "-offload" if cpu_offload else ""
name += f"-cache{cache_threshold:.2f}" if cache_threshold > 0 else ""
save_dir_4bit = os.path.join(save_root, name)
if not already_generate(save_dir_4bit, max_dataset_size):
pipeline_init_kwargs = {}
if precision == "int4":
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-dev", offload=cpu_offload
)
else:
assert precision == "fp4"
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-fp4-flux.1-dev", precision="fp4", offload=cpu_offload
)
if lora_name is not None:
lora_path = LORA_PATH_MAP[lora_name]
transformer.update_lora_params(lora_path)
transformer.set_lora_strength(lora_scale)
pipeline_init_kwargs["transformer"] = transformer
if use_qencoder:
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
pipeline = pipeline.to("cuda")
if cpu_offload:
pipeline.enable_sequential_cpu_offload()
if cache_threshold > 0:
apply_cache_on_pipe(pipeline, residual_diff_threshold=cache_threshold)
run_pipeline(
dataset,
pipeline,
save_dir=save_dir_4bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
lpips = compute_lpips(save_dir_16bit, save_dir_4bit)
print(f"lpips: {lpips}")
assert lpips < expected_lpips * 1.05
@pytest.mark.parametrize("cpu_offload", [False, True])
def test_flux_dev_base(cpu_offload: bool):
run_test_flux_dev(
precision="int4",
height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=3.5,
use_qencoder=False,
run_test(
precision=get_precision(),
model_name="flux.1-dev",
height=height,
width=width,
num_inference_steps=num_inference_steps,
attention_impl=attention_impl,
cpu_offload=cpu_offload,
lora_name=None,
lora_scale=0,
cache_threshold=0,
max_dataset_size=8,
expected_lpips=0.16,
)
def test_flux_dev_qencoder_800x600():
run_test_flux_dev(
precision="int4",
height=800,
width=600,
num_inference_steps=50,
guidance_scale=3.5,
use_qencoder=True,
cpu_offload=False,
lora_name=None,
lora_scale=0,
cache_threshold=0,
max_dataset_size=8,
expected_lpips=0.36,
expected_lpips=expected_lpips,
)
import pytest
from tests.flux.test_flux_dev import run_test_flux_dev
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
@pytest.mark.parametrize(
"num_inference_steps,lora_name,lora_scale,cpu_offload,expected_lpips",
"num_inference_steps,lora_name,lora_strength,cpu_offload,expected_lpips",
[
(25, "realism", 0.9, False, 0.17),
(25, "ghibsky", 1, False, 0.16),
(28, "anime", 1, False, 0.27),
(24, "sketch", 1, False, 0.35),
(28, "yarn", 1, False, 0.22),
(25, "haunted_linework", 1, False, 0.34),
(25, "realism", 0.9, True, 0.178),
(25, "ghibsky", 1, False, 0.164),
(28, "anime", 1, False, 0.284),
(24, "sketch", 1, True, 0.223),
(28, "yarn", 1, False, 0.211),
(25, "haunted_linework", 1, True, 0.317),
],
)
def test_flux_dev_loras(num_inference_steps, lora_name, lora_scale, cpu_offload, expected_lpips):
run_test_flux_dev(
precision="int4",
def test_flux_dev_loras(num_inference_steps, lora_name, lora_strength, cpu_offload, expected_lpips):
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name=lora_name,
height=1024,
width=1024,
num_inference_steps=num_inference_steps,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=cpu_offload,
lora_name=lora_name,
lora_scale=lora_scale,
lora_names=lora_name,
lora_strengths=lora_strength,
cache_threshold=0,
max_dataset_size=8,
expected_lpips=expected_lpips,
)
def test_flux_dev_hypersd8_1080x1920():
run_test_flux_dev(
precision="int4",
height=1080,
width=1920,
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_hypersd8_1536x2048():
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="MJHQ",
height=1536,
width=2048,
num_inference_steps=8,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=False,
lora_name="hypersd8",
lora_scale=0.125,
attention_impl="nunchaku-fp16",
cpu_offload=True,
lora_names="hypersd8",
lora_strengths=0.125,
cache_threshold=0,
expected_lpips=0.291,
)
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_turbo8_2048x2048():
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="MJHQ",
height=2048,
width=2048,
num_inference_steps=8,
guidance_scale=3.5,
use_qencoder=False,
attention_impl="nunchaku-fp16",
cpu_offload=True,
lora_names="turbo8",
lora_strengths=1,
cache_threshold=0,
expected_lpips=0.189,
)
# lora composition
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_turbo8_yarn_2048x1024():
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="yarn",
height=2048,
width=1024,
num_inference_steps=8,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=True,
lora_names=["turbo8", "yarn"],
lora_strengths=[1, 1],
cache_threshold=0,
expected_lpips=0.252,
)
# large rank loras
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_turbo8_yarn_1024x1024():
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="ghibsky",
height=1024,
width=1024,
num_inference_steps=8,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=True,
lora_names=["realism", "ghibsky", "anime", "sketch", "yarn", "haunted_linework", "turbo8"],
lora_strengths=[0, 1, 0, 0, 0, 0, 1],
cache_threshold=0,
max_dataset_size=8,
expected_lpips=0.44,
)
......@@ -3,8 +3,10 @@ import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
@pytest.mark.parametrize(
"use_qencoder,cpu_offload,memory_limit",
[
......@@ -15,10 +17,12 @@ from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
],
)
def test_flux_schnell_memory(use_qencoder: bool, cpu_offload: bool, memory_limit: float):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
precision = get_precision()
pipeline_init_kwargs = {
"transformer": NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-schnell", offload=cpu_offload
f"mit-han-lab/svdq-{precision}-flux.1-schnell", offload=cpu_offload
)
}
if use_qencoder:
......@@ -26,10 +30,12 @@ def test_flux_schnell_memory(use_qencoder: bool, cpu_offload: bool, memory_limit
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
).to("cuda")
)
if cpu_offload:
pipeline.enable_sequential_cpu_offload()
else:
pipeline = pipeline.to("cuda")
pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=50, guidance_scale=0
......
import os
import pytest
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
from tests.data import get_dataset
from tests.flux.utils import run_pipeline
from tests.utils import already_generate, compute_lpips
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
@pytest.mark.parametrize(
"precision,height,width,num_inference_steps,guidance_scale,use_qencoder,cpu_offload,max_dataset_size,expected_lpips",
"height,width,attention_impl,cpu_offload,expected_lpips",
[
("int4", 1024, 1024, 4, 0, False, False, 16, 0.258),
("int4", 1024, 1024, 4, 0, True, False, 16, 0.41),
("int4", 1024, 1024, 4, 0, True, False, 16, 0.41),
("int4", 1920, 1080, 4, 0, False, False, 16, 0.258),
("int4", 600, 800, 4, 0, False, False, 16, 0.29),
(1024, 1024, "flashattn2", False, 0.250),
(1024, 1024, "nunchaku-fp16", False, 0.255),
(1024, 1024, "flashattn2", True, 0.250),
(1920, 1080, "nunchaku-fp16", False, 0.253),
(2048, 2048, "flashattn2", True, 0.274),
],
)
def test_flux_schnell(
precision: str,
height: int,
width: int,
num_inference_steps: int,
guidance_scale: float,
use_qencoder: bool,
cpu_offload: bool,
max_dataset_size: int,
expected_lpips: float,
):
dataset = get_dataset(name="MJHQ", max_dataset_size=max_dataset_size)
save_root = os.path.join("results", "schnell", f"w{width}h{height}t{num_inference_steps}g{guidance_scale}")
save_dir_16bit = os.path.join(save_root, "bf16")
if not already_generate(save_dir_16bit, max_dataset_size):
pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipeline = pipeline.to("cuda")
run_pipeline(
dataset,
pipeline,
save_dir=save_dir_16bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
save_dir_4bit = os.path.join(
save_root, f"{precision}-qencoder" if use_qencoder else f"{precision}" + ("-cpuoffload" if cpu_offload else "")
def test_int4_schnell(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float):
run_test(
precision=get_precision(),
height=height,
width=width,
attention_impl=attention_impl,
cpu_offload=cpu_offload,
expected_lpips=expected_lpips,
)
if not already_generate(save_dir_4bit, max_dataset_size):
pipeline_init_kwargs = {}
if precision == "int4":
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-schnell", offload=cpu_offload
)
else:
assert precision == "fp4"
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-fp4-flux.1-schnell", precision="fp4", offload=cpu_offload
)
pipeline_init_kwargs["transformer"] = transformer
if use_qencoder:
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
pipeline = pipeline.to("cuda")
if cpu_offload:
pipeline.enable_sequential_cpu_offload()
run_pipeline(
dataset,
pipeline,
save_dir=save_dir_4bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
lpips = compute_lpips(save_dir_16bit, save_dir_4bit)
print(f"lpips: {lpips}")
assert lpips < expected_lpips * 1.05
import pytest
import torch
from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline, FluxFillPipeline, FluxPipeline, FluxPriorReduxPipeline
from diffusers.utils import load_image
from image_gen_aux import DepthPreprocessor
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
def test_flux_dev_canny():
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-canny-dev")
pipe = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Canny-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts." # noqa: E501
control_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png"
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_canny_dev():
run_test(
precision=get_precision(),
model_name="flux.1-canny-dev",
dataset_name="MJHQ-control",
task="canny",
dtype=torch.bfloat16,
height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=30,
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
expected_lpips=0.103 if get_precision() == "int4" else 0.164,
)
processor = CannyDetector()
control_image = processor(
control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024
)
image = pipe(
prompt=prompt, control_image=control_image, height=1024, width=1024, num_inference_steps=50, guidance_scale=30.0
).images[0]
image.save("flux.1-canny-dev.png")
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_depth_dev():
run_test(
precision=get_precision(),
model_name="flux.1-depth-dev",
dataset_name="MJHQ-control",
task="depth",
dtype=torch.bfloat16,
height=1024,
width=1024,
num_inference_steps=30,
guidance_scale=10,
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
expected_lpips=0.103 if get_precision() == "int4" else 0.120,
)
def test_flux_dev_depth():
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-depth-dev")
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_fill_dev():
run_test(
precision=get_precision(),
model_name="flux.1-fill-dev",
dataset_name="MJHQ-control",
task="fill",
dtype=torch.bfloat16,
height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=30,
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
expected_lpips=0.045,
)
pipe = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Depth-dev",
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts." # noqa: E501
control_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png"
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_canny_lora():
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="MJHQ-control",
task="canny",
dtype=torch.bfloat16,
height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=30,
attention_impl="nunchaku-fp16",
cpu_offload=False,
lora_names="canny",
lora_strengths=0.85,
cache_threshold=0,
expected_lpips=0.103,
)
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
control_image = processor(control_image)[0].convert("RGB")
image = pipe(
prompt=prompt, control_image=control_image, height=1024, width=1024, num_inference_steps=30, guidance_scale=10.0
).images[0]
image.save("flux.1-depth-dev.png")
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_depth_lora():
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="MJHQ-control",
task="depth",
dtype=torch.bfloat16,
height=1024,
width=1024,
num_inference_steps=30,
guidance_scale=10,
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
lora_names="depth",
lora_strengths=0.85,
expected_lpips=0.163,
)
def test_flux_dev_fill():
image = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/example.png")
mask = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/resolve/main/mask.png")
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-fill-dev")
pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = pipe(
prompt="A wooden basket of a cat.",
image=image,
mask_image=mask,
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_fill_dev_turbo():
run_test(
precision=get_precision(),
model_name="flux.1-fill-dev",
dataset_name="MJHQ-control",
task="fill",
dtype=torch.bfloat16,
height=1024,
width=1024,
num_inference_steps=8,
guidance_scale=30,
num_inference_steps=50,
max_sequence_length=512,
).images[0]
image.save("flux.1-fill-dev.png")
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
lora_names="turbo8",
lora_strengths=1,
expected_lpips=0.048,
)
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_redux():
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Redux-dev", torch_dtype=torch.bfloat16
).to("cuda")
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev")
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
text_encoder=None,
text_encoder_2=None,
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
pipe_prior_output = pipe_prior_redux(image)
images = pipe(guidance_scale=2.5, num_inference_steps=50, **pipe_prior_output).images
images[0].save("flux.1-redux-dev.png")
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="MJHQ-control",
task="redux",
dtype=torch.bfloat16,
height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=2.5,
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
expected_lpips=0.187 if get_precision() == "int4" else 0.55, # redux seems to generate different images on 5090
)
import os
import pytest
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel
from tests.data import get_dataset
from tests.flux.utils import run_pipeline
from tests.utils import already_generate, compute_lpips
from .utils import run_test
from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
@pytest.mark.parametrize(
"precision,height,width,num_inference_steps,guidance_scale,use_qencoder,cpu_offload,max_dataset_size,expected_lpips",
[
("int4", 1024, 1024, 4, 3.5, False, False, 16, 0.25),
("int4", 2048, 512, 4, 3.5, False, False, 16, 0.21),
],
"height,width,attention_impl,cpu_offload,expected_lpips",
[(1024, 1024, "flashattn2", False, 0.25), (2048, 512, "nunchaku-fp16", False, 0.25)],
)
def test_shuttle_jaguar(
precision: str,
height: int,
width: int,
num_inference_steps: int,
guidance_scale: float,
use_qencoder: bool,
cpu_offload: bool,
max_dataset_size: int,
expected_lpips: float,
):
dataset = get_dataset(name="MJHQ", max_dataset_size=max_dataset_size)
save_root = os.path.join("results", "shuttle-jaguar", f"w{width}h{height}t{num_inference_steps}g{guidance_scale}")
save_dir_16bit = os.path.join(save_root, "bf16")
if not already_generate(save_dir_16bit, max_dataset_size):
pipeline = FluxPipeline.from_pretrained("shuttleai/shuttle-jaguar", torch_dtype=torch.bfloat16)
pipeline = pipeline.to("cuda")
run_pipeline(
dataset,
pipeline,
save_dir=save_dir_16bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
save_dir_4bit = os.path.join(
save_root, f"{precision}-qencoder" if use_qencoder else f"{precision}" + ("-cpuoffload" if cpu_offload else "")
def test_shuttle_jaguar(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float):
run_test(
precision=get_precision(),
model_name="shuttle-jaguar",
height=height,
width=width,
attention_impl=attention_impl,
cpu_offload=cpu_offload,
expected_lpips=expected_lpips,
)
if not already_generate(save_dir_4bit, max_dataset_size):
pipeline_init_kwargs = {}
if precision == "int4":
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-shuttle-jaguar", offload=cpu_offload
)
else:
assert precision == "fp4"
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-fp4-shuttle-jaguar", precision="fp4", offload=cpu_offload
)
pipeline_init_kwargs["transformer"] = transformer
if use_qencoder:
raise NotImplementedError
# text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
# pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained(
"shuttleai/shuttle-jaguar", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
pipeline = pipeline.to("cuda")
if cpu_offload:
pipeline.enable_sequential_cpu_offload()
run_pipeline(
dataset,
pipeline,
save_dir=save_dir_4bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
lpips = compute_lpips(save_dir_16bit, save_dir_4bit)
print(f"lpips: {lpips}")
assert lpips < expected_lpips * 1.05
import pytest
from nunchaku.utils import get_precision
from .utils import run_test
@pytest.mark.skipif(get_precision() == "fp4", reason="Blackwell GPUs. Skip tests for Turing.")
@pytest.mark.parametrize(
"height,width,num_inference_steps,cpu_offload,i2f_mode,expected_lpips",
[
(1024, 1024, 50, True, None, 0.253),
(1024, 1024, 50, True, "enabled", 0.258),
(1024, 1024, 50, True, "always", 0.257),
],
)
def test_flux_dev(
height: int, width: int, num_inference_steps: int, cpu_offload: bool, i2f_mode: str | None, expected_lpips: float
):
run_test(
precision=get_precision(),
dtype="fp16",
model_name="flux.1-dev",
height=height,
width=width,
num_inference_steps=num_inference_steps,
attention_impl="nunchaku-fp16",
cpu_offload=cpu_offload,
i2f_mode=i2f_mode,
expected_lpips=expected_lpips,
)
import os
import torch
from diffusers import FluxPipeline
from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline, FluxFillPipeline, FluxPipeline, FluxPriorReduxPipeline
from diffusers.utils import load_image
from image_gen_aux import DepthPreprocessor
from tqdm import tqdm
from ..utils import hash_str_to_int
import nunchaku
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
from nunchaku.lora.flux.compose import compose_lora
from ..data import download_hf_dataset, get_dataset
from ..utils import already_generate, compute_lpips, hash_str_to_int
ORIGINAL_REPO_MAP = {
"flux.1-schnell": "black-forest-labs/FLUX.1-schnell",
"flux.1-dev": "black-forest-labs/FLUX.1-dev",
"shuttle-jaguar": "shuttleai/shuttle-jaguar",
"flux.1-canny-dev": "black-forest-labs/FLUX.1-Canny-dev",
"flux.1-depth-dev": "black-forest-labs/FLUX.1-Depth-dev",
"flux.1-fill-dev": "black-forest-labs/FLUX.1-Fill-dev",
}
def run_pipeline(dataset, pipeline: FluxPipeline, save_dir: str, forward_kwargs: dict = {}):
NUNCHAKU_REPO_PATTERN_MAP = {
"flux.1-schnell": "mit-han-lab/svdq-{precision}-flux.1-schnell",
"flux.1-dev": "mit-han-lab/svdq-{precision}-flux.1-dev",
"shuttle-jaguar": "mit-han-lab/svdq-{precision}-shuttle-jaguar",
"flux.1-canny-dev": "mit-han-lab/svdq-{precision}-flux.1-canny-dev",
"flux.1-depth-dev": "mit-han-lab/svdq-{precision}-flux.1-depth-dev",
"flux.1-fill-dev": "mit-han-lab/svdq-{precision}-flux.1-fill-dev",
}
LORA_PATH_MAP = {
"hypersd8": "ByteDance/Hyper-SD/Hyper-FLUX.1-dev-8steps-lora.safetensors",
"turbo8": "alimama-creative/FLUX.1-Turbo-Alpha/diffusion_pytorch_model.safetensors",
"realism": "XLabs-AI/flux-RealismLora/lora.safetensors",
"ghibsky": "aleksa-codes/flux-ghibsky-illustration/lora.safetensors",
"anime": "alvdansen/sonny-anime-fixed/araminta_k_sonnyanime_fluxd_fixed.safetensors",
"sketch": "Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch/FLUX-dev-lora-children-simple-sketch.safetensors",
"yarn": "linoyts/yarn_art_Flux_LoRA/pytorch_lora_weights.safetensors",
"haunted_linework": "alvdansen/haunted_linework_flux/hauntedlinework_flux_araminta_k.safetensors",
"canny": "black-forest-labs/FLUX.1-Canny-dev-lora/flux1-canny-dev-lora.safetensors",
"depth": "black-forest-labs/FLUX.1-Depth-dev-lora/flux1-depth-dev-lora.safetensors",
}
def run_pipeline(dataset, task: str, pipeline: FluxPipeline, save_dir: str, forward_kwargs: dict = {}):
os.makedirs(save_dir, exist_ok=True)
pipeline.set_progress_bar_config(desc="Sampling", leave=False, dynamic_ncols=True, position=1)
if task == "canny":
processor = CannyDetector()
elif task == "depth":
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
elif task == "redux":
processor = FluxPriorReduxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Redux-dev", torch_dtype=torch.bfloat16
).to("cuda")
else:
assert task in ["t2i", "fill"]
processor = None
for row in tqdm(dataset):
filename = row["filename"]
prompt = row["prompt"]
_forward_kwargs = {k: v for k, v in forward_kwargs.items()}
if task == "canny":
assert forward_kwargs.get("height", 1024) == 1024
assert forward_kwargs.get("width", 1024) == 1024
control_image = load_image(row["canny_image_path"])
control_image = processor(
control_image,
low_threshold=50,
high_threshold=200,
detect_resolution=1024,
image_resolution=1024,
)
_forward_kwargs["control_image"] = control_image
elif task == "depth":
control_image = load_image(row["depth_image_path"])
control_image = processor(control_image)[0].convert("RGB")
_forward_kwargs["control_image"] = control_image
elif task == "fill":
image = load_image(row["image_path"])
mask_image = load_image(row["mask_image_path"])
_forward_kwargs["image"] = image
_forward_kwargs["mask_image"] = mask_image
elif task == "redux":
image = load_image(row["image_path"])
_forward_kwargs.update(processor(image))
seed = hash_str_to_int(filename)
image = pipeline(prompt, generator=torch.Generator().manual_seed(seed), **forward_kwargs).images[0]
if task == "redux":
image = pipeline(generator=torch.Generator().manual_seed(seed), **_forward_kwargs).images[0]
else:
image = pipeline(prompt, generator=torch.Generator().manual_seed(seed), **_forward_kwargs).images[0]
image.save(os.path.join(save_dir, f"{filename}.png"))
torch.cuda.empty_cache()
def run_test(
precision: str = "int4",
model_name: str = "flux.1-schnell",
dataset_name: str = "MJHQ",
task: str = "t2i",
dtype: str | torch.dtype = torch.bfloat16, # the full precision dtype
height: int = 1024,
width: int = 1024,
num_inference_steps: int = 4,
guidance_scale: float = 3.5,
use_qencoder: bool = False,
attention_impl: str = "flashattn2", # "flashattn2" or "nunchaku-fp16"
cpu_offload: bool = False,
cache_threshold: float = 0,
lora_names: str | list[str] | None = None,
lora_strengths: float | list[float] = 1.0,
max_dataset_size: int = 20,
i2f_mode: str | None = None,
expected_lpips: float = 0.5,
):
if isinstance(dtype, str):
dtype_str = dtype
if dtype == "bf16":
dtype = torch.bfloat16
else:
assert dtype == "fp16"
dtype = torch.float16
else:
if dtype == torch.bfloat16:
dtype_str = "bf16"
else:
assert dtype == torch.float16
dtype_str = "fp16"
dataset = get_dataset(name=dataset_name, max_dataset_size=max_dataset_size)
model_id_16bit = ORIGINAL_REPO_MAP[model_name]
folder_name = f"w{width}h{height}t{num_inference_steps}g{guidance_scale}"
if lora_names is None:
lora_names = []
elif isinstance(lora_names, str):
lora_names = [lora_names]
if len(lora_names) > 0:
if isinstance(lora_strengths, (int, float)):
lora_strengths = [lora_strengths]
assert len(lora_names) == len(lora_strengths)
for lora_name, lora_strength in zip(lora_names, lora_strengths):
folder_name += f"-{lora_name}_{lora_strength}"
if not os.path.exists(os.path.join("test_results", "ref")):
ref_root = download_hf_dataset(local_dir=os.path.join("test_results", "ref"))
else:
ref_root = os.path.join("test_results", "ref")
save_dir_16bit = os.path.join(ref_root, dtype_str, model_name, folder_name)
if task in ["t2i", "redux"]:
pipeline_cls = FluxPipeline
elif task in ["canny", "depth"]:
pipeline_cls = FluxControlPipeline
elif task == "fill":
pipeline_cls = FluxFillPipeline
else:
raise NotImplementedError(f"Unknown task {task}!")
if not already_generate(save_dir_16bit, max_dataset_size):
pipeline_init_kwargs = {"text_encoder": None, "text_encoder2": None} if task == "redux" else {}
pipeline = pipeline_cls.from_pretrained(model_id_16bit, torch_dtype=dtype, **pipeline_init_kwargs)
pipeline = pipeline.to("cuda")
if len(lora_names) > 0:
for i, (lora_name, lora_strength) in enumerate(zip(lora_names, lora_strengths)):
lora_path = LORA_PATH_MAP[lora_name]
pipeline.load_lora_weights(
os.path.dirname(lora_path), weight_name=os.path.basename(lora_path), adapter_name=f"lora_{i}"
)
pipeline.set_adapters([f"lora_{i}" for i in range(len(lora_names))], lora_strengths)
run_pipeline(
dataset=dataset,
task=task,
pipeline=pipeline,
save_dir=save_dir_16bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
precision_str = precision
if use_qencoder:
precision_str += "-qe"
if attention_impl == "flashattn2":
precision_str += "-fa2"
else:
assert attention_impl == "nunchaku-fp16"
precision_str += "-nfp16"
if cpu_offload:
precision_str += "-co"
if cache_threshold > 0:
precision_str += f"-cache{cache_threshold}"
if i2f_mode is not None:
precision_str += f"-i2f{i2f_mode}"
save_dir_4bit = os.path.join("test_results", dtype_str, precision_str, model_name, folder_name)
if not already_generate(save_dir_4bit, max_dataset_size):
pipeline_init_kwargs = {}
model_id_4bit = NUNCHAKU_REPO_PATTERN_MAP[model_name].format(precision=precision)
if i2f_mode is not None:
nunchaku._C.utils.set_faster_i2f_mode(i2f_mode)
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
model_id_4bit, offload=cpu_offload, torch_dtype=dtype
)
transformer.set_attention_impl(attention_impl)
if len(lora_names) > 0:
if len(lora_names) == 1: # directly load the lora
lora_path = LORA_PATH_MAP[lora_names[0]]
lora_strength = lora_strengths[0]
transformer.update_lora_params(lora_path)
transformer.set_lora_strength(lora_strength)
else:
composed_lora = compose_lora(
[
(LORA_PATH_MAP[lora_name], lora_strength)
for lora_name, lora_strength in zip(lora_names, lora_strengths)
]
)
transformer.update_lora_params(composed_lora)
pipeline_init_kwargs["transformer"] = transformer
if task == "redux":
pipeline_init_kwargs.update({"text_encoder": None, "text_encoder_2": None})
elif use_qencoder:
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = pipeline_cls.from_pretrained(model_id_16bit, torch_dtype=dtype, **pipeline_init_kwargs)
if cpu_offload:
pipeline.enable_sequential_cpu_offload()
else:
pipeline = pipeline.to("cuda")
run_pipeline(
dataset=dataset,
task=task,
pipeline=pipeline,
save_dir=save_dir_4bit,
forward_kwargs={
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
},
)
del transformer
del pipeline
# release the gpu memory
torch.cuda.empty_cache()
lpips = compute_lpips(save_dir_16bit, save_dir_4bit)
print(f"lpips: {lpips}")
assert lpips < expected_lpips * 1.05
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment