Unverified Commit d34b18c7 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Deprecate Stable Cascade (#12537)



* update

* update

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 7536f647
...@@ -21,7 +21,7 @@ from ...models import StableCascadeUNet ...@@ -21,7 +21,7 @@ from ...models import StableCascadeUNet
from ...schedulers import DDPMWuerstchenScheduler from ...schedulers import DDPMWuerstchenScheduler
from ...utils import is_torch_version, is_torch_xla_available, logging, replace_example_docstring from ...utils import is_torch_version, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
...@@ -55,7 +55,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -55,7 +55,7 @@ EXAMPLE_DOC_STRING = """
""" """
class StableCascadeDecoderPipeline(DiffusionPipeline): class StableCascadeDecoderPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
""" """
Pipeline for generating images from the Stable Cascade model. Pipeline for generating images from the Stable Cascade model.
...@@ -79,6 +79,8 @@ class StableCascadeDecoderPipeline(DiffusionPipeline): ...@@ -79,6 +79,8 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
width=int(24*10.67)=256 in order to match the training conditions. width=int(24*10.67)=256 in order to match the training conditions.
""" """
_last_supported_version = "0.35.2"
unet_name = "decoder" unet_name = "decoder"
text_encoder_name = "text_encoder" text_encoder_name = "text_encoder"
model_cpu_offload_seq = "text_encoder->decoder->vqgan" model_cpu_offload_seq = "text_encoder->decoder->vqgan"
......
...@@ -20,7 +20,7 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTo ...@@ -20,7 +20,7 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTo
from ...models import StableCascadeUNet from ...models import StableCascadeUNet
from ...schedulers import DDPMWuerstchenScheduler from ...schedulers import DDPMWuerstchenScheduler
from ...utils import is_torch_version, replace_example_docstring from ...utils import is_torch_version, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
from .pipeline_stable_cascade import StableCascadeDecoderPipeline from .pipeline_stable_cascade import StableCascadeDecoderPipeline
from .pipeline_stable_cascade_prior import StableCascadePriorPipeline from .pipeline_stable_cascade_prior import StableCascadePriorPipeline
...@@ -42,7 +42,7 @@ TEXT2IMAGE_EXAMPLE_DOC_STRING = """ ...@@ -42,7 +42,7 @@ TEXT2IMAGE_EXAMPLE_DOC_STRING = """
""" """
class StableCascadeCombinedPipeline(DiffusionPipeline): class StableCascadeCombinedPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
""" """
Combined Pipeline for text-to-image generation using Stable Cascade. Combined Pipeline for text-to-image generation using Stable Cascade.
...@@ -74,6 +74,8 @@ class StableCascadeCombinedPipeline(DiffusionPipeline): ...@@ -74,6 +74,8 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
""" """
_last_supported_version = "0.35.2"
_load_connected_pipes = True _load_connected_pipes = True
_optional_components = ["prior_feature_extractor", "prior_image_encoder"] _optional_components = ["prior_feature_extractor", "prior_image_encoder"]
......
...@@ -25,7 +25,7 @@ from ...models import StableCascadeUNet ...@@ -25,7 +25,7 @@ from ...models import StableCascadeUNet
from ...schedulers import DDPMWuerstchenScheduler from ...schedulers import DDPMWuerstchenScheduler
from ...utils import BaseOutput, is_torch_xla_available, logging, replace_example_docstring from ...utils import BaseOutput, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline
if is_torch_xla_available(): if is_torch_xla_available():
...@@ -77,7 +77,7 @@ class StableCascadePriorPipelineOutput(BaseOutput): ...@@ -77,7 +77,7 @@ class StableCascadePriorPipelineOutput(BaseOutput):
negative_prompt_embeds_pooled: Union[torch.Tensor, np.ndarray] negative_prompt_embeds_pooled: Union[torch.Tensor, np.ndarray]
class StableCascadePriorPipeline(DiffusionPipeline): class StableCascadePriorPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
""" """
Pipeline for generating image prior for Stable Cascade. Pipeline for generating image prior for Stable Cascade.
...@@ -103,6 +103,8 @@ class StableCascadePriorPipeline(DiffusionPipeline): ...@@ -103,6 +103,8 @@ class StableCascadePriorPipeline(DiffusionPipeline):
Default resolution for multiple images generated. Default resolution for multiple images generated.
""" """
_last_supported_version = "0.35.2"
unet_name = "prior" unet_name = "prior"
text_encoder_name = "text_encoder" text_encoder_name = "text_encoder"
model_cpu_offload_seq = "image_encoder->text_encoder->prior" model_cpu_offload_seq = "image_encoder->text_encoder->prior"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment