Commit 08c85229 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

add license disclaimers to schedulers

parent 2b8bc91c
# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py # model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import einops import einops
from einops.layers.torch import Rearrange from einops.layers.torch import Rearrange
import math
class SinusoidalPosEmb(nn.Module): class SinusoidalPosEmb(nn.Module):
def __init__(self, dim): def __init__(self, dim):
...@@ -20,6 +23,7 @@ class SinusoidalPosEmb(nn.Module): ...@@ -20,6 +23,7 @@ class SinusoidalPosEmb(nn.Module):
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb return emb
class Downsample1d(nn.Module): class Downsample1d(nn.Module):
def __init__(self, dim): def __init__(self, dim):
super().__init__() super().__init__()
...@@ -28,6 +32,7 @@ class Downsample1d(nn.Module): ...@@ -28,6 +32,7 @@ class Downsample1d(nn.Module):
def forward(self, x): def forward(self, x):
return self.conv(x) return self.conv(x)
class Upsample1d(nn.Module): class Upsample1d(nn.Module):
def __init__(self, dim): def __init__(self, dim):
super().__init__() super().__init__()
...@@ -36,57 +41,61 @@ class Upsample1d(nn.Module): ...@@ -36,57 +41,61 @@ class Upsample1d(nn.Module):
def forward(self, x): def forward(self, x):
return self.conv(x) return self.conv(x)
class Conv1dBlock(nn.Module): class Conv1dBlock(nn.Module):
''' """
Conv1d --> GroupNorm --> Mish Conv1d --> GroupNorm --> Mish
''' """
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__() super().__init__()
self.block = nn.Sequential( self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
Rearrange('batch channels horizon -> batch channels 1 horizon'), Rearrange("batch channels horizon -> batch channels 1 horizon"),
nn.GroupNorm(n_groups, out_channels), nn.GroupNorm(n_groups, out_channels),
Rearrange('batch channels 1 horizon -> batch channels horizon'), Rearrange("batch channels 1 horizon -> batch channels horizon"),
nn.Mish(), nn.Mish(),
) )
def forward(self, x): def forward(self, x):
return self.block(x) return self.block(x)
class ResidualTemporalBlock(nn.Module):
class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5): def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super().__init__() super().__init__()
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList(
[
Conv1dBlock(inp_channels, out_channels, kernel_size), Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock(out_channels, out_channels, kernel_size), Conv1dBlock(out_channels, out_channels, kernel_size),
]) ]
)
self.time_mlp = nn.Sequential( self.time_mlp = nn.Sequential(
nn.Mish(), nn.Mish(),
nn.Linear(embed_dim, out_channels), nn.Linear(embed_dim, out_channels),
Rearrange('batch t -> batch t 1'), Rearrange("batch t -> batch t 1"),
) )
self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1) \ self.residual_conv = (
if inp_channels != out_channels else nn.Identity() nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
)
def forward(self, x, t): def forward(self, x, t):
''' """
x : [ batch_size x inp_channels x horizon ] x : [ batch_size x inp_channels x horizon ]
t : [ batch_size x embed_dim ] t : [ batch_size x embed_dim ]
returns: returns:
out : [ batch_size x out_channels x horizon ] out : [ batch_size x out_channels x horizon ]
''' """
out = self.blocks[0](x) + self.time_mlp(t) out = self.blocks[0](x) + self.time_mlp(t)
out = self.blocks[1](out) out = self.blocks[1](out)
return out + self.residual_conv(x) return out + self.residual_conv(x)
class TemporalUnet(nn.Module):
class TemporalUnet(nn.Module):
def __init__( def __init__(
self, self,
horizon, horizon,
...@@ -99,7 +108,7 @@ class TemporalUnet(nn.Module): ...@@ -99,7 +108,7 @@ class TemporalUnet(nn.Module):
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:])) in_out = list(zip(dims[:-1], dims[1:]))
print(f'[ models/temporal ] Channel dimensions: {in_out}') print(f"[ models/temporal ] Channel dimensions: {in_out}")
time_dim = dim time_dim = dim
self.time_mlp = nn.Sequential( self.time_mlp = nn.Sequential(
...@@ -117,11 +126,15 @@ class TemporalUnet(nn.Module): ...@@ -117,11 +126,15 @@ class TemporalUnet(nn.Module):
for ind, (dim_in, dim_out) in enumerate(in_out): for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1) is_last = ind >= (num_resolutions - 1)
self.downs.append(nn.ModuleList([ self.downs.append(
nn.ModuleList(
[
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=horizon), ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=horizon), ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=horizon),
Downsample1d(dim_out) if not is_last else nn.Identity() Downsample1d(dim_out) if not is_last else nn.Identity(),
])) ]
)
)
if not is_last: if not is_last:
horizon = horizon // 2 horizon = horizon // 2
...@@ -133,11 +146,15 @@ class TemporalUnet(nn.Module): ...@@ -133,11 +146,15 @@ class TemporalUnet(nn.Module):
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1) is_last = ind >= (num_resolutions - 1)
self.ups.append(nn.ModuleList([ self.ups.append(
nn.ModuleList(
[
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=horizon), ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=horizon), ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=horizon),
Upsample1d(dim_in) if not is_last else nn.Identity() Upsample1d(dim_in) if not is_last else nn.Identity(),
])) ]
)
)
if not is_last: if not is_last:
horizon = horizon * 2 horizon = horizon * 2
...@@ -148,11 +165,11 @@ class TemporalUnet(nn.Module): ...@@ -148,11 +165,11 @@ class TemporalUnet(nn.Module):
) )
def forward(self, x, cond, time): def forward(self, x, cond, time):
''' """
x : [ batch x horizon x transition ] x : [ batch x horizon x transition ]
''' """
x = einops.rearrange(x, 'b h t -> b t h') x = einops.rearrange(x, "b h t -> b t h")
t = self.time_mlp(time) t = self.time_mlp(time)
h = [] h = []
...@@ -174,11 +191,11 @@ class TemporalUnet(nn.Module): ...@@ -174,11 +191,11 @@ class TemporalUnet(nn.Module):
x = self.final_conv(x) x = self.final_conv(x)
x = einops.rearrange(x, 'b t h -> b h t') x = einops.rearrange(x, "b t h -> b h t")
return x return x
class TemporalValue(nn.Module):
class TemporalValue(nn.Module):
def __init__( def __init__(
self, self,
horizon, horizon,
...@@ -207,11 +224,15 @@ class TemporalValue(nn.Module): ...@@ -207,11 +224,15 @@ class TemporalValue(nn.Module):
print(in_out) print(in_out)
for dim_in, dim_out in in_out: for dim_in, dim_out in in_out:
self.blocks.append(nn.ModuleList([ self.blocks.append(
nn.ModuleList(
[
ResidualTemporalBlock(dim_in, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon), ResidualTemporalBlock(dim_in, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_out, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon), ResidualTemporalBlock(dim_out, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
Downsample1d(dim_out) Downsample1d(dim_out),
])) ]
)
)
horizon = horizon // 2 horizon = horizon // 2
...@@ -224,11 +245,11 @@ class TemporalValue(nn.Module): ...@@ -224,11 +245,11 @@ class TemporalValue(nn.Module):
) )
def forward(self, x, cond, time, *args): def forward(self, x, cond, time, *args):
''' """
x : [ batch x horizon x transition ] x : [ batch x horizon x transition ]
''' """
x = einops.rearrange(x, 'b h t -> b t h') x = einops.rearrange(x, "b h t -> b t h")
t = self.time_mlp(time) t = self.time_mlp(time)
......
...@@ -233,6 +233,7 @@ def english_cleaners(text): ...@@ -233,6 +233,7 @@ def english_cleaners(text):
text = collapse_whitespace(text) text = collapse_whitespace(text)
return text return text
try: try:
_inflect = inflect.engine() _inflect = inflect.engine()
except: except:
......
# Copyright 2022 The HuggingFace Team. All rights reserved. # Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,6 +11,10 @@ ...@@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion
import math import math
import numpy as np import numpy as np
...@@ -31,6 +35,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): ...@@ -31,6 +35,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
:param max_beta: the maximum beta to use; use values lower than 1 to :param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities. prevent singularities.
""" """
def alpha_bar(time_step): def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
......
# Copyright 2022 The HuggingFace Team. All rights reserved. # Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,6 +11,9 @@ ...@@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math import math
import numpy as np import numpy as np
...@@ -31,6 +34,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): ...@@ -31,6 +34,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
:param max_beta: the maximum beta to use; use values lower than 1 to :param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities. prevent singularities.
""" """
def alpha_bar(time_step): def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
......
# Copyright 2022 The HuggingFace Team. All rights reserved. # Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,9 +11,13 @@ ...@@ -11,9 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math import math
import numpy as np
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -30,6 +34,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): ...@@ -30,6 +34,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
:param max_beta: the maximum beta to use; use values lower than 1 to :param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities. prevent singularities.
""" """
def alpha_bar(time_step): def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
......
...@@ -17,11 +17,11 @@ ...@@ -17,11 +17,11 @@
import inspect import inspect
import tempfile import tempfile
import unittest import unittest
import numpy as np
import pytest
import numpy as np
import torch import torch
import pytest
from diffusers import ( from diffusers import (
BDDM, BDDM,
DDIM, DDIM,
...@@ -30,10 +30,10 @@ from diffusers import ( ...@@ -30,10 +30,10 @@ from diffusers import (
PNDM, PNDM,
DDIMScheduler, DDIMScheduler,
DDPMScheduler, DDPMScheduler,
GLIDESuperResUNetModel,
LatentDiffusion, LatentDiffusion,
PNDMScheduler, PNDMScheduler,
UNetModel, UNetModel,
GLIDESuperResUNetModel
) )
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
...@@ -180,7 +180,7 @@ class ModelTesterMixin: ...@@ -180,7 +180,7 @@ class ModelTesterMixin:
model.to(torch_device) model.to(torch_device)
model.train() model.train()
output = model(**inputs_dict) output = model(**inputs_dict)
noise = torch.randn((inputs_dict["x"].shape[0], ) + self.get_output_shape).to(torch_device) noise = torch.randn((inputs_dict["x"].shape[0],) + self.get_output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise) loss = torch.nn.functional.mse_loss(output, noise)
loss.backward() loss.backward()
...@@ -249,6 +249,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -249,6 +249,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
print(output_slice) print(output_slice)
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase): class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
model_class = GLIDESuperResUNetModel model_class = GLIDESuperResUNetModel
...@@ -278,7 +279,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase): ...@@ -278,7 +279,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
def prepare_init_args_and_inputs_for_common(self): def prepare_init_args_and_inputs_for_common(self):
init_dict = { init_dict = {
"attention_resolutions": (2,), "attention_resolutions": (2,),
"channel_mult": (1,2), "channel_mult": (1, 2),
"in_channels": 6, "in_channels": 6,
"out_channels": 6, "out_channels": 6,
"model_channels": 32, "model_channels": 32,
...@@ -287,7 +288,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase): ...@@ -287,7 +288,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
"num_res_blocks": 2, "num_res_blocks": 2,
"resblock_updown": True, "resblock_updown": True,
"resolution": 32, "resolution": 32,
"use_scale_shift_norm": True "use_scale_shift_norm": True,
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
...@@ -308,7 +309,9 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase): ...@@ -308,7 +309,9 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
model, loading_info = GLIDESuperResUNetModel.from_pretrained("fusing/glide-super-res-dummy", output_loading_info=True) model, loading_info = GLIDESuperResUNetModel.from_pretrained(
"fusing/glide-super-res-dummy", output_loading_info=True
)
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0) self.assertEqual(len(loading_info["missing_keys"]), 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment