Unverified Commit 9c7e2051 authored by Daniel Regado's avatar Daniel Regado Committed by GitHub
Browse files

Comprehensive type checking for `from_pretrained` kwargs (#10758)



* More robust from_pretrained init_kwargs type checking

* Corrected for Python 3.10

* Type checks subclasses and fixed type warnings

* More type corrections and skip tokenizer type checking

* make style && make quality

* Updated docs and types for Lumina pipelines

* Fixed check for empty signature

* changed location of helper functions

* make style

---------
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent 64dec70e
...@@ -224,7 +224,7 @@ class AnimateDiffVideoToVideoPipeline( ...@@ -224,7 +224,7 @@ class AnimateDiffVideoToVideoPipeline(
vae: AutoencoderKL, vae: AutoencoderKL,
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel, unet: Union[UNet2DConditionModel, UNetMotionModel],
motion_adapter: MotionAdapter, motion_adapter: MotionAdapter,
scheduler: Union[ scheduler: Union[
DDIMScheduler, DDIMScheduler,
......
...@@ -246,7 +246,7 @@ class AnimateDiffVideoToVideoControlNetPipeline( ...@@ -246,7 +246,7 @@ class AnimateDiffVideoToVideoControlNetPipeline(
vae: AutoencoderKL, vae: AutoencoderKL,
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel, unet: Union[UNet2DConditionModel, UNetMotionModel],
motion_adapter: MotionAdapter, motion_adapter: MotionAdapter,
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
scheduler: Union[ scheduler: Union[
......
...@@ -232,8 +232,8 @@ class HunyuanDiTControlNetPipeline(DiffusionPipeline): ...@@ -232,8 +232,8 @@ class HunyuanDiTControlNetPipeline(DiffusionPipeline):
Tuple[HunyuanDiT2DControlNetModel], Tuple[HunyuanDiT2DControlNetModel],
HunyuanDiT2DMultiControlNetModel, HunyuanDiT2DMultiControlNetModel,
], ],
text_encoder_2=T5EncoderModel, text_encoder_2: Optional[T5EncoderModel] = None,
tokenizer_2=MT5Tokenizer, tokenizer_2: Optional[MT5Tokenizer] = None,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
......
...@@ -17,10 +17,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union ...@@ -17,10 +17,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
from transformers import ( from transformers import (
BaseImageProcessor,
CLIPTextModelWithProjection, CLIPTextModelWithProjection,
CLIPTokenizer, CLIPTokenizer,
PreTrainedModel, SiglipImageProcessor,
SiglipVisionModel,
T5EncoderModel, T5EncoderModel,
T5TokenizerFast, T5TokenizerFast,
) )
...@@ -178,9 +178,9 @@ class StableDiffusion3ControlNetPipeline( ...@@ -178,9 +178,9 @@ class StableDiffusion3ControlNetPipeline(
Provides additional conditioning to the `unet` during the denoising process. If you set multiple Provides additional conditioning to the `unet` during the denoising process. If you set multiple
ControlNets as a list, the outputs from each ControlNet are added together to create one combined ControlNets as a list, the outputs from each ControlNet are added together to create one combined
additional conditioning. additional conditioning.
image_encoder (`PreTrainedModel`, *optional*): image_encoder (`SiglipVisionModel`, *optional*):
Pre-trained Vision Model for IP Adapter. Pre-trained Vision Model for IP Adapter.
feature_extractor (`BaseImageProcessor`, *optional*): feature_extractor (`SiglipImageProcessor`, *optional*):
Image processor for IP Adapter. Image processor for IP Adapter.
""" """
...@@ -202,8 +202,8 @@ class StableDiffusion3ControlNetPipeline( ...@@ -202,8 +202,8 @@ class StableDiffusion3ControlNetPipeline(
controlnet: Union[ controlnet: Union[
SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel
], ],
image_encoder: PreTrainedModel = None, image_encoder: Optional[SiglipVisionModel] = None,
feature_extractor: BaseImageProcessor = None, feature_extractor: Optional[SiglipImageProcessor] = None,
): ):
super().__init__() super().__init__()
if isinstance(controlnet, (list, tuple)): if isinstance(controlnet, (list, tuple)):
......
...@@ -17,10 +17,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union ...@@ -17,10 +17,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
from transformers import ( from transformers import (
BaseImageProcessor,
CLIPTextModelWithProjection, CLIPTextModelWithProjection,
CLIPTokenizer, CLIPTokenizer,
PreTrainedModel, SiglipImageProcessor,
SiglipModel,
T5EncoderModel, T5EncoderModel,
T5TokenizerFast, T5TokenizerFast,
) )
...@@ -223,8 +223,8 @@ class StableDiffusion3ControlNetInpaintingPipeline( ...@@ -223,8 +223,8 @@ class StableDiffusion3ControlNetInpaintingPipeline(
controlnet: Union[ controlnet: Union[
SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel
], ],
image_encoder: PreTrainedModel = None, image_encoder: SiglipModel = None,
feature_extractor: BaseImageProcessor = None, feature_extractor: Optional[SiglipImageProcessor] = None,
): ):
super().__init__() super().__init__()
......
...@@ -17,6 +17,8 @@ from typing import List, Optional, Tuple, Union ...@@ -17,6 +17,8 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from ...models import UNet1DModel
from ...schedulers import SchedulerMixin
from ...utils import is_torch_xla_available, logging from ...utils import is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
...@@ -49,7 +51,7 @@ class DanceDiffusionPipeline(DiffusionPipeline): ...@@ -49,7 +51,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
model_cpu_offload_seq = "unet" model_cpu_offload_seq = "unet"
def __init__(self, unet, scheduler): def __init__(self, unet: UNet1DModel, scheduler: SchedulerMixin):
super().__init__() super().__init__()
self.register_modules(unet=unet, scheduler=scheduler) self.register_modules(unet=unet, scheduler=scheduler)
......
...@@ -16,6 +16,7 @@ from typing import List, Optional, Tuple, Union ...@@ -16,6 +16,7 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from ...models import UNet2DModel
from ...schedulers import DDIMScheduler from ...schedulers import DDIMScheduler
from ...utils import is_torch_xla_available from ...utils import is_torch_xla_available
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
...@@ -47,7 +48,7 @@ class DDIMPipeline(DiffusionPipeline): ...@@ -47,7 +48,7 @@ class DDIMPipeline(DiffusionPipeline):
model_cpu_offload_seq = "unet" model_cpu_offload_seq = "unet"
def __init__(self, unet, scheduler): def __init__(self, unet: UNet2DModel, scheduler: DDIMScheduler):
super().__init__() super().__init__()
# make sure scheduler can always be converted to DDIM # make sure scheduler can always be converted to DDIM
......
...@@ -17,6 +17,8 @@ from typing import List, Optional, Tuple, Union ...@@ -17,6 +17,8 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from ...models import UNet2DModel
from ...schedulers import DDPMScheduler
from ...utils import is_torch_xla_available from ...utils import is_torch_xla_available
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
...@@ -47,7 +49,7 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -47,7 +49,7 @@ class DDPMPipeline(DiffusionPipeline):
model_cpu_offload_seq = "unet" model_cpu_offload_seq = "unet"
def __init__(self, unet, scheduler): def __init__(self, unet: UNet2DModel, scheduler: DDPMScheduler):
super().__init__() super().__init__()
self.register_modules(unet=unet, scheduler=scheduler) self.register_modules(unet=unet, scheduler=scheduler)
......
...@@ -91,7 +91,7 @@ class RePaintPipeline(DiffusionPipeline): ...@@ -91,7 +91,7 @@ class RePaintPipeline(DiffusionPipeline):
scheduler: RePaintScheduler scheduler: RePaintScheduler
model_cpu_offload_seq = "unet" model_cpu_offload_seq = "unet"
def __init__(self, unet, scheduler): def __init__(self, unet: UNet2DModel, scheduler: RePaintScheduler):
super().__init__() super().__init__()
self.register_modules(unet=unet, scheduler=scheduler) self.register_modules(unet=unet, scheduler=scheduler)
......
...@@ -207,8 +207,8 @@ class HunyuanDiTPipeline(DiffusionPipeline): ...@@ -207,8 +207,8 @@ class HunyuanDiTPipeline(DiffusionPipeline):
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor, feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
text_encoder_2=T5EncoderModel, text_encoder_2: Optional[T5EncoderModel] = None,
tokenizer_2=MT5Tokenizer, tokenizer_2: Optional[MT5Tokenizer] = None,
): ):
super().__init__() super().__init__()
......
...@@ -20,7 +20,7 @@ import urllib.parse as ul ...@@ -20,7 +20,7 @@ import urllib.parse as ul
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
import torch import torch
from transformers import AutoModel, AutoTokenizer from transformers import GemmaPreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
...@@ -144,13 +144,10 @@ class LuminaText2ImgPipeline(DiffusionPipeline): ...@@ -144,13 +144,10 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`AutoModel`]): text_encoder ([`GemmaPreTrainedModel`]):
Frozen text-encoder. Lumina-T2I uses Frozen Gemma text-encoder.
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`):
[t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant. Gemma tokenizer.
tokenizer (`AutoModel`):
Tokenizer of class
[AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel).
transformer ([`Transformer2DModel`]): transformer ([`Transformer2DModel`]):
A text conditioned `Transformer2DModel` to denoise the encoded image latents. A text conditioned `Transformer2DModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]): scheduler ([`SchedulerMixin`]):
...@@ -185,8 +182,8 @@ class LuminaText2ImgPipeline(DiffusionPipeline): ...@@ -185,8 +182,8 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
transformer: LuminaNextDiT2DModel, transformer: LuminaNextDiT2DModel,
scheduler: FlowMatchEulerDiscreteScheduler, scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL, vae: AutoencoderKL,
text_encoder: AutoModel, text_encoder: GemmaPreTrainedModel,
tokenizer: AutoTokenizer, tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
): ):
super().__init__() super().__init__()
......
...@@ -17,7 +17,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union ...@@ -17,7 +17,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from transformers import AutoModel, AutoTokenizer from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import Lumina2LoraLoaderMixin from ...loaders import Lumina2LoraLoaderMixin
...@@ -143,13 +143,10 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline, Lumina2LoraLoaderMixin): ...@@ -143,13 +143,10 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline, Lumina2LoraLoaderMixin):
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`AutoModel`]): text_encoder ([`Gemma2PreTrainedModel`]):
Frozen text-encoder. Lumina-T2I uses Frozen Gemma2 text-encoder.
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`):
[t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant. Gemma tokenizer.
tokenizer (`AutoModel`):
Tokenizer of class
[AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel).
transformer ([`Transformer2DModel`]): transformer ([`Transformer2DModel`]):
A text conditioned `Transformer2DModel` to denoise the encoded image latents. A text conditioned `Transformer2DModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]): scheduler ([`SchedulerMixin`]):
...@@ -165,8 +162,8 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline, Lumina2LoraLoaderMixin): ...@@ -165,8 +162,8 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline, Lumina2LoraLoaderMixin):
transformer: Lumina2Transformer2DModel, transformer: Lumina2Transformer2DModel,
scheduler: FlowMatchEulerDiscreteScheduler, scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL, vae: AutoencoderKL,
text_encoder: AutoModel, text_encoder: Gemma2PreTrainedModel,
tokenizer: AutoTokenizer, tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
): ):
super().__init__() super().__init__()
......
...@@ -20,7 +20,7 @@ import warnings ...@@ -20,7 +20,7 @@ import warnings
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
import torch import torch
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PixArtImageProcessor from ...image_processor import PixArtImageProcessor
...@@ -160,8 +160,8 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin): ...@@ -160,8 +160,8 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
def __init__( def __init__(
self, self,
tokenizer: AutoTokenizer, tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
text_encoder: AutoModelForCausalLM, text_encoder: Gemma2PreTrainedModel,
vae: AutoencoderDC, vae: AutoencoderDC,
transformer: SanaTransformer2DModel, transformer: SanaTransformer2DModel,
scheduler: FlowMatchEulerDiscreteScheduler, scheduler: FlowMatchEulerDiscreteScheduler,
......
...@@ -17,7 +17,7 @@ import os ...@@ -17,7 +17,7 @@ import os
import re import re
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, get_args, get_origin
import requests import requests
import torch import torch
...@@ -1059,3 +1059,76 @@ def _maybe_raise_error_for_incorrect_transformers(config_dict): ...@@ -1059,3 +1059,76 @@ def _maybe_raise_error_for_incorrect_transformers(config_dict):
break break
if has_transformers_component and not is_transformers_version(">", "4.47.1"): if has_transformers_component and not is_transformers_version(">", "4.47.1"):
raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.") raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.")
def _is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bool:
"""
Checks if an object is an instance of any of the provided types. For collections, it checks if every element is of
the correct type as well.
"""
if not isinstance(class_or_tuple, tuple):
class_or_tuple = (class_or_tuple,)
# Unpack unions
unpacked_class_or_tuple = []
for t in class_or_tuple:
if get_origin(t) is Union:
unpacked_class_or_tuple.extend(get_args(t))
else:
unpacked_class_or_tuple.append(t)
class_or_tuple = tuple(unpacked_class_or_tuple)
if Any in class_or_tuple:
return True
obj_type = type(obj)
# Classes with obj's type
class_or_tuple = {t for t in class_or_tuple if isinstance(obj, get_origin(t) or t)}
# Singular types (e.g. int, ControlNet, ...)
# Untyped collections (e.g. List, but not List[int])
elem_class_or_tuple = {get_args(t) for t in class_or_tuple}
if () in elem_class_or_tuple:
return True
# Typed lists or sets
elif obj_type in (list, set):
return any(all(_is_valid_type(x, t) for x in obj) for t in elem_class_or_tuple)
# Typed tuples
elif obj_type is tuple:
return any(
# Tuples with any length and single type (e.g. Tuple[int, ...])
(len(t) == 2 and t[-1] is Ellipsis and all(_is_valid_type(x, t[0]) for x in obj))
or
# Tuples with fixed length and any types (e.g. Tuple[int, str])
(len(obj) == len(t) and all(_is_valid_type(x, tt) for x, tt in zip(obj, t)))
for t in elem_class_or_tuple
)
# Typed dicts
elif obj_type is dict:
return any(
all(_is_valid_type(k, kt) and _is_valid_type(v, vt) for k, v in obj.items())
for kt, vt in elem_class_or_tuple
)
else:
return False
def _get_detailed_type(obj: Any) -> Type:
"""
Gets a detailed type for an object, including nested types for collections.
"""
obj_type = type(obj)
if obj_type in (list, set):
obj_origin_type = List if obj_type is list else Set
elems_type = Union[tuple({_get_detailed_type(x) for x in obj})]
return obj_origin_type[elems_type]
elif obj_type is tuple:
return Tuple[tuple(_get_detailed_type(x) for x in obj)]
elif obj_type is dict:
keys_type = Union[tuple({_get_detailed_type(k) for k in obj.keys()})]
values_type = Union[tuple({_get_detailed_type(k) for k in obj.values()})]
return Dict[keys_type, values_type]
else:
return obj_type
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import enum
import fnmatch import fnmatch
import importlib import importlib
import inspect import inspect
...@@ -79,10 +78,12 @@ from .pipeline_loading_utils import ( ...@@ -79,10 +78,12 @@ from .pipeline_loading_utils import (
_fetch_class_library_tuple, _fetch_class_library_tuple,
_get_custom_components_and_folders, _get_custom_components_and_folders,
_get_custom_pipeline_class, _get_custom_pipeline_class,
_get_detailed_type,
_get_final_device_map, _get_final_device_map,
_get_ignore_patterns, _get_ignore_patterns,
_get_pipeline_class, _get_pipeline_class,
_identify_model_variants, _identify_model_variants,
_is_valid_type,
_maybe_raise_error_for_incorrect_transformers, _maybe_raise_error_for_incorrect_transformers,
_maybe_raise_warning_for_inpainting, _maybe_raise_warning_for_inpainting,
_resolve_custom_pipeline_and_cls, _resolve_custom_pipeline_and_cls,
...@@ -876,26 +877,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -876,26 +877,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
for key in init_dict.keys():
if key not in passed_class_obj:
continue
if "scheduler" in key:
continue
class_obj = passed_class_obj[key]
_expected_class_types = []
for expected_type in expected_types[key]:
if isinstance(expected_type, enum.EnumMeta):
_expected_class_types.extend(expected_type.__members__.keys())
else:
_expected_class_types.append(expected_type.__name__)
_is_valid_type = class_obj.__class__.__name__ in _expected_class_types
if not _is_valid_type:
logger.warning(
f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}."
)
# Special case: safety_checker must be loaded separately when using `from_flax` # Special case: safety_checker must be loaded separately when using `from_flax`
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj: if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
raise NotImplementedError( raise NotImplementedError(
...@@ -1015,10 +996,26 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1015,10 +996,26 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
) )
# 10. Instantiate the pipeline # 10. Type checking init arguments
for kw, arg in init_kwargs.items():
# Too complex to validate with type annotation alone
if "scheduler" in kw:
continue
# Many tokenizer annotations don't include its "Fast" variant, so skip this
# e.g T5Tokenizer but not T5TokenizerFast
elif "tokenizer" in kw:
continue
elif (
arg is not None # Skip if None
and not expected_types[kw] == (inspect.Signature.empty,) # Skip if no type annotations
and not _is_valid_type(arg, expected_types[kw]) # Check type
):
logger.warning(f"Expected types for {kw}: {expected_types[kw]}, got {_get_detailed_type(arg)}.")
# 11. Instantiate the pipeline
model = pipeline_class(**init_kwargs) model = pipeline_class(**init_kwargs)
# 11. Save where the model was instantiated from # 12. Save where the model was instantiated from
model.register_to_config(_name_or_path=pretrained_model_name_or_path) model.register_to_config(_name_or_path=pretrained_model_name_or_path)
if device_map is not None: if device_map is not None:
setattr(model, "hf_device_map", final_device_map) setattr(model, "hf_device_map", final_device_map)
......
...@@ -20,7 +20,7 @@ import warnings ...@@ -20,7 +20,7 @@ import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PixArtImageProcessor from ...image_processor import PixArtImageProcessor
...@@ -200,8 +200,8 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin): ...@@ -200,8 +200,8 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
def __init__( def __init__(
self, self,
tokenizer: AutoTokenizer, tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
text_encoder: AutoModelForCausalLM, text_encoder: Gemma2PreTrainedModel,
vae: AutoencoderDC, vae: AutoencoderDC,
transformer: SanaTransformer2DModel, transformer: SanaTransformer2DModel,
scheduler: DPMSolverMultistepScheduler, scheduler: DPMSolverMultistepScheduler,
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from typing import Callable, Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
import torch import torch
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from ...models import StableCascadeUNet from ...models import StableCascadeUNet
from ...schedulers import DDPMWuerstchenScheduler from ...schedulers import DDPMWuerstchenScheduler
...@@ -65,7 +65,7 @@ class StableCascadeDecoderPipeline(DiffusionPipeline): ...@@ -65,7 +65,7 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
Args: Args:
tokenizer (`CLIPTokenizer`): tokenizer (`CLIPTokenizer`):
The CLIP tokenizer. The CLIP tokenizer.
text_encoder (`CLIPTextModel`): text_encoder (`CLIPTextModelWithProjection`):
The CLIP text encoder. The CLIP text encoder.
decoder ([`StableCascadeUNet`]): decoder ([`StableCascadeUNet`]):
The Stable Cascade decoder unet. The Stable Cascade decoder unet.
...@@ -93,7 +93,7 @@ class StableCascadeDecoderPipeline(DiffusionPipeline): ...@@ -93,7 +93,7 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
self, self,
decoder: StableCascadeUNet, decoder: StableCascadeUNet,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel, text_encoder: CLIPTextModelWithProjection,
scheduler: DDPMWuerstchenScheduler, scheduler: DDPMWuerstchenScheduler,
vqgan: PaellaVQModel, vqgan: PaellaVQModel,
latent_dim_scale: float = 10.67, latent_dim_scale: float = 10.67,
......
...@@ -15,7 +15,7 @@ from typing import Callable, Dict, List, Optional, Union ...@@ -15,7 +15,7 @@ from typing import Callable, Dict, List, Optional, Union
import PIL import PIL
import torch import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
from ...models import StableCascadeUNet from ...models import StableCascadeUNet
from ...schedulers import DDPMWuerstchenScheduler from ...schedulers import DDPMWuerstchenScheduler
...@@ -52,7 +52,7 @@ class StableCascadeCombinedPipeline(DiffusionPipeline): ...@@ -52,7 +52,7 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
Args: Args:
tokenizer (`CLIPTokenizer`): tokenizer (`CLIPTokenizer`):
The decoder tokenizer to be used for text inputs. The decoder tokenizer to be used for text inputs.
text_encoder (`CLIPTextModel`): text_encoder (`CLIPTextModelWithProjection`):
The decoder text encoder to be used for text inputs. The decoder text encoder to be used for text inputs.
decoder (`StableCascadeUNet`): decoder (`StableCascadeUNet`):
The decoder model to be used for decoder image generation pipeline. The decoder model to be used for decoder image generation pipeline.
...@@ -60,14 +60,18 @@ class StableCascadeCombinedPipeline(DiffusionPipeline): ...@@ -60,14 +60,18 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
The scheduler to be used for decoder image generation pipeline. The scheduler to be used for decoder image generation pipeline.
vqgan (`PaellaVQModel`): vqgan (`PaellaVQModel`):
The VQGAN model to be used for decoder image generation pipeline. The VQGAN model to be used for decoder image generation pipeline.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `image_encoder`.
image_encoder ([`CLIPVisionModelWithProjection`]):
Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
prior_prior (`StableCascadeUNet`): prior_prior (`StableCascadeUNet`):
The prior model to be used for prior pipeline. The prior model to be used for prior pipeline.
prior_text_encoder (`CLIPTextModelWithProjection`):
The prior text encoder to be used for text inputs.
prior_tokenizer (`CLIPTokenizer`):
The prior tokenizer to be used for text inputs.
prior_scheduler (`DDPMWuerstchenScheduler`): prior_scheduler (`DDPMWuerstchenScheduler`):
The scheduler to be used for prior pipeline. The scheduler to be used for prior pipeline.
prior_feature_extractor ([`~transformers.CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `image_encoder`.
prior_image_encoder ([`CLIPVisionModelWithProjection`]):
Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
""" """
_load_connected_pipes = True _load_connected_pipes = True
...@@ -76,12 +80,12 @@ class StableCascadeCombinedPipeline(DiffusionPipeline): ...@@ -76,12 +80,12 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
def __init__( def __init__(
self, self,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel, text_encoder: CLIPTextModelWithProjection,
decoder: StableCascadeUNet, decoder: StableCascadeUNet,
scheduler: DDPMWuerstchenScheduler, scheduler: DDPMWuerstchenScheduler,
vqgan: PaellaVQModel, vqgan: PaellaVQModel,
prior_prior: StableCascadeUNet, prior_prior: StableCascadeUNet,
prior_text_encoder: CLIPTextModel, prior_text_encoder: CLIPTextModelWithProjection,
prior_tokenizer: CLIPTokenizer, prior_tokenizer: CLIPTokenizer,
prior_scheduler: DDPMWuerstchenScheduler, prior_scheduler: DDPMWuerstchenScheduler,
prior_feature_extractor: Optional[CLIPImageProcessor] = None, prior_feature_extractor: Optional[CLIPImageProcessor] = None,
......
...@@ -141,7 +141,7 @@ class StableUnCLIPPipeline( ...@@ -141,7 +141,7 @@ class StableUnCLIPPipeline(
image_noising_scheduler: KarrasDiffusionSchedulers, image_noising_scheduler: KarrasDiffusionSchedulers,
# regular denoising components # regular denoising components
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModelWithProjection, text_encoder: CLIPTextModel,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
# vae # vae
......
...@@ -17,10 +17,10 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -17,10 +17,10 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
from transformers import ( from transformers import (
BaseImageProcessor,
CLIPTextModelWithProjection, CLIPTextModelWithProjection,
CLIPTokenizer, CLIPTokenizer,
PreTrainedModel, SiglipImageProcessor,
SiglipVisionModel,
T5EncoderModel, T5EncoderModel,
T5TokenizerFast, T5TokenizerFast,
) )
...@@ -176,9 +176,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -176,9 +176,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
tokenizer_3 (`T5TokenizerFast`): tokenizer_3 (`T5TokenizerFast`):
Tokenizer of class Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
image_encoder (`PreTrainedModel`, *optional*): image_encoder (`SiglipVisionModel`, *optional*):
Pre-trained Vision Model for IP Adapter. Pre-trained Vision Model for IP Adapter.
feature_extractor (`BaseImageProcessor`, *optional*): feature_extractor (`SiglipImageProcessor`, *optional*):
Image processor for IP Adapter. Image processor for IP Adapter.
""" """
...@@ -197,8 +197,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -197,8 +197,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
tokenizer_2: CLIPTokenizer, tokenizer_2: CLIPTokenizer,
text_encoder_3: T5EncoderModel, text_encoder_3: T5EncoderModel,
tokenizer_3: T5TokenizerFast, tokenizer_3: T5TokenizerFast,
image_encoder: PreTrainedModel = None, image_encoder: SiglipVisionModel = None,
feature_extractor: BaseImageProcessor = None, feature_extractor: SiglipImageProcessor = None,
): ):
super().__init__() super().__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment