"...source/git@developer.sourcefind.cn:OpenDAS/diffusers.git" did not exist on "757babfcadeabda57a1f080335779892e08f5b0b"
Unverified Commit 39dfb7ab authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Raise an error when trying to use SD Cascade Decoder with dtype bfloat16 and torch < 2.2 (#7244)

update
parent 19683569
...@@ -19,7 +19,7 @@ from transformers import CLIPTextModel, CLIPTokenizer ...@@ -19,7 +19,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
from ...models import StableCascadeUNet from ...models import StableCascadeUNet
from ...schedulers import DDPMWuerstchenScheduler from ...schedulers import DDPMWuerstchenScheduler
from ...utils import logging, replace_example_docstring from ...utils import is_torch_version, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
...@@ -361,6 +361,8 @@ class StableCascadeDecoderPipeline(DiffusionPipeline): ...@@ -361,6 +361,8 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
device = self._execution_device device = self._execution_device
dtype = self.decoder.dtype dtype = self.decoder.dtype
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
if is_torch_version("<", "2.2.0") and dtype == torch.bfloat16:
raise ValueError("`StableCascadeDecoderPipeline` requires torch>=2.2.0 when using `torch.bfloat16` dtype.")
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment