"vscode:/vscode.git/clone" did not exist on "c459536b0f1c66e259258cccf95039580fd43f37"
flux.1-dev-controlnet-union-pro.py 2.14 KB
Newer Older
1
2
3
4
5
6
7
import torch
from diffusers import FluxControlNetModel, FluxControlNetPipeline
from diffusers.models import FluxMultiControlNetModel
from diffusers.utils import load_image

from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.caching.diffusers_adapters.flux import apply_cache_on_pipe
muyangli's avatar
muyangli committed
8
from nunchaku.utils import get_gpu_memory, get_precision
9
10
11
12
13
14
15
16

base_model = "black-forest-labs/FLUX.1-dev"
controlnet_model_union = "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"

controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union, torch_dtype=torch.bfloat16)
controlnet = FluxMultiControlNetModel([controlnet_union])  # we always recommend loading via FluxMultiControlNetModel

precision = get_precision()
muyangli's avatar
muyangli committed
17
need_offload = get_gpu_memory() < 36
18
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
muyangli's avatar
muyangli committed
19
    f"mit-han-lab/svdq-{precision}-flux.1-dev", torch_dtype=torch.bfloat16, offload=need_offload
20
21
22
23
24
)
transformer.set_attention_impl("nunchaku-fp16")

pipeline = FluxControlNetPipeline.from_pretrained(
    base_model, transformer=transformer, controlnet=controlnet, torch_dtype=torch.bfloat16
muyangli's avatar
muyangli committed
25
26
27
28
29
30
31
)

if need_offload:
    pipeline.enable_sequential_cpu_offload()
else:
    pipeline = pipeline.to("cuda")

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# apply_cache_on_pipe(
#     pipeline, residual_diff_threshold=0.1
# )  # Uncomment this line to enable first-block cache to speedup generation


prompt = "A anime style girl with messy beach waves."
control_image_depth = load_image(
    "https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/depth.jpg"
)
control_mode_depth = 2

control_image_canny = load_image(
    "https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/canny.jpg"
)
control_mode_canny = 0

width, height = control_image_depth.size

image = pipeline(
    prompt,
    control_image=[control_image_depth, control_image_canny],
    control_mode=[control_mode_depth, control_mode_canny],
    width=width,
    height=height,
    controlnet_conditioning_scale=[0.3, 0.1],
    num_inference_steps=28,
    guidance_scale=3.5,
    generator=torch.manual_seed(233),
).images[0]


image.save(f"flux.1-dev-controlnet-union-pro-{precision}.png")