Unverified Commit 5791f4ac authored by Partho's avatar Partho Committed by GitHub
Browse files

[Type Hints] VAE models (#344)

* [Type Hints] VAE models

* apply suggestions from code review

apply suggestions to also return the return type
parent 878af0e1
from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
......@@ -293,7 +295,7 @@ class DiagonalGaussianDistribution(object):
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
def sample(self, generator=None):
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
x = self.mean + self.std * torch.randn(self.mean.shape, generator=generator, device=self.parameters.device)
return x
......@@ -327,16 +329,16 @@ class VQModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock2D",),
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(64,),
layers_per_block=1,
act_fn="silu",
latent_channels=3,
sample_size=32,
num_vq_embeddings=256,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 3,
sample_size: int = 32,
num_vq_embeddings: int = 256,
):
super().__init__()
......@@ -382,7 +384,7 @@ class VQModel(ModelMixin, ConfigMixin):
dec = self.decoder(quant)
return dec
def forward(self, sample):
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
x = sample
h = self.encode(x)
dec = self.decode(h)
......@@ -393,15 +395,15 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock2D",),
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(64,),
layers_per_block=1,
act_fn="silu",
latent_channels=4,
sample_size=32,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
sample_size: int = 32,
):
super().__init__()
......@@ -440,7 +442,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
dec = self.decoder(z)
return dec
def forward(self, sample, sample_posterior=False):
def forward(self, sample: torch.FloatTensor, sample_posterior: bool = False) -> torch.FloatTensor:
x = sample
posterior = self.encode(x)
if sample_posterior:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment