"vscode:/vscode.git/clone" did not exist on "a09d129c020cba115e6584ff526224cf4c6f30ac"
Unverified Commit 07f8ebd5 authored by Samuel Ajisegiri's avatar Samuel Ajisegiri Committed by GitHub
Browse files

type hints: models/vae.py (#346)



* type hints: models/vae.py

* modify typings in vae.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>
parent ada09bd3
...@@ -487,7 +487,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -487,7 +487,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
def encode(self, x, return_dict: bool = True): def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
h = self.encoder(x) h = self.encoder(x)
moments = self.quant_conv(h) moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments) posterior = DiagonalGaussianDistribution(moments)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment