import os
import json
import logging

import torch
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
import torch
import onnxruntime
import time

import numpy as np

from general_perf.backends import runtime_backend

log = logging.getLogger("BackendDCU")

pt_dtype_map = {
    "FLOAT32": torch.float32,
    "FLOAT16": torch.float16,
    "INT8": torch.int8,
    "LONG": torch.long
}

INPUT_TYPE = {
    "INT8": np.int8,
    "UINT8": np.uint8,
    "FLOAT32": np.float32,
    "FLOAT16": np.float16,
    "LONG": np.long,
    "INT32": np.int32,
    "INT64": np.int64,
    "BOOL": np.bool
}

class RuntimeBackendDCU(runtime_backend.RuntimeBackend):
    def __init__(self):
        super(RuntimeBackendDCU, self).__init__()
        self.hardware_type = 'DCU'
        self.need_reload = False
        self.model_runtimes = []
        self.configs = None
        self.batch_size = -1

    def predict(self, feeds):
        results = {}
        if self.framework == "Tensorflow":
            entry_rt = self.model_runtimes[0].signatures['serving_default']
            all_sn_inputs = entry_rt.structured_input_signature

            def get_real_feeds(feeds, sn_inputs):
                sn_inputs = tf.nest.flatten(sn_inputs, True)
                real_feeds = {}               
                itr = 0
                for _, val in feeds.items():
                    real_feeds[sn_inputs[itr].name] = tf.constant(val)
                    itr += 1
                return real_feeds
            real_feeds = get_real_feeds(feeds, all_sn_inputs)

            for model_runtime in self.model_runtimes:

                with tf.device('GPU'):    
                    _results = model_runtime.signatures['serving_default'](
                        **real_feeds)

            results = {}
            for key, val in _results.items():
                results[key] = val.numpy()

            assert len(results) != 0

        elif self.framework == "Pytorch":
            input_tensors = []
            new_input_type = self.input_type.split(',')
      
            i = 0

            for key, _ in feeds.items():
                input_tensors.append(
                    torch.tensor(feeds[key], 
                                 dtype=pt_dtype_map[new_input_type[i]]).to(
                                     self.device))
                i += 1
                

            if self.configs["model"] == "bert-torch-fp16":
                with torch.cuda.amp.autocast():
                    with torch.no_grad():
                        for model_runtime in self.model_runtimes:
                            results = model_runtime(*input_tensors)
            
            else:
                with torch.no_grad():
                    for model_runtime in self.model_runtimes:
                        results = model_runtime(*input_tensors)
            
            if isinstance(results, dict):
                for key, val in results.items():
                    results[key] = val.cpu().detach().numpy()
            elif isinstance(results, tuple):
                dic = {}
                for i, key in enumerate(self.outputs):
                    dic[key] = list(results)[i]
            else:
                results = {self.outputs[0]: results.cpu().numpy()}           
        else:
            for model_runtime in self.model_runtimes:              
                if self.configs["model"] == "resnet50-onnxruntime-fp16":
                    feeds["input_1.1"] = feeds["input_1.1"].astype("float16")
                results = model_runtime.run(None, feeds)
        return results

    def benchmark(self, dataloader):
        iterations = self.workload['iterations']
        batch_size = self.get_loaded_batch_size()
        times_range = []
        report = {}
        report['BS'] = batch_size

        test_data = self._get_fake_samples(
            batch_size, self.configs['segments'][0]['input_tensor_map'],
            self.configs['input_type'])

        for _ in range(30):
            self.predict(test_data)

        for _ in range(iterations):
            start_time = time.time()
            self.predict(test_data)
            end_time = time.time()
            times_range.append(end_time - start_time)

        times_range.sort()
        tail_latency = round(
            times_range[int(len(times_range) * 0.99)] * 1000, 2)
        avg_latency = round(sum(times_range) / iterations * 1000, 2)
        qps = int(1000.0 * batch_size / avg_latency)

        log.info(
            'Batch size is {}, QPS: {}, Avg Latency:{}, Tail Latency:{}'.
            format(batch_size, qps, avg_latency, tail_latency))

        report['QPS'] = qps
        report['AVG Latency'] = avg_latency
        report['P99 Latency'] = tail_latency

        return report

    def get_loaded_batch_size(self):
        return self.batch_size

    def load(self, batch_size) -> None:
        self.batch_size = batch_size
        self.model_runtimes = []
        self.input_type = self.configs['input_type']
        self.framework = self.configs['framework']

        self.model_name = self.configs['model']

        for i, segment in enumerate(self.configs['segments']):
            # there is no input/output meta data i the graph so it need to come from config.
            if not segment['input_tensor_map']:
                raise ValueError("Segment " + str(i) + " needs inputs")
            if not segment['output_tensor_map']:
                raise ValueError("Segment " + str(i) + " needs outputs")

            self.input_shapes = segment['input_tensor_map']
            self.outputs = segment['output_tensor_map'].split(",")

            if self.framework == "Tensorflow":

                with tf.device('GPU'):
                    model = tf.saved_model.load(
                        segment['compiled_model'][0]['compiled_obj'])
                
                if self.configs['compile_precision'] == "FP16":
                    # 将所有变量转换为 float16
                    for var in model.variables:
                        var.assign(tf.cast(var,tf.float16))         
                        
            elif self.framework == "Pytorch":
                self.device = "cuda"
    
                if self.configs["model"].split("-")[0] == "bert" or self.configs["model"].split("-")[0] == "roberta":
                    # torch.jit.fuser('off')
                    torch._C._jit_override_can_fuse_on_cpu(False)
                    torch._C._jit_override_can_fuse_on_gpu(False)
                    # torch._C._jit_set_texpr_fuser_enabled(False)
                    # torch._C._jit_set_nvfuser_enabled(False)
                    # https://github.com/pytorch/pytorch/issues/62962
                
                model = torch.jit.load(
                    segment['compiled_model'][0]['compiled_obj'],
                    torch.device('cuda'))
                
                if self.configs['compile_precision'] == "FP16":
                    model = model.half()
                model.eval()
            
                                               
            else:
                # import pdb
                # pdb.set_trace()
                providers = [ 
                    ('ROCMExecutionProvider', {
                        'device_id': 0,
                        'arena_extend_strategy': 'kNextPowerOfTwo',
                        # 'cudnn_conv_algo_search': 'EXHAUSTIVE',
                        'do_copy_in_default_stream': True,
                    }),
                ]
                
                # # 启用 FP16
                options = onnxruntime.SessionOptions()
                # options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
                # options.intra_op_num_threads = 1
                # options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
                # options.enable_cuda_graph = True  # 如果你的硬件支持 CUDA Graph
                # options.add_session_config_entry("session.set_denormal_as_zero", "1")
                
                if self.configs['compile_precision'] == "FP16":
                    options.add_session_config_entry("session.enable_fp16", "1")  # 启用 FP16
                model = onnxruntime.InferenceSession(
                    segment['compiled_model'][0]['compiled_obj'],
                    providers=providers,
                    sess_options=options)

            self.model_runtimes.append(model)

    def _get_fake_samples(self, batch_size, shape, input_type):
        data = {}
        if input_type:
            i = 0
            new_input_type = self.input_type.split(',')
            for key, val in shape.items():
                if key != "text":
                    val = [val[0] * batch_size] + val[1:]
                    data[key] = np.random.random(size=val).astype(
                        INPUT_TYPE[new_input_type[i]])
                else:
                    data[key] = np.random.random(size=val).astype(
                        INPUT_TYPE[new_input_type[i]])
                i += 1
            return data
        else:
            raise ValueError("Please provide input type")
