flux.1-redux-dev.py 1010 Bytes
Newer Older
1
import torch
2
from diffusers import FluxPipeline, FluxPriorReduxPipeline
3
from diffusers.utils import load_image
4

muyangli's avatar
muyangli committed
5
from nunchaku import NunchakuFluxTransformer2dModel
6
from nunchaku.utils import get_precision
7

8
precision = get_precision()
9
10
11
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-Redux-dev", torch_dtype=torch.bfloat16
).to("cuda")
12
13
14
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
    f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
15
16
17
18
19
20
21
22
23
24
25
pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    text_encoder=None,
    text_encoder_2=None,
    transformer=transformer,
    torch_dtype=torch.bfloat16,
).to("cuda")

image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
pipe_prior_output = pipe_prior_redux(image)
images = pipe(guidance_scale=2.5, num_inference_steps=50, **pipe_prior_output).images
26
images[0].save(f"flux.1-redux-dev-{precision}.png")