# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py import torch import torch.nn as nn import einops from einops.layers.torch import Rearrange import math class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = x[:, None] * emb[None, :] emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb class Downsample1d(nn.Module): def __init__(self, dim): super().__init__() self.conv = nn.Conv1d(dim, dim, 3, 2, 1) def forward(self, x): return self.conv(x) class Upsample1d(nn.Module): def __init__(self, dim): super().__init__() self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) def forward(self, x): return self.conv(x) class Conv1dBlock(nn.Module): ''' Conv1d --> GroupNorm --> Mish ''' def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): super().__init__() self.block = nn.Sequential( nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), Rearrange('batch channels horizon -> batch channels 1 horizon'), nn.GroupNorm(n_groups, out_channels), Rearrange('batch channels 1 horizon -> batch channels horizon'), nn.Mish(), ) def forward(self, x): return self.block(x) class ResidualTemporalBlock(nn.Module): def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5): super().__init__() self.blocks = nn.ModuleList([ Conv1dBlock(inp_channels, out_channels, kernel_size), Conv1dBlock(out_channels, out_channels, kernel_size), ]) self.time_mlp = nn.Sequential( nn.Mish(), nn.Linear(embed_dim, out_channels), Rearrange('batch t -> batch t 1'), ) self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1) \ if inp_channels != out_channels else nn.Identity() def forward(self, x, t): ''' x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x out_channels x horizon ] ''' out = self.blocks[0](x) + self.time_mlp(t) out = self.blocks[1](out) return out + self.residual_conv(x) class TemporalUnet(nn.Module): def __init__( self, horizon, transition_dim, cond_dim, dim=32, dim_mults=(1, 2, 4, 8), ): super().__init__() dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) print(f'[ models/temporal ] Channel dimensions: {in_out}') time_dim = dim self.time_mlp = nn.Sequential( SinusoidalPosEmb(dim), nn.Linear(dim, dim * 4), nn.Mish(), nn.Linear(dim * 4, dim), ) self.downs = nn.ModuleList([]) self.ups = nn.ModuleList([]) num_resolutions = len(in_out) print(in_out) for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (num_resolutions - 1) self.downs.append(nn.ModuleList([ ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=horizon), ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=horizon), Downsample1d(dim_out) if not is_last else nn.Identity() ])) if not is_last: horizon = horizon // 2 mid_dim = dims[-1] self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon) self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon) for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): is_last = ind >= (num_resolutions - 1) self.ups.append(nn.ModuleList([ ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=horizon), ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=horizon), Upsample1d(dim_in) if not is_last else nn.Identity() ])) if not is_last: horizon = horizon * 2 self.final_conv = nn.Sequential( Conv1dBlock(dim, dim, kernel_size=5), nn.Conv1d(dim, transition_dim, 1), ) def forward(self, x, cond, time): ''' x : [ batch x horizon x transition ] ''' x = einops.rearrange(x, 'b h t -> b t h') t = self.time_mlp(time) h = [] for resnet, resnet2, downsample in self.downs: x = resnet(x, t) x = resnet2(x, t) h.append(x) x = downsample(x) x = self.mid_block1(x, t) x = self.mid_block2(x, t) for resnet, resnet2, upsample in self.ups: x = torch.cat((x, h.pop()), dim=1) x = resnet(x, t) x = resnet2(x, t) x = upsample(x) x = self.final_conv(x) x = einops.rearrange(x, 'b t h -> b h t') return x class TemporalValue(nn.Module): def __init__( self, horizon, transition_dim, cond_dim, dim=32, time_dim=None, out_dim=1, dim_mults=(1, 2, 4, 8), ): super().__init__() dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) time_dim = time_dim or dim self.time_mlp = nn.Sequential( SinusoidalPosEmb(dim), nn.Linear(dim, dim * 4), nn.Mish(), nn.Linear(dim * 4, dim), ) self.blocks = nn.ModuleList([]) print(in_out) for dim_in, dim_out in in_out: self.blocks.append(nn.ModuleList([ ResidualTemporalBlock(dim_in, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon), ResidualTemporalBlock(dim_out, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon), Downsample1d(dim_out) ])) horizon = horizon // 2 fc_dim = dims[-1] * max(horizon, 1) self.final_block = nn.Sequential( nn.Linear(fc_dim + time_dim, fc_dim // 2), nn.Mish(), nn.Linear(fc_dim // 2, out_dim), ) def forward(self, x, cond, time, *args): ''' x : [ batch x horizon x transition ] ''' x = einops.rearrange(x, 'b h t -> b t h') t = self.time_mlp(time) for resnet, resnet2, downsample in self.blocks: x = resnet(x, t) x = resnet2(x, t) x = downsample(x) x = x.view(len(x), -1) out = self.final_block(torch.cat([x, t], dim=-1)) return out