Unverified Commit 9c7e2051 authored by Daniel Regado's avatar Daniel Regado Committed by GitHub
Browse files

Comprehensive type checking for `from_pretrained` kwargs (#10758)



* More robust from_pretrained init_kwargs type checking

* Corrected for Python 3.10

* Type checks subclasses and fixed type warnings

* More type corrections and skip tokenizer type checking

* make style && make quality

* Updated docs and types for Lumina pipelines

* Fixed check for empty signature

* changed location of helper functions

* make style

---------
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent 64dec70e
...@@ -18,10 +18,10 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -18,10 +18,10 @@ from typing import Any, Callable, Dict, List, Optional, Union
import PIL.Image import PIL.Image
import torch import torch
from transformers import ( from transformers import (
BaseImageProcessor,
CLIPTextModelWithProjection, CLIPTextModelWithProjection,
CLIPTokenizer, CLIPTokenizer,
PreTrainedModel, SiglipImageProcessor,
SiglipVisionModel,
T5EncoderModel, T5EncoderModel,
T5TokenizerFast, T5TokenizerFast,
) )
...@@ -197,6 +197,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro ...@@ -197,6 +197,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
tokenizer_3 (`T5TokenizerFast`): tokenizer_3 (`T5TokenizerFast`):
Tokenizer of class Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
image_encoder (`SiglipVisionModel`, *optional*):
Pre-trained Vision Model for IP Adapter.
feature_extractor (`SiglipImageProcessor`, *optional*):
Image processor for IP Adapter.
""" """
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae" model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
...@@ -214,8 +218,8 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro ...@@ -214,8 +218,8 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
tokenizer_2: CLIPTokenizer, tokenizer_2: CLIPTokenizer,
text_encoder_3: T5EncoderModel, text_encoder_3: T5EncoderModel,
tokenizer_3: T5TokenizerFast, tokenizer_3: T5TokenizerFast,
image_encoder: PreTrainedModel = None, image_encoder: Optional[SiglipVisionModel] = None,
feature_extractor: BaseImageProcessor = None, feature_extractor: Optional[SiglipImageProcessor] = None,
): ):
super().__init__() super().__init__()
......
...@@ -17,10 +17,10 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -17,10 +17,10 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
from transformers import ( from transformers import (
BaseImageProcessor,
CLIPTextModelWithProjection, CLIPTextModelWithProjection,
CLIPTokenizer, CLIPTokenizer,
PreTrainedModel, SiglipImageProcessor,
SiglipVisionModel,
T5EncoderModel, T5EncoderModel,
T5TokenizerFast, T5TokenizerFast,
) )
...@@ -196,9 +196,9 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro ...@@ -196,9 +196,9 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
tokenizer_3 (`T5TokenizerFast`): tokenizer_3 (`T5TokenizerFast`):
Tokenizer of class Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
image_encoder (`PreTrainedModel`, *optional*): image_encoder (`SiglipVisionModel`, *optional*):
Pre-trained Vision Model for IP Adapter. Pre-trained Vision Model for IP Adapter.
feature_extractor (`BaseImageProcessor`, *optional*): feature_extractor (`SiglipImageProcessor`, *optional*):
Image processor for IP Adapter. Image processor for IP Adapter.
""" """
...@@ -217,8 +217,8 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro ...@@ -217,8 +217,8 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
tokenizer_2: CLIPTokenizer, tokenizer_2: CLIPTokenizer,
text_encoder_3: T5EncoderModel, text_encoder_3: T5EncoderModel,
tokenizer_3: T5TokenizerFast, tokenizer_3: T5TokenizerFast,
image_encoder: PreTrainedModel = None, image_encoder: Optional[SiglipVisionModel] = None,
feature_extractor: BaseImageProcessor = None, feature_extractor: Optional[SiglipImageProcessor] = None,
): ):
super().__init__() super().__init__()
......
...@@ -19,15 +19,31 @@ from typing import Callable, List, Optional, Union ...@@ -19,15 +19,31 @@ from typing import Callable, List, Optional, Union
import torch import torch
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
CLIPTokenizer,
CLIPTokenizerFast,
)
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import (
StableDiffusionLoraLoaderMixin,
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import LMSDiscreteScheduler from ...schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -95,13 +111,13 @@ class StableDiffusionKDiffusionPipeline( ...@@ -95,13 +111,13 @@ class StableDiffusionKDiffusionPipeline(
def __init__( def __init__(
self, self,
vae, vae: AutoencoderKL,
text_encoder, text_encoder: CLIPTextModel,
tokenizer, tokenizer: Union[CLIPTokenizer, CLIPTokenizerFast],
unet, unet: UNet2DConditionModel,
scheduler, scheduler: KarrasDiffusionSchedulers,
safety_checker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor, feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
......
...@@ -18,7 +18,7 @@ from typing import Optional, Tuple, Union ...@@ -18,7 +18,7 @@ from typing import Optional, Tuple, Union
import torch import torch
from diffusers import DiffusionPipeline, ImagePipelineOutput from diffusers import DiffusionPipeline, ImagePipelineOutput, SchedulerMixin, UNet2DModel
class CustomLocalPipeline(DiffusionPipeline): class CustomLocalPipeline(DiffusionPipeline):
...@@ -33,7 +33,7 @@ class CustomLocalPipeline(DiffusionPipeline): ...@@ -33,7 +33,7 @@ class CustomLocalPipeline(DiffusionPipeline):
[`DDPMScheduler`], or [`DDIMScheduler`]. [`DDPMScheduler`], or [`DDIMScheduler`].
""" """
def __init__(self, unet, scheduler): def __init__(self, unet: UNet2DModel, scheduler: SchedulerMixin):
super().__init__() super().__init__()
self.register_modules(unet=unet, scheduler=scheduler) self.register_modules(unet=unet, scheduler=scheduler)
......
...@@ -18,6 +18,7 @@ from typing import Optional, Tuple, Union ...@@ -18,6 +18,7 @@ from typing import Optional, Tuple, Union
import torch import torch
from diffusers import SchedulerMixin, UNet2DModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
...@@ -33,7 +34,7 @@ class CustomLocalPipeline(DiffusionPipeline): ...@@ -33,7 +34,7 @@ class CustomLocalPipeline(DiffusionPipeline):
[`DDPMScheduler`], or [`DDIMScheduler`]. [`DDPMScheduler`], or [`DDIMScheduler`].
""" """
def __init__(self, unet, scheduler): def __init__(self, unet: UNet2DModel, scheduler: SchedulerMixin):
super().__init__() super().__init__()
self.register_modules(unet=unet, scheduler=scheduler) self.register_modules(unet=unet, scheduler=scheduler)
......
...@@ -91,10 +91,10 @@ class Lumina2Text2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTester ...@@ -91,10 +91,10 @@ class Lumina2Text2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTester
text_encoder = Gemma2Model(config) text_encoder = Gemma2Model(config)
components = { components = {
"transformer": transformer.eval(), "transformer": transformer,
"vae": vae.eval(), "vae": vae.eval(),
"scheduler": scheduler, "scheduler": scheduler,
"text_encoder": text_encoder.eval(), "text_encoder": text_encoder,
"tokenizer": tokenizer, "tokenizer": tokenizer,
} }
return components return components
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment