Commit e779b250 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

add first template for DDPM forward

parent 95f4256f
#!/usr/bin/env python3
import torch
from diffusers import GaussianDiffusion, UNetConfig, UNetModel
config = UNetConfig(dim=64, dim_mults=(1, 2, 4, 8))
model = UNetModel(config)
print(model.config)
model.save_pretrained("/home/patrick/diffusion_example")
import ipdb
ipdb.set_trace()
diffusion = GaussianDiffusion(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2
training_images = torch.randn(8, 3, 128, 128) # your images need to be normalized from a range of -1 to +1
loss = diffusion(training_images)
loss.backward()
# after a lot of training
sampled_images = diffusion.sample(batch_size=4)
sampled_images.shape # (4, 3, 128, 128)
......@@ -4,4 +4,5 @@
__version__ = "0.0.1"
from .models import UNetModel
from .models.unet import GaussianDiffusion # TODO(PVP): move somewhere else
from .models.unet import UNetConfig, UNetModel
This diff is collapsed.
This diff is collapsed.
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all
# module, but to preserve other warnings. So, don't check this module at all.
from .unet import UNetModel
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import unet
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all
from .configuration_unet import UNetConfig
from .modeling_unet import GaussianDiffusion, UNetModel
......@@ -10,10 +10,36 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# helpers functions
# NOTE: the following file is completely copied from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/master/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
from ...configuration_utils import PretrainedConfig
class UNetConfig(PretrainedConfig):
model_type = "unet"
class UNetModel:
def __init__(self, config):
self.config = config
print("I can diffuse!")
def __init__(
self,
dim=64,
dim_mults=(1, 2, 4, 8),
init_dim=None,
out_dim=None,
channels=3,
with_time_emb=True,
resnet_block_groups=8,
learned_variance=False,
**kwargs,
):
super().__init__(**kwargs)
self.dim = dim
self.dim_mults = dim_mults
self.init_dim = init_dim
self.out_dim = out_dim
self.channels = channels
self.with_time_emb = with_time_emb
self.resnet_block_groups = resnet_block_groups
self.learned_variance = learned_variance
This diff is collapsed.
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import tempfile
import unittest
import torch
from diffusers import UNetConfig, UNetModel
global_rng = random.Random()
def floats_tensor(shape, scale=1.0, rng=None, name=None):
"""Creates a random float32 tensor"""
if rng is None:
rng = global_rng
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.random() * scale)
return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous()
class ModelTesterMixin(unittest.TestCase):
def test_from_pretrained_save_pretrained(self):
config = UNetConfig(dim=8, dim_mults=(1, 2), resnet_block_groups=2)
model = UNetModel(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
new_model = UNetModel.from_pretrained(tmpdirname)
batch_size = 1
num_channels = 3
sizes = (32, 32)
noise = floats_tensor((batch_size, num_channels) + sizes)
time_step = torch.tensor([10])
image = model(noise, time_step)
new_image = new_model(noise, time_step)
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment