Unverified Commit 37a787a1 authored by stano's avatar stano Committed by GitHub
Browse files

Add docstring for the AutoencoderKL's decode (#5242)

* Add docstring for the AutoencoderKL's decode

#5230

* Follow the style guidelines in AutoencoderKL's decode

#5230

---------

Co-authored-by: stano <>
parent d56825e4
...@@ -281,6 +281,20 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): ...@@ -281,6 +281,20 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
@apply_forward_hook @apply_forward_hook
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
"""
Decode a batch of images.
Args:
z (`torch.FloatTensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1: if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices) decoded = torch.cat(decoded_slices)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment