"docs/vscode:/vscode.git/clone" did not exist on "181688012a2abadc93b316d91a513bee36193615"
Unverified Commit 9c7e2051 authored by Daniel Regado's avatar Daniel Regado Committed by GitHub
Browse files

Comprehensive type checking for `from_pretrained` kwargs (#10758)



* More robust from_pretrained init_kwargs type checking

* Corrected for Python 3.10

* Type checks subclasses and fixed type warnings

* More type corrections and skip tokenizer type checking

* make style && make quality

* Updated docs and types for Lumina pipelines

* Fixed check for empty signature

* changed location of helper functions

* make style

---------
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent 64dec70e
......@@ -18,10 +18,10 @@ from typing import Any, Callable, Dict, List, Optional, Union
import PIL.Image
import torch
from transformers import (
BaseImageProcessor,
CLIPTextModelWithProjection,
CLIPTokenizer,
PreTrainedModel,
SiglipImageProcessor,
SiglipVisionModel,
T5EncoderModel,
T5TokenizerFast,
)
......@@ -197,6 +197,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
tokenizer_3 (`T5TokenizerFast`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
image_encoder (`SiglipVisionModel`, *optional*):
Pre-trained Vision Model for IP Adapter.
feature_extractor (`SiglipImageProcessor`, *optional*):
Image processor for IP Adapter.
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
......@@ -214,8 +218,8 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
tokenizer_2: CLIPTokenizer,
text_encoder_3: T5EncoderModel,
tokenizer_3: T5TokenizerFast,
image_encoder: PreTrainedModel = None,
feature_extractor: BaseImageProcessor = None,
image_encoder: Optional[SiglipVisionModel] = None,
feature_extractor: Optional[SiglipImageProcessor] = None,
):
super().__init__()
......
......@@ -17,10 +17,10 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch
from transformers import (
BaseImageProcessor,
CLIPTextModelWithProjection,
CLIPTokenizer,
PreTrainedModel,
SiglipImageProcessor,
SiglipVisionModel,
T5EncoderModel,
T5TokenizerFast,
)
......@@ -196,9 +196,9 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
tokenizer_3 (`T5TokenizerFast`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
image_encoder (`PreTrainedModel`, *optional*):
image_encoder (`SiglipVisionModel`, *optional*):
Pre-trained Vision Model for IP Adapter.
feature_extractor (`BaseImageProcessor`, *optional*):
feature_extractor (`SiglipImageProcessor`, *optional*):
Image processor for IP Adapter.
"""
......@@ -217,8 +217,8 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
tokenizer_2: CLIPTokenizer,
text_encoder_3: T5EncoderModel,
tokenizer_3: T5TokenizerFast,
image_encoder: PreTrainedModel = None,
feature_extractor: BaseImageProcessor = None,
image_encoder: Optional[SiglipVisionModel] = None,
feature_extractor: Optional[SiglipImageProcessor] = None,
):
super().__init__()
......
......@@ -19,15 +19,31 @@ from typing import Callable, List, Optional, Union
import torch
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
CLIPTokenizer,
CLIPTokenizerFast,
)
from ...image_processor import VaeImageProcessor
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import (
StableDiffusionLoraLoaderMixin,
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import LMSDiscreteScheduler
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -95,13 +111,13 @@ class StableDiffusionKDiffusionPipeline(
def __init__(
self,
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker,
feature_extractor,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: Union[CLIPTokenizer, CLIPTokenizerFast],
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
):
super().__init__()
......
......@@ -18,7 +18,7 @@ from typing import Optional, Tuple, Union
import torch
from diffusers import DiffusionPipeline, ImagePipelineOutput
from diffusers import DiffusionPipeline, ImagePipelineOutput, SchedulerMixin, UNet2DModel
class CustomLocalPipeline(DiffusionPipeline):
......@@ -33,7 +33,7 @@ class CustomLocalPipeline(DiffusionPipeline):
[`DDPMScheduler`], or [`DDIMScheduler`].
"""
def __init__(self, unet, scheduler):
def __init__(self, unet: UNet2DModel, scheduler: SchedulerMixin):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
......
......@@ -18,6 +18,7 @@ from typing import Optional, Tuple, Union
import torch
from diffusers import SchedulerMixin, UNet2DModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......@@ -33,7 +34,7 @@ class CustomLocalPipeline(DiffusionPipeline):
[`DDPMScheduler`], or [`DDIMScheduler`].
"""
def __init__(self, unet, scheduler):
def __init__(self, unet: UNet2DModel, scheduler: SchedulerMixin):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
......
......@@ -91,10 +91,10 @@ class Lumina2Text2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTester
text_encoder = Gemma2Model(config)
components = {
"transformer": transformer.eval(),
"transformer": transformer,
"vae": vae.eval(),
"scheduler": scheduler,
"text_encoder": text_encoder.eval(),
"text_encoder": text_encoder,
"tokenizer": tokenizer,
}
return components
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment