import os.path as osp
from typing import Union, Tuple, Optional
from typing_extensions import Self

import torch
from transformers import CLIPTextConfig, CLIPPreTrainedModel, CLIPTextModel, CLIPTextModelWithProjection
from transformers.modeling_outputs import BaseModelOutputWithPooling
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from transformers.utils import ContextManagers
from transformers.modeling_utils import no_init_weights
from transformers.integrations.accelerate import init_empty_weights

from .migraphx_model import DTYPE_MAPPING, MIGraphXModel


class MIGraphXCLIPPreTrainedModel(CLIPPreTrainedModel, MIGraphXModel):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    @classmethod
    def from_pretrained(
            cls, 
            model_dir: str, 
            subfolder: Optional[str] = None,
            batch: Optional[int] = 1, 
            img_size: Optional[int] = 1024, 
            model_dtype: Optional[str] = 'fp16', 
            force_compile: Optional[bool] = False,
            **kwargs
        ) -> Self:

        _subfolder = None if subfolder is None else subfolder.strip()
        _model_dir = osp.join(model_dir, _subfolder) if subfolder else model_dir

        model_cfg = CLIPTextConfig.from_pretrained(_model_dir)
        model_cfg.name_or_path = _model_dir
        # model = cls(model_cfg)

        init_context = [no_init_weights(), init_empty_weights()]
        # init_context = cls.get_init_context(is_quantized=False, _is_ds_init_called=False)
        with ContextManagers(init_context):
            model = cls(model_cfg)

        # use MIGraphX model instead of PyTorch, so remove unsueful PyTorch module
        module_names = [module_name for module_name, _ in model.named_children()]
        for module_name in module_names:
            delattr(model, module_name)
            # del model._modules[module_name]

        # set attributes
        torch_dtype = DTYPE_MAPPING.get(model_dtype, None)
        model._dtype = torch_dtype
        model._device = torch.device("cuda")

        input_shapes = [
            [1 * batch, 77]  # input_ids
        ]
        model.load_migraphx_model(_model_dir, 
                                  input_shapes,
                                  batch=batch, 
                                  img_size=img_size, 
                                  model_dtype=model_dtype, 
                                  force_compile=force_compile)

        model.output_hidden_states = len(model.mgx_model.get_outputs()) > 2
        model.num_hidden_layers = model_cfg.num_hidden_layers
        return model

    @property
    def dtype(self):
        if hasattr(self, "_dtype") and self._dtype is not None:
            return self._dtype
        return super().dtype

    @property
    def device(self):
        if hasattr(self, '_device') and self._device is not None:
            return self._device
        return super().device


class MIGraphXCLIPTextModel(CLIPTextModel, MIGraphXCLIPPreTrainedModel):
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = True,
    ) -> Union[Tuple, BaseModelOutputWithPooling]:
        """
        reference: transformers/blob/4.54.1/src/transformers/models/clip/modeling_clip.py#L691
        """
        self.set_input_data(0, input_ids)
        self.run_model(mode='sync')

        last_hidden_state = self.get_output_data(0)
        pooled_output = self.get_output_data(1)

        if output_hidden_states and self.output_hidden_states:
            start_idx = 2
            hidden_states = tuple([
                self.get_output_data(i)
                for i in range(start_idx, start_idx + self.config.num_hidden_layers + 1)
            ])
        else:
            hidden_states = tuple()

        if not return_dict:
            return (last_hidden_state, pooled_output) + hidden_states

        return BaseModelOutputWithPooling(
            last_hidden_state=last_hidden_state,
            pooler_output=pooled_output,
            hidden_states=hidden_states
        )


class MIGraphXCLIPTextModelWithProjection(CLIPTextModelWithProjection, MIGraphXCLIPPreTrainedModel):
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = True,
    ) -> Union[Tuple, BaseModelOutputWithPooling]:
        """
        reference: transformers/blob/4.54.1/src/transformers/models/clip/modeling_clip.py#L1076
        """
        self.set_input_data(0, input_ids)
        self.run_model(mode="sync")

        text_embeds = self.get_output_data(0)
        last_hidden_state = self.get_output_data(1)
        if output_hidden_states and self.output_hidden_states:
            start_idx = 2
            hidden_states = tuple([
                self.get_output_data(i)
                for i in range(start_idx, start_idx + self.config.num_hidden_layers + 1)
            ])
        else:
            hidden_states = tuple()

        if not return_dict:
            return (text_embeds, last_hidden_state) + hidden_states

        return CLIPTextModelOutput(
            text_embeds=text_embeds,
            last_hidden_state=last_hidden_state,
            hidden_states=hidden_states,
        )
