Commit bd9c9fbf authored by Patrick von Platen's avatar Patrick von Platen
Browse files
parents f941fc99 e29fc446
......@@ -5,7 +5,6 @@ import math
import torch
import torch.nn as nn
try:
import einops
from einops.layers.torch import Rearrange
......@@ -13,7 +12,6 @@ except:
print("Einops is not installed")
pass
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
......@@ -107,14 +105,21 @@ class ResidualTemporalBlock(nn.Module):
class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
def __init__(
self,
horizon,
training_horizon,
transition_dim,
cond_dim,
predict_epsilon=False,
clip_denoised=True,
dim=32,
dim_mults=(1, 2, 4, 8),
):
super().__init__()
self.transition_dim = transition_dim
self.cond_dim = cond_dim
self.predict_epsilon = predict_epsilon
self.clip_denoised = clip_denoised
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
# print(f'[ models/temporal ] Channel dimensions: {in_out}')
......@@ -138,19 +143,19 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
self.downs.append(
nn.ModuleList(
[
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon),
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon),
Downsample1d(dim_out) if not is_last else nn.Identity(),
]
)
)
if not is_last:
horizon = horizon // 2
training_horizon = training_horizon // 2
mid_dim = dims[-1]
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon)
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon)
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon)
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
......@@ -158,15 +163,15 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
self.ups.append(
nn.ModuleList(
[
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon),
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon),
Upsample1d(dim_in) if not is_last else nn.Identity(),
]
)
)
if not is_last:
horizon = horizon * 2
training_horizon = training_horizon * 2
self.final_conv = nn.Sequential(
Conv1dBlock(dim, dim, kernel_size=5),
......@@ -232,7 +237,6 @@ class TemporalValue(nn.Module):
print(in_out)
for dim_in, dim_out in in_out:
self.blocks.append(
nn.ModuleList(
[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment