import os
import os.path as osp
from typing import Optional, Union
from typing_extensions import Self
import warnings

import accelerate
from diffusers import AutoencoderKL
from diffusers.models.modeling_utils import ContextManagers, no_init_weights
from diffusers.models.autoencoders.vae import DecoderOutput
import torch

from .migraphx_model import DTYPE_MAPPING, MIGraphXModel


class MIGraphXAutoencoderKL(AutoencoderKL, MIGraphXModel):

    @classmethod
    def from_pretrained(
            cls, 
            model_dir: str, 
            subfolder: Optional[str] = None,
            batch: Optional[int] = 1, 
            img_size: Optional[int] = 1024, 
            model_dtype: Optional[str] = 'fp16', 
            force_compile: Optional[bool] = False,
            **kwargs
        ) -> Self:
        
        _subfolder = None if subfolder is None else subfolder.strip()
        _model_dir = osp.join(model_dir, _subfolder) if subfolder else model_dir

        model_cfg = cls.load_config(_model_dir)
        
        # set torch dtype
        dtype_orig = None
        torch_dtype = DTYPE_MAPPING.get(model_dtype, None)
        if torch_dtype is not None and not torch_dtype == getattr(torch, "float8_e4m3fn", None):
            dtype_orig = cls._set_default_torch_dtype(torch_dtype)

        if "force_upcast" in model_cfg and model_cfg["force_upcast"]:
            warnings.warn(
                f"Not support force_upcast, run vae main model still with dtype {torch_dtype}.",
                RuntimeWarning
            )
            model_cfg["force_upcast"] = False

        # initialize model
        init_contexts = [no_init_weights(), accelerate.init_empty_weights()]
        with ContextManagers(init_contexts):
            model = cls.from_config(model_cfg)
        
        # restore torch dtype
        if dtype_orig is not None:
            torch.set_default_dtype(dtype_orig)

        # use MIGraphX model instead of PyTorch, so remove unsueful PyTorch module
        module_names = [module_name for module_name, _ in model.named_children()]
        for module_name in module_names:
            delattr(model, module_name)
            # del model._modules[module_name]

        # set attributes
        model._dtype = torch_dtype
        model._device = torch.device("cuda")

        img_sz_scale = 1024 // 128
        shapes = {
            "latent_sample": [1 * batch, 4, (img_size // img_sz_scale), (img_size // img_sz_scale)]
        }

        # orig_env_value1 = os.environ.get("MIGRAPHX_ENABLE_MIOPEN_GROUPNORM", None)
        # orig_env_value2 = os.environ.get("MIGRAPHX_ENABLE_NHWC", None)
        # os.environ["MIGRAPHX_ENABLE_MIOPEN_GROUPNORM"] = "1"
        # os.environ['MIGRAPHX_ENABLE_NHWC'] = '1'
        model.load_migraphx_model(_model_dir, 
                                  shapes,
                                  batch=batch, 
                                  img_size=img_size, 
                                  model_dtype=model_dtype, 
                                  force_compile=force_compile)
        # if orig_env_value1 is None:
        #     os.environ.pop("MIGRAPHX_ENABLE_MIOPEN_GROUPNORM")
        # else:
        #     os.environ["MIGRAPHX_ENABLE_MIOPEN_GROUPNORM"] = orig_env_value1
        # if orig_env_value2 is None:
        #     os.environ.pop("MIGRAPHX_ENABLE_NHWC")
        # else:
        #     os.environ["MIGRAPHX_ENABLE_NHWC"] = orig_env_value2

        return model

    @property
    def dtype(self):
        if hasattr(self, "_dtype") and self._dtype is not None:
            return self._dtype
        return super().dtype

    @property
    def device(self):
        if hasattr(self, '_device') and self._device is not None:
            return self._device
        return super().device

    def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
        """
        reference: diffusers/blob/v0.34.0/src/diffusers/models/autoencoders/autoencoder_kl.py#L287
        """
        
        if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
            return self.tiled_decode(z, return_dict=return_dict)

        self.set_input_data("latent_sample", z)
        self.run_model(mode="sync")
        dec = self.get_output_data(0)

        if not return_dict:
            return (dec,)

        return DecoderOutput(sample=dec)

    def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
        """
        reference: diffusers/blob/v0.34.0/src/diffusers/models/autoencoders/autoencoder_kl.py#L452
        """

        overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
        blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
        row_limit = self.tile_sample_min_size - blend_extent

        # Split z into overlapping 64x64 tiles and decode them separately.
        # The tiles have an overlap to avoid seams between tiles.
        rows = []
        for i in range(0, z.shape[2], overlap_size):
            row = []
            for j in range(0, z.shape[3], overlap_size):
                tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
                # if self.config.use_post_quant_conv:
                #     tile = self.post_quant_conv(tile)
                # decoded = self.decoder(tile)
                self.set_input_data("latent_sample", tile)
                self.run_model(mode="sync")
                decoded = self.get_output_data(0)
                row.append(decoded)
            rows.append(row)
        result_rows = []
        for i, row in enumerate(rows):
            result_row = []
            for j, tile in enumerate(row):
                # blend the above tile and the left tile
                # to the current tile and add the current tile to the result row
                if i > 0:
                    tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
                if j > 0:
                    tile = self.blend_h(row[j - 1], tile, blend_extent)
                result_row.append(tile[:, :, :row_limit, :row_limit])
            result_rows.append(torch.cat(result_row, dim=3))

        dec = torch.cat(result_rows, dim=2)
        if not return_dict:
            return (dec,)

        return DecoderOutput(sample=dec)
