"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "d03b776a5cd1f4d125eacf127f95d8571a852137"
Unverified Commit 052bf328 authored by Ollin Boer Bohan's avatar Ollin Boer Bohan Committed by GitHub
Browse files

Fix AutoencoderTiny encoder scaling convention (#4682)

* Fix AutoencoderTiny encoder scaling convention

  * Add [-1, 1] -> [0, 1] rescaling to EncoderTiny

  * Move [0, 1] -> [-1, 1] rescaling from AutoencoderTiny.decode to DecoderTiny
    (i.e. immediately after the final conv, as early as possible)

  * Fix missing [0, 255] -> [0, 1] rescaling in AutoencoderTiny.forward

  * Update AutoencoderTinyIntegrationTests to protect against scaling issues.
    The new test constructs a simple image, round-trips it through AutoencoderTiny,
    and confirms the decoded result is approximately equal to the source image.
    This test checks behavior with and without tiling enabled.
    This test will fail if new AutoencoderTiny scaling issues are introduced.

  * Context: Raw TAESD weights expect images in [0, 1], but diffusers'
    convention represents images with zero-centered values in [-1, 1],
    so AutoencoderTiny needs to scale / unscale images at the start of
    encoding and at the end of decoding in order to work with diffusers.

* Re-add existing AutoencoderTiny test, update golden values

* Add comments to AutoencoderTiny.forward
parent 80871ac5
...@@ -312,9 +312,6 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): ...@@ -312,9 +312,6 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
output = torch.cat(output) output = torch.cat(output)
else: else:
output = self._tiled_decode(x) if self.use_tiling else self.decoder(x) output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
# Refer to the following discussion to know why this is needed.
# https://github.com/huggingface/diffusers/pull/4384#discussion_r1279401854
output = output.mul_(2).sub_(1)
if not return_dict: if not return_dict:
return (output,) return (output,)
...@@ -333,8 +330,15 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): ...@@ -333,8 +330,15 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple. Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
""" """
enc = self.encode(sample).latents enc = self.encode(sample).latents
# scale latents to be in [0, 1], then quantize latents to a byte tensor,
# as if we were storing the latents in an RGBA uint8 image.
scaled_enc = self.scale_latents(enc).mul_(255).round_().byte() scaled_enc = self.scale_latents(enc).mul_(255).round_().byte()
unscaled_enc = self.unscale_latents(scaled_enc)
# unquantize latents back into [0, 1], then unscale latents back to their original range,
# as if we were loading the latents from an RGBA uint8 image.
unscaled_enc = self.unscale_latents(scaled_enc / 255.0)
dec = self.decode(unscaled_enc) dec = self.decode(unscaled_enc)
if not return_dict: if not return_dict:
......
...@@ -732,7 +732,8 @@ class EncoderTiny(nn.Module): ...@@ -732,7 +732,8 @@ class EncoderTiny(nn.Module):
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x) x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
else: else:
x = self.layers(x) # scale image from [-1, 1] to [0, 1] to match TAESD convention
x = self.layers(x.add(1).div(2))
return x return x
...@@ -790,4 +791,5 @@ class DecoderTiny(nn.Module): ...@@ -790,4 +791,5 @@ class DecoderTiny(nn.Module):
else: else:
x = self.layers(x) x = self.layers(x)
return x # scale image from [0, 1] to [-1, 1] to match diffusers convention
return x.mul(2).sub(1)
...@@ -312,10 +312,32 @@ class AutoencoderTinyIntegrationTests(unittest.TestCase): ...@@ -312,10 +312,32 @@ class AutoencoderTinyIntegrationTests(unittest.TestCase):
assert sample.shape == image.shape assert sample.shape == image.shape
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor([0.9858, 0.9262, 0.8629, 1.0974, -0.091, -0.2485, 0.0936, 0.0604]) expected_output_slice = torch.tensor([0.0093, 0.6385, -0.1274, 0.1631, -0.1762, 0.5232, -0.3108, -0.0382])
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3) assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
@parameterized.expand([(True,), (False,)])
def test_tae_roundtrip(self, enable_tiling):
# load the autoencoder
model = self.get_sd_vae_model()
if enable_tiling:
model.enable_tiling()
# make a black image with a white square in the middle,
# which is large enough to split across multiple tiles
image = -torch.ones(1, 3, 1024, 1024, device=torch_device)
image[..., 256:768, 256:768] = 1.0
# round-trip the image through the autoencoder
with torch.no_grad():
sample = model(image).sample
# the autoencoder reconstruction should match original image, sorta
def downscale(x):
return torch.nn.functional.avg_pool2d(x, model.spatial_scale_factor)
assert torch_all_close(downscale(sample), downscale(image), atol=0.125)
@slow @slow
class AutoencoderKLIntegrationTests(unittest.TestCase): class AutoencoderKLIntegrationTests(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment