from typing import Dict, List, Tuple
import migraphx as mgx
import torch

from .utils import load_migraphx_model


mgx_to_torch_dtype_dict = {
    "bool_type": torch.bool,
    "uint8_type": torch.uint8,
    "int8_type": torch.int8,
    "int16_type": torch.int16,
    "int32_type": torch.int32,
    "int64_type": torch.int64,
    "float_type": torch.float32,
    "double_type": torch.float64,
    "half_type": torch.float16,
}

torch_to_mgx_dtype_dict = {
    value: key
    for (key, value) in mgx_to_torch_dtype_dict.items()
}

DTYPE_MAPPING = {
    'fp32': torch.float32, 
    'fp16': torch.float16, 
    'bf16': torch.bfloat16,
    'int8': torch.int8
}


class MIGraphXModel:
    def __init__(self):
        self.input_names: List[str] = None
        self.input_shapes: List[int] | Tuple[int] = None
        self.model_dtype: str = None
        self.mgx_model: mgx.program = None
        self.mxr_path: str = None
        self.tensors: Dict[str, torch.Tensor] = None
        self.mgx_arguments: Dict[str, mgx.argument] = None

    def load_migraphx_model(
            self, 
            model_dir: str, 
            input_shapes: List[int] | Tuple[int], 
            batch: int = 1, 
            img_size: int = 1024, 
            model_dtype: str = 'fp16', 
            force_compile: bool = False
        ) -> None:
        """
        Load the migraphx model from the given directory.
        """
        
        assert model_dtype in ['fp16', 'fp32'], "Only fp16 and fp32 are supported"
        use_fp16 = model_dtype == 'fp16'
        
        self.mgx_model, self.mxr_path = load_migraphx_model(
            model_dir, 
            input_shapes, 
            use_fp16=use_fp16, 
            force_compile=force_compile,
            offload_copy=False,
            batch=batch,
            img_size=img_size
        )
        
        self.input_names = self._get_input_names()
        self.model_dtype = torch.float16 if use_fp16 else torch.float32
        self.input_shapes = input_shapes
        self.tensors = self._allocate_torch_tensors()
        self.mgx_arguments = self._tensors_to_arguments()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError("forward is not implemented.")

    def _allocate_torch_tensors(self):
        input_output_shapes = self.mgx_model.get_parameter_shapes()
        data_mapping = {
            name: torch.zeros(shape.lens()).to(
                mgx_to_torch_dtype_dict[shape.type_string()]).to(device="cuda")
            for name, shape in input_output_shapes.items()
        }
        return data_mapping

    @staticmethod
    def _tensor_to_argument(tensor):
        return mgx.argument_from_pointer(
            mgx.shape(
                type=torch_to_mgx_dtype_dict[tensor.dtype],
                lens=list(tensor.size()),
                strides=list(tensor.stride())
            ), tensor.data_ptr()
        )

    def _tensors_to_arguments(self):
        return {name: self._tensor_to_argument(tensor) 
                for name, tensor in self.tensors.items()}

    def _get_input_names(self):
        # return list(self.mgx_model.get_inputs().keys())
        input_indexes = sorted(self.mgx_model.get_input_indexes().items(), 
                               key=lambda x: x[1])
        return [x[0] for x in input_indexes]
        
    @staticmethod
    def _get_output_name(idx):
        return f"main:#output_{idx}"

    def set_input_data(self, name_or_index, data, sync=False):
        assert isinstance(name_or_index, (str, int)), \
            "name_or_index must be string or integer"
        name = self.input_names[name_or_index] \
            if isinstance(name_or_index, int) else name_or_index
        # print(self.input_names)
        self.tensors[name].copy_(data.to(self.tensors[name].dtype))
        if sync:
            torch.cuda.synchronize()

    def run_model(self, stream=None, mode='sync'):
        if mode == 'sync':
            self.mgx_model.run(self.mgx_arguments)
            mgx.gpu_sync()
        elif mode == 'async':
            self.mgx_model.run_async(self.mgx_arguments, stream, "ihipStream_t")
        else:
            raise ValueError(f"Invalid migraphx running mode: {mode}")
    
    def get_output_data(self, idx):
        return self.tensors[self._get_output_name(idx)]
