import os
import json
import logging

import torch._tensor
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import migraphx
import tensorflow as tf
import torch
import onnxruntime
import time
import numpy as np

import onnx

from onnx import shape_inference
from general_perf.backends import runtime_backend

log = logging.getLogger("BackendDCU")

pt_dtype_map = {
    "FLOAT32": torch.float32,
    "FLOAT16": torch.float16,
    "INT8": torch.int8,
    "LONG": torch.long
}

INPUT_TYPE = {
    "INT8": np.int8,
    "UINT8": np.uint8,
    "FLOAT32": np.float32,
    "FLOAT16": np.float16,
    "LONG": np.long,
    "INT32": np.int32,
    "INT64": np.int64,
    "BOOL": np.bool
}


class RuntimeBackendDCU(runtime_backend.RuntimeBackend):
    def __init__(self):
        super(RuntimeBackendDCU, self).__init__()
        self.hardware_type = 'DCU'
        self.need_reload = False
        self.model_runtimes = []
        self.configs = None
        self.batch_size = -1
    
    def predict(self, feeds):
        results = {}
        if self.framework == "Tensorflow":
            entry_rt = self.model_runtimes[0].signatures['serving_default']
            all_sn_inputs = entry_rt.structured_input_signature

            def get_real_feeds(feeds, sn_inputs):
                sn_inputs = tf.nest.flatten(sn_inputs, True)
                real_feeds = {}
                itr = 0
                for _, val in feeds.items():
                    real_feeds[sn_inputs[itr].name] = tf.constant(val)
                    itr += 1
                return real_feeds

            real_feeds = get_real_feeds(feeds, all_sn_inputs)
            start_time = time.time()
            for model_runtime in self.model_runtimes:

                with tf.device('GPU'):    
                    _results = model_runtime.signatures['serving_default'](
                        **real_feeds)
            end_time = time.time()
            use_time = end_time - start_time
            results = {}
            for key, val in _results.items():
                results[key] = val.numpy()

            assert len(results) != 0

        elif self.framework == "Pytorch":

            input_tensors = []
            new_input_type = self.input_type.split(',')

            i = 0

            for key, _ in feeds.items():
                input_tensors.append(
                    torch.tensor(feeds[key],
                                 dtype=pt_dtype_map[new_input_type[i]]).to(
                                     self.device))
                i += 1
            start_time = time.time()
            if self.configs['compile_precision'] == "FP16" and self.configs['model'].find("bert") != -1:
                with torch.no_grad(), torch.cuda.amp.autocast():

                    for model_runtime in self.model_runtimes:
                        results = model_runtime(*input_tensors)
            else:
                with torch.no_grad():

                    for model_runtime in self.model_runtimes:
                        results = model_runtime(*input_tensors)
            end_time = time.time()
            use_time = end_time - start_time
            if isinstance(results, dict):
                for key, val in results.items():
                    results[key] = val.cpu().detach().numpy()
            elif isinstance(results, tuple):
                dic = {}
                for i, key in enumerate(self.outputs):
                    dic[key] = list(results)[i]
            else:
                results = {self.outputs[0]: results.cpu().numpy()}
        elif self.framework == "Migraphx":
            for model_runtime in self.model_runtimes:
                modelData = self.AllocateteOutputMemory(model_runtime)
                for key, _ in feeds.items():
                    feeds[key] = np.array(feeds[key])
                    modelData[key] = migraphx.to_gpu(migraphx.argument(feeds[key]))
                
                start_time = time.time()
                
                results_migraphx = model_runtime.run(modelData)
                end_time = time.time()
                use_time = end_time - start_time
                results = []
                for i in range(len(results_migraphx)):
                    result = np.array(results_migraphx[i])
                    results.append(result)

        else:
            for model_runtime in self.model_runtimes:   
                start_time = time.time()           
                results = model_runtime.run(None, feeds)
                end_time = time.time()
                use_time = end_time - start_time
        return use_time, results
        # return results

    def benchmark(self, dataloader):
        iterations = self.workload['iterations']
        batch_size = self.get_loaded_batch_size()
        times_range = []
        time_range = []
        report = {}
        report['BS'] = batch_size

        test_data = self._get_fake_samples(
            batch_size, self.configs['segments'][0]['input_tensor_map'],
            self.configs['input_type'])

        for _ in range(30):
            self.predict(test_data)

        for _ in range(iterations):
            start_time = time.time()
            use_time,_ = self.predict(test_data)
            end_time = time.time()
            times_range.append(use_time)
            time_range.append(batch_size / use_time)
            # times_range.append(end_time - start_time)
        
        times_range.sort()
        tail_latency = round(
            times_range[int(len(times_range) * 0.99)] * 1000, 2)
        avg_latency = round(sum(times_range) / iterations * 1000, 2)
        qps = int(1000.0 * batch_size / avg_latency)

        log.info(
            'Batch size is {}, QPS: {}, Avg Latency:{}, Tail Latency:{}'.
            format(batch_size, qps, avg_latency, tail_latency))

        report['QPS'] = qps
        report['AVG Latency'] = avg_latency
        report['P99 Latency'] = tail_latency

        return report

    def get_loaded_batch_size(self):
        return self.batch_size

    def load(self, batch_size) -> None:
        self.batch_size = batch_size
        self.model_runtimes = []
        self.input_type = self.configs['input_type']
        self.framework = self.configs['framework']

        self.model_name = self.configs['model']
    

        for i, segment in enumerate(self.configs['segments']):
            # there is no input/output meta data i the graph so it need to come from config.
            if not segment['input_tensor_map']:
                raise ValueError("Segment " + str(i) + " needs inputs")
            if not segment['output_tensor_map']:
                raise ValueError("Segment " + str(i) + " needs outputs")

            self.input_shapes = segment['input_tensor_map']
            self.outputs = segment['output_tensor_map'].split(",")

            if self.framework == "Tensorflow":


                '''
                判断需要的模型精度，并对之进行相应的转换操作
                '''

                if self.configs['compile_precision'] == "FP16":

                    with tf.device('GPU'):
                        model = tf.saved_model.load(
                            segment['compiled_model'][0]['compiled_obj'])
                        
                    for var in model.variables:
                        var.assign(tf.cast(var,tf.float16))

                if self.configs['compile_precision'] == "INT8":
                    
                    with tf.device('GPU'):
                        model = tf.saved_model.load(
                            segment['compiled_model'][0]['compiled_obj'])
                    for var in model.variables:
                        var.assign(tf.cast(var,tf.int8))

                if self.configs['compile_precision'] == "FP32":
                    
                    with tf.device('GPU'):
                        model = tf.saved_model.load(
                            segment['compiled_model'][0]['compiled_obj'])
                        
            elif self.framework == "Pytorch":
                self.device = "cuda"

                '''
                判断模型是否为bert,如果是需要禁用torch.jit.fuser
                '''

                if self.configs['model'].find("bert") != -1:

                    torch._C._jit_set_texpr_fuser_enabled(False)

                model = torch.jit.load(
                    segment['compiled_model'][0]['compiled_obj'],
                    torch.device('cuda'))
                
                if self.configs['compile_precision'] == "FP16":
                    if self.configs['model'].find("bert") != -1:
                        scaler = torch.cuda.amp.GradScaler()
                    model = model.half()
                model.eval()

            elif self.framework == "Migraphx":
                self.device = "cuda"
                if self.configs['model'] == 'bert-migraphx-fp16':
                    model = migraphx.load(segment['compiled_model'][0]['compiled_obj'] + f'-{self.batch_size}.mrx')
                else:
                    model = migraphx.parse_onnx(segment['compiled_model'][0]['compiled_obj'] + f'-{self.batch_size}.onnx')
                    if self.configs['compile_precision'] == "INT8":
                        print("=======================INT8====================")

                        dic = dict()
                        fake_data = self._get_fake_samples(batch_size, self.configs['segments'][0]['input_tensor_map'], self.configs['input_type'])
                        for key,_ in fake_data.items():
                            dic[key] = migraphx.argument(fake_data[key])
                        calibration = [dic]
                        migraphx.quantize_int8(model, migraphx.get_target("gpu"), calibration)

                    model.compile(migraphx.get_target("gpu"),offload_copy=False,device_id=0) 
                                                
            else:

                enable_tag = 'false'
                if self.configs['compile_precision'] == 'FP16':
                    enable_tag = 'true'
                    
                if self.configs['model'].find("resnet50") != -1:
                    if self.configs['compile_precision'] == 'INT8':
                        providers = ['ROCMExecutionProvider']
                    else:
                        providers = ['MIGraphXExecutionProvider']
                    # provider_options=[{'device_id': '0','migraphx_fp16_enable':enable_tag,'dynamic_model':'true','migraphx_profile_max_shapes':'input_1.1:256x3x224x224'}]
                    provider_options=[{'device_id': '0'}]

                else:
                    providers=['ROCMExecutionProvider']
                    # provider_options=None
                    provider_options=[{'device_id': '0'}]

                # model = onnxruntime.InferenceSession(
                #     segment['compiled_model'][0]['compiled_obj'],
                #     providers=providers,provider_options=provider_options)
                model = onnxruntime.InferenceSession(
                    segment['compiled_model'][0]['compiled_obj'] + f'-{self.batch_size}.onnx',
                    providers=providers,provider_options=provider_options)
                    

            self.model_runtimes.append(model)

    def _get_fake_samples(self, batch_size, shape, input_type):
        data = {}
        if input_type:
            i = 0

            new_input_type = self.input_type.split(',')
            for key, val in shape.items():


                if key != "text":
                    val = [val[0] * batch_size] + val[1:]
                    data[key] = np.random.random(size=val).astype(
                        INPUT_TYPE[new_input_type[i]])
                else:
                    data[key] = np.random.random(size=val).astype(
                        INPUT_TYPE[new_input_type[i]])

                i += 1
            return data
        else:
            raise ValueError("Please provide input type")
    
    def AllocateteOutputMemory(self, model):
        outputData={}
        for key in model.get_outputs().keys():
            outputData[key] = migraphx.allocate_gpu(s=model.get_outputs()[key])

        return outputData
    
    def GetMIGraphXType(self, type):
        typeMap = {
            'double_type': np.float64,
            'float_type': np.float32,
            'half_type': np.half,
            'int64_type': np.int64,
            'uint64_type': np.uint64,
            'int32_type': np.int32,
            'uint32_type': np.uint32,
            'int16_type': np.int16,
            'uint16_type': np.uint16,
            'int8_type': np.int8,
            'uint8_type': np.uint8,
            'bool_type': bool
        }
        return typeMap[type]
    
