Commit 554b374d authored by Patrick von Platen's avatar Patrick von Platen
Browse files

Merge branch 'main' of https://github.com/huggingface/diffusers into main

parents d5ab55e4 a0520193
from .rl import ValueGuidedRLPipeline
from .value_guided_sampling import ValueGuidedRLPipeline
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
import tqdm
from ...models.unet_1d import UNet1DModel
from ...pipeline_utils import DiffusionPipeline
from ...utils.dummy_pt_objects import DDPMScheduler
class ValueGuidedRLPipeline(DiffusionPipeline):
def __init__(
self,
value_function: UNet1DModel,
unet: UNet1DModel,
scheduler: DDPMScheduler,
env,
):
super().__init__()
self.value_function = value_function
self.unet = unet
self.scheduler = scheduler
self.env = env
self.data = env.get_dataset()
self.means = dict()
for key in self.data.keys():
try:
self.means[key] = self.data[key].mean()
except:
pass
self.stds = dict()
for key in self.data.keys():
try:
self.stds[key] = self.data[key].std()
except:
pass
self.state_dim = env.observation_space.shape[0]
self.action_dim = env.action_space.shape[0]
def normalize(self, x_in, key):
return (x_in - self.means[key]) / self.stds[key]
def de_normalize(self, x_in, key):
return x_in * self.stds[key] + self.means[key]
def to_torch(self, x_in):
if type(x_in) is dict:
return {k: self.to_torch(v) for k, v in x_in.items()}
elif torch.is_tensor(x_in):
return x_in.to(self.unet.device)
return torch.tensor(x_in, device=self.unet.device)
def reset_x0(self, x_in, cond, act_dim):
for key, val in cond.items():
x_in[:, key, act_dim:] = val.clone()
return x_in
def run_diffusion(self, x, conditions, n_guide_steps, scale):
batch_size = x.shape[0]
y = None
for i in tqdm.tqdm(self.scheduler.timesteps):
# create batch of timesteps to pass into model
timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
for _ in range(n_guide_steps):
with torch.enable_grad():
x.requires_grad_()
y = self.value_function(x.permute(0, 2, 1), timesteps).sample
grad = torch.autograd.grad([y.sum()], [x])[0]
posterior_variance = self.scheduler._get_variance(i)
model_std = torch.exp(0.5 * posterior_variance)
grad = model_std * grad
grad[timesteps < 2] = 0
x = x.detach()
x = x + scale * grad
x = self.reset_x0(x, conditions, self.action_dim)
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
# apply conditions to the trajectory
x = self.reset_x0(x, conditions, self.action_dim)
x = self.to_torch(x)
return x, y
def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
# normalize the observations and create batch dimension
obs = self.normalize(obs, "observations")
obs = obs[None].repeat(batch_size, axis=0)
conditions = {0: self.to_torch(obs)}
shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
# generate initial noise and apply our conditions (to make the trajectories start at current state)
x1 = torch.randn(shape, device=self.unet.device)
x = self.reset_x0(x1, conditions, self.action_dim)
x = self.to_torch(x)
# run the diffusion process
x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
# sort output trajectories by value
sorted_idx = y.argsort(0, descending=True).squeeze()
sorted_values = x[sorted_idx]
actions = sorted_values[:, :, : self.action_dim]
actions = actions.detach().cpu().numpy()
denorm_actions = self.de_normalize(actions, key="actions")
# select the action with the highest value
if y is not None:
selected_index = 0
else:
# if we didn't run value guiding, select a random action
selected_index = np.random.randint(0, batch_size)
denorm_actions = denorm_actions[selected_index, 0]
return denorm_actions
...@@ -62,14 +62,21 @@ def get_timestep_embedding( ...@@ -62,14 +62,21 @@ def get_timestep_embedding(
class TimestepEmbedding(nn.Module): class TimestepEmbedding(nn.Module):
def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"): def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None):
super().__init__() super().__init__()
self.linear_1 = nn.Linear(channel, time_embed_dim) self.linear_1 = nn.Linear(in_channels, time_embed_dim)
self.act = None self.act = None
if act_fn == "silu": if act_fn == "silu":
self.act = nn.SiLU() self.act = nn.SiLU()
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) elif act_fn == "mish":
self.act = nn.Mish()
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
def forward(self, sample): def forward(self, sample):
sample = self.linear_1(sample) sample = self.linear_1(sample)
......
...@@ -5,6 +5,75 @@ import torch.nn as nn ...@@ -5,6 +5,75 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
class Upsample1D(nn.Module):
"""
An upsampling layer with an optional convolution.
Parameters:
channels: channels in the inputs and outputs.
use_conv: a bool determining if a convolution is applied.
use_conv_transpose:
out_channels:
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
self.conv = None
if use_conv_transpose:
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
elif use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(x)
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample1D(nn.Module):
"""
A downsampling layer with an optional convolution.
Parameters:
channels: channels in the inputs and outputs.
use_conv: a bool determining if a convolution is applied.
out_channels:
padding:
"""
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
self.name = name
if use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.conv(x)
class Upsample2D(nn.Module): class Upsample2D(nn.Module):
""" """
An upsampling layer with an optional convolution. An upsampling layer with an optional convolution.
...@@ -12,7 +81,8 @@ class Upsample2D(nn.Module): ...@@ -12,7 +81,8 @@ class Upsample2D(nn.Module):
Parameters: Parameters:
channels: channels in the inputs and outputs. channels: channels in the inputs and outputs.
use_conv: a bool determining if a convolution is applied. use_conv: a bool determining if a convolution is applied.
dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions. use_conv_transpose:
out_channels:
""" """
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
...@@ -80,7 +150,8 @@ class Downsample2D(nn.Module): ...@@ -80,7 +150,8 @@ class Downsample2D(nn.Module):
Parameters: Parameters:
channels: channels in the inputs and outputs. channels: channels in the inputs and outputs.
use_conv: a bool determining if a convolution is applied. use_conv: a bool determining if a convolution is applied.
dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions. out_channels:
padding:
""" """
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
...@@ -415,6 +486,69 @@ class Mish(torch.nn.Module): ...@@ -415,6 +486,69 @@ class Mish(torch.nn.Module):
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
# unet_rl.py
def rearrange_dims(tensor):
if len(tensor.shape) == 2:
return tensor[:, :, None]
if len(tensor.shape) == 3:
return tensor[:, :, None, :]
elif len(tensor.shape) == 4:
return tensor[:, :, 0, :]
else:
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
class Conv1dBlock(nn.Module):
"""
Conv1d --> GroupNorm --> Mish
"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
self.group_norm = nn.GroupNorm(n_groups, out_channels)
self.mish = nn.Mish()
def forward(self, x):
x = self.conv1d(x)
x = rearrange_dims(x)
x = self.group_norm(x)
x = rearrange_dims(x)
x = self.mish(x)
return x
# unet_rl.py
class ResidualTemporalBlock1D(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
super().__init__()
self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
self.time_emb_act = nn.Mish()
self.time_emb = nn.Linear(embed_dim, out_channels)
self.residual_conv = (
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
)
def forward(self, x, t):
"""
Args:
x : [ batch_size x inp_channels x horizon ]
t : [ batch_size x embed_dim ]
returns:
out : [ batch_size x out_channels x horizon ]
"""
t = self.time_emb_act(t)
t = self.time_emb(t)
out = self.conv_in(x) + rearrange_dims(t)
out = self.conv_out(out)
return out + self.residual_conv(x)
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
r"""Upsample2D a batch of 2D images with the given filter. r"""Upsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
...@@ -8,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config ...@@ -8,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..utils import BaseOutput from ..utils import BaseOutput
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .unet_1d_blocks import get_down_block, get_mid_block, get_up_block from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
@dataclass @dataclass
...@@ -30,11 +44,11 @@ class UNet1DModel(ModelMixin, ConfigMixin): ...@@ -30,11 +44,11 @@ class UNet1DModel(ModelMixin, ConfigMixin):
implements for all the model (such as downloading or saving, etc.) implements for all the model (such as downloading or saving, etc.)
Parameters: Parameters:
sample_size (`int`, *optionl*): Default length of sample. Should be adaptable at runtime. sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample. in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding. freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for fourier time embedding.
flip_sin_to_cos (`bool`, *optional*, defaults to : flip_sin_to_cos (`bool`, *optional*, defaults to :
obj:`False`): Whether to flip sin to cos for fourier time embedding. obj:`False`): Whether to flip sin to cos for fourier time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to : down_block_types (`Tuple[str]`, *optional*, defaults to :
...@@ -43,6 +57,13 @@ class UNet1DModel(ModelMixin, ConfigMixin): ...@@ -43,6 +57,13 @@ class UNet1DModel(ModelMixin, ConfigMixin):
obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types. obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to : block_out_channels (`Tuple[int]`, *optional*, defaults to :
obj:`(32, 32, 64)`): Tuple of block output channels. obj:`(32, 32, 64)`): Tuple of block output channels.
mid_block_type (`str`, *optional*, defaults to "UNetMidBlock1D"): block type for middle of UNet.
out_block_type (`str`, *optional*, defaults to `None`): optional output processing of UNet.
act_fn (`str`, *optional*, defaults to None): optional activitation function in UNet blocks.
norm_num_groups (`int`, *optional*, defaults to 8): group norm member count in UNet blocks.
layers_per_block (`int`, *optional*, defaults to 1): added number of layers in a UNet block.
downsample_each_block (`int`, *optional*, defaults to False:
experimental feature for using a UNet without upsampling.
""" """
@register_to_config @register_to_config
...@@ -54,16 +75,20 @@ class UNet1DModel(ModelMixin, ConfigMixin): ...@@ -54,16 +75,20 @@ class UNet1DModel(ModelMixin, ConfigMixin):
out_channels: int = 2, out_channels: int = 2,
extra_in_channels: int = 0, extra_in_channels: int = 0,
time_embedding_type: str = "fourier", time_embedding_type: str = "fourier",
freq_shift: int = 0,
flip_sin_to_cos: bool = True, flip_sin_to_cos: bool = True,
use_timestep_embedding: bool = False, use_timestep_embedding: bool = False,
freq_shift: float = 0.0,
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"), down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
mid_block_type: str = "UNetMidBlock1D",
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"), up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
mid_block_type: Tuple[str] = "UNetMidBlock1D",
out_block_type: str = None,
block_out_channels: Tuple[int] = (32, 32, 64), block_out_channels: Tuple[int] = (32, 32, 64),
act_fn: str = None,
norm_num_groups: int = 8,
layers_per_block: int = 1,
downsample_each_block: bool = False,
): ):
super().__init__() super().__init__()
self.sample_size = sample_size self.sample_size = sample_size
# time # time
...@@ -73,12 +98,19 @@ class UNet1DModel(ModelMixin, ConfigMixin): ...@@ -73,12 +98,19 @@ class UNet1DModel(ModelMixin, ConfigMixin):
) )
timestep_input_dim = 2 * block_out_channels[0] timestep_input_dim = 2 * block_out_channels[0]
elif time_embedding_type == "positional": elif time_embedding_type == "positional":
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) self.time_proj = Timesteps(
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
)
timestep_input_dim = block_out_channels[0] timestep_input_dim = block_out_channels[0]
if use_timestep_embedding: if use_timestep_embedding:
time_embed_dim = block_out_channels[0] * 4 time_embed_dim = block_out_channels[0] * 4
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) self.time_mlp = TimestepEmbedding(
in_channels=timestep_input_dim,
time_embed_dim=time_embed_dim,
act_fn=act_fn,
out_dim=block_out_channels[0],
)
self.down_blocks = nn.ModuleList([]) self.down_blocks = nn.ModuleList([])
self.mid_block = None self.mid_block = None
...@@ -94,38 +126,66 @@ class UNet1DModel(ModelMixin, ConfigMixin): ...@@ -94,38 +126,66 @@ class UNet1DModel(ModelMixin, ConfigMixin):
if i == 0: if i == 0:
input_channel += extra_in_channels input_channel += extra_in_channels
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block( down_block = get_down_block(
down_block_type, down_block_type,
num_layers=layers_per_block,
in_channels=input_channel, in_channels=input_channel,
out_channels=output_channel, out_channels=output_channel,
temb_channels=block_out_channels[0],
add_downsample=not is_final_block or downsample_each_block,
) )
self.down_blocks.append(down_block) self.down_blocks.append(down_block)
# mid # mid
self.mid_block = get_mid_block( self.mid_block = get_mid_block(
mid_block_type=mid_block_type, mid_block_type,
mid_channels=block_out_channels[-1],
in_channels=block_out_channels[-1], in_channels=block_out_channels[-1],
out_channels=None, mid_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
embed_dim=block_out_channels[0],
num_layers=layers_per_block,
add_downsample=downsample_each_block,
) )
# up # up
reversed_block_out_channels = list(reversed(block_out_channels)) reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0] output_channel = reversed_block_out_channels[0]
if out_block_type is None:
final_upsample_channels = out_channels
else:
final_upsample_channels = block_out_channels[0]
for i, up_block_type in enumerate(up_block_types): for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else out_channels output_channel = (
reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels
)
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block( up_block = get_up_block(
up_block_type, up_block_type,
num_layers=layers_per_block,
in_channels=prev_output_channel, in_channels=prev_output_channel,
out_channels=output_channel, out_channels=output_channel,
temb_channels=block_out_channels[0],
add_upsample=not is_final_block,
) )
self.up_blocks.append(up_block) self.up_blocks.append(up_block)
prev_output_channel = output_channel prev_output_channel = output_channel
# TODO(PVP, Nathan) placeholder for RL application to be merged shortly # out
# Totally fine to add another layer with a if statement - no need for nn.Identity here num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
self.out_block = get_out_block(
out_block_type=out_block_type,
num_groups_out=num_groups_out,
embed_dim=block_out_channels[0],
out_channels=out_channels,
act_fn=act_fn,
fc_dim=block_out_channels[-1] // 4,
)
def forward( def forward(
self, self,
...@@ -144,12 +204,20 @@ class UNet1DModel(ModelMixin, ConfigMixin): ...@@ -144,12 +204,20 @@ class UNet1DModel(ModelMixin, ConfigMixin):
[`~models.unet_1d.UNet1DOutput`] or `tuple`: [`~models.unet_1d.UNet1DOutput`] if `return_dict` is True, [`~models.unet_1d.UNet1DOutput`] or `tuple`: [`~models.unet_1d.UNet1DOutput`] if `return_dict` is True,
otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
""" """
# 1. time
if len(timestep.shape) == 0:
timestep = timestep[None]
timestep_embed = self.time_proj(timestep)[..., None] # 1. time
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype) timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
timestep_embed = self.time_proj(timesteps)
if self.config.use_timestep_embedding:
timestep_embed = self.time_mlp(timestep_embed)
else:
timestep_embed = timestep_embed[..., None]
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
# 2. down # 2. down
down_block_res_samples = () down_block_res_samples = ()
...@@ -158,13 +226,18 @@ class UNet1DModel(ModelMixin, ConfigMixin): ...@@ -158,13 +226,18 @@ class UNet1DModel(ModelMixin, ConfigMixin):
down_block_res_samples += res_samples down_block_res_samples += res_samples
# 3. mid # 3. mid
sample = self.mid_block(sample) if self.mid_block:
sample = self.mid_block(sample, timestep_embed)
# 4. up # 4. up
for i, upsample_block in enumerate(self.up_blocks): for i, upsample_block in enumerate(self.up_blocks):
res_samples = down_block_res_samples[-1:] res_samples = down_block_res_samples[-1:]
down_block_res_samples = down_block_res_samples[:-1] down_block_res_samples = down_block_res_samples[:-1]
sample = upsample_block(sample, res_samples) sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed)
# 5. post-process
if self.out_block:
sample = self.out_block(sample, timestep_embed)
if not return_dict: if not return_dict:
return (sample,) return (sample,)
......
...@@ -17,6 +17,256 @@ import torch ...@@ -17,6 +17,256 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims
class DownResnetBlock1D(nn.Module):
def __init__(
self,
in_channels,
out_channels=None,
num_layers=1,
conv_shortcut=False,
temb_channels=32,
groups=32,
groups_out=None,
non_linearity=None,
time_embedding_norm="default",
output_scale_factor=1.0,
add_downsample=True,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.add_downsample = add_downsample
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
# there will always be at least one resnet
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)]
for _ in range(num_layers):
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
self.resnets = nn.ModuleList(resnets)
if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = nn.Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
else:
self.nonlinearity = None
self.downsample = None
if add_downsample:
self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
def forward(self, hidden_states, temb=None):
output_states = ()
hidden_states = self.resnets[0](hidden_states, temb)
for resnet in self.resnets[1:]:
hidden_states = resnet(hidden_states, temb)
output_states += (hidden_states,)
if self.nonlinearity is not None:
hidden_states = self.nonlinearity(hidden_states)
if self.downsample is not None:
hidden_states = self.downsample(hidden_states)
return hidden_states, output_states
class UpResnetBlock1D(nn.Module):
def __init__(
self,
in_channels,
out_channels=None,
num_layers=1,
temb_channels=32,
groups=32,
groups_out=None,
non_linearity=None,
time_embedding_norm="default",
output_scale_factor=1.0,
add_upsample=True,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.time_embedding_norm = time_embedding_norm
self.add_upsample = add_upsample
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
# there will always be at least one resnet
resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)]
for _ in range(num_layers):
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
self.resnets = nn.ModuleList(resnets)
if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = nn.Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
else:
self.nonlinearity = None
self.upsample = None
if add_upsample:
self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None):
if res_hidden_states_tuple is not None:
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
hidden_states = self.resnets[0](hidden_states, temb)
for resnet in self.resnets[1:]:
hidden_states = resnet(hidden_states, temb)
if self.nonlinearity is not None:
hidden_states = self.nonlinearity(hidden_states)
if self.upsample is not None:
hidden_states = self.upsample(hidden_states)
return hidden_states
class ValueFunctionMidBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, embed_dim):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.embed_dim = embed_dim
self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim)
self.down1 = Downsample1D(out_channels // 2, use_conv=True)
self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
self.down2 = Downsample1D(out_channels // 4, use_conv=True)
def forward(self, x, temb=None):
x = self.res1(x, temb)
x = self.down1(x)
x = self.res2(x, temb)
x = self.down2(x)
return x
class MidResTemporalBlock1D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
embed_dim,
num_layers: int = 1,
add_downsample: bool = False,
add_upsample: bool = False,
non_linearity=None,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.add_downsample = add_downsample
# there will always be at least one resnet
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)]
for _ in range(num_layers):
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim))
self.resnets = nn.ModuleList(resnets)
if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = nn.Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
else:
self.nonlinearity = None
self.upsample = None
if add_upsample:
self.upsample = Downsample1D(out_channels, use_conv=True)
self.downsample = None
if add_downsample:
self.downsample = Downsample1D(out_channels, use_conv=True)
if self.upsample and self.downsample:
raise ValueError("Block cannot downsample and upsample")
def forward(self, hidden_states, temb):
hidden_states = self.resnets[0](hidden_states, temb)
for resnet in self.resnets[1:]:
hidden_states = resnet(hidden_states, temb)
if self.upsample:
hidden_states = self.upsample(hidden_states)
if self.downsample:
self.downsample = self.downsample(hidden_states)
return hidden_states
class OutConv1DBlock(nn.Module):
def __init__(self, num_groups_out, out_channels, embed_dim, act_fn):
super().__init__()
self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
if act_fn == "silu":
self.final_conv1d_act = nn.SiLU()
if act_fn == "mish":
self.final_conv1d_act = nn.Mish()
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
def forward(self, hidden_states, temb=None):
hidden_states = self.final_conv1d_1(hidden_states)
hidden_states = rearrange_dims(hidden_states)
hidden_states = self.final_conv1d_gn(hidden_states)
hidden_states = rearrange_dims(hidden_states)
hidden_states = self.final_conv1d_act(hidden_states)
hidden_states = self.final_conv1d_2(hidden_states)
return hidden_states
class OutValueFunctionBlock(nn.Module):
def __init__(self, fc_dim, embed_dim):
super().__init__()
self.final_block = nn.ModuleList(
[
nn.Linear(fc_dim + embed_dim, fc_dim // 2),
nn.Mish(),
nn.Linear(fc_dim // 2, 1),
]
)
def forward(self, hidden_states, temb):
hidden_states = hidden_states.view(hidden_states.shape[0], -1)
hidden_states = torch.cat((hidden_states, temb), dim=-1)
for layer in self.final_block:
hidden_states = layer(hidden_states)
return hidden_states
_kernels = { _kernels = {
"linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8], "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
...@@ -62,7 +312,7 @@ class Upsample1d(nn.Module): ...@@ -62,7 +312,7 @@ class Upsample1d(nn.Module):
self.pad = kernel_1d.shape[0] // 2 - 1 self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer("kernel", kernel_1d) self.register_buffer("kernel", kernel_1d)
def forward(self, hidden_states): def forward(self, hidden_states, temb=None):
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
...@@ -162,32 +412,6 @@ class ResConvBlock(nn.Module): ...@@ -162,32 +412,6 @@ class ResConvBlock(nn.Module):
return output return output
def get_down_block(down_block_type, out_channels, in_channels):
if down_block_type == "DownBlock1D":
return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
elif down_block_type == "AttnDownBlock1D":
return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
elif down_block_type == "DownBlock1DNoSkip":
return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
raise ValueError(f"{down_block_type} does not exist.")
def get_up_block(up_block_type, in_channels, out_channels):
if up_block_type == "UpBlock1D":
return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
elif up_block_type == "AttnUpBlock1D":
return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
elif up_block_type == "UpBlock1DNoSkip":
return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
raise ValueError(f"{up_block_type} does not exist.")
def get_mid_block(mid_block_type, in_channels, mid_channels, out_channels):
if mid_block_type == "UNetMidBlock1D":
return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
raise ValueError(f"{mid_block_type} does not exist.")
class UNetMidBlock1D(nn.Module): class UNetMidBlock1D(nn.Module):
def __init__(self, mid_channels, in_channels, out_channels=None): def __init__(self, mid_channels, in_channels, out_channels=None):
super().__init__() super().__init__()
...@@ -217,7 +441,7 @@ class UNetMidBlock1D(nn.Module): ...@@ -217,7 +441,7 @@ class UNetMidBlock1D(nn.Module):
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states): def forward(self, hidden_states, temb=None):
hidden_states = self.down(hidden_states) hidden_states = self.down(hidden_states)
for attn, resnet in zip(self.attentions, self.resnets): for attn, resnet in zip(self.attentions, self.resnets):
hidden_states = resnet(hidden_states) hidden_states = resnet(hidden_states)
...@@ -322,7 +546,7 @@ class AttnUpBlock1D(nn.Module): ...@@ -322,7 +546,7 @@ class AttnUpBlock1D(nn.Module):
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic") self.up = Upsample1d(kernel="cubic")
def forward(self, hidden_states, res_hidden_states_tuple): def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
...@@ -349,7 +573,7 @@ class UpBlock1D(nn.Module): ...@@ -349,7 +573,7 @@ class UpBlock1D(nn.Module):
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic") self.up = Upsample1d(kernel="cubic")
def forward(self, hidden_states, res_hidden_states_tuple): def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
...@@ -374,7 +598,7 @@ class UpBlock1DNoSkip(nn.Module): ...@@ -374,7 +598,7 @@ class UpBlock1DNoSkip(nn.Module):
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, res_hidden_states_tuple): def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
...@@ -382,3 +606,63 @@ class UpBlock1DNoSkip(nn.Module): ...@@ -382,3 +606,63 @@ class UpBlock1DNoSkip(nn.Module):
hidden_states = resnet(hidden_states) hidden_states = resnet(hidden_states)
return hidden_states return hidden_states
def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample):
if down_block_type == "DownResnetBlock1D":
return DownResnetBlock1D(
in_channels=in_channels,
num_layers=num_layers,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
)
elif down_block_type == "DownBlock1D":
return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
elif down_block_type == "AttnDownBlock1D":
return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
elif down_block_type == "DownBlock1DNoSkip":
return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
raise ValueError(f"{down_block_type} does not exist.")
def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_channels, add_upsample):
if up_block_type == "UpResnetBlock1D":
return UpResnetBlock1D(
in_channels=in_channels,
num_layers=num_layers,
out_channels=out_channels,
temb_channels=temb_channels,
add_upsample=add_upsample,
)
elif up_block_type == "UpBlock1D":
return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
elif up_block_type == "AttnUpBlock1D":
return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
elif up_block_type == "UpBlock1DNoSkip":
return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
raise ValueError(f"{up_block_type} does not exist.")
def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels, embed_dim, add_downsample):
if mid_block_type == "MidResTemporalBlock1D":
return MidResTemporalBlock1D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
embed_dim=embed_dim,
add_downsample=add_downsample,
)
elif mid_block_type == "ValueFunctionMidBlock1D":
return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim)
elif mid_block_type == "UNetMidBlock1D":
return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
raise ValueError(f"{mid_block_type} does not exist.")
def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim):
if out_block_type == "OutConv1DBlock":
return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
elif out_block_type == "ValueFunction":
return OutValueFunctionBlock(fc_dim, embed_dim)
return None
...@@ -51,7 +51,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -51,7 +51,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use. time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding. freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
flip_sin_to_cos (`bool`, *optional*, defaults to : flip_sin_to_cos (`bool`, *optional*, defaults to :
obj:`False`): Whether to flip sin to cos for fourier time embedding. obj:`True`): Whether to flip sin to cos for fourier time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to : down_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
types. types.
......
...@@ -60,7 +60,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -60,7 +60,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`): flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
Whether to flip the sin to cos in the time embedding. Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
......
...@@ -47,7 +47,7 @@ logger = logging.get_logger(__name__) ...@@ -47,7 +47,7 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES = { LOADABLE_CLASSES = {
"diffusers": { "diffusers": {
"FlaxModelMixin": ["save_pretrained", "from_pretrained"], "FlaxModelMixin": ["save_pretrained", "from_pretrained"],
"FlaxSchedulerMixin": ["save_config", "from_config"], "FlaxSchedulerMixin": ["save_pretrained", "from_pretrained"],
"FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"], "FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"],
}, },
"transformers": { "transformers": {
...@@ -280,7 +280,7 @@ class FlaxDiffusionPipeline(ConfigMixin): ...@@ -280,7 +280,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
>>> from diffusers import FlaxDPMSolverMultistepScheduler >>> from diffusers import FlaxDPMSolverMultistepScheduler
>>> model_id = "runwayml/stable-diffusion-v1-5" >>> model_id = "runwayml/stable-diffusion-v1-5"
>>> sched, sched_state = FlaxDPMSolverMultistepScheduler.from_config( >>> sched, sched_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
... model_id, ... model_id,
... subfolder="scheduler", ... subfolder="scheduler",
... ) ... )
...@@ -303,7 +303,7 @@ class FlaxDiffusionPipeline(ConfigMixin): ...@@ -303,7 +303,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
# 1. Download the checkpoints and configs # 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained # use snapshot download here to get it working from from_pretrained
if not os.path.isdir(pretrained_model_name_or_path): if not os.path.isdir(pretrained_model_name_or_path):
config_dict = cls.get_config_dict( config_dict = cls.load_config(
pretrained_model_name_or_path, pretrained_model_name_or_path,
cache_dir=cache_dir, cache_dir=cache_dir,
resume_download=resume_download, resume_download=resume_download,
...@@ -349,7 +349,7 @@ class FlaxDiffusionPipeline(ConfigMixin): ...@@ -349,7 +349,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
else: else:
cached_folder = pretrained_model_name_or_path cached_folder = pretrained_model_name_or_path
config_dict = cls.get_config_dict(cached_folder) config_dict = cls.load_config(cached_folder)
# 2. Load the pipeline class, if using custom module then load it from the hub # 2. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it # if we load from explicit class, let's use it
...@@ -370,7 +370,7 @@ class FlaxDiffusionPipeline(ConfigMixin): ...@@ -370,7 +370,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys())
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_kwargs = {} init_kwargs = {}
......
...@@ -65,7 +65,7 @@ logger = logging.get_logger(__name__) ...@@ -65,7 +65,7 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES = { LOADABLE_CLASSES = {
"diffusers": { "diffusers": {
"ModelMixin": ["save_pretrained", "from_pretrained"], "ModelMixin": ["save_pretrained", "from_pretrained"],
"SchedulerMixin": ["save_config", "from_config"], "SchedulerMixin": ["save_pretrained", "from_pretrained"],
"DiffusionPipeline": ["save_pretrained", "from_pretrained"], "DiffusionPipeline": ["save_pretrained", "from_pretrained"],
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
}, },
...@@ -207,7 +207,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -207,7 +207,7 @@ class DiffusionPipeline(ConfigMixin):
if torch_device is None: if torch_device is None:
return self return self
module_names, _ = self.extract_init_dict(dict(self.config)) module_names, _, _ = self.extract_init_dict(dict(self.config))
for name in module_names.keys(): for name in module_names.keys():
module = getattr(self, name) module = getattr(self, name)
if isinstance(module, torch.nn.Module): if isinstance(module, torch.nn.Module):
...@@ -228,7 +228,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -228,7 +228,7 @@ class DiffusionPipeline(ConfigMixin):
Returns: Returns:
`torch.device`: The torch device on which the pipeline is located. `torch.device`: The torch device on which the pipeline is located.
""" """
module_names, _ = self.extract_init_dict(dict(self.config)) module_names, _, _ = self.extract_init_dict(dict(self.config))
for name in module_names.keys(): for name in module_names.keys():
module = getattr(self, name) module = getattr(self, name)
if isinstance(module, torch.nn.Module): if isinstance(module, torch.nn.Module):
...@@ -377,11 +377,11 @@ class DiffusionPipeline(ConfigMixin): ...@@ -377,11 +377,11 @@ class DiffusionPipeline(ConfigMixin):
>>> # of the documentation](https://huggingface.co/docs/hub/security-tokens) >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") >>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> # Download pipeline, but overwrite scheduler >>> # Use a different scheduler
>>> from diffusers import LMSDiscreteScheduler >>> from diffusers import LMSDiscreteScheduler
>>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler") >>> scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler) >>> pipeline.scheduler = scheduler
``` ```
""" """
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
...@@ -428,7 +428,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -428,7 +428,7 @@ class DiffusionPipeline(ConfigMixin):
# 1. Download the checkpoints and configs # 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained # use snapshot download here to get it working from from_pretrained
if not os.path.isdir(pretrained_model_name_or_path): if not os.path.isdir(pretrained_model_name_or_path):
config_dict = cls.get_config_dict( config_dict = cls.load_config(
pretrained_model_name_or_path, pretrained_model_name_or_path,
cache_dir=cache_dir, cache_dir=cache_dir,
resume_download=resume_download, resume_download=resume_download,
...@@ -474,7 +474,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -474,7 +474,7 @@ class DiffusionPipeline(ConfigMixin):
else: else:
cached_folder = pretrained_model_name_or_path cached_folder = pretrained_model_name_or_path
config_dict = cls.get_config_dict(cached_folder) config_dict = cls.load_config(cached_folder)
# 2. Load the pipeline class, if using custom module then load it from the hub # 2. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it # if we load from explicit class, let's use it
...@@ -513,7 +513,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -513,7 +513,7 @@ class DiffusionPipeline(ConfigMixin):
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"]) expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"])
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
init_dict, unused_kwargs = pipeline_class.extract_init_dict(config_dict, **kwargs) init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
if len(unused_kwargs) > 0: if len(unused_kwargs) > 0:
logger.warning(f"Keyword arguments {unused_kwargs} not recognized.") logger.warning(f"Keyword arguments {unused_kwargs} not recognized.")
......
...@@ -40,7 +40,7 @@ available a colab notebook to directly try them out. ...@@ -40,7 +40,7 @@ available a colab notebook to directly try them out.
| [pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | *Unconditional Image Generation* | | [pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | *Unconditional Image Generation* |
| [score_sde_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* | | [score_sde_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* |
| [score_sde_vp](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* | | [score_sde_vp](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* |
| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) | [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb)
| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Image-to-Image Text-Guided Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) | [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Image-to-Image Text-Guided Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-Guided Image Inpainting* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) | [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-Guided Image Inpainting* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [stochastic_karras_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | *Unconditional Image Generation* | | [stochastic_karras_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | *Unconditional Image Generation* |
......
...@@ -71,7 +71,7 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -71,7 +71,7 @@ class DDPMPipeline(DiffusionPipeline):
""" """
message = ( message = (
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler =" "Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`." " DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
) )
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
......
...@@ -72,7 +72,7 @@ image.save("astronaut_rides_horse.png") ...@@ -72,7 +72,7 @@ image.save("astronaut_rides_horse.png")
# make sure you're logged in with `huggingface-cli login` # make sure you're logged in with `huggingface-cli login`
from diffusers import StableDiffusionPipeline, DDIMScheduler from diffusers import StableDiffusionPipeline, DDIMScheduler
scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler") scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained( pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", "runwayml/stable-diffusion-v1-5",
...@@ -91,7 +91,7 @@ image.save("astronaut_rides_horse.png") ...@@ -91,7 +91,7 @@ image.save("astronaut_rides_horse.png")
# make sure you're logged in with `huggingface-cli login` # make sure you're logged in with `huggingface-cli login`
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
lms = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler") lms = LMSDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained( pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", "runwayml/stable-diffusion-v1-5",
...@@ -120,7 +120,7 @@ from diffusers import CycleDiffusionPipeline, DDIMScheduler ...@@ -120,7 +120,7 @@ from diffusers import CycleDiffusionPipeline, DDIMScheduler
# load the pipeline # load the pipeline
# make sure you're logged in with `huggingface-cli login` # make sure you're logged in with `huggingface-cli login`
model_id_or_path = "CompVis/stable-diffusion-v1-4" model_id_or_path = "CompVis/stable-diffusion-v1-4"
scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler") scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler")
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda") pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda")
# let's download an initial image # let's download an initial image
......
...@@ -23,7 +23,7 @@ import numpy as np ...@@ -23,7 +23,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -82,8 +82,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -82,8 +82,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~ConfigMixin.from_config`] functions. [`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2010.02502 For more details, see the original paper: https://arxiv.org/abs/2010.02502
...@@ -109,14 +109,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -109,14 +109,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatible_classes = [ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
"PNDMScheduler",
"DDPMScheduler",
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
]
@register_to_config @register_to_config
def __init__( def __init__(
......
...@@ -23,7 +23,12 @@ import flax ...@@ -23,7 +23,12 @@ import flax
import jax.numpy as jnp import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
broadcast_to_shape_from_left,
)
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
...@@ -79,8 +84,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -79,8 +84,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~ConfigMixin.from_config`] functions. [`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2010.02502 For more details, see the original paper: https://arxiv.org/abs/2010.02502
...@@ -105,6 +110,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -105,6 +110,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
stable diffusion. stable diffusion.
""" """
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@property @property
def has_state(self): def has_state(self):
return True return True
......
...@@ -22,7 +22,7 @@ import numpy as np ...@@ -22,7 +22,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ..utils import BaseOutput, deprecate from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -80,8 +80,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -80,8 +80,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~ConfigMixin.from_config`] functions. [`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2006.11239 For more details, see the original paper: https://arxiv.org/abs/2006.11239
...@@ -104,14 +104,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -104,14 +104,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatible_classes = [ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
"DDIMScheduler",
"PNDMScheduler",
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
]
@register_to_config @register_to_config
def __init__( def __init__(
...@@ -204,6 +197,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -204,6 +197,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# for rl-diffuser https://arxiv.org/abs/2205.09991 # for rl-diffuser https://arxiv.org/abs/2205.09991
elif variance_type == "fixed_small_log": elif variance_type == "fixed_small_log":
variance = torch.log(torch.clamp(variance, min=1e-20)) variance = torch.log(torch.clamp(variance, min=1e-20))
variance = torch.exp(0.5 * variance)
elif variance_type == "fixed_large": elif variance_type == "fixed_large":
variance = self.betas[t] variance = self.betas[t]
elif variance_type == "fixed_large_log": elif variance_type == "fixed_large_log":
...@@ -248,7 +242,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -248,7 +242,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
""" """
message = ( message = (
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler =" "Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`." " DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
) )
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon: if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
...@@ -301,7 +295,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -301,7 +295,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance_noise = torch.randn( variance_noise = torch.randn(
model_output.shape, generator=generator, device=device, dtype=model_output.dtype model_output.shape, generator=generator, device=device, dtype=model_output.dtype
) )
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise if self.variance_type == "fixed_small_log":
variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
else:
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
pred_prev_sample = pred_prev_sample + variance pred_prev_sample = pred_prev_sample + variance
......
...@@ -24,7 +24,12 @@ from jax import random ...@@ -24,7 +24,12 @@ from jax import random
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ..utils import deprecate from ..utils import deprecate
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
broadcast_to_shape_from_left,
)
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
...@@ -79,8 +84,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -79,8 +84,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~ConfigMixin.from_config`] functions. [`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2006.11239 For more details, see the original paper: https://arxiv.org/abs/2006.11239
...@@ -103,6 +108,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -103,6 +108,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
""" """
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@property @property
def has_state(self): def has_state(self):
return True return True
...@@ -221,7 +228,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -221,7 +228,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
""" """
message = ( message = (
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler =" "Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`." " DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
) )
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon: if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
......
...@@ -21,6 +21,7 @@ import numpy as np ...@@ -21,6 +21,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
from .scheduling_utils import SchedulerMixin, SchedulerOutput from .scheduling_utils import SchedulerMixin, SchedulerOutput
...@@ -71,8 +72,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -71,8 +72,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~ConfigMixin.from_config`] functions. [`~SchedulerMixin.from_pretrained`] functions.
Args: Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model. num_train_timesteps (`int`): number of diffusion steps used to train the model.
...@@ -116,14 +117,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -116,14 +117,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatible_classes = [ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
"DDIMScheduler",
"DDPMScheduler",
"PNDMScheduler",
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
]
@register_to_config @register_to_config
def __init__( def __init__(
......
...@@ -23,7 +23,12 @@ import jax ...@@ -23,7 +23,12 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
broadcast_to_shape_from_left,
)
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray: def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
...@@ -96,8 +101,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -96,8 +101,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~ConfigMixin.from_config`] functions. [`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
...@@ -143,6 +148,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -143,6 +148,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
""" """
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@property @property
def has_state(self): def has_state(self):
return True return True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment