Unverified Commit 5791f4ac authored by Partho's avatar Partho Committed by GitHub
Browse files

[Type Hints] VAE models (#344)

* [Type Hints] VAE models

* apply suggestions from code review

apply suggestions to also return the return type
parent 878af0e1
from typing import Optional, Tuple
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -293,7 +295,7 @@ class DiagonalGaussianDistribution(object): ...@@ -293,7 +295,7 @@ class DiagonalGaussianDistribution(object):
if self.deterministic: if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
def sample(self, generator=None): def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
x = self.mean + self.std * torch.randn(self.mean.shape, generator=generator, device=self.parameters.device) x = self.mean + self.std * torch.randn(self.mean.shape, generator=generator, device=self.parameters.device)
return x return x
...@@ -327,16 +329,16 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -327,16 +329,16 @@ class VQModel(ModelMixin, ConfigMixin):
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
in_channels=3, in_channels: int = 3,
out_channels=3, out_channels: int = 3,
down_block_types=("DownEncoderBlock2D",), down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
up_block_types=("UpDecoderBlock2D",), up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
block_out_channels=(64,), block_out_channels: Tuple[int] = (64,),
layers_per_block=1, layers_per_block: int = 1,
act_fn="silu", act_fn: str = "silu",
latent_channels=3, latent_channels: int = 3,
sample_size=32, sample_size: int = 32,
num_vq_embeddings=256, num_vq_embeddings: int = 256,
): ):
super().__init__() super().__init__()
...@@ -382,7 +384,7 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -382,7 +384,7 @@ class VQModel(ModelMixin, ConfigMixin):
dec = self.decoder(quant) dec = self.decoder(quant)
return dec return dec
def forward(self, sample): def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
x = sample x = sample
h = self.encode(x) h = self.encode(x)
dec = self.decode(h) dec = self.decode(h)
...@@ -393,15 +395,15 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -393,15 +395,15 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
in_channels=3, in_channels: int = 3,
out_channels=3, out_channels: int = 3,
down_block_types=("DownEncoderBlock2D",), down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
up_block_types=("UpDecoderBlock2D",), up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
block_out_channels=(64,), block_out_channels: Tuple[int] = (64,),
layers_per_block=1, layers_per_block: int = 1,
act_fn="silu", act_fn: str = "silu",
latent_channels=4, latent_channels: int = 4,
sample_size=32, sample_size: int = 32,
): ):
super().__init__() super().__init__()
...@@ -440,7 +442,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -440,7 +442,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
dec = self.decoder(z) dec = self.decoder(z)
return dec return dec
def forward(self, sample, sample_posterior=False): def forward(self, sample: torch.FloatTensor, sample_posterior: bool = False) -> torch.FloatTensor:
x = sample x = sample
posterior = self.encode(x) posterior = self.encode(x)
if sample_posterior: if sample_posterior:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment