Unverified Commit 37a27712 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

Merge pull request #340 from mit-han-lab/dev

feat: support PuLID, Double FBCache and TeaCache; better linter
parents c1d6fc84 760ab022
import pytest import pytest
from nunchaku.utils import get_precision, is_turing from nunchaku.utils import get_precision, is_turing
from .utils import run_test from .utils import run_test
......
from types import MethodType
import numpy as np
import pytest
import torch
import torch.nn.functional as F
from diffusers.utils import load_image
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.models.pulid.pulid_forward import pulid_forward
from nunchaku.models.pulid.utils import resize_numpy_image_long
from nunchaku.pipeline.pipeline_flux_pulid import PuLIDFluxPipeline
from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_pulid():
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev")
pipeline = PuLIDFluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
pipeline.transformer.forward = MethodType(pulid_forward, pipeline.transformer)
id_image = load_image("https://github.com/ToTheBeginning/PuLID/blob/main/example_inputs/liuyifei.png?raw=true")
image = pipeline(
"A woman holding a sign that says hello world",
id_image=id_image,
id_weight=1,
num_inference_steps=12,
guidance_scale=3.5,
).images[0]
id_image = id_image.convert("RGB")
id_image_numpy = np.array(id_image)
id_image = resize_numpy_image_long(id_image_numpy, 1024)
id_embeddings, _ = pipeline.pulid_model.get_id_embedding(id_image)
output_image = image.convert("RGB")
output_image_numpy = np.array(output_image)
output_image = resize_numpy_image_long(output_image_numpy, 1024)
output_id_embeddings, _ = pipeline.pulid_model.get_id_embedding(output_image)
cosine_similarities = (
F.cosine_similarity(id_embeddings.view(32, 2048), output_id_embeddings.view(32, 2048), dim=1).mean().item()
)
print(cosine_similarities)
assert cosine_similarities > 0.93
import pytest
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize(
"use_double_fb_cache,residual_diff_threshold_multi,residual_diff_threshold_single,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips",
[
(True, 0.09, 0.12, 1024, 1024, 30, None, 1, 0.24 if get_precision() == "int4" else 0.165),
(True, 0.09, 0.12, 1024, 1024, 50, None, 1, 0.24 if get_precision() == "int4" else 0.161),
],
)
def test_flux_dev_double_fb_cache(
use_double_fb_cache: bool,
residual_diff_threshold_multi: float,
residual_diff_threshold_single: float,
height: int,
width: int,
num_inference_steps: int,
lora_name: str,
lora_strength: float,
expected_lpips: float,
):
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="MJHQ" if lora_name is None else lora_name,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=3.5,
use_qencoder=False,
cpu_offload=False,
lora_names=lora_name,
lora_strengths=lora_strength,
use_double_fb_cache=use_double_fb_cache,
residual_diff_threshold_multi=residual_diff_threshold_multi,
residual_diff_threshold_single=residual_diff_threshold_single,
expected_lpips=expected_lpips,
)
import gc
import os import os
import subprocess import subprocess
import pytest import pytest
import torch
EXAMPLES_DIR = "./examples" EXAMPLES_DIR = "./examples"
...@@ -10,6 +12,8 @@ example_scripts = [f for f in os.listdir(EXAMPLES_DIR) if f.endswith(".py") and ...@@ -10,6 +12,8 @@ example_scripts = [f for f in os.listdir(EXAMPLES_DIR) if f.endswith(".py") and
@pytest.mark.parametrize("script_name", example_scripts) @pytest.mark.parametrize("script_name", example_scripts)
def test_example_script_runs(script_name): def test_example_script_runs(script_name):
gc.collect()
torch.cuda.empty_cache()
script_path = os.path.join(EXAMPLES_DIR, script_name) script_path = os.path.join(EXAMPLES_DIR, script_name)
result = subprocess.run(["python", script_path], capture_output=True, text=True) result = subprocess.run(["python", script_path], capture_output=True, text=True)
print(f"Running {script_path} -> Return code: {result.returncode}") print(f"Running {script_path} -> Return code: {result.returncode}")
......
import pytest import pytest
from nunchaku.utils import get_precision, is_turing from nunchaku.utils import get_precision, is_turing
from .utils import run_test from .utils import run_test
...@@ -8,13 +9,13 @@ from .utils import run_test ...@@ -8,13 +9,13 @@ from .utils import run_test
@pytest.mark.parametrize( @pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips", "height,width,attention_impl,cpu_offload,expected_lpips",
[ [
(1024, 1024, "flashattn2", False, 0.126 if get_precision() == "int4" else 0.113), (1024, 1024, "flashattn2", False, 0.126 if get_precision() == "int4" else 0.126),
(1024, 1024, "nunchaku-fp16", False, 0.126 if get_precision() == "int4" else 0.113), (1024, 1024, "nunchaku-fp16", False, 0.126 if get_precision() == "int4" else 0.126),
(1920, 1080, "nunchaku-fp16", False, 0.158 if get_precision() == "int4" else 0.138), (1920, 1080, "nunchaku-fp16", False, 0.158 if get_precision() == "int4" else 0.138),
(2048, 2048, "nunchaku-fp16", True, 0.166 if get_precision() == "int4" else 0.120), (2048, 2048, "nunchaku-fp16", True, 0.166 if get_precision() == "int4" else 0.120),
], ],
) )
def test_int4_schnell(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float): def test_flux_schnell(height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float):
run_test( run_test(
precision=get_precision(), precision=get_precision(),
height=height, height=height,
......
import gc
import os
import pytest
import torch
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.caching.teacache import TeaCache
from nunchaku.utils import get_precision, is_turing
from .utils import already_generate, compute_lpips, offload_pipeline
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize(
"height,width,num_inference_steps,prompt,name,seed,threshold,expected_lpips",
[
(
1024,
1024,
30,
"A cat holding a sign that says hello world",
"cat",
0,
0.6,
0.363 if get_precision() == "int4" else 0.363,
),
(
512,
2048,
25,
"The brown fox jumps over the lazy dog",
"fox",
1234,
0.7,
0.349 if get_precision() == "int4" else 0.349,
),
(
1024,
768,
50,
"A scene from the Titanic movie featuring the Muppets",
"muppets",
42,
0.3,
0.360 if get_precision() == "int4" else 0.495,
),
(
1024,
768,
50,
"A crystal ball showing a waterfall",
"waterfall",
23,
0.6,
0.226 if get_precision() == "int4" else 0.226,
),
],
)
def test_flux_teacache(
height: int,
width: int,
num_inference_steps: int,
prompt: str,
name: str,
seed: int,
threshold: float,
expected_lpips: float,
):
gc.collect()
torch.cuda.empty_cache()
device = torch.device("cuda")
precision = get_precision()
ref_root = os.environ.get("NUNCHAKU_TEST_CACHE_ROOT", os.path.join("test_results", "ref"))
results_dir_16_bit = os.path.join(ref_root, "bf16", "flux.1-dev", "teacache", name)
results_dir_4_bit = os.path.join("test_results", precision, "flux.1-dev", "teacache", name)
os.makedirs(results_dir_16_bit, exist_ok=True)
os.makedirs(results_dir_4_bit, exist_ok=True)
# First, generate results with the 16-bit model
if not already_generate(results_dir_16_bit, 1):
pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
# Possibly offload the model to CPU when GPU memory is scarce
pipeline = offload_pipeline(pipeline)
result = pipeline(
prompt=prompt,
num_inference_steps=num_inference_steps,
height=height,
width=width,
generator=torch.Generator(device=device).manual_seed(seed),
).images[0]
result.save(os.path.join(results_dir_16_bit, f"{name}_{seed}.png"))
# Clean up the 16-bit model
del pipeline.transformer
del pipeline.text_encoder
del pipeline.text_encoder_2
del pipeline.vae
del pipeline
del result
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
free, total = torch.cuda.mem_get_info() # bytes
print(f"After 16-bit generation: Free: {free/1024**2:.0f} MB / Total: {total/1024**2:.0f} MB")
# Then, generate results with the 4-bit model
if not already_generate(results_dir_4_bit, 1):
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
with torch.inference_mode():
with TeaCache(
model=pipeline.transformer, num_steps=num_inference_steps, rel_l1_thresh=threshold, enabled=True
):
result = pipeline(
prompt=prompt,
num_inference_steps=num_inference_steps,
height=height,
width=width,
generator=torch.Generator(device=device).manual_seed(seed),
).images[0]
result.save(os.path.join(results_dir_4_bit, f"{name}_{seed}.png"))
# Clean up the 4-bit model
del pipeline
del transformer
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
free, total = torch.cuda.mem_get_info() # bytes
print(f"After 4-bit generation: Free: {free/1024**2:.0f} MB / Total: {total/1024**2:.0f} MB")
lpips = compute_lpips(results_dir_16_bit, results_dir_4_bit)
print(f"lpips: {lpips}")
assert lpips < expected_lpips * 1.1
...@@ -2,6 +2,7 @@ import pytest ...@@ -2,6 +2,7 @@ import pytest
import torch import torch
from nunchaku.utils import get_precision, is_turing from nunchaku.utils import get_precision, is_turing
from .utils import run_test from .utils import run_test
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import pytest import pytest
from nunchaku.utils import get_precision, is_turing from nunchaku.utils import get_precision, is_turing
from .utils import run_test from .utils import run_test
...@@ -9,11 +10,11 @@ from .utils import run_test ...@@ -9,11 +10,11 @@ from .utils import run_test
@pytest.mark.parametrize( @pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips,batch_size", "height,width,attention_impl,cpu_offload,expected_lpips,batch_size",
[ [
(1024, 1024, "nunchaku-fp16", False, 0.140 if get_precision() == "int4" else 0.118, 2), (1024, 1024, "nunchaku-fp16", False, 0.140 if get_precision() == "int4" else 0.135, 2),
(1920, 1080, "flashattn2", False, 0.160 if get_precision() == "int4" else 0.123, 4), (1920, 1080, "flashattn2", False, 0.160 if get_precision() == "int4" else 0.123, 4),
], ],
) )
def test_int4_schnell( def test_flux_schnell(
height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float, batch_size: int height: int, width: int, attention_impl: str, cpu_offload: bool, expected_lpips: float, batch_size: int
): ):
run_test( run_test(
......
import pytest import pytest
from .utils import run_test
from nunchaku.utils import get_precision, is_turing from nunchaku.utils import get_precision, is_turing
from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
import pytest import pytest
from nunchaku.utils import get_precision, is_turing from nunchaku.utils import get_precision, is_turing
from .utils import run_test from .utils import run_test
......
...@@ -12,7 +12,9 @@ from tqdm import tqdm ...@@ -12,7 +12,9 @@ from tqdm import tqdm
import nunchaku import nunchaku
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
from nunchaku.lora.flux.compose import compose_lora from nunchaku.lora.flux.compose import compose_lora
from ..data import get_dataset from ..data import get_dataset
from ..utils import already_generate, compute_lpips, hash_str_to_int from ..utils import already_generate, compute_lpips, hash_str_to_int
...@@ -141,6 +143,9 @@ def run_test( ...@@ -141,6 +143,9 @@ def run_test(
attention_impl: str = "flashattn2", # "flashattn2" or "nunchaku-fp16" attention_impl: str = "flashattn2", # "flashattn2" or "nunchaku-fp16"
cpu_offload: bool = False, cpu_offload: bool = False,
cache_threshold: float = 0, cache_threshold: float = 0,
use_double_fb_cache: bool = False,
residual_diff_threshold_multi: float = 0,
residual_diff_threshold_single: float = 0,
lora_names: str | list[str] | None = None, lora_names: str | list[str] | None = None,
lora_strengths: float | list[float] = 1.0, lora_strengths: float | list[float] = 1.0,
max_dataset_size: int = 4, max_dataset_size: int = 4,
...@@ -196,8 +201,6 @@ def run_test( ...@@ -196,8 +201,6 @@ def run_test(
if not already_generate(save_dir_16bit, max_dataset_size): if not already_generate(save_dir_16bit, max_dataset_size):
pipeline_init_kwargs = {"text_encoder": None, "text_encoder2": None} if task == "redux" else {} pipeline_init_kwargs = {"text_encoder": None, "text_encoder2": None} if task == "redux" else {}
pipeline = pipeline_cls.from_pretrained(model_id_16bit, torch_dtype=dtype, **pipeline_init_kwargs) pipeline = pipeline_cls.from_pretrained(model_id_16bit, torch_dtype=dtype, **pipeline_init_kwargs)
gpu_properties = torch.cuda.get_device_properties(0)
gpu_memory = gpu_properties.total_memory / (1024**2)
if len(lora_names) > 0: if len(lora_names) > 0:
for i, (lora_name, lora_strength) in enumerate(zip(lora_names, lora_strengths)): for i, (lora_name, lora_strength) in enumerate(zip(lora_names, lora_strengths)):
...@@ -207,27 +210,7 @@ def run_test( ...@@ -207,27 +210,7 @@ def run_test(
) )
pipeline.set_adapters([f"lora_{i}" for i in range(len(lora_names))], lora_strengths) pipeline.set_adapters([f"lora_{i}" for i in range(len(lora_names))], lora_strengths)
if gpu_memory > 36 * 1024: pipeline = offload_pipeline(pipeline)
pipeline = pipeline.to("cuda")
elif gpu_memory < 26 * 1024:
pipeline.transformer.enable_group_offload(
onload_device=torch.device("cuda"),
offload_device=torch.device("cpu"),
offload_type="leaf_level",
use_stream=True,
)
if pipeline.text_encoder is not None:
pipeline.text_encoder.to("cuda")
if pipeline.text_encoder_2 is not None:
apply_group_offloading(
pipeline.text_encoder_2,
onload_device=torch.device("cuda"),
offload_type="block_level",
num_blocks_per_group=2,
)
pipeline.vae.to("cuda")
else:
pipeline.enable_model_cpu_offload()
run_pipeline( run_pipeline(
batch_size=batch_size, batch_size=batch_size,
...@@ -259,6 +242,12 @@ def run_test( ...@@ -259,6 +242,12 @@ def run_test(
precision_str += "-co" precision_str += "-co"
if cache_threshold > 0: if cache_threshold > 0:
precision_str += f"-cache{cache_threshold}" precision_str += f"-cache{cache_threshold}"
if use_double_fb_cache:
precision_str += "-dfb"
if residual_diff_threshold_multi > 0:
precision_str += f"-rdm{residual_diff_threshold_multi}"
if residual_diff_threshold_single > 0:
precision_str += f"-rds{residual_diff_threshold_single}"
if i2f_mode is not None: if i2f_mode is not None:
precision_str += f"-i2f{i2f_mode}" precision_str += f"-i2f{i2f_mode}"
if batch_size > 1: if batch_size > 1:
...@@ -303,6 +292,15 @@ def run_test( ...@@ -303,6 +292,15 @@ def run_test(
pipeline.enable_sequential_cpu_offload() pipeline.enable_sequential_cpu_offload()
else: else:
pipeline = pipeline.to("cuda") pipeline = pipeline.to("cuda")
if use_double_fb_cache:
apply_cache_on_pipe(
pipeline,
use_double_fb_cache=use_double_fb_cache,
residual_diff_threshold_multi=residual_diff_threshold_multi,
residual_diff_threshold_single=residual_diff_threshold_single,
)
run_pipeline( run_pipeline(
batch_size=batch_size, batch_size=batch_size,
dataset=dataset, dataset=dataset,
...@@ -324,3 +322,34 @@ def run_test( ...@@ -324,3 +322,34 @@ def run_test(
lpips = compute_lpips(save_dir_16bit, save_dir_4bit) lpips = compute_lpips(save_dir_16bit, save_dir_4bit)
print(f"lpips: {lpips}") print(f"lpips: {lpips}")
assert lpips < expected_lpips * 1.1 assert lpips < expected_lpips * 1.1
def offload_pipeline(pipeline: FluxPipeline) -> FluxPipeline:
gpu_properties = torch.cuda.get_device_properties(0)
gpu_memory = gpu_properties.total_memory / (1024**2)
device = torch.device("cuda")
cpu = torch.device("cpu")
if gpu_memory > 36 * 1024:
pipeline = pipeline.to(device)
elif gpu_memory < 26 * 1024:
pipeline.transformer.enable_group_offload(
onload_device=device,
offload_device=cpu,
offload_type="leaf_level",
use_stream=True,
)
if pipeline.text_encoder is not None:
pipeline.text_encoder.to(device)
if pipeline.text_encoder_2 is not None:
apply_group_offloading(
pipeline.text_encoder_2,
onload_device=device,
offload_type="block_level",
num_blocks_per_group=2,
)
pipeline.vae.to(device)
else:
pipeline.enable_model_cpu_offload()
return pipeline
...@@ -5,4 +5,8 @@ torchmetrics ...@@ -5,4 +5,8 @@ torchmetrics
mediapipe mediapipe
controlnet_aux controlnet_aux
peft peft
git+https://github.com/asomoza/image_gen_aux.git git+https://github.com/asomoza/image_gen_aux.git
\ No newline at end of file insightface
opencv-python
facexlib
onnxruntime
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment