from collections import namedtuple import csv import json import os import os.path as osp from diffusers import DiffusionPipeline import migraphx_diffusers import torch def parse_args(): from argparse import ArgumentParser parser = ArgumentParser(description="SDXL inference with migraphx backend") #=========================== mdoel load and compile ======================== parser.add_argument( "-m", "--model-dir", type=str, required=True, help="Path to local model directory.", ) parser.add_argument( "--force-compile", action="store_true", default=False, help="Ignore existing .mxr files and override them", ) parser.add_argument( "--img-size", type=int, default=None, help="output image size", ) parser.add_argument( "--num-images-per-prompt", type=int, default=1, help="The number of images to generate per prompt." ) # -------------------------------------------------------------------------- # =============================== generation =============================== parser.add_argument( "-p", "--parti-prompts-file", type=str, required=True, help="Number of iteration steps", ) parser.add_argument( "-t", "--num-inference-steps", type=int, default=50, help="Number of iteration steps", ) parser.add_argument( "--save-dir", type=str, default=None, help="Path to save images", ) parser.add_argument( "-s", "--seed", type=int, default=42, help="Random seed", ) parser.add_argument( "--resume", action="store_true", help="resume image generation", ) # -------------------------------------------------------------------------- args = parser.parse_args() return args def get_name_and_migraphx_config(model_dir): model_index_json = osp.join(model_dir, "model_index.json") with open(model_index_json, "r") as f: pipe_cfg = json.load(f) if pipe_cfg["_class_name"] == "StableDiffusionXLPipeline": return 'sdxl', migraphx_diffusers.DEFAULT_ARGS['sdxl'] elif pipe_cfg["_class_name"] == "StableDiffusionPipeline": return 'sd2.1', migraphx_diffusers.DEFAULT_ARGS['sd2.1'] else: raise NotImplementedError( f"{pipe_cfg['_class_name']} has not been adapted yet") def parse_prompts(parti_prompts_file): Prompt = namedtuple("Prompt", ["prompt_text", "category", "challenge", "note"]) prompt_list = [] with open(parti_prompts_file, "r") as f: csv_reader = csv.reader(f, delimiter="\t") for i, row in enumerate(csv_reader): if i == 0: continue prompt_list.append(Prompt(*row)) return prompt_list def main(): args = parse_args() name, migraphx_config = get_name_and_migraphx_config(args.model_dir) if args.img_size is not None: migraphx_config['common_args']['img_size'] = args.img_size migraphx_config['common_args'].update(dict( batch=args.num_images_per_prompt, force_compile=args.force_compile, )) pipe = DiffusionPipeline.from_pretrained( args.model_dir, torch_dtype=torch.float16, migraphx_config=migraphx_config ) pipe.to("cuda") os.makedirs(args.save_dir, exist_ok=True) generator = torch.Generator("cuda").manual_seed(args.seed) print("Generating image...") for i, prompt in enumerate(parse_prompts(args.parti_prompts_file)): sub_dir = osp.join(args.save_dir, prompt.category.replace(" ", "").replace("&", "_"), f"prompt_{i:0>4d}") prompt_json = osp.join(sub_dir, "prompt_info.json") # =========================== resume ========================= if args.resume: check_file_list = [osp.join(sub_dir, f"image_{j:0>2d}.png") for j in range(args.num_images_per_prompt)] check_file_list.append(prompt_json) if all([osp.exists(f) for f in check_file_list]): print(f"Skipping prompt {i}: \"{prompt.prompt_text}\"") continue # =========================== generate image ========================= print(f"Processing prompt {i}: \"{prompt.prompt_text}\"") if not osp.isdir(sub_dir): os.makedirs(sub_dir, exist_ok=True) with open(prompt_json, "w") as f: json.dump(prompt._asdict(), f) images = pipe( prompt=prompt.prompt_text, num_inference_steps=args.num_inference_steps, generator=generator ).images for j, image in enumerate(images): save_path = osp.join(sub_dir, f"{j:0>2d}.png") image.save(save_path) print(f"Generated image: {save_path}") if __name__ == "__main__": main()