test_flux_dev_IPA.py 2.89 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gc

import numpy as np
import pytest
import torch
import torch.nn.functional as F
from diffusers import FluxPipeline
from diffusers.utils import load_image

from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.models.ip_adapter.diffusers_adapters import apply_IPA_on_pipe
from nunchaku.models.ip_adapter.utils import resize_numpy_image_long
from nunchaku.pipeline.pipeline_flux_pulid import PuLIDFluxPipeline
from nunchaku.utils import get_precision, is_turing


@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_IPA():
    precision = get_precision()  # auto-detect your precision is 'int4' or 'fp4' based on your GPU
    transformer = NunchakuFluxTransformer2dModel.from_pretrained(
        f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
    )

    pipeline = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
    ).to("cuda")

    pipeline.load_ip_adapter(
        pretrained_model_name_or_path_or_dict="XLabs-AI/flux-ip-adapter-v2",
        weight_name="ip_adapter.safetensors",
        image_encoder_pretrained_model_name_or_path="openai/clip-vit-large-patch14",
    )

    apply_IPA_on_pipe(pipeline, ip_adapter_scale=1.15, repo_id="XLabs-AI/flux-ip-adapter-v2")

    id_image = load_image(
        "https://huggingface.co/datasets/nunchaku-tech/test-data/resolve/main/ComfyUI-nunchaku/inputs/monalisa.jpg"
    )

    image = pipeline(
        prompt="holding an sign saying 'SVDQuant is fast!'",
        ip_adapter_image=id_image.convert("RGB"),
        num_inference_steps=50,
    ).images[0]

    del pipeline
    del transformer
    gc.collect()
    torch.cuda.empty_cache()

    # use the pulid pipeline to get the id embedding
    transformer = NunchakuFluxTransformer2dModel.from_pretrained(
        f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors", offload=True
    )
    pipeline = PuLIDFluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev",
        transformer=transformer,
        torch_dtype=torch.bfloat16,
    )

    id_image = id_image.convert("RGB")
    id_image_numpy = np.array(id_image)
    id_image = resize_numpy_image_long(id_image_numpy, 1024)
    id_embeddings, _ = pipeline.pulid_model.get_id_embedding(id_image)

    output_image = image.convert("RGB")
    output_image_numpy = np.array(output_image)
    output_image = resize_numpy_image_long(output_image_numpy, 1024)
    output_id_embeddings, _ = pipeline.pulid_model.get_id_embedding(output_image)
    cosine_similarities = (
        F.cosine_similarity(id_embeddings.view(32, 2048), output_id_embeddings.view(32, 2048), dim=1).mean().item()
    )
    print(cosine_similarities)
    assert cosine_similarities > 0.80

    del pipeline
    del transformer
    gc.collect()
    torch.cuda.empty_cache()


if __name__ == "__main__":
    test_flux_dev_IPA()