import json
import os.path as osp

from diffusers import DiffusionPipeline
import migraphx_diffusers
import torch


def parse_args():
    from argparse import ArgumentParser
    parser = ArgumentParser(description="SDXL inference with migraphx backend")

    #=========================== mdoel load and compile ========================
    parser.add_argument(
        "-m", 
        "--model-dir",
        type=str,
        required=True,
        help="Path to local model directory.",
    )
    parser.add_argument(
        "--force-compile",
        action="store_true",
        default=False,
        help="Ignore existing .mxr files and override them",
    )
    parser.add_argument(
        "--img-size",
        type=int,
        default=None,
        help="output image size",
    )
    parser.add_argument(
        "--num-images-per-prompt",
        type=int,
        default=1,
        help="The number of images to generate per prompt."
    )
    # --------------------------------------------------------------------------

    # =============================== generation ===============================
    parser.add_argument(
        "-p",
        "--prompt",
        type=str,
        required=True,
        help="Prompt for describe image content, style and so on."
    )
    parser.add_argument(
        "-n",
        "--negative-prompt",
        type=str,
        default=None,
        help="Negative prompt",
    )
    parser.add_argument(
        "-t",
        "--num-inference-steps",
        type=int,
        default=50,
        help="Number of iteration steps",
    )
    parser.add_argument(
        "--save-prefix",
        type=str,
        default=None,
        help="Prefix of path for saving results",
    )
    parser.add_argument(
        "-s",
        "--seed",
        type=int,
        default=42,
        help="Random seed",
    )
    # --------------------------------------------------------------------------

    args = parser.parse_args()
    return args


def get_name_and_migraphx_config(model_dir):
    model_index_json = osp.join(model_dir, "model_index.json")
    with open(model_index_json, "r") as f:
        pipe_cfg = json.load(f)
    
    if pipe_cfg["_class_name"] == "StableDiffusionXLPipeline":
        return 'sdxl', migraphx_diffusers.DEFAULT_ARGS['sdxl']
    elif pipe_cfg["_class_name"] == "StableDiffusionPipeline":
        return 'sd2.1', migraphx_diffusers.DEFAULT_ARGS['sd2.1']
    else:
        raise NotImplementedError(
            f"{pipe_cfg['_class_name']} has not been adapted yet")


def main():
    args = parse_args()
    name, migraphx_config = get_name_and_migraphx_config(args.model_dir)
    if args.save_prefix is None:
        args.save_prefix = f"./{name}_output"

    if args.img_size is not None:
        migraphx_config['common_args']['img_size'] = args.img_size
    migraphx_config['common_args'].update(dict(
        batch=args.num_images_per_prompt,
        force_compile=args.force_compile,
    ))

    pipe = DiffusionPipeline.from_pretrained(
        args.model_dir,
        torch_dtype=torch.float16,
        migraphx_config=migraphx_config
    )
    pipe.to("cuda")

    print("Generating image...")
    images = pipe(
        prompt=args.prompt, 
        negative_prompt=args.negative_prompt, 
        num_inference_steps=args.num_inference_steps,
        generator=torch.Generator("cuda").manual_seed(args.seed)
    ).images

    for i, image in enumerate(images):
        save_path = f"{args.save_prefix}_{i}.png"
        image.save(save_path)
        print(f"Generated image: {save_path}")


if __name__ == "__main__":
    main()
