Unverified Commit b0484ae0 authored by dengdong's avatar dengdong Committed by GitHub
Browse files

feat: sdxl model support (#674)

* feat: sdxl model support

* code style auto-modified by pre-commit hook

* refine comments

* add tests and examples for sdxl

* refine sdxl tests code

* make linter happy

* mv the locations of the examples

* move the locations of the tests

* refine tests and examples

* add API documentation for unet_sdxl.py

* usage doc for sdxl

* update docs

* update

* refine pipeline initialization

* refine tests for sdxl/sdxl-turbo
parent 657863bb
...@@ -20,6 +20,7 @@ Check out `DeepCompressor <github_deepcompressor_>`_ for the quantization librar ...@@ -20,6 +20,7 @@ Check out `DeepCompressor <github_deepcompressor_>`_ for the quantization librar
usage/qwen-image-edit.rst usage/qwen-image-edit.rst
usage/lora.rst usage/lora.rst
usage/kontext.rst usage/kontext.rst
usage/sdxl.rst
usage/controlnet.rst usage/controlnet.rst
usage/qencoder.rst usage/qencoder.rst
usage/offload.rst usage/offload.rst
......
...@@ -5,6 +5,7 @@ nunchaku.models ...@@ -5,6 +5,7 @@ nunchaku.models
:maxdepth: 4 :maxdepth: 4
nunchaku.models.transformers nunchaku.models.transformers
nunchaku.models.unets
nunchaku.models.text_encoders nunchaku.models.text_encoders
nunchaku.models.linear nunchaku.models.linear
nunchaku.models.attention nunchaku.models.attention
......
nunchaku.models.unets
=====================
.. toctree::
:maxdepth: 4
nunchaku.models.unets.unet_sdxl
nunchaku.models.unets.unet\_sdxl
================================
.. automodule:: nunchaku.models.unets.unet_sdxl
:members:
:undoc-members:
:show-inheritance:
Stable Diffusion XL
===================
The following is the example of running Nunchaku INT4 version of SDXL and SDXL-Turbo text-to-image pipeline.
.. tabs::
.. tab:: SDXL
.. literalinclude:: ../../../examples/v1/sdxl.py
:language: python
:caption: Running Nunchaku SDXL (`examples/v1/sdxl.py <https://github.com/nunchaku-tech/nunchaku/blob/main/examples/v1/sdxl.py>`__)
:linenos:
.. tab:: SDXL-turbo
.. literalinclude:: ../../../examples/v1/sdxl-turbo.py
:language: python
:caption: Running Nunchaku SDXL-Turbo (`examples/v1/sdxl-turbo.py <https://github.com/nunchaku-tech/nunchaku/blob/main/examples/v1/sdxl-turbo.py>`__)
:linenos:
For more details, see :class:`~nunchaku.models.unets.unet_sdxl.NunchakuSDXLUNet2DConditionModel`.
import torch
from diffusers import StableDiffusionXLPipeline
from nunchaku.models.unets.unet_sdxl import NunchakuSDXLUNet2DConditionModel
if __name__ == "__main__":
unet = NunchakuSDXLUNet2DConditionModel.from_pretrained(
"nunchaku-tech/nunchaku-sdxl-turbo/svdq-int4_r32-sdxl-turbo.safetensors"
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/sdxl-turbo", unet=unet, torch_dtype=torch.bfloat16, variant="fp16"
).to("cuda")
prompt = "A cinematic shot of a baby racoon wearing an intricate italian priest robe."
image = pipeline(prompt=prompt, guidance_scale=0.0, num_inference_steps=4).images[0]
image.save("sdxl-turbo.png")
import torch
from diffusers import StableDiffusionXLPipeline
from nunchaku.models.unets.unet_sdxl import NunchakuSDXLUNet2DConditionModel
if __name__ == "__main__":
unet = NunchakuSDXLUNet2DConditionModel.from_pretrained(
"nunchaku-tech/nunchaku-sdxl/svdq-int4_r32-sdxl.safetensors"
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
unet=unet,
torch_dtype=torch.bfloat16,
use_safetensors=True,
variant="fp16",
).to("cuda")
prompt = "A cinematic shot of a baby racoon wearing an intricate italian priest robe."
image = pipeline(prompt=prompt, guidance_scale=5.0, num_inference_steps=50).images[0]
image.save("sdxl.png")
from typing import Optional
import torch
from torch.nn import functional as F
class NunchakuSDXLFA2Processor:
def __call__(
self,
attn,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**cross_attention_kwargs,
):
# Adapted from https://github.com/huggingface/diffusers/blob/50dea89dc6036e71a00bc3d57ac062a80206d9eb/src/diffusers/models/attention_processor.py#AttnProcessor2_0
# if len(args) > 0 or kwargs.get("scale", None) is not None:
# deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# deprecate("scale", "1.0.0", deprecation_message)
# residual = hidden_states
# if attn.spatial_norm is not None:
# hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
# attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# # scaled_dot_product_attention expects attention_mask shape to be
# # (batch, heads, source_length, target_length)
# attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
raise NotImplementedError("attention_mask is not supported")
# if attn.group_norm is not None:
# hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
############# qkv ################
# query = attn.to_q(hidden_states)
# if encoder_hidden_states is None:
# encoder_hidden_states = hidden_states
# elif attn.norm_cross:
# encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# key = attn.to_k(encoder_hidden_states)
# value = attn.to_v(encoder_hidden_states)
if not attn.is_cross_attention:
qkv = attn.to_qkv(hidden_states)
query, key, value = qkv.chunk(3, dim=-1)
# query, key, value = attn.to_q(hidden_states), attn.to_k(hidden_states), attn.to_v(hidden_states)
else:
query, key, value = (
attn.to_q(hidden_states),
attn.to_k(encoder_hidden_states),
attn.to_v(encoder_hidden_states),
)
############# end of qkv ################
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# if attn.norm_q is not None:
# query = attn.norm_q(query)
# if attn.norm_k is not None:
# key = attn.norm_k(key)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
# if attn.residual_connection:
# hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
from .unet_sdxl import (
NunchakuSDXLAttention,
NunchakuSDXLConcatShiftedConv2d,
NunchakuSDXLShiftedConv2d,
NunchakuSDXLTransformerBlock,
NunchakuSDXLUNet2DConditionModel,
)
__all__ = [
"NunchakuSDXLAttention",
"NunchakuSDXLTransformerBlock",
"NunchakuSDXLShiftedConv2d",
"NunchakuSDXLConcatShiftedConv2d",
"NunchakuSDXLUNet2DConditionModel",
"NunchakuSDXLFeedForward",
]
This diff is collapsed.
...@@ -82,9 +82,9 @@ def fused_gelu_mlp(x: torch.Tensor, fc1: SVDQW4A4Linear, fc2: SVDQW4A4Linear, pa ...@@ -82,9 +82,9 @@ def fused_gelu_mlp(x: torch.Tensor, fc1: SVDQW4A4Linear, fc2: SVDQW4A4Linear, pa
def fused_qkv_norm_rottary( def fused_qkv_norm_rottary(
x: torch.Tensor, x: torch.Tensor,
proj: SVDQW4A4Linear, proj: SVDQW4A4Linear,
norm_q: RMSNorm, norm_q: RMSNorm | None = None,
norm_k: RMSNorm, norm_k: RMSNorm | None = None,
rotary_emb: torch.Tensor, rotary_emb: torch.Tensor | None = None,
output: torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, output: torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
attn_tokens: int = 0, attn_tokens: int = 0,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
...@@ -124,8 +124,8 @@ def fused_qkv_norm_rottary( ...@@ -124,8 +124,8 @@ def fused_qkv_norm_rottary(
- C_in: input features - C_in: input features
- C_out: output features - C_out: output features
""" """
assert isinstance(norm_q, RMSNorm) assert norm_q is None or isinstance(norm_q, RMSNorm)
assert isinstance(norm_k, RMSNorm) assert norm_k is None or isinstance(norm_k, RMSNorm)
batch_size, seq_len, channels = x.shape batch_size, seq_len, channels = x.shape
x = x.view(batch_size * seq_len, channels) x = x.view(batch_size * seq_len, channels)
...@@ -148,8 +148,8 @@ def fused_qkv_norm_rottary( ...@@ -148,8 +148,8 @@ def fused_qkv_norm_rottary(
fp4=proj.precision == "nvfp4", fp4=proj.precision == "nvfp4",
alpha=proj.wtscale, alpha=proj.wtscale,
wcscales=proj.wcscales, wcscales=proj.wcscales,
norm_q=norm_q.weight, norm_q=norm_q.weight if norm_q is not None else None,
norm_k=norm_k.weight, norm_k=norm_k.weight if norm_k is not None else None,
rotary_emb=rotary_emb, rotary_emb=rotary_emb,
out_q=output_q, out_q=output_q,
out_k=output_k, out_k=output_k,
...@@ -170,8 +170,8 @@ def fused_qkv_norm_rottary( ...@@ -170,8 +170,8 @@ def fused_qkv_norm_rottary(
fp4=proj.precision == "nvfp4", fp4=proj.precision == "nvfp4",
alpha=proj.wtscale, alpha=proj.wtscale,
wcscales=proj.wcscales, wcscales=proj.wcscales,
norm_q=norm_q.weight, norm_q=norm_q.weight if norm_q is not None else None,
norm_k=norm_k.weight, norm_k=norm_k.weight if norm_k is not None else None,
rotary_emb=rotary_emb, rotary_emb=rotary_emb,
) )
output = output.view(batch_size, seq_len, -1) output = output.view(batch_size, seq_len, -1)
......
import gc
import os
from pathlib import Path
import pytest
import torch
from diffusers import StableDiffusionXLPipeline
from nunchaku.models.unets.unet_sdxl import NunchakuSDXLUNet2DConditionModel
from nunchaku.utils import get_precision, is_turing
from ...flux.utils import already_generate, compute_lpips, hash_str_to_int
from .test_sdxl_turbo import plot, run_benchmark
@pytest.mark.skipif(
is_turing() or get_precision() == "fp4", reason="Skip tests due to using Turing GPUs or FP4 precision"
)
@pytest.mark.parametrize("expected_lpips", [0.25 if get_precision() == "int4" else 0.18])
def test_sdxl_lpips(expected_lpips: float):
gc.collect()
torch.cuda.empty_cache()
precision = get_precision()
ref_root = Path(os.environ.get("NUNCHAKU_TEST_CACHE_ROOT", os.path.join("test_results", "ref")))
results_dir_original = ref_root / "fp16" / "sdxl"
results_dir_nunchaku = ref_root / precision / "sdxl"
os.makedirs(results_dir_original, exist_ok=True)
os.makedirs(results_dir_nunchaku, exist_ok=True)
prompts = [
"Ilya Repin, Moebius, Yoshitaka Amano, 1980s nubian punk rock glam core fashion shoot, closeup, 35mm ",
"A honeybee sitting on a flower in a garden full of yellow flowers",
"Vibrant, tropical rainforest, teeming with wildlife, nature photography ",
"very realistic photo of barak obama in a wing eating contest",
"oil paint of colorful wildflowers in a meadow, Paul Signac divisionism style ",
]
if not already_generate(results_dir_original, 5):
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16, use_safetensors=True, variant="fp16"
).to("cuda")
for prompt in prompts:
seed = hash_str_to_int(prompt)
result = pipeline(
prompt=prompt, guidance_scale=5.0, num_inference_steps=50, generator=torch.Generator().manual_seed(seed)
).images[0]
result.save(os.path.join(results_dir_original, f"{seed}.png"))
del pipeline.unet
del pipeline.text_encoder
del pipeline.text_encoder_2
del pipeline.vae
del pipeline
del result
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
free, total = torch.cuda.mem_get_info()
print(f"After original generation: Free: {free/1024**2:.0f} MB / Total: {total/1024**2:.0f} MB")
if not already_generate(results_dir_nunchaku, 5):
quantized_unet = NunchakuSDXLUNet2DConditionModel.from_pretrained(
"nunchaku-tech/nunchaku-sdxl/svdq-int4_r32-sdxl.safetensors"
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
unet=quantized_unet,
torch_dtype=torch.bfloat16,
use_safetensors=True,
variant="fp16",
)
pipeline.unet = quantized_unet
pipeline = pipeline.to("cuda")
for prompt in prompts:
seed = hash_str_to_int(prompt)
result = pipeline(
prompt=prompt, guidance_scale=5.0, num_inference_steps=50, generator=torch.Generator().manual_seed(seed)
).images[0]
result.save(os.path.join(results_dir_nunchaku, f"{seed}.png"))
del pipeline
del quantized_unet
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
free, total = torch.cuda.mem_get_info()
print(f"After Nunchaku generation: Free: {free/1024**2:.0f} MB / Total: {total/1024**2:.0f} MB")
lpips = compute_lpips(results_dir_original, results_dir_nunchaku)
print(f"lpips: {lpips}")
assert lpips < expected_lpips * 1.15
@pytest.mark.skipif(
is_turing() or get_precision() == "fp4", reason="Skip tests due to using Turing GPUs or FP4 precision"
)
@pytest.mark.parametrize("expected_latency", [7.455])
def test_sdxl_time_cost(expected_latency: float):
batch_size = 2
runs = 5
inference_steps = 50
guidance_scale = 5.0
device_name = torch.cuda.get_device_name(0)
results = {"Nunchaku INT4": []}
quantized_unet = NunchakuSDXLUNet2DConditionModel.from_pretrained(
"nunchaku-tech/nunchaku-sdxl/svdq-int4_r32-sdxl.safetensors"
)
pipeline_quantized = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
unet=quantized_unet,
torch_dtype=torch.bfloat16,
use_safetensors=True,
variant="fp16",
)
pipeline_quantized = pipeline_quantized.to("cuda")
benchmark_quantized = run_benchmark(
pipeline_quantized, batch_size, guidance_scale, device_name, runs, inference_steps
)
avg_latency = benchmark_quantized.mean() * inference_steps
results["Nunchaku INT4"].append(avg_latency)
ref_root = Path(os.environ.get("NUNCHAKU_TEST_CACHE_ROOT", os.path.join("test_results", "ref")))
plot_save_path = ref_root / "time_cost" / "sdxl"
os.makedirs(plot_save_path, exist_ok=True)
plot([batch_size], results, device_name, runs, inference_steps, plot_save_path, "SDXL")
assert avg_latency < expected_latency * 1.1
import gc
import os
import time
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pytest
import torch
from diffusers import StableDiffusionXLPipeline
from nunchaku.models.unets.unet_sdxl import NunchakuSDXLUNet2DConditionModel
from nunchaku.utils import get_precision, is_turing
from ...flux.utils import already_generate, compute_lpips, hash_str_to_int
@pytest.mark.skipif(
is_turing() or get_precision() == "fp4", reason="Skip tests due to using Turing GPUs or FP4 precision"
)
@pytest.mark.parametrize("expected_lpips", [0.25 if get_precision() == "int4" else 0.18])
def test_sdxl_turbo_lpips(expected_lpips: float):
gc.collect()
torch.cuda.empty_cache()
precision = get_precision()
ref_root = Path(os.environ.get("NUNCHAKU_TEST_CACHE_ROOT", os.path.join("test_results", "ref")))
results_dir_original = ref_root / "fp16" / "sdxl-turbo"
results_dir_nunchaku = ref_root / precision / "sdxl-turbo"
os.makedirs(results_dir_original, exist_ok=True)
os.makedirs(results_dir_nunchaku, exist_ok=True)
prompts = [
"Ilya Repin, Moebius, Yoshitaka Amano, 1980s nubian punk rock glam core fashion shoot, closeup, 35mm ",
"A honeybee sitting on a flower in a garden full of yellow flowers",
"Vibrant, tropical rainforest, teeming with wildlife, nature photography ",
"very realistic photo of barak obama in a wing eating contest",
"oil paint of colorful wildflowers in a meadow, Paul Signac divisionism style ",
]
if not already_generate(results_dir_original, 5):
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/sdxl-turbo", torch_dtype=torch.bfloat16, variant="fp16"
).to("cuda")
for prompt in prompts:
seed = hash_str_to_int(prompt)
result = pipeline(
prompt=prompt, guidance_scale=0.0, num_inference_steps=4, generator=torch.Generator().manual_seed(seed)
).images[0]
result.save(os.path.join(results_dir_original, f"{seed}.png"))
del pipeline.unet
del pipeline.text_encoder
del pipeline.text_encoder_2
del pipeline.vae
del pipeline
del result
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
free, total = torch.cuda.mem_get_info()
print(f"After original generation: Free: {free/1024**2:.0f} MB / Total: {total/1024**2:.0f} MB")
if not already_generate(results_dir_nunchaku, 5):
quantized_unet = NunchakuSDXLUNet2DConditionModel.from_pretrained(
"nunchaku-tech/nunchaku-sdxl-turbo/svdq-int4_r32-sdxl-turbo.safetensors"
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/sdxl-turbo", unet=quantized_unet, torch_dtype=torch.bfloat16, variant="fp16"
)
pipeline = pipeline.to("cuda")
for prompt in prompts:
seed = hash_str_to_int(prompt)
result = pipeline(
prompt=prompt, guidance_scale=0.0, num_inference_steps=4, generator=torch.Generator().manual_seed(seed)
).images[0]
result.save(os.path.join(results_dir_nunchaku, f"{seed}.png"))
del pipeline
del quantized_unet
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
free, total = torch.cuda.mem_get_info()
print(f"After Nunchaku generation: Free: {free/1024**2:.0f} MB / Total: {total/1024**2:.0f} MB")
lpips = compute_lpips(results_dir_original, results_dir_nunchaku)
print(f"lpips: {lpips}")
assert lpips < expected_lpips * 1.15
class PerfHook:
def __init__(self):
self.start = []
self.end = []
def pre_hook(self, module, input):
self.start.append(time.perf_counter())
def post_hook(self, module, input, output):
self.end.append(time.perf_counter())
def run_benchmark(pipeline, batch_size, guidance_scale, device, runs, inference_steps):
prompt = "A cinematic shot of a baby racoon wearing an intricate italian priest robe."
# warmup
_ = pipeline(
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=inference_steps,
num_images_per_prompt=batch_size,
).images
time_cost = []
unet = pipeline.unet
perf_hook = PerfHook()
handle_pre = unet.register_forward_pre_hook(perf_hook.pre_hook)
handle_post = unet.register_forward_hook(perf_hook.post_hook)
# run
for _ in range(runs):
_ = pipeline(
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=inference_steps,
num_images_per_prompt=batch_size,
).images
time_cost = [perf_hook.end[i] - perf_hook.start[i] for i in range(len(perf_hook.start))]
# to numpy for stats
time_cost = np.array(time_cost)
print(f"device: {device}")
print(f"runs :{runs}")
print(f"batch_size: {batch_size}")
print(f"max :{time_cost.max():.4f}")
print(f"min :{time_cost.min():.4f}")
print(f"avg :{time_cost.mean():.4f}")
print(f"std :{time_cost.std():.4f}")
handle_pre.remove()
handle_post.remove()
return time_cost
def plot(batch_sizes, results, device_name, runs, inference_steps, plot_save_path, title):
x = np.arange(len(batch_sizes))
width = 0.35
fig, ax = plt.subplots()
rects2 = ax.bar(x + width / 2, results["Nunchaku INT4"], width, label="Nunchaku INT4")
ax.set_ylabel(f"Average time cost (seconds)\n{runs} runs of {inference_steps} inference steps each.")
ax.set_xlabel("Batch size")
ax.set_title(f"{title} diffusion time cost\n(GPU: {device_name})")
ax.set_xticks(x)
ax.set_xticklabels(batch_sizes)
ax.legend()
def autolabel(rects):
for rect in rects:
height = rect.get_height()
ax.annotate(
f"{height:.3f}",
xy=(rect.get_x() + rect.get_width() / 2, height),
xytext=(0, 3),
textcoords="offset points",
ha="center",
va="bottom",
)
autolabel(rects2)
plt.tight_layout()
plt.savefig(plot_save_path / "plot.png", dpi=300, bbox_inches="tight")
@pytest.mark.skipif(
is_turing() or get_precision() == "fp4", reason="Skip tests due to using Turing GPUs or FP4 precision"
)
@pytest.mark.parametrize("expected_latency", [0.306])
def test_sdxl_turbo_time_cost(expected_latency: float):
batch_size = 8
runs = 5
guidance_scale = 0.0
inference_steps = 4
device_name = torch.cuda.get_device_name(0)
results = {"Nunchaku INT4": []}
quantized_unet = NunchakuSDXLUNet2DConditionModel.from_pretrained(
"nunchaku-tech/nunchaku-sdxl-turbo/svdq-int4_r32-sdxl-turbo.safetensors"
)
pipeline_quantized = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/sdxl-turbo", unet=quantized_unet, torch_dtype=torch.bfloat16, variant="fp16"
)
pipeline_quantized = pipeline_quantized.to("cuda")
benchmark_quantized = run_benchmark(
pipeline_quantized, batch_size, guidance_scale, device_name, runs, inference_steps
)
avg_latency = benchmark_quantized.mean() * inference_steps
results["Nunchaku INT4"].append(avg_latency)
ref_root = Path(os.environ.get("NUNCHAKU_TEST_CACHE_ROOT", os.path.join("test_results", "ref")))
plot_save_path = ref_root / "time_cost" / "sdxl-turbo"
os.makedirs(plot_save_path, exist_ok=True)
plot([batch_size], results, device_name, runs, inference_steps, plot_save_path, "SDXL-Turbo")
assert avg_latency < expected_latency * 1.1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment