"vscode:/vscode.git/clone" did not exist on "4dc5518e4d2ae89a687709bcbe05d2f3f80e00ad"
Commit df981d24 authored by muyangli's avatar muyangli
Browse files

[major] fix the evaluation scripts; no need to download the entire model

parent 25ce8942
...@@ -60,9 +60,9 @@ SVDQuant is a post-training quantization technique for 4-bit weights and activat ...@@ -60,9 +60,9 @@ SVDQuant is a post-training quantization technique for 4-bit weights and activat
Then build the package from source: Then build the package from source:
```shell ```shell
git clone https://github.com/mit-han-lab/nunchaku.git git clone https://github.com/mit-han-lab/nunchaku.git
cd nunchaku cd nunchaku
git submodule init git submodule init
git submodule update git submodule update
pip install -e . pip install -e .
``` ```
...@@ -78,7 +78,7 @@ from nunchaku.pipelines import flux as nunchaku_flux ...@@ -78,7 +78,7 @@ from nunchaku.pipelines import flux as nunchaku_flux
pipeline = nunchaku_flux.from_pretrained( pipeline = nunchaku_flux.from_pretrained(
"black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
qmodel_path="mit-han-lab/svdquant-models/svdq-int4-flux.1-schnell.safetensors", # download from Huggingface qmodel_path="mit-han-lab/svdq-int4-flux.1-schnell", # download from Huggingface
).to("cuda") ).to("cuda")
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=4, guidance_scale=0).images[0] image = pipeline("A cat holding a sign that says hello world", num_inference_steps=4, guidance_scale=0).images[0]
image.save("example.png") image.save("example.png")
......
...@@ -4,11 +4,11 @@ from typing import Any, Callable, Optional, Union ...@@ -4,11 +4,11 @@ from typing import Any, Callable, Optional, Union
import torch import torch
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
import torchvision.utils import torchvision.utils
from PIL import Image
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, FluxPipelineOutput, FluxTransformer2DModel from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, FluxPipelineOutput, FluxTransformer2DModel
from einops import rearrange from einops import rearrange
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download, snapshot_download
from peft.tuners import lora from peft.tuners import lora
from PIL import Image
from torch import nn from torch import nn
from nunchaku.models.flux import inject_pipeline, load_quantized_model from nunchaku.models.flux import inject_pipeline, load_quantized_model
...@@ -145,9 +145,7 @@ class FluxPix2pixTurboPipeline(FluxPipeline): ...@@ -145,9 +145,7 @@ class FluxPix2pixTurboPipeline(FluxPipeline):
self.erosion_kernel = erosion_kernel self.erosion_kernel = erosion_kernel
torchvision.utils.save_image(image_t[0], "before.png") torchvision.utils.save_image(image_t[0], "before.png")
image_t = ( image_t = nn.functional.conv2d(image_t[:, :1], erosion_kernel, padding=kernel_size // 2) > kernel_size**2 - 0.1
nn.functional.conv2d(image_t[:, :1], erosion_kernel, padding=kernel_size // 2) > kernel_size**2 - 0.1
)
image_t = torch.concat([image_t, image_t, image_t], dim=1).to(self.dtype) image_t = torch.concat([image_t, image_t, image_t], dim=1).to(self.dtype)
torchvision.utils.save_image(image_t[0], "after.png") torchvision.utils.save_image(image_t[0], "after.png")
...@@ -219,6 +217,8 @@ class FluxPix2pixTurboPipeline(FluxPipeline): ...@@ -219,6 +217,8 @@ class FluxPix2pixTurboPipeline(FluxPipeline):
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
qmodel_device = kwargs.pop("qmodel_device", "cuda:0") qmodel_device = kwargs.pop("qmodel_device", "cuda:0")
qmodel_device = torch.device(qmodel_device) qmodel_device = torch.device(qmodel_device)
if qmodel_device.type != "cuda":
raise ValueError(f"qmodel_device = {qmodel_device} is not a CUDA device")
qmodel_path = kwargs.pop("qmodel_path", None) qmodel_path = kwargs.pop("qmodel_path", None)
qencoder_path = kwargs.pop("qencoder_path", None) qencoder_path = kwargs.pop("qencoder_path", None)
...@@ -229,11 +229,12 @@ class FluxPix2pixTurboPipeline(FluxPipeline): ...@@ -229,11 +229,12 @@ class FluxPix2pixTurboPipeline(FluxPipeline):
if qmodel_path is not None: if qmodel_path is not None:
assert isinstance(qmodel_path, str) assert isinstance(qmodel_path, str)
if not os.path.exists(qmodel_path): if not os.path.exists(qmodel_path):
hf_repo_id = os.path.dirname(qmodel_path) qmodel_path = snapshot_download(qmodel_path)
filename = os.path.basename(qmodel_path) m = load_quantized_model(
qmodel_path = hf_hub_download(repo_id=hf_repo_id, filename=filename) os.path.join(qmodel_path, "transformer_blocks.safetensors"),
m = load_quantized_model(qmodel_path, 0 if qmodel_device.index is None else qmodel_device.index) 0 if qmodel_device.index is None else qmodel_device.index,
inject_pipeline(pipeline, m) )
inject_pipeline(pipeline, m, qmodel_device)
pipeline.precision = "int4" pipeline.precision = "int4"
if qencoder_path is not None: if qencoder_path is not None:
......
...@@ -5,15 +5,17 @@ import tempfile ...@@ -5,15 +5,17 @@ import tempfile
import time import time
import GPUtil import GPUtil
import gradio as gr import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from flux_pix2pix_pipeline import FluxPix2pixTurboPipeline from flux_pix2pix_pipeline import FluxPix2pixTurboPipeline
from nunchaku.models.safety_checker import SafetyChecker from nunchaku.models.safety_checker import SafetyChecker
from utils import get_args from utils import get_args
from vars import DEFAULT_SKETCH_GUIDANCE, DEFAULT_STYLE_NAME, MAX_SEED, STYLES, STYLE_NAMES from vars import DEFAULT_SKETCH_GUIDANCE, DEFAULT_STYLE_NAME, MAX_SEED, STYLE_NAMES, STYLES
import numpy as np
# import gradio last to avoid conflicts with other imports
import gradio as gr
blank_image = Image.new("RGB", (1024, 1024), (255, 255, 255)) blank_image = Image.new("RGB", (1024, 1024), (255, 255, 255))
...@@ -30,7 +32,7 @@ else: ...@@ -30,7 +32,7 @@ else:
pipeline = FluxPix2pixTurboPipeline.from_pretrained( pipeline = FluxPix2pixTurboPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
qmodel_path="mit-han-lab/svdquant-models/svdq-int4-flux.1-schnell.safetensors", qmodel_path="mit-han-lab/svdq-int4-flux.1-schnell",
qencoder_path="mit-han-lab/svdquant-models/svdq-w4a16-t5.pt" if args.use_qencoder else None, qencoder_path="mit-han-lab/svdquant-models/svdq-w4a16-t5.pt" if args.use_qencoder else None,
) )
pipeline = pipeline.to("cuda") pipeline = pipeline.to("cuda")
......
import cv2
import numpy as np
from PIL import Image
import argparse import argparse
......
...@@ -26,7 +26,7 @@ _HOMEPAGE = "https://github.com/facebookresearch/DCI" ...@@ -26,7 +26,7 @@ _HOMEPAGE = "https://github.com/facebookresearch/DCI"
_LICENSE = "Attribution-NonCommercial 4.0 International (https://github.com/facebookresearch/DCI/blob/main/LICENSE)" _LICENSE = "Attribution-NonCommercial 4.0 International (https://github.com/facebookresearch/DCI/blob/main/LICENSE)"
IMAGE_URL = "https://scontent.xx.fbcdn.net/m1/v/t6/An_zz_Te0EtVC_cHtUwnyNKODapWXuNNPeBgZn_3XY8yDFzwHrNb-zwN9mYCbAeWUKQooCI9mVbwvzZDZzDUlscRjYxLKsw.tar?ccb=10-5&oh=00_AYBnKR-fSIir-E49Q7-qO2tjmY0BGJhCciHS__B5QyiBAg&oe=673FFA8A&_nc_sid=0fdd51" IMAGE_URL = "https://huggingface.co/datasets/mit-han-lab/svdquant-datasets/resolve/main/sDCI.gz"
PROMPT_URLS = {"sDCI": "https://huggingface.co/datasets/mit-han-lab/svdquant-datasets/resolve/main/sDCI.yaml"} PROMPT_URLS = {"sDCI": "https://huggingface.co/datasets/mit-han-lab/svdquant-datasets/resolve/main/sDCI.yaml"}
......
...@@ -31,7 +31,9 @@ def main(): ...@@ -31,7 +31,9 @@ def main():
results = {} results = {}
dataset_names = sorted(os.listdir(image_root1)) dataset_names = sorted(os.listdir(image_root1))
for dataset_name in dataset_names: for dataset_name in dataset_names:
print("##Results for dataset:", dataset_name) if image_root2 is not None and dataset_name not in os.listdir(image_root2):
continue
print("Results for dataset:", dataset_name)
results[dataset_name] = {} results[dataset_name] = {}
dataset = get_dataset(name=dataset_name, return_gt=True) dataset = get_dataset(name=dataset_name, return_gt=True)
fid = compute_fid(ref_dirpath_or_dataset=dataset, gen_dirpath=os.path.join(image_root1, dataset_name)) fid = compute_fid(ref_dirpath_or_dataset=dataset, gen_dirpath=os.path.join(image_root1, dataset_name))
......
...@@ -2,6 +2,7 @@ import argparse ...@@ -2,6 +2,7 @@ import argparse
import time import time
import torch import torch
from torch import nn
from tqdm import trange from tqdm import trange
from utils import get_pipeline from utils import get_pipeline
...@@ -51,23 +52,38 @@ def main(): ...@@ -51,23 +52,38 @@ def main():
pipeline.set_progress_bar_config(position=1, desc="Step", leave=False) pipeline.set_progress_bar_config(position=1, desc="Step", leave=False)
for _ in trange(args.warmup_times, desc="Warmup", position=0, leave=False): for _ in trange(args.warmup_times, desc="Warmup", position=0, leave=False):
pipeline( pipeline(
prompt=dummy_prompt, prompt=dummy_prompt, num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
for _ in trange(args.test_times, desc="Warmup", position=0, leave=False): for _ in trange(args.test_times, desc="Warmup", position=0, leave=False):
start_time = time.time() start_time = time.time()
pipeline( pipeline(
prompt=dummy_prompt, prompt=dummy_prompt, num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
end_time = time.time() end_time = time.time()
latency_list.append(end_time - start_time) latency_list.append(end_time - start_time)
elif args.mode == "step": elif args.mode == "step":
pass inputs = {}
def get_input_hook(module: nn.Module, input_args, input_kwargs):
inputs["args"] = input_args
inputs["kwargs"] = input_kwargs
pipeline.transformer.register_forward_pre_hook(get_input_hook, with_kwargs=True)
pipeline(prompt=dummy_prompt, num_inference_steps=1, guidance_scale=args.guidance_scale, output_type="latent")
for _ in trange(args.warmup_times, desc="Warmup", position=0, leave=False):
pipeline.transformer(*inputs["args"], **inputs["kwargs"])
torch.cuda.synchronize()
for _ in trange(args.test_times, desc="Warmup", position=0, leave=False):
start_time = time.time()
pipeline.transformer(*inputs["args"], **inputs["kwargs"])
torch.cuda.synchronize()
end_time = time.time()
latency_list.append(end_time - start_time)
latency_list = sorted(latency_list) latency_list = sorted(latency_list)
ignored_count = int(args.ignore_ratio * len(latency_list) / 2) ignored_count = int(args.ignore_ratio * len(latency_list) / 2)
if ignored_count > 0: if ignored_count > 0:
......
...@@ -5,7 +5,6 @@ import random ...@@ -5,7 +5,6 @@ import random
import time import time
import GPUtil import GPUtil
import gradio as gr
import spaces import spaces
import torch import torch
from peft.tuners import lora from peft.tuners import lora
...@@ -14,6 +13,9 @@ from nunchaku.models.safety_checker import SafetyChecker ...@@ -14,6 +13,9 @@ from nunchaku.models.safety_checker import SafetyChecker
from utils import get_pipeline from utils import get_pipeline
from vars import DEFAULT_HEIGHT, DEFAULT_WIDTH, EXAMPLES, MAX_SEED, PROMPT_TEMPLATES, SVDQ_LORA_PATHS from vars import DEFAULT_HEIGHT, DEFAULT_WIDTH, EXAMPLES, MAX_SEED, PROMPT_TEMPLATES, SVDQ_LORA_PATHS
# import gradio last to avoid conflicts with other imports
import gradio as gr
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
...@@ -29,7 +29,7 @@ def get_pipeline( ...@@ -29,7 +29,7 @@ def get_pipeline(
pipeline = nunchaku_flux.from_pretrained( pipeline = nunchaku_flux.from_pretrained(
"black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
qmodel_path="mit-han-lab/svdquant-models/svdq-int4-flux.1-schnell.safetensors", qmodel_path="mit-han-lab/svdq-int4-flux.1-schnell",
qencoder_path="mit-han-lab/svdquant-models/svdq-w4a16-t5.pt" if use_qencoder else None, qencoder_path="mit-han-lab/svdquant-models/svdq-w4a16-t5.pt" if use_qencoder else None,
qmodel_device=device, qmodel_device=device,
) )
...@@ -41,7 +41,7 @@ def get_pipeline( ...@@ -41,7 +41,7 @@ def get_pipeline(
pipeline = nunchaku_flux.from_pretrained( pipeline = nunchaku_flux.from_pretrained(
"black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
qmodel_path="mit-han-lab/svdquant-models/svdq-int4-flux.1-dev.safetensors", qmodel_path="mit-han-lab/svdq-int4-flux.1-dev",
qencoder_path="mit-han-lab/svdquant-models/svdq-w4a16-t5.pt" if use_qencoder else None, qencoder_path="mit-han-lab/svdquant-models/svdq-w4a16-t5.pt" if use_qencoder else None,
qmodel_device=device, qmodel_device=device,
) )
......
...@@ -14,10 +14,11 @@ SVD_RANK = 32 ...@@ -14,10 +14,11 @@ SVD_RANK = 32
class NunchakuFluxModel(nn.Module): class NunchakuFluxModel(nn.Module):
def __init__(self, m: QuantizedFluxModel): def __init__(self, m: QuantizedFluxModel, device: torch.device):
super().__init__() super().__init__()
self.m = m self.m = m
self.dtype = torch.bfloat16 self.dtype = torch.bfloat16
self.device = device
def forward( def forward(
self, self,
...@@ -33,10 +34,12 @@ class NunchakuFluxModel(nn.Module): ...@@ -33,10 +34,12 @@ class NunchakuFluxModel(nn.Module):
img_tokens = hidden_states.shape[1] img_tokens = hidden_states.shape[1]
original_dtype = hidden_states.dtype original_dtype = hidden_states.dtype
original_device = hidden_states.device
hidden_states = hidden_states.to(self.dtype) hidden_states = hidden_states.to(self.dtype).to(self.device)
encoder_hidden_states = encoder_hidden_states.to(self.dtype) encoder_hidden_states = encoder_hidden_states.to(self.dtype).to(self.device)
temb = temb.to(self.dtype) temb = temb.to(self.dtype).to(self.device)
image_rotary_emb = image_rotary_emb.to(self.device)
assert image_rotary_emb.ndim == 6 assert image_rotary_emb.ndim == 6
assert image_rotary_emb.shape[0] == 1 assert image_rotary_emb.shape[0] == 1
...@@ -52,7 +55,7 @@ class NunchakuFluxModel(nn.Module): ...@@ -52,7 +55,7 @@ class NunchakuFluxModel(nn.Module):
hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_txt, rotary_emb_single hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_txt, rotary_emb_single
) )
hidden_states = hidden_states.to(original_dtype) hidden_states = hidden_states.to(original_dtype).to(original_device)
encoder_hidden_states = hidden_states[:, :txt_tokens, ...] encoder_hidden_states = hidden_states[:, :txt_tokens, ...]
hidden_states = hidden_states[:, txt_tokens:, ...] hidden_states = hidden_states[:, txt_tokens:, ...]
...@@ -110,11 +113,11 @@ def load_quantized_model(path: str, device: str | torch.device) -> QuantizedFlux ...@@ -110,11 +113,11 @@ def load_quantized_model(path: str, device: str | torch.device) -> QuantizedFlux
return m return m
def inject_pipeline(pipe: FluxPipeline, m: QuantizedFluxModel) -> FluxPipeline: def inject_pipeline(pipe: FluxPipeline, m: QuantizedFluxModel, device: torch.device) -> FluxPipeline:
net: FluxTransformer2DModel = pipe.transformer net: FluxTransformer2DModel = pipe.transformer
net.pos_embed = EmbedND(dim=net.inner_dim, theta=10000, axes_dim=[16, 56, 56]) net.pos_embed = EmbedND(dim=net.inner_dim, theta=10000, axes_dim=[16, 56, 56])
net.transformer_blocks = torch.nn.ModuleList([NunchakuFluxModel(m)]) net.transformer_blocks = torch.nn.ModuleList([NunchakuFluxModel(m, device)])
net.single_transformer_blocks = torch.nn.ModuleList([]) net.single_transformer_blocks = torch.nn.ModuleList([])
def update_params(self: FluxTransformer2DModel, path: str): def update_params(self: FluxTransformer2DModel, path: str):
......
import os import os
import torch import torch
from diffusers import FluxPipeline from diffusers import __version__
from huggingface_hub import hf_hub_download from diffusers import FluxPipeline, FluxTransformer2DModel
from huggingface_hub import hf_hub_download, snapshot_download
from safetensors.torch import load_file
from torch import nn
from ..models.flux import inject_pipeline, load_quantized_model from ..models.flux import inject_pipeline, load_quantized_model
...@@ -45,13 +48,34 @@ def from_pretrained(pretrained_model_name_or_path: str | os.PathLike, **kwargs) ...@@ -45,13 +48,34 @@ def from_pretrained(pretrained_model_name_or_path: str | os.PathLike, **kwargs)
qencoder_path = kwargs.pop("qencoder_path", None) qencoder_path = kwargs.pop("qencoder_path", None)
if not os.path.exists(qmodel_path): if not os.path.exists(qmodel_path):
hf_repo_id = os.path.dirname(qmodel_path) qmodel_path = snapshot_download(qmodel_path)
filename = os.path.basename(qmodel_path)
qmodel_path = hf_hub_download(repo_id=hf_repo_id, filename=filename)
pipeline = FluxPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs) assert kwargs.pop("transformer", None) is None
m = load_quantized_model(qmodel_path, 0 if qmodel_device.index is None else qmodel_device.index)
inject_pipeline(pipeline, m) config, unused_kwargs, commit_hash = FluxTransformer2DModel.load_config(
pretrained_model_name_or_path,
cache_dir=kwargs.get("cache_dir", None),
return_unused_kwargs=True,
return_commit_hash=True,
force_download=kwargs.get("force_download", False),
proxies=kwargs.get("proxies", None),
local_files_only=kwargs.get("local_files_only", None),
token=kwargs.get("token", None),
revision=kwargs.get("revision", None),
subfolder="transformer",
user_agent={"diffusers": __version__, "file_type": "model", "framework": "pytorch"},
**kwargs,
)
transformer: nn.Module = FluxTransformer2DModel.from_config(config).to(kwargs.get("torch_dtype", torch.bfloat16))
state_dict = load_file(os.path.join(qmodel_path, "unquantized_layers.safetensors"))
transformer.load_state_dict(state_dict, strict=False)
pipeline = FluxPipeline.from_pretrained(pretrained_model_name_or_path, transformer=transformer, **kwargs)
m = load_quantized_model(
os.path.join(qmodel_path, "transformer_blocks.safetensors"),
0 if qmodel_device.index is None else qmodel_device.index,
)
inject_pipeline(pipeline, m, qmodel_device)
if qencoder_path is not None: if qencoder_path is not None:
assert isinstance(qencoder_path, str) assert isinstance(qencoder_path, str)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment