Commit a677565f authored by Patrick von Platen's avatar Patrick von Platen
Browse files
parents ff885b0e d182a6ad
......@@ -287,14 +287,14 @@ class UNetModel(ModelMixin, ConfigMixin):
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, x, t):
def forward(self, x, timesteps):
assert x.shape[2] == x.shape[3] == self.resolution
if not torch.is_tensor(t):
t = torch.tensor([t], dtype=torch.long, device=x.device)
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device)
# timestep embedding
temb = get_timestep_embedding(t, self.ch)
temb = get_timestep_embedding(timesteps, self.ch)
temb = self.temb.dense[0](temb)
temb = nonlinearity(temb)
temb = self.temb.dense[1](temb)
......
# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py
import torch
import torch.nn as nn
import einops
from einops.layers.torch import Rearrange
import math
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class Downsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
class Conv1dBlock(nn.Module):
'''
Conv1d --> GroupNorm --> Mish
'''
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
Rearrange('batch channels horizon -> batch channels 1 horizon'),
nn.GroupNorm(n_groups, out_channels),
Rearrange('batch channels 1 horizon -> batch channels horizon'),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super().__init__()
self.blocks = nn.ModuleList([
Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock(out_channels, out_channels, kernel_size),
])
self.time_mlp = nn.Sequential(
nn.Mish(),
nn.Linear(embed_dim, out_channels),
Rearrange('batch t -> batch t 1'),
)
self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1) \
if inp_channels != out_channels else nn.Identity()
def forward(self, x, t):
'''
x : [ batch_size x inp_channels x horizon ]
t : [ batch_size x embed_dim ]
returns:
out : [ batch_size x out_channels x horizon ]
'''
out = self.blocks[0](x) + self.time_mlp(t)
out = self.blocks[1](out)
return out + self.residual_conv(x)
class TemporalUnet(nn.Module):
def __init__(
self,
horizon,
transition_dim,
cond_dim,
dim=32,
dim_mults=(1, 2, 4, 8),
):
super().__init__()
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
print(f'[ models/temporal ] Channel dimensions: {in_out}')
time_dim = dim
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.Mish(),
nn.Linear(dim * 4, dim),
)
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
print(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(nn.ModuleList([
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=horizon),
Downsample1d(dim_out) if not is_last else nn.Identity()
]))
if not is_last:
horizon = horizon // 2
mid_dim = dims[-1]
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon)
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
self.ups.append(nn.ModuleList([
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=horizon),
Upsample1d(dim_in) if not is_last else nn.Identity()
]))
if not is_last:
horizon = horizon * 2
self.final_conv = nn.Sequential(
Conv1dBlock(dim, dim, kernel_size=5),
nn.Conv1d(dim, transition_dim, 1),
)
def forward(self, x, cond, time):
'''
x : [ batch x horizon x transition ]
'''
x = einops.rearrange(x, 'b h t -> b t h')
t = self.time_mlp(time)
h = []
for resnet, resnet2, downsample in self.downs:
x = resnet(x, t)
x = resnet2(x, t)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_block2(x, t)
for resnet, resnet2, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, t)
x = resnet2(x, t)
x = upsample(x)
x = self.final_conv(x)
x = einops.rearrange(x, 'b t h -> b h t')
return x
class TemporalValue(nn.Module):
def __init__(
self,
horizon,
transition_dim,
cond_dim,
dim=32,
time_dim=None,
out_dim=1,
dim_mults=(1, 2, 4, 8),
):
super().__init__()
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
time_dim = time_dim or dim
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.Mish(),
nn.Linear(dim * 4, dim),
)
self.blocks = nn.ModuleList([])
print(in_out)
for dim_in, dim_out in in_out:
self.blocks.append(nn.ModuleList([
ResidualTemporalBlock(dim_in, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_out, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
Downsample1d(dim_out)
]))
horizon = horizon // 2
fc_dim = dims[-1] * max(horizon, 1)
self.final_block = nn.Sequential(
nn.Linear(fc_dim + time_dim, fc_dim // 2),
nn.Mish(),
nn.Linear(fc_dim // 2, out_dim),
)
def forward(self, x, cond, time, *args):
'''
x : [ batch x horizon x transition ]
'''
x = einops.rearrange(x, 'b h t -> b t h')
t = self.time_mlp(time)
for resnet, resnet2, downsample in self.blocks:
x = resnet(x, t)
x = resnet2(x, t)
x = downsample(x)
x = x.view(len(x), -1)
out = self.final_block(torch.cat([x, t], dim=-1))
return out
\ No newline at end of file
......@@ -233,8 +233,12 @@ def english_cleaners(text):
text = collapse_whitespace(text)
return text
try:
_inflect = inflect.engine()
except:
print("inflect is not installed")
_inflect = None
_inflect = inflect.engine()
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
......
......@@ -105,12 +105,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# hacks - were probs added for training stability
if self.config.variance_type == "fixed_small":
variance = self.clip(variance, min_value=1e-20)
# for rl-diffuser https://arxiv.org/abs/2205.09991
elif self.config.variance_type == "fixed_small_log":
variance = self.log(self.clip(variance, min_value=1e-20))
elif self.config.variance_type == "fixed_large":
variance = self.get_beta(t)
return variance
def step(self, residual, sample, t):
def step(self, residual, sample, t, predict_epsilon=True):
# 1. compute alphas, betas
alpha_prod_t = self.get_alpha_prod(t)
alpha_prod_t_prev = self.get_alpha_prod(t - 1)
......@@ -119,7 +122,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if predict_epsilon:
pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
else:
pred_original_sample = residual
# 3. Clip "predicted x_0"
if self.config.clip_sample:
......
......@@ -64,3 +64,13 @@ class SchedulerMixin:
return torch.clamp(tensor, min_value, max_value)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def log(self, tensor):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.log(tensor)
elif tensor_format == "pt":
return torch.log(tensor)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
......@@ -14,8 +14,10 @@
# limitations under the License.
import inspect
import tempfile
import unittest
import numpy as np
import torch
......@@ -82,7 +84,108 @@ class ConfigTester(unittest.TestCase):
assert config == new_config
class ModelTesterMixin(unittest.TestCase):
class ModelTesterMixin:
def test_from_pretrained_save_pretrained(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
new_model = UNetModel.from_pretrained(tmpdirname)
new_model.to(torch_device)
with torch.no_grad():
image = model(**inputs_dict)
new_image = new_model(**inputs_dict)
max_diff = (image - new_image).abs().sum().item()
self.assertLessEqual(max_diff, 1e-5, "Models give different forward passes")
def test_determinism(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
first = model(**inputs_dict)
second = model(**inputs_dict)
out_1 = first.cpu().numpy()
out_2 = second.cpu().numpy()
out_1 = out_1[~np.isnan(out_1)]
out_2 = out_2[~np.isnan(out_2)]
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
def test_output(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
self.assertIsNotNone(output)
expected_shape = inputs_dict["x"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_forward_signature(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["x", "timesteps"]
self.assertListEqual(arg_names[:2], expected_arg_names)
def test_model_from_config(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
# test if the model can be loaded from the config
# and has all the expected shape
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_config(tmpdirname)
new_model = self.model_class.from_config(tmpdirname)
new_model.to(torch_device)
new_model.eval()
# check if all paramters shape are the same
for param_name in model.state_dict().keys():
param_1 = model.state_dict()[param_name]
param_2 = new_model.state_dict()[param_name]
self.assertEqual(param_1.shape, param_2.shape)
with torch.no_grad():
output_1 = model(**inputs_dict)
output_2 = new_model(**inputs_dict)
self.assertEqual(output_1.shape, output_2.shape)
def test_training(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.train()
output = model(**inputs_dict)
noise = torch.randn(inputs_dict["x"].shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
class UnetModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNetModel
@property
def dummy_input(self):
batch_size = 4
......@@ -92,31 +195,51 @@ class ModelTesterMixin(unittest.TestCase):
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
return (noise, time_step)
return {"x": noise, "timesteps": time_step}
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"ch": 32,
"ch_mult": (1, 2),
"num_res_blocks": 2,
"attn_resolutions": (16,),
"resolution": 32,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_from_pretrained_hub(self):
model, loading_info = UNetModel.from_pretrained("fusing/ddpm_dummy", output_loading_info=True)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
def test_from_pretrained_save_pretrained(self):
model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32)
model.to(torch_device)
image = model(**self.dummy_input)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
new_model = UNetModel.from_pretrained(tmpdirname)
new_model.to(torch_device)
assert image is not None, "Make sure output is not None"
dummy_input = self.dummy_input
def test_output_pretrained(self):
model = UNetModel.from_pretrained("fusing/ddpm_dummy")
model.eval()
image = model(*dummy_input)
new_image = new_model(*dummy_input)
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
noise = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution)
print(noise.shape)
time_step = torch.tensor([10])
def test_from_pretrained_hub(self):
model = UNetModel.from_pretrained("fusing/ddpm_dummy")
model.to(torch_device)
with torch.no_grad():
output = model(noise, time_step)
image = model(*self.dummy_input)
output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
expected_output_slice = torch.tensor([ 0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053])
# fmt: on
print(output_slice)
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
assert image is not None, "Make sure output is not None"
class PipelineTesterMixin(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment