import importlib
import inspect
import os
import os.path as osp
from typing import Optional, Union
from typing_extensions import Self

import torch
from transformers import AutoConfig
from huggingface_hub import read_dduf_file

from diffusers import __version__, DiffusionPipeline
from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from diffusers.quantizers import PipelineQuantizationConfig

from diffusers.utils import (
    _get_detailed_type,
    _is_valid_type,
    is_accelerate_available,
    is_accelerate_version,
    is_torch_version,
    logging,
)
from diffusers.utils.hub_utils import _check_legacy_sharding_variant_format
from diffusers.pipelines.pipeline_loading_utils import (
    ALL_IMPORTABLE_CLASSES,
    LOADABLE_CLASSES,
    _get_final_device_map,
    _get_pipeline_class,
    _identify_model_variants,
    _maybe_raise_error_for_incorrect_transformers,
    _maybe_raise_warning_for_inpainting,
    _maybe_warn_for_wrong_component_in_quant_config,
    _resolve_custom_pipeline_and_cls,
    _update_init_kwargs_with_connected_pipeline,
    load_sub_model,
    maybe_raise_or_warn,
)

from .utils import load_model_args


LIBRARIES = []
for library in LOADABLE_CLASSES:
    LIBRARIES.append(library)

SUPPORTED_DEVICE_MAP = ["balanced"]
logger = logging.get_logger(__name__)


@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs) -> Self:
    r"""
    Instantiate a PyTorch diffusion pipeline from pretrained pipeline weights.

    The pipeline is set in evaluation mode (`model.eval()`) by default.

    If you get the error message below, you need to finetune the weights for your downstream task:

    ```
    Some weights of UNet2DConditionModel were not initialized from the model checkpoint at stable-diffusion-v1-5/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
    - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
    You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
    ```

    Parameters:
        pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
            Can be either:

                - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
                    hosted on the Hub.
                - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
                    saved using
                [`~DiffusionPipeline.save_pretrained`].
                - A path to a *directory* (for example `./my_pipeline_directory/`) containing a dduf file
        torch_dtype (`torch.dtype` or `dict[str, Union[str, torch.dtype]]`, *optional*):
            Override the default `torch.dtype` and load the model with another dtype. To load submodels with
            different dtype pass a `dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`).
            Set the default dtype for unspecified components with `default` (for example `{'transformer':
            torch.bfloat16, 'default': torch.float16}`). If a component is not specified and no default is set,
            `torch.float32` is used.
        custom_pipeline (`str`, *optional*):

            <Tip warning={true}>

            🧪 This is an experimental feature and may change in the future.

            </Tip>

            Can be either:

                - A string, the *repo id* (for example `hf-internal-testing/diffusers-dummy-pipeline`) of a custom
                    pipeline hosted on the Hub. The repository must contain a file called pipeline.py that defines
                    the custom pipeline.
                - A string, the *file name* of a community pipeline hosted on GitHub under
                    [Community](https://github.com/huggingface/diffusers/tree/main/examples/community). Valid file
                    names must match the file name and not the pipeline script (`clip_guided_stable_diffusion`
                    instead of `clip_guided_stable_diffusion.py`). Community pipelines are always loaded from the
                    current main branch of GitHub.
                - A path to a directory (`./my_pipeline_directory/`) containing a custom pipeline. The directory
                    must contain a file called `pipeline.py` that defines the custom pipeline.

            For more information on how to load and create custom pipelines, please have a look at [Loading and
            Adding Custom
            Pipelines](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview)
        force_download (`bool`, *optional*, defaults to `False`):
            Whether or not to force the (re-)download of the model weights and configuration files, overriding the
            cached versions if they exist.
        cache_dir (`Union[str, os.PathLike]`, *optional*):
            Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
            is not used.

        proxies (`Dict[str, str]`, *optional*):
            A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
            'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
        output_loading_info(`bool`, *optional*, defaults to `False`):
            Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
        local_files_only (`bool`, *optional*, defaults to `False`):
            Whether to only load local model weights and configuration files or not. If set to `True`, the model
            won't be downloaded from the Hub.
        token (`str` or *bool*, *optional*):
            The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
            `diffusers-cli login` (stored in `~/.huggingface`) is used.
        revision (`str`, *optional*, defaults to `"main"`):
            The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
            allowed by Git.
        custom_revision (`str`, *optional*):
            The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
            `revision` when loading a custom pipeline from the Hub. Defaults to the latest stable 🤗 Diffusers
            version.
        mirror (`str`, *optional*):
            Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
            guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
            information.
        device_map (`str`, *optional*):
            Strategy that dictates how the different components of a pipeline should be placed on available
            devices. Currently, only "balanced" `device_map` is supported. Check out
            [this](https://huggingface.co/docs/diffusers/main/en/tutorials/inference_with_big_models#device-placement)
            to know more.
        max_memory (`Dict`, *optional*):
            A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
            each GPU and the available CPU RAM if unset.
        offload_folder (`str` or `os.PathLike`, *optional*):
            The path to offload weights if device_map contains the value `"disk"`.
        offload_state_dict (`bool`, *optional*):
            If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
            the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
            when there is some disk offload.
        low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
            Speed up model loading only loading the pretrained weights and not initializing the weights. This also
            tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
            Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
            argument to `True` will raise an error.
        use_safetensors (`bool`, *optional*, defaults to `None`):
            If set to `None`, the safetensors weights are downloaded if they're available **and** if the
            safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
            weights. If set to `False`, safetensors weights are not loaded.
        use_onnx (`bool`, *optional*, defaults to `None`):
            If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
            will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
            `False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
            with `.onnx` and `.pb`.
        kwargs (remaining dictionary of keyword arguments, *optional*):
            Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
            class). The overwritten components are passed directly to the pipelines `__init__` method. See example
            below for more information.
        variant (`str`, *optional*):
            Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
            loading `from_flax`.
        dduf_file(`str`, *optional*):
            Load weights from the specified dduf file.

    <Tip>

    To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
    `huggingface-cli login`.

    </Tip>

    Examples:

    ```py
    >>> from diffusers import DiffusionPipeline

    >>> # Download pipeline from huggingface.co and cache.
    >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")

    >>> # Download pipeline that requires an authorization token
    >>> # For more information on access tokens, please refer to this section
    >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
    >>> pipeline = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")

    >>> # Use a different scheduler
    >>> from diffusers import LMSDiscreteScheduler

    >>> scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
    >>> pipeline.scheduler = scheduler
    ```
    """
    # Copy the kwargs to re-use during loading connected pipeline.
    kwargs_copied = kwargs.copy()

    cache_dir = kwargs.pop("cache_dir", None)
    force_download = kwargs.pop("force_download", False)
    proxies = kwargs.pop("proxies", None)
    local_files_only = kwargs.pop("local_files_only", None)
    token = kwargs.pop("token", None)
    revision = kwargs.pop("revision", None)
    from_flax = kwargs.pop("from_flax", False)
    torch_dtype = kwargs.pop("torch_dtype", None)
    custom_pipeline = kwargs.pop("custom_pipeline", None)
    custom_revision = kwargs.pop("custom_revision", None)
    provider = kwargs.pop("provider", None)
    sess_options = kwargs.pop("sess_options", None)
    provider_options = kwargs.pop("provider_options", None)
    device_map = kwargs.pop("device_map", None)
    max_memory = kwargs.pop("max_memory", None)
    offload_folder = kwargs.pop("offload_folder", None)
    offload_state_dict = kwargs.pop("offload_state_dict", None)
    low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
    variant = kwargs.pop("variant", None)
    dduf_file = kwargs.pop("dduf_file", None)
    use_safetensors = kwargs.pop("use_safetensors", None)
    use_onnx = kwargs.pop("use_onnx", None)
    load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
    quantization_config = kwargs.pop("quantization_config", None)
    migraphx_config = kwargs.pop("migraphx_config", None)

    mgx_model_args = None if migraphx_config is None else load_model_args(migraphx_config)

    if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype):
        torch_dtype = torch.float32
        logger.warning(
            f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
        )

    if low_cpu_mem_usage and not is_accelerate_available():
        low_cpu_mem_usage = False
        logger.warning(
            "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
            " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
            " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
            " install accelerate\n```\n."
        )

    if quantization_config is not None and not isinstance(quantization_config, PipelineQuantizationConfig):
        raise ValueError("`quantization_config` must be an instance of `PipelineQuantizationConfig`.")

    if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
        raise NotImplementedError(
            "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
            " `low_cpu_mem_usage=False`."
        )

    if device_map is not None and not is_torch_version(">=", "1.9.0"):
        raise NotImplementedError(
            "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
            " `device_map=None`."
        )

    if device_map is not None and not is_accelerate_available():
        raise NotImplementedError(
            "Using `device_map` requires the `accelerate` library. Please install it using: `pip install accelerate`."
        )

    if device_map is not None and not isinstance(device_map, str):
        raise ValueError("`device_map` must be a string.")

    if device_map is not None and device_map not in SUPPORTED_DEVICE_MAP:
        raise NotImplementedError(
            f"{device_map} not supported. Supported strategies are: {', '.join(SUPPORTED_DEVICE_MAP)}"
        )

    if device_map is not None and device_map in SUPPORTED_DEVICE_MAP:
        if is_accelerate_version("<", "0.28.0"):
            raise NotImplementedError("Device placement requires `accelerate` version `0.28.0` or later.")

    if low_cpu_mem_usage is False and device_map is not None:
        raise ValueError(
            f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
            " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
        )

    if dduf_file:
        if custom_pipeline:
            raise NotImplementedError("Custom pipelines are not supported with DDUF at the moment.")
        if load_connected_pipeline:
            raise NotImplementedError("Connected pipelines are not supported with DDUF at the moment.")

    # 1. Download the checkpoints and configs
    # use snapshot download here to get it working from from_pretrained
    if not os.path.isdir(pretrained_model_name_or_path):
        if pretrained_model_name_or_path.count("/") > 1:
            raise ValueError(
                f'The provided pretrained_model_name_or_path "{pretrained_model_name_or_path}"'
                " is neither a valid local path nor a valid repo id. Please check the parameter."
            )
        cached_folder = cls.download(
            pretrained_model_name_or_path,
            cache_dir=cache_dir,
            force_download=force_download,
            proxies=proxies,
            local_files_only=local_files_only,
            token=token,
            revision=revision,
            from_flax=from_flax,
            use_safetensors=use_safetensors,
            use_onnx=use_onnx,
            custom_pipeline=custom_pipeline,
            custom_revision=custom_revision,
            variant=variant,
            dduf_file=dduf_file,
            load_connected_pipeline=load_connected_pipeline,
            **kwargs,
        )
    else:
        cached_folder = pretrained_model_name_or_path

    # The variant filenames can have the legacy sharding checkpoint format that we check and throw
    # a warning if detected.
    if variant is not None and _check_legacy_sharding_variant_format(folder=cached_folder, variant=variant):
        warn_msg = (
            f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
            "Please check your files carefully:\n\n"
            "- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
            "- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
            "If you find any files in the deprecated format:\n"
            "1. Remove all existing checkpoint files for this variant.\n"
            "2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
            "This will ensure you're using the most up-to-date and compatible checkpoint format."
        )
        logger.warning(warn_msg)

    dduf_entries = None
    if dduf_file:
        dduf_file_path = os.path.join(cached_folder, dduf_file)
        dduf_entries = read_dduf_file(dduf_file_path)
        # The reader contains already all the files needed, no need to check it again
        cached_folder = ""

    config_dict = cls.load_config(cached_folder, dduf_entries=dduf_entries)

    if dduf_file:
        _maybe_raise_error_for_incorrect_transformers(config_dict)

    # pop out "_ignore_files" as it is only needed for download
    config_dict.pop("_ignore_files", None)

    # 2. Define which model components should load variants
    # We retrieve the information by matching whether variant model checkpoints exist in the subfolders.
    # Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors`
    # with variant being `"fp16"`.
    model_variants = _identify_model_variants(folder=cached_folder, variant=variant, config=config_dict)
    if len(model_variants) == 0 and variant is not None and (mgx_model_args is None or len(mgx_model_args) <= 1):
        error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
        raise ValueError(error_message)

    # 3. Load the pipeline class, if using custom module then load it from the hub
    # if we load from explicit class, let's use it
    custom_pipeline, custom_class_name = _resolve_custom_pipeline_and_cls(
        folder=cached_folder, config=config_dict, custom_pipeline=custom_pipeline
    )
    pipeline_class = _get_pipeline_class(
        cls,
        config=config_dict,
        load_connected_pipeline=load_connected_pipeline,
        custom_pipeline=custom_pipeline,
        class_name=custom_class_name,
        cache_dir=cache_dir,
        revision=custom_revision,
    )

    if device_map is not None and pipeline_class._load_connected_pipes:
        raise NotImplementedError("`device_map` is not yet supported for connected pipelines.")

    # DEPRECATED: To be removed in 1.0.0
    # we are deprecating the `StableDiffusionInpaintPipelineLegacy` pipeline which gets loaded
    # when a user requests for a `StableDiffusionInpaintPipeline` with `diffusers` version being <= 0.5.1.
    _maybe_raise_warning_for_inpainting(
        pipeline_class=pipeline_class,
        pretrained_model_name_or_path=pretrained_model_name_or_path,
        config=config_dict,
    )

    # 4. Define expected modules given pipeline signature
    # and define non-None initialized modules (=`init_kwargs`)

    # some modules can be passed directly to the init
    # in this case they are already instantiated in `kwargs`
    # extract them here
    expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
    expected_types = pipeline_class._get_signature_types()
    passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
    passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
    init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)

    # define init kwargs and make sure that optional component modules are filtered out
    init_kwargs = {
        k: init_dict.pop(k)
        for k in optional_kwargs
        if k in init_dict and k not in pipeline_class._optional_components
    }
    init_kwargs = {**init_kwargs, **passed_pipe_kwargs}

    # remove `null` components
    def load_module(name, value):
        if value[0] is None:
            return False
        if name in passed_class_obj and passed_class_obj[name] is None:
            return False
        return True

    init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}

    # Special case: safety_checker must be loaded separately when using `from_flax`
    if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
        raise NotImplementedError(
            "The safety checker cannot be automatically loaded when loading weights `from_flax`."
            " Please, pass `safety_checker=None` to `from_pretrained`, and load the safety checker"
            " separately if you need it."
        )

    # 5. Throw nice warnings / errors for fast accelerate loading
    if len(unused_kwargs) > 0:
        logger.warning(
            f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
        )

    # import it here to avoid circular import
    from diffusers import pipelines

    # 6. device map delegation
    final_device_map = None
    if device_map is not None:
        final_device_map = _get_final_device_map(
            device_map=device_map,
            pipeline_class=pipeline_class,
            passed_class_obj=passed_class_obj,
            init_dict=init_dict,
            library=library,
            max_memory=max_memory,
            torch_dtype=torch_dtype,
            cached_folder=cached_folder,
            force_download=force_download,
            proxies=proxies,
            local_files_only=local_files_only,
            token=token,
            revision=revision,
        )
    
    # 7. Load each module in the pipeline
    current_device_map = None
    _maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config)
    mxr_paths = []
    mgx_models = []
    for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):

        # 7.0 load migraphx models
        if mgx_model_args is not None and name in mgx_model_args:
            class_name = "MIGraphX" + class_name
            mgx_sd_module = importlib.import_module("..", package=__name__)
            class_obj = getattr(mgx_sd_module, class_name, None)
            if class_obj is None:
                raise NotImplementedError(
                    f"Class {class_name} is not implemented in package mgx_sd."
                )
            _subfolder = name
            if name == 'vae' and osp.isdir(osp.join(pretrained_model_name_or_path, 'vae_decoder')):
                _subfolder = 'vae_decoder'
            loaded_sub_model = class_obj.from_pretrained(
                pretrained_model_name_or_path, subfolder=_subfolder, **mgx_model_args[name], 
                pipeline_class=pipeline_class
            )

            init_kwargs[name] = loaded_sub_model
            mxr_paths.append((name, loaded_sub_model.mxr_path))
            mgx_models.append(name)
            continue

        # 7.1 device_map shenanigans
        if final_device_map is not None and len(final_device_map) > 0:
            component_device = final_device_map.get(name, None)
            if component_device is not None:
                current_device_map = {"": component_device}
            else:
                current_device_map = None

        # 7.2 - now that JAX/Flax is an official framework of the library, we might load from Flax names
        class_name = class_name[4:] if class_name.startswith("Flax") else class_name

        # 7.3 Define all importable classes
        is_pipeline_module = hasattr(pipelines, library_name)
        importable_classes = ALL_IMPORTABLE_CLASSES
        loaded_sub_model = None

        # 7.4 Use passed sub model or load class_name from library_name
        if name in passed_class_obj:
            # if the model is in a pipeline module, then we load it from the pipeline
            # check that passed_class_obj has correct parent class
            maybe_raise_or_warn(
                library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
            )

            loaded_sub_model = passed_class_obj[name]
        else:
            # load sub model
            sub_model_dtype = (
                torch_dtype.get(name, torch_dtype.get("default", torch.float32))
                if isinstance(torch_dtype, dict)
                else torch_dtype
            )
            loaded_sub_model = load_sub_model(
                library_name=library_name,
                class_name=class_name,
                importable_classes=importable_classes,
                pipelines=pipelines,
                is_pipeline_module=is_pipeline_module,
                pipeline_class=pipeline_class,
                torch_dtype=sub_model_dtype,
                provider=provider,
                sess_options=sess_options,
                device_map=current_device_map,
                max_memory=max_memory,
                offload_folder=offload_folder,
                offload_state_dict=offload_state_dict,
                model_variants=model_variants,
                name=name,
                from_flax=from_flax,
                variant=variant,
                low_cpu_mem_usage=low_cpu_mem_usage,
                cached_folder=cached_folder,
                use_safetensors=use_safetensors,
                dduf_entries=dduf_entries,
                provider_options=provider_options,
                quantization_config=quantization_config,
            )
            logger.info(
                f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
            )

        init_kwargs[name] = loaded_sub_model  # UNet(...), # DiffusionSchedule(...)
    
    for name, mxr_path in mxr_paths:
        logger.info(f"{name} used migraphx model: {mxr_path}")

    # 8. Handle connected pipelines.
    if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")):
        init_kwargs = _update_init_kwargs_with_connected_pipeline(
            init_kwargs=init_kwargs,
            passed_pipe_kwargs=passed_pipe_kwargs,
            passed_class_objs=passed_class_obj,
            folder=cached_folder,
            **kwargs_copied,
        )

    # 9. Potentially add passed objects if expected
    missing_modules = set(expected_modules) - set(init_kwargs.keys())
    passed_modules = list(passed_class_obj.keys())
    optional_modules = pipeline_class._optional_components
    if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
        for module in missing_modules:
            init_kwargs[module] = passed_class_obj.get(module, None)
    elif len(missing_modules) > 0:
        passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - set(optional_kwargs)
        raise ValueError(
            f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
        )

    # 10. Type checking init arguments
    for kw, arg in init_kwargs.items():
        # Too complex to validate with type annotation alone
        if "scheduler" in kw:
            continue
        # Many tokenizer annotations don't include its "Fast" variant, so skip this
        # e.g T5Tokenizer but not T5TokenizerFast
        elif "tokenizer" in kw:
            continue
        elif (
            arg is not None  # Skip if None
            and not expected_types[kw] == (inspect.Signature.empty,)  # Skip if no type annotations
            and not _is_valid_type(arg, expected_types[kw])  # Check type
        ):
            logger.warning(f"Expected types for {kw}: {expected_types[kw]}, got {_get_detailed_type(arg)}.")

    # 11. Instantiate the pipeline
    model = pipeline_class(**init_kwargs)
    
    if mgx_model_args is not None and 'pipeline' in mgx_model_args and len(mgx_models) > 0:
        model.register_to_config(_mgx_models=tuple(mgx_models))
        model.register_to_config(_batch=mgx_model_args['pipeline']['batch'])
        model.register_to_config(_img_height=mgx_model_args['pipeline']['img_size'])
        model.register_to_config(_img_width=mgx_model_args['pipeline']['img_size'])

    # 12. Save where the model was instantiated from
    model.register_to_config(_name_or_path=pretrained_model_name_or_path)
    if device_map is not None:
        setattr(model, "hf_device_map", final_device_map)
    return model


DiffusionPipeline.from_pretrained = from_pretrained
