Unverified Commit 443aa14e authored by M. Tolga Cangöz's avatar M. Tolga Cangöz Committed by GitHub
Browse files

Fix Tiling in `ConsistencyDecoderVAE` (#7290)



* Fix typos

* Add docstring to `decode` method in `ConsistencyDecoderVAE`

* Fix tiling

* Enable tiled VAE decoding with customizable tile sample size and overlap factor

* Revert "Enable tiled VAE decoding with customizable tile sample size and overlap factor"

This reverts commit 181049675e83cea7b33ae2bbeba2aff7ae1b1761.

* Add VAE tiling test for `ConsistencyDecoderVAE`

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 288632ad
...@@ -63,7 +63,8 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -63,7 +63,8 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
... "runwayml/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16 ... "runwayml/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16
... ).to("cuda") ... ).to("cuda")
>>> pipe("horse", generator=torch.manual_seed(0)).images >>> image = pipe("horse", generator=torch.manual_seed(0)).images[0]
>>> image
``` ```
""" """
...@@ -72,6 +73,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -72,6 +73,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
self, self,
scaling_factor: float = 0.18215, scaling_factor: float = 0.18215,
latent_channels: int = 4, latent_channels: int = 4,
sample_size: int = 32,
encoder_act_fn: str = "silu", encoder_act_fn: str = "silu",
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
encoder_double_z: bool = True, encoder_double_z: bool = True,
...@@ -153,6 +155,16 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -153,6 +155,16 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
self.use_slicing = False self.use_slicing = False
self.use_tiling = False self.use_tiling = False
# only relevant if vae tiling is enabled
self.tile_sample_min_size = self.config.sample_size
sample_size = (
self.config.sample_size[0]
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
)
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_tiling # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_tiling
def enable_tiling(self, use_tiling: bool = True): def enable_tiling(self, use_tiling: bool = True):
r""" r"""
...@@ -272,7 +284,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -272,7 +284,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
Args: Args:
x (`torch.FloatTensor`): Input batch of images. x (`torch.FloatTensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.consistecy_decoder_vae.ConsistencyDecoderOoutput`] instead of a plain Whether to return a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] instead of a plain
tuple. tuple.
Returns: Returns:
...@@ -305,6 +317,19 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -305,6 +317,19 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
return_dict: bool = True, return_dict: bool = True,
num_inference_steps: int = 2, num_inference_steps: int = 2,
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]: ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
"""
Decodes the input latent vector `z` using the consistency decoder VAE model.
Args:
z (torch.FloatTensor): The input latent vector.
generator (Optional[torch.Generator]): The random number generator. Default is None.
return_dict (bool): Whether to return the output as a dictionary. Default is True.
num_inference_steps (int): The number of inference steps. Default is 2.
Returns:
Union[DecoderOutput, Tuple[torch.FloatTensor]]: The decoded output.
"""
z = (z * self.config.scaling_factor - self.means) / self.stds z = (z * self.config.scaling_factor - self.means) / self.stds
scale_factor = 2 ** (len(self.config.block_out_channels) - 1) scale_factor = 2 ** (len(self.config.block_out_channels) - 1)
...@@ -345,7 +370,9 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -345,7 +370,9 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b return b
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> ConsistencyDecoderVAEOutput: def tiled_encode(
self, x: torch.FloatTensor, return_dict: bool = True
) -> Union[ConsistencyDecoderVAEOutput, Tuple]:
r"""Encode a batch of images using a tiled encoder. r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
......
...@@ -1116,3 +1116,33 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase): ...@@ -1116,3 +1116,33 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
) )
assert torch_all_close(actual_output, expected_output, atol=5e-3) assert torch_all_close(actual_output, expected_output, atol=5e-3)
def test_vae_tiling(self):
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder")
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
out_1 = pipe(
"horse",
num_inference_steps=2,
output_type="pt",
generator=torch.Generator("cpu").manual_seed(0),
).images[0]
# make sure tiled vae decode yields the same result
pipe.enable_vae_tiling()
out_2 = pipe(
"horse",
num_inference_steps=2,
output_type="pt",
generator=torch.Generator("cpu").manual_seed(0),
).images[0]
assert torch_all_close(out_1, out_2, atol=5e-3)
# test that tiled decode works with various shapes
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
for shape in shapes:
image = torch.zeros(shape, device=torch_device)
pipe.vae.decode(image)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment