Unverified Commit 39dfb7ab authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Raise an error when trying to use SD Cascade Decoder with dtype bfloat16 and torch < 2.2 (#7244)

update
parent 19683569
...@@ -19,7 +19,7 @@ from transformers import CLIPTextModel, CLIPTokenizer ...@@ -19,7 +19,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
from ...models import StableCascadeUNet from ...models import StableCascadeUNet
from ...schedulers import DDPMWuerstchenScheduler from ...schedulers import DDPMWuerstchenScheduler
from ...utils import logging, replace_example_docstring from ...utils import is_torch_version, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
...@@ -361,6 +361,8 @@ class StableCascadeDecoderPipeline(DiffusionPipeline): ...@@ -361,6 +361,8 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
device = self._execution_device device = self._execution_device
dtype = self.decoder.dtype dtype = self.decoder.dtype
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
if is_torch_version("<", "2.2.0") and dtype == torch.bfloat16:
raise ValueError("`StableCascadeDecoderPipeline` requires torch>=2.2.0 when using `torch.bfloat16` dtype.")
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment