from diffusers import DiffusionPipeline
import argparse
import os
import torch
import time
import migraphx_diffusers


parser = argparse.ArgumentParser("test sd2.1")
parser.add_argument('model_dir', type=str, help="path to sd2.1 models")
parser.add_argument('--result-dir', type=str, default="./results", help="path to sd2.1 models")
args = parser.parse_args()
os.makedirs(args.result_dir, exist_ok=True)

# 基础提示词
prompt = "An astronaut riding a green horse"

# 配置组合参数
widths = [512]
heights = [512]
steps_list = [20]
batch_sizes = [1, 2, 4, 8]  # [8]
mgx_config = migraphx_diffusers.DEFAULT_ARGS['sd2.1']

# 生成8种配置组合
for width, height in zip(widths, heights):
    assert width == height, "Only support generate images with square shape!"
    mgx_config["common_args"]["img_size"] = width
    
    for batch_size in batch_sizes:
        mgx_config["common_args"]["batch"] = batch_size

        # 初始化模型
        pipe = DiffusionPipeline.from_pretrained(
            args.model_dir,
            torch_dtype=torch.float16,
            use_safetensors=True,
            variant="fp16",
            migraphx_config=mgx_config,
        )
        pipe.to("cuda")

        # Warm up
        for i in range(1):
            pipe(
                prompt=prompt,
                width=width,
                height=height,
                num_inference_steps=1,
                num_images_per_prompt=batch_size
            )

        for num_inference_steps in steps_list:

            print(f"\n生成配置: {width}x{height}, steps={num_inference_steps}, batch={batch_size}")
            time_list = []
            for i in range(1):
                torch.cuda.synchronize()
                time_start = time.time()
                
                result = pipe(
                    prompt=prompt,
                    width=width,
                    height=height,
                    num_inference_steps=num_inference_steps,
                    num_images_per_prompt=batch_size,
                    generator=torch.Generator("cuda").manual_seed(42)
                )
                
                torch.cuda.synchronize()
                time_end = time.time()
                time_list.append((time_end - time_start)*1000)


            print(f"time cost: {time_list}, avg: {sum(time_list)/len(time_list)}")
            # 保存本批次生成的图片
            print(len(result.images))
            for i, image in enumerate(result.images):
                filename = os.path.join(args.result_dir, f"output_{width}x{height}_steps{num_inference_steps}_batch{batch_size}_{i}.png")
                image.save(filename)
                print(f"保存图片: {filename}")


print("所有配置组合生成完成！")
