Unverified Commit 188bca30 authored by zhangp365's avatar zhangp365 Committed by GitHub
Browse files

fixed a dtype bfloat16 bug in torch_utils.py (#10125)



* fixed a dtype bfloat16 bug in torch_utils.py

when generating 1024*1024 image with bfloat16 dtype, there is an exception:
  File "/opt/conda/lib/python3.10/site-packages/diffusers/utils/torch_utils.py", line 107, in fourier_filter
    x_freq = fftn(x, dim=(-2, -1))
RuntimeError: Unsupported dtype BFloat16

* remove whitespace in torch_utils.py

* Update src/diffusers/utils/torch_utils.py

* Update torch_utils.py

---------
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent cd892041
......@@ -102,6 +102,9 @@ def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.T
# Non-power of 2 images must be float32
if (W & (W - 1)) != 0 or (H & (H - 1)) != 0:
x = x.to(dtype=torch.float32)
# fftn does not support bfloat16
elif x.dtype == torch.bfloat16:
x = x.to(dtype=torch.float32)
# FFT
x_freq = fftn(x, dim=(-2, -1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment