Unverified Commit 3dc97bd1 authored by Tolga Cangöz's avatar Tolga Cangöz Committed by GitHub
Browse files

Update `CLIPFeatureExtractor` to `CLIPImageProcessor` and...


Update `CLIPFeatureExtractor` to `CLIPImageProcessor` and `DPTFeatureExtractor` to `DPTImageProcessor` (#9002)

* fix: update `CLIPFeatureExtractor` to `CLIPImageProcessor` in codebase

* `make style && make quality`

* Update `DPTFeatureExtractor` to `DPTImageProcessor` in codebase

* `make style`

---------
Co-authored-by: default avatarAryan <aryan@huggingface.co>
parent 6d32b292
...@@ -6,7 +6,7 @@ import numpy as np ...@@ -6,7 +6,7 @@ import numpy as np
import torch import torch
from datasets import Dataset, load_dataset from datasets import Dataset, load_dataset
from PIL import Image from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel, PretrainedConfig from transformers import CLIPImageProcessor, CLIPModel, PretrainedConfig
from diffusers import logging from diffusers import logging
...@@ -20,7 +20,7 @@ def normalize_images(images: List[Image.Image]): ...@@ -20,7 +20,7 @@ def normalize_images(images: List[Image.Image]):
return images return images
def preprocess_images(images: List[np.array], feature_extractor: CLIPFeatureExtractor) -> torch.Tensor: def preprocess_images(images: List[np.array], feature_extractor: CLIPImageProcessor) -> torch.Tensor:
""" """
Preprocesses a list of images into a batch of tensors. Preprocesses a list of images into a batch of tensors.
...@@ -95,14 +95,12 @@ class Index: ...@@ -95,14 +95,12 @@ class Index:
def build_index( def build_index(
self, self,
model=None, model=None,
feature_extractor: CLIPFeatureExtractor = None, feature_extractor: CLIPImageProcessor = None,
torch_dtype=torch.float32, torch_dtype=torch.float32,
): ):
if not self.index_initialized: if not self.index_initialized:
model = model or CLIPModel.from_pretrained(self.config.clip_name_or_path).to(dtype=torch_dtype) model = model or CLIPModel.from_pretrained(self.config.clip_name_or_path).to(dtype=torch_dtype)
feature_extractor = feature_extractor or CLIPFeatureExtractor.from_pretrained( feature_extractor = feature_extractor or CLIPImageProcessor.from_pretrained(self.config.clip_name_or_path)
self.config.clip_name_or_path
)
self.dataset = get_dataset_with_emb_from_clip_model( self.dataset = get_dataset_with_emb_from_clip_model(
self.dataset, self.dataset,
model, model,
...@@ -136,7 +134,7 @@ class Retriever: ...@@ -136,7 +134,7 @@ class Retriever:
index: Index = None, index: Index = None,
dataset: Dataset = None, dataset: Dataset = None,
model=None, model=None,
feature_extractor: CLIPFeatureExtractor = None, feature_extractor: CLIPImageProcessor = None,
): ):
self.config = config self.config = config
self.index = index or self._build_index(config, dataset, model=model, feature_extractor=feature_extractor) self.index = index or self._build_index(config, dataset, model=model, feature_extractor=feature_extractor)
...@@ -148,7 +146,7 @@ class Retriever: ...@@ -148,7 +146,7 @@ class Retriever:
index: Index = None, index: Index = None,
dataset: Dataset = None, dataset: Dataset = None,
model=None, model=None,
feature_extractor: CLIPFeatureExtractor = None, feature_extractor: CLIPImageProcessor = None,
**kwargs, **kwargs,
): ):
config = kwargs.pop("config", None) or IndexConfig.from_pretrained(retriever_name_or_path, **kwargs) config = kwargs.pop("config", None) or IndexConfig.from_pretrained(retriever_name_or_path, **kwargs)
...@@ -156,7 +154,7 @@ class Retriever: ...@@ -156,7 +154,7 @@ class Retriever:
@staticmethod @staticmethod
def _build_index( def _build_index(
config: IndexConfig, dataset: Dataset = None, model=None, feature_extractor: CLIPFeatureExtractor = None config: IndexConfig, dataset: Dataset = None, model=None, feature_extractor: CLIPImageProcessor = None
): ):
dataset = dataset or load_dataset(config.dataset_name) dataset = dataset or load_dataset(config.dataset_name)
dataset = dataset[config.dataset_set] dataset = dataset[config.dataset_set]
......
...@@ -76,13 +76,13 @@ EXAMPLE_DOC_STRING = """ ...@@ -76,13 +76,13 @@ EXAMPLE_DOC_STRING = """
>>> import numpy as np >>> import numpy as np
>>> from PIL import Image >>> from PIL import Image
>>> from transformers import DPTFeatureExtractor, DPTForDepthEstimation >>> from transformers import DPTImageProcessor, DPTForDepthEstimation
>>> from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL >>> from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL
>>> from diffusers.utils import load_image >>> from diffusers.utils import load_image
>>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") >>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
>>> feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas") >>> feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
>>> controlnet = ControlNetModel.from_pretrained( >>> controlnet = ControlNetModel.from_pretrained(
... "diffusers/controlnet-depth-sdxl-1.0-small", ... "diffusers/controlnet-depth-sdxl-1.0-small",
... variant="fp16", ... variant="fp16",
......
...@@ -23,7 +23,7 @@ from flax.core.frozen_dict import FrozenDict ...@@ -23,7 +23,7 @@ from flax.core.frozen_dict import FrozenDict
from flax.jax_utils import unreplicate from flax.jax_utils import unreplicate
from flax.training.common_utils import shard from flax.training.common_utils import shard
from PIL import Image from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel
from ...models import FlaxAutoencoderKL, FlaxControlNetModel, FlaxUNet2DConditionModel from ...models import FlaxAutoencoderKL, FlaxControlNetModel, FlaxUNet2DConditionModel
from ...schedulers import ( from ...schedulers import (
...@@ -149,7 +149,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline): ...@@ -149,7 +149,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
], ],
safety_checker: FlaxStableDiffusionSafetyChecker, safety_checker: FlaxStableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
): ):
super().__init__() super().__init__()
......
...@@ -16,7 +16,7 @@ import inspect ...@@ -16,7 +16,7 @@ import inspect
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ....image_processor import VaeImageProcessor from ....image_processor import VaeImageProcessor
from ....loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ....loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
...@@ -66,8 +66,8 @@ class StableDiffusionModelEditingPipeline( ...@@ -66,8 +66,8 @@ class StableDiffusionModelEditingPipeline(
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
about a model's potential harms. about a model's potential harms.
feature_extractor ([`~transformers.CLIPFeatureExtractor`]): feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPFeatureExtractor` to extract features from generated images; used as inputs to the `safety_checker`. A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
with_to_k ([`bool`]): with_to_k ([`bool`]):
Whether to edit the key projection matrices along with the value projection matrices. Whether to edit the key projection matrices along with the value projection matrices.
with_augs ([`list`]): with_augs ([`list`]):
...@@ -86,7 +86,7 @@ class StableDiffusionModelEditingPipeline( ...@@ -86,7 +86,7 @@ class StableDiffusionModelEditingPipeline(
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: SchedulerMixin, scheduler: SchedulerMixin,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
with_to_k: bool = True, with_to_k: bool = True,
with_augs: list = AUGS_CONST, with_augs: list = AUGS_CONST,
......
...@@ -20,7 +20,7 @@ import numpy as np ...@@ -20,7 +20,7 @@ import numpy as np
import PIL.Image import PIL.Image
import torch import torch
from packaging import version from packaging import version
from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation from transformers import CLIPTextModel, CLIPTokenizer, DPTForDepthEstimation, DPTImageProcessor
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
...@@ -111,7 +111,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -111,7 +111,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
depth_estimator: DPTForDepthEstimation, depth_estimator: DPTForDepthEstimation,
feature_extractor: DPTFeatureExtractor, feature_extractor: DPTImageProcessor,
): ):
super().__init__() super().__init__()
......
...@@ -18,7 +18,7 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -18,7 +18,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import PIL.Image import PIL.Image
import torch import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
...@@ -138,7 +138,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -138,7 +138,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline, StableDiffusionMixin):
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
......
...@@ -19,7 +19,7 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -19,7 +19,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import PIL.Image import PIL.Image
import torch import torch
from transformers import ( from transformers import (
CLIPFeatureExtractor, CLIPImageProcessor,
CLIPProcessor, CLIPProcessor,
CLIPTextModel, CLIPTextModel,
CLIPTokenizer, CLIPTokenizer,
...@@ -193,7 +193,7 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM ...@@ -193,7 +193,7 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
......
...@@ -19,7 +19,7 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -19,7 +19,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import torch import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
...@@ -209,7 +209,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -209,7 +209,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin):
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
...@@ -225,7 +225,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -225,7 +225,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin):
adapter: Union[T2IAdapter, MultiAdapter, List[T2IAdapter]], adapter: Union[T2IAdapter, MultiAdapter, List[T2IAdapter]],
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
......
...@@ -237,7 +237,7 @@ class StableDiffusionXLAdapterPipeline( ...@@ -237,7 +237,7 @@ class StableDiffusionXLAdapterPipeline(
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
......
...@@ -26,8 +26,8 @@ from transformers import ( ...@@ -26,8 +26,8 @@ from transformers import (
CLIPTextModel, CLIPTextModel,
CLIPTokenizer, CLIPTokenizer,
DPTConfig, DPTConfig,
DPTFeatureExtractor,
DPTForDepthEstimation, DPTForDepthEstimation,
DPTImageProcessor,
) )
from diffusers import ( from diffusers import (
...@@ -145,9 +145,7 @@ class StableDiffusionDepth2ImgPipelineFastTests( ...@@ -145,9 +145,7 @@ class StableDiffusionDepth2ImgPipelineFastTests(
backbone_featmap_shape=[1, 384, 24, 24], backbone_featmap_shape=[1, 384, 24, 24],
) )
depth_estimator = DPTForDepthEstimation(depth_estimator_config).eval() depth_estimator = DPTForDepthEstimation(depth_estimator_config).eval()
feature_extractor = DPTFeatureExtractor.from_pretrained( feature_extractor = DPTImageProcessor.from_pretrained("hf-internal-testing/tiny-random-DPTForDepthEstimation")
"hf-internal-testing/tiny-random-DPTForDepthEstimation"
)
components = { components = {
"unet": unet, "unet": unet,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment