"tests/vscode:/vscode.git/clone" did not exist on "de54891fc89c958a4c1b1b37bf3b3d2155c6e9ab"
Commit 08c85229 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

add license disclaimers to schedulers

parent 2b8bc91c
# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py
import math
import torch
import torch.nn as nn
import einops
from einops.layers.torch import Rearrange
import math
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
......@@ -20,6 +23,7 @@ class SinusoidalPosEmb(nn.Module):
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class Downsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
......@@ -28,6 +32,7 @@ class Downsample1d(nn.Module):
def forward(self, x):
return self.conv(x)
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
......@@ -36,57 +41,61 @@ class Upsample1d(nn.Module):
def forward(self, x):
return self.conv(x)
class Conv1dBlock(nn.Module):
'''
Conv1d --> GroupNorm --> Mish
'''
"""
Conv1d --> GroupNorm --> Mish
"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
Rearrange('batch channels horizon -> batch channels 1 horizon'),
Rearrange("batch channels horizon -> batch channels 1 horizon"),
nn.GroupNorm(n_groups, out_channels),
Rearrange('batch channels 1 horizon -> batch channels horizon'),
Rearrange("batch channels 1 horizon -> batch channels horizon"),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
class ResidualTemporalBlock(nn.Module):
class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super().__init__()
self.blocks = nn.ModuleList([
Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock(out_channels, out_channels, kernel_size),
])
self.blocks = nn.ModuleList(
[
Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock(out_channels, out_channels, kernel_size),
]
)
self.time_mlp = nn.Sequential(
nn.Mish(),
nn.Linear(embed_dim, out_channels),
Rearrange('batch t -> batch t 1'),
Rearrange("batch t -> batch t 1"),
)
self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1) \
if inp_channels != out_channels else nn.Identity()
self.residual_conv = (
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
)
def forward(self, x, t):
'''
x : [ batch_size x inp_channels x horizon ]
t : [ batch_size x embed_dim ]
returns:
out : [ batch_size x out_channels x horizon ]
'''
"""
x : [ batch_size x inp_channels x horizon ]
t : [ batch_size x embed_dim ]
returns:
out : [ batch_size x out_channels x horizon ]
"""
out = self.blocks[0](x) + self.time_mlp(t)
out = self.blocks[1](out)
return out + self.residual_conv(x)
class TemporalUnet(nn.Module):
class TemporalUnet(nn.Module):
def __init__(
self,
horizon,
......@@ -99,7 +108,7 @@ class TemporalUnet(nn.Module):
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
print(f'[ models/temporal ] Channel dimensions: {in_out}')
print(f"[ models/temporal ] Channel dimensions: {in_out}")
time_dim = dim
self.time_mlp = nn.Sequential(
......@@ -117,11 +126,15 @@ class TemporalUnet(nn.Module):
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(nn.ModuleList([
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=horizon),
Downsample1d(dim_out) if not is_last else nn.Identity()
]))
self.downs.append(
nn.ModuleList(
[
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=horizon),
Downsample1d(dim_out) if not is_last else nn.Identity(),
]
)
)
if not is_last:
horizon = horizon // 2
......@@ -133,11 +146,15 @@ class TemporalUnet(nn.Module):
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
self.ups.append(nn.ModuleList([
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=horizon),
Upsample1d(dim_in) if not is_last else nn.Identity()
]))
self.ups.append(
nn.ModuleList(
[
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=horizon),
Upsample1d(dim_in) if not is_last else nn.Identity(),
]
)
)
if not is_last:
horizon = horizon * 2
......@@ -148,11 +165,11 @@ class TemporalUnet(nn.Module):
)
def forward(self, x, cond, time):
'''
x : [ batch x horizon x transition ]
'''
"""
x : [ batch x horizon x transition ]
"""
x = einops.rearrange(x, 'b h t -> b t h')
x = einops.rearrange(x, "b h t -> b t h")
t = self.time_mlp(time)
h = []
......@@ -174,11 +191,11 @@ class TemporalUnet(nn.Module):
x = self.final_conv(x)
x = einops.rearrange(x, 'b t h -> b h t')
x = einops.rearrange(x, "b t h -> b h t")
return x
class TemporalValue(nn.Module):
class TemporalValue(nn.Module):
def __init__(
self,
horizon,
......@@ -207,11 +224,15 @@ class TemporalValue(nn.Module):
print(in_out)
for dim_in, dim_out in in_out:
self.blocks.append(nn.ModuleList([
ResidualTemporalBlock(dim_in, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_out, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
Downsample1d(dim_out)
]))
self.blocks.append(
nn.ModuleList(
[
ResidualTemporalBlock(dim_in, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_out, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
Downsample1d(dim_out),
]
)
)
horizon = horizon // 2
......@@ -224,11 +245,11 @@ class TemporalValue(nn.Module):
)
def forward(self, x, cond, time, *args):
'''
x : [ batch x horizon x transition ]
'''
"""
x : [ batch x horizon x transition ]
"""
x = einops.rearrange(x, 'b h t -> b t h')
x = einops.rearrange(x, "b h t -> b t h")
t = self.time_mlp(time)
......@@ -239,4 +260,4 @@ class TemporalValue(nn.Module):
x = x.view(len(x), -1)
out = self.final_block(torch.cat([x, t], dim=-1))
return out
\ No newline at end of file
return out
......@@ -233,6 +233,7 @@ def english_cleaners(text):
text = collapse_whitespace(text)
return text
try:
_inflect = inflect.engine()
except:
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion
import math
import numpy as np
......@@ -31,6 +35,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
import numpy as np
......@@ -31,6 +34,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
# Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,9 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
import numpy as np
from ..configuration_utils import ConfigMixin
from .scheduling_utils import SchedulerMixin
......@@ -30,6 +34,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
......
......@@ -17,11 +17,11 @@
import inspect
import tempfile
import unittest
import numpy as np
import pytest
import numpy as np
import torch
import pytest
from diffusers import (
BDDM,
DDIM,
......@@ -30,10 +30,10 @@ from diffusers import (
PNDM,
DDIMScheduler,
DDPMScheduler,
GLIDESuperResUNetModel,
LatentDiffusion,
PNDMScheduler,
UNetModel,
GLIDESuperResUNetModel
)
from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline
......@@ -105,7 +105,7 @@ class ModelTesterMixin:
max_diff = (image - new_image).abs().sum().item()
self.assertLessEqual(max_diff, 1e-5, "Models give different forward passes")
def test_determinism(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
......@@ -121,7 +121,7 @@ class ModelTesterMixin:
out_2 = out_2[~np.isnan(out_2)]
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
def test_output(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
......@@ -130,11 +130,11 @@ class ModelTesterMixin:
with torch.no_grad():
output = model(**inputs_dict)
self.assertIsNotNone(output)
expected_shape = inputs_dict["x"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_forward_signature(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
......@@ -145,14 +145,14 @@ class ModelTesterMixin:
expected_arg_names = ["x", "timesteps"]
self.assertListEqual(arg_names[:2], expected_arg_names)
def test_model_from_config(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
# test if the model can be loaded from the config
# and has all the expected shape
with tempfile.TemporaryDirectory() as tmpdirname:
......@@ -160,17 +160,17 @@ class ModelTesterMixin:
new_model = self.model_class.from_config(tmpdirname)
new_model.to(torch_device)
new_model.eval()
# check if all paramters shape are the same
for param_name in model.state_dict().keys():
param_1 = model.state_dict()[param_name]
param_2 = new_model.state_dict()[param_name]
self.assertEqual(param_1.shape, param_2.shape)
with torch.no_grad():
output_1 = model(**inputs_dict)
output_2 = new_model(**inputs_dict)
self.assertEqual(output_1.shape, output_2.shape)
def test_training(self):
......@@ -180,7 +180,7 @@ class ModelTesterMixin:
model.to(torch_device)
model.train()
output = model(**inputs_dict)
noise = torch.randn((inputs_dict["x"].shape[0], ) + self.get_output_shape).to(torch_device)
noise = torch.randn((inputs_dict["x"].shape[0],) + self.get_output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
......@@ -198,11 +198,11 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
time_step = torch.tensor([10]).to(torch_device)
return {"x": noise, "timesteps": time_step}
@property
def get_input_shape(self):
return (3, 32, 32)
@property
def get_output_shape(self):
return (3, 32, 32)
......@@ -217,7 +217,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_from_pretrained_hub(self):
model, loading_info = UNetModel.from_pretrained("fusing/ddpm_dummy", output_loading_info=True)
self.assertIsNotNone(model)
......@@ -227,7 +227,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
image = model(**self.dummy_input)
assert image is not None, "Make sure output is not None"
def test_output_pretrained(self):
model = UNetModel.from_pretrained("fusing/ddpm_dummy")
model.eval()
......@@ -235,13 +235,13 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
noise = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution)
time_step = torch.tensor([10])
with torch.no_grad():
output = model(noise, time_step)
output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
expected_output_slice = torch.tensor([ 0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053])
......@@ -249,6 +249,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
print(output_slice)
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
model_class = GLIDESuperResUNetModel
......@@ -266,19 +267,19 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
time_step = torch.tensor([10] * noise.shape[0], device=torch_device)
return {"x": noise, "timesteps": time_step, "low_res": low_res}
@property
def get_input_shape(self):
return (3, 32, 32)
@property
def get_output_shape(self):
return (6, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"attention_resolutions": (2,),
"channel_mult": (1,2),
"channel_mult": (1, 2),
"in_channels": 6,
"out_channels": 6,
"model_channels": 32,
......@@ -287,7 +288,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
"num_res_blocks": 2,
"resblock_updown": True,
"resolution": 32,
"use_scale_shift_norm": True
"use_scale_shift_norm": True,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
......@@ -302,13 +303,15 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
output = model(**inputs_dict)
output, _ = torch.split(output, 3, dim=1)
self.assertIsNotNone(output)
expected_shape = inputs_dict["x"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_from_pretrained_hub(self):
model, loading_info = GLIDESuperResUNetModel.from_pretrained("fusing/glide-super-res-dummy", output_loading_info=True)
model, loading_info = GLIDESuperResUNetModel.from_pretrained(
"fusing/glide-super-res-dummy", output_loading_info=True
)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
......@@ -316,7 +319,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
image = model(**self.dummy_input)
assert image is not None, "Make sure output is not None"
# TODO (patil-suraj): Check why GLIDESuperResUNetModel always outputs zero
@unittest.skip("GLIDESuperResUNetModel always outputs zero")
def test_output_pretrained(self):
......@@ -326,14 +329,14 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
noise = torch.randn(1, 3, 32, 32)
low_res = torch.randn(1, 3, 4, 4)
time_step = torch.tensor([42] * noise.shape[0])
with torch.no_grad():
output = model(noise, time_step, low_res)
output, _ = torch.split(output, 3, dim=1)
output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment