Unverified Commit 889efafd authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

[major] Adding some test workflows; support multiple batches

[major] Adding some test workflows; support multiple batches
parents 68dafdfa 16979769
...@@ -303,10 +303,10 @@ Tensor gemv_awq( ...@@ -303,10 +303,10 @@ Tensor gemv_awq(
constexpr int GROUP_SIZE = 64; constexpr int GROUP_SIZE = 64;
assert(m > 0 && m < 8); assert(m > 0 && m <= 8);
assert(group_size == GROUP_SIZE); assert(group_size == GROUP_SIZE);
dispatchVal(m, std::make_integer_sequence<int, 8>(), [&]<int M>() { dispatchVal(m, std::make_integer_sequence<int, 9>(), [&]<int M>() {
if constexpr (M == 0) { if constexpr (M == 0) {
assert(false); assert(false);
return; return;
......
...@@ -180,7 +180,7 @@ std::array<Tensor, N> split_mod(Tensor input) { ...@@ -180,7 +180,7 @@ std::array<Tensor, N> split_mod(Tensor input) {
auto stream = getCurrentCUDAStream(); auto stream = getCurrentCUDAStream();
auto shapeOut = input.shape; auto shapeOut = TensorShape(input.shape.dataExtent);
shapeOut[-1] /= N; shapeOut[-1] /= N;
std::array<Tensor, N> out; std::array<Tensor, N> out;
......
...@@ -7,12 +7,7 @@ from huggingface_hub import snapshot_download ...@@ -7,12 +7,7 @@ from huggingface_hub import snapshot_download
from nunchaku.utils import fetch_or_download from nunchaku.utils import fetch_or_download
__all__ = ["get_dataset", "load_dataset_yaml", "download_hf_dataset"] __all__ = ["get_dataset", "load_dataset_yaml"]
def download_hf_dataset(repo_id: str = "mit-han-lab/nunchaku-test", local_dir: str | None = None) -> str:
path = snapshot_download(repo_id=repo_id, repo_type="dataset", local_dir=local_dir)
return path
def load_dataset_yaml(meta_path: str, max_dataset_size: int = -1, repeat: int = 4) -> dict: def load_dataset_yaml(meta_path: str, max_dataset_size: int = -1, repeat: int = 4) -> dict:
......
import pytest
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(
is_turing() or torch.cuda.device_count() <= 1, reason="Skip tests due to using Turing GPUs or single GPU"
)
def test_device_id():
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
torch_dtype = torch.float16 if is_turing("cuda:1") else torch.bfloat16
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-schnell", torch_dtype=torch_dtype, device="cuda:1"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch_dtype
).to("cuda:1")
pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
)
...@@ -4,15 +4,14 @@ from nunchaku.utils import get_precision, is_turing ...@@ -4,15 +4,14 @@ from nunchaku.utils import get_precision, is_turing
from .utils import run_test from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests for Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips", "cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips",
[ [
(0.12, 1024, 1024, 30, None, 1, 0.26), (0.12, 1024, 1024, 30, None, 1, 0.212),
(0.12, 512, 2048, 30, "anime", 1, 0.4),
], ],
) )
def test_flux_dev_loras( def test_flux_dev_cache(
cache_threshold: float, cache_threshold: float,
height: int, height: int,
width: int, width: int,
......
...@@ -4,12 +4,12 @@ from nunchaku.utils import get_precision, is_turing ...@@ -4,12 +4,12 @@ from nunchaku.utils import get_precision, is_turing
from .utils import run_test from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests for Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips", "height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips",
[ [
(1024, 1024, 50, "flashattn2", False, 0.226), (1024, 1024, 50, "flashattn2", False, 0.139),
(2048, 512, 25, "nunchaku-fp16", False, 0.243), (2048, 512, 25, "nunchaku-fp16", False, 0.168),
], ],
) )
def test_flux_dev( def test_flux_dev(
......
...@@ -4,16 +4,16 @@ from nunchaku.utils import get_precision, is_turing ...@@ -4,16 +4,16 @@ from nunchaku.utils import get_precision, is_turing
from .utils import run_test from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_inference_steps,lora_name,lora_strength,cpu_offload,expected_lpips", "num_inference_steps,lora_name,lora_strength,cpu_offload,expected_lpips",
[ [
(25, "realism", 0.9, True, 0.178), (25, "realism", 0.9, True, 0.136),
(25, "ghibsky", 1, False, 0.164), # (25, "ghibsky", 1, False, 0.186),
(28, "anime", 1, False, 0.284), # (28, "anime", 1, False, 0.284),
(24, "sketch", 1, True, 0.223), (24, "sketch", 1, True, 0.291),
(28, "yarn", 1, False, 0.211), # (28, "yarn", 1, False, 0.211),
(25, "haunted_linework", 1, True, 0.317), # (25, "haunted_linework", 1, True, 0.317),
], ],
) )
def test_flux_dev_loras(num_inference_steps, lora_name, lora_strength, cpu_offload, expected_lpips): def test_flux_dev_loras(num_inference_steps, lora_name, lora_strength, cpu_offload, expected_lpips):
...@@ -26,6 +26,7 @@ def test_flux_dev_loras(num_inference_steps, lora_name, lora_strength, cpu_offlo ...@@ -26,6 +26,7 @@ def test_flux_dev_loras(num_inference_steps, lora_name, lora_strength, cpu_offlo
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
guidance_scale=3.5, guidance_scale=3.5,
use_qencoder=False, use_qencoder=False,
attention_impl="nunchaku-fp16",
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
lora_names=lora_name, lora_names=lora_name,
lora_strengths=lora_strength, lora_strengths=lora_strength,
...@@ -34,73 +35,13 @@ def test_flux_dev_loras(num_inference_steps, lora_name, lora_strength, cpu_offlo ...@@ -34,73 +35,13 @@ def test_flux_dev_loras(num_inference_steps, lora_name, lora_strength, cpu_offlo
) )
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") # lora composition & large rank loras
def test_flux_dev_hypersd8_1536x2048(): @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
run_test( def test_flux_dev_turbo8_ghibsky_1024x1024():
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="MJHQ",
height=1536,
width=2048,
num_inference_steps=8,
guidance_scale=3.5,
use_qencoder=False,
attention_impl="nunchaku-fp16",
cpu_offload=True,
lora_names="hypersd8",
lora_strengths=0.125,
cache_threshold=0,
expected_lpips=0.291,
)
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_turbo8_2048x2048():
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="MJHQ",
height=2048,
width=2048,
num_inference_steps=8,
guidance_scale=3.5,
use_qencoder=False,
attention_impl="nunchaku-fp16",
cpu_offload=True,
lora_names="turbo8",
lora_strengths=1,
cache_threshold=0,
expected_lpips=0.189,
)
# lora composition
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_turbo8_yarn_2048x1024():
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="yarn",
height=2048,
width=1024,
num_inference_steps=8,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=True,
lora_names=["turbo8", "yarn"],
lora_strengths=[1, 1],
cache_threshold=0,
expected_lpips=0.252,
)
# large rank loras
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
def test_flux_dev_turbo8_yarn_1024x1024():
run_test( run_test(
precision=get_precision(), precision=get_precision(),
model_name="flux.1-dev", model_name="flux.1-dev",
dataset_name="ghibsky", dataset_name="haunted_linework",
height=1024, height=1024,
width=1024, width=1024,
num_inference_steps=8, num_inference_steps=8,
...@@ -110,5 +51,5 @@ def test_flux_dev_turbo8_yarn_1024x1024(): ...@@ -110,5 +51,5 @@ def test_flux_dev_turbo8_yarn_1024x1024():
lora_names=["realism", "ghibsky", "anime", "sketch", "yarn", "haunted_linework", "turbo8"], lora_names=["realism", "ghibsky", "anime", "sketch", "yarn", "haunted_linework", "turbo8"],
lora_strengths=[0, 1, 0, 0, 0, 0, 1], lora_strengths=[0, 1, 0, 0, 0, 0, 1],
cache_threshold=0, cache_threshold=0,
expected_lpips=0.44, expected_lpips=0.310,
) )
import os
import subprocess
import pytest
EXAMPLES_DIR = "./examples"
example_scripts = [f for f in os.listdir(EXAMPLES_DIR) if f.endswith(".py") and f.startswith("flux")]
@pytest.mark.parametrize("script_name", example_scripts)
def test_example_script_runs(script_name):
script_path = os.path.join(EXAMPLES_DIR, script_name)
result = subprocess.run(["python", script_path], capture_output=True, text=True)
print(f"Running {script_path} -> Return code: {result.returncode}")
assert result.returncode == 0, f"{script_path} failed with code {result.returncode}"
...@@ -6,7 +6,7 @@ from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel ...@@ -6,7 +6,7 @@ from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
from nunchaku.utils import get_precision, is_turing from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"use_qencoder,cpu_offload,memory_limit", "use_qencoder,cpu_offload,memory_limit",
[ [
...@@ -38,7 +38,7 @@ def test_flux_schnell_memory(use_qencoder: bool, cpu_offload: bool, memory_limit ...@@ -38,7 +38,7 @@ def test_flux_schnell_memory(use_qencoder: bool, cpu_offload: bool, memory_limit
pipeline = pipeline.to("cuda") pipeline = pipeline.to("cuda")
pipeline( pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=50, guidance_scale=0 "A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
) )
memory = torch.cuda.max_memory_reserved(0) / 1024**3 memory = torch.cuda.max_memory_reserved(0) / 1024**3
assert memory < memory_limit assert memory < memory_limit
......
...@@ -4,15 +4,14 @@ from nunchaku.utils import get_precision, is_turing ...@@ -4,15 +4,14 @@ from nunchaku.utils import get_precision, is_turing
from .utils import run_test from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips", "height,width,attention_impl,cpu_offload,expected_lpips",
[ [
(1024, 1024, "flashattn2", False, 0.250), (1024, 1024, "flashattn2", False, 0.126),
(1024, 1024, "nunchaku-fp16", False, 0.255), (1024, 1024, "nunchaku-fp16", False, 0.126),
(1024, 1024, "flashattn2", True, 0.250), (1920, 1080, "nunchaku-fp16", False, 0.158),
(1920, 1080, "nunchaku-fp16", False, 0.253), (2048, 2048, "nunchaku-fp16", True, 0.166),
(2048, 2048, "flashattn2", True, 0.274),
], ],
) )
def test_int4_schnell(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float): def test_int4_schnell(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float):
......
...@@ -5,7 +5,7 @@ from nunchaku.utils import get_precision, is_turing ...@@ -5,7 +5,7 @@ from nunchaku.utils import get_precision, is_turing
from .utils import run_test from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_canny_dev(): def test_flux_canny_dev():
run_test( run_test(
precision=get_precision(), precision=get_precision(),
...@@ -15,16 +15,16 @@ def test_flux_canny_dev(): ...@@ -15,16 +15,16 @@ def test_flux_canny_dev():
dtype=torch.bfloat16, dtype=torch.bfloat16,
height=1024, height=1024,
width=1024, width=1024,
num_inference_steps=50, num_inference_steps=30,
guidance_scale=30, guidance_scale=30,
attention_impl="nunchaku-fp16", attention_impl="nunchaku-fp16",
cpu_offload=False, cpu_offload=False,
cache_threshold=0, cache_threshold=0,
expected_lpips=0.103 if get_precision() == "int4" else 0.164, expected_lpips=0.076 if get_precision() == "int4" else 0.164,
) )
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_depth_dev(): def test_flux_depth_dev():
run_test( run_test(
precision=get_precision(), precision=get_precision(),
...@@ -39,11 +39,11 @@ def test_flux_depth_dev(): ...@@ -39,11 +39,11 @@ def test_flux_depth_dev():
attention_impl="nunchaku-fp16", attention_impl="nunchaku-fp16",
cpu_offload=False, cpu_offload=False,
cache_threshold=0, cache_threshold=0,
expected_lpips=0.170 if get_precision() == "int4" else 0.120, expected_lpips=0.137 if get_precision() == "int4" else 0.120,
) )
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_fill_dev(): def test_flux_fill_dev():
run_test( run_test(
precision=get_precision(), precision=get_precision(),
...@@ -53,37 +53,37 @@ def test_flux_fill_dev(): ...@@ -53,37 +53,37 @@ def test_flux_fill_dev():
dtype=torch.bfloat16, dtype=torch.bfloat16,
height=1024, height=1024,
width=1024, width=1024,
num_inference_steps=50, num_inference_steps=30,
guidance_scale=30, guidance_scale=30,
attention_impl="nunchaku-fp16", attention_impl="nunchaku-fp16",
cpu_offload=False, cpu_offload=False,
cache_threshold=0, cache_threshold=0,
expected_lpips=0.045, expected_lpips=0.046,
) )
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") # @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_canny_lora(): # def test_flux_dev_canny_lora():
run_test( # run_test(
precision=get_precision(), # precision=get_precision(),
model_name="flux.1-dev", # model_name="flux.1-dev",
dataset_name="MJHQ-control", # dataset_name="MJHQ-control",
task="canny", # task="canny",
dtype=torch.bfloat16, # dtype=torch.bfloat16,
height=1024, # height=1024,
width=1024, # width=1024,
num_inference_steps=50, # num_inference_steps=30,
guidance_scale=30, # guidance_scale=30,
attention_impl="nunchaku-fp16", # attention_impl="nunchaku-fp16",
cpu_offload=False, # cpu_offload=False,
lora_names="canny", # lora_names="canny",
lora_strengths=0.85, # lora_strengths=0.85,
cache_threshold=0, # cache_threshold=0,
expected_lpips=0.103, # expected_lpips=0.081,
) # )
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_depth_lora(): def test_flux_dev_depth_lora():
run_test( run_test(
precision=get_precision(), precision=get_precision(),
...@@ -100,11 +100,11 @@ def test_flux_dev_depth_lora(): ...@@ -100,11 +100,11 @@ def test_flux_dev_depth_lora():
cache_threshold=0, cache_threshold=0,
lora_names="depth", lora_names="depth",
lora_strengths=0.85, lora_strengths=0.85,
expected_lpips=0.163, expected_lpips=0.181,
) )
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_fill_dev_turbo(): def test_flux_fill_dev_turbo():
run_test( run_test(
precision=get_precision(), precision=get_precision(),
...@@ -121,11 +121,11 @@ def test_flux_fill_dev_turbo(): ...@@ -121,11 +121,11 @@ def test_flux_fill_dev_turbo():
cache_threshold=0, cache_threshold=0,
lora_names="turbo8", lora_names="turbo8",
lora_strengths=1, lora_strengths=1,
expected_lpips=0.048, expected_lpips=0.036,
) )
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_redux(): def test_flux_dev_redux():
run_test( run_test(
precision=get_precision(), precision=get_precision(),
...@@ -135,10 +135,10 @@ def test_flux_dev_redux(): ...@@ -135,10 +135,10 @@ def test_flux_dev_redux():
dtype=torch.bfloat16, dtype=torch.bfloat16,
height=1024, height=1024,
width=1024, width=1024,
num_inference_steps=50, num_inference_steps=20,
guidance_scale=2.5, guidance_scale=2.5,
attention_impl="nunchaku-fp16", attention_impl="nunchaku-fp16",
cpu_offload=False, cpu_offload=False,
cache_threshold=0, cache_threshold=0,
expected_lpips=0.198 if get_precision() == "int4" else 0.55, # redux seems to generate different images on 5090 expected_lpips=(0.162 if get_precision() == "int4" else 0.5), # not sure why the fp4 model is so different
) )
# skip this test
import pytest
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips,batch_size",
[
(1024, 1024, "nunchaku-fp16", False, 0.140, 2),
(1920, 1080, "flashattn2", False, 0.160, 4),
],
)
def test_int4_schnell(
height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float, batch_size: int
):
run_test(
precision=get_precision(),
height=height,
width=width,
attention_impl=attention_impl,
cpu_offload=cpu_offload,
expected_lpips=expected_lpips,
batch_size=batch_size,
)
...@@ -4,10 +4,9 @@ from .utils import run_test ...@@ -4,10 +4,9 @@ from .utils import run_test
from nunchaku.utils import get_precision, is_turing from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips", "height,width,attention_impl,cpu_offload,expected_lpips", [(1024, 1024, "nunchaku-fp16", False, 0.209)]
[(1024, 1024, "flashattn2", False, 0.25), (2048, 512, "nunchaku-fp16", False, 0.25)],
) )
def test_shuttle_jaguar(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float): def test_shuttle_jaguar(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float):
run_test( run_test(
......
import pytest import pytest
from nunchaku.utils import get_precision from nunchaku.utils import get_precision, is_turing
from .utils import run_test from .utils import run_test
@pytest.mark.skipif(get_precision() == "fp4", reason="Blackwell GPUs. Skip tests for Turing.") @pytest.mark.skipif(not is_turing(), reason="Not turing GPUs. Skip tests.")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"height,width,num_inference_steps,cpu_offload,i2f_mode,expected_lpips", "height,width,num_inference_steps,cpu_offload,i2f_mode,expected_lpips",
[ [
(1024, 1024, 50, True, None, 0.253),
(1024, 1024, 50, True, "enabled", 0.258), (1024, 1024, 50, True, "enabled", 0.258),
(1024, 1024, 50, True, "always", 0.257),
], ],
) )
def test_flux_dev( def test_flux_dev_on_turing(
height: int, width: int, num_inference_steps: int, cpu_offload: bool, i2f_mode: str | None, expected_lpips: float height: int, width: int, num_inference_steps: int, cpu_offload: bool, i2f_mode: str | None, expected_lpips: float
): ):
run_test( run_test(
......
import gc
import math
import os import os
import torch import torch
from controlnet_aux import CannyDetector from controlnet_aux import CannyDetector
from diffusers import FluxControlPipeline, FluxFillPipeline, FluxPipeline, FluxPriorReduxPipeline from diffusers import FluxControlPipeline, FluxFillPipeline, FluxPipeline, FluxPriorReduxPipeline
from diffusers.hooks import apply_group_offloading
from diffusers.utils import load_image from diffusers.utils import load_image
from image_gen_aux import DepthPreprocessor from image_gen_aux import DepthPreprocessor
from tqdm import tqdm from tqdm import tqdm
...@@ -10,7 +13,7 @@ from tqdm import tqdm ...@@ -10,7 +13,7 @@ from tqdm import tqdm
import nunchaku import nunchaku
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
from nunchaku.lora.flux.compose import compose_lora from nunchaku.lora.flux.compose import compose_lora
from ..data import download_hf_dataset, get_dataset from ..data import get_dataset
from ..utils import already_generate, compute_lpips, hash_str_to_int from ..utils import already_generate, compute_lpips, hash_str_to_int
ORIGINAL_REPO_MAP = { ORIGINAL_REPO_MAP = {
...@@ -45,7 +48,7 @@ LORA_PATH_MAP = { ...@@ -45,7 +48,7 @@ LORA_PATH_MAP = {
} }
def run_pipeline(dataset, task: str, pipeline: FluxPipeline, save_dir: str, forward_kwargs: dict = {}): def run_pipeline(dataset, batch_size: int, task: str, pipeline: FluxPipeline, save_dir: str, forward_kwargs: dict = {}):
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
pipeline.set_progress_bar_config(desc="Sampling", leave=False, dynamic_ncols=True, position=1) pipeline.set_progress_bar_config(desc="Sampling", leave=False, dynamic_ncols=True, position=1)
...@@ -61,43 +64,65 @@ def run_pipeline(dataset, task: str, pipeline: FluxPipeline, save_dir: str, forw ...@@ -61,43 +64,65 @@ def run_pipeline(dataset, task: str, pipeline: FluxPipeline, save_dir: str, forw
assert task in ["t2i", "fill"] assert task in ["t2i", "fill"]
processor = None processor = None
for row in tqdm(dataset): for row in tqdm(
filename = row["filename"] dataset.iter(batch_size=batch_size, drop_last_batch=False),
prompt = row["prompt"] desc="Batch",
total=math.ceil(len(dataset) // batch_size),
position=0,
leave=False,
):
filenames = row["filename"]
prompts = row["prompt"]
_forward_kwargs = {k: v for k, v in forward_kwargs.items()} _forward_kwargs = {k: v for k, v in forward_kwargs.items()}
if task == "canny": if task == "canny":
assert forward_kwargs.get("height", 1024) == 1024 assert forward_kwargs.get("height", 1024) == 1024
assert forward_kwargs.get("width", 1024) == 1024 assert forward_kwargs.get("width", 1024) == 1024
control_image = load_image(row["canny_image_path"]) control_images = []
control_image = processor( for canny_image_path in row["canny_image_path"]:
control_image, control_image = load_image(canny_image_path)
low_threshold=50, control_image = processor(
high_threshold=200, control_image,
detect_resolution=1024, low_threshold=50,
image_resolution=1024, high_threshold=200,
) detect_resolution=1024,
_forward_kwargs["control_image"] = control_image image_resolution=1024,
)
control_images.append(control_image)
_forward_kwargs["control_image"] = control_images
elif task == "depth": elif task == "depth":
control_image = load_image(row["depth_image_path"]) control_images = []
control_image = processor(control_image)[0].convert("RGB") for depth_image_path in row["depth_image_path"]:
_forward_kwargs["control_image"] = control_image control_image = load_image(depth_image_path)
control_image = processor(control_image)[0].convert("RGB")
control_images.append(control_image)
_forward_kwargs["control_image"] = control_images
elif task == "fill": elif task == "fill":
image = load_image(row["image_path"]) images, mask_images = [], []
mask_image = load_image(row["mask_image_path"]) for image_path, mask_image_path in zip(row["image_path"], row["mask_image_path"]):
_forward_kwargs["image"] = image image = load_image(image_path)
_forward_kwargs["mask_image"] = mask_image mask_image = load_image(mask_image_path)
images.append(image)
mask_images.append(mask_image)
_forward_kwargs["image"] = images
_forward_kwargs["mask_image"] = mask_images
elif task == "redux": elif task == "redux":
image = load_image(row["image_path"]) images = []
_forward_kwargs.update(processor(image)) for image_path in row["image_path"]:
image = load_image(image_path)
images.append(image)
_forward_kwargs.update(processor(images))
seed = hash_str_to_int(filename) seeds = [hash_str_to_int(filename) for filename in filenames]
generators = [torch.Generator().manual_seed(seed) for seed in seeds]
if task == "redux": if task == "redux":
image = pipeline(generator=torch.Generator().manual_seed(seed), **_forward_kwargs).images[0] images = pipeline(generator=generators, **_forward_kwargs).images
else: else:
image = pipeline(prompt, generator=torch.Generator().manual_seed(seed), **_forward_kwargs).images[0] images = pipeline(prompts, generator=generators, **_forward_kwargs).images
image.save(os.path.join(save_dir, f"{filename}.png")) for i, image in enumerate(images):
filename = filenames[i]
image.save(os.path.join(save_dir, f"{filename}.png"))
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -105,6 +130,7 @@ def run_test( ...@@ -105,6 +130,7 @@ def run_test(
precision: str = "int4", precision: str = "int4",
model_name: str = "flux.1-schnell", model_name: str = "flux.1-schnell",
dataset_name: str = "MJHQ", dataset_name: str = "MJHQ",
batch_size: int = 1,
task: str = "t2i", task: str = "t2i",
dtype: str | torch.dtype = torch.bfloat16, # the full precision dtype dtype: str | torch.dtype = torch.bfloat16, # the full precision dtype
height: int = 1024, height: int = 1024,
...@@ -117,10 +143,12 @@ def run_test( ...@@ -117,10 +143,12 @@ def run_test(
cache_threshold: float = 0, cache_threshold: float = 0,
lora_names: str | list[str] | None = None, lora_names: str | list[str] | None = None,
lora_strengths: float | list[float] = 1.0, lora_strengths: float | list[float] = 1.0,
max_dataset_size: int = 20, max_dataset_size: int = 4,
i2f_mode: str | None = None, i2f_mode: str | None = None,
expected_lpips: float = 0.5, expected_lpips: float = 0.5,
): ):
gc.collect()
torch.cuda.empty_cache()
if isinstance(dtype, str): if isinstance(dtype, str):
dtype_str = dtype dtype_str = dtype
if dtype == "bf16": if dtype == "bf16":
...@@ -153,10 +181,7 @@ def run_test( ...@@ -153,10 +181,7 @@ def run_test(
for lora_name, lora_strength in zip(lora_names, lora_strengths): for lora_name, lora_strength in zip(lora_names, lora_strengths):
folder_name += f"-{lora_name}_{lora_strength}" folder_name += f"-{lora_name}_{lora_strength}"
if not os.path.exists(os.path.join("test_results", "ref")): ref_root = os.environ.get("NUNCHAKU_TEST_CACHE_ROOT", os.path.join("test_results", "ref"))
ref_root = download_hf_dataset(local_dir=os.path.join("test_results", "ref"))
else:
ref_root = os.path.join("test_results", "ref")
save_dir_16bit = os.path.join(ref_root, dtype_str, model_name, folder_name) save_dir_16bit = os.path.join(ref_root, dtype_str, model_name, folder_name)
if task in ["t2i", "redux"]: if task in ["t2i", "redux"]:
...@@ -171,7 +196,8 @@ def run_test( ...@@ -171,7 +196,8 @@ def run_test(
if not already_generate(save_dir_16bit, max_dataset_size): if not already_generate(save_dir_16bit, max_dataset_size):
pipeline_init_kwargs = {"text_encoder": None, "text_encoder2": None} if task == "redux" else {} pipeline_init_kwargs = {"text_encoder": None, "text_encoder2": None} if task == "redux" else {}
pipeline = pipeline_cls.from_pretrained(model_id_16bit, torch_dtype=dtype, **pipeline_init_kwargs) pipeline = pipeline_cls.from_pretrained(model_id_16bit, torch_dtype=dtype, **pipeline_init_kwargs)
pipeline = pipeline.to("cuda") gpu_properties = torch.cuda.get_device_properties(0)
gpu_memory = gpu_properties.total_memory / (1024**2)
if len(lora_names) > 0: if len(lora_names) > 0:
for i, (lora_name, lora_strength) in enumerate(zip(lora_names, lora_strengths)): for i, (lora_name, lora_strength) in enumerate(zip(lora_names, lora_strengths)):
...@@ -181,7 +207,30 @@ def run_test( ...@@ -181,7 +207,30 @@ def run_test(
) )
pipeline.set_adapters([f"lora_{i}" for i in range(len(lora_names))], lora_strengths) pipeline.set_adapters([f"lora_{i}" for i in range(len(lora_names))], lora_strengths)
if gpu_memory > 36 * 1024:
pipeline = pipeline.to("cuda")
elif gpu_memory < 26 * 1024:
pipeline.transformer.enable_group_offload(
onload_device=torch.device("cuda"),
offload_device=torch.device("cpu"),
offload_type="leaf_level",
use_stream=True,
)
if pipeline.text_encoder is not None:
pipeline.text_encoder.to("cuda")
if pipeline.text_encoder_2 is not None:
apply_group_offloading(
pipeline.text_encoder_2,
onload_device=torch.device("cuda"),
offload_type="block_level",
num_blocks_per_group=2,
)
pipeline.vae.to("cuda")
else:
pipeline.enable_model_cpu_offload()
run_pipeline( run_pipeline(
batch_size=batch_size,
dataset=dataset, dataset=dataset,
task=task, task=task,
pipeline=pipeline, pipeline=pipeline,
...@@ -195,6 +244,7 @@ def run_test( ...@@ -195,6 +244,7 @@ def run_test(
) )
del pipeline del pipeline
# release the gpu memory # release the gpu memory
gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
precision_str = precision precision_str = precision
...@@ -211,6 +261,8 @@ def run_test( ...@@ -211,6 +261,8 @@ def run_test(
precision_str += f"-cache{cache_threshold}" precision_str += f"-cache{cache_threshold}"
if i2f_mode is not None: if i2f_mode is not None:
precision_str += f"-i2f{i2f_mode}" precision_str += f"-i2f{i2f_mode}"
if batch_size > 1:
precision_str += f"-bs{batch_size}"
save_dir_4bit = os.path.join("test_results", dtype_str, precision_str, model_name, folder_name) save_dir_4bit = os.path.join("test_results", dtype_str, precision_str, model_name, folder_name)
if not already_generate(save_dir_4bit, max_dataset_size): if not already_generate(save_dir_4bit, max_dataset_size):
...@@ -252,6 +304,7 @@ def run_test( ...@@ -252,6 +304,7 @@ def run_test(
else: else:
pipeline = pipeline.to("cuda") pipeline = pipeline.to("cuda")
run_pipeline( run_pipeline(
batch_size=batch_size,
dataset=dataset, dataset=dataset,
task=task, task=task,
pipeline=pipeline, pipeline=pipeline,
...@@ -266,7 +319,8 @@ def run_test( ...@@ -266,7 +319,8 @@ def run_test(
del transformer del transformer
del pipeline del pipeline
# release the gpu memory # release the gpu memory
gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
lpips = compute_lpips(save_dir_16bit, save_dir_4bit) lpips = compute_lpips(save_dir_16bit, save_dir_4bit)
print(f"lpips: {lpips}") print(f"lpips: {lpips}")
assert lpips < expected_lpips * 1.05 assert lpips < expected_lpips * 1.1
import os
import subprocess
import pytest
from nunchaku.utils import get_precision, is_turing
EXAMPLES_DIR = "./examples"
example_scripts = [f for f in os.listdir(EXAMPLES_DIR) if f.endswith(".py") and f.startswith("sana")]
@pytest.mark.skipif(
is_turing() or get_precision() == "fp4", reason="SANA does not support Turing GPUs or FP4 precision"
)
@pytest.mark.parametrize("script_name", example_scripts)
def test_example_script_runs(script_name):
script_path = os.path.join(EXAMPLES_DIR, script_name)
result = subprocess.run(["python", script_path], capture_output=True, text=True)
print(f"Running {script_path} -> Return code: {result.returncode}")
assert result.returncode == 0, f"{script_path} failed with code {result.returncode}"
import pytest
import torch
from diffusers import SanaPAGPipeline, SanaPipeline
from nunchaku import NunchakuSanaTransformer2DModel
from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing() or get_precision() == "fp4", reason="Skip tests due to Turing GPUs")
def test_sana():
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
pipe = SanaPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
transformer=transformer,
variant="bf16",
torch_dtype=torch.bfloat16,
).to("cuda")
pipe.vae.to(torch.bfloat16)
pipe.text_encoder.to(torch.bfloat16)
prompt = "A cute 🐼 eating 🎋, ink drawing style"
image = pipe(
prompt=prompt,
height=1024,
width=1024,
guidance_scale=4.5,
num_inference_steps=20,
generator=torch.Generator().manual_seed(42),
).images[0]
image.save("sana_1600m.png")
@pytest.mark.skipif(is_turing() or get_precision() == "fp4", reason="Skip tests due to Turing GPUs")
def test_sana_pag():
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m", pag_layers=8)
pipe = SanaPAGPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
transformer=transformer,
variant="bf16",
torch_dtype=torch.bfloat16,
pag_applied_layers="transformer_blocks.8",
).to("cuda")
pipe._set_pag_attn_processor = lambda *args, **kwargs: None
pipe.text_encoder.to(torch.bfloat16)
pipe.vae.to(torch.bfloat16)
image = pipe(
prompt="A cute 🐼 eating 🎋, ink drawing style",
height=1024,
width=1024,
guidance_scale=5.0,
pag_scale=2.0,
num_inference_steps=20,
).images[0]
image.save("sana_1600m_pag.png")
...@@ -59,7 +59,7 @@ class MultiImageDataset(data.Dataset): ...@@ -59,7 +59,7 @@ class MultiImageDataset(data.Dataset):
def compute_lpips( def compute_lpips(
ref_dirpath: str, gen_dirpath: str, batch_size: int = 4, num_workers: int = 8, device: str | torch.device = "cuda" ref_dirpath: str, gen_dirpath: str, batch_size: int = 4, num_workers: int = 0, device: str | torch.device = "cuda"
) -> float: ) -> float:
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
metric = LearnedPerceptualImagePatchSimilarity(normalize=True).to(device) metric = LearnedPerceptualImagePatchSimilarity(normalize=True).to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment