Commit 08c85229 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

add license disclaimers to schedulers

parent 2b8bc91c
# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py # model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import einops import einops
from einops.layers.torch import Rearrange from einops.layers.torch import Rearrange
import math
class SinusoidalPosEmb(nn.Module): class SinusoidalPosEmb(nn.Module):
def __init__(self, dim): def __init__(self, dim):
...@@ -20,6 +23,7 @@ class SinusoidalPosEmb(nn.Module): ...@@ -20,6 +23,7 @@ class SinusoidalPosEmb(nn.Module):
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb return emb
class Downsample1d(nn.Module): class Downsample1d(nn.Module):
def __init__(self, dim): def __init__(self, dim):
super().__init__() super().__init__()
...@@ -28,6 +32,7 @@ class Downsample1d(nn.Module): ...@@ -28,6 +32,7 @@ class Downsample1d(nn.Module):
def forward(self, x): def forward(self, x):
return self.conv(x) return self.conv(x)
class Upsample1d(nn.Module): class Upsample1d(nn.Module):
def __init__(self, dim): def __init__(self, dim):
super().__init__() super().__init__()
...@@ -36,57 +41,61 @@ class Upsample1d(nn.Module): ...@@ -36,57 +41,61 @@ class Upsample1d(nn.Module):
def forward(self, x): def forward(self, x):
return self.conv(x) return self.conv(x)
class Conv1dBlock(nn.Module): class Conv1dBlock(nn.Module):
''' """
Conv1d --> GroupNorm --> Mish Conv1d --> GroupNorm --> Mish
''' """
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__() super().__init__()
self.block = nn.Sequential( self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
Rearrange('batch channels horizon -> batch channels 1 horizon'), Rearrange("batch channels horizon -> batch channels 1 horizon"),
nn.GroupNorm(n_groups, out_channels), nn.GroupNorm(n_groups, out_channels),
Rearrange('batch channels 1 horizon -> batch channels horizon'), Rearrange("batch channels 1 horizon -> batch channels horizon"),
nn.Mish(), nn.Mish(),
) )
def forward(self, x): def forward(self, x):
return self.block(x) return self.block(x)
class ResidualTemporalBlock(nn.Module):
class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5): def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super().__init__() super().__init__()
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList(
Conv1dBlock(inp_channels, out_channels, kernel_size), [
Conv1dBlock(out_channels, out_channels, kernel_size), Conv1dBlock(inp_channels, out_channels, kernel_size),
]) Conv1dBlock(out_channels, out_channels, kernel_size),
]
)
self.time_mlp = nn.Sequential( self.time_mlp = nn.Sequential(
nn.Mish(), nn.Mish(),
nn.Linear(embed_dim, out_channels), nn.Linear(embed_dim, out_channels),
Rearrange('batch t -> batch t 1'), Rearrange("batch t -> batch t 1"),
) )
self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1) \ self.residual_conv = (
if inp_channels != out_channels else nn.Identity() nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
)
def forward(self, x, t): def forward(self, x, t):
''' """
x : [ batch_size x inp_channels x horizon ] x : [ batch_size x inp_channels x horizon ]
t : [ batch_size x embed_dim ] t : [ batch_size x embed_dim ]
returns: returns:
out : [ batch_size x out_channels x horizon ] out : [ batch_size x out_channels x horizon ]
''' """
out = self.blocks[0](x) + self.time_mlp(t) out = self.blocks[0](x) + self.time_mlp(t)
out = self.blocks[1](out) out = self.blocks[1](out)
return out + self.residual_conv(x) return out + self.residual_conv(x)
class TemporalUnet(nn.Module):
class TemporalUnet(nn.Module):
def __init__( def __init__(
self, self,
horizon, horizon,
...@@ -99,7 +108,7 @@ class TemporalUnet(nn.Module): ...@@ -99,7 +108,7 @@ class TemporalUnet(nn.Module):
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:])) in_out = list(zip(dims[:-1], dims[1:]))
print(f'[ models/temporal ] Channel dimensions: {in_out}') print(f"[ models/temporal ] Channel dimensions: {in_out}")
time_dim = dim time_dim = dim
self.time_mlp = nn.Sequential( self.time_mlp = nn.Sequential(
...@@ -117,11 +126,15 @@ class TemporalUnet(nn.Module): ...@@ -117,11 +126,15 @@ class TemporalUnet(nn.Module):
for ind, (dim_in, dim_out) in enumerate(in_out): for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1) is_last = ind >= (num_resolutions - 1)
self.downs.append(nn.ModuleList([ self.downs.append(
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=horizon), nn.ModuleList(
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=horizon), [
Downsample1d(dim_out) if not is_last else nn.Identity() ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=horizon),
])) ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=horizon),
Downsample1d(dim_out) if not is_last else nn.Identity(),
]
)
)
if not is_last: if not is_last:
horizon = horizon // 2 horizon = horizon // 2
...@@ -133,11 +146,15 @@ class TemporalUnet(nn.Module): ...@@ -133,11 +146,15 @@ class TemporalUnet(nn.Module):
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1) is_last = ind >= (num_resolutions - 1)
self.ups.append(nn.ModuleList([ self.ups.append(
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=horizon), nn.ModuleList(
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=horizon), [
Upsample1d(dim_in) if not is_last else nn.Identity() ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=horizon),
])) ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=horizon),
Upsample1d(dim_in) if not is_last else nn.Identity(),
]
)
)
if not is_last: if not is_last:
horizon = horizon * 2 horizon = horizon * 2
...@@ -148,11 +165,11 @@ class TemporalUnet(nn.Module): ...@@ -148,11 +165,11 @@ class TemporalUnet(nn.Module):
) )
def forward(self, x, cond, time): def forward(self, x, cond, time):
''' """
x : [ batch x horizon x transition ] x : [ batch x horizon x transition ]
''' """
x = einops.rearrange(x, 'b h t -> b t h') x = einops.rearrange(x, "b h t -> b t h")
t = self.time_mlp(time) t = self.time_mlp(time)
h = [] h = []
...@@ -174,11 +191,11 @@ class TemporalUnet(nn.Module): ...@@ -174,11 +191,11 @@ class TemporalUnet(nn.Module):
x = self.final_conv(x) x = self.final_conv(x)
x = einops.rearrange(x, 'b t h -> b h t') x = einops.rearrange(x, "b t h -> b h t")
return x return x
class TemporalValue(nn.Module):
class TemporalValue(nn.Module):
def __init__( def __init__(
self, self,
horizon, horizon,
...@@ -207,11 +224,15 @@ class TemporalValue(nn.Module): ...@@ -207,11 +224,15 @@ class TemporalValue(nn.Module):
print(in_out) print(in_out)
for dim_in, dim_out in in_out: for dim_in, dim_out in in_out:
self.blocks.append(nn.ModuleList([ self.blocks.append(
ResidualTemporalBlock(dim_in, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon), nn.ModuleList(
ResidualTemporalBlock(dim_out, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon), [
Downsample1d(dim_out) ResidualTemporalBlock(dim_in, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
])) ResidualTemporalBlock(dim_out, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
Downsample1d(dim_out),
]
)
)
horizon = horizon // 2 horizon = horizon // 2
...@@ -224,11 +245,11 @@ class TemporalValue(nn.Module): ...@@ -224,11 +245,11 @@ class TemporalValue(nn.Module):
) )
def forward(self, x, cond, time, *args): def forward(self, x, cond, time, *args):
''' """
x : [ batch x horizon x transition ] x : [ batch x horizon x transition ]
''' """
x = einops.rearrange(x, 'b h t -> b t h') x = einops.rearrange(x, "b h t -> b t h")
t = self.time_mlp(time) t = self.time_mlp(time)
...@@ -239,4 +260,4 @@ class TemporalValue(nn.Module): ...@@ -239,4 +260,4 @@ class TemporalValue(nn.Module):
x = x.view(len(x), -1) x = x.view(len(x), -1)
out = self.final_block(torch.cat([x, t], dim=-1)) out = self.final_block(torch.cat([x, t], dim=-1))
return out return out
\ No newline at end of file
...@@ -233,6 +233,7 @@ def english_cleaners(text): ...@@ -233,6 +233,7 @@ def english_cleaners(text):
text = collapse_whitespace(text) text = collapse_whitespace(text)
return text return text
try: try:
_inflect = inflect.engine() _inflect = inflect.engine()
except: except:
......
# Copyright 2022 The HuggingFace Team. All rights reserved. # Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,6 +11,10 @@ ...@@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion
import math import math
import numpy as np import numpy as np
...@@ -31,6 +35,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): ...@@ -31,6 +35,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
:param max_beta: the maximum beta to use; use values lower than 1 to :param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities. prevent singularities.
""" """
def alpha_bar(time_step): def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
......
# Copyright 2022 The HuggingFace Team. All rights reserved. # Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,6 +11,9 @@ ...@@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math import math
import numpy as np import numpy as np
...@@ -31,6 +34,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): ...@@ -31,6 +34,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
:param max_beta: the maximum beta to use; use values lower than 1 to :param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities. prevent singularities.
""" """
def alpha_bar(time_step): def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
......
# Copyright 2022 The HuggingFace Team. All rights reserved. # Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,9 +11,13 @@ ...@@ -11,9 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math import math
import numpy as np
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -30,6 +34,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): ...@@ -30,6 +34,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
:param max_beta: the maximum beta to use; use values lower than 1 to :param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities. prevent singularities.
""" """
def alpha_bar(time_step): def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
......
...@@ -17,11 +17,11 @@ ...@@ -17,11 +17,11 @@
import inspect import inspect
import tempfile import tempfile
import unittest import unittest
import numpy as np
import pytest
import numpy as np
import torch import torch
import pytest
from diffusers import ( from diffusers import (
BDDM, BDDM,
DDIM, DDIM,
...@@ -30,10 +30,10 @@ from diffusers import ( ...@@ -30,10 +30,10 @@ from diffusers import (
PNDM, PNDM,
DDIMScheduler, DDIMScheduler,
DDPMScheduler, DDPMScheduler,
GLIDESuperResUNetModel,
LatentDiffusion, LatentDiffusion,
PNDMScheduler, PNDMScheduler,
UNetModel, UNetModel,
GLIDESuperResUNetModel
) )
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
...@@ -105,7 +105,7 @@ class ModelTesterMixin: ...@@ -105,7 +105,7 @@ class ModelTesterMixin:
max_diff = (image - new_image).abs().sum().item() max_diff = (image - new_image).abs().sum().item()
self.assertLessEqual(max_diff, 1e-5, "Models give different forward passes") self.assertLessEqual(max_diff, 1e-5, "Models give different forward passes")
def test_determinism(self): def test_determinism(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
...@@ -121,7 +121,7 @@ class ModelTesterMixin: ...@@ -121,7 +121,7 @@ class ModelTesterMixin:
out_2 = out_2[~np.isnan(out_2)] out_2 = out_2[~np.isnan(out_2)]
max_diff = np.amax(np.abs(out_1 - out_2)) max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5) self.assertLessEqual(max_diff, 1e-5)
def test_output(self): def test_output(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
...@@ -130,11 +130,11 @@ class ModelTesterMixin: ...@@ -130,11 +130,11 @@ class ModelTesterMixin:
with torch.no_grad(): with torch.no_grad():
output = model(**inputs_dict) output = model(**inputs_dict)
self.assertIsNotNone(output) self.assertIsNotNone(output)
expected_shape = inputs_dict["x"].shape expected_shape = inputs_dict["x"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_forward_signature(self): def test_forward_signature(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common() init_dict, _ = self.prepare_init_args_and_inputs_for_common()
...@@ -145,14 +145,14 @@ class ModelTesterMixin: ...@@ -145,14 +145,14 @@ class ModelTesterMixin:
expected_arg_names = ["x", "timesteps"] expected_arg_names = ["x", "timesteps"]
self.assertListEqual(arg_names[:2], expected_arg_names) self.assertListEqual(arg_names[:2], expected_arg_names)
def test_model_from_config(self): def test_model_from_config(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
# test if the model can be loaded from the config # test if the model can be loaded from the config
# and has all the expected shape # and has all the expected shape
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
...@@ -160,17 +160,17 @@ class ModelTesterMixin: ...@@ -160,17 +160,17 @@ class ModelTesterMixin:
new_model = self.model_class.from_config(tmpdirname) new_model = self.model_class.from_config(tmpdirname)
new_model.to(torch_device) new_model.to(torch_device)
new_model.eval() new_model.eval()
# check if all paramters shape are the same # check if all paramters shape are the same
for param_name in model.state_dict().keys(): for param_name in model.state_dict().keys():
param_1 = model.state_dict()[param_name] param_1 = model.state_dict()[param_name]
param_2 = new_model.state_dict()[param_name] param_2 = new_model.state_dict()[param_name]
self.assertEqual(param_1.shape, param_2.shape) self.assertEqual(param_1.shape, param_2.shape)
with torch.no_grad(): with torch.no_grad():
output_1 = model(**inputs_dict) output_1 = model(**inputs_dict)
output_2 = new_model(**inputs_dict) output_2 = new_model(**inputs_dict)
self.assertEqual(output_1.shape, output_2.shape) self.assertEqual(output_1.shape, output_2.shape)
def test_training(self): def test_training(self):
...@@ -180,7 +180,7 @@ class ModelTesterMixin: ...@@ -180,7 +180,7 @@ class ModelTesterMixin:
model.to(torch_device) model.to(torch_device)
model.train() model.train()
output = model(**inputs_dict) output = model(**inputs_dict)
noise = torch.randn((inputs_dict["x"].shape[0], ) + self.get_output_shape).to(torch_device) noise = torch.randn((inputs_dict["x"].shape[0],) + self.get_output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise) loss = torch.nn.functional.mse_loss(output, noise)
loss.backward() loss.backward()
...@@ -198,11 +198,11 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -198,11 +198,11 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
time_step = torch.tensor([10]).to(torch_device) time_step = torch.tensor([10]).to(torch_device)
return {"x": noise, "timesteps": time_step} return {"x": noise, "timesteps": time_step}
@property @property
def get_input_shape(self): def get_input_shape(self):
return (3, 32, 32) return (3, 32, 32)
@property @property
def get_output_shape(self): def get_output_shape(self):
return (3, 32, 32) return (3, 32, 32)
...@@ -217,7 +217,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -217,7 +217,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
model, loading_info = UNetModel.from_pretrained("fusing/ddpm_dummy", output_loading_info=True) model, loading_info = UNetModel.from_pretrained("fusing/ddpm_dummy", output_loading_info=True)
self.assertIsNotNone(model) self.assertIsNotNone(model)
...@@ -227,7 +227,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -227,7 +227,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
image = model(**self.dummy_input) image = model(**self.dummy_input)
assert image is not None, "Make sure output is not None" assert image is not None, "Make sure output is not None"
def test_output_pretrained(self): def test_output_pretrained(self):
model = UNetModel.from_pretrained("fusing/ddpm_dummy") model = UNetModel.from_pretrained("fusing/ddpm_dummy")
model.eval() model.eval()
...@@ -235,13 +235,13 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -235,13 +235,13 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
torch.manual_seed(0) torch.manual_seed(0)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
noise = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution) noise = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution)
time_step = torch.tensor([10]) time_step = torch.tensor([10])
with torch.no_grad(): with torch.no_grad():
output = model(noise, time_step) output = model(noise, time_step)
output_slice = output[0, -1, -3:, -3:].flatten() output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off # fmt: off
expected_output_slice = torch.tensor([ 0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053]) expected_output_slice = torch.tensor([ 0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053])
...@@ -249,6 +249,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -249,6 +249,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
print(output_slice) print(output_slice)
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase): class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
model_class = GLIDESuperResUNetModel model_class = GLIDESuperResUNetModel
...@@ -266,19 +267,19 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase): ...@@ -266,19 +267,19 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
time_step = torch.tensor([10] * noise.shape[0], device=torch_device) time_step = torch.tensor([10] * noise.shape[0], device=torch_device)
return {"x": noise, "timesteps": time_step, "low_res": low_res} return {"x": noise, "timesteps": time_step, "low_res": low_res}
@property @property
def get_input_shape(self): def get_input_shape(self):
return (3, 32, 32) return (3, 32, 32)
@property @property
def get_output_shape(self): def get_output_shape(self):
return (6, 32, 32) return (6, 32, 32)
def prepare_init_args_and_inputs_for_common(self): def prepare_init_args_and_inputs_for_common(self):
init_dict = { init_dict = {
"attention_resolutions": (2,), "attention_resolutions": (2,),
"channel_mult": (1,2), "channel_mult": (1, 2),
"in_channels": 6, "in_channels": 6,
"out_channels": 6, "out_channels": 6,
"model_channels": 32, "model_channels": 32,
...@@ -287,7 +288,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase): ...@@ -287,7 +288,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
"num_res_blocks": 2, "num_res_blocks": 2,
"resblock_updown": True, "resblock_updown": True,
"resolution": 32, "resolution": 32,
"use_scale_shift_norm": True "use_scale_shift_norm": True,
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
...@@ -302,13 +303,15 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase): ...@@ -302,13 +303,15 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
output = model(**inputs_dict) output = model(**inputs_dict)
output, _ = torch.split(output, 3, dim=1) output, _ = torch.split(output, 3, dim=1)
self.assertIsNotNone(output) self.assertIsNotNone(output)
expected_shape = inputs_dict["x"].shape expected_shape = inputs_dict["x"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
model, loading_info = GLIDESuperResUNetModel.from_pretrained("fusing/glide-super-res-dummy", output_loading_info=True) model, loading_info = GLIDESuperResUNetModel.from_pretrained(
"fusing/glide-super-res-dummy", output_loading_info=True
)
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0) self.assertEqual(len(loading_info["missing_keys"]), 0)
...@@ -316,7 +319,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase): ...@@ -316,7 +319,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
image = model(**self.dummy_input) image = model(**self.dummy_input)
assert image is not None, "Make sure output is not None" assert image is not None, "Make sure output is not None"
# TODO (patil-suraj): Check why GLIDESuperResUNetModel always outputs zero # TODO (patil-suraj): Check why GLIDESuperResUNetModel always outputs zero
@unittest.skip("GLIDESuperResUNetModel always outputs zero") @unittest.skip("GLIDESuperResUNetModel always outputs zero")
def test_output_pretrained(self): def test_output_pretrained(self):
...@@ -326,14 +329,14 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase): ...@@ -326,14 +329,14 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
torch.manual_seed(0) torch.manual_seed(0)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
noise = torch.randn(1, 3, 32, 32) noise = torch.randn(1, 3, 32, 32)
low_res = torch.randn(1, 3, 4, 4) low_res = torch.randn(1, 3, 4, 4)
time_step = torch.tensor([42] * noise.shape[0]) time_step = torch.tensor([42] * noise.shape[0])
with torch.no_grad(): with torch.no_grad():
output = model(noise, time_step, low_res) output = model(noise, time_step, low_res)
output, _ = torch.split(output, 3, dim=1) output, _ = torch.split(output, 3, dim=1)
output_slice = output[0, -1, -3:, -3:].flatten() output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off # fmt: off
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment