Commit df981d24 authored by muyangli's avatar muyangli
Browse files

[major] fix the evaluation scripts; no need to download the entire model

parent 25ce8942
......@@ -60,9 +60,9 @@ SVDQuant is a post-training quantization technique for 4-bit weights and activat
Then build the package from source:
```shell
git clone https://github.com/mit-han-lab/nunchaku.git
cd nunchaku
git submodule init
git submodule update
cd nunchaku
git submodule init
git submodule update
pip install -e .
```
......@@ -78,7 +78,7 @@ from nunchaku.pipelines import flux as nunchaku_flux
pipeline = nunchaku_flux.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16,
qmodel_path="mit-han-lab/svdquant-models/svdq-int4-flux.1-schnell.safetensors", # download from Huggingface
qmodel_path="mit-han-lab/svdq-int4-flux.1-schnell", # download from Huggingface
).to("cuda")
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=4, guidance_scale=0).images[0]
image.save("example.png")
......
......@@ -4,11 +4,11 @@ from typing import Any, Callable, Optional, Union
import torch
import torchvision.transforms.functional as F
import torchvision.utils
from PIL import Image
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, FluxPipelineOutput, FluxTransformer2DModel
from einops import rearrange
from huggingface_hub import hf_hub_download
from huggingface_hub import hf_hub_download, snapshot_download
from peft.tuners import lora
from PIL import Image
from torch import nn
from nunchaku.models.flux import inject_pipeline, load_quantized_model
......@@ -145,9 +145,7 @@ class FluxPix2pixTurboPipeline(FluxPipeline):
self.erosion_kernel = erosion_kernel
torchvision.utils.save_image(image_t[0], "before.png")
image_t = (
nn.functional.conv2d(image_t[:, :1], erosion_kernel, padding=kernel_size // 2) > kernel_size**2 - 0.1
)
image_t = nn.functional.conv2d(image_t[:, :1], erosion_kernel, padding=kernel_size // 2) > kernel_size**2 - 0.1
image_t = torch.concat([image_t, image_t, image_t], dim=1).to(self.dtype)
torchvision.utils.save_image(image_t[0], "after.png")
......@@ -219,6 +217,8 @@ class FluxPix2pixTurboPipeline(FluxPipeline):
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
qmodel_device = kwargs.pop("qmodel_device", "cuda:0")
qmodel_device = torch.device(qmodel_device)
if qmodel_device.type != "cuda":
raise ValueError(f"qmodel_device = {qmodel_device} is not a CUDA device")
qmodel_path = kwargs.pop("qmodel_path", None)
qencoder_path = kwargs.pop("qencoder_path", None)
......@@ -229,11 +229,12 @@ class FluxPix2pixTurboPipeline(FluxPipeline):
if qmodel_path is not None:
assert isinstance(qmodel_path, str)
if not os.path.exists(qmodel_path):
hf_repo_id = os.path.dirname(qmodel_path)
filename = os.path.basename(qmodel_path)
qmodel_path = hf_hub_download(repo_id=hf_repo_id, filename=filename)
m = load_quantized_model(qmodel_path, 0 if qmodel_device.index is None else qmodel_device.index)
inject_pipeline(pipeline, m)
qmodel_path = snapshot_download(qmodel_path)
m = load_quantized_model(
os.path.join(qmodel_path, "transformer_blocks.safetensors"),
0 if qmodel_device.index is None else qmodel_device.index,
)
inject_pipeline(pipeline, m, qmodel_device)
pipeline.precision = "int4"
if qencoder_path is not None:
......
......@@ -5,15 +5,17 @@ import tempfile
import time
import GPUtil
import gradio as gr
import numpy as np
import torch
from PIL import Image
from flux_pix2pix_pipeline import FluxPix2pixTurboPipeline
from nunchaku.models.safety_checker import SafetyChecker
from utils import get_args
from vars import DEFAULT_SKETCH_GUIDANCE, DEFAULT_STYLE_NAME, MAX_SEED, STYLES, STYLE_NAMES
import numpy as np
from vars import DEFAULT_SKETCH_GUIDANCE, DEFAULT_STYLE_NAME, MAX_SEED, STYLE_NAMES, STYLES
# import gradio last to avoid conflicts with other imports
import gradio as gr
blank_image = Image.new("RGB", (1024, 1024), (255, 255, 255))
......@@ -30,7 +32,7 @@ else:
pipeline = FluxPix2pixTurboPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16,
qmodel_path="mit-han-lab/svdquant-models/svdq-int4-flux.1-schnell.safetensors",
qmodel_path="mit-han-lab/svdq-int4-flux.1-schnell",
qencoder_path="mit-han-lab/svdquant-models/svdq-w4a16-t5.pt" if args.use_qencoder else None,
)
pipeline = pipeline.to("cuda")
......
import cv2
import numpy as np
from PIL import Image
import argparse
......
......@@ -26,7 +26,7 @@ _HOMEPAGE = "https://github.com/facebookresearch/DCI"
_LICENSE = "Attribution-NonCommercial 4.0 International (https://github.com/facebookresearch/DCI/blob/main/LICENSE)"
IMAGE_URL = "https://scontent.xx.fbcdn.net/m1/v/t6/An_zz_Te0EtVC_cHtUwnyNKODapWXuNNPeBgZn_3XY8yDFzwHrNb-zwN9mYCbAeWUKQooCI9mVbwvzZDZzDUlscRjYxLKsw.tar?ccb=10-5&oh=00_AYBnKR-fSIir-E49Q7-qO2tjmY0BGJhCciHS__B5QyiBAg&oe=673FFA8A&_nc_sid=0fdd51"
IMAGE_URL = "https://huggingface.co/datasets/mit-han-lab/svdquant-datasets/resolve/main/sDCI.gz"
PROMPT_URLS = {"sDCI": "https://huggingface.co/datasets/mit-han-lab/svdquant-datasets/resolve/main/sDCI.yaml"}
......
......@@ -31,7 +31,9 @@ def main():
results = {}
dataset_names = sorted(os.listdir(image_root1))
for dataset_name in dataset_names:
print("##Results for dataset:", dataset_name)
if image_root2 is not None and dataset_name not in os.listdir(image_root2):
continue
print("Results for dataset:", dataset_name)
results[dataset_name] = {}
dataset = get_dataset(name=dataset_name, return_gt=True)
fid = compute_fid(ref_dirpath_or_dataset=dataset, gen_dirpath=os.path.join(image_root1, dataset_name))
......
......@@ -2,6 +2,7 @@ import argparse
import time
import torch
from torch import nn
from tqdm import trange
from utils import get_pipeline
......@@ -51,23 +52,38 @@ def main():
pipeline.set_progress_bar_config(position=1, desc="Step", leave=False)
for _ in trange(args.warmup_times, desc="Warmup", position=0, leave=False):
pipeline(
prompt=dummy_prompt,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
prompt=dummy_prompt, num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale
)
torch.cuda.synchronize()
for _ in trange(args.test_times, desc="Warmup", position=0, leave=False):
start_time = time.time()
pipeline(
prompt=dummy_prompt,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
prompt=dummy_prompt, num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale
)
torch.cuda.synchronize()
end_time = time.time()
latency_list.append(end_time - start_time)
elif args.mode == "step":
pass
inputs = {}
def get_input_hook(module: nn.Module, input_args, input_kwargs):
inputs["args"] = input_args
inputs["kwargs"] = input_kwargs
pipeline.transformer.register_forward_pre_hook(get_input_hook, with_kwargs=True)
pipeline(prompt=dummy_prompt, num_inference_steps=1, guidance_scale=args.guidance_scale, output_type="latent")
for _ in trange(args.warmup_times, desc="Warmup", position=0, leave=False):
pipeline.transformer(*inputs["args"], **inputs["kwargs"])
torch.cuda.synchronize()
for _ in trange(args.test_times, desc="Warmup", position=0, leave=False):
start_time = time.time()
pipeline.transformer(*inputs["args"], **inputs["kwargs"])
torch.cuda.synchronize()
end_time = time.time()
latency_list.append(end_time - start_time)
latency_list = sorted(latency_list)
ignored_count = int(args.ignore_ratio * len(latency_list) / 2)
if ignored_count > 0:
......
......@@ -5,7 +5,6 @@ import random
import time
import GPUtil
import gradio as gr
import spaces
import torch
from peft.tuners import lora
......@@ -14,6 +13,9 @@ from nunchaku.models.safety_checker import SafetyChecker
from utils import get_pipeline
from vars import DEFAULT_HEIGHT, DEFAULT_WIDTH, EXAMPLES, MAX_SEED, PROMPT_TEMPLATES, SVDQ_LORA_PATHS
# import gradio last to avoid conflicts with other imports
import gradio as gr
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
......
......@@ -29,7 +29,7 @@ def get_pipeline(
pipeline = nunchaku_flux.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16,
qmodel_path="mit-han-lab/svdquant-models/svdq-int4-flux.1-schnell.safetensors",
qmodel_path="mit-han-lab/svdq-int4-flux.1-schnell",
qencoder_path="mit-han-lab/svdquant-models/svdq-w4a16-t5.pt" if use_qencoder else None,
qmodel_device=device,
)
......@@ -41,7 +41,7 @@ def get_pipeline(
pipeline = nunchaku_flux.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
qmodel_path="mit-han-lab/svdquant-models/svdq-int4-flux.1-dev.safetensors",
qmodel_path="mit-han-lab/svdq-int4-flux.1-dev",
qencoder_path="mit-han-lab/svdquant-models/svdq-w4a16-t5.pt" if use_qencoder else None,
qmodel_device=device,
)
......
......@@ -14,10 +14,11 @@ SVD_RANK = 32
class NunchakuFluxModel(nn.Module):
def __init__(self, m: QuantizedFluxModel):
def __init__(self, m: QuantizedFluxModel, device: torch.device):
super().__init__()
self.m = m
self.dtype = torch.bfloat16
self.device = device
def forward(
self,
......@@ -33,10 +34,12 @@ class NunchakuFluxModel(nn.Module):
img_tokens = hidden_states.shape[1]
original_dtype = hidden_states.dtype
original_device = hidden_states.device
hidden_states = hidden_states.to(self.dtype)
encoder_hidden_states = encoder_hidden_states.to(self.dtype)
temb = temb.to(self.dtype)
hidden_states = hidden_states.to(self.dtype).to(self.device)
encoder_hidden_states = encoder_hidden_states.to(self.dtype).to(self.device)
temb = temb.to(self.dtype).to(self.device)
image_rotary_emb = image_rotary_emb.to(self.device)
assert image_rotary_emb.ndim == 6
assert image_rotary_emb.shape[0] == 1
......@@ -52,7 +55,7 @@ class NunchakuFluxModel(nn.Module):
hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_txt, rotary_emb_single
)
hidden_states = hidden_states.to(original_dtype)
hidden_states = hidden_states.to(original_dtype).to(original_device)
encoder_hidden_states = hidden_states[:, :txt_tokens, ...]
hidden_states = hidden_states[:, txt_tokens:, ...]
......@@ -110,11 +113,11 @@ def load_quantized_model(path: str, device: str | torch.device) -> QuantizedFlux
return m
def inject_pipeline(pipe: FluxPipeline, m: QuantizedFluxModel) -> FluxPipeline:
def inject_pipeline(pipe: FluxPipeline, m: QuantizedFluxModel, device: torch.device) -> FluxPipeline:
net: FluxTransformer2DModel = pipe.transformer
net.pos_embed = EmbedND(dim=net.inner_dim, theta=10000, axes_dim=[16, 56, 56])
net.transformer_blocks = torch.nn.ModuleList([NunchakuFluxModel(m)])
net.transformer_blocks = torch.nn.ModuleList([NunchakuFluxModel(m, device)])
net.single_transformer_blocks = torch.nn.ModuleList([])
def update_params(self: FluxTransformer2DModel, path: str):
......
import os
import torch
from diffusers import FluxPipeline
from huggingface_hub import hf_hub_download
from diffusers import __version__
from diffusers import FluxPipeline, FluxTransformer2DModel
from huggingface_hub import hf_hub_download, snapshot_download
from safetensors.torch import load_file
from torch import nn
from ..models.flux import inject_pipeline, load_quantized_model
......@@ -45,13 +48,34 @@ def from_pretrained(pretrained_model_name_or_path: str | os.PathLike, **kwargs)
qencoder_path = kwargs.pop("qencoder_path", None)
if not os.path.exists(qmodel_path):
hf_repo_id = os.path.dirname(qmodel_path)
filename = os.path.basename(qmodel_path)
qmodel_path = hf_hub_download(repo_id=hf_repo_id, filename=filename)
qmodel_path = snapshot_download(qmodel_path)
pipeline = FluxPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs)
m = load_quantized_model(qmodel_path, 0 if qmodel_device.index is None else qmodel_device.index)
inject_pipeline(pipeline, m)
assert kwargs.pop("transformer", None) is None
config, unused_kwargs, commit_hash = FluxTransformer2DModel.load_config(
pretrained_model_name_or_path,
cache_dir=kwargs.get("cache_dir", None),
return_unused_kwargs=True,
return_commit_hash=True,
force_download=kwargs.get("force_download", False),
proxies=kwargs.get("proxies", None),
local_files_only=kwargs.get("local_files_only", None),
token=kwargs.get("token", None),
revision=kwargs.get("revision", None),
subfolder="transformer",
user_agent={"diffusers": __version__, "file_type": "model", "framework": "pytorch"},
**kwargs,
)
transformer: nn.Module = FluxTransformer2DModel.from_config(config).to(kwargs.get("torch_dtype", torch.bfloat16))
state_dict = load_file(os.path.join(qmodel_path, "unquantized_layers.safetensors"))
transformer.load_state_dict(state_dict, strict=False)
pipeline = FluxPipeline.from_pretrained(pretrained_model_name_or_path, transformer=transformer, **kwargs)
m = load_quantized_model(
os.path.join(qmodel_path, "transformer_blocks.safetensors"),
0 if qmodel_device.index is None else qmodel_device.index,
)
inject_pipeline(pipeline, m, qmodel_device)
if qencoder_path is not None:
assert isinstance(qencoder_path, str)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment