import json import os.path as osp import time from diffusers import DiffusionPipeline import migraphx_diffusers from migraphx_diffusers import AutoTimer, get_name_and_migraphx_config import torch def parse_args(): date_str = time.strftime("%Y%m%d-%H%M%S", time.localtime()) from argparse import ArgumentParser parser = ArgumentParser(description="SDXL inference with migraphx backend") #=========================== mdoel load and compile ======================== parser.add_argument( "-m", "--model-dir", type=str, required=True, help="Path to local model directory.", ) parser.add_argument( "--force-compile", action="store_true", default=False, help="Ignore existing .mxr files and override them", ) parser.add_argument( "--img-size", type=int, default=None, help="output image size", ) parser.add_argument( "--num-images-per-prompt", type=int, default=1, help="The number of images to generate per prompt." ) # -------------------------------------------------------------------------- # =============================== generation =============================== parser.add_argument( "-t", "--num-inference-steps", type=int, default=50, help="Number of iteration steps", ) parser.add_argument( "--out-csv-file", type=str, default=f"./perf-{date_str}.csv", help="Prefix of path for saving results", ) # -------------------------------------------------------------------------- # =============================== time count =============================== parser.add_argument( "--count-submodels", action="store_true", help="count running time for each submodel", ) parser.add_argument( "--num-warmup-loops", type=int, default=1, help="warmup loops", ) parser.add_argument( "--num-count-loops", type=int, default=100, help="time count loops", ) # -------------------------------------------------------------------------- args = parser.parse_args() return args def get_default_prompt(pipe_name): negative_prompt = "ugly" if pipe_name == 'sd2.1': prompt = "a photo of an astronaut riding a horse on mars" elif pipe_name == 'sdxl': prompt = "An astronaut riding a green horse", None elif pipe_name == 'flux.1-dev': prompt = "A cat holding a sign that says hello world" else: raise ValueError(f"{pipe_name} is not supported!") return prompt, negative_prompt def set_timer(timer, pipe, pipe_name, count_submodels=False): timer.add_target(pipe, key="end2end") if not count_submodels: return if pipe_name == 'sd2.1': timer.add_targets([ (pipe.text_encoder, "text_encoder"), (pipe.unet, "unet"), (pipe.vae.decode, "vae_decoder") ]) elif pipe_name == 'sdxl': timer.add_targets([ (pipe.text_encoder, "text_encoder"), (pipe.text_encoder_2, "text_encoder_2"), (pipe.unet, "unet"), (pipe.vae.decode, "vae_decoder") ]) elif pipe_name == 'flux.1-dev': timer.add_targets([ (pipe.text_encoder, "text_encoder"), (pipe.text_encoder_2, "text_encoder_2"), (pipe.transformer, "transformer"), (pipe.vae.decode, "vae_decoder") ]) else: raise ValueError(f"{pipe_name} is not supported!") def test_latency(pipe, timer, prompt, negative_prompt=None, batch=1, num_inference_steps=50, num_warmup_loops=1, num_count_loops=100, title=None, out_csv_file=None, **call_kwargs): date_str = time.strftime("%Y%m%d-%H%M%S", time.localtime()) if not out_csv_file: out_csv_file = f"./perf-{date_str}.csv" for i in range(num_warmup_loops + num_count_loops): if i == num_warmup_loops: timer.start_work() pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, **call_kwargs) table = timer.summary(batchsize=batch, title=title) with open(out_csv_file, 'a') as f: f.write(table.get_csv_string()) timer.clear() timer.finish_work() def main(): args = parse_args() pipe_name, migraphx_config = get_name_and_migraphx_config(args.model_dir) assert pipe_name in ['sdxl', 'sd2.1', 'flux.1-dev'], \ "Only support (1)SDXL (2)SD2.1 (3)Flux.1-dev!" if args.img_size is not None: migraphx_config['common_args']['img_size'] = args.img_size migraphx_config['common_args'].update(dict( batch=args.num_images_per_prompt, force_compile=args.force_compile, )) pipe = DiffusionPipeline.from_pretrained( args.model_dir, torch_dtype=torch.float16, migraphx_config=migraphx_config ) pipe.to("cuda") t = AutoTimer() set_timer(t, pipe, pipe_name, count_submodels=args.count_submodels) prompt, negative_prompt = get_default_prompt(pipe_name) test_latency(pipe, t, prompt, batch=args.num_images_per_prompt, num_inference_steps=args.num_inference_steps, num_warmup_loops=args.num_warmup_loops, num_count_loops=args.num_count_loops, title=f"{pipe_name} Latency (Only Prompt)", out_csv_file=args.out_csv_file) if pipe_name == 'flux.1-dev': test_latency(pipe, t, prompt, negative_prompt=negative_prompt, batch=args.num_images_per_prompt, num_inference_steps=args.num_inference_steps, num_warmup_loops=args.num_warmup_loops, num_count_loops=args.num_count_loops, title=f"{pipe_name} Latency (Prompt + NegativePrompt)", out_csv_file=args.out_csv_file, true_cfg_scale=2.0) if __name__ == "__main__": main()