import argparse
import os.path as osp
from diffusers import AutoencoderKL, DiffusionPipeline
from migraphx_diffusers import ONNXModifier
import numpy as np
import onnx
import torch


def export_vae_decoder(pipeline_dir):
    save_path = osp.join(pipeline_dir, "vae_decoder", "model.onnx")
    vae = AutoencoderKL.from_pretrained(pipeline_dir, subfolder="vae")
    # pipe = DiffusionPipeline.from_pretrained(pipeline_dir, use_safetensors=True)

    input_names = ["latent_sample"]
    output_names = ["sample"]
    dynamic_axes = {
        'latent_sample': {
            0: 'batch_size',
            2: 'latent_height',
            3: 'latent_width'
        },
        'latent': {
            0: 'batch_size',
            2: 'image_height',
            3: 'image_width'
        }
    }

    vae.forward = vae.decode
    # pipe.vae.forward = pipe.vae.decode

    torch.onnx.export(
        vae,
        # pipe.vae, 
        torch.randn(1, vae.config["latent_channels"], 128, 128), 
        # torch.randn(1, pipe.vae.config["latent_channels"], 128, 128), 
        save_path ,
        export_params=True,
        input_names=input_names,
        output_names=output_names,
        dynamic_axes=dynamic_axes
    )

    if osp.isfile(save_path):
        print(f"Successfully exported vae decoder to ONNX: {save_path}")
    else:
        raise RuntimeError(f"Failed to export vae decoder to ONNX.")


    return save_path


def modify_vae_decoder(onnx_path, save_path=None):
    om = ONNXModifier(onnx_path)

    shape_node = om.get_node("/decoder/mid_block/attentions.0/Shape")
    C_gather = om.get_node("/decoder/mid_block/attentions.0/Gather_1")
    H_gather = om.get_node("/decoder/mid_block/attentions.0/Gather_2")
    W_gather = om.get_node("/decoder/mid_block/attentions.0/Gather_3")
    C_gather.set_input(0, shape_node.outputs[0])
    H_gather.set_input(0, shape_node.outputs[0])
    W_gather.set_input(0, shape_node.outputs[0])

    matmul1 = om.get_node("/decoder/mid_block/attentions.0/to_k/MatMul")
    matmul1.set_input(0, matmul1.inputs[1])
    matmul1.set_input(1, "/decoder/mid_block/attentions.0/group_norm/Add_output_0")
    matmul1_ini = om.get_initializer("onnx::MatMul_918")
    orig_dtype = onnx.helper.tensor_dtype_to_np_dtype(matmul1_ini.data_type)
    orig_shape = matmul1_ini.dims[:]
    new_value = np.frombuffer(matmul1_ini.raw_data, 
                              dtype=orig_dtype).reshape(orig_shape).T
    matmul1_ini.raw_data = new_value.tobytes()
    matmul1_ini.dims[:] = orig_shape[::-1]
    add_ini = om.get_initializer("decoder.mid_block.attentions.0.to_k.bias")
    add_ini.dims[:] = [512, 1]

    mul1 = om.get_node("/decoder/mid_block/attentions.0/Mul_1")
    mul2 = om.get_node("/decoder/mid_block/attentions.0/Mul_2")
    mul1.set_input(0, "/decoder/mid_block/attentions.0/to_q/Add_output_0")
    mul2.set_input(0, "/decoder/mid_block/attentions.0/to_k/Add_output_0")

    value_B = om.create_initializer("/decoder/mid_block/attentions.0/mul_B", 
                                    np.sqrt(1 / np.sqrt(512)).astype(np.float32)[None])
    mul1.set_input(1, value_B.name)
    mul2.set_input(1, value_B.name)

    matmul2 = om.get_node("/decoder/mid_block/attentions.0/MatMul_1")
    matmul2.set_input(1, "/decoder/mid_block/attentions.0/to_v/Add_output_0")

    matmul3 = om.get_node("/decoder/mid_block/attentions.0/to_out.0/MatMul")
    matmul3.set_input(0, matmul2.outputs[0])

    reshape_node = om.get_node("/decoder/mid_block/attentions.0/Reshape_5")
    reshape_node.set_input(1, shape_node.outputs[0])

    om.update_map()
    om.remove_trash()
    save_path = onnx_path if save_path is None else save_path
    om.save(save_path)


def main():
    parser = argparse.ArgumentParser("Export vae decoder to ONNX")
    parser.add_argument("--pipeline-dir", 
                        type=str, 
                        required=True, 
                        help="The path to the sdxl pipeline directory.")
    args = parser.parse_args()

    onnx_path = export_vae_decoder(args.pipeline_dir)
    modify_vae_decoder(onnx_path)


if __name__ == "__main__":
    main()
