import json import os import os.path as osp from diffusers import DiffusionPipeline import migraphx_diffusers from migraphx_diffusers import get_name_and_migraphx_config import torch def parse_args(): from argparse import ArgumentParser parser = ArgumentParser(description="SDXL inference with migraphx backend") #=========================== mdoel load and compile ======================== parser.add_argument( "-m", "--model-dir", type=str, required=True, help="Path to local model directory.", ) parser.add_argument( "--force-compile", action="store_true", default=False, help="Ignore existing .mxr files and override them", ) parser.add_argument( "--num-images-per-prompt", type=int, default=1, help="The number of images to generate per prompt." ) parser.add_argument( "--img-size", type=int, default=None, help="output image size", ) # -------------------------------------------------------------------------- # =============================== generation =============================== parser.add_argument( "-t", "--num-inference-steps", type=int, default=None, help="Number of iteration steps", ) parser.add_argument( "--true-cfg-scale", default=None, type=float, help="Olny for flux pipeline. When > 1.0 and a provided `negative_prompt`, " \ "enables true classifier-free guidance." ) parser.add_argument( "--guidance-scale", default=None, type=float, help="Guidance scale is enabled by setting `guidance_scale > 1`. Higher " \ "guidance scale encourages to generate images that are closely linked to " \ "the text `prompt`, usually at the expense of lower image quality." ) parser.add_argument( "-s", "--seed", type=int, default=42, help="Random seed", ) # -------------------------------------------------------------------------- parser.add_argument( "--examples-json", type=str, default="./examples/prompts_and_negative_prompts.json", help="Prompts and negative prompts data path", ) parser.add_argument( "--output-dir", type=str, default=None, help="Path to save images", ) args = parser.parse_args() return args def parse_prompts(examples_json): with open(examples_json, 'r') as f: prompt_data = json.load(f) return prompt_data def main(): args = parse_args() pipe_name, migraphx_config = get_name_and_migraphx_config(args.model_dir) if args.output_dir is None: args.output_dir = f"./examples/{pipe_name}-images-{args.img_size}" if args.img_size is not None: migraphx_config['common_args']['img_size'] = args.img_size migraphx_config['common_args'].update(dict( batch=args.num_images_per_prompt, force_compile=args.force_compile, )) pipe = DiffusionPipeline.from_pretrained( args.model_dir, torch_dtype=torch.float16, migraphx_config=migraphx_config ) pipe.to("cuda") call_kwargs = {} if args.num_inference_steps is not None: call_kwargs['num_inference_steps'] = args.num_inference_steps if args.guidance_scale is not None: call_kwargs['guidance_scale'] = args.guidance_scale if args.true_cfg_scale is not None: assert pipe_name == 'flux.1-dev', \ "`true_cfg_scale` is only valid for flux.1-dev pipeline!" call_kwargs['true_cfg_scale'] = args.true_cfg_scale if args.seed is not None: call_kwargs['generator'] = torch.Generator("cuda").manual_seed(args.seed) prompt_data = parse_prompts(args.examples_json) cnt = 0 for i, d in enumerate(prompt_data): theme = d["theme"] pairs = d["examples"] sub_dir = osp.join(args.output_dir, f"{i}-{theme.title().replace(' ', '')}") os.makedirs(sub_dir, exist_ok=True) for j, pair in enumerate(pairs): print(f"Generating image {cnt}...") prompt = pair["prompt"] negative_prompt = pair["negative_prompt"] print(f"Prompt: {prompt}") print(f"negative Prompt: {negative_prompt}") images = pipe( prompt=prompt, negative_prompt=negative_prompt, **call_kwargs ).images for k, image in enumerate(images): save_path = osp.join( sub_dir, f"theme_{i}_example_{j}_image_{k}.png") image.save(save_path) print(f"Image saved: {save_path}") cnt += 1 print(f"Total {cnt} images Generated!") if __name__ == "__main__": main()