Unverified Commit 9a92b817 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Allegro VAE fix (#9811)

fix
parent 0d1d267b
......@@ -1091,8 +1091,6 @@ class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
encoder_local_batch_size: int = 2,
decoder_local_batch_size: int = 2,
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Args:
......@@ -1103,18 +1101,14 @@ class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
generator (`torch.Generator`, *optional*):
PyTorch random number generator.
encoder_local_batch_size (`int`, *optional*, defaults to 2):
Local batch size for the encoder's batch inference.
decoder_local_batch_size (`int`, *optional*, defaults to 2):
Local batch size for the decoder's batch inference.
"""
x = sample
posterior = self.encode(x, local_batch_size=encoder_local_batch_size).latent_dist
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, local_batch_size=decoder_local_batch_size).sample
dec = self.decode(z).sample
if not return_dict:
return (dec,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment