Commit 05e265fb authored by Patrick von Platen's avatar Patrick von Platen
Browse files
parents a677565f 694ad984
...@@ -30,19 +30,31 @@ More precisely, 🤗 Diffusers offers: ...@@ -30,19 +30,31 @@ More precisely, 🤗 Diffusers offers:
**Models**: Neural network that models $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$ (see image below) and is trained end-to-end to *denoise* a noisy input to an image. **Models**: Neural network that models $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$ (see image below) and is trained end-to-end to *denoise* a noisy input to an image.
*Examples*: UNet, Conditioned UNet, 3D UNet, Transformer UNet *Examples*: UNet, Conditioned UNet, 3D UNet, Transformer UNet
![model_diff_1_50](https://user-images.githubusercontent.com/23423619/171610307-dab0cd8b-75da-4d4e-9f5a-5922072e2bb5.png) <p align="center">
<img src="https://user-images.githubusercontent.com/10695622/174349667-04e9e485-793b-429a-affe-096e8199ad5b.png" width="800"/>
<br>
<em> Figure from DDPM paper (https://arxiv.org/abs/2006.11239). </em>
<p>
**Schedulers**: Algorithm class for both **inference** and **training**. **Schedulers**: Algorithm class for both **inference** and **training**.
The class provides functionality to compute previous image according to alpha, beta schedule as well as predict noise for training. The class provides functionality to compute previous image according to alpha, beta schedule as well as predict noise for training.
*Examples*: [DDPM](https://arxiv.org/abs/2006.11239), [DDIM](https://arxiv.org/abs/2010.02502), [PNDM](https://arxiv.org/abs/2202.09778), [DEIS](https://arxiv.org/abs/2204.13902) *Examples*: [DDPM](https://arxiv.org/abs/2006.11239), [DDIM](https://arxiv.org/abs/2010.02502), [PNDM](https://arxiv.org/abs/2202.09778), [DEIS](https://arxiv.org/abs/2204.13902)
![sampling](https://user-images.githubusercontent.com/23423619/171608981-3ad05953-a684-4c82-89f8-62a459147a07.png) <p align="center">
![training](https://user-images.githubusercontent.com/23423619/171608964-b3260cce-e6b4-4841-959d-7d8ba4b8d1b2.png) <img src="https://user-images.githubusercontent.com/10695622/174349706-53d58acc-a4d1-4cda-b3e8-432d9dc7ad38.png" width="800"/>
<br>
<em> Sampling and training algorithms. Figure from DDPM paper (https://arxiv.org/abs/2006.11239). </em>
<p>
**Diffusion Pipeline**: End-to-end pipeline that includes multiple diffusion models, possible text encoders, ... **Diffusion Pipeline**: End-to-end pipeline that includes multiple diffusion models, possible text encoders, ...
*Examples*: GLIDE, Latent-Diffusion, Imagen, DALL-E 2 *Examples*: GLIDE, Latent-Diffusion, Imagen, DALL-E 2
![imagen](https://user-images.githubusercontent.com/23423619/171609001-c3f2c1c9-f597-4a16-9843-749bf3f9431c.png) <p align="center">
<img src="https://user-images.githubusercontent.com/10695622/174348898-481bd7c2-5457-4830-89bc-f0907756f64c.jpeg" width="550"/>
<br>
<em> Figure from ImageGen (https://imagen.research.google/). </em>
<p>
## Philosophy ## Philosophy
...@@ -173,6 +185,10 @@ image_pil = PIL.Image.fromarray(image_processed[0]) ...@@ -173,6 +185,10 @@ image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save("test.png") image_pil.save("test.png")
``` ```
#### **Examples for other modalities:**
[Diffuser](https://diffusion-planning.github.io/) for planning in reinforcement learning: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1TmBmlYeKUZSkUZoJqfBmaicVTKx6nN1R?usp=sharing)
### 2. `diffusers` as a collection of popular Diffusion systems (GLIDE, Dalle, ...) ### 2. `diffusers` as a collection of popular Diffusion systems (GLIDE, Dalle, ...)
For more examples see [pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines). For more examples see [pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines).
......
...@@ -490,7 +490,7 @@ class ModelMixin(torch.nn.Module): ...@@ -490,7 +490,7 @@ class ModelMixin(torch.nn.Module):
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
if len(unexpected_keys) > 0: if len(unexpected_keys) > 0:
logger.warninging( logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
...@@ -502,7 +502,7 @@ class ModelMixin(torch.nn.Module): ...@@ -502,7 +502,7 @@ class ModelMixin(torch.nn.Module):
else: else:
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0: if len(missing_keys) > 0:
logger.warninging( logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
" TRAIN this model on a down-stream task to be able to use it for predictions and inference." " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
...@@ -521,7 +521,7 @@ class ModelMixin(torch.nn.Module): ...@@ -521,7 +521,7 @@ class ModelMixin(torch.nn.Module):
for key, shape1, shape2 in mismatched_keys for key, shape1, shape2 in mismatched_keys
] ]
) )
logger.warninging( logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
......
...@@ -18,6 +18,7 @@ import inspect ...@@ -18,6 +18,7 @@ import inspect
import tempfile import tempfile
import unittest import unittest
import numpy as np import numpy as np
import pytest
import torch import torch
...@@ -32,6 +33,7 @@ from diffusers import ( ...@@ -32,6 +33,7 @@ from diffusers import (
LatentDiffusion, LatentDiffusion,
PNDMScheduler, PNDMScheduler,
UNetModel, UNetModel,
GLIDESuperResUNetModel
) )
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
...@@ -94,7 +96,7 @@ class ModelTesterMixin: ...@@ -94,7 +96,7 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
new_model = UNetModel.from_pretrained(tmpdirname) new_model = self.model_class.from_pretrained(tmpdirname)
new_model.to(torch_device) new_model.to(torch_device)
with torch.no_grad(): with torch.no_grad():
...@@ -178,7 +180,7 @@ class ModelTesterMixin: ...@@ -178,7 +180,7 @@ class ModelTesterMixin:
model.to(torch_device) model.to(torch_device)
model.train() model.train()
output = model(**inputs_dict) output = model(**inputs_dict)
noise = torch.randn(inputs_dict["x"].shape).to(torch_device) noise = torch.randn((inputs_dict["x"].shape[0], ) + self.get_output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise) loss = torch.nn.functional.mse_loss(output, noise)
loss.backward() loss.backward()
...@@ -197,6 +199,14 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -197,6 +199,14 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
return {"x": noise, "timesteps": time_step} return {"x": noise, "timesteps": time_step}
@property
def get_input_shape(self):
return (3, 32, 32)
@property
def get_output_shape(self):
return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self): def prepare_init_args_and_inputs_for_common(self):
init_dict = { init_dict = {
"ch": 32, "ch": 32,
...@@ -227,7 +237,6 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -227,7 +237,6 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
noise = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution) noise = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution)
print(noise.shape)
time_step = torch.tensor([10]) time_step = torch.tensor([10])
with torch.no_grad(): with torch.no_grad():
...@@ -240,6 +249,98 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -240,6 +249,98 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
print(output_slice) print(output_slice)
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
model_class = GLIDESuperResUNetModel
@property
def dummy_input(self):
batch_size = 4
num_channels = 6
sizes = (32, 32)
low_res_size = (4, 4)
torch_device = "cpu"
noise = torch.randn((batch_size, num_channels // 2) + sizes).to(torch_device)
low_res = torch.randn((batch_size, 3) + low_res_size).to(torch_device)
time_step = torch.tensor([10] * noise.shape[0], device=torch_device)
return {"x": noise, "timesteps": time_step, "low_res": low_res}
@property
def get_input_shape(self):
return (3, 32, 32)
@property
def get_output_shape(self):
return (6, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"attention_resolutions": (2,),
"channel_mult": (1,2),
"in_channels": 6,
"out_channels": 6,
"model_channels": 32,
"num_head_channels": 8,
"num_heads_upsample": 1,
"num_res_blocks": 2,
"resblock_updown": True,
"resolution": 32,
"use_scale_shift_norm": True
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_output(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
output, _ = torch.split(output, 3, dim=1)
self.assertIsNotNone(output)
expected_shape = inputs_dict["x"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_from_pretrained_hub(self):
model, loading_info = GLIDESuperResUNetModel.from_pretrained("fusing/glide-super-res-dummy", output_loading_info=True)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
model.to(torch_device)
image = model(**self.dummy_input)
assert image is not None, "Make sure output is not None"
# TODO (patil-suraj): Check why GLIDESuperResUNetModel always outputs zero
@unittest.skip("GLIDESuperResUNetModel always outputs zero")
def test_output_pretrained(self):
model = GLIDESuperResUNetModel.from_pretrained("fusing/glide-super-res-dummy")
model.eval()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
noise = torch.randn(1, 3, 32, 32)
low_res = torch.randn(1, 3, 4, 4)
time_step = torch.tensor([42] * noise.shape[0])
with torch.no_grad():
output = model(noise, time_step, low_res)
output, _ = torch.split(output, 3, dim=1)
output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
expected_output_slice = torch.tensor([ 0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053])
# fmt: on
print(output_slice)
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
class PipelineTesterMixin(unittest.TestCase): class PipelineTesterMixin(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment