Unverified Commit b1fe1706 authored by Sid Sahai's avatar Sid Sahai Committed by GitHub
Browse files

[Type Hint] Unet Models (#330)

* add void check

* remove void, add types for params
parent 9b704f76
from typing import Dict, Union from typing import Dict, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -13,23 +13,23 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -13,23 +13,23 @@ class UNet2DModel(ModelMixin, ConfigMixin):
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
sample_size=None, sample_size: Optional[int] = None,
in_channels=3, in_channels: int = 3,
out_channels=3, out_channels: int = 3,
center_input_sample=False, center_input_sample: bool = False,
time_embedding_type="positional", time_embedding_type: str = "positional",
freq_shift=0, freq_shift: int = 0,
flip_sin_to_cos=True, flip_sin_to_cos: bool = True,
down_block_types=("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
block_out_channels=(224, 448, 672, 896), block_out_channels: Tuple[int] = (224, 448, 672, 896),
layers_per_block=2, layers_per_block: int = 2,
mid_block_scale_factor=1, mid_block_scale_factor: float = 1,
downsample_padding=1, downsample_padding: int = 1,
act_fn="silu", act_fn: str = "silu",
attention_head_dim=8, attention_head_dim: int = 8,
norm_num_groups=32, norm_num_groups: int = 32,
norm_eps=1e-5, norm_eps: float = 1e-5,
): ):
super().__init__() super().__init__()
......
from typing import Dict, Union from typing import Dict, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -13,23 +13,28 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -13,23 +13,28 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
sample_size=None, sample_size: Optional[int] = None,
in_channels=4, in_channels: int = 4,
out_channels=4, out_channels: int = 4,
center_input_sample=False, center_input_sample: bool = False,
flip_sin_to_cos=True, flip_sin_to_cos: bool = True,
freq_shift=0, freq_shift: int = 0,
down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"), down_block_types: Tuple[str] = (
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), "CrossAttnDownBlock2D",
block_out_channels=(320, 640, 1280, 1280), "CrossAttnDownBlock2D",
layers_per_block=2, "CrossAttnDownBlock2D",
downsample_padding=1, "DownBlock2D",
mid_block_scale_factor=1, ),
act_fn="silu", up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
norm_num_groups=32, block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
norm_eps=1e-5, layers_per_block: int = 2,
cross_attention_dim=1280, downsample_padding: int = 1,
attention_head_dim=8, mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
attention_head_dim: int = 8,
): ):
super().__init__() super().__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment