import torch import time import os from typing import Optional from pathlib import Path from diffusers.utils import load_image from diffusers import AutoPipelineForText2Image, AutoPipelineForInpainting, \ StableDiffusionXLInpaintPipeline, AutoPipelineForImage2Image, \ StableDiffusionXLImg2ImgPipeline, DiffusionPipeline def get_pipeline(base_path: Optional[str] = None, refiner_path: Optional[str] = None, mode: str = "t2i"): # mode = Union[t2i, i2i, inpainting, t2i_wr, inpainting_wr] if "t2i" in mode: if "wr" in mode: base_pipeline = DiffusionPipeline.from_pretrained( base_path, torch_dtype=torch.float16, variant="fp16", use_safetensors=True ).to("cuda") refiner_pipeline = DiffusionPipeline.from_pretrained( refiner_path, text_encoder_2=base_pipeline.text_encoder_2, vae=base_pipeline.vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16", ).to("cuda") pipelines = [base_pipeline, refiner_pipeline] else: base_pipeline = AutoPipelineForText2Image.from_pretrained( base_path, torch_dtype=torch.float16, variant="fp16", use_safetensors=True ).to("cuda") pipelines = [base_pipeline] elif "i2i" in mode: base_pipeline = AutoPipelineForImage2Image.from_pretrained( base_path, torch_dtype=torch.float16, variant="fp16", use_safetensors=True ).to("cuda") pipelines = [base_pipeline] elif "inpainting" in mode: if "wr" in mode: base_pipeline = StableDiffusionXLInpaintPipeline.from_pretrained( base_path, torch_dtype=torch.float16, variant="fp16", use_safetensors=True ).to("cuda") refiner_pipeline = StableDiffusionXLInpaintPipeline.from_pretrained( refiner_path, text_encoder_2=base_pipeline.text_encoder_2, vae=base_pipeline.vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16" ).to("cuda") pipelines = [base_pipeline, refiner_pipeline] else: base_pipeline = AutoPipelineForInpainting.from_pretrained( base_path, torch_dtype=torch.float16, variant="fp16", use_safetensors=True ).to("cuda") pipelines = [base_pipeline] else: raise NotImplemented return pipelines def save_img(img, save_root, mode): save_root = os.path.join(save_root, mode) os.makedirs(save_root, exist_ok=True) save_path = os.path.join(save_root, str(time.time())+".png") img.save(save_path) def t2i(pipelines: list, prompt: str): if len(pipelines) == 1: # only base image = pipelines[0](prompt=prompt).images[0] else: image = pipelines[0](prompt=prompt, num_inference_steps=40, denoising_end=0.8, output_type="latent").images image = pipelines[1](prompt=prompt, num_inference_steps=40, denoising_start=0.8, image=image).images[0] return image def i2i(pipelines: list, prompt: str, img_path: str): if len(pipelines) == 1: init_image = load_image(img_path) image = pipelines[0](prompt=prompt, image=init_image, strength=0.8, guidance_scale=10.5).images[0] else: raise NotImplementedError return image def inpainting(pipelines: list, prompt: str, img_path: str, mask_path: str): init_image = load_image(img_path) mask_image = load_image(mask_path) if len(pipelines) == 1: image = pipelines[0](prompt=prompt, image=init_image, mask_image=mask_image, strength=0.85, guidance_scale=12.5).images[0] else: image = pipelines[0](prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=75, denoising_end=0.7, output_type="latent").images image = pipelines[1](prompt=prompt, image=image, mask_image=mask_image, num_inference_steps=75, denoising_start=0.7).images[0] return image def inference(args): pipelines = get_pipeline(args.base_path, args.refiner_path, args.mode) if "t2i" in args.mode: img = t2i(pipelines, args.prompt) elif "i2i" in args.mode: img = i2i(pipelines, args.prompt, args.img_path) elif "inpainting" in args.mode: img = inpainting(pipelines, args.prompt, args.img_path, args.mask_path) save_img(img, args.save_root, args.mode) if __name__ == "__main__": from argparse import ArgumentParser default_base_path = str(Path(__file__).resolve().parent / "pretrained_models" / "stable-diffusion-xl-base-1.0") default_refiner_path = str(Path(__file__).resolve().parent / "pretrained_models" / "stable-diffusion-xl-refiner-1.0") parser = ArgumentParser() parser.add_argument("--base_path", default=default_base_path, type=str) parser.add_argument("--refiner_path", default=default_refiner_path, type=str) parser.add_argument("--prompt", default="a panda is playing a ball", type=str) parser.add_argument("--img_path", type=str, default="") parser.add_argument("--mask_path", type=str, default="") parser.add_argument("--save_root", default="./results", type=str) parser.add_argument("--mode", type=str, required=True) args = parser.parse_args() inference(args)