import warnings
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline


sd_call = StableDiffusionPipeline.__call__
sdxl_call = StableDiffusionXLPipeline.__call__

def pipeline_call(self, *args, **kwargs):
    if "_mgx_models" in self.config and len(self.config._mgx_models) > 0:
        valid_batch = self.config._batch
        valid_height = self.config._img_height
        valid_width = self.config._img_width
        batch = kwargs.get("num_images_per_prompt", None)
        height = kwargs.get("height", None)
        width = kwargs.get("width", None)
        
        if batch is not None and batch != valid_batch:
            warnings.warn(
                f"Argument `num_images_per_prompt` mismatch between pipeline loading " \
                f"({valid_batch}) and runtime ({batch}). Forcing use loading-time " \
                f"value ({valid_batch}).",
                RuntimeWarning
            )
        if height is not None and height != valid_height:
            warnings.warn(
                f"Image height mismatch between pipeline loading ({valid_height}) and " \
                f"runtime ({height}). Forcing use loading-time height ({valid_height}).",
                RuntimeWarning
            )
        if width is not None and width != valid_width:
            warnings.warn(
                f"Image height mismatch between pipeline loading ({valid_width}) and " \
                f"runtime ({width}). Forcing use loading-time width ({valid_width}).",
                RuntimeWarning
            )

        kwargs["num_images_per_prompt"] = valid_batch
        kwargs["height"] = valid_height
        kwargs["width"] = valid_width

    if isinstance(self, StableDiffusionPipeline):
        return sd_call(self, *args, **kwargs)
    elif isinstance(self, StableDiffusionXLPipeline):
        return sdxl_call(self, *args, **kwargs)
    else:
        raise NotImplementedError("Not supported {self.__class__.__name__}!")


StableDiffusionPipeline.__call__ = pipeline_call
StableDiffusionXLPipeline.__call__ = pipeline_call
