Commit 37c494a7 authored by Zhekai Zhang's avatar Zhekai Zhang
Browse files

Initial release

parents
import os.path
import datasets
__all__ = ["get_dataset"]
def get_dataset(
name: str, config_name: str | None = None, split: str = "train", return_gt: bool = False
) -> datasets.Dataset:
prefix = os.path.dirname(__file__)
kwargs = {"name": config_name, "split": split, "trust_remote_code": True, "token": True}
path = os.path.join(prefix, f"{name}")
if name == "DCI":
dataset = datasets.load_dataset(path, return_gt=return_gt, **kwargs)
elif name == "MJHQ":
dataset = datasets.load_dataset(path, return_gt=return_gt, **kwargs)
else:
raise ValueError(f"Unknown dataset name: {name}")
return dataset
import argparse
import os
import torch
from tqdm import tqdm
from data import get_dataset
from utils import get_pipeline, hash_str_to_int
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"-m", "--model", type=str, default="schnell", choices=["schnell", "dev"], help="Which FLUX.1 model to use"
)
parser.add_argument(
"-p", "--precision", type=str, default="int4", choices=["int4", "bf16"], help="Which precision to use"
)
parser.add_argument(
"-d", "--datasets", type=str, nargs="*", default=["MJHQ", "DCI"], help="The benchmark datasets to evaluate on."
)
parser.add_argument("-t", "--num-inference-steps", type=int, default=4, help="Number of inference steps")
parser.add_argument("-g", "--guidance-scale", type=float, default=0, help="Guidance scale.")
parser.add_argument("-o", "--output-root", type=str, default=None, help="Image output path")
parser.add_argument(
"--chunk-step",
type=int,
default=1,
help="You will generate images for the subset specified by [chunk-start::chunk-step].",
)
parser.add_argument(
"--chunk-start",
type=int,
default=0,
help="You will generate images for the subset specified by [chunk-start::chunk-step].",
)
known_args, _ = parser.parse_known_args()
if known_args.model == "dev":
parser.set_defaults(num_inference_steps=50, guidance_scale=3.5)
args = parser.parse_args()
return args
def main():
args = get_args()
assert args.chunk_step > 0
assert 0 <= args.chunk_start < args.chunk_step
pipeline = get_pipeline(model_name=args.model, precision=args.precision, device="cuda")
pipeline.set_progress_bar_config(desc="Sampling", leave=False, dynamic_ncols=True, position=1)
output_root = args.output_root
if output_root is None:
output_root = f"results/{args.model}/{args.precision}/"
for dataset_name in args.datasets:
output_dirname = os.path.join(output_root, dataset_name)
os.makedirs(output_dirname, exist_ok=True)
dataset = get_dataset(name=dataset_name)
if args.chunk_step > 1:
dataset = dataset.select(range(args.chunk_start, len(dataset), args.chunk_step))
for row in tqdm(dataset):
filename = row["filename"]
prompt = row["prompt"]
seed = hash_str_to_int(filename)
image = pipeline(
prompt,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
generator=torch.Generator().manual_seed(seed),
).images[0]
image.save(os.path.join(output_dirname, f"{filename}.png"))
if __name__ == "__main__":
main()
import argparse
import os
import torch
from utils import get_pipeline
from vars import PROMPT_TEMPLATES
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"-m", "--model", type=str, default="schnell", choices=["schnell", "dev"], help="Which FLUX.1 model to use"
)
parser.add_argument(
"-p", "--precision", type=str, default="int4", choices=["int4", "bf16"], help="Which precision to use"
)
parser.add_argument(
"--prompt", type=str, default="A cat holding a sign that says hello world", help="Prompt for the image"
)
parser.add_argument("--seed", type=int, default=2333, help="Random seed (-1 for random)")
parser.add_argument("-t", "--num-inference-steps", type=int, default=4, help="Number of inference steps")
parser.add_argument("-o", "--output-path", type=str, default="output.png", help="Image output path")
parser.add_argument("-g", "--guidance-scale", type=float, default=0, help="Guidance scale.")
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
known_args, _ = parser.parse_known_args()
if known_args.model == "dev":
parser.set_defaults(num_inference_steps=50, guidance_scale=3.5)
parser.add_argument(
"-n",
"--lora-name",
type=str,
default="None",
choices=PROMPT_TEMPLATES.keys(),
help="Name of the LoRA layer",
)
parser.add_argument("-a", "--lora-weight", type=float, default=1, help="Weight of the LoRA layer")
args = parser.parse_args()
return args
def main():
args = get_args()
pipeline = get_pipeline(
model_name=args.model,
precision=args.precision,
use_qencoder=args.use_qencoder,
lora_name=getattr(args, "lora_name", "None"),
lora_weight=getattr(args, "lora_weight", 1),
device="cuda",
)
if args.model == "dev":
prompt = PROMPT_TEMPLATES[args.lora_name].format(prompt=args.prompt)
else:
prompt = args.prompt
image = pipeline(
prompt=prompt,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
generator=torch.Generator().manual_seed(args.seed) if args.seed >= 0 else None,
).images[0]
output_dir = os.path.dirname(os.path.abspath(os.path.expanduser(args.output_path)))
os.makedirs(output_dir, exist_ok=True)
image.save(args.output_path)
if __name__ == "__main__":
main()
import argparse
import json
import os
from data import get_dataset
from metrics.fid import compute_fid
from metrics.image_reward import compute_image_reward
from metrics.multimodal import compute_image_multimodal_metrics
from metrics.similarity import compute_image_similarity_metrics
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("input_roots", type=str, nargs="*")
parser.add_argument("-o", "--output-path", type=str, default="metrics.json", help="Image output path")
args = parser.parse_args()
return args
def main():
args = get_args()
assert len(args.input_roots) > 0
assert len(args.input_roots) <= 2
image_root1 = args.input_roots[0]
if len(args.input_roots) > 1:
image_root2 = args.input_roots[1]
else:
image_root2 = None
results = {}
dataset_names = sorted(os.listdir(image_root1))
for dataset_name in dataset_names:
print("##Results for dataset:", dataset_name)
results[dataset_name] = {}
dataset = get_dataset(name=dataset_name, return_gt=True)
fid = compute_fid(ref_dirpath_or_dataset=dataset, gen_dirpath=os.path.join(image_root1, dataset_name))
results[dataset_name]["fid"] = fid
print("FID:", fid)
multimodal_metrics = compute_image_multimodal_metrics(
ref_dataset=dataset, gen_dirpath=os.path.join(image_root1, dataset_name)
)
results[dataset_name].update(multimodal_metrics)
for k, v in multimodal_metrics.items():
print(f"{k}:", v)
image_reward = compute_image_reward(ref_dataset=dataset, gen_dirpath=os.path.join(image_root1, dataset_name))
results[dataset_name].update(image_reward)
for k, v in image_reward.items():
print(f"{k}:", v)
if image_root2 is not None and os.path.exists(os.path.join(image_root2, dataset_name)):
similarity_results = compute_image_similarity_metrics(
os.path.join(image_root1, dataset_name), os.path.join(image_root2, dataset_name)
)
results[dataset_name].update(similarity_results)
for k, v in similarity_results.items():
print(f"{k}:", v)
os.makedirs(os.path.dirname(os.path.abspath(args.output_path)), exist_ok=True)
with open(args.output_path, "w") as f:
json.dump(results, f, indent=2, sort_keys=True)
if __name__ == "__main__":
main()
import argparse
import time
import torch
from tqdm import trange
from utils import get_pipeline
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"-m", "--model", type=str, default="schnell", choices=["schnell", "dev"], help="Which FLUX.1 model to use"
)
parser.add_argument(
"-p", "--precision", type=str, default="int4", choices=["int4", "bf16"], help="Which precision to use"
)
parser.add_argument("-t", "--num-inference-steps", type=int, default=4, help="Number of inference steps")
parser.add_argument("-g", "--guidance-scale", type=float, default=0, help="Guidance scale")
# Test related
parser.add_argument("--warmup-times", type=int, default=2, help="Number of warmup times")
parser.add_argument("--test-times", type=int, default=10, help="Number of test times")
parser.add_argument(
"--mode",
type=str,
default="end2end",
choices=["end2end", "step"],
help="Measure mode: end-to-end latency or per-step latency",
)
parser.add_argument(
"--ignore_ratio", type=float, default=0.2, help="Ignored ratio of the slowest and fastest steps"
)
known_args, _ = parser.parse_known_args()
if known_args.model == "dev":
parser.set_defaults(num_inference_steps=50, guidance_scale=3.5)
args = parser.parse_args()
return args
def main():
args = get_args()
pipeline = get_pipeline(model_name=args.model, precision=args.precision, device="cuda")
dummy_prompt = "A cat holding a sign that says hello world"
latency_list = []
if args.mode == "end2end":
pipeline.set_progress_bar_config(position=1, desc="Step", leave=False)
for _ in trange(args.warmup_times, desc="Warmup", position=0, leave=False):
pipeline(
prompt=dummy_prompt,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
)
torch.cuda.synchronize()
for _ in trange(args.test_times, desc="Warmup", position=0, leave=False):
start_time = time.time()
pipeline(
prompt=dummy_prompt,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
)
torch.cuda.synchronize()
end_time = time.time()
latency_list.append(end_time - start_time)
elif args.mode == "step":
pass
latency_list = sorted(latency_list)
ignored_count = int(args.ignore_ratio * len(latency_list) / 2)
if ignored_count > 0:
latency_list = latency_list[ignored_count:-ignored_count]
print(f"Latency: {sum(latency_list) / len(latency_list):.5f} s")
if __name__ == "__main__":
main()
import logging
import os
from .fid import compute_fid
from .image_reward import compute_image_reward
from .multimodal import compute_image_multimodal_metrics
from .similarity import compute_image_similarity_metrics
from ..benchmarks import get_dataset
logging.getLogger("PIL").setLevel(logging.WARNING)
__all__ = ["compute_image_metrics"]
def compute_image_metrics(
gen_root: str,
benchmarks: str | tuple[str, ...] = ("DCI", "GenAIBench", "GenEval", "MJHQ", "T2ICompBench"),
max_dataset_size: int = -1,
chunk_start: int = 0,
chunk_step: int = 1,
chunk_only: bool = False,
ref_root: str = "",
gt_stats_root: str = "",
gt_metrics: tuple[str, ...] = ("clip_iqa", "clip_score", "image_reward", "fid"),
ref_metrics: tuple[str, ...] = ("psnr", "lpips", "ssim", "fid"),
) -> dict:
if chunk_start == 0 and chunk_step == 1:
chunk_only = False
assert chunk_start == 0 and chunk_step == 1, "Chunking is not supported for image benchmarks."
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if isinstance(benchmarks, str):
benchmarks = (benchmarks,)
gt_multimodal_metrics, gt_similarity_metrics, gt_other_metrics = categorize_metrics(gt_metrics)
_, ref_similarity_metrics, ref_other_metrics = categorize_metrics(ref_metrics)
results = {}
for benchmark in benchmarks:
benchmark_results = {}
dataset = get_dataset(benchmark, max_dataset_size=max_dataset_size, return_gt=True)
dirname = f"{dataset.config_name}-{dataset._unchunk_size}"
if dataset._chunk_start == 0 and dataset._chunk_step == 1:
filename = f"{dirname}.npz"
else:
filename = os.path.join(dirname, f"{dataset._chunk_start}-{dataset._chunk_step}.npz")
if chunk_only:
dirname += f".{dataset._chunk_start}.{dataset._chunk_step}"
gen_dirpath = os.path.join(gen_root, "samples", benchmark, dirname)
if gt_metrics:
gt_results = compute_image_multimodal_metrics(dataset, gen_dirpath, metrics=gt_multimodal_metrics)
if "image_reward" in gt_other_metrics:
gt_results.update(compute_image_reward(dataset, gen_dirpath))
if benchmark in ("COCO", "DCI", "MJHQ"):
gt_results.update(compute_image_similarity_metrics(dataset, gen_dirpath, metrics=gt_similarity_metrics))
if "fid" in gt_other_metrics:
gt_results["fid"] = compute_fid(
dataset,
gen_dirpath,
ref_cache_path=(os.path.join(gt_stats_root, benchmark, filename) if gt_stats_root else None),
gen_cache_path=os.path.join(gen_root, "fid_stats", benchmark, filename),
)
benchmark_results["with_gt"] = gt_results
if ref_root and ref_metrics:
assert os.path.exists(ref_root), f"Reference root directory {ref_root} does not exist."
ref_dirpath = os.path.join(ref_root, "samples", benchmark, dirname)
ref_results = compute_image_similarity_metrics(ref_dirpath, gen_dirpath, metrics=ref_similarity_metrics)
if "fid" in ref_other_metrics:
ref_results["fid"] = compute_fid(
ref_dirpath,
gen_dirpath,
ref_cache_path=os.path.join(ref_root, "fid_stats", benchmark, filename),
gen_cache_path=os.path.join(gen_root, "fid_stats", benchmark, filename),
)
benchmark_results["with_orig"] = ref_results
print(f"{dirname} results:")
print(benchmark_results)
results[dirname] = benchmark_results
return results
def categorize_metrics(metrics: tuple[str, ...]) -> tuple[list[str], list[str], list[str]]:
"""
Categorize metrics into multimodal, similarity, and other metrics.
Args:
metrics (tuple[str, ...]): List of metrics.
Returns:
tuple[list[str], list[str], list[str]]: Tuple of multimodal, similarity, and other metrics.
"""
metrics = tuple(set(metrics))
multimodal_metrics, similarity_metrics, other_metrics = [], [], []
for metric in metrics:
if metric in ("clip_iqa", "clip_score"):
multimodal_metrics.append(metric)
elif metric in ("psnr", "lpips", "ssim"):
similarity_metrics.append(metric)
else:
other_metrics.append(metric)
return multimodal_metrics, similarity_metrics, other_metrics
import os
from datetime import datetime
import numpy as np
import torch
import torchvision
from cleanfid import fid
from cleanfid.resize import build_resizer
from datasets import Dataset
from tqdm import tqdm
def get_dataset_features(
dataset: Dataset,
model,
mode: str = "clean",
batch_size: int = 128,
device: str | torch.device = "cuda",
) -> np.ndarray:
to_tensor = torchvision.transforms.ToTensor()
fn_resize = build_resizer(mode)
np_feats = []
for batch in tqdm(
dataset.iter(batch_size=batch_size, drop_last_batch=False),
desc=f"Extracting {dataset.config_name} features",
total=(len(dataset) + batch_size - 1) // batch_size,
):
resized_images = [fn_resize(np.array(image.convert("RGB"))) for image in batch["image"]]
image_tensors = []
for resized_image in resized_images:
if resized_image.dtype == "uint8":
image_tensor = to_tensor(resized_image) * 255
else:
image_tensor = to_tensor(resized_image)
image_tensors.append(image_tensor)
image_tensors = torch.stack(image_tensors, dim=0)
np_feats.append(fid.get_batch_features(image_tensors, model, device))
np_feats = np.concatenate(np_feats, axis=0)
return np_feats
def get_fid_features(
dataset_or_folder: str | Dataset | None = None,
cache_path: str | None = None,
num: int | None = None,
mode: str = "clean",
num_workers: int = 8,
batch_size: int = 64,
device: str | torch.device = "cuda",
force_overwrite: bool = False,
verbose: bool = True,
) -> tuple[np.ndarray, np.ndarray]:
if cache_path is not None and os.path.exists(cache_path) and not force_overwrite:
npz = np.load(cache_path)
mu, sigma = npz["mu"], npz["sigma"]
else:
feat_model = fid.build_feature_extractor(mode, device)
if isinstance(dataset_or_folder, str):
np_feats = fid.get_folder_features(
dataset_or_folder,
feat_model,
num_workers=num_workers,
num=num,
batch_size=batch_size,
device=device,
verbose=verbose,
mode=mode,
description=f"Extracting {dataset_or_folder} features",
)
else:
assert isinstance(dataset_or_folder, Dataset)
np_feats = get_dataset_features(
dataset_or_folder, model=feat_model, mode=mode, batch_size=batch_size, device=device
)
mu = np.mean(np_feats, axis=0)
sigma = np.cov(np_feats, rowvar=False)
if cache_path is not None:
os.makedirs(os.path.abspath(os.path.dirname(cache_path)), exist_ok=True)
np.savez(cache_path, mu=mu, sigma=sigma)
return mu, sigma
def compute_fid(
ref_dirpath_or_dataset: str | Dataset,
gen_dirpath: str,
ref_cache_path: str | None = None,
gen_cache_path: str | None = None,
use_symlink: bool = True,
timestamp: str | None = None,
) -> float:
sym_ref_dirpath, sym_gen_dirpath = None, None
if use_symlink:
if timestamp is None:
timestamp = datetime.now().strftime("%y%m%d.%H%M%S")
os.makedirs(".tmp", exist_ok=True)
if isinstance(ref_dirpath_or_dataset, str):
sym_ref_dirpath = os.path.join(".tmp", f"ref-{hash(str(ref_dirpath_or_dataset))}-{timestamp}")
os.symlink(os.path.abspath(ref_dirpath_or_dataset), os.path.abspath(sym_ref_dirpath))
ref_dirpath_or_dataset = sym_ref_dirpath
sym_gen_dirpath = os.path.join(".tmp", f"gen-{hash(str(gen_dirpath))}-{timestamp}")
os.symlink(os.path.abspath(gen_dirpath), os.path.abspath(sym_gen_dirpath))
gen_dirpath = sym_gen_dirpath
mu1, sigma1 = get_fid_features(dataset_or_folder=ref_dirpath_or_dataset, cache_path=ref_cache_path)
mu2, sigma2 = get_fid_features(dataset_or_folder=gen_dirpath, cache_path=gen_cache_path)
fid_score = fid.frechet_distance(mu1, sigma1, mu2, sigma2)
fid_score = float(fid_score)
if use_symlink:
if sym_ref_dirpath is not None:
os.remove(sym_ref_dirpath)
os.remove(sym_gen_dirpath)
return fid_score
import os
import ImageReward as RM
import datasets
import torch
from tqdm import tqdm
def compute_image_reward(
ref_dataset: datasets.Dataset,
gen_dirpath: str,
) -> dict[str, float]:
scores = []
model = RM.load("ImageReward-v1.0")
for batch in tqdm(
ref_dataset.iter(batch_size=1, drop_last_batch=False),
desc=f"{ref_dataset.config_name} image reward",
total=len(ref_dataset),
dynamic_ncols=True,
):
filename = batch["filename"][0]
path = os.path.join(gen_dirpath, f"{filename}.png")
prompt = batch["prompt"][0]
with torch.inference_mode():
score = model.score(prompt, path)
scores.append(score)
result = {"image_reward": sum(scores) / len(scores)}
return result
import os
import datasets
import numpy as np
import torch
import torchmetrics
import torchvision
from PIL import Image
from torch.utils import data
from torchmetrics.multimodal import CLIPImageQualityAssessment, CLIPScore
from tqdm import tqdm
class PromptImageDataset(data.Dataset):
def __init__(self, ref_dataset: datasets.Dataset, gen_dirpath: str):
super(data.Dataset, self).__init__()
self.ref_dataset, self.gen_dirpath = ref_dataset, gen_dirpath
self.transform = torchvision.transforms.ToTensor()
def __len__(self):
return len(self.ref_dataset)
def __getitem__(self, idx: int):
row = self.ref_dataset[idx]
gen_image = Image.open(os.path.join(self.gen_dirpath, row["filename"] + ".png")).convert("RGB")
gen_tensor = torch.from_numpy(np.array(gen_image)).permute(2, 0, 1)
prompt = row["prompt"]
return [gen_tensor, prompt]
def compute_image_multimodal_metrics(
ref_dataset: datasets.Dataset,
gen_dirpath: str,
metrics: tuple[str, ...] = ("clip_iqa", "clip_score"),
batch_size: int = 64,
num_workers: int = 8,
device: str | torch.device = "cuda",
) -> dict[str, float]:
if len(metrics) == 0:
return {}
metric_names = metrics
metrics: dict[str, torchmetrics.Metric] = {}
for metric_name in metric_names:
if metric_name == "clip_iqa":
metric = CLIPImageQualityAssessment(model_name_or_path="openai/clip-vit-large-patch14").to(device)
elif metric_name == "clip_score":
metric = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14").to(device)
else:
raise NotImplementedError(f"Metric {metric_name} is not implemented")
metrics[metric_name] = metric
dataset = PromptImageDataset(ref_dataset, gen_dirpath)
dataloader = data.DataLoader(
dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, drop_last=False
)
with torch.no_grad():
for i, batch in enumerate(tqdm(dataloader, desc=f"{ref_dataset.config_name} multimodal metrics")):
batch[0] = batch[0].to(device)
for metric_name, metric in metrics.items():
if metric_name == "clip_iqa":
metric.update(batch[0].to(torch.float32))
else:
prompts = list(batch[1])
metric.update(batch[0], prompts)
result = {metric_name: metric.compute().mean().item() for metric_name, metric in metrics.items()}
return result
import os
import datasets
import torch
import torchmetrics
import torchvision
from PIL import Image
from torch.utils import data
from torchmetrics.image import (
LearnedPerceptualImagePatchSimilarity,
PeakSignalNoiseRatio,
StructuralSimilarityIndexMeasure,
)
from tqdm import tqdm
class MultiImageDataset(data.Dataset):
def __init__(self, gen_dirpath: str, ref_dirpath_or_dataset: str | datasets.Dataset):
super(data.Dataset, self).__init__()
self.gen_names = sorted(
[name for name in os.listdir(gen_dirpath) if name.endswith(".png") or name.endswith(".jpg")]
)
self.gen_dirpath, self.ref_dirpath_or_dataset = gen_dirpath, ref_dirpath_or_dataset
if isinstance(ref_dirpath_or_dataset, str):
self.ref_names = sorted(
[name for name in os.listdir(ref_dirpath_or_dataset) if name.endswith(".png") or name.endswith(".jpg")]
)
assert len(self.ref_names) == len(self.gen_names)
else:
assert isinstance(ref_dirpath_or_dataset, datasets.Dataset)
self.ref_names = self.gen_names
assert len(ref_dirpath_or_dataset) == len(self.gen_names)
self.transform = torchvision.transforms.ToTensor()
def __len__(self):
return len(self.ref_names)
def __getitem__(self, idx: int):
if isinstance(self.ref_dirpath_or_dataset, str):
name = self.ref_names[idx]
assert name == self.gen_names[idx]
ref_image = Image.open(os.path.join(self.ref_dirpath_or_dataset, name)).convert("RGB")
else:
row = self.ref_dirpath_or_dataset[idx]
ref_image = row["image"].convert("RGB")
name = row["filename"] + ".png"
gen_image = Image.open(os.path.join(self.gen_dirpath, name)).convert("RGB")
gen_size = gen_image.size
ref_size = ref_image.size
if ref_size != gen_size:
ref_image = ref_image.resize(gen_size, Image.Resampling.BICUBIC)
gen_tensor = self.transform(gen_image)
ref_tensor = self.transform(ref_image)
return [gen_tensor, ref_tensor]
def compute_image_similarity_metrics(
ref_dirpath_or_dataset: str | datasets.Dataset,
gen_dirpath: str,
metrics: tuple[str, ...] = ("psnr", "lpips", "ssim"),
batch_size: int = 64,
num_workers: int = 8,
device: str | torch.device = "cuda",
) -> dict[str, float]:
if len(metrics) == 0:
return {}
metric_names = metrics
metrics: dict[str, torchmetrics.Metric] = {}
for metric_name in metric_names:
if metric_name == "psnr":
metric = PeakSignalNoiseRatio(data_range=(0, 1), reduction="elementwise_mean", dim=(1, 2, 3)).to(device)
elif metric_name == "lpips":
metric = LearnedPerceptualImagePatchSimilarity(normalize=True).to(device)
elif metric_name == "ssim":
metric = StructuralSimilarityIndexMeasure(data_range=(0, 1)).to(device)
else:
raise NotImplementedError(f"Metric {metric_name} is not implemented")
metrics[metric_name] = metric
dataset = MultiImageDataset(gen_dirpath, ref_dirpath_or_dataset)
dataloader = data.DataLoader(
dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, drop_last=False
)
with torch.no_grad():
desc = (
ref_dirpath_or_dataset.config_name
if isinstance(ref_dirpath_or_dataset, datasets.Dataset)
else os.path.basename(ref_dirpath_or_dataset)
) + " similarity metrics"
for i, batch in enumerate(tqdm(dataloader, desc=desc)):
batch = [tensor.to(device) for tensor in batch]
for metric in metrics.values():
metric.update(batch[0], batch[1])
result = {metric_name: metric.compute().item() for metric_name, metric in metrics.items()}
return result
# Changed from https://huggingface.co/spaces/playgroundai/playground-v2.5/blob/main/app.py
import argparse
import random
import time
import GPUtil
import gradio as gr
import spaces
import torch
from peft.tuners import lora
from nunchaku.models.safety_checker import SafetyChecker
from utils import get_pipeline
from vars import DEFAULT_HEIGHT, DEFAULT_WIDTH, EXAMPLES, MAX_SEED, PROMPT_TEMPLATES, SVDQ_LORA_PATHS
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"-m", "--model", type=str, default="schnell", choices=["schnell", "dev"], help="Which FLUX.1 model to use"
)
parser.add_argument(
"-p",
"--precisions",
type=str,
default=["int4"],
nargs="*",
choices=["int4", "bf16"],
help="Which precisions to use",
)
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
parser.add_argument("--no-safety-checker", action="store_true", help="Disable safety checker")
return parser.parse_args()
args = get_args()
pipelines = []
for i, precision in enumerate(args.precisions):
pipeline = get_pipeline(
model_name=args.model,
precision=precision,
use_qencoder=args.use_qencoder,
device=f"cuda:{i}",
lora_name="All",
)
pipeline.cur_lora_name = "None"
pipeline.cur_lora_weight = 0
pipelines.append(pipeline)
safety_checker = SafetyChecker("cuda", disabled=args.no_safety_checker)
@spaces.GPU(enable_queue=True)
def generate(
prompt: str = None,
height: int = 1024,
width: int = 1024,
num_inference_steps: int = 4,
guidance_scale: float = 0,
lora_name: str = "None",
lora_weight: float = 1,
seed: int = 0,
):
is_unsafe_prompt = False
if not safety_checker(prompt):
is_unsafe_prompt = True
prompt = "A peaceful world."
prompt = PROMPT_TEMPLATES[lora_name].format(prompt=prompt)
images, latency_strs = [], []
for i, pipeline in enumerate(pipelines):
precision = args.precisions[i]
progress = gr.Progress(track_tqdm=True)
if pipeline.cur_lora_name != lora_name:
if precision == "bf16":
for m in pipeline.transformer.modules():
if isinstance(m, lora.LoraLayer):
if pipeline.cur_lora_name != "None":
if pipeline.cur_lora_name in m.scaling:
m.scaling[pipeline.cur_lora_name] = 0
if lora_name != "None":
if lora_name in m.scaling:
m.scaling[lora_name] = lora_weight
else:
assert precision == "int4"
if lora_name != "None":
pipeline.transformer.nunchaku_update_params(SVDQ_LORA_PATHS[lora_name])
pipeline.transformer.nunchaku_set_lora_scale(lora_weight)
else:
pipeline.transformer.nunchaku_set_lora_scale(0)
elif lora_name != "None":
if precision == "bf16":
if pipeline.cur_lora_weight != lora_weight:
for m in pipeline.transformer.modules():
if isinstance(m, lora.LoraLayer):
if lora_name in m.scaling:
m.scaling[lora_name] = lora_weight
else:
assert precision == "int4"
pipeline.transformer.nunchaku_set_lora_scale(lora_weight)
pipeline.cur_lora_name = lora_name
pipeline.cur_lora_weight = lora_weight
start_time = time.time()
image = pipeline(
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=torch.Generator().manual_seed(seed),
).images[0]
end_time = time.time()
latency = end_time - start_time
if latency < 1:
latency = latency * 1000
latency_str = f"{latency:.2f}ms"
else:
latency_str = f"{latency:.2f}s"
images.append(image)
latency_strs.append(latency_str)
if is_unsafe_prompt:
for i in range(len(latency_strs)):
latency_strs[i] += " (Unsafe prompt detected)"
torch.cuda.empty_cache()
return *images, *latency_strs
with open("./assets/description.html", "r") as f:
DESCRIPTION = f.read()
gpus = GPUtil.getGPUs()
if len(gpus) > 0:
gpu = gpus[0]
memory = gpu.memoryTotal / 1024
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
else:
device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = f'<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
with gr.Blocks(
css_paths=[f"assets/frame{len(args.precisions)}.css", "assets/common.css"],
title=f"SVDQuant FLUX.1-{args.model} Demo",
) as demo:
gr.HTML(DESCRIPTION.format(model=args.model, device_info=device_info, notice=notice))
with gr.Row():
image_results, latency_results = [], []
for i, precision in enumerate(args.precisions):
with gr.Column():
gr.Markdown(f"# {precision.upper()}", elem_id="image_header")
with gr.Group():
image_result = gr.Image(
format="png",
image_mode="RGB",
label="Result",
show_label=False,
show_download_button=True,
interactive=False,
)
latency_result = gr.Text(label="Inference Latency", show_label=True)
image_results.append(image_result)
latency_results.append(latency_result)
with gr.Row():
prompt = gr.Text(
label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, scale=4
)
run_button = gr.Button("Run", scale=1)
if args.model == "dev":
with gr.Row():
lora_name = gr.Dropdown(label="LoRA Name", choices=PROMPT_TEMPLATES.keys(), value="None", scale=1)
prompt_template = gr.Textbox(
label="LoRA Prompt Template", value=PROMPT_TEMPLATES["None"], scale=1, max_lines=1
)
else:
lora_name = "None"
with gr.Row():
seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
with gr.Accordion("Advanced options", open=False):
with gr.Group():
if args.model == "schnell":
num_inference_steps = gr.Slider(label="Sampling Steps", minimum=1, maximum=8, step=1, value=4)
guidance_scale = 0
lora_weight = 0
elif args.model == "dev":
num_inference_steps = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, step=1, value=25)
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, step=0.1, value=3.5)
lora_weight = gr.Slider(label="LoRA Weight", minimum=0, maximum=2, step=0.1, value=1)
else:
raise NotImplementedError(f"Model {args.model} not implemented")
if args.model == "schnell":
def generate_func(prompt, num_inference_steps, seed):
return generate(
prompt, DEFAULT_HEIGHT, DEFAULT_WIDTH, num_inference_steps, guidance_scale, lora_name, lora_weight, seed
)
input_args = [prompt, num_inference_steps, seed]
elif args.model == "dev":
def generate_func(prompt, num_inference_steps, guidance_scale, lora_name, lora_weight, seed):
return generate(
prompt, DEFAULT_HEIGHT, DEFAULT_WIDTH, num_inference_steps, guidance_scale, lora_name, lora_weight, seed
)
input_args = [prompt, num_inference_steps, guidance_scale, lora_name, lora_weight, seed]
gr.Examples(
examples=EXAMPLES[args.model], inputs=input_args, outputs=[*image_results, *latency_results], fn=generate_func
)
gr.on(
triggers=[prompt.submit, run_button.click],
fn=generate_func,
inputs=input_args,
outputs=[*image_results, *latency_results],
api_name="run",
)
randomize_seed.click(
lambda: random.randint(0, MAX_SEED), inputs=[], outputs=seed, api_name=False, queue=False
).then(fn=generate_func, inputs=input_args, outputs=[*image_results, *latency_results], api_name=False, queue=False)
if args.model == "dev":
lora_name.change(
lambda x: PROMPT_TEMPLATES[x],
inputs=[lora_name],
outputs=[prompt_template],
api_name=False,
queue=False,
).then(
fn=generate_func, inputs=input_args, outputs=[*image_results, *latency_results], api_name=False, queue=False
)
if __name__ == "__main__":
demo.queue(max_size=20).launch(server_name="0.0.0.0", debug=True, share=True)
import torch
from diffusers import FluxPipeline
from peft.tuners import lora
from nunchaku.pipelines import flux as nunchaku_flux
from vars import LORA_PATHS, SVDQ_LORA_PATHS
def hash_str_to_int(s: str) -> int:
"""Hash a string to an integer."""
modulus = 10**9 + 7 # Large prime modulus
hash_int = 0
for char in s:
hash_int = (hash_int * 31 + ord(char)) % modulus
return hash_int
def get_pipeline(
model_name: str,
precision: str,
use_qencoder: bool = False,
lora_name: str = "None",
lora_weight: float = 1,
device: str | torch.device = "cuda",
) -> FluxPipeline:
if model_name == "schnell":
if precision == "int4":
assert torch.device(device).type == "cuda", "int4 only supported on CUDA devices"
pipeline = nunchaku_flux.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16,
qmodel_path="mit-han-lab/svdquant-models/svdq-int4-flux.1-schnell.safetensors",
qencoder_path="mit-han-lab/svdquant-models/svdq-w4a16-t5.pt" if use_qencoder else None,
qmodel_device=device,
)
else:
assert precision == "bf16"
pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
elif model_name == "dev":
if precision == "int4":
pipeline = nunchaku_flux.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
qmodel_path="mit-han-lab/svdquant-models/svdq-int4-flux.1-dev.safetensors",
qencoder_path="mit-han-lab/svdquant-models/svdq-w4a16-t5.pt" if use_qencoder else None,
qmodel_device=device,
)
if lora_name not in ["All", "None"]:
pipeline.transformer.nunchaku_update_params(SVDQ_LORA_PATHS[lora_name])
pipeline.transformer.nunchaku_set_lora_scale(lora_weight)
else:
assert precision == "bf16"
pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
if lora_name == "All":
# Pre-load all the LoRA weights for demo use
for name, path in LORA_PATHS.items():
pipeline.load_lora_weights(path["name_or_path"], weight_name=path["weight_name"], adapter_name=name)
for m in pipeline.transformer.modules():
if isinstance(m, lora.LoraLayer):
m.set_adapter(m.scaling.keys())
for name in m.scaling.keys():
m.scaling[name] = 0
elif lora_name != "None":
path = LORA_PATHS[lora_name]
pipeline.load_lora_weights(
path["name_or_path"], weight_name=path["weight_name"], adapter_name=lora_name
)
for m in pipeline.transformer.modules():
if isinstance(m, lora.LoraLayer):
for name in m.scaling.keys():
m.scaling[name] = lora_weight
else:
raise NotImplementedError(f"Model {model_name} not implemented")
pipeline = pipeline.to(device)
return pipeline
MAX_IMAGE_SIZE = 2048
MAX_SEED = 1000000000
DEFAULT_HEIGHT = 1024
DEFAULT_WIDTH = 1024
PROMPT_TEMPLATES = {
"None": "{prompt}",
"Anime": "{prompt}, nm22 style",
"GHIBSKY Illustration": "GHIBSKY style, {prompt}",
"Realism": "{prompt}",
"Yarn Art": "{prompt}, yarn art style",
"Children Sketch": "sketched style, {prompt}",
}
LORA_PATHS = {
"Anime": {
"name_or_path": "alvdansen/sonny-anime-fixed",
"weight_name": "araminta_k_sonnyanime_fluxd_fixed.safetensors",
},
"GHIBSKY Illustration": {
"name_or_path": "aleksa-codes/flux-ghibsky-illustration",
"weight_name": "lora.safetensors",
},
"Realism": {
"name_or_path": "mit-han-lab/FLUX.1-dev-LoRA-Collections",
"weight_name": "realism.safetensors",
},
"Yarn Art": {
"name_or_path": "linoyts/yarn_art_Flux_LoRA",
"weight_name": "pytorch_lora_weights.safetensors",
},
"Children Sketch": {
"name_or_path": "mit-han-lab/FLUX.1-dev-LoRA-Collections",
"weight_name": "sketch.safetensors",
},
}
SVDQ_LORA_PATH_FORMAT = "mit-han-lab/svdquant-models/svdq-flux.1-dev-lora-{name}.safetensors"
SVDQ_LORA_PATHS = {
"Anime": SVDQ_LORA_PATH_FORMAT.format(name="anime"),
"GHIBSKY Illustration": SVDQ_LORA_PATH_FORMAT.format(name="ghibsky"),
"Realism": SVDQ_LORA_PATH_FORMAT.format(name="realism"),
"Yarn Art": SVDQ_LORA_PATH_FORMAT.format(name="yarn"),
"Children Sketch": SVDQ_LORA_PATH_FORMAT.format(name="sketch"),
}
EXAMPLES = {
"schnell": [
[
"An elegant, art deco-style cat with sleek, geometric fur patterns reclining next to a polished sign that "
"reads 'MIT HAN Lab' in bold, stylized typography. The sign, framed in gold and silver, "
"exudes a sophisticated, 1920s flair, with ambient light casting a warm glow around it.",
4,
1,
],
[
"Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, "
"volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, "
"art nouveau style, illustration art artwork by SenseiJaye, intricate detail.",
4,
2,
],
[
"A worker that looks like a mixture of cow and horse is working hard to type code.",
4,
3,
],
[
"A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. "
"She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. "
"She wears sunglasses and red lipstick. She walks confidently and casually. "
"The street is damp and reflective, creating a mirror effect of the colorful lights. "
"Many pedestrians walk about.",
4,
4,
],
[
"Cozy bedroom with vintage wooden furniture and a large circular window covered in lush green vines, "
"opening to a misty forest. Soft, ambient lighting highlights the bed with crumpled blankets, a bookshelf, "
"and a desk. The atmosphere is serene and natural. 8K resolution, highly detailed, photorealistic, "
"cinematic lighting, ultra-HD.",
4,
5,
],
[
"A photo of a Eurasian lynx in a sunlit forest, with tufted ears and a spotted coat. The lynx should be "
"sharply focused, gazing into the distance, while the background is softly blurred for depth. Use cinematic "
"lighting with soft rays filtering through the trees, and capture the scene with a shallow depth of field "
"for a natural, peaceful atmosphere. 8K resolution, highly detailed, photorealistic, "
"cinematic lighting, ultra-HD.",
4,
6,
],
],
"dev": [
[
'a cyberpunk cat holding a huge neon sign that says "SVDQuant is lite and fast", wearing fancy goggles and '
"a black leather jacket.",
25,
3.5,
"None",
0,
2,
],
["a dog wearing a wizard hat", 28, 3.5, "Anime", 1, 23],
[
"a fisherman casting a line into a peaceful village lake surrounded by quaint cottages",
28,
3.5,
"GHIBSKY Illustration",
1,
233,
],
["a man in armor with a beard and a sword", 25, 3.5, "Realism", 0.9, 2333],
["a panda playing in the snow", 28, 3.5, "Yarn Art", 1, 23333],
["A squirrel wearing glasses and reading a tiny book under an oak tree", 24, 3.5, "Children Sketch", 1, 233333],
],
}
<?xml version="1.0" encoding="UTF-8"?>
<svg id="_图层_1" data-name="图层 1" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 632.8 69.84">
<defs>
<style>
.cls-1 {
fill: #8a8b8d;
}
.cls-2 {
fill: #a41e35;
}
</style>
</defs>
<polygon class="cls-2" points="443.27 37.07 447.42 50.74 433.75 50.74 437.9 37.07 443.27 37.07"/>
<path class="cls-2" d="M64.81,29.39c-2.1-1.52-4.55-2.66-7.34-3.44-2.8-.77-5.83-1.16-9.09-1.16h-23.29c-1.8-.19-3.32-.71-4.55-1.58-1.23-.87-1.85-2.1-1.85-3.71,0-1.92.83-3.3,2.5-4.13,1.66-.84,3.63-1.25,5.9-1.25h21.36V0h-24.46c-3.33,0-6.38.4-9.14,1.21-2.77.81-5.18,1.97-7.25,3.48-2.07,1.52-3.75,3.41-5.05,5.67-1.3,2.26-2.15,4.85-2.55,7.76v2.69c.27,2.91,1.06,5.5,2.4,7.76,1.33,2.26,3.05,4.16,5.15,5.71,2.1,1.55,4.55,2.71,7.35,3.48,2.8.77,5.83,1.16,9.09,1.16h23.18c1.8.25,3.33.77,4.6,1.58,1.26.81,1.9,2.01,1.9,3.62,0,1.92-.83,3.3-2.5,4.13-1.67.84-3.63,1.25-5.9,1.25H0v14.12h48.37c3.26,0,6.3-.4,9.09-1.21s5.23-1.96,7.3-3.48c2.06-1.52,3.75-3.4,5.05-5.67,1.3-2.26,2.15-4.84,2.55-7.76v-2.69c-.33-2.91-1.15-5.49-2.45-7.76-1.3-2.26-3-4.15-5.1-5.67Z"/>
<rect class="cls-1" x="53.26" width="19.09" height="14.12"/>
<polygon class="cls-2" points="152.3 0 128.72 63.62 103.93 63.62 80.45 0 98.84 0 116.33 48.2 134.22 0 152.3 0"/>
<path class="cls-2" d="M231,11.93c-1.17-2.57-2.8-4.74-4.9-6.5-2.1-1.76-4.61-3.11-7.55-4.04-2.93-.93-6.16-1.39-9.69-1.39h-26.14v14.12h23.05c2.8,0,4.95.68,6.45,2.04,1.5,1.36,2.25,3.38,2.25,6.04v19.22c0,2.66-.75,4.68-2.25,6.04-1.5,1.36-3.65,2.04-6.45,2.04h-26.98v-24.7h-18.39v38.82h48.47c3.53,0,6.76-.45,9.69-1.35,2.93-.9,5.45-2.23,7.55-3.99,2.1-1.77,3.73-3.95,4.9-6.55,1.16-2.6,1.75-5.57,1.75-8.92v-22.01c0-3.34-.58-6.3-1.75-8.87Z"/>
<rect class="cls-1" x="160.4" width="18.63" height="14.12"/>
<path class="cls-1" d="M315.7,69.84h-19.09l-5.6-6.22h-23.79c-3.53,0-6.76-.45-9.69-1.35-2.93-.9-5.45-2.23-7.55-3.99-2.1-1.77-3.73-3.95-4.9-6.55-1.17-2.6-1.75-5.57-1.75-8.92v-22.01c0-3.34.58-6.3,1.75-8.87,1.16-2.57,2.8-4.74,4.9-6.5,2.1-1.76,4.61-3.11,7.55-4.04,2.93-.93,6.16-1.39,9.69-1.39h24.48c3.53,0,6.76.46,9.69,1.39,2.93.93,5.46,2.28,7.6,4.04,2.13,1.77,3.78,3.93,4.95,6.5,1.16,2.57,1.75,5.53,1.75,8.87v22.01c0,3.9-.8,7.28-2.4,10.12-1.6,2.85-3.8,5.17-6.6,6.97l8.99,9.94ZM297.31,41.42v-19.22c0-2.66-.75-4.67-2.25-6.04-1.5-1.36-3.65-2.04-6.45-2.04h-18.19c-2.8,0-4.95.68-6.45,2.04-1.5,1.36-2.25,3.38-2.25,6.04v19.22c0,2.66.75,4.68,2.25,6.04,1.5,1.36,3.65,2.04,6.45,2.04h7.89l-11.29-12.54h19.09l9.19,10.22c1.33-1.42,2-3.34,2-5.76Z"/>
<path class="cls-1" d="M380.26,41.42V0h18.39v42.82c0,3.34-.58,6.31-1.75,8.92-1.17,2.6-2.82,4.78-4.95,6.55-2.13,1.76-4.66,3.1-7.6,3.99-2.93.9-6.16,1.35-9.69,1.35h-24.48c-3.53,0-6.76-.45-9.69-1.35-2.93-.9-5.45-2.23-7.55-3.99-2.1-1.77-3.73-3.95-4.9-6.55-1.17-2.6-1.75-5.57-1.75-8.92V0h18.39v41.42c0,2.66.75,4.68,2.25,6.04,1.5,1.36,3.65,2.04,6.45,2.04h18.19c2.8,0,4.95-.68,6.45-2.04,1.5-1.36,2.25-3.37,2.25-6.04Z"/>
<path class="cls-1" d="M458.01,63.62l-17.59-48.2-17.49,48.2h-18.39L428.03,0h24.79l23.58,63.62h-18.39Z"/>
<polygon class="cls-1" points="538.46 21.84 538.46 39.47 503.18 0 484.49 0 484.49 63.62 502.88 63.62 502.88 24.15 538.26 63.62 556.85 63.62 556.85 21.84 538.46 21.84"/>
<rect class="cls-2" x="538.46" width="18.39" height="14.25"/>
<path class="cls-1" d="M565.55,14.12V0h67.25v14.12h-23.48v49.5h-18.39V14.12h-25.38Z"/>
</svg>
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment