Commit 08c85229 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

add license disclaimers to schedulers

parent 2b8bc91c
# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py
import math
import torch
import torch.nn as nn
import einops
from einops.layers.torch import Rearrange
import math
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
......@@ -20,6 +23,7 @@ class SinusoidalPosEmb(nn.Module):
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class Downsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
......@@ -28,6 +32,7 @@ class Downsample1d(nn.Module):
def forward(self, x):
return self.conv(x)
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
......@@ -36,57 +41,61 @@ class Upsample1d(nn.Module):
def forward(self, x):
return self.conv(x)
class Conv1dBlock(nn.Module):
'''
"""
Conv1d --> GroupNorm --> Mish
'''
"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
Rearrange('batch channels horizon -> batch channels 1 horizon'),
Rearrange("batch channels horizon -> batch channels 1 horizon"),
nn.GroupNorm(n_groups, out_channels),
Rearrange('batch channels 1 horizon -> batch channels horizon'),
Rearrange("batch channels 1 horizon -> batch channels horizon"),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
class ResidualTemporalBlock(nn.Module):
class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super().__init__()
self.blocks = nn.ModuleList([
self.blocks = nn.ModuleList(
[
Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock(out_channels, out_channels, kernel_size),
])
]
)
self.time_mlp = nn.Sequential(
nn.Mish(),
nn.Linear(embed_dim, out_channels),
Rearrange('batch t -> batch t 1'),
Rearrange("batch t -> batch t 1"),
)
self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1) \
if inp_channels != out_channels else nn.Identity()
self.residual_conv = (
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
)
def forward(self, x, t):
'''
"""
x : [ batch_size x inp_channels x horizon ]
t : [ batch_size x embed_dim ]
returns:
out : [ batch_size x out_channels x horizon ]
'''
"""
out = self.blocks[0](x) + self.time_mlp(t)
out = self.blocks[1](out)
return out + self.residual_conv(x)
class TemporalUnet(nn.Module):
class TemporalUnet(nn.Module):
def __init__(
self,
horizon,
......@@ -99,7 +108,7 @@ class TemporalUnet(nn.Module):
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
print(f'[ models/temporal ] Channel dimensions: {in_out}')
print(f"[ models/temporal ] Channel dimensions: {in_out}")
time_dim = dim
self.time_mlp = nn.Sequential(
......@@ -117,11 +126,15 @@ class TemporalUnet(nn.Module):
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(nn.ModuleList([
self.downs.append(
nn.ModuleList(
[
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=horizon),
Downsample1d(dim_out) if not is_last else nn.Identity()
]))
Downsample1d(dim_out) if not is_last else nn.Identity(),
]
)
)
if not is_last:
horizon = horizon // 2
......@@ -133,11 +146,15 @@ class TemporalUnet(nn.Module):
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
self.ups.append(nn.ModuleList([
self.ups.append(
nn.ModuleList(
[
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=horizon),
Upsample1d(dim_in) if not is_last else nn.Identity()
]))
Upsample1d(dim_in) if not is_last else nn.Identity(),
]
)
)
if not is_last:
horizon = horizon * 2
......@@ -148,11 +165,11 @@ class TemporalUnet(nn.Module):
)
def forward(self, x, cond, time):
'''
"""
x : [ batch x horizon x transition ]
'''
"""
x = einops.rearrange(x, 'b h t -> b t h')
x = einops.rearrange(x, "b h t -> b t h")
t = self.time_mlp(time)
h = []
......@@ -174,11 +191,11 @@ class TemporalUnet(nn.Module):
x = self.final_conv(x)
x = einops.rearrange(x, 'b t h -> b h t')
x = einops.rearrange(x, "b t h -> b h t")
return x
class TemporalValue(nn.Module):
class TemporalValue(nn.Module):
def __init__(
self,
horizon,
......@@ -207,11 +224,15 @@ class TemporalValue(nn.Module):
print(in_out)
for dim_in, dim_out in in_out:
self.blocks.append(nn.ModuleList([
self.blocks.append(
nn.ModuleList(
[
ResidualTemporalBlock(dim_in, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
ResidualTemporalBlock(dim_out, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
Downsample1d(dim_out)
]))
Downsample1d(dim_out),
]
)
)
horizon = horizon // 2
......@@ -224,11 +245,11 @@ class TemporalValue(nn.Module):
)
def forward(self, x, cond, time, *args):
'''
"""
x : [ batch x horizon x transition ]
'''
"""
x = einops.rearrange(x, 'b h t -> b t h')
x = einops.rearrange(x, "b h t -> b t h")
t = self.time_mlp(time)
......
......@@ -233,6 +233,7 @@ def english_cleaners(text):
text = collapse_whitespace(text)
return text
try:
_inflect = inflect.engine()
except:
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion
import math
import numpy as np
......@@ -31,6 +35,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
import numpy as np
......@@ -31,6 +34,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
# Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,9 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
import numpy as np
from ..configuration_utils import ConfigMixin
from .scheduling_utils import SchedulerMixin
......@@ -30,6 +34,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
......
......@@ -17,11 +17,11 @@
import inspect
import tempfile
import unittest
import numpy as np
import pytest
import numpy as np
import torch
import pytest
from diffusers import (
BDDM,
DDIM,
......@@ -30,10 +30,10 @@ from diffusers import (
PNDM,
DDIMScheduler,
DDPMScheduler,
GLIDESuperResUNetModel,
LatentDiffusion,
PNDMScheduler,
UNetModel,
GLIDESuperResUNetModel
)
from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline
......@@ -180,7 +180,7 @@ class ModelTesterMixin:
model.to(torch_device)
model.train()
output = model(**inputs_dict)
noise = torch.randn((inputs_dict["x"].shape[0], ) + self.get_output_shape).to(torch_device)
noise = torch.randn((inputs_dict["x"].shape[0],) + self.get_output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
......@@ -249,6 +249,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
print(output_slice)
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
model_class = GLIDESuperResUNetModel
......@@ -278,7 +279,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"attention_resolutions": (2,),
"channel_mult": (1,2),
"channel_mult": (1, 2),
"in_channels": 6,
"out_channels": 6,
"model_channels": 32,
......@@ -287,7 +288,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
"num_res_blocks": 2,
"resblock_updown": True,
"resolution": 32,
"use_scale_shift_norm": True
"use_scale_shift_norm": True,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
......@@ -308,7 +309,9 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_from_pretrained_hub(self):
model, loading_info = GLIDESuperResUNetModel.from_pretrained("fusing/glide-super-res-dummy", output_loading_info=True)
model, loading_info = GLIDESuperResUNetModel.from_pretrained(
"fusing/glide-super-res-dummy", output_loading_info=True
)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment