Unverified Commit 14e3a28c authored by Naoki Ainoya's avatar Naoki Ainoya Committed by GitHub
Browse files

Rename 'CLIPFeatureExtractor' class to 'CLIPImageProcessor' (#2732)

The 'CLIPFeatureExtractor' class name has been renamed to 'CLIPImageProcessor' in order to comply with future deprecation. This commit includes the necessary changes to the affected files.
parent 8e35ef01
...@@ -18,7 +18,7 @@ from typing import Callable, List, Optional, Union ...@@ -18,7 +18,7 @@ from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
import PIL import PIL
import torch import torch
from transformers import CLIPFeatureExtractor, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
...@@ -77,7 +77,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -77,7 +77,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
vae_encoder: OnnxRuntimeModel vae_encoder: OnnxRuntimeModel
...@@ -87,7 +87,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -87,7 +87,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
unet: OnnxRuntimeModel unet: OnnxRuntimeModel
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
safety_checker: OnnxRuntimeModel safety_checker: OnnxRuntimeModel
feature_extractor: CLIPFeatureExtractor feature_extractor: CLIPImageProcessor
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
...@@ -100,7 +100,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -100,7 +100,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
unet: OnnxRuntimeModel, unet: OnnxRuntimeModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: OnnxRuntimeModel, safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
......
...@@ -18,7 +18,7 @@ from typing import Callable, List, Optional, Union ...@@ -18,7 +18,7 @@ from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
import PIL import PIL
import torch import torch
from transformers import CLIPFeatureExtractor, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
...@@ -77,7 +77,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -77,7 +77,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
vae_encoder: OnnxRuntimeModel vae_encoder: OnnxRuntimeModel
...@@ -87,7 +87,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -87,7 +87,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
unet: OnnxRuntimeModel unet: OnnxRuntimeModel
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
safety_checker: OnnxRuntimeModel safety_checker: OnnxRuntimeModel
feature_extractor: CLIPFeatureExtractor feature_extractor: CLIPImageProcessor
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
...@@ -100,7 +100,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -100,7 +100,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
unet: OnnxRuntimeModel, unet: OnnxRuntimeModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: OnnxRuntimeModel, safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
......
...@@ -4,7 +4,7 @@ from typing import Callable, List, Optional, Union ...@@ -4,7 +4,7 @@ from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
import PIL import PIL
import torch import torch
from transformers import CLIPFeatureExtractor, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
...@@ -63,7 +63,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -63,7 +63,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
...@@ -75,7 +75,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -75,7 +75,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
unet: OnnxRuntimeModel unet: OnnxRuntimeModel
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
safety_checker: OnnxRuntimeModel safety_checker: OnnxRuntimeModel
feature_extractor: CLIPFeatureExtractor feature_extractor: CLIPImageProcessor
def __init__( def __init__(
self, self,
...@@ -86,7 +86,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -86,7 +86,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
unet: OnnxRuntimeModel, unet: OnnxRuntimeModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: OnnxRuntimeModel, safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
......
...@@ -17,7 +17,7 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -17,7 +17,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
from packaging import version from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
...@@ -76,7 +76,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -76,7 +76,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
...@@ -89,7 +89,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -89,7 +89,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
......
...@@ -19,7 +19,7 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -19,7 +19,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import Attention from ...models.attention_processor import Attention
...@@ -183,7 +183,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline): ...@@ -183,7 +183,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
...@@ -196,7 +196,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline): ...@@ -196,7 +196,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
......
...@@ -21,7 +21,7 @@ import numpy as np ...@@ -21,7 +21,7 @@ import numpy as np
import PIL.Image import PIL.Image
import torch import torch
from torch import nn from torch import nn
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.controlnet import ControlNetOutput from ...models.controlnet import ControlNetOutput
...@@ -174,7 +174,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline): ...@@ -174,7 +174,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
...@@ -188,7 +188,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline): ...@@ -188,7 +188,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
......
...@@ -18,7 +18,7 @@ from typing import Callable, List, Optional, Union ...@@ -18,7 +18,7 @@ from typing import Callable, List, Optional, Union
import PIL import PIL
import torch import torch
from packaging import version from packaging import version
from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
...@@ -53,7 +53,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -53,7 +53,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
# TODO: feature_extractor is required to encode images (if they are in PIL format), # TODO: feature_extractor is required to encode images (if they are in PIL format),
...@@ -67,7 +67,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -67,7 +67,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
...@@ -284,7 +284,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -284,7 +284,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
The image or images to guide the image generation. If you provide a tensor, it needs to comply with the The image or images to guide the image generation. If you provide a tensor, it needs to comply with the
configuration of configuration of
[this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json) [this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
`CLIPFeatureExtractor` `CLIPImageProcessor`
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
......
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
import PIL import PIL
import torch import torch
from packaging import version from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
...@@ -115,7 +115,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -115,7 +115,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
...@@ -128,7 +128,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -128,7 +128,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
......
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
import PIL import PIL
import torch import torch
from packaging import version from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
...@@ -161,7 +161,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -161,7 +161,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
...@@ -174,7 +174,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -174,7 +174,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
......
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
import PIL import PIL
import torch import torch
from packaging import version from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
...@@ -105,7 +105,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -105,7 +105,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["feature_extractor"] _optional_components = ["feature_extractor"]
...@@ -119,7 +119,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -119,7 +119,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
......
...@@ -18,7 +18,7 @@ from typing import Callable, List, Optional, Union ...@@ -18,7 +18,7 @@ from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
import PIL import PIL
import torch import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
...@@ -84,7 +84,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): ...@@ -84,7 +84,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
...@@ -97,7 +97,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): ...@@ -97,7 +97,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
......
...@@ -71,7 +71,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): ...@@ -71,7 +71,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
......
...@@ -15,7 +15,7 @@ import inspect ...@@ -15,7 +15,7 @@ import inspect
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, PNDMScheduler
...@@ -75,7 +75,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline): ...@@ -75,7 +75,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
...@@ -88,7 +88,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline): ...@@ -88,7 +88,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline):
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: DDIMScheduler, scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
......
...@@ -23,7 +23,7 @@ import torch.nn.functional as F ...@@ -23,7 +23,7 @@ import torch.nn.functional as F
from transformers import ( from transformers import (
BlipForConditionalGeneration, BlipForConditionalGeneration,
BlipProcessor, BlipProcessor,
CLIPFeatureExtractor, CLIPImageProcessor,
CLIPTextModel, CLIPTextModel,
CLIPTokenizer, CLIPTokenizer,
) )
...@@ -297,7 +297,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -297,7 +297,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
requires_safety_checker (bool): requires_safety_checker (bool):
Whether the pipeline requires a safety checker. We recommend setting it to True if you're using the Whether the pipeline requires a safety checker. We recommend setting it to True if you're using the
...@@ -318,7 +318,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -318,7 +318,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: Union[DDPMScheduler, DDIMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler], scheduler: Union[DDPMScheduler, DDIMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler],
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
inverse_scheduler: DDIMInverseScheduler, inverse_scheduler: DDIMInverseScheduler,
caption_generator: BlipForConditionalGeneration, caption_generator: BlipForConditionalGeneration,
......
...@@ -17,7 +17,7 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -17,7 +17,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
...@@ -111,7 +111,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline): ...@@ -111,7 +111,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
...@@ -124,7 +124,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline): ...@@ -124,7 +124,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline):
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
......
...@@ -17,7 +17,7 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -17,7 +17,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import PIL import PIL
import torch import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers.utils.import_utils import is_accelerate_available from diffusers.utils.import_utils import is_accelerate_available
...@@ -68,7 +68,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline): ...@@ -68,7 +68,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args: Args:
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Feature extractor for image pre-processing before being encoded. Feature extractor for image pre-processing before being encoded.
image_encoder ([`CLIPVisionModelWithProjection`]): image_encoder ([`CLIPVisionModelWithProjection`]):
CLIP vision model for encoding images. CLIP vision model for encoding images.
...@@ -91,7 +91,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline): ...@@ -91,7 +91,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
""" """
# image encoding components # image encoding components
feature_extractor: CLIPFeatureExtractor feature_extractor: CLIPImageProcessor
image_encoder: CLIPVisionModelWithProjection image_encoder: CLIPVisionModelWithProjection
# image noising components # image noising components
...@@ -109,7 +109,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline): ...@@ -109,7 +109,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
def __init__( def __init__(
self, self,
# image encoding components # image encoding components
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection, image_encoder: CLIPVisionModelWithProjection,
# image noising components # image noising components
image_normalizer: StableUnCLIPImageNormalizer, image_normalizer: StableUnCLIPImageNormalizer,
......
...@@ -5,7 +5,7 @@ from typing import Callable, List, Optional, Union ...@@ -5,7 +5,7 @@ from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from packaging import version from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
...@@ -45,7 +45,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -45,7 +45,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
...@@ -59,7 +59,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -59,7 +59,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: SafeStableDiffusionSafetyChecker, safety_checker: SafeStableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
......
...@@ -19,7 +19,7 @@ import PIL ...@@ -19,7 +19,7 @@ import PIL
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from transformers import ( from transformers import (
CLIPFeatureExtractor, CLIPImageProcessor,
CLIPTextModelWithProjection, CLIPTextModelWithProjection,
CLIPTokenizer, CLIPTokenizer,
CLIPVisionModelWithProjection, CLIPVisionModelWithProjection,
...@@ -48,7 +48,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): ...@@ -48,7 +48,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
tokenizer (`CLIPTokenizer`): tokenizer (`CLIPTokenizer`):
Tokenizer of class Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `image_encoder`. Model that extracts features from generated images to be used as inputs for the `image_encoder`.
image_encoder ([`CLIPVisionModelWithProjection`]): image_encoder ([`CLIPVisionModelWithProjection`]):
Frozen CLIP image-encoder. unCLIP Image Variation uses the vision portion of Frozen CLIP image-encoder. unCLIP Image Variation uses the vision portion of
...@@ -73,7 +73,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): ...@@ -73,7 +73,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
text_proj: UnCLIPTextProjModel text_proj: UnCLIPTextProjModel
text_encoder: CLIPTextModelWithProjection text_encoder: CLIPTextModelWithProjection
tokenizer: CLIPTokenizer tokenizer: CLIPTokenizer
feature_extractor: CLIPFeatureExtractor feature_extractor: CLIPImageProcessor
image_encoder: CLIPVisionModelWithProjection image_encoder: CLIPVisionModelWithProjection
super_res_first: UNet2DModel super_res_first: UNet2DModel
super_res_last: UNet2DModel super_res_last: UNet2DModel
...@@ -87,7 +87,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): ...@@ -87,7 +87,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
text_encoder: CLIPTextModelWithProjection, text_encoder: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
text_proj: UnCLIPTextProjModel, text_proj: UnCLIPTextProjModel,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection, image_encoder: CLIPVisionModelWithProjection,
super_res_first: UNet2DModel, super_res_first: UNet2DModel,
super_res_last: UNet2DModel, super_res_last: UNet2DModel,
...@@ -264,7 +264,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): ...@@ -264,7 +264,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
The image or images to guide the image generation. If you provide a tensor, it needs to comply with the The image or images to guide the image generation. If you provide a tensor, it needs to comply with the
configuration of configuration of
[this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json) [this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
`CLIPFeatureExtractor`. Can be left to `None` only when `image_embeddings` are passed. `CLIPImageProcessor`. Can be left to `None` only when `image_embeddings` are passed.
num_images_per_prompt (`int`, *optional*, defaults to 1): num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. The number of images to generate per prompt.
decoder_num_inference_steps (`int`, *optional*, defaults to 25): decoder_num_inference_steps (`int`, *optional*, defaults to 25):
......
...@@ -3,7 +3,7 @@ from typing import Callable, List, Optional, Union ...@@ -3,7 +3,7 @@ from typing import Callable, List, Optional, Union
import PIL.Image import PIL.Image
import torch import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
...@@ -41,12 +41,12 @@ class VersatileDiffusionPipeline(DiffusionPipeline): ...@@ -41,12 +41,12 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionMegaSafetyChecker`]): safety_checker ([`StableDiffusionMegaSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
tokenizer: CLIPTokenizer tokenizer: CLIPTokenizer
image_feature_extractor: CLIPFeatureExtractor image_feature_extractor: CLIPImageProcessor
text_encoder: CLIPTextModel text_encoder: CLIPTextModel
image_encoder: CLIPVisionModel image_encoder: CLIPVisionModel
image_unet: UNet2DConditionModel image_unet: UNet2DConditionModel
...@@ -57,7 +57,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline): ...@@ -57,7 +57,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
def __init__( def __init__(
self, self,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
image_feature_extractor: CLIPFeatureExtractor, image_feature_extractor: CLIPImageProcessor,
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
image_encoder: CLIPVisionModel, image_encoder: CLIPVisionModel,
image_unet: UNet2DConditionModel, image_unet: UNet2DConditionModel,
......
...@@ -20,7 +20,7 @@ import PIL ...@@ -20,7 +20,7 @@ import PIL
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from transformers import ( from transformers import (
CLIPFeatureExtractor, CLIPImageProcessor,
CLIPTextModelWithProjection, CLIPTextModelWithProjection,
CLIPTokenizer, CLIPTokenizer,
CLIPVisionModelWithProjection, CLIPVisionModelWithProjection,
...@@ -55,7 +55,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): ...@@ -55,7 +55,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
""" """
tokenizer: CLIPTokenizer tokenizer: CLIPTokenizer
image_feature_extractor: CLIPFeatureExtractor image_feature_extractor: CLIPImageProcessor
text_encoder: CLIPTextModelWithProjection text_encoder: CLIPTextModelWithProjection
image_encoder: CLIPVisionModelWithProjection image_encoder: CLIPVisionModelWithProjection
image_unet: UNet2DConditionModel image_unet: UNet2DConditionModel
...@@ -68,7 +68,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): ...@@ -68,7 +68,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
def __init__( def __init__(
self, self,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
image_feature_extractor: CLIPFeatureExtractor, image_feature_extractor: CLIPImageProcessor,
text_encoder: CLIPTextModelWithProjection, text_encoder: CLIPTextModelWithProjection,
image_encoder: CLIPVisionModelWithProjection, image_encoder: CLIPVisionModelWithProjection,
image_unet: UNet2DConditionModel, image_unet: UNet2DConditionModel,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment