from collections import namedtuple
import os
import os.path as osp
from typing import Union, Tuple, Optional, Dict, Any
from typing_extensions import Self

import accelerate
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel
from diffusers.models.modeling_utils import ContextManagers, no_init_weights
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
import torch

from .migraphx_model import DTYPE_MAPPING, MIGraphXModel


class MIGraphXUNet2DConditionModel(UNet2DConditionModel, MIGraphXModel):

    @classmethod
    def from_pretrained(
            cls, 
            model_dir: str, 
            subfolder: Optional[str] = None,
            batch: Optional[int] = 1, 
            img_size: Optional[int] = 1024, 
            model_dtype: Optional[str] = 'fp16', 
            force_compile: Optional[bool] = False,
            **kwargs
        ) -> Self:
        
        _subfolder = None if subfolder is None else subfolder.strip()
        _model_dir = osp.join(model_dir, _subfolder) if subfolder else model_dir

        model_cfg = cls.load_config(_model_dir)

        # set torch dtype
        dtype_orig = None
        torch_dtype = DTYPE_MAPPING.get(model_dtype, None)
        if torch_dtype is not None and not torch_dtype == getattr(torch, "float8_e4m3fn", None):
            dtype_orig = cls._set_default_torch_dtype(torch_dtype)
        
        # initialize model
        init_contexts = [no_init_weights(), accelerate.init_empty_weights()]
        with ContextManagers(init_contexts):
            model = cls.from_config(model_cfg)
        
        if hasattr(model, "add_embedding"):
            linear_1_in_features = model.add_embedding.linear_1.in_features
        else:
            linear_1_in_features = None

        # restore torch dtype
        if dtype_orig is not None:
            torch.set_default_dtype(dtype_orig)

        # use MIGraphX model instead of PyTorch, so remove unsueful PyTorch module
        module_names = [module_name for module_name, _ in model.named_children()]
        for module_name in module_names:
            delattr(model, module_name)
            # del model._modules[module_name]

        # set attributes
        # 1. diffusers/blob/v0.34.0/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L487
        # 2. diffusers/blob/v0.34.0/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L501
        # 3. diffusers/blob/v0.34.0/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L742
        if linear_1_in_features is not None:
            add_embedding = namedtuple("AddEmbedding", "linear_1")
            add_embedding.linear_1 = namedtuple("Linear", "in_features")
            add_embedding.linear_1.in_features = linear_1_in_features
            model.add_embedding = add_embedding
        model._dtype = torch_dtype
        model._device = torch.device("cuda")

        # fix input shape of MIGraphX model
        in_channels = model_cfg['in_channels']
        img_sz_scale = 8  # TODO: check multi-scale
        h = img_size // img_sz_scale
        w = img_size // img_sz_scale
        input_shapes = [
            [2 * batch, in_channels, h, w], # sample
            [1],  # timestep
            [2 * batch, 77, model_cfg['cross_attention_dim']],  # encoder_hidden_states
        ]
        pipeline_class = kwargs.get("pipeline_class", None)
        if pipeline_class is StableDiffusionXLPipeline:
            input_shapes.extend([
                [2 * batch, 1280],  # text_embeds
                [2 * batch, 6],  # time_ids
            ])

        # load MIGraphX model
        model.load_migraphx_model(_model_dir, 
                                  input_shapes,
                                  batch=batch, 
                                  img_size=img_size, 
                                  model_dtype=model_dtype, 
                                  force_compile=force_compile)

        return model


    def forward(
        self,
        sample: torch.Tensor,
        timestep: Union[torch.Tensor, float, int],
        encoder_hidden_states: torch.Tensor,
        class_labels: Optional[torch.Tensor] = None,
        timestep_cond: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
        mid_block_additional_residual: Optional[torch.Tensor] = None,
        down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        return_dict: bool = True,
    ) -> Union[UNet2DConditionOutput, Tuple]:
        """
        reference: diffusers/blob/v0.34.0/src/diffusers/models/unets/unet_2d_condition.py#L1038
        """

        self.set_input_data(0, sample)
        self.set_input_data(1, timestep)
        self.set_input_data(2, encoder_hidden_states)
        if added_cond_kwargs is not None:
            self.set_input_data(3, added_cond_kwargs['text_embeds'])
            self.set_input_data(4, added_cond_kwargs['time_ids'])

        self.run_model(mode="sync")
        sample = self.get_output_data(0)

        if not return_dict:
            return (sample,)
        return UNet2DConditionOutput(sample=sample)

    @property
    def dtype(self):
        if hasattr(self, "_dtype") and self._dtype is not None:
            return self._dtype
        return super().dtype

    @property
    def device(self):
        if hasattr(self, '_device') and self._device is not None:
            return self._device
        return super().device
