test_flux_kontext.py 3.62 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gc
import os
from pathlib import Path

import pytest
import torch
from diffusers import FluxKontextPipeline
from diffusers.utils import load_image

from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision, is_turing

from .utils import already_generate, compute_lpips, hash_str_to_int, offload_pipeline


@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize("expected_lpips", [0.25 if get_precision() == "int4" else 0.18])
def test_flux_kontext(expected_lpips: float):
    gc.collect()
    torch.cuda.empty_cache()

    precision = get_precision()

    ref_root = Path(os.environ.get("NUNCHAKU_TEST_CACHE_ROOT", os.path.join("test_results", "ref")))
    results_dir_16_bit = ref_root / "bf16" / "flux.1-kontext-dev" / "kontext"
    results_dir_4_bit = Path("test_results") / precision / "flux.1-kontext-dev" / "kontext"

    os.makedirs(results_dir_16_bit, exist_ok=True)
    os.makedirs(results_dir_4_bit, exist_ok=True)

    image = load_image(
        "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
    ).convert("RGB")
    prompts = [
        "Make Pikachu hold a sign that says 'Nunchaku is awesome', yarn art style, detailed, vibrant colors",
        "Convert the image to ghibli style",
        "help me convert it to manga style",
        "Convert it to a realistic photo",
    ]

    # First, generate results with the 16-bit model
    if not already_generate(results_dir_16_bit, 4):
        pipeline = FluxKontextPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
        )

        # Possibly offload the model to CPU when GPU memory is scarce
        pipeline = offload_pipeline(pipeline)

        for prompt in prompts:
            seed = hash_str_to_int(prompt)
            result = pipeline(image=image, prompt=prompt, generator=torch.Generator().manual_seed(seed)).images[0]
            result.save(os.path.join(results_dir_16_bit, f"{seed}.png"))

        # Clean up the 16-bit model
        del pipeline.transformer
        del pipeline.text_encoder
        del pipeline.text_encoder_2
        del pipeline.vae
        del pipeline
        del result
        gc.collect()
        torch.cuda.synchronize()
        torch.cuda.empty_cache()

    free, total = torch.cuda.mem_get_info()  # bytes
    print(f"After 16-bit generation: Free: {free/1024**2:.0f} MB  /  Total: {total/1024**2:.0f} MB")

    # Then, generate results with the 4-bit model
    if not already_generate(results_dir_4_bit, 4):
        transformer = NunchakuFluxTransformer2dModel.from_pretrained(
            f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{precision}_r32-flux.1-kontext-dev.safetensors"
        )
        pipeline = FluxKontextPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
        ).to("cuda")
        for prompt in prompts:
            seed = hash_str_to_int(prompt)
            result = pipeline(image=image, prompt=prompt, generator=torch.Generator().manual_seed(seed)).images[0]
            result.save(os.path.join(results_dir_4_bit, f"{seed}.png"))

        # Clean up the 4-bit model
        del pipeline
        del transformer
        gc.collect()
        torch.cuda.synchronize()
        torch.cuda.empty_cache()

    free, total = torch.cuda.mem_get_info()  # bytes
    print(f"After 4-bit generation: Free: {free/1024**2:.0f} MB  /  Total: {total/1024**2:.0f} MB")

    lpips = compute_lpips(results_dir_16_bit, results_dir_4_bit)
    print(f"lpips: {lpips}")
    assert lpips < expected_lpips * 1.15