"vscode:/vscode.git/clone" did not exist on "13e781f9a5693eff2b7f3d6848d9ef205b24e493"
Unverified Commit 3dc97bd1 authored by Tolga Cangöz's avatar Tolga Cangöz Committed by GitHub
Browse files

Update `CLIPFeatureExtractor` to `CLIPImageProcessor` and...


Update `CLIPFeatureExtractor` to `CLIPImageProcessor` and `DPTFeatureExtractor` to `DPTImageProcessor` (#9002)

* fix: update `CLIPFeatureExtractor` to `CLIPImageProcessor` in codebase

* `make style && make quality`

* Update `DPTFeatureExtractor` to `DPTImageProcessor` in codebase

* `make style`

---------
Co-authored-by: default avatarAryan <aryan@huggingface.co>
parent 6d32b292
......@@ -6,7 +6,7 @@ import numpy as np
import torch
from datasets import Dataset, load_dataset
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel, PretrainedConfig
from transformers import CLIPImageProcessor, CLIPModel, PretrainedConfig
from diffusers import logging
......@@ -20,7 +20,7 @@ def normalize_images(images: List[Image.Image]):
return images
def preprocess_images(images: List[np.array], feature_extractor: CLIPFeatureExtractor) -> torch.Tensor:
def preprocess_images(images: List[np.array], feature_extractor: CLIPImageProcessor) -> torch.Tensor:
"""
Preprocesses a list of images into a batch of tensors.
......@@ -95,14 +95,12 @@ class Index:
def build_index(
self,
model=None,
feature_extractor: CLIPFeatureExtractor = None,
feature_extractor: CLIPImageProcessor = None,
torch_dtype=torch.float32,
):
if not self.index_initialized:
model = model or CLIPModel.from_pretrained(self.config.clip_name_or_path).to(dtype=torch_dtype)
feature_extractor = feature_extractor or CLIPFeatureExtractor.from_pretrained(
self.config.clip_name_or_path
)
feature_extractor = feature_extractor or CLIPImageProcessor.from_pretrained(self.config.clip_name_or_path)
self.dataset = get_dataset_with_emb_from_clip_model(
self.dataset,
model,
......@@ -136,7 +134,7 @@ class Retriever:
index: Index = None,
dataset: Dataset = None,
model=None,
feature_extractor: CLIPFeatureExtractor = None,
feature_extractor: CLIPImageProcessor = None,
):
self.config = config
self.index = index or self._build_index(config, dataset, model=model, feature_extractor=feature_extractor)
......@@ -148,7 +146,7 @@ class Retriever:
index: Index = None,
dataset: Dataset = None,
model=None,
feature_extractor: CLIPFeatureExtractor = None,
feature_extractor: CLIPImageProcessor = None,
**kwargs,
):
config = kwargs.pop("config", None) or IndexConfig.from_pretrained(retriever_name_or_path, **kwargs)
......@@ -156,7 +154,7 @@ class Retriever:
@staticmethod
def _build_index(
config: IndexConfig, dataset: Dataset = None, model=None, feature_extractor: CLIPFeatureExtractor = None
config: IndexConfig, dataset: Dataset = None, model=None, feature_extractor: CLIPImageProcessor = None
):
dataset = dataset or load_dataset(config.dataset_name)
dataset = dataset[config.dataset_set]
......
......@@ -76,13 +76,13 @@ EXAMPLE_DOC_STRING = """
>>> import numpy as np
>>> from PIL import Image
>>> from transformers import DPTFeatureExtractor, DPTForDepthEstimation
>>> from transformers import DPTImageProcessor, DPTForDepthEstimation
>>> from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL
>>> from diffusers.utils import load_image
>>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
>>> feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
>>> feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
>>> controlnet = ControlNetModel.from_pretrained(
... "diffusers/controlnet-depth-sdxl-1.0-small",
... variant="fp16",
......
......@@ -23,7 +23,7 @@ from flax.core.frozen_dict import FrozenDict
from flax.jax_utils import unreplicate
from flax.training.common_utils import shard
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel
from ...models import FlaxAutoencoderKL, FlaxControlNetModel, FlaxUNet2DConditionModel
from ...schedulers import (
......@@ -149,7 +149,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
],
safety_checker: FlaxStableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
dtype: jnp.dtype = jnp.float32,
):
super().__init__()
......
......@@ -16,7 +16,7 @@ import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ....image_processor import VaeImageProcessor
from ....loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
......@@ -66,8 +66,8 @@ class StableDiffusionModelEditingPipeline(
Classification module that estimates whether generated images could be considered offensive or harmful.
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
about a model's potential harms.
feature_extractor ([`~transformers.CLIPFeatureExtractor`]):
A `CLIPFeatureExtractor` to extract features from generated images; used as inputs to the `safety_checker`.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
with_to_k ([`bool`]):
Whether to edit the key projection matrices along with the value projection matrices.
with_augs ([`list`]):
......@@ -86,7 +86,7 @@ class StableDiffusionModelEditingPipeline(
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
with_to_k: bool = True,
with_augs: list = AUGS_CONST,
......
......@@ -20,7 +20,7 @@ import numpy as np
import PIL.Image
import torch
from packaging import version
from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation
from transformers import CLIPTextModel, CLIPTokenizer, DPTForDepthEstimation, DPTImageProcessor
from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor
......@@ -111,7 +111,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
depth_estimator: DPTForDepthEstimation,
feature_extractor: DPTFeatureExtractor,
feature_extractor: DPTImageProcessor,
):
super().__init__()
......
......@@ -18,7 +18,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import PIL.Image
import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
......@@ -138,7 +138,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline, StableDiffusionMixin):
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
):
super().__init__()
......
......@@ -19,7 +19,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import PIL.Image
import torch
from transformers import (
CLIPFeatureExtractor,
CLIPImageProcessor,
CLIPProcessor,
CLIPTextModel,
CLIPTokenizer,
......@@ -193,7 +193,7 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
):
super().__init__()
......
......@@ -19,7 +19,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
......@@ -209,7 +209,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin):
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
......@@ -225,7 +225,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin):
adapter: Union[T2IAdapter, MultiAdapter, List[T2IAdapter]],
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
):
super().__init__()
......
......@@ -237,7 +237,7 @@ class StableDiffusionXLAdapterPipeline(
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
......
......@@ -26,8 +26,8 @@ from transformers import (
CLIPTextModel,
CLIPTokenizer,
DPTConfig,
DPTFeatureExtractor,
DPTForDepthEstimation,
DPTImageProcessor,
)
from diffusers import (
......@@ -145,9 +145,7 @@ class StableDiffusionDepth2ImgPipelineFastTests(
backbone_featmap_shape=[1, 384, 24, 24],
)
depth_estimator = DPTForDepthEstimation(depth_estimator_config).eval()
feature_extractor = DPTFeatureExtractor.from_pretrained(
"hf-internal-testing/tiny-random-DPTForDepthEstimation"
)
feature_extractor = DPTImageProcessor.from_pretrained("hf-internal-testing/tiny-random-DPTForDepthEstimation")
components = {
"unet": unet,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment