import json
import os.path as osp

from diffusers import DiffusionPipeline
import migraphx_diffusers
from migraphx_diffusers import get_name_and_migraphx_config
import torch


def parse_args():
    from argparse import ArgumentParser
    parser = ArgumentParser(description="SDXL inference with migraphx backend")

    #=========================== mdoel load and compile ========================
    parser.add_argument(
        "-m", 
        "--model-dir",
        type=str,
        required=True,
        help="Path to local model directory.",
    )
    parser.add_argument(
        "--force-compile",
        action="store_true",
        default=False,
        help="Ignore existing .mxr files and override them",
    )
    parser.add_argument(
        "--img-size",
        type=int,
        default=None,
        help="output image size",
    )
    parser.add_argument(
        "--num-images-per-prompt",
        type=int,
        default=1,
        help="The number of images to generate per prompt."
    )
    # --------------------------------------------------------------------------

    # =============================== generation ===============================
    parser.add_argument(
        "-p",
        "--prompt",
        type=str,
        required=True,
        help="Prompt for describe image content, style and so on."
    )
    parser.add_argument(
        "-n",
        "--negative-prompt",
        type=str,
        default=None,
        help="Negative prompt",
    )
    parser.add_argument(
        "-t",
        "--num-inference-steps",
        type=int,
        default=None,
        help="Number of iteration steps",
    )
    parser.add_argument(
        "--true-cfg-scale",
        default=None,
        type=float,
        help="Olny for flux pipeline. When > 1.0 and a provided `negative_prompt`, " \
             "enables true classifier-free guidance."
    )
    parser.add_argument(
        "--guidance-scale",
        default=None,
        type=float,
        help="Guidance scale is enabled by setting `guidance_scale > 1`. Higher " \
             "guidance scale encourages to generate images that are closely linked to " \
             "the text `prompt`, usually at the expense of lower image quality."
    )
    parser.add_argument(
        "-s",
        "--seed",
        type=int,
        default=42,
        help="Random seed",
    )
    parser.add_argument(
        "--save-prefix",
        type=str,
        default=None,
        help="Prefix of path for saving results",
    )
    # --------------------------------------------------------------------------

    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    pipe_name, migraphx_config = get_name_and_migraphx_config(args.model_dir)
    if args.save_prefix is None:
        args.save_prefix = f"./{pipe_name}_output"

    if args.img_size is not None:
        migraphx_config['common_args']['img_size'] = args.img_size
    migraphx_config['common_args'].update(dict(
        batch=args.num_images_per_prompt,
        force_compile=args.force_compile,
    ))

    pipe = DiffusionPipeline.from_pretrained(
        args.model_dir,
        torch_dtype=torch.float16,
        migraphx_config=migraphx_config
    )
    pipe.to("cuda")

    call_kwargs = {}
    if args.num_inference_steps is not None:
        call_kwargs['num_inference_steps'] = args.num_inference_steps
    if args.guidance_scale is not None:
        call_kwargs['guidance_scale'] = args.guidance_scale
    if args.true_cfg_scale is not None:
        assert pipe_name == 'flux.1-dev', \
            "`true_cfg_scale` is only valid for flux.1-dev pipeline!"
        call_kwargs['true_cfg_scale'] = args.true_cfg_scale
    if args.seed is not None:
        call_kwargs['generator'] = torch.Generator("cuda").manual_seed(args.seed)

    print("Generating image...")
    images = pipe(
        prompt=args.prompt, 
        negative_prompt=args.negative_prompt, 
        **call_kwargs
    ).images

    for i, image in enumerate(images):
        save_path = f"{args.save_prefix}_{i}.png"
        image.save(save_path)
        print(f"Generated image: {save_path}")


if __name__ == "__main__":
    main()
