Unverified Commit 14e3a28c authored by Naoki Ainoya's avatar Naoki Ainoya Committed by GitHub
Browse files

Rename 'CLIPFeatureExtractor' class to 'CLIPImageProcessor' (#2732)

The 'CLIPFeatureExtractor' class name has been renamed to 'CLIPImageProcessor' in order to comply with future deprecation. This commit includes the necessary changes to the affected files.
parent 8e35ef01
...@@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import PIL.Image import PIL.Image
import torch import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
...@@ -47,7 +47,7 @@ class StableDiffusionMegaPipeline(DiffusionPipeline): ...@@ -47,7 +47,7 @@ class StableDiffusionMegaPipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionMegaSafetyChecker`]): safety_checker ([`StableDiffusionMegaSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
...@@ -60,7 +60,7 @@ class StableDiffusionMegaPipeline(DiffusionPipeline): ...@@ -60,7 +60,7 @@ class StableDiffusionMegaPipeline(DiffusionPipeline):
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
......
...@@ -3,7 +3,7 @@ from typing import Callable, List, Optional, Union ...@@ -3,7 +3,7 @@ from typing import Callable, List, Optional, Union
import PIL import PIL
import torch import torch
from transformers import ( from transformers import (
CLIPFeatureExtractor, CLIPImageProcessor,
CLIPSegForImageSegmentation, CLIPSegForImageSegmentation,
CLIPSegProcessor, CLIPSegProcessor,
CLIPTextModel, CLIPTextModel,
...@@ -52,7 +52,7 @@ class TextInpainting(DiffusionPipeline): ...@@ -52,7 +52,7 @@ class TextInpainting(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
...@@ -66,7 +66,7 @@ class TextInpainting(DiffusionPipeline): ...@@ -66,7 +66,7 @@ class TextInpainting(DiffusionPipeline):
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
): ):
super().__init__() super().__init__()
......
...@@ -5,7 +5,7 @@ import PIL ...@@ -5,7 +5,7 @@ import PIL
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from transformers import ( from transformers import (
CLIPFeatureExtractor, CLIPImageProcessor,
CLIPTextModelWithProjection, CLIPTextModelWithProjection,
CLIPTokenizer, CLIPTokenizer,
CLIPVisionModelWithProjection, CLIPVisionModelWithProjection,
...@@ -50,7 +50,7 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline): ...@@ -50,7 +50,7 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
tokenizer (`CLIPTokenizer`): tokenizer (`CLIPTokenizer`):
Tokenizer of class Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `image_encoder`. Model that extracts features from generated images to be used as inputs for the `image_encoder`.
image_encoder ([`CLIPVisionModelWithProjection`]): image_encoder ([`CLIPVisionModelWithProjection`]):
Frozen CLIP image-encoder. unCLIP Image Variation uses the vision portion of Frozen CLIP image-encoder. unCLIP Image Variation uses the vision portion of
...@@ -75,7 +75,7 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline): ...@@ -75,7 +75,7 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
text_proj: UnCLIPTextProjModel text_proj: UnCLIPTextProjModel
text_encoder: CLIPTextModelWithProjection text_encoder: CLIPTextModelWithProjection
tokenizer: CLIPTokenizer tokenizer: CLIPTokenizer
feature_extractor: CLIPFeatureExtractor feature_extractor: CLIPImageProcessor
image_encoder: CLIPVisionModelWithProjection image_encoder: CLIPVisionModelWithProjection
super_res_first: UNet2DModel super_res_first: UNet2DModel
super_res_last: UNet2DModel super_res_last: UNet2DModel
...@@ -90,7 +90,7 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline): ...@@ -90,7 +90,7 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
text_encoder: CLIPTextModelWithProjection, text_encoder: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
text_proj: UnCLIPTextProjModel, text_proj: UnCLIPTextProjModel,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection, image_encoder: CLIPVisionModelWithProjection,
super_res_first: UNet2DModel, super_res_first: UNet2DModel,
super_res_last: UNet2DModel, super_res_last: UNet2DModel,
...@@ -270,7 +270,7 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline): ...@@ -270,7 +270,7 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
The images to use for the image interpolation. Only accepts a list of two PIL Images or If you provide a tensor, it needs to comply with the The images to use for the image interpolation. Only accepts a list of two PIL Images or If you provide a tensor, it needs to comply with the
configuration of configuration of
[this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json) [this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
`CLIPFeatureExtractor` while still having a shape of two in the 0th dimension. Can be left to `None` only when `image_embeddings` are passed. `CLIPImageProcessor` while still having a shape of two in the 0th dimension. Can be left to `None` only when `image_embeddings` are passed.
steps (`int`, *optional*, defaults to 5): steps (`int`, *optional*, defaults to 5):
The number of interpolation images to generate. The number of interpolation images to generate.
decoder_num_inference_steps (`int`, *optional*, defaults to 25): decoder_num_inference_steps (`int`, *optional*, defaults to 25):
......
...@@ -6,7 +6,7 @@ from dataclasses import dataclass ...@@ -6,7 +6,7 @@ from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
import torch import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from diffusers.configuration_utils import FrozenDict from diffusers.configuration_utils import FrozenDict
...@@ -104,7 +104,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline): ...@@ -104,7 +104,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
...@@ -116,7 +116,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline): ...@@ -116,7 +116,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline):
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
): ):
super().__init__() super().__init__()
......
...@@ -22,7 +22,7 @@ from PIL import Image ...@@ -22,7 +22,7 @@ from PIL import Image
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
from diffusers import ( from diffusers import (
FlaxAutoencoderKL, FlaxAutoencoderKL,
...@@ -652,7 +652,7 @@ def main(): ...@@ -652,7 +652,7 @@ def main():
tokenizer=tokenizer, tokenizer=tokenizer,
scheduler=scheduler, scheduler=scheduler,
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),
) )
outdir = os.path.join(args.output_dir, str(step)) if step else args.output_dir outdir = os.path.join(args.output_dir, str(step)) if step else args.output_dir
......
...@@ -23,7 +23,7 @@ from PIL import Image ...@@ -23,7 +23,7 @@ from PIL import Image
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
...@@ -632,7 +632,7 @@ def main(): ...@@ -632,7 +632,7 @@ def main():
tokenizer=tokenizer, tokenizer=tokenizer,
scheduler=PNDMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler"), scheduler=PNDMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler"),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),
) )
pipeline.save_pretrained(args.output_dir) pipeline.save_pretrained(args.output_dir)
# Save the newly trained embeddings # Save the newly trained embeddings
......
...@@ -25,7 +25,7 @@ from PIL import Image ...@@ -25,7 +25,7 @@ from PIL import Image
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
from diffusers import ( from diffusers import (
FlaxAutoencoderKL, FlaxAutoencoderKL,
...@@ -640,7 +640,7 @@ def main(): ...@@ -640,7 +640,7 @@ def main():
tokenizer=tokenizer, tokenizer=tokenizer,
scheduler=scheduler, scheduler=scheduler,
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),
) )
pipeline.save_pretrained( pipeline.save_pretrained(
......
...@@ -20,7 +20,7 @@ from flax.training.common_utils import shard ...@@ -20,7 +20,7 @@ from flax.training.common_utils import shard
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
from diffusers import ( from diffusers import (
FlaxAutoencoderKL, FlaxAutoencoderKL,
...@@ -567,7 +567,7 @@ def main(): ...@@ -567,7 +567,7 @@ def main():
tokenizer=tokenizer, tokenizer=tokenizer,
scheduler=scheduler, scheduler=scheduler,
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),
) )
pipeline.save_pretrained( pipeline.save_pretrained(
......
...@@ -25,7 +25,7 @@ from PIL import Image ...@@ -25,7 +25,7 @@ from PIL import Image
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
from diffusers import ( from diffusers import (
FlaxAutoencoderKL, FlaxAutoencoderKL,
...@@ -667,7 +667,7 @@ def main(): ...@@ -667,7 +667,7 @@ def main():
tokenizer=tokenizer, tokenizer=tokenizer,
scheduler=scheduler, scheduler=scheduler,
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),
) )
pipeline.save_pretrained( pipeline.save_pretrained(
......
...@@ -19,7 +19,7 @@ from argparse import Namespace ...@@ -19,7 +19,7 @@ from argparse import Namespace
import torch import torch
from transformers import ( from transformers import (
CLIPFeatureExtractor, CLIPImageProcessor,
CLIPTextModelWithProjection, CLIPTextModelWithProjection,
CLIPTokenizer, CLIPTokenizer,
CLIPVisionModelWithProjection, CLIPVisionModelWithProjection,
...@@ -774,7 +774,7 @@ if __name__ == "__main__": ...@@ -774,7 +774,7 @@ if __name__ == "__main__":
vae.load_state_dict(converted_vae_checkpoint) vae.load_state_dict(converted_vae_checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
image_feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-large-patch14") image_feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") text_encoder = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
......
...@@ -7,9 +7,9 @@ components - all of which are needed to have a functioning end-to-end diffusion ...@@ -7,9 +7,9 @@ components - all of which are needed to have a functioning end-to-end diffusion
As an example, [Stable Diffusion](https://huggingface.co/blog/stable_diffusion) has three independently trained models: As an example, [Stable Diffusion](https://huggingface.co/blog/stable_diffusion) has three independently trained models:
- [Autoencoder](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/models/vae.py#L392) - [Autoencoder](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/models/vae.py#L392)
- [Conditional Unet](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/models/unet_2d_condition.py#L12) - [Conditional Unet](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/models/unet_2d_condition.py#L12)
- [CLIP text encoder](https://huggingface.co/docs/transformers/v4.21.2/en/model_doc/clip#transformers.CLIPTextModel) - [CLIP text encoder](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPTextModel)
- a scheduler component, [scheduler](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py), - a scheduler component, [scheduler](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py),
- a [CLIPFeatureExtractor](https://huggingface.co/docs/transformers/v4.21.2/en/model_doc/clip#transformers.CLIPFeatureExtractor), - a [CLIPImageProcessor](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPImageProcessor),
- as well as a [safety checker](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py). - as well as a [safety checker](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py).
All of these components are necessary to run stable diffusion in inference even though they were trained All of these components are necessary to run stable diffusion in inference even though they were trained
or created independently from each other. or created independently from each other.
......
...@@ -17,7 +17,7 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -17,7 +17,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
from packaging import version from packaging import version
from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer from transformers import CLIPImageProcessor, XLMRobertaTokenizer
from diffusers.utils import is_accelerate_available, is_accelerate_version from diffusers.utils import is_accelerate_available, is_accelerate_version
...@@ -73,7 +73,7 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -73,7 +73,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
...@@ -86,7 +86,7 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -86,7 +86,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
......
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
import PIL import PIL
import torch import torch
from packaging import version from packaging import version
from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer from transformers import CLIPImageProcessor, XLMRobertaTokenizer
from diffusers.utils import is_accelerate_available, is_accelerate_version from diffusers.utils import is_accelerate_available, is_accelerate_version
...@@ -112,7 +112,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -112,7 +112,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
...@@ -125,7 +125,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -125,7 +125,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
......
...@@ -18,7 +18,7 @@ from typing import Callable, List, Optional, Union ...@@ -18,7 +18,7 @@ from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
import PIL import PIL
import torch import torch
from transformers import CLIPFeatureExtractor from transformers import CLIPImageProcessor
from diffusers.utils import is_accelerate_available from diffusers.utils import is_accelerate_available
...@@ -156,7 +156,7 @@ class PaintByExamplePipeline(DiffusionPipeline): ...@@ -156,7 +156,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
# TODO: feature_extractor is required to encode initial images (if they are in PIL format), # TODO: feature_extractor is required to encode initial images (if they are in PIL format),
...@@ -170,7 +170,7 @@ class PaintByExamplePipeline(DiffusionPipeline): ...@@ -170,7 +170,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = False, requires_safety_checker: bool = False,
): ):
super().__init__() super().__init__()
......
...@@ -3,7 +3,7 @@ from itertools import repeat ...@@ -3,7 +3,7 @@ from itertools import repeat
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import torch import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
...@@ -84,7 +84,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline): ...@@ -84,7 +84,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
safety_checker ([`Q16SafetyChecker`]): safety_checker ([`Q16SafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
...@@ -98,7 +98,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline): ...@@ -98,7 +98,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
......
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
import PIL import PIL
import torch import torch
from packaging import version from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers.utils import is_accelerate_available, is_accelerate_version from diffusers.utils import is_accelerate_available, is_accelerate_version
...@@ -142,7 +142,7 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -142,7 +142,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]): safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
...@@ -155,7 +155,7 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -155,7 +155,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: DDIMScheduler, scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
......
...@@ -24,7 +24,7 @@ from flax.jax_utils import unreplicate ...@@ -24,7 +24,7 @@ from flax.jax_utils import unreplicate
from flax.training.common_utils import shard from flax.training.common_utils import shard
from packaging import version from packaging import version
from PIL import Image from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
from ...schedulers import ( from ...schedulers import (
...@@ -103,7 +103,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -103,7 +103,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
safety_checker ([`FlaxStableDiffusionSafetyChecker`]): safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
...@@ -117,7 +117,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -117,7 +117,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
], ],
safety_checker: FlaxStableDiffusionSafetyChecker, safety_checker: FlaxStableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
): ):
super().__init__() super().__init__()
......
...@@ -23,7 +23,7 @@ from flax.core.frozen_dict import FrozenDict ...@@ -23,7 +23,7 @@ from flax.core.frozen_dict import FrozenDict
from flax.jax_utils import unreplicate from flax.jax_utils import unreplicate
from flax.training.common_utils import shard from flax.training.common_utils import shard
from PIL import Image from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
from ...schedulers import ( from ...schedulers import (
...@@ -127,7 +127,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): ...@@ -127,7 +127,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
safety_checker ([`FlaxStableDiffusionSafetyChecker`]): safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
...@@ -141,7 +141,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): ...@@ -141,7 +141,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
], ],
safety_checker: FlaxStableDiffusionSafetyChecker, safety_checker: FlaxStableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
): ):
super().__init__() super().__init__()
......
...@@ -24,7 +24,7 @@ from flax.jax_utils import unreplicate ...@@ -24,7 +24,7 @@ from flax.jax_utils import unreplicate
from flax.training.common_utils import shard from flax.training.common_utils import shard
from packaging import version from packaging import version
from PIL import Image from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
from ...schedulers import ( from ...schedulers import (
...@@ -124,7 +124,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): ...@@ -124,7 +124,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
safety_checker ([`FlaxStableDiffusionSafetyChecker`]): safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
...@@ -138,7 +138,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): ...@@ -138,7 +138,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
], ],
safety_checker: FlaxStableDiffusionSafetyChecker, safety_checker: FlaxStableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
): ):
super().__init__() super().__init__()
......
...@@ -17,7 +17,7 @@ from typing import Callable, List, Optional, Union ...@@ -17,7 +17,7 @@ from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from transformers import CLIPFeatureExtractor, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
...@@ -38,7 +38,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -38,7 +38,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
unet: OnnxRuntimeModel unet: OnnxRuntimeModel
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
safety_checker: OnnxRuntimeModel safety_checker: OnnxRuntimeModel
feature_extractor: CLIPFeatureExtractor feature_extractor: CLIPImageProcessor
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
...@@ -51,7 +51,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -51,7 +51,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
unet: OnnxRuntimeModel, unet: OnnxRuntimeModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: OnnxRuntimeModel, safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
...@@ -333,7 +333,7 @@ class StableDiffusionOnnxPipeline(OnnxStableDiffusionPipeline): ...@@ -333,7 +333,7 @@ class StableDiffusionOnnxPipeline(OnnxStableDiffusionPipeline):
unet: OnnxRuntimeModel, unet: OnnxRuntimeModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: OnnxRuntimeModel, safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPImageProcessor,
): ):
deprecation_message = "Please use `OnnxStableDiffusionPipeline` instead of `StableDiffusionOnnxPipeline`." deprecation_message = "Please use `OnnxStableDiffusionPipeline` instead of `StableDiffusionOnnxPipeline`."
deprecate("StableDiffusionOnnxPipeline", "1.0.0", deprecation_message) deprecate("StableDiffusionOnnxPipeline", "1.0.0", deprecation_message)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment