Commit 4497e78d authored by Nathan Lambert's avatar Nathan Lambert
Browse files

merge unet-rl formatting

parents 49718b47 77aadfee
...@@ -74,9 +74,9 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency ...@@ -74,9 +74,9 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency
# Make marked copies of snippets of codes conform to the original # Make marked copies of snippets of codes conform to the original
fix-copies: fix-copies:
python utils/check_copies.py --fix_and_overwrite
python utils/check_table.py --fix_and_overwrite
python utils/check_dummies.py --fix_and_overwrite python utils/check_dummies.py --fix_and_overwrite
python utils/check_table.py --fix_and_overwrite
python utils/check_copies.py --fix_and_overwrite
# Run tests for the library # Run tests for the library
......
...@@ -30,20 +30,32 @@ More precisely, 🤗 Diffusers offers: ...@@ -30,20 +30,32 @@ More precisely, 🤗 Diffusers offers:
**Models**: Neural network that models $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$ (see image below) and is trained end-to-end to *denoise* a noisy input to an image. **Models**: Neural network that models $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$ (see image below) and is trained end-to-end to *denoise* a noisy input to an image.
*Examples*: UNet, Conditioned UNet, 3D UNet, Transformer UNet *Examples*: UNet, Conditioned UNet, 3D UNet, Transformer UNet
![model_diff_1_50](https://user-images.githubusercontent.com/23423619/171610307-dab0cd8b-75da-4d4e-9f5a-5922072e2bb5.png) <p align="center">
<img src="https://user-images.githubusercontent.com/10695622/174349667-04e9e485-793b-429a-affe-096e8199ad5b.png" width="800"/>
<br>
<em> Figure from DDPM paper (https://arxiv.org/abs/2006.11239). </em>
<p>
**Schedulers**: Algorithm class for both **inference** and **training**. **Schedulers**: Algorithm class for both **inference** and **training**.
The class provides functionality to compute previous image according to alpha, beta schedule as well as predict noise for training. The class provides functionality to compute previous image according to alpha, beta schedule as well as predict noise for training.
*Examples*: [DDPM](https://arxiv.org/abs/2006.11239), [DDIM](https://arxiv.org/abs/2010.02502), [PNDM](https://arxiv.org/abs/2202.09778), [DEIS](https://arxiv.org/abs/2204.13902) *Examples*: [DDPM](https://arxiv.org/abs/2006.11239), [DDIM](https://arxiv.org/abs/2010.02502), [PNDM](https://arxiv.org/abs/2202.09778), [DEIS](https://arxiv.org/abs/2204.13902)
![sampling](https://user-images.githubusercontent.com/23423619/171608981-3ad05953-a684-4c82-89f8-62a459147a07.png) <p align="center">
![training](https://user-images.githubusercontent.com/23423619/171608964-b3260cce-e6b4-4841-959d-7d8ba4b8d1b2.png) <img src="https://user-images.githubusercontent.com/10695622/174349706-53d58acc-a4d1-4cda-b3e8-432d9dc7ad38.png" width="800"/>
<br>
<em> Sampling and training algorithms. Figure from DDPM paper (https://arxiv.org/abs/2006.11239). </em>
<p>
**Diffusion Pipeline**: End-to-end pipeline that includes multiple diffusion models, possible text encoders, ... **Diffusion Pipeline**: End-to-end pipeline that includes multiple diffusion models, possible text encoders, ...
*Examples*: GLIDE, Latent-Diffusion, Imagen, DALL-E 2 *Examples*: GLIDE, Latent-Diffusion, Imagen, DALL-E 2
![imagen](https://user-images.githubusercontent.com/23423619/171609001-c3f2c1c9-f597-4a16-9843-749bf3f9431c.png) <p align="center">
<img src="https://user-images.githubusercontent.com/10695622/174348898-481bd7c2-5457-4830-89bc-f0907756f64c.jpeg" width="550"/>
<br>
<em> Figure from ImageGen (https://imagen.research.google/). </em>
<p>
## Philosophy ## Philosophy
- Readability and clarity is prefered over highly optimized code. A strong importance is put on providing readable, intuitive and elementary code design. *E.g.*, the provided [schedulers](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers) are separated from the provided [models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and provide well-commented code that can be read alongside the original paper. - Readability and clarity is prefered over highly optimized code. A strong importance is put on providing readable, intuitive and elementary code design. *E.g.*, the provided [schedulers](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers) are separated from the provided [models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and provide well-commented code that can be read alongside the original paper.
...@@ -147,7 +159,8 @@ eta = 0.0 # <- deterministic sampling ...@@ -147,7 +159,8 @@ eta = 0.0 # <- deterministic sampling
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps): for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# 1. predict noise residual # 1. predict noise residual
orig_t = noise_scheduler.get_orig_t(t, num_inference_steps) orig_t = len(noise_scheduler) // num_inference_steps * t
with torch.inference_mode(): with torch.inference_mode():
residual = unet(image, orig_t) residual = unet(image, orig_t)
...@@ -173,6 +186,10 @@ image_pil = PIL.Image.fromarray(image_processed[0]) ...@@ -173,6 +186,10 @@ image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save("test.png") image_pil.save("test.png")
``` ```
#### **Examples for other modalities:**
[Diffuser](https://diffusion-planning.github.io/) for planning in reinforcement learning: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1TmBmlYeKUZSkUZoJqfBmaicVTKx6nN1R?usp=sharing)
### 2. `diffusers` as a collection of popular Diffusion systems (GLIDE, Dalle, ...) ### 2. `diffusers` as a collection of popular Diffusion systems (GLIDE, Dalle, ...)
For more examples see [pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines). For more examples see [pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines).
......
# flake8: noqa # flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this # There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all. # module, but to preserve other warnings. So, don't check this module at all.
from .utils import is_transformers_available
__version__ = "0.0.4" __version__ = "0.0.4"
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .models.unet import UNetModel from .models.unet import UNetModel
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
from .models.unet_grad_tts import UNetGradTTSModel
from .models.unet_ldm import UNetLDMModel from .models.unet_ldm import UNetLDMModel
from .models.unet_rl import TemporalUNet from .models.unet_rl import TemporalUNet
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .pipelines import BDDM, DDIM, DDPM, GLIDE, PNDM, GradTTS, LatentDiffusion from .pipelines import BDDM, DDIM, DDPM, PNDM
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
if is_transformers_available():
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
from .models.unet_grad_tts import UNetGradTTSModel
from .pipelines import GLIDE, GradTTS, LatentDiffusion
else:
from .utils.dummy_transformers_objects import *
...@@ -241,7 +241,7 @@ class ConfigMixin: ...@@ -241,7 +241,7 @@ class ConfigMixin:
Returns: Returns:
`str`: String containing all the attributes that make up this configuration instance in JSON format. `str`: String containing all the attributes that make up this configuration instance in JSON format.
""" """
config_dict = self._internal_dict config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path: Union[str, os.PathLike]): def to_json_file(self, json_file_path: Union[str, os.PathLike]):
......
...@@ -490,7 +490,7 @@ class ModelMixin(torch.nn.Module): ...@@ -490,7 +490,7 @@ class ModelMixin(torch.nn.Module):
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
if len(unexpected_keys) > 0: if len(unexpected_keys) > 0:
logger.warninging( logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
...@@ -502,7 +502,7 @@ class ModelMixin(torch.nn.Module): ...@@ -502,7 +502,7 @@ class ModelMixin(torch.nn.Module):
else: else:
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0: if len(missing_keys) > 0:
logger.warninging( logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
" TRAIN this model on a down-stream task to be able to use it for predictions and inference." " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
...@@ -521,7 +521,7 @@ class ModelMixin(torch.nn.Module): ...@@ -521,7 +521,7 @@ class ModelMixin(torch.nn.Module):
for key, shape1, shape2 in mismatched_keys for key, shape1, shape2 in mismatched_keys
] ]
) )
logger.warninging( logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
......
...@@ -287,14 +287,14 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -287,14 +287,14 @@ class UNetModel(ModelMixin, ConfigMixin):
self.norm_out = Normalize(block_in) self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, x, t): def forward(self, x, timesteps):
assert x.shape[2] == x.shape[3] == self.resolution assert x.shape[2] == x.shape[3] == self.resolution
if not torch.is_tensor(t): if not torch.is_tensor(timesteps):
t = torch.tensor([t], dtype=torch.long, device=x.device) timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device)
# timestep embedding # timestep embedding
temb = get_timestep_embedding(t, self.ch) temb = get_timestep_embedding(timesteps, self.ch)
temb = self.temb.dense[0](temb) temb = self.temb.dense[0](temb)
temb = nonlinearity(temb) temb = nonlinearity(temb)
temb = self.temb.dense[1](temb) temb = self.temb.dense[1](temb)
......
...@@ -190,7 +190,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -190,7 +190,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self.final_block = Block(dim, dim) self.final_block = Block(dim, dim)
self.final_conv = torch.nn.Conv2d(dim, 1, 1) self.final_conv = torch.nn.Conv2d(dim, 1, 1)
def forward(self, x, mask, mu, t, spk=None): def forward(self, x, timesteps, mu, mask, spk=None):
if self.n_spks > 1: if self.n_spks > 1:
# Get speaker embedding # Get speaker embedding
spk = self.spk_emb(spk) spk = self.spk_emb(spk)
...@@ -198,7 +198,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ...@@ -198,7 +198,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
if not isinstance(spk, type(None)): if not isinstance(spk, type(None)):
s = self.spk_mlp(spk) s = self.spk_mlp(spk)
t = self.time_pos_emb(t, scale=self.pe_scale) t = self.time_pos_emb(timesteps, scale=self.pe_scale)
t = self.mlp(t) t = self.mlp(t)
if self.n_spks < 2: if self.n_spks < 2:
......
# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py # model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import einops import einops
from einops.layers.torch import Rearrange from einops.layers.torch import Rearrange
import math
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
...@@ -24,6 +27,7 @@ class SinusoidalPosEmb(nn.Module): ...@@ -24,6 +27,7 @@ class SinusoidalPosEmb(nn.Module):
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb return emb
class Downsample1d(nn.Module): class Downsample1d(nn.Module):
def __init__(self, dim): def __init__(self, dim):
super().__init__() super().__init__()
...@@ -32,6 +36,7 @@ class Downsample1d(nn.Module): ...@@ -32,6 +36,7 @@ class Downsample1d(nn.Module):
def forward(self, x): def forward(self, x):
return self.conv(x) return self.conv(x)
class Upsample1d(nn.Module): class Upsample1d(nn.Module):
def __init__(self, dim): def __init__(self, dim):
super().__init__() super().__init__()
...@@ -40,57 +45,61 @@ class Upsample1d(nn.Module): ...@@ -40,57 +45,61 @@ class Upsample1d(nn.Module):
def forward(self, x): def forward(self, x):
return self.conv(x) return self.conv(x)
class Conv1dBlock(nn.Module): class Conv1dBlock(nn.Module):
''' """
Conv1d --> GroupNorm --> Mish Conv1d --> GroupNorm --> Mish
''' """
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__() super().__init__()
self.block = nn.Sequential( self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
Rearrange('batch channels horizon -> batch channels 1 horizon'), Rearrange("batch channels horizon -> batch channels 1 horizon"),
nn.GroupNorm(n_groups, out_channels), nn.GroupNorm(n_groups, out_channels),
Rearrange('batch channels 1 horizon -> batch channels horizon'), Rearrange("batch channels 1 horizon -> batch channels horizon"),
nn.Mish(), nn.Mish(),
) )
def forward(self, x): def forward(self, x):
return self.block(x) return self.block(x)
class ResidualTemporalBlock(nn.Module):
class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5): def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super().__init__() super().__init__()
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList(
Conv1dBlock(inp_channels, out_channels, kernel_size), [
Conv1dBlock(out_channels, out_channels, kernel_size), Conv1dBlock(inp_channels, out_channels, kernel_size),
]) Conv1dBlock(out_channels, out_channels, kernel_size),
]
)
self.time_mlp = nn.Sequential( self.time_mlp = nn.Sequential(
nn.Mish(), nn.Mish(),
nn.Linear(embed_dim, out_channels), nn.Linear(embed_dim, out_channels),
Rearrange('batch t -> batch t 1'), Rearrange("batch t -> batch t 1"),
) )
self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1) \ self.residual_conv = (
if inp_channels != out_channels else nn.Identity() nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
)
def forward(self, x, t): def forward(self, x, t):
''' """
x : [ batch_size x inp_channels x horizon ] x : [ batch_size x inp_channels x horizon ]
t : [ batch_size x embed_dim ] t : [ batch_size x embed_dim ]
returns: returns:
out : [ batch_size x out_channels x horizon ] out : [ batch_size x out_channels x horizon ]
''' """
out = self.blocks[0](x) + self.time_mlp(t) out = self.blocks[0](x) + self.time_mlp(t)
out = self.blocks[1](out) out = self.blocks[1](out)
return out + self.residual_conv(x) return out + self.residual_conv(x)
class TemporalUNet(ModelMixin, ConfigMixin): #(nn.Module):
class TemporalUNet(ModelMixin, ConfigMixin): #(nn.Module):
def __init__( def __init__(
self, self,
horizon, horizon,
...@@ -105,6 +114,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): #(nn.Module): ...@@ -105,6 +114,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): #(nn.Module):
in_out = list(zip(dims[:-1], dims[1:])) in_out = list(zip(dims[:-1], dims[1:]))
# print(f'[ models/temporal ] Channel dimensions: {in_out}') # print(f'[ models/temporal ] Channel dimensions: {in_out}')
time_dim = dim time_dim = dim
self.time_mlp = nn.Sequential( self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim), SinusoidalPosEmb(dim),
...@@ -121,11 +131,15 @@ class TemporalUNet(ModelMixin, ConfigMixin): #(nn.Module): ...@@ -121,11 +131,15 @@ class TemporalUNet(ModelMixin, ConfigMixin): #(nn.Module):
for ind, (dim_in, dim_out) in enumerate(in_out): for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1) is_last = ind >= (num_resolutions - 1)
self.downs.append(nn.ModuleList([ self.downs.append(
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=horizon), nn.ModuleList(
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=horizon), [
Downsample1d(dim_out) if not is_last else nn.Identity() ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=horizon),
])) ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=horizon),
Downsample1d(dim_out) if not is_last else nn.Identity(),
]
)
)
if not is_last: if not is_last:
horizon = horizon // 2 horizon = horizon // 2
...@@ -137,11 +151,15 @@ class TemporalUNet(ModelMixin, ConfigMixin): #(nn.Module): ...@@ -137,11 +151,15 @@ class TemporalUNet(ModelMixin, ConfigMixin): #(nn.Module):
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1) is_last = ind >= (num_resolutions - 1)
self.ups.append(nn.ModuleList([ self.ups.append(
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=horizon), nn.ModuleList(
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=horizon), [
Upsample1d(dim_in) if not is_last else nn.Identity() ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=horizon),
])) ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=horizon),
Upsample1d(dim_in) if not is_last else nn.Identity(),
]
)
)
if not is_last: if not is_last:
horizon = horizon * 2 horizon = horizon * 2
...@@ -152,11 +170,11 @@ class TemporalUNet(ModelMixin, ConfigMixin): #(nn.Module): ...@@ -152,11 +170,11 @@ class TemporalUNet(ModelMixin, ConfigMixin): #(nn.Module):
) )
def forward(self, x, cond, time): def forward(self, x, cond, time):
''' """
x : [ batch x horizon x transition ] x : [ batch x horizon x transition ]
''' """
x = einops.rearrange(x, 'b h t -> b t h') x = einops.rearrange(x, "b h t -> b t h")
t = self.time_mlp(time) t = self.time_mlp(time)
h = [] h = []
...@@ -178,11 +196,11 @@ class TemporalUNet(ModelMixin, ConfigMixin): #(nn.Module): ...@@ -178,11 +196,11 @@ class TemporalUNet(ModelMixin, ConfigMixin): #(nn.Module):
x = self.final_conv(x) x = self.final_conv(x)
x = einops.rearrange(x, 'b t h -> b h t') x = einops.rearrange(x, "b t h -> b h t")
return x return x
class TemporalValue(nn.Module):
class TemporalValue(nn.Module):
def __init__( def __init__(
self, self,
horizon, horizon,
...@@ -211,11 +229,15 @@ class TemporalValue(nn.Module): ...@@ -211,11 +229,15 @@ class TemporalValue(nn.Module):
print(in_out) print(in_out)
for dim_in, dim_out in in_out: for dim_in, dim_out in in_out:
self.blocks.append(nn.ModuleList([ self.blocks.append(
ResidualTemporalBlock(dim_in, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon), nn.ModuleList(
ResidualTemporalBlock(dim_out, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon), [
Downsample1d(dim_out) ResidualTemporalBlock(dim_in, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
])) ResidualTemporalBlock(dim_out, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
Downsample1d(dim_out),
]
)
)
horizon = horizon // 2 horizon = horizon // 2
...@@ -228,11 +250,11 @@ class TemporalValue(nn.Module): ...@@ -228,11 +250,11 @@ class TemporalValue(nn.Module):
) )
def forward(self, x, cond, time, *args): def forward(self, x, cond, time, *args):
''' """
x : [ batch x horizon x transition ] x : [ batch x horizon x transition ]
''' """
x = einops.rearrange(x, 'b h t -> b t h') x = einops.rearrange(x, "b h t -> b t h")
t = self.time_mlp(time) t = self.time_mlp(time)
...@@ -243,4 +265,4 @@ class TemporalValue(nn.Module): ...@@ -243,4 +265,4 @@ class TemporalValue(nn.Module):
x = x.view(len(x), -1) x = x.view(len(x), -1)
out = self.final_block(torch.cat([x, t], dim=-1)) out = self.final_block(torch.cat([x, t], dim=-1))
return out return out
\ No newline at end of file
from ..utils import is_transformers_available
from .pipeline_bddm import BDDM from .pipeline_bddm import BDDM
from .pipeline_ddim import DDIM from .pipeline_ddim import DDIM
from .pipeline_ddpm import DDPM from .pipeline_ddpm import DDPM
from .pipeline_grad_tts import GradTTS from .pipeline_pndm import PNDM
try: if is_transformers_available():
from .pipeline_glide import GLIDE from .pipeline_glide import GLIDE
except (NameError, ImportError): from .pipeline_grad_tts import GradTTS
from .pipeline_latent_diffusion import LatentDiffusion
class GLIDE:
pass
from .pipeline_latent_diffusion import LatentDiffusion
from .pipeline_pndm import PNDM
...@@ -6,11 +6,8 @@ from shutil import copyfile ...@@ -6,11 +6,8 @@ from shutil import copyfile
import torch import torch
from transformers import PreTrainedTokenizer
try:
from transformers import PreTrainedTokenizer
except:
print("transformers is not installed")
try: try:
from unidecode import unidecode from unidecode import unidecode
...@@ -237,7 +234,12 @@ def english_cleaners(text): ...@@ -237,7 +234,12 @@ def english_cleaners(text):
return text return text
_inflect = inflect.engine() try:
_inflect = inflect.engine()
except:
print("inflect is not installed")
_inflect = None
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
......
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Denoising Diffusion Implicit Models (DDIM)
## Overview
DDPM was proposed in [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) by *Jiaming Song, Chenlin Meng, Stefano Ermon*
The abstract from the paper is the following:
*Denoising diffusion probabilistic models (DDPMs) have achieved high quality image generation without adversarial training, yet they require simulating a Markov chain for many steps to produce a sample. To accelerate sampling, we present denoising diffusion implicit models (DDIMs), a more efficient class of iterative implicit probabilistic models with the same training procedure as DDPMs. In DDPMs, the generative process is defined as the reverse of a Markovian diffusion process. We construct a class of non-Markovian diffusion processes that lead to the same training objective, but whose reverse process can be much faster to sample from. We empirically demonstrate that DDIMs can produce high quality samples 10× to 50× faster in terms of wall-clock time compared to DDPMs, allow us to trade off computation for sample quality, and can perform semantically meaningful image interpolation directly in the latent space.*
Tips:
- ...
- ...
This model was contributed by [???](https://huggingface.co/???). The original code can be found [here](https://github.com/hojonathanho/diffusion).
#!/usr/bin/env python3
import os
import pathlib
import numpy as np
import PIL.Image
from modeling_ddim import DDIM
model_ids = ["ddim-celeba-hq", "ddim-lsun-church", "ddim-lsun-bedroom"]
for model_id in model_ids:
path = os.path.join("/home/patrick/images/hf", model_id)
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
ddpm = DDIM.from_pretrained("fusing/" + model_id)
image = ddpm(batch_size=4)
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = image_processed.numpy().astype(np.uint8)
for i in range(image_processed.shape[0]):
image_pil = PIL.Image.fromarray(image_processed[i])
image_pil.save(os.path.join(path, f"image_{i}.png"))
#!/usr/bin/env python3
import torch
from diffusers import DDPMScheduler, UNetModel
model = UNetModel(dim=64, dim_mults=(1, 2, 4, 8))
diffusion = DDPMScheduler(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2
training_images = torch.randn(8, 3, 128, 128) # your images need to be normalized from a range of -1 to +1
loss = diffusion(training_images)
loss.backward()
# after a lot of training
sampled_images = diffusion.sample(batch_size=4)
sampled_images.shape # (4, 3, 128, 128)
#!/usr/bin/env python3
# !pip install diffusers
import numpy as np
import PIL.Image
from modeling_ddim import DDIM
model_id = "fusing/ddpm-cifar10"
model_id = "fusing/ddpm-lsun-bedroom"
# load model and scheduler
ddpm = DDIM.from_pretrained(model_id)
# run pipeline in inference (sample random noise and denoise)
image = ddpm()
# process image to PIL
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = image_processed.numpy().astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0])
# save image
image_pil.save("/home/patrick/images/show.png")
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Denoising Diffusion Probabilistic Models (DDPM)
## Overview
DDPM was proposed in [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) by *Jonathan Ho, Ajay Jain, Pieter Abbeel*.
The abstract from the paper is the following:
*We present high quality image synthesis results using diffusion probabilistic models, a class of latent variable models inspired by considerations from nonequilibrium thermodynamics. Our best results are obtained by training on a weighted variational bound designed according to a novel connection between diffusion probabilistic models and denoising score matching with Langevin dynamics, and our models naturally admit a progressive lossy decompression scheme that can be interpreted as a generalization of autoregressive decoding. On the unconditional CIFAR10 dataset, we obtain an Inception score of 9.46 and a state-of-the-art FID score of 3.17. On 256x256 LSUN, we obtain sample quality similar to ProgressiveGAN. Our implementation is available at this https URL*
Tips:
- ...
- ...
This model was contributed by [???](https://huggingface.co/???). The original code can be found [here](https://github.com/hojonathanho/diffusion).
![ddpm](https://user-images.githubusercontent.com/23423619/171627620-e3406711-1e20-4a99-8e30-ec5a86a465be.png)
#!/usr/bin/env python3
import os
import pathlib
import numpy as np
import PIL.Image
from modeling_ddpm import DDPM
model_ids = [
"ddpm-lsun-cat",
"ddpm-lsun-cat-ema",
"ddpm-lsun-church-ema",
"ddpm-lsun-church",
"ddpm-lsun-bedroom",
"ddpm-lsun-bedroom-ema",
"ddpm-cifar10-ema",
"ddpm-cifar10",
"ddpm-celeba-hq",
"ddpm-celeba-hq-ema",
]
for model_id in model_ids:
path = os.path.join("/home/patrick/images/hf", model_id)
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
ddpm = DDPM.from_pretrained("fusing/" + model_id)
image = ddpm(batch_size=4)
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = image_processed.numpy().astype(np.uint8)
for i in range(image_processed.shape[0]):
image_pil = PIL.Image.fromarray(image_processed[i])
image_pil.save(os.path.join(path, f"image_{i}.png"))
#!/usr/bin/env python3
import torch
from diffusers import DDPMScheduler, UNetModel
model = UNetModel(dim=64, dim_mults=(1, 2, 4, 8))
diffusion = DDPMScheduler(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2
training_images = torch.randn(8, 3, 128, 128) # your images need to be normalized from a range of -1 to +1
loss = diffusion(training_images)
loss.backward()
# after a lot of training
sampled_images = diffusion.sample(batch_size=4)
sampled_images.shape # (4, 3, 128, 128)
# References
[GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models](https://arxiv.org/pdf/2112.10741.pdf)
[Diffusion Models Beat GANs on Image Synthesis](https://arxiv.org/pdf/2105.05233.pdf)
\ No newline at end of file
import torch
from torch import nn
from diffusers import ClassifierFreeGuidanceScheduler, GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from modeling_glide import GLIDE, CLIPTextModel
from transformers import CLIPTextConfig, GPT2Tokenizer
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt
state_dict = torch.load("base.pt", map_location="cpu")
state_dict = {k: nn.Parameter(v) for k, v in state_dict.items()}
### Convert the text encoder
config = CLIPTextConfig(
vocab_size=50257,
max_position_embeddings=128,
hidden_size=512,
intermediate_size=2048,
num_hidden_layers=16,
num_attention_heads=8,
use_padding_embeddings=True,
)
model = CLIPTextModel(config).eval()
tokenizer = GPT2Tokenizer(
"./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>"
)
hf_encoder = model.text_model
hf_encoder.embeddings.token_embedding.weight = state_dict["token_embedding.weight"]
hf_encoder.embeddings.position_embedding.weight.data = state_dict["positional_embedding"]
hf_encoder.embeddings.padding_embedding.weight.data = state_dict["padding_embedding"]
hf_encoder.final_layer_norm.weight = state_dict["final_ln.weight"]
hf_encoder.final_layer_norm.bias = state_dict["final_ln.bias"]
for layer_idx in range(config.num_hidden_layers):
hf_layer = hf_encoder.encoder.layers[layer_idx]
hf_layer.self_attn.qkv_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.weight"]
hf_layer.self_attn.qkv_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_qkv.bias"]
hf_layer.self_attn.out_proj.weight = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.weight"]
hf_layer.self_attn.out_proj.bias = state_dict[f"transformer.resblocks.{layer_idx}.attn.c_proj.bias"]
hf_layer.layer_norm1.weight = state_dict[f"transformer.resblocks.{layer_idx}.ln_1.weight"]
hf_layer.layer_norm1.bias = state_dict[f"transformer.resblocks.{layer_idx}.ln_1.bias"]
hf_layer.layer_norm2.weight = state_dict[f"transformer.resblocks.{layer_idx}.ln_2.weight"]
hf_layer.layer_norm2.bias = state_dict[f"transformer.resblocks.{layer_idx}.ln_2.bias"]
hf_layer.mlp.fc1.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_fc.weight"]
hf_layer.mlp.fc1.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_fc.bias"]
hf_layer.mlp.fc2.weight = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.weight"]
hf_layer.mlp.fc2.bias = state_dict[f"transformer.resblocks.{layer_idx}.mlp.c_proj.bias"]
### Convert the Text-to-Image UNet
text2im_model = GLIDETextToImageUNetModel(
in_channels=3,
model_channels=192,
out_channels=6,
num_res_blocks=3,
attention_resolutions=(2, 4, 8),
dropout=0.1,
channel_mult=(1, 2, 3, 4),
num_heads=1,
num_head_channels=64,
num_heads_upsample=1,
use_scale_shift_norm=True,
resblock_updown=True,
transformer_dim=512,
)
text2im_model.load_state_dict(state_dict, strict=False)
text_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squaredcos_cap_v2")
### Convert the Super-Resolution UNet
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt
ups_state_dict = torch.load("upsample.pt", map_location="cpu")
superres_model = GLIDESuperResUNetModel(
in_channels=6,
model_channels=192,
out_channels=6,
num_res_blocks=2,
attention_resolutions=(8, 16, 32),
dropout=0.1,
channel_mult=(1, 1, 2, 2, 4, 4),
num_heads=1,
num_head_channels=64,
num_heads_upsample=1,
use_scale_shift_norm=True,
resblock_updown=True,
)
superres_model.load_state_dict(ups_state_dict, strict=False)
upscale_scheduler = DDIMScheduler(timesteps=1000, beta_schedule="linear")
glide = GLIDE(
text_unet=text2im_model,
text_noise_scheduler=text_scheduler,
text_encoder=model,
tokenizer=tokenizer,
upscale_unet=superres_model,
upscale_noise_scheduler=upscale_scheduler,
)
glide.save_pretrained("./glide-base")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment