import os.path as osp
from typing import Union, Tuple, Optional
from typing_extensions import Self
import warnings

import torch
from transformers import T5EncoderModel
from transformers.models.t5.configuration_t5 import T5Config 
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.utils import ContextManagers
from transformers.modeling_utils import no_init_weights
from transformers.integrations.accelerate import init_empty_weights

from .migraphx_model import DTYPE_MAPPING, MIGraphXModel


class MIGraphXT5EncoderModel(T5EncoderModel, MIGraphXModel):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    @classmethod
    def from_pretrained(
            cls, 
            model_dir: str, 
            subfolder: Optional[str] = None,
            batch: Optional[int] = 1, 
            img_size: Optional[int] = 1024, 
            model_dtype: Optional[str] = 'fp16', 
            force_compile: Optional[bool] = False,
            **kwargs
        ) -> Self:

        _subfolder = None if subfolder is None else subfolder.strip()
        _model_dir = osp.join(model_dir, _subfolder) if subfolder else model_dir

        model_cfg = T5Config.from_pretrained(_model_dir)
        model_cfg.name_or_path = _model_dir
        # model = cls(model_cfg)

        init_context = [no_init_weights(), init_empty_weights()]
        # init_context = cls.get_init_context(is_quantized=False, _is_ds_init_called=False)
        with ContextManagers(init_context):
            model = cls(model_cfg)

        # use MIGraphX model instead of PyTorch, so remove unsueful PyTorch module
        module_names = [module_name for module_name, _ in model.named_children()]
        for module_name in module_names:
            delattr(model, module_name)
            # del model._modules[module_name]

        # set attributes
        torch_dtype = DTYPE_MAPPING.get(model_dtype, None)
        model._dtype = torch_dtype
        model._device = torch.device("cuda")

        input_shapes = [
            [batch, 512]  # input_ids
        ]
        model.load_migraphx_model(_model_dir, 
                                  input_shapes,
                                  batch=batch, 
                                  img_size=img_size, 
                                  model_dtype=model_dtype, 
                                  force_compile=force_compile)

        model.output_hidden_states = len(model.mgx_model.get_outputs()) > 2
        model.num_hidden_layers = model_cfg.num_hidden_layers
        return model

    @property
    def dtype(self):
        if hasattr(self, "_dtype") and self._dtype is not None:
            return self._dtype
        return super().dtype

    @property
    def device(self):
        if hasattr(self, '_device') and self._device is not None:
            return self._device
        return super().device

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[tuple[torch.FloatTensor], BaseModelOutputWithPastAndCrossAttentions]:
        """
        reference: transformers/blob/4.54.1/src/transformers/models/t5/modeling_t5.py#L1895
        """

        if attention_mask is not None:
            warnings.warn(f"An unused argument `attention_mask` was passed in!", RuntimeWarning)
        if head_mask is not None:
            warnings.warn(f"An unused argument `head_mask` was passed in!", RuntimeWarning)
        if inputs_embeds is not None:
            warnings.warn(f"An unused argument `inputs_embeds` was passed in!", RuntimeWarning)
        if output_attentions:
            warnings.warn(f"Not supported output_attentions!", RuntimeWarning)
        if output_hidden_states:
            warnings.warn(f"Not supported output_hidden_states!", RuntimeWarning)

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        self.set_input_data(0, input_ids)
        self.run_model(mode='sync')

        last_hidden_state = self.get_output_data(0)

        if not return_dict:
            return (last_hidden_state) 
        return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=last_hidden_state)
