"docs/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "34fab8b511846ca4d3432f17317026fed2d883ad"
Unverified Commit 3dc97bd1 authored by Tolga Cangöz's avatar Tolga Cangöz Committed by GitHub
Browse files

Update `CLIPFeatureExtractor` to `CLIPImageProcessor` and...


Update `CLIPFeatureExtractor` to `CLIPImageProcessor` and `DPTFeatureExtractor` to `DPTImageProcessor` (#9002)

* fix: update `CLIPFeatureExtractor` to `CLIPImageProcessor` in codebase

* `make style && make quality`

* Update `DPTFeatureExtractor` to `DPTImageProcessor` in codebase

* `make style`

---------
Co-authored-by: default avatarAryan <aryan@huggingface.co>
parent 6d32b292
...@@ -289,9 +289,9 @@ scheduler = DPMSolverMultistepScheduler.from_pretrained(pipe_id, subfolder="sche ...@@ -289,9 +289,9 @@ scheduler = DPMSolverMultistepScheduler.from_pretrained(pipe_id, subfolder="sche
3. Load an image processor: 3. Load an image processor:
```python ```python
from transformers import CLIPFeatureExtractor from transformers import CLIPImageProcessor
feature_extractor = CLIPFeatureExtractor.from_pretrained(pipe_id, subfolder="feature_extractor") feature_extractor = CLIPImageProcessor.from_pretrained(pipe_id, subfolder="feature_extractor")
``` ```
<Tip warning={true}> <Tip warning={true}>
......
...@@ -212,14 +212,14 @@ TCD-LoRA is very versatile, and it can be combined with other adapter types like ...@@ -212,14 +212,14 @@ TCD-LoRA is very versatile, and it can be combined with other adapter types like
import torch import torch
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from transformers import DPTFeatureExtractor, DPTForDepthEstimation from transformers import DPTImageProcessor, DPTForDepthEstimation
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
from diffusers.utils import load_image, make_image_grid from diffusers.utils import load_image, make_image_grid
from scheduling_tcd import TCDScheduler from scheduling_tcd import TCDScheduler
device = "cuda" device = "cuda"
depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device) depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device)
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas") feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
def get_depth_map(image): def get_depth_map(image):
image = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device) image = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
......
...@@ -307,7 +307,7 @@ print(pipeline) ...@@ -307,7 +307,7 @@ print(pipeline)
위의 코드 출력 결과를 확인해보면, `pipeline`은 [`StableDiffusionPipeline`]의 인스턴스이며, 다음과 같이 총 7개의 컴포넌트로 구성된다는 것을 알 수 있습니다. 위의 코드 출력 결과를 확인해보면, `pipeline`은 [`StableDiffusionPipeline`]의 인스턴스이며, 다음과 같이 총 7개의 컴포넌트로 구성된다는 것을 알 수 있습니다.
- `"feature_extractor"`: [`~transformers.CLIPFeatureExtractor`]의 인스턴스 - `"feature_extractor"`: [`~transformers.CLIPImageProcessor`]의 인스턴스
- `"safety_checker"`: 유해한 컨텐츠를 스크리닝하기 위한 [컴포넌트](https://github.com/huggingface/diffusers/blob/e55687e1e15407f60f32242027b7bb8170e58266/src/diffusers/pipelines/stable_diffusion/safety_checker.py#L32) - `"safety_checker"`: 유해한 컨텐츠를 스크리닝하기 위한 [컴포넌트](https://github.com/huggingface/diffusers/blob/e55687e1e15407f60f32242027b7bb8170e58266/src/diffusers/pipelines/stable_diffusion/safety_checker.py#L32)
- `"scheduler"`: [`PNDMScheduler`]의 인스턴스 - `"scheduler"`: [`PNDMScheduler`]의 인스턴스
- `"text_encoder"`: [`~transformers.CLIPTextModel`]의 인스턴스 - `"text_encoder"`: [`~transformers.CLIPTextModel`]의 인스턴스
......
...@@ -24,7 +24,7 @@ import PIL ...@@ -24,7 +24,7 @@ import PIL
from PIL import Image from PIL import Image
from diffusers import StableDiffusionPipeline from diffusers import StableDiffusionPipeline
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
def image_grid(imgs, rows, cols): def image_grid(imgs, rows, cols):
......
...@@ -1435,9 +1435,9 @@ import requests ...@@ -1435,9 +1435,9 @@ import requests
import torch import torch
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from PIL import Image from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel from transformers import CLIPImageProcessor, CLIPModel
feature_extractor = CLIPFeatureExtractor.from_pretrained( feature_extractor = CLIPImageProcessor.from_pretrained(
"laion/CLIP-ViT-B-32-laion2B-s34B-b79K" "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
) )
clip_model = CLIPModel.from_pretrained( clip_model = CLIPModel.from_pretrained(
...@@ -2122,7 +2122,7 @@ import torch ...@@ -2122,7 +2122,7 @@ import torch
import open_clip import open_clip
from open_clip import SimpleTokenizer from open_clip import SimpleTokenizer
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from transformers import CLIPFeatureExtractor, CLIPModel from transformers import CLIPImageProcessor, CLIPModel
def download_image(url): def download_image(url):
...@@ -2130,7 +2130,7 @@ def download_image(url): ...@@ -2130,7 +2130,7 @@ def download_image(url):
return PIL.Image.open(BytesIO(response.content)).convert("RGB") return PIL.Image.open(BytesIO(response.content)).convert("RGB")
# Loading additional models # Loading additional models
feature_extractor = CLIPFeatureExtractor.from_pretrained( feature_extractor = CLIPImageProcessor.from_pretrained(
"laion/CLIP-ViT-B-32-laion2B-s34B-b79K" "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
) )
clip_model = CLIPModel.from_pretrained( clip_model = CLIPModel.from_pretrained(
......
...@@ -7,7 +7,7 @@ import PIL.Image ...@@ -7,7 +7,7 @@ import PIL.Image
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from torchvision import transforms from torchvision import transforms
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPModel, CLIPTextModel, CLIPTokenizer
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
...@@ -86,7 +86,7 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline, StableDiffusionMi ...@@ -86,7 +86,7 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline, StableDiffusionMi
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler], scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler],
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
coca_model=None, coca_model=None,
coca_tokenizer=None, coca_tokenizer=None,
coca_transform=None, coca_transform=None,
......
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from torchvision import transforms from torchvision import transforms
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPModel, CLIPTextModel, CLIPTokenizer
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
...@@ -32,9 +32,9 @@ EXAMPLE_DOC_STRING = """ ...@@ -32,9 +32,9 @@ EXAMPLE_DOC_STRING = """
import torch import torch
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from PIL import Image from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel from transformers import CLIPImageProcessor, CLIPModel
feature_extractor = CLIPFeatureExtractor.from_pretrained( feature_extractor = CLIPImageProcessor.from_pretrained(
"laion/CLIP-ViT-B-32-laion2B-s34B-b79K" "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
) )
clip_model = CLIPModel.from_pretrained( clip_model = CLIPModel.from_pretrained(
...@@ -139,7 +139,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline, StableDiffusionMixin): ...@@ -139,7 +139,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline, StableDiffusionMixin):
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler], scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler],
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
): ):
super().__init__() super().__init__()
self.register_modules( self.register_modules(
......
...@@ -9,7 +9,7 @@ import torch ...@@ -9,7 +9,7 @@ import torch
from numpy import exp, pi, sqrt from numpy import exp, pi, sqrt
from torchvision.transforms.functional import resize from torchvision.transforms.functional import resize
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
...@@ -275,7 +275,7 @@ class StableDiffusionCanvasPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -275,7 +275,7 @@ class StableDiffusionCanvasPipeline(DiffusionPipeline, StableDiffusionMixin):
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler], scheduler: Union[DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler],
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
): ):
super().__init__() super().__init__()
self.register_modules( self.register_modules(
......
...@@ -15,7 +15,7 @@ from diffusers.utils import logging ...@@ -15,7 +15,7 @@ from diffusers.utils import logging
try: try:
from ligo.segments import segment from ligo.segments import segment
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
except ImportError: except ImportError:
raise ImportError("Please install transformers and ligo-segments to use the mixture pipeline") raise ImportError("Please install transformers and ligo-segments to use the mixture pipeline")
...@@ -144,7 +144,7 @@ class StableDiffusionTilingPipeline(DiffusionPipeline, StableDiffusionExtrasMixi ...@@ -144,7 +144,7 @@ class StableDiffusionTilingPipeline(DiffusionPipeline, StableDiffusionExtrasMixi
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler], scheduler: Union[DDIMScheduler, PNDMScheduler],
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
): ):
super().__init__() super().__init__()
self.register_modules( self.register_modules(
......
...@@ -189,7 +189,7 @@ class StableDiffusionXLControlNetAdapterPipeline( ...@@ -189,7 +189,7 @@ class StableDiffusionXLControlNetAdapterPipeline(
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
......
...@@ -332,7 +332,7 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline( ...@@ -332,7 +332,7 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config
......
...@@ -9,7 +9,7 @@ import numpy as np ...@@ -9,7 +9,7 @@ import numpy as np
import PIL.Image import PIL.Image
import torch import torch
from packaging import version from packaging import version
from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
# from ...configuration_utils import FrozenDict # from ...configuration_utils import FrozenDict
# from ...models import AutoencoderKL, UNet2DConditionModel # from ...models import AutoencoderKL, UNet2DConditionModel
...@@ -87,7 +87,7 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -87,7 +87,7 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
cc_projection ([`CCProjection`]): cc_projection ([`CCProjection`]):
Projection layer to project the concated CLIP features and pose embeddings to the original CLIP feature size. Projection layer to project the concated CLIP features and pose embeddings to the original CLIP feature size.
...@@ -102,7 +102,7 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -102,7 +102,7 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
cc_projection: CCProjection, cc_projection: CCProjection,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
......
...@@ -3,7 +3,7 @@ from typing import Dict, Optional ...@@ -3,7 +3,7 @@ from typing import Dict, Optional
import torch import torch
import torchvision.transforms.functional as FF import torchvision.transforms.functional as FF
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers import StableDiffusionPipeline from diffusers import StableDiffusionPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
...@@ -69,7 +69,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -69,7 +69,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__( super().__init__(
......
...@@ -18,7 +18,7 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -18,7 +18,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
import torch import torch
from packaging import version from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers.configuration_utils import FrozenDict from diffusers.configuration_utils import FrozenDict
from diffusers.loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from diffusers.loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
...@@ -86,7 +86,7 @@ class StableDiffusionIPEXPipeline( ...@@ -86,7 +86,7 @@ class StableDiffusionIPEXPipeline(
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
...@@ -100,7 +100,7 @@ class StableDiffusionIPEXPipeline( ...@@ -100,7 +100,7 @@ class StableDiffusionIPEXPipeline(
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
......
...@@ -42,7 +42,7 @@ from polygraphy.backend.trt import ( ...@@ -42,7 +42,7 @@ from polygraphy.backend.trt import (
network_from_onnx_path, network_from_onnx_path,
save_engine, save_engine,
) )
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from diffusers.configuration_utils import FrozenDict, deprecate from diffusers.configuration_utils import FrozenDict, deprecate
...@@ -679,7 +679,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -679,7 +679,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
...@@ -693,7 +693,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -693,7 +693,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: DDIMScheduler, scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection = None, image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
stages=["clip", "unet", "vae", "vae_encoder"], stages=["clip", "unet", "vae", "vae_encoder"],
......
...@@ -42,7 +42,7 @@ from polygraphy.backend.trt import ( ...@@ -42,7 +42,7 @@ from polygraphy.backend.trt import (
network_from_onnx_path, network_from_onnx_path,
save_engine, save_engine,
) )
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from diffusers.configuration_utils import FrozenDict, deprecate from diffusers.configuration_utils import FrozenDict, deprecate
...@@ -683,7 +683,7 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -683,7 +683,7 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
...@@ -697,7 +697,7 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -697,7 +697,7 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: DDIMScheduler, scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection = None, image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
stages=["clip", "unet", "vae", "vae_encoder"], stages=["clip", "unet", "vae", "vae_encoder"],
......
...@@ -42,7 +42,7 @@ from polygraphy.backend.trt import ( ...@@ -42,7 +42,7 @@ from polygraphy.backend.trt import (
network_from_onnx_path, network_from_onnx_path,
save_engine, save_engine,
) )
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from diffusers.configuration_utils import FrozenDict, deprecate from diffusers.configuration_utils import FrozenDict, deprecate
...@@ -595,7 +595,7 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline): ...@@ -595,7 +595,7 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
...@@ -609,7 +609,7 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline): ...@@ -609,7 +609,7 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: DDIMScheduler, scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection = None, image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
stages=["clip", "unet", "vae"], stages=["clip", "unet", "vae"],
......
...@@ -43,7 +43,7 @@ from PIL import Image ...@@ -43,7 +43,7 @@ from PIL import Image
from torch.utils.data import default_collate from torch.utils.data import default_collate
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import AutoTokenizer, DPTFeatureExtractor, DPTForDepthEstimation, PretrainedConfig from transformers import AutoTokenizer, DPTForDepthEstimation, DPTImageProcessor, PretrainedConfig
from webdataset.tariterators import ( from webdataset.tariterators import (
base_plus_ext, base_plus_ext,
tar_file_expander, tar_file_expander,
...@@ -205,7 +205,7 @@ class Text2ImageDataset: ...@@ -205,7 +205,7 @@ class Text2ImageDataset:
pin_memory: bool = False, pin_memory: bool = False,
persistent_workers: bool = False, persistent_workers: bool = False,
control_type: str = "canny", control_type: str = "canny",
feature_extractor: Optional[DPTFeatureExtractor] = None, feature_extractor: Optional[DPTImageProcessor] = None,
): ):
if not isinstance(train_shards_path_or_url, str): if not isinstance(train_shards_path_or_url, str):
train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url] train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url]
...@@ -1011,7 +1011,7 @@ def main(args): ...@@ -1011,7 +1011,7 @@ def main(args):
controlnet = pre_controlnet controlnet = pre_controlnet
if args.control_type == "depth": if args.control_type == "depth":
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas") feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas") depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas")
depth_model.requires_grad_(False) depth_model.requires_grad_(False)
else: else:
......
...@@ -45,7 +45,7 @@ ...@@ -45,7 +45,7 @@
" UniPCMultistepScheduler,\n", " UniPCMultistepScheduler,\n",
" EulerDiscreteScheduler,\n", " EulerDiscreteScheduler,\n",
")\n", ")\n",
"from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer\n", "from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n",
"# pretrained_model_name_or_path = 'masterful/gligen-1-4-generation-text-box'\n", "# pretrained_model_name_or_path = 'masterful/gligen-1-4-generation-text-box'\n",
"\n", "\n",
"pretrained_model_name_or_path = '/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83'\n", "pretrained_model_name_or_path = '/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83'\n",
......
...@@ -4,7 +4,7 @@ from typing import Callable, List, Optional, Union ...@@ -4,7 +4,7 @@ from typing import Callable, List, Optional, Union
import torch import torch
from PIL import Image from PIL import Image
from retriever import Retriever, normalize_images, preprocess_images from retriever import Retriever, normalize_images, preprocess_images
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPModel, CLIPTokenizer
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
...@@ -47,7 +47,7 @@ class RDMPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -47,7 +47,7 @@ class RDMPipeline(DiffusionPipeline, StableDiffusionMixin):
scheduler ([`SchedulerMixin`]): scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
...@@ -65,7 +65,7 @@ class RDMPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -65,7 +65,7 @@ class RDMPipeline(DiffusionPipeline, StableDiffusionMixin):
EulerAncestralDiscreteScheduler, EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
], ],
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
retriever: Optional[Retriever] = None, retriever: Optional[Retriever] = None,
): ):
super().__init__() super().__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment