Unverified Commit 139f707e authored by lawfordp2017's avatar lawfordp2017 Committed by GitHub
Browse files

Correction for non-integral image resolutions with quantizations other than float32 (#7356)

* Correction for non-integral image resolutions with quantizations other than float32.

* Support for training, and use of diffusers-style casting.
parent e4546fd5
......@@ -521,9 +521,11 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin):
if isinstance(block, SDCascadeResBlock):
skip = level_outputs[i] if k == 0 and i > 0 else None
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
orig_type = x.dtype
x = torch.nn.functional.interpolate(
x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
)
x = x.to(orig_type)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), x, skip, use_reentrant=False
)
......@@ -547,9 +549,11 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin):
if isinstance(block, SDCascadeResBlock):
skip = level_outputs[i] if k == 0 and i > 0 else None
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
orig_type = x.dtype
x = torch.nn.functional.interpolate(
x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
)
x = x.to(orig_type)
x = block(x, skip)
elif isinstance(block, SDCascadeAttnBlock):
x = block(x, clip)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment