from collections import namedtuple
import os
import os.path as osp
from typing import Union, Tuple, Optional, Dict, Any
from typing_extensions import Self
import warnings

import accelerate
from diffusers import FluxTransformer2DModel
from diffusers.models.modeling_utils import ContextManagers, no_init_weights
from diffusers.models.modeling_outputs import Transformer2DModelOutput
import torch

from .migraphx_model import DTYPE_MAPPING, MIGraphXModel


class MIGraphXFluxTransformer2DModel(FluxTransformer2DModel, MIGraphXModel):

    @classmethod
    def from_pretrained(
            cls, 
            model_dir: str, 
            subfolder: Optional[str] = None,
            batch: Optional[int] = 1, 
            img_size: Optional[int] = 1024, 
            model_dtype: Optional[str] = 'fp16', 
            force_compile: Optional[bool] = False,
            **kwargs
        ) -> Self:
        
        _subfolder = None if subfolder is None else subfolder.strip()
        _model_dir = osp.join(model_dir, _subfolder) if subfolder else model_dir

        model_cfg = cls.load_config(_model_dir)

        # set torch dtype
        dtype_orig = None
        torch_dtype = DTYPE_MAPPING.get(model_dtype, None)
        if torch_dtype is not None and not torch_dtype == getattr(torch, "float8_e4m3fn", None):
            dtype_orig = cls._set_default_torch_dtype(torch_dtype)
        
        # initialize model
        init_contexts = [no_init_weights(), accelerate.init_empty_weights()]
        with ContextManagers(init_contexts):
            model = cls.from_config(model_cfg)
        
        # restore torch dtype
        if dtype_orig is not None:
            torch.set_default_dtype(dtype_orig)

        # use MIGraphX model instead of PyTorch, so remove unsueful PyTorch module
        module_names = [module_name for module_name, _ in model.named_children()]
        for module_name in module_names:
            delattr(model, module_name)
            # del model._modules[module_name]

        # fix input shape of MIGraphX model
        joint_attention_dim = model_cfg['joint_attention_dim']
        in_channels = model_cfg['in_channels']
        pooled_projection_dim = model_cfg['pooled_projection_dim']
        sequence_length = 512
        # vae_block_out_channels = [128, 256, 512, 512]
        # model.vae_scale_factor = 2 ** (len(vae_block_out_channels) - 1)
        vae_scale_factor = 8
        height = 2 * (int(img_size) // (vae_scale_factor * 2))
        width = 2 * (int(img_size) // (vae_scale_factor * 2))
        dim1 = (height // 2) * (width // 2)
        # print(f"height: {height}, width: {width}, dim1: {dim1}")
        input_shapes = [
            [batch, dim1, in_channels], # hidden_states
            [batch, sequence_length, joint_attention_dim],  # encoder_hidden_states
            [batch, pooled_projection_dim],  # pooled_projections
            [batch],  # timestep
            [dim1, 3],  # img_ids
            [sequence_length, 3],  # txt_ids
            [batch],  # guidance
        ]

        model.load_migraphx_model(_model_dir, 
                                  input_shapes,
                                  batch=batch, 
                                  img_size=img_size, 
                                  model_dtype=model_dtype, 
                                  force_compile=force_compile)
        
        # model.input_names = [
        #     "hidden_states", "encoder_hidden_states", "pooled_projections", 
        #     "timestep", "img_ids", "txt_ids", "guidance"
        # ]

        return model

    def forward(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor = None,
        pooled_projections: torch.Tensor = None,
        timestep: torch.LongTensor = None,
        img_ids: torch.Tensor = None,
        txt_ids: torch.Tensor = None,
        guidance: torch.Tensor = None,
        joint_attention_kwargs: Optional[Dict[str, Any]] = None,
        controlnet_block_samples=None,
        controlnet_single_block_samples=None,
        return_dict: bool = True,
        controlnet_blocks_repeat: bool = False,
    ) -> Union[torch.Tensor, Transformer2DModelOutput]:
        """
        reference: diffusers/blob/v0.34.0/src/diffusers/models/transformers/transformer_flux.py#L389
        """

        if joint_attention_kwargs is not None:
            warnings.warn(f"An unused argument `joint_attention_kwargs` was passed in!", RuntimeWarning)
        if controlnet_block_samples is not None:
            warnings.warn(f"An unused argument `controlnet_block_samples` was passed in!", RuntimeWarning)
        if controlnet_single_block_samples is not None:
            warnings.warn(f"An unused argument `controlnet_single_block_samples` was passed in!", RuntimeWarning)
        if controlnet_blocks_repeat is not None:
            warnings.warn(f"An unused argument `controlnet_blocks_repeat` was passed in!", RuntimeWarning)

        # print(self.input_shapes[0], hidden_states.shape)
        # print(self.input_shapes[1], encoder_hidden_states.shape)
        # print(self.input_shapes[2], pooled_projections.shape)
        # print(self.input_shapes[4], img_ids.shape)
        # print(self.input_shapes[5], txt_ids.shape)
        self.set_input_data(0, hidden_states)
        self.set_input_data(1, encoder_hidden_states)
        self.set_input_data(2, pooled_projections)
        self.set_input_data(3, timestep)
        self.set_input_data(4, img_ids)
        self.set_input_data(5, txt_ids)
        self.set_input_data(6, guidance)

        self.run_model(mode="sync")
        out_hidden_states = self.get_output_data(0)

        if not return_dict:
            return (out_hidden_states,)
        return Transformer2DModelOutput(sample=out_hidden_states)

    @property
    def dtype(self):
        if hasattr(self, "_dtype") and self._dtype is not None:
            return self._dtype
        return super().dtype

    @property
    def device(self):
        if hasattr(self, '_device') and self._device is not None:
            return self._device
        return super().device
