import copy
import os
import os.path as osp

import migraphx as mgx
import numpy as np
import prettytable
import torch


def load_migraphx_model(model_dir,
                        shapes,
                        use_fp16=False,
                        force_compile=False,
                        exhaustive_tune=False,
                        offload_copy=True,
                        batch=1,
                        img_size=1024):

    model_dtype = 'fp16' if use_fp16 else 'fp32'
    suffix = 'gpu' if not offload_copy else 'oc'
    onnx_file = osp.join(model_dir, 'model.onnx')
    mxr_file = osp.join(model_dir, 
                        f"model_{model_dtype}_b{batch}_{img_size}_{suffix}.mxr")

    if not force_compile and osp.isfile(mxr_file):
        # print(f"Found mxr, loading it from {mxr_file}")
        model = mgx.load(mxr_file, format="msgpack")
    elif osp.isfile(onnx_file):
        # print(f"No mxr found at {mxr_file}")
        # print(f"Parsing from {onnx_file}")

        orig_env_value1 = os.environ.get("MIGRAPHX_ENABLE_MIOPEN_GROUPNORM", None)
        orig_env_value2 = os.environ.get("MIGRAPHX_ENABLE_NHWC", None)
        os.environ["MIGRAPHX_ENABLE_MIOPEN_GROUPNORM"] = "1"
        os.environ['MIGRAPHX_ENABLE_NHWC'] = '1'

        model = mgx.parse_onnx(onnx_file, map_input_dims=shapes)
        if use_fp16:
            mgx.quantize_fp16(model)
        model.compile(mgx.get_target("gpu"),
                      exhaustive_tune=exhaustive_tune,
                      offload_copy=offload_copy)
        # model = mgx.opt_compile(model, 
        #                         t=mgx.get_target("gpu"), 
        #                         offload_copy=offload_copy, 
        #                         device_id=0)
        print(f"Saving mxr model to {mxr_file}")
        mgx.save(model, mxr_file, format="msgpack")

        if orig_env_value1 is None:
            os.environ.pop("MIGRAPHX_ENABLE_MIOPEN_GROUPNORM")
        else:
            os.environ["MIGRAPHX_ENABLE_MIOPEN_GROUPNORM"] = orig_env_value1
        if orig_env_value2 is None:
            os.environ.pop("MIGRAPHX_ENABLE_NHWC")
        else:
            os.environ["MIGRAPHX_ENABLE_NHWC"] = orig_env_value2

    else:
        raise RuntimeError(f"No model found at {model_dir}.")
    return model, mxr_file


def compare_data(outputs1, outputs2):
    decimal_bits = 8

    keys = set(outputs1) & set(outputs2)
    for key in keys:
        print(f"compare output {key}")
        a, b = outputs1[key], outputs2[key]

        pt1 = prettytable.PrettyTable()
        pt1.field_names = ['framework', 'dtype', 'shape', 'min', 'max', 'mean', 
                           'std', 'has-nan', 'has-inf']
        pt1.add_row(['PyTorch', a.dtype, a.shape, 
                     round(a.min().item(), decimal_bits), 
                     round(a.max().item(), decimal_bits), 
                     round(a.astype('float64').mean().item(), decimal_bits), 
                     round(a.astype('float64').std().item(), decimal_bits), 
                     np.isnan(a).any().item(), 
                     np.isinf(a).any().item()])
        pt1.add_row(['MIGraphX', b.dtype, b.shape, 
                     round(b.min().item(), decimal_bits), 
                     round(b.max().item(), decimal_bits), 
                     round(b.astype('float64').mean().item(), decimal_bits), 
                     round(b.astype('float64').std().item(), decimal_bits), 
                     np.isnan(b).any().item(), 
                     np.isinf(b).any().item()])
        print(pt1)

        a_ = a.flatten().astype(np.float64)
        b_ = b.flatten().astype(np.float64)
        cos = np.dot(a_, b_) / (np.linalg.norm(a_) * np.linalg.norm(b_))

        diff = np.abs(a_ - b_)
        max_diff = np.max(diff)
        avg_diff = np.mean(diff)
        sum_diff = np.sum(diff)

        all_close = np.allclose(a, b, rtol=1e-03, atol=1e-06, equal_nan=True)

        pt2 = prettytable.PrettyTable()
        pt2.field_names = ['cosine-similarity', 'all-close', 'max-diff', 
                           'mean-diff', 'sum-diff']
        pt2.add_row([round(cos.item(), decimal_bits), 
                     all_close, 
                     round(max_diff.item(), decimal_bits), 
                     round(avg_diff.item(), decimal_bits), 
                     round(sum_diff.item(), decimal_bits)])
        print(pt2)


def fix_random_seed(seed=0):
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed) # Numpy module.
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def load_model_args(config_dict):
    config_dict = copy.deepcopy(config_dict)
    if 'vae_decoder' in config_dict['model_args']:
        config_dict['model_args']['vae'] = \
            config_dict['model_args'].pop('vae_decoder')

    use_migraphx_models = set(config_dict['use_migraphx_models'])
    model_args = dict(
        pipeline=dict(
            batch=config_dict['common_args']['batch'], 
            img_size=config_dict['common_args']['img_size']
        )
    )
    for name in use_migraphx_models:
        name = 'vae' if name == 'vae_decoder' else name            
        cfg = copy.deepcopy(config_dict['common_args'])
        if name in config_dict['model_args']:
            cfg.update(config_dict['model_args'][name])
        model_args[name] = cfg

    return model_args
