Commit ae81c3d6 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

make from pretrained more general

parent db3757aa
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Denoising Diffusion Implicit Models (DDIM)
## Overview
DDPM was proposed in [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) by *Jiaming Song, Chenlin Meng, Stefano Ermon*
The abstract from the paper is the following:
*Denoising diffusion probabilistic models (DDPMs) have achieved high quality image generation without adversarial training, yet they require simulating a Markov chain for many steps to produce a sample. To accelerate sampling, we present denoising diffusion implicit models (DDIMs), a more efficient class of iterative implicit probabilistic models with the same training procedure as DDPMs. In DDPMs, the generative process is defined as the reverse of a Markovian diffusion process. We construct a class of non-Markovian diffusion processes that lead to the same training objective, but whose reverse process can be much faster to sample from. We empirically demonstrate that DDIMs can produce high quality samples 10× to 50× faster in terms of wall-clock time compared to DDPMs, allow us to trade off computation for sample quality, and can perform semantically meaningful image interpolation directly in the latent space.*
Tips:
- ...
- ...
This model was contributed by [???](https://huggingface.co/???). The original code can be found [here](https://github.com/hojonathanho/diffusion).
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from diffusers import DiffusionPipeline
import tqdm
import torch
def compute_alpha(beta, t):
beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
return a
class DDIM(DiffusionPipeline):
def __init__(self, unet, noise_scheduler):
super().__init__()
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, inference_time_steps=50):
seq = range(0, self.num_timesteps, self.num_timesteps // inference_time_steps)
b = self.noise_scheduler.betas
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.unet.to(torch_device)
x = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator)
with torch.no_grad():
n = batch_size
seq_next = [-1] + list(seq[:-1])
x0_preds = []
xs = [x]
for i, j in zip(reversed(seq), reversed(seq_next)):
print(i)
t = (torch.ones(n) * i).to(x.device)
next_t = (torch.ones(n) * j).to(x.device)
at = compute_alpha(b, t.long())
at_next = compute_alpha(b, next_t.long())
xt = xs[-1].to('cuda')
et = self.unet(xt, t)
x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
x0_preds.append(x0_t.to('cpu'))
# eta
c1 = (
eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
)
c2 = ((1 - at_next) - c1 ** 2).sqrt()
xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
xs.append(xt_next.to('cpu'))
import ipdb; ipdb.set_trace()
return xs, x0_preds
#!/usr/bin/env python3
import torch
from diffusers import GaussianDDPMScheduler, UNetModel
model = UNetModel(dim=64, dim_mults=(1, 2, 4, 8))
diffusion = GaussianDDPMScheduler(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2
training_images = torch.randn(8, 3, 128, 128) # your images need to be normalized from a range of -1 to +1
loss = diffusion(training_images)
loss.backward()
# after a lot of training
sampled_images = diffusion.sample(batch_size=4)
sampled_images.shape # (4, 3, 128, 128)
#!/usr/bin/env python3
# !pip install diffusers
from diffusers import DiffusionPipeline
from modeling_ddim import DDIM
import PIL.Image
import numpy as np
......@@ -8,7 +8,7 @@ model_id = "fusing/ddpm-cifar10"
model_id = "fusing/ddpm-lsun-bedroom"
# load model and scheduler
ddpm = DiffusionPipeline.from_pretrained(model_id)
ddpm = DDIM.from_pretrained(model_id)
# run pipeline in inference (sample random noise and denoise)
image = ddpm()
......
......@@ -21,8 +21,6 @@ import torch
class DDPM(DiffusionPipeline):
modeling_file = "modeling_ddpm.py"
def __init__(self, unet, noise_scheduler):
super().__init__()
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
......
......@@ -56,7 +56,6 @@ class DiffusionPipeline(ConfigMixin):
register_dict = {name: (library, class_name)}
# save model index config
self.register(**register_dict)
......@@ -102,15 +101,15 @@ class DiffusionPipeline(ConfigMixin):
config_dict = cls.get_config_dict(cached_folder)
module = config_dict["_module"]
class_name_ = config_dict["_class_name"]
if class_name_ == cls.__name__:
# if we load from explicit class, let's use it
if cls != DiffusionPipeline:
pipeline_class = cls
else:
# else we need to load the correct module from the Hub
class_name_ = config_dict["_class_name"]
module = config_dict["_module"]
pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_kwargs = {}
......@@ -120,6 +119,7 @@ class DiffusionPipeline(ConfigMixin):
if library_name == module:
# TODO(Suraj)
# for vq
pass
library = importlib.import_module(library_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment