Commit 66ee73ee authored by patil-suraj's avatar patil-suraj
Browse files

refactor up/down sample blocks in unet_rl

parent 597b7ae2
......@@ -6,7 +6,7 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
from .resnet import ResidualTemporalBlock
from .resnet import Downsample, ResidualTemporalBlock, Upsample
class SinusoidalPosEmb(nn.Module):
......@@ -18,24 +18,6 @@ class SinusoidalPosEmb(nn.Module):
return get_timestep_embedding(x, self.dim)
class Downsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
class RearrangeDim(nn.Module):
def __init__(self):
super().__init__()
......@@ -114,7 +96,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
[
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon),
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon),
Downsample1d(dim_out) if not is_last else nn.Identity(),
Downsample(dim_out, use_conv=True, dims=1) if not is_last else nn.Identity(),
]
)
)
......@@ -134,7 +116,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
[
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon),
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon),
Upsample1d(dim_in) if not is_last else nn.Identity(),
Upsample(dim_in, use_conv_transpose=True, dims=1) if not is_last else nn.Identity(),
]
)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment