"vscode:/vscode.git/clone" did not exist on "3a278d701d3a0bba25ad52891653330ece2cb472"
Unverified Commit 14e3a28c authored by Naoki Ainoya's avatar Naoki Ainoya Committed by GitHub
Browse files

Rename 'CLIPFeatureExtractor' class to 'CLIPImageProcessor' (#2732)

The 'CLIPFeatureExtractor' class name has been renamed to 'CLIPImageProcessor' in order to comply with future deprecation. This commit includes the necessary changes to the affected files.
parent 8e35ef01
......@@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import PIL.Image
import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers import (
AutoencoderKL,
......@@ -47,7 +47,7 @@ class StableDiffusionMegaPipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionMegaSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
......@@ -60,7 +60,7 @@ class StableDiffusionMegaPipeline(DiffusionPipeline):
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
):
super().__init__()
......
......@@ -3,7 +3,7 @@ from typing import Callable, List, Optional, Union
import PIL
import torch
from transformers import (
CLIPFeatureExtractor,
CLIPImageProcessor,
CLIPSegForImageSegmentation,
CLIPSegProcessor,
CLIPTextModel,
......@@ -52,7 +52,7 @@ class TextInpainting(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
......@@ -66,7 +66,7 @@ class TextInpainting(DiffusionPipeline):
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
):
super().__init__()
......
......@@ -5,7 +5,7 @@ import PIL
import torch
from torch.nn import functional as F
from transformers import (
CLIPFeatureExtractor,
CLIPImageProcessor,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
......@@ -50,7 +50,7 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
feature_extractor ([`CLIPFeatureExtractor`]):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `image_encoder`.
image_encoder ([`CLIPVisionModelWithProjection`]):
Frozen CLIP image-encoder. unCLIP Image Variation uses the vision portion of
......@@ -75,7 +75,7 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
text_proj: UnCLIPTextProjModel
text_encoder: CLIPTextModelWithProjection
tokenizer: CLIPTokenizer
feature_extractor: CLIPFeatureExtractor
feature_extractor: CLIPImageProcessor
image_encoder: CLIPVisionModelWithProjection
super_res_first: UNet2DModel
super_res_last: UNet2DModel
......@@ -90,7 +90,7 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
text_encoder: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer,
text_proj: UnCLIPTextProjModel,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection,
super_res_first: UNet2DModel,
super_res_last: UNet2DModel,
......@@ -270,7 +270,7 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
The images to use for the image interpolation. Only accepts a list of two PIL Images or If you provide a tensor, it needs to comply with the
configuration of
[this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
`CLIPFeatureExtractor` while still having a shape of two in the 0th dimension. Can be left to `None` only when `image_embeddings` are passed.
`CLIPImageProcessor` while still having a shape of two in the 0th dimension. Can be left to `None` only when `image_embeddings` are passed.
steps (`int`, *optional*, defaults to 5):
The number of interpolation images to generate.
decoder_num_inference_steps (`int`, *optional*, defaults to 25):
......
......@@ -6,7 +6,7 @@ from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Union
import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import FrozenDict
......@@ -104,7 +104,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
......@@ -116,7 +116,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline):
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
):
super().__init__()
......
......@@ -22,7 +22,7 @@ from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
from diffusers import (
FlaxAutoencoderKL,
......@@ -652,7 +652,7 @@ def main():
tokenizer=tokenizer,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),
)
outdir = os.path.join(args.output_dir, str(step)) if step else args.output_dir
......
......@@ -23,7 +23,7 @@ from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
......@@ -632,7 +632,7 @@ def main():
tokenizer=tokenizer,
scheduler=PNDMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler"),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),
)
pipeline.save_pretrained(args.output_dir)
# Save the newly trained embeddings
......
......@@ -25,7 +25,7 @@ from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
from diffusers import (
FlaxAutoencoderKL,
......@@ -640,7 +640,7 @@ def main():
tokenizer=tokenizer,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),
)
pipeline.save_pretrained(
......
......@@ -20,7 +20,7 @@ from flax.training.common_utils import shard
from huggingface_hub import HfFolder, Repository, create_repo, whoami
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
from diffusers import (
FlaxAutoencoderKL,
......@@ -567,7 +567,7 @@ def main():
tokenizer=tokenizer,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),
)
pipeline.save_pretrained(
......
......@@ -25,7 +25,7 @@ from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
from diffusers import (
FlaxAutoencoderKL,
......@@ -667,7 +667,7 @@ def main():
tokenizer=tokenizer,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32"),
)
pipeline.save_pretrained(
......
......@@ -19,7 +19,7 @@ from argparse import Namespace
import torch
from transformers import (
CLIPFeatureExtractor,
CLIPImageProcessor,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
......@@ -774,7 +774,7 @@ if __name__ == "__main__":
vae.load_state_dict(converted_vae_checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
image_feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-large-patch14")
image_feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
......
......@@ -7,9 +7,9 @@ components - all of which are needed to have a functioning end-to-end diffusion
As an example, [Stable Diffusion](https://huggingface.co/blog/stable_diffusion) has three independently trained models:
- [Autoencoder](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/models/vae.py#L392)
- [Conditional Unet](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/models/unet_2d_condition.py#L12)
- [CLIP text encoder](https://huggingface.co/docs/transformers/v4.21.2/en/model_doc/clip#transformers.CLIPTextModel)
- [CLIP text encoder](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPTextModel)
- a scheduler component, [scheduler](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py),
- a [CLIPFeatureExtractor](https://huggingface.co/docs/transformers/v4.21.2/en/model_doc/clip#transformers.CLIPFeatureExtractor),
- a [CLIPImageProcessor](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPImageProcessor),
- as well as a [safety checker](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py).
All of these components are necessary to run stable diffusion in inference even though they were trained
or created independently from each other.
......
......@@ -17,7 +17,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch
from packaging import version
from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
from transformers import CLIPImageProcessor, XLMRobertaTokenizer
from diffusers.utils import is_accelerate_available, is_accelerate_version
......@@ -73,7 +73,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
......@@ -86,7 +86,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
):
super().__init__()
......
......@@ -19,7 +19,7 @@ import numpy as np
import PIL
import torch
from packaging import version
from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
from transformers import CLIPImageProcessor, XLMRobertaTokenizer
from diffusers.utils import is_accelerate_available, is_accelerate_version
......@@ -112,7 +112,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
......@@ -125,7 +125,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
):
super().__init__()
......
......@@ -18,7 +18,7 @@ from typing import Callable, List, Optional, Union
import numpy as np
import PIL
import torch
from transformers import CLIPFeatureExtractor
from transformers import CLIPImageProcessor
from diffusers.utils import is_accelerate_available
......@@ -156,7 +156,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
# TODO: feature_extractor is required to encode initial images (if they are in PIL format),
......@@ -170,7 +170,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = False,
):
super().__init__()
......
......@@ -3,7 +3,7 @@ from itertools import repeat
from typing import Callable, List, Optional, Union
import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
......@@ -84,7 +84,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
safety_checker ([`Q16SafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
......@@ -98,7 +98,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
):
super().__init__()
......
......@@ -19,7 +19,7 @@ import numpy as np
import PIL
import torch
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers.utils import is_accelerate_available, is_accelerate_version
......@@ -142,7 +142,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
......@@ -155,7 +155,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
unet: UNet2DConditionModel,
scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
):
super().__init__()
......
......@@ -24,7 +24,7 @@ from flax.jax_utils import unreplicate
from flax.training.common_utils import shard
from packaging import version
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
from ...schedulers import (
......@@ -103,7 +103,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
......@@ -117,7 +117,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
],
safety_checker: FlaxStableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
dtype: jnp.dtype = jnp.float32,
):
super().__init__()
......
......@@ -23,7 +23,7 @@ from flax.core.frozen_dict import FrozenDict
from flax.jax_utils import unreplicate
from flax.training.common_utils import shard
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
from ...schedulers import (
......@@ -127,7 +127,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
......@@ -141,7 +141,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
],
safety_checker: FlaxStableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
dtype: jnp.dtype = jnp.float32,
):
super().__init__()
......
......@@ -24,7 +24,7 @@ from flax.jax_utils import unreplicate
from flax.training.common_utils import shard
from packaging import version
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
from ...schedulers import (
......@@ -124,7 +124,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
......@@ -138,7 +138,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
],
safety_checker: FlaxStableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
dtype: jnp.dtype = jnp.float32,
):
super().__init__()
......
......@@ -17,7 +17,7 @@ from typing import Callable, List, Optional, Union
import numpy as np
import torch
from transformers import CLIPFeatureExtractor, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
......@@ -38,7 +38,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
unet: OnnxRuntimeModel
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
safety_checker: OnnxRuntimeModel
feature_extractor: CLIPFeatureExtractor
feature_extractor: CLIPImageProcessor
_optional_components = ["safety_checker", "feature_extractor"]
......@@ -51,7 +51,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
unet: OnnxRuntimeModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
):
super().__init__()
......@@ -333,7 +333,7 @@ class StableDiffusionOnnxPipeline(OnnxStableDiffusionPipeline):
unet: OnnxRuntimeModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
):
deprecation_message = "Please use `OnnxStableDiffusionPipeline` instead of `StableDiffusionOnnxPipeline`."
deprecate("StableDiffusionOnnxPipeline", "1.0.0", deprecation_message)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment