import numpy as np
import tensorrt as trt
import torch
from cuda import cudart
import common as common
# import common as common
from colored import fg, stylize
import copy
import time
import json

# 随机种子
def set_random_seed(num: int):
    np.random.seed(num)
    # torch.random.manual_seed(num)

def compare_value(pre_numpy: np.array, true_numpy: np.array):
    assert pre_numpy.shape == true_numpy.shape
    diff = np.abs(pre_numpy - true_numpy).max()
    print(f"{pre_numpy[0, 0, 0, :3]} == {true_numpy[0, 0, 0, :3]}")
    if diff > 1e-5:
        print(stylize(f"diff: {diff} is_pass: failed", fg("red")))
    else:
        print(stylize(f"diff: {diff} is_pass: OK", fg("green")))
    return diff


def load_tensor_from_npy_file(file_name, dir_path):
    w_path = f"{dir_path}/{file_name}.npy"
    data = np.load(w_path)
    return torch.from_numpy(data)
    

def load_numpy_from_npy_file(file_name, dir_path):
    w_path = f"{dir_path}/{file_name}.npy"
    data = np.load(w_path)
    return data


def load_numpy_from_tensor(tensor):
    return copy.deepcopy(tensor.detach().cpu().numpy())


def get_tensor_from_numpy(data):
    return torch.from_numpy(data)


def get_data_type(trt_data_type):
    if trt.DataType.FLOAT == trt_data_type:
        return torch.float32, 4
    if trt.DataType.HALF == trt_data_type:
        return torch.float16, 2
    if trt.DataType.INT8 == trt_data_type:
        return torch.int8, 1
    if trt.DataType.INT32 == trt_data_type:
        return torch.int32, 4
    if trt.DataType.BOOL == trt_data_type:
        return torch.bool, 1
    if trt.DataType.UINT8 == trt_data_type:
        return torch.uint8, 1
    if trt.DataType.FP8 == trt_data_type:
        return torch.float8, 1
    else:
        return "unknown", 0


class trtInfer:
    def __init__(self, plan_path, batch_size=1):
        self.init_plugin()
        with open(plan_path, "rb") as f:
            buffer = f.read()
        self.engine = trt.Runtime(self.logger).deserialize_cuda_engine(buffer)
        self.nIO = self.engine.num_io_tensors
        self.ITensorName = [self.engine.get_tensor_name(i) for i in range(self.nIO)]
        self.nInput = [self.engine.get_tensor_mode(self.ITensorName[i]) for i in range(self.nIO)].count(trt.TensorIOMode.INPUT)
        self.stream = cudart.cudaStreamCreate()[1]
        self.context = self.engine.create_execution_context()
        assert self.context
        # print(f"self.ITensorName: {self.ITensorName}")
        # print(f"self.nIO: {self.nIO}")
        # print(f"self.nInput: {self.nInput}")
        # Setup I/O bindings
        self.inputs = []
        self.outputs = []
        self.allocations = []
        self.IOBindings = []
        for i in range(self.nIO):
            name = self.ITensorName[i]
            mode = self.engine.get_tensor_mode(name)
            dtype = self.engine.get_tensor_dtype(name)
            shape = self.engine.get_tensor_shape(name)
            # print(f"name: {name}, shape: {shape}, dtype: {dtype}, mode: {mode}")
            t_type, size = get_data_type(dtype)
            for s in shape:
                if s == -1:
                    s = 1
                size *= s
            # allocation = common.cuda_call(cudart.cudaMalloc(size * batch_size))
            allocation = common.cuda_call(cudart.cudaMalloc(1024))
            self.allocations.append(allocation)
            binding = {
                "index": i,
                "name": name,
                "dtype": t_type,
                "shape": list(shape),
                "allocation": allocation,
            }
            
            if trt.TensorIOMode.INPUT == mode:
                self.batch_size = shape[0]
                self.inputs.append(binding)
            else:
                self.outputs.append(binding)
        device = torch.device("cuda:0")
        self.output_buffer = []
        for shape, dtype in self.output_spec():
            self.output_buffer.append(torch.zeros(shape, dtype=dtype).float().to(device))

    def init_plugin(self):
        self.logger = trt.Logger(trt.Logger.ERROR)
        trt.init_libnvinfer_plugins(self.logger, "")
    
    def input_spec(self):
        """
        Get the specs for the input tensor of the network. Useful to prepare memory allocations.
        :return: Two items, the shape of the input tensor and its (numpy) datatype.
        """
        specs = []
        for o in self.inputs:
            specs.append((o['shape'], o['dtype']))
        return specs

    def output_spec(self):
        """
        Get the specs for the output tensors of the network. Useful to prepare memory allocations.
        :return: A list with two items per element, the shape and (numpy) datatype of each output tensor.
        """
        specs = []
        for o in self.outputs:
            specs.append((o['shape'], o['dtype']))
        return specs

    def set_Bindding(self):
        self.IOBindings = []
        self.IOBindings.extend(self.inputs)
        self.IOBindings.extend(self.outputs)
        for i, item in enumerate(self.IOBindings):
            if i < self.nInput:
                if not self.context.set_input_shape(item["name"], item["shape"]):
                    return False
            if not self.context.set_tensor_address(item["name"], item["allocation"]):
                return False
        return True

    def set_input(self, binding_buffering):
        for i, item in enumerate(binding_buffering):
            if torch.is_tensor(item):
                self.inputs[i]['shape'] = list(item.shape)
                self.inputs[i]['allocation'] = item.reshape(-1).data_ptr()
            else:
                self.inputs[i]['allocation'] = item

    def set_output(self, binding_buffering):
        for i, item in enumerate(binding_buffering):
            self.outputs[i]['shape'] = list(item.shape)
            self.outputs[i]['allocation'] = item.reshape(-1).data_ptr()

    def release(self):
        cudart.cudaStreamDestroy(self.stream)


class DM_TRT(trtInfer):
    def __init__(self, plan_path, bs=1):
        super().__init__(plan_path, bs)

    def __call__(self, x, timesteps, context, control, only_mid_control=False):
        device = x.device

        timesteps = timesteps.int()
        input_buffer = []
        input_buffer.append(x)
        input_buffer.append(timesteps)
        input_buffer.append(context)
        input_buffer.extend(control)

        current_batch = x.shape[0]
        output_buffer = []
        for shape, dtype in self.output_spec():
            shape[0] = current_batch
            output_buffer.append(torch.zeros(shape, dtype=dtype).float().to(device))

        self.set_input(input_buffer)  # set shape, allocate
        self.set_output(output_buffer)
        self.set_Bindding()
        self.context.execute_async_v3(self.stream)
        cudart.cudaStreamSynchronize(self.stream)
        return output_buffer[0]


class CM_TRT(trtInfer):
    def __init__(self, plan_path, bs=1):
        super().__init__(plan_path, bs)

    def __call__(self, x, hint, timesteps, context, **kwargs):
        device = x.device

        timesteps = timesteps.int()
        input_buffer = []
        input_buffer.append(x)
        input_buffer.append(hint)
        input_buffer.append(timesteps)
        input_buffer.append(context)

        # current_batch = x.shape[0]
        # output_buffer = []
        # for shape, dtype in self.output_spec():
        #     shape[0] = current_batch
        #     output_buffer.append(torch.zeros(shape, dtype=dtype).float().to(device))

        self.set_input(input_buffer)  # set shape, allocate
        # self.set_output(self.output_buffer)
        self.set_Bindding()
        self.context.execute_async_v3(self.stream)
        cudart.cudaStreamSynchronize(self.stream)

        # return output_buffer
        # return self.output_buffer
        return self.allocations[self.nInput:self.nIO]


class CM_DM_FUSE_TRT:
    def __init__(self, control_path, unet_path):
        self.control = CM_TRT(control_path)
        self.unet = DM_TRT(unet_path)

    def __call__(self, x, hint, timesteps, context, **kwargs):
        device = x.device

        timesteps = timesteps.int()
        input_buffer = []
        input_buffer.append(x)
        input_buffer.append(hint)
        input_buffer.append(timesteps)
        input_buffer.append(context)

        self.control.set_input(input_buffer)   # set shape, allocate
        # self.control.set_output(self.output_buffer)  # 使用 内部开辟好的cudaMemcpy

        input_unet_buffer = []
        input_unet_buffer.append(self.control.inputs[0]["allocation"])
        input_unet_buffer.append(self.control.inputs[2]["allocation"])
        input_unet_buffer.append(self.control.inputs[3]["allocation"])
        input_unet_buffer.extend(self.control.allocations[self.control.nInput:self.control.nIO])

        current_batch = x.shape[0]
        output_unet_buffer = []
        for shape, dtype in self.unet.output_spec():
            shape[0] = current_batch
            output_unet_buffer.append(torch.zeros(shape, dtype=dtype).float().to(device))

        self.unet.set_input(input_unet_buffer)   # set shape, allocate
        self.unet.set_output(output_unet_buffer)  # 使用 内部开辟好的cudaMemcpy
        
        self.control.set_Bindding()
        self.unet.set_Bindding()
        self.control.context.execute_async_v3(self.control.stream)
        self.unet.context.execute_async_v3(self.control.stream)
        cudart.cudaStreamSynchronize(self.control.stream)

        return output_unet_buffer[0]


def memcpy_tensor_to_dev(data, address):
    a_size = data[0].numel() * data[0].element_size()
    for i, item in enumerate(data):
        item_address = item.reshape(-1).data_ptr()
        # batch x
        common.cuda_call(cudart.cudaMemcpy(
            address + i * a_size, item_address, a_size, cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice))


class CM_DM_BATCH_TRT:
    def __init__(self, control_path, unet_path, batch_size):
        self.control = CM_TRT(control_path, batch_size)
        self.unet = DM_TRT(unet_path, batch_size)

    # def __call__(self, x, hint, timesteps, context, **kwargs):
    #     device = x.device

    #     timesteps = timesteps.int()
    #     input_buffer = []
    #     # input_buffer.append(x)
    #     memcpy_tensor_to_dev([x,x], self.control.inputs[0]["allocation"])
    #     # input_buffer.append(hint)
    #     memcpy_tensor_to_dev(hint, self.control.inputs[1]["allocation"])
    #     # input_buffer.append(timesteps)
    #     memcpy_tensor_to_dev([timesteps, timesteps], self.control.inputs[2]["allocation"])
    #     # input_buffer.append(context)
    #     memcpy_tensor_to_dev(context, self.control.inputs[3]["allocation"])

    #     # self.control.set_input(input_buffer)   # 使用 内部开辟好的cudaMemcpy
    #     # self.control.set_output(self.output_buffer)  # 使用 内部开辟好的cudaMemcpy
    #     self.control.set_Bindding()

    #     input_unet_buffer = []
    #     input_unet_buffer.append(self.control.inputs[0]["allocation"])
    #     input_unet_buffer.append(self.control.inputs[2]["allocation"])
    #     input_unet_buffer.append(self.control.inputs[3]["allocation"])
    #     input_unet_buffer.extend(self.control.allocations[self.control.nInput:self.control.nIO])

    #     # current_batch = x.shape[0]
    #     current_batch = 2
    #     output_unet_buffer = []
    #     for shape, dtype in self.unet.output_spec():
    #         shape[0] = current_batch
    #         temp = torch.zeros(shape, dtype=dtype).float().to(device)
    #         output_unet_buffer.append(temp)

    #     self.unet.set_input(input_unet_buffer)   # set shape, allocate
    #     self.unet.set_output(output_unet_buffer)  # 使用 内部开辟好的cudaMemcpy
    #     self.unet.set_Bindding()

    #     self.control.context.execute_async_v3(self.control.stream)
    #     self.unet.context.execute_async_v3(self.control.stream)
    #     cudart.cudaStreamSynchronize(self.control.stream)

    #     model_t = output_unet_buffer[0][0]
    #     model_uncond = output_unet_buffer[0][1]
    #     model_output = model_uncond + 9 * (model_t - model_uncond)

    #     return model_output

    def __call__(self, x, hint, timesteps, context, **kwargs):
        device = x.device

        timesteps = timesteps.int()
        input_buffer = []
        input_buffer.append(x)
        # memcpy_tensor_to_dev([x,x], self.control.inputs[0]["allocation"])
        input_buffer.append(hint)
        # memcpy_tensor_to_dev(hint, self.control.inputs[1]["allocation"])
        input_buffer.append(timesteps)
        # memcpy_tensor_to_dev([timesteps, timesteps], self.control.inputs[2]["allocation"])
        input_buffer.append(context)
        # memcpy_tensor_to_dev(context, self.control.inputs[3]["allocation"])

        self.control.set_input(input_buffer)   # 使用 内部开辟好的cudaMemcpy
        # self.control.set_output(self.output_buffer)  # 使用 内部开辟好的cudaMemcpy
        self.control.set_Bindding()

        input_unet_buffer = []
        input_unet_buffer.append(self.control.inputs[0]["allocation"])
        input_unet_buffer.append(self.control.inputs[2]["allocation"])
        input_unet_buffer.append(self.control.inputs[3]["allocation"])
        input_unet_buffer.extend(self.control.allocations[self.control.nInput:self.control.nIO])

        # current_batch = x.shape[0]
        current_batch = 2
        output_unet_buffer = []
        for shape, dtype in self.unet.output_spec():
            shape[0] = current_batch
            temp = torch.zeros(shape, dtype=dtype).float().to(device)
            output_unet_buffer.append(temp)

        self.unet.set_input(input_unet_buffer)   # set shape, allocate
        self.unet.set_output(output_unet_buffer)  # 使用 内部开辟好的cudaMemcpy
        self.unet.set_Bindding()

        self.control.context.execute_async_v3(self.control.stream)
        self.unet.context.execute_async_v3(self.control.stream)
        cudart.cudaStreamSynchronize(self.control.stream)

        return output_unet_buffer[0]


class Decoder_TRT(trtInfer):
    def __init__(self, plan_path):
        super().__init__(plan_path)

    def __call__(self, z):
        device = z.device

        input_buffer = []
        input_buffer.append(z)
        current_batch = z.shape[0]
        output_buffer = []
        for shape, dtype in self.output_spec():
            shape[0] = current_batch
            output_buffer.append(torch.zeros(shape, dtype=dtype).float().to(device))

        self.set_input(input_buffer)  # set shape, allocate
        self.set_output(output_buffer)
        self.set_Bindding()
        self.context.execute_async_v3(self.stream)
        cudart.cudaStreamSynchronize(self.stream)

        return output_buffer[0]


class ClipModelOutputs:
    def __init__(self, last_hidden_state):
        self.last_hidden_state = last_hidden_state


class CL_TRT(trtInfer):
    def __init__(self, plan_path):
        super().__init__(plan_path)

    def __call__(self, input_ids, **kwargs):
        device = input_ids.device
        input_ids = input_ids.int()

        input_buffer = []
        input_buffer.append(input_ids)
        # intput_id = x.cpu().numpy()
        # common.memcpy_host_to_device(self.inputs[0]["allocation"], intput_id)

        current_batch = input_ids.shape[0]
        output_buffer = []
        for shape, dtype in self.output_spec():
            shape[0] = current_batch
            output_buffer.append(torch.zeros(shape, dtype=dtype).float().to(device))

        self.set_input(input_buffer)  # set shape, allocate
        self.set_output(output_buffer)
        self.set_Bindding()
        self.context.execute_async_v3(self.stream)
        cudart.cudaStreamSynchronize(self.stream)

        # text_embedding = np.zeros((1, 77, 768), dtype=np.float32)
        # pooler_output = np.zeros((1, 768), dtype=np.float32)
        # common.memcpy_device_to_host(text_embedding, self.outputs[0]["allocation"])
        # common.memcpy_device_to_host(pooler_output, self.outputs[1]["allocation"])
        # print(text_embedding)
        # print(pooler_output)

        return ClipModelOutputs(*output_buffer)
        # return None
        
        
class EXP_TRT(trtInfer):
    def __init__(self, plan_path, batch_size):
        super().__init__(plan_path, batch_size)

    def __call__(self, input_datas):
        self.set_input(input_datas)
        self.set_Bindding()
        self.context.execute_async_v3(self.stream)
        cudart.cudaStreamSynchronize(self.stream)
        return 0


if __name__ == "__main__":
    set_random_seed(2)
    for batch_size in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]:
            
        input_data_json_path = f'../new_models/model_1/dataset/input_tensor_datas_{batch_size}.json' 
        with open(input_data_json_path, 'r') as f:
            input_datas = json.load(f)
        input_datas = [value for value in input_datas.values()]
        device = torch.device("cuda:0")
        model_path = f"../new_models/model_1/trt/model-static-batch-size-{batch_size}.trt"
        dm_trt = EXP_TRT(model_path, batch_size)
        specs = dm_trt.input_spec()
        specs = [spec[-1] for spec in specs]
        input_datas = [torch.tensor(value, dtype=spec).to(device) for value, spec in zip(input_datas, specs) ]
        
        times = time.time()
        for i in range(1100):
            if i < 100:
                times = time.time()
            dm_trt(input_datas)
        print(f"*******batch_size: {batch_size} *******QPS: {1000 / (time.time() - times) * batch_size}")
        time.sleep(10)
