controlnet-flux-cache.py 1.9 KB
Newer Older
Hyunsung Lee's avatar
Hyunsung Lee committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import random

import torch
from diffusers import FluxControlNetPipeline, FluxControlNetModel
from diffusers.models import FluxMultiControlNetModel
from nunchaku import NunchakuFluxTransformer2dModel
from diffusers.utils import load_image
import numpy as np

from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe


SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

base_model = 'black-forest-labs/FLUX.1-dev'
controlnet_model_union = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro'

controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union, torch_dtype=torch.bfloat16)
controlnet = FluxMultiControlNetModel([controlnet_union]) # we always recommend loading via FluxMultiControlNetModel


transformer = NunchakuFluxTransformer2dModel.from_pretrained(
    "mit-han-lab/svdq-int4-flux.1-dev",
    torch_dtype=torch.bfloat16).to("cuda")

pipe = FluxControlNetPipeline.from_pretrained(
    base_model,
    transformer=transformer,
    controlnet=controlnet,
    torch_dtype=torch.bfloat16)
apply_cache_on_pipe(pipe, residual_diff_threshold=0.12)
pipe.to("cuda")

prompt = 'A anime style girl with messy beach waves.'
control_image_depth = load_image("https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/depth.jpg")
control_mode_depth = 2

control_image_canny = load_image("https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/canny.jpg")
control_mode_canny = 0

width, height = control_image_depth.size

image = pipe(
    prompt,
    control_image=[control_image_depth, control_image_canny],
    control_mode=[control_mode_depth, control_mode_canny],
    width=width,
    height=height,
    controlnet_conditioning_scale=[0.3, 0.1],
    num_inference_steps=28,
    guidance_scale=3.5,
    generator=torch.manual_seed(SEED),
).images[0]


image.save("nunchaku-controlnet-flux.1-dev.png")