test_flux_txt2img_cache_controlnet.py 3.85 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gc

import torch
from diffusers import (
    AutoencoderKL,
    FlowMatchEulerDiscreteScheduler,
    FluxControlNetModel,
    FluxControlNetPipeline,
    FluxPipeline,
)
from diffusers.models import FluxMultiControlNetModel
from diffusers.utils import load_image
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast

from nunchaku import NunchakuT5EncoderModel
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision


def test_flux_txt2img_cache_controlnet():
    bfl_repo = "black-forest-labs/FLUX.1-dev"
    dtype = torch.bfloat16  # or torch.float16, or torch.float32
    device = "cuda"  # or "cpu" if you want to run on CPU

    scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder="scheduler", torch_dtype=dtype)
    text_encoder = CLIPTextModel.from_pretrained(bfl_repo, subfolder="text_encoder", torch_dtype=dtype)
    text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
    tokenizer = CLIPTokenizer.from_pretrained(
        bfl_repo, subfolder="tokenizer", torch_dtype=dtype, clean_up_tokenization_spaces=True
    )
    tokenizer_2 = T5TokenizerFast.from_pretrained(
        bfl_repo, subfolder="tokenizer_2", torch_dtype=dtype, clean_up_tokenization_spaces=True
    )
    vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder="vae", torch_dtype=dtype)
    precision = get_precision()
    transformer = NunchakuFluxTransformer2dModel.from_pretrained(
        f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors",
        # offload=True
    )
    transformer.set_attention_impl("nunchaku-fp16")

    # qencoder
    text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors")
    controlnet_union = FluxControlNetModel.from_pretrained(
        "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0", torch_dtype=torch.bfloat16
    )
    controlnet = FluxMultiControlNetModel(
        [controlnet_union]
    )  # we always recommend loading via FluxMultiControlNetModel

    params = {
        "scheduler": scheduler,
        "vae": vae,
        "tokenizer": tokenizer,
        "tokenizer_2": tokenizer_2,
        "text_encoder": text_encoder,
        "text_encoder_2": text_encoder_2,
        "transformer": transformer,
    }
    # pipe
    pipe = FluxPipeline(**params).to(device, dtype=dtype)
    pipe_cn = FluxControlNetPipeline(**params, controlnet=controlnet).to(device, dtype)

    # offload
    pipe.enable_sequential_cpu_offload(device=device)
    pipe_cn.enable_sequential_cpu_offload(device=device)

    # cache
    apply_cache_on_pipe(
        pipe_cn,
        use_double_fb_cache=True,
        residual_diff_threshold_multi=0.09,
        residual_diff_threshold_single=0.12,
    )

    params = {
        "prompt": "A bohemian-style female travel blogger with sun-kissed skin and messy beach waves.",
        "height": 1152,
        "width": 768,
        "num_inference_steps": 30,
        "guidance_scale": 3.5,
    }

    # pipe
    txt2img_res = pipe(
        **params,
    ).images[0]
    txt2img_res.save("flux.1-dev-txt2img.jpg")

    gc.collect()
    torch.cuda.empty_cache()

    # cache
    apply_cache_on_pipe(
        pipe_cn,
        use_double_fb_cache=True,
        residual_diff_threshold_multi=0.09,
        residual_diff_threshold_single=0.12,
    )

    # pipe_cn
    control_iamge = load_image(
        "https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/openpose.jpg"
    )
    params["control_image"] = [control_iamge]
    params["controlnet_conditioning_scale"] = [0.9]
    params["control_guidance_end"] = [0.65]
    cn_res = pipe_cn(
        **params,
    ).images[0]
    cn_res.save("flux.1-dev-cn-txt2img.jpg")