import copy from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler from migraphx_diffusers import (MIGraphXAutoencoderKL, MIGraphXCLIPTextModel, MIGraphXCLIPTextModelWithProjection, MIGraphXUNet2DConditionModel) from transformers import AutoTokenizer import numpy as np import torch def parse_args(): from argparse import ArgumentParser parser = ArgumentParser(description="SDXL inference with migraphx backend") #=========================== mdoel load and compile ======================== parser.add_argument( "-m", "--model-dir", type=str, required=True, help="Path to local model directory.", ) parser.add_argument( "--force-compile", action="store_true", default=False, help="Ignore existing .mxr files and override them", ) parser.add_argument( "--img-size", type=int, default=1024, help="output image size", ) parser.add_argument( "--num-images-per-prompt", type=int, default=1, help="The number of images to generate per prompt." ) # -------------------------------------------------------------------------- # =============================== generation =============================== parser.add_argument( "-p", "--prompt", type=str, required=True, help="Prompt for describe image content, style and so on." ) parser.add_argument( "-n", "--negative-prompt", type=str, default=None, help="Negative prompt", ) parser.add_argument( "-t", "--num-inference-steps", type=int, default=50, help="Number of iteration steps", ) parser.add_argument( "--save-prefix", type=str, default="sdxl_output", help="Prefix of path for saving results", ) parser.add_argument( "-s", "--seed", type=int, default=42, help="Random seed", ) # -------------------------------------------------------------------------- args = parser.parse_args() return args def main(): args = parse_args() pipeline_dir = args.model_dir common_args = dict( batch=args.num_images_per_prompt, img_size=args.img_size, model_dtype='fp16', force_compile=args.force_compile, ) text_encoder_args = copy.deepcopy(common_args) text_encoder_args['batch'] = 1 # ========================== load migraphx mdoels ========================== text_encoder = MIGraphXCLIPTextModel.from_pretrained( pipeline_dir, subfolder="text_encoder", **text_encoder_args) text_encoder_2 = MIGraphXCLIPTextModelWithProjection.from_pretrained( pipeline_dir, subfolder="text_encoder_2", **text_encoder_args) unet = MIGraphXUNet2DConditionModel.from_pretrained( pipeline_dir, subfolder="unet", **common_args, pipeline_class=StableDiffusionXLPipeline) vae = MIGraphXAutoencoderKL.from_pretrained( pipeline_dir, subfolder="vae_decoder", **common_args) # -------------------------------------------------------------------------- # ============================ load torch models =========================== scheduler = EulerDiscreteScheduler.from_pretrained( pipeline_dir, subfolder="scheduler") tokenizer = AutoTokenizer.from_pretrained( pipeline_dir, subfolder="tokenizer") tokenizer_2 = AutoTokenizer.from_pretrained( pipeline_dir, subfolder="tokenizer_2") # -------------------------------------------------------------------------- # create pipeline pipe = StableDiffusionXLPipeline( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, scheduler=scheduler, force_zeros_for_empty_prompt=True, add_watermarker=None, ) pipe.to("cuda") pipe.to(torch.float16) # register configuration pipe.register_to_config( _mgx_models=["text_encoder", "text_encoder_2", "unet", "vae"]) pipe.register_to_config(_batch=args.num_images_per_prompt) pipe.register_to_config(_img_height=args.img_size) pipe.register_to_config(_img_width=args.img_size) # generate images print("Generating image...") images = pipe( prompt=args.prompt, negative_prompt=args.negative_prompt, num_inference_steps=args.num_inference_steps, generator=torch.Generator("cuda").manual_seed(args.seed) ).images for i, image in enumerate(images): save_path = f"{args.save_prefix}_{i}.png" image.save(save_path) print(f"Generated image: {save_path}") if __name__ == "__main__": main()