Commit 9c96682a authored by Nathan Lambert's avatar Nathan Lambert
Browse files

ddpm changes for rl, add rl unet

parent 1997b908
# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py
import torch
import torch.nn as nn
import einops
from einops.layers.torch import Rearrange
import math
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class Downsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
class Conv1dBlock(nn.Module):
'''
Conv1d --> GroupNorm --> Mish
'''
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
Rearrange('batch channels horizon -> batch channels 1 horizon'),
nn.GroupNorm(n_groups, out_channels),
Rearrange('batch channels 1 horizon -> batch channels horizon'),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super().__init__()
self.blocks = nn.ModuleList([
Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock(out_channels, out_channels, kernel_size),
])
self.time_mlp = nn.Sequential(
nn.Mish(),
nn.Linear(embed_dim, out_channels),
Rearrange('batch t -> batch t 1'),
)
self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1) \
if inp_channels != out_channels else nn.Identity()
def forward(self, x, t):
'''
x : [ batch_size x inp_channels x horizon ]
t : [ batch_size x embed_dim ]
returns:
out : [ batch_size x out_channels x horizon ]
'''
out = self.blocks[0](x) + self.time_mlp(t)
out = self.blocks[1](out)
return out + self.residual_conv(x)
class TemporalUnet(nn.Module):
def __init__(
self,
horizon,
transition_dim,
cond_dim,
dim=32,
dim_mults=(1, 2, 4, 8),
):
super().__init__()
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
print(f'[ models/temporal ] Channel dimensions: {in_out}')
time_dim = dim
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.Mish(),
nn.Linear(dim * 4, dim),
)
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
print(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(nn.ModuleList([
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=horizon),
Downsample1d(dim_out) if not is_last else nn.Identity()
]))
if not is_last:
horizon = horizon // 2
mid_dim = dims[-1]
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon)
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
self.ups.append(nn.ModuleList([
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=horizon),
Upsample1d(dim_in) if not is_last else nn.Identity()
]))
if not is_last:
horizon = horizon * 2
self.final_conv = nn.Sequential(
Conv1dBlock(dim, dim, kernel_size=5),
nn.Conv1d(dim, transition_dim, 1),
)
def forward(self, x, cond, time):
'''
x : [ batch x horizon x transition ]
'''
x = einops.rearrange(x, 'b h t -> b t h')
t = self.time_mlp(time)
h = []
for resnet, resnet2, downsample in self.downs:
x = resnet(x, t)
x = resnet2(x, t)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_block2(x, t)
for resnet, resnet2, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, t)
x = resnet2(x, t)
x = upsample(x)
x = self.final_conv(x)
x = einops.rearrange(x, 'b t h -> b h t')
return x
class TemporalValue(nn.Module):
def __init__(
self,
horizon,
transition_dim,
cond_dim,
dim=32,
time_dim=None,
out_dim=1,
dim_mults=(1, 2, 4, 8),
):
super().__init__()
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
time_dim = time_dim or dim
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.Mish(),
nn.Linear(dim * 4, dim),
)
self.blocks = nn.ModuleList([])
print(in_out)
for dim_in, dim_out in in_out:
self.blocks.append(nn.ModuleList([
ResidualTemporalBlock(dim_in, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_out, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
Downsample1d(dim_out)
]))
horizon = horizon // 2
fc_dim = dims[-1] * max(horizon, 1)
self.final_block = nn.Sequential(
nn.Linear(fc_dim + time_dim, fc_dim // 2),
nn.Mish(),
nn.Linear(fc_dim // 2, out_dim),
)
def forward(self, x, cond, time, *args):
'''
x : [ batch x horizon x transition ]
'''
x = einops.rearrange(x, 'b h t -> b t h')
t = self.time_mlp(time)
for resnet, resnet2, downsample in self.blocks:
x = resnet(x, t)
x = resnet2(x, t)
x = downsample(x)
x = x.view(len(x), -1)
out = self.final_block(torch.cat([x, t], dim=-1))
return out
\ No newline at end of file
......@@ -105,12 +105,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# hacks - were probs added for training stability
if self.config.variance_type == "fixed_small":
variance = self.clip(variance, min_value=1e-20)
# for rl-diffuser https://arxiv.org/abs/2205.09991
elif self.config.variance_type == "fixed_small_log":
variance = self.log(self.clip(variance, min_value=1e-20))
elif self.config.variance_type == "fixed_large":
variance = self.get_beta(t)
return variance
def step(self, residual, sample, t):
def step(self, residual, sample, t, predict_epsilon=True):
# 1. compute alphas, betas
alpha_prod_t = self.get_alpha_prod(t)
alpha_prod_t_prev = self.get_alpha_prod(t - 1)
......@@ -119,7 +122,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
if predict_epsilon:
pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
else:
pred_original_sample = residual
# 3. Clip "predicted x_0"
if self.config.clip_sample:
......
......@@ -64,3 +64,13 @@ class SchedulerMixin:
return torch.clamp(tensor, min_value, max_value)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def log(self, tensor):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.log(tensor)
elif tensor_format == "pt":
return torch.log(tensor)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment