Unverified Commit 0c9f174d authored by Aryan V S's avatar Aryan V S Committed by GitHub
Browse files

Improve typehints and docs in `diffusers/models` (#5391)



* improvement: add typehints and docs to src/diffusers/models/attention_processor.py

* improvement: add typehints and docs to src/diffusers/models/vae.py

* improvement: add missing docs in src/diffusers/models/vq_model.py

* improvement: add typehints and docs to src/diffusers/models/transformer_temporal.py

* improvement: add typehints and docs to src/diffusers/models/t5_film_transformer.py

* improvement: add type hints to src/diffusers/models/unet_1d_blocks.py

* improvement: add missing type hints to src/diffusers/models/unet_2d_blocks.py

* fix: CI error (make fix-copies required)

* fix: CI error (make fix-copies required again)

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent d420d713
This diff is collapsed.
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple
import torch
from torch import nn
......@@ -23,6 +24,28 @@ from .modeling_utils import ModelMixin
class T5FilmDecoder(ModelMixin, ConfigMixin):
r"""
T5 style decoder with FiLM conditioning.
Args:
input_dims (`int`, *optional*, defaults to `128`):
The number of input dimensions.
targets_length (`int`, *optional*, defaults to `256`):
The length of the targets.
d_model (`int`, *optional*, defaults to `768`):
Size of the input hidden states.
num_layers (`int`, *optional*, defaults to `12`):
The number of `DecoderLayer`'s to use.
num_heads (`int`, *optional*, defaults to `12`):
The number of attention heads to use.
d_kv (`int`, *optional*, defaults to `64`):
Size of the key-value projection vectors.
d_ff (`int`, *optional*, defaults to `2048`):
The number of dimensions in the intermediate feed-forward layer of `DecoderLayer`'s.
dropout_rate (`float`, *optional*, defaults to `0.1`):
Dropout probability.
"""
@register_to_config
def __init__(
self,
......@@ -63,7 +86,7 @@ class T5FilmDecoder(ModelMixin, ConfigMixin):
self.post_dropout = nn.Dropout(p=dropout_rate)
self.spec_out = nn.Linear(d_model, input_dims, bias=False)
def encoder_decoder_mask(self, query_input, key_input):
def encoder_decoder_mask(self, query_input: torch.FloatTensor, key_input: torch.FloatTensor) -> torch.FloatTensor:
mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
return mask.unsqueeze(-3)
......@@ -125,7 +148,27 @@ class T5FilmDecoder(ModelMixin, ConfigMixin):
class DecoderLayer(nn.Module):
def __init__(self, d_model, d_kv, num_heads, d_ff, dropout_rate, layer_norm_epsilon=1e-6):
r"""
T5 decoder layer.
Args:
d_model (`int`):
Size of the input hidden states.
d_kv (`int`):
Size of the key-value projection vectors.
num_heads (`int`):
Number of attention heads.
d_ff (`int`):
Size of the intermediate feed-forward layer.
dropout_rate (`float`):
Dropout probability.
layer_norm_epsilon (`float`, *optional*, defaults to `1e-6`):
A small value used for numerical stability to avoid dividing by zero.
"""
def __init__(
self, d_model: int, d_kv: int, num_heads: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float = 1e-6
):
super().__init__()
self.layer = nn.ModuleList()
......@@ -152,13 +195,13 @@ class DecoderLayer(nn.Module):
def forward(
self,
hidden_states,
conditioning_emb=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
hidden_states: torch.FloatTensor,
conditioning_emb: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
encoder_decoder_position_bias=None,
):
) -> Tuple[torch.FloatTensor]:
hidden_states = self.layer[0](
hidden_states,
conditioning_emb=conditioning_emb,
......@@ -183,7 +226,21 @@ class DecoderLayer(nn.Module):
class T5LayerSelfAttentionCond(nn.Module):
def __init__(self, d_model, d_kv, num_heads, dropout_rate):
r"""
T5 style self-attention layer with conditioning.
Args:
d_model (`int`):
Size of the input hidden states.
d_kv (`int`):
Size of the key-value projection vectors.
num_heads (`int`):
Number of attention heads.
dropout_rate (`float`):
Dropout probability.
"""
def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float):
super().__init__()
self.layer_norm = T5LayerNorm(d_model)
self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
......@@ -192,10 +249,10 @@ class T5LayerSelfAttentionCond(nn.Module):
def forward(
self,
hidden_states,
conditioning_emb=None,
attention_mask=None,
):
hidden_states: torch.FloatTensor,
conditioning_emb: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
# pre_self_attention_layer_norm
normed_hidden_states = self.layer_norm(hidden_states)
......@@ -211,7 +268,23 @@ class T5LayerSelfAttentionCond(nn.Module):
class T5LayerCrossAttention(nn.Module):
def __init__(self, d_model, d_kv, num_heads, dropout_rate, layer_norm_epsilon):
r"""
T5 style cross-attention layer.
Args:
d_model (`int`):
Size of the input hidden states.
d_kv (`int`):
Size of the key-value projection vectors.
num_heads (`int`):
Number of attention heads.
dropout_rate (`float`):
Dropout probability.
layer_norm_epsilon (`float`):
A small value used for numerical stability to avoid dividing by zero.
"""
def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float, layer_norm_epsilon: float):
super().__init__()
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
......@@ -219,10 +292,10 @@ class T5LayerCrossAttention(nn.Module):
def forward(
self,
hidden_states,
key_value_states=None,
attention_mask=None,
):
hidden_states: torch.FloatTensor,
key_value_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.attention(
normed_hidden_states,
......@@ -234,14 +307,30 @@ class T5LayerCrossAttention(nn.Module):
class T5LayerFFCond(nn.Module):
def __init__(self, d_model, d_ff, dropout_rate, layer_norm_epsilon):
r"""
T5 style feed-forward conditional layer.
Args:
d_model (`int`):
Size of the input hidden states.
d_ff (`int`):
Size of the intermediate feed-forward layer.
dropout_rate (`float`):
Dropout probability.
layer_norm_epsilon (`float`):
A small value used for numerical stability to avoid dividing by zero.
"""
def __init__(self, d_model: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float):
super().__init__()
self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate)
self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
self.dropout = nn.Dropout(dropout_rate)
def forward(self, hidden_states, conditioning_emb=None):
def forward(
self, hidden_states: torch.FloatTensor, conditioning_emb: Optional[torch.FloatTensor] = None
) -> torch.FloatTensor:
forwarded_states = self.layer_norm(hidden_states)
if conditioning_emb is not None:
forwarded_states = self.film(forwarded_states, conditioning_emb)
......@@ -252,7 +341,19 @@ class T5LayerFFCond(nn.Module):
class T5DenseGatedActDense(nn.Module):
def __init__(self, d_model, d_ff, dropout_rate):
r"""
T5 style feed-forward layer with gated activations and dropout.
Args:
d_model (`int`):
Size of the input hidden states.
d_ff (`int`):
Size of the intermediate feed-forward layer.
dropout_rate (`float`):
Dropout probability.
"""
def __init__(self, d_model: int, d_ff: int, dropout_rate: float):
super().__init__()
self.wi_0 = nn.Linear(d_model, d_ff, bias=False)
self.wi_1 = nn.Linear(d_model, d_ff, bias=False)
......@@ -260,7 +361,7 @@ class T5DenseGatedActDense(nn.Module):
self.dropout = nn.Dropout(dropout_rate)
self.act = NewGELUActivation()
def forward(self, hidden_states):
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
......@@ -271,7 +372,17 @@ class T5DenseGatedActDense(nn.Module):
class T5LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
r"""
T5 style layer normalization module.
Args:
hidden_size (`int`):
Size of the input hidden states.
eps (`float`, `optional`, defaults to `1e-6`):
A small value used for numerical stability to avoid dividing by zero.
"""
def __init__(self, hidden_size: int, eps: float = 1e-6):
"""
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
"""
......@@ -279,7 +390,7 @@ class T5LayerNorm(nn.Module):
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
......@@ -307,14 +418,20 @@ class NewGELUActivation(nn.Module):
class T5FiLMLayer(nn.Module):
"""
FiLM Layer
T5 style FiLM Layer.
Args:
in_features (`int`):
Number of input features.
out_features (`int`):
Number of output features.
"""
def __init__(self, in_features, out_features):
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
def forward(self, x, conditioning_emb):
def forward(self, x: torch.FloatTensor, conditioning_emb: torch.FloatTensor) -> torch.FloatTensor:
emb = self.scale_bias(conditioning_emb)
scale, shift = torch.chunk(emb, 2, -1)
x = x * (1 + scale) + shift
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
from typing import Any, Dict, Optional
import torch
from torch import nn
......@@ -48,11 +48,15 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
attention_bias (`bool`, *optional*):
Configure if the `TransformerBlock` attention should contain a bias parameter.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
activation_fn (`str`, *optional*, defaults to `"geglu"`):
Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
activation functions.
norm_elementwise_affine (`bool`, *optional*):
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
double_self_attention (`bool`, *optional*):
Configure if each `TransformerBlock` should contain two self-attention layers.
"""
......@@ -106,14 +110,14 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
def forward(
self,
hidden_states,
encoder_hidden_states=None,
timestep=None,
class_labels=None,
num_frames=1,
cross_attention_kwargs=None,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.LongTensor] = None,
timestep: Optional[torch.LongTensor] = None,
class_labels: torch.LongTensor = None,
num_frames: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
):
) -> TransformerTemporalModelOutput:
"""
The [`TransformerTemporal`] forward method.
......@@ -123,7 +127,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.long`, *optional*):
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
......@@ -24,17 +25,17 @@ from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange
class DownResnetBlock1D(nn.Module):
def __init__(
self,
in_channels,
out_channels=None,
num_layers=1,
conv_shortcut=False,
temb_channels=32,
groups=32,
groups_out=None,
non_linearity=None,
time_embedding_norm="default",
output_scale_factor=1.0,
add_downsample=True,
in_channels: int,
out_channels: Optional[int] = None,
num_layers: int = 1,
conv_shortcut: bool = False,
temb_channels: int = 32,
groups: int = 32,
groups_out: Optional[int] = None,
non_linearity: Optional[str] = None,
time_embedding_norm: str = "default",
output_scale_factor: float = 1.0,
add_downsample: bool = True,
):
super().__init__()
self.in_channels = in_channels
......@@ -65,7 +66,7 @@ class DownResnetBlock1D(nn.Module):
if add_downsample:
self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
def forward(self, hidden_states, temb=None):
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
output_states = ()
hidden_states = self.resnets[0](hidden_states, temb)
......@@ -86,16 +87,16 @@ class DownResnetBlock1D(nn.Module):
class UpResnetBlock1D(nn.Module):
def __init__(
self,
in_channels,
out_channels=None,
num_layers=1,
temb_channels=32,
groups=32,
groups_out=None,
non_linearity=None,
time_embedding_norm="default",
output_scale_factor=1.0,
add_upsample=True,
in_channels: int,
out_channels: Optional[int] = None,
num_layers: int = 1,
temb_channels: int = 32,
groups: int = 32,
groups_out: Optional[int] = None,
non_linearity: Optional[str] = None,
time_embedding_norm: str = "default",
output_scale_factor: float = 1.0,
add_upsample: bool = True,
):
super().__init__()
self.in_channels = in_channels
......@@ -125,7 +126,12 @@ class UpResnetBlock1D(nn.Module):
if add_upsample:
self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None):
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Optional[Tuple[torch.FloatTensor, ...]] = None,
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
if res_hidden_states_tuple is not None:
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
......@@ -144,7 +150,7 @@ class UpResnetBlock1D(nn.Module):
class ValueFunctionMidBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, embed_dim):
def __init__(self, in_channels: int, out_channels: int, embed_dim: int):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
......@@ -155,7 +161,7 @@ class ValueFunctionMidBlock1D(nn.Module):
self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
self.down2 = Downsample1D(out_channels // 4, use_conv=True)
def forward(self, x, temb=None):
def forward(self, x: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
x = self.res1(x, temb)
x = self.down1(x)
x = self.res2(x, temb)
......@@ -166,13 +172,13 @@ class ValueFunctionMidBlock1D(nn.Module):
class MidResTemporalBlock1D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
embed_dim,
in_channels: int,
out_channels: int,
embed_dim: int,
num_layers: int = 1,
add_downsample: bool = False,
add_upsample: bool = False,
non_linearity=None,
non_linearity: Optional[str] = None,
):
super().__init__()
self.in_channels = in_channels
......@@ -203,7 +209,7 @@ class MidResTemporalBlock1D(nn.Module):
if self.upsample and self.downsample:
raise ValueError("Block cannot downsample and upsample")
def forward(self, hidden_states, temb):
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = self.resnets[0](hidden_states, temb)
for resnet in self.resnets[1:]:
hidden_states = resnet(hidden_states, temb)
......@@ -217,14 +223,14 @@ class MidResTemporalBlock1D(nn.Module):
class OutConv1DBlock(nn.Module):
def __init__(self, num_groups_out, out_channels, embed_dim, act_fn):
def __init__(self, num_groups_out: int, out_channels: int, embed_dim: int, act_fn: str):
super().__init__()
self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
self.final_conv1d_act = get_activation(act_fn)
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
def forward(self, hidden_states, temb=None):
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.final_conv1d_1(hidden_states)
hidden_states = rearrange_dims(hidden_states)
hidden_states = self.final_conv1d_gn(hidden_states)
......@@ -235,7 +241,7 @@ class OutConv1DBlock(nn.Module):
class OutValueFunctionBlock(nn.Module):
def __init__(self, fc_dim, embed_dim, act_fn="mish"):
def __init__(self, fc_dim: int, embed_dim: int, act_fn: str = "mish"):
super().__init__()
self.final_block = nn.ModuleList(
[
......@@ -245,7 +251,7 @@ class OutValueFunctionBlock(nn.Module):
]
)
def forward(self, hidden_states, temb):
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = hidden_states.view(hidden_states.shape[0], -1)
hidden_states = torch.cat((hidden_states, temb), dim=-1)
for layer in self.final_block:
......@@ -275,14 +281,14 @@ _kernels = {
class Downsample1d(nn.Module):
def __init__(self, kernel="linear", pad_mode="reflect"):
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel])
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer("kernel", kernel_1d)
def forward(self, hidden_states):
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
......@@ -292,14 +298,14 @@ class Downsample1d(nn.Module):
class Upsample1d(nn.Module):
def __init__(self, kernel="linear", pad_mode="reflect"):
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel]) * 2
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer("kernel", kernel_1d)
def forward(self, hidden_states, temb=None):
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
......@@ -309,7 +315,7 @@ class Upsample1d(nn.Module):
class SelfAttention1d(nn.Module):
def __init__(self, in_channels, n_head=1, dropout_rate=0.0):
def __init__(self, in_channels: int, n_head: int = 1, dropout_rate: float = 0.0):
super().__init__()
self.channels = in_channels
self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
......@@ -329,7 +335,7 @@ class SelfAttention1d(nn.Module):
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
return new_projection
def forward(self, hidden_states):
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
residual = hidden_states
batch, channel_dim, seq = hidden_states.shape
......@@ -367,7 +373,7 @@ class SelfAttention1d(nn.Module):
class ResConvBlock(nn.Module):
def __init__(self, in_channels, mid_channels, out_channels, is_last=False):
def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_last: bool = False):
super().__init__()
self.is_last = is_last
self.has_conv_skip = in_channels != out_channels
......@@ -384,7 +390,7 @@ class ResConvBlock(nn.Module):
self.group_norm_2 = nn.GroupNorm(1, out_channels)
self.gelu_2 = nn.GELU()
def forward(self, hidden_states):
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
hidden_states = self.conv_1(hidden_states)
......@@ -401,7 +407,7 @@ class ResConvBlock(nn.Module):
class UNetMidBlock1D(nn.Module):
def __init__(self, mid_channels, in_channels, out_channels=None):
def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None):
super().__init__()
out_channels = in_channels if out_channels is None else out_channels
......@@ -429,7 +435,7 @@ class UNetMidBlock1D(nn.Module):
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None):
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.down(hidden_states)
for attn, resnet in zip(self.attentions, self.resnets):
hidden_states = resnet(hidden_states)
......@@ -441,7 +447,7 @@ class UNetMidBlock1D(nn.Module):
class AttnDownBlock1D(nn.Module):
def __init__(self, out_channels, in_channels, mid_channels=None):
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels
......@@ -460,7 +466,7 @@ class AttnDownBlock1D(nn.Module):
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None):
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.down(hidden_states)
for resnet, attn in zip(self.resnets, self.attentions):
......@@ -471,7 +477,7 @@ class AttnDownBlock1D(nn.Module):
class DownBlock1D(nn.Module):
def __init__(self, out_channels, in_channels, mid_channels=None):
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels
......@@ -484,7 +490,7 @@ class DownBlock1D(nn.Module):
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None):
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.down(hidden_states)
for resnet in self.resnets:
......@@ -494,7 +500,7 @@ class DownBlock1D(nn.Module):
class DownBlock1DNoSkip(nn.Module):
def __init__(self, out_channels, in_channels, mid_channels=None):
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels
......@@ -506,7 +512,7 @@ class DownBlock1DNoSkip(nn.Module):
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None):
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = torch.cat([hidden_states, temb], dim=1)
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
......@@ -515,7 +521,7 @@ class DownBlock1DNoSkip(nn.Module):
class AttnUpBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None):
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels
......@@ -534,7 +540,12 @@ class AttnUpBlock1D(nn.Module):
self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic")
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
......@@ -548,7 +559,7 @@ class AttnUpBlock1D(nn.Module):
class UpBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None):
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = in_channels if mid_channels is None else mid_channels
......@@ -561,7 +572,12 @@ class UpBlock1D(nn.Module):
self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic")
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
......@@ -574,7 +590,7 @@ class UpBlock1D(nn.Module):
class UpBlock1DNoSkip(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None):
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = in_channels if mid_channels is None else mid_channels
......@@ -586,7 +602,12 @@ class UpBlock1DNoSkip(nn.Module):
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
......@@ -596,7 +617,20 @@ class UpBlock1DNoSkip(nn.Module):
return hidden_states
def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample):
DownBlockType = Union[DownResnetBlock1D, DownBlock1D, AttnDownBlock1D, DownBlock1DNoSkip]
MidBlockType = Union[MidResTemporalBlock1D, ValueFunctionMidBlock1D, UNetMidBlock1D]
OutBlockType = Union[OutConv1DBlock, OutValueFunctionBlock]
UpBlockType = Union[UpResnetBlock1D, UpBlock1D, AttnUpBlock1D, UpBlock1DNoSkip]
def get_down_block(
down_block_type: str,
num_layers: int,
in_channels: int,
out_channels: int,
temb_channels: int,
add_downsample: bool,
) -> DownBlockType:
if down_block_type == "DownResnetBlock1D":
return DownResnetBlock1D(
in_channels=in_channels,
......@@ -614,7 +648,9 @@ def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_
raise ValueError(f"{down_block_type} does not exist.")
def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_channels, add_upsample):
def get_up_block(
up_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_upsample: bool
) -> UpBlockType:
if up_block_type == "UpResnetBlock1D":
return UpResnetBlock1D(
in_channels=in_channels,
......@@ -632,7 +668,15 @@ def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_chan
raise ValueError(f"{up_block_type} does not exist.")
def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels, embed_dim, add_downsample):
def get_mid_block(
mid_block_type: str,
num_layers: int,
in_channels: int,
mid_channels: int,
out_channels: int,
embed_dim: int,
add_downsample: bool,
) -> MidBlockType:
if mid_block_type == "MidResTemporalBlock1D":
return MidResTemporalBlock1D(
num_layers=num_layers,
......@@ -648,7 +692,9 @@ def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_cha
raise ValueError(f"{mid_block_type} does not exist.")
def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim):
def get_out_block(
*, out_block_type: str, num_groups_out: int, embed_dim: int, out_channels: int, act_fn: str, fc_dim: int
) -> Optional[OutBlockType]:
if out_block_type == "OutConv1DBlock":
return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
elif out_block_type == "ValueFunction":
......
This diff is collapsed.
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Tuple
import numpy as np
import torch
......@@ -27,7 +27,7 @@ from .unet_2d_blocks import AutoencoderTinyBlock, UNetMidBlock2D, get_down_block
@dataclass
class DecoderOutput(BaseOutput):
"""
r"""
Output of decoding method.
Args:
......@@ -39,16 +39,39 @@ class DecoderOutput(BaseOutput):
class Encoder(nn.Module):
r"""
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
double_z (`bool`, *optional*, defaults to `True`):
Whether to double the number of output channels for the last block.
"""
def __init__(
self,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
double_z=True,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
double_z: bool = True,
):
super().__init__()
self.layers_per_block = layers_per_block
......@@ -107,7 +130,8 @@ class Encoder(nn.Module):
self.gradient_checkpointing = False
def forward(self, x):
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
sample = x
sample = self.conv_in(sample)
......@@ -152,16 +176,38 @@ class Encoder(nn.Module):
class Decoder(nn.Module):
r"""
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
norm_type (`str`, *optional*, defaults to `"group"`):
The normalization type to use. Can be either `"group"` or `"spatial"`.
"""
def __init__(
self,
in_channels=3,
out_channels=3,
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
norm_type="group", # group, spatial
in_channels: int = 3,
out_channels: int = 3,
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
norm_type: str = "group", # group, spatial
):
super().__init__()
self.layers_per_block = layers_per_block
......@@ -227,7 +273,8 @@ class Decoder(nn.Module):
self.gradient_checkpointing = False
def forward(self, z, latent_embeds=None):
def forward(self, z: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
sample = z
sample = self.conv_in(sample)
......@@ -283,6 +330,16 @@ class Decoder(nn.Module):
class UpSample(nn.Module):
r"""
The `UpSample` layer of a variational autoencoder that upsamples its input.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
"""
def __init__(
self,
in_channels: int,
......@@ -294,6 +351,7 @@ class UpSample(nn.Module):
self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `UpSample` class."""
x = torch.relu(x)
x = self.deconv(x)
return x
......@@ -342,6 +400,7 @@ class MaskConditionEncoder(nn.Module):
self.layers = nn.Sequential(*layers)
def forward(self, x: torch.FloatTensor, mask=None) -> torch.FloatTensor:
r"""The forward method of the `MaskConditionEncoder` class."""
out = {}
for l in range(len(self.layers)):
layer = self.layers[l]
......@@ -352,19 +411,38 @@ class MaskConditionEncoder(nn.Module):
class MaskConditionDecoder(nn.Module):
"""The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's
decoder with a conditioner on the mask and masked image."""
r"""The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's
decoder with a conditioner on the mask and masked image.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
norm_type (`str`, *optional*, defaults to `"group"`):
The normalization type to use. Can be either `"group"` or `"spatial"`.
"""
def __init__(
self,
in_channels=3,
out_channels=3,
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
norm_type="group", # group, spatial
in_channels: int = 3,
out_channels: int = 3,
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
norm_type: str = "group", # group, spatial
):
super().__init__()
self.layers_per_block = layers_per_block
......@@ -437,7 +515,14 @@ class MaskConditionDecoder(nn.Module):
self.gradient_checkpointing = False
def forward(self, z, image=None, mask=None, latent_embeds=None):
def forward(
self,
z: torch.FloatTensor,
image: Optional[torch.FloatTensor] = None,
mask: Optional[torch.FloatTensor] = None,
latent_embeds: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
r"""The forward method of the `MaskConditionDecoder` class."""
sample = z
sample = self.conv_in(sample)
......@@ -539,7 +624,14 @@ class VectorQuantizer(nn.Module):
# backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it.
def __init__(
self, n_e, vq_embed_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True
self,
n_e: int,
vq_embed_dim: int,
beta: float,
remap=None,
unknown_index: str = "random",
sane_index_shape: bool = False,
legacy: bool = True,
):
super().__init__()
self.n_e = n_e
......@@ -553,6 +645,7 @@ class VectorQuantizer(nn.Module):
self.remap = remap
if self.remap is not None:
self.register_buffer("used", torch.tensor(np.load(self.remap)))
self.used: torch.Tensor
self.re_embed = self.used.shape[0]
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.unknown_index == "extra":
......@@ -567,7 +660,7 @@ class VectorQuantizer(nn.Module):
self.sane_index_shape = sane_index_shape
def remap_to_used(self, inds):
def remap_to_used(self, inds: torch.LongTensor) -> torch.LongTensor:
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
......@@ -581,7 +674,7 @@ class VectorQuantizer(nn.Module):
new[unknown] = self.unknown_index
return new.reshape(ishape)
def unmap_to_all(self, inds):
def unmap_to_all(self, inds: torch.LongTensor) -> torch.LongTensor:
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
......@@ -591,7 +684,7 @@ class VectorQuantizer(nn.Module):
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
return back.reshape(ishape)
def forward(self, z):
def forward(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, Tuple]:
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.vq_embed_dim)
......@@ -610,7 +703,7 @@ class VectorQuantizer(nn.Module):
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
# preserve gradients
z_q = z + (z_q - z).detach()
z_q: torch.FloatTensor = z + (z_q - z).detach()
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
......@@ -625,7 +718,7 @@ class VectorQuantizer(nn.Module):
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
def get_codebook_entry(self, indices, shape):
def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.FloatTensor:
# shape specifying (batch, height, width, channel)
if self.remap is not None:
indices = indices.reshape(shape[0], -1) # add batch axis
......@@ -633,7 +726,7 @@ class VectorQuantizer(nn.Module):
indices = indices.reshape(-1) # flatten again
# get quantized latent vectors
z_q = self.embedding(indices)
z_q: torch.FloatTensor = self.embedding(indices)
if shape is not None:
z_q = z_q.view(shape)
......@@ -644,7 +737,7 @@ class VectorQuantizer(nn.Module):
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
......@@ -664,7 +757,7 @@ class DiagonalGaussianDistribution(object):
x = self.mean + self.std * sample
return x
def kl(self, other=None):
def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
else:
......@@ -680,23 +773,40 @@ class DiagonalGaussianDistribution(object):
dim=[1, 2, 3],
)
def nll(self, sample, dims=[1, 2, 3]):
def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
def mode(self):
def mode(self) -> torch.Tensor:
return self.mean
class EncoderTiny(nn.Module):
r"""
The `EncoderTiny` layer is a simpler version of the `Encoder` layer.
Args:
in_channels (`int`):
The number of input channels.
out_channels (`int`):
The number of output channels.
num_blocks (`Tuple[int, ...]`):
Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
use.
block_out_channels (`Tuple[int, ...]`):
The number of output channels for each block.
act_fn (`str`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
num_blocks: int,
block_out_channels: int,
num_blocks: Tuple[int, ...],
block_out_channels: Tuple[int, ...],
act_fn: str,
):
super().__init__()
......@@ -718,7 +828,8 @@ class EncoderTiny(nn.Module):
self.layers = nn.Sequential(*layers)
self.gradient_checkpointing = False
def forward(self, x):
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `EncoderTiny` class."""
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
......@@ -740,12 +851,31 @@ class EncoderTiny(nn.Module):
class DecoderTiny(nn.Module):
r"""
The `DecoderTiny` layer is a simpler version of the `Decoder` layer.
Args:
in_channels (`int`):
The number of input channels.
out_channels (`int`):
The number of output channels.
num_blocks (`Tuple[int, ...]`):
Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
use.
block_out_channels (`Tuple[int, ...]`):
The number of output channels for each block.
upsampling_scaling_factor (`int`):
The scaling factor to use for upsampling.
act_fn (`str`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
num_blocks: int,
block_out_channels: int,
num_blocks: Tuple[int, ...],
block_out_channels: Tuple[int, ...],
upsampling_scaling_factor: int,
act_fn: str,
):
......@@ -772,7 +902,8 @@ class DecoderTiny(nn.Module):
self.layers = nn.Sequential(*layers)
self.gradient_checkpointing = False
def forward(self, x):
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `DecoderTiny` class."""
# Clamp.
x = torch.tanh(x / 3) * 3
......
......@@ -53,10 +53,12 @@ class VQModel(ModelMixin, ConfigMixin):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
Tuple of block output channels.
layers_per_block (`int`, *optional*, defaults to `1`): Number of layers per block.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
norm_num_groups (`int`, *optional*, defaults to `32`): Number of groups for normalization layers.
vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
scaling_factor (`float`, *optional*, defaults to `0.18215`):
The component-wise standard deviation of the trained latent space computed using the first batch of the
......@@ -65,6 +67,8 @@ class VQModel(ModelMixin, ConfigMixin):
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
norm_type (`str`, *optional*, defaults to `"group"`):
Type of normalization layer to use. Can be one of `"group"` or `"spatial"`.
"""
@register_to_config
......@@ -72,9 +76,9 @@ class VQModel(ModelMixin, ConfigMixin):
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 3,
......
......@@ -106,7 +106,6 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
......
......@@ -134,7 +134,6 @@ class AltDiffusionImg2ImgPipeline(
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
......
......@@ -1508,9 +1508,9 @@ class DownBlockFlat(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_downsample=True,
downsample_padding=1,
output_scale_factor: float = 1.0,
add_downsample: bool = True,
downsample_padding: int = 1,
):
super().__init__()
resnets = []
......@@ -1547,7 +1547,9 @@ class DownBlockFlat(nn.Module):
self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None, scale: float = 1.0):
def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = ()
for resnet in self.resnets:
......@@ -1596,16 +1598,16 @@ class CrossAttnDownBlockFlat(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
downsample_padding=1,
add_downsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
attention_type="default",
num_attention_heads: int = 1,
cross_attention_dim: int = 1280,
output_scale_factor: float = 1.0,
downsample_padding: int = 1,
add_downsample: bool = True,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
attention_type: str = "default",
):
super().__init__()
resnets = []
......@@ -1682,8 +1684,8 @@ class CrossAttnDownBlockFlat(nn.Module):
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
additional_residuals=None,
):
additional_residuals: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = ()
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
......@@ -1751,7 +1753,7 @@ class UpBlockFlat(nn.Module):
prev_output_channel: int,
out_channels: int,
temb_channels: int,
resolution_idx: int = None,
resolution_idx: Optional[int] = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
......@@ -1759,8 +1761,8 @@ class UpBlockFlat(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_upsample=True,
output_scale_factor: float = 1.0,
add_upsample: bool = True,
):
super().__init__()
resnets = []
......@@ -1794,7 +1796,14 @@ class UpBlockFlat(nn.Module):
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
upsample_size: Optional[int] = None,
scale: float = 1.0,
) -> torch.FloatTensor:
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
......@@ -1855,7 +1864,7 @@ class CrossAttnUpBlockFlat(nn.Module):
out_channels: int,
prev_output_channel: int,
temb_channels: int,
resolution_idx: int = None,
resolution_idx: Optional[int] = None,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
......@@ -1864,15 +1873,15 @@ class CrossAttnUpBlockFlat(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
add_upsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
attention_type="default",
num_attention_heads: int = 1,
cross_attention_dim: int = 1280,
output_scale_factor: float = 1.0,
add_upsample: bool = True,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
attention_type: str = "default",
):
super().__init__()
resnets = []
......@@ -1949,7 +1958,7 @@ class CrossAttnUpBlockFlat(nn.Module):
upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
) -> torch.FloatTensor:
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
is_freeu_enabled = (
getattr(self, "s1", None)
......@@ -2066,8 +2075,8 @@ class UNetMidBlockFlat(nn.Module):
attn_groups: Optional[int] = None,
resnet_pre_norm: bool = True,
add_attention: bool = True,
attention_head_dim=1,
output_scale_factor=1.0,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
):
super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
......@@ -2138,7 +2147,7 @@ class UNetMidBlockFlat(nn.Module):
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None):
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
......@@ -2162,13 +2171,13 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads=1,
output_scale_factor=1.0,
cross_attention_dim=1280,
dual_cross_attention=False,
use_linear_projection=False,
upcast_attention=False,
attention_type="default",
num_attention_heads: int = 1,
output_scale_factor: float = 1.0,
cross_attention_dim: int = 1280,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
upcast_attention: bool = False,
attention_type: str = "default",
):
super().__init__()
......@@ -2308,12 +2317,12 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attention_head_dim=1,
output_scale_factor=1.0,
cross_attention_dim=1280,
skip_time_act=False,
only_cross_attention=False,
cross_attention_norm=None,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
cross_attention_dim: int = 1280,
skip_time_act: bool = False,
only_cross_attention: bool = False,
cross_attention_norm: Optional[str] = None,
):
super().__init__()
......@@ -2389,7 +2398,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
) -> torch.FloatTensor:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
lora_scale = cross_attention_kwargs.get("scale", 1.0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment