Unverified Commit 0c9f174d authored by Aryan V S's avatar Aryan V S Committed by GitHub
Browse files

Improve typehints and docs in `diffusers/models` (#5391)



* improvement: add typehints and docs to src/diffusers/models/attention_processor.py

* improvement: add typehints and docs to src/diffusers/models/vae.py

* improvement: add missing docs in src/diffusers/models/vq_model.py

* improvement: add typehints and docs to src/diffusers/models/transformer_temporal.py

* improvement: add typehints and docs to src/diffusers/models/t5_film_transformer.py

* improvement: add type hints to src/diffusers/models/unet_1d_blocks.py

* improvement: add missing type hints to src/diffusers/models/unet_2d_blocks.py

* fix: CI error (make fix-copies required)

* fix: CI error (make fix-copies required again)

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent d420d713
This diff is collapsed.
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math import math
from typing import Optional, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -23,6 +24,28 @@ from .modeling_utils import ModelMixin ...@@ -23,6 +24,28 @@ from .modeling_utils import ModelMixin
class T5FilmDecoder(ModelMixin, ConfigMixin): class T5FilmDecoder(ModelMixin, ConfigMixin):
r"""
T5 style decoder with FiLM conditioning.
Args:
input_dims (`int`, *optional*, defaults to `128`):
The number of input dimensions.
targets_length (`int`, *optional*, defaults to `256`):
The length of the targets.
d_model (`int`, *optional*, defaults to `768`):
Size of the input hidden states.
num_layers (`int`, *optional*, defaults to `12`):
The number of `DecoderLayer`'s to use.
num_heads (`int`, *optional*, defaults to `12`):
The number of attention heads to use.
d_kv (`int`, *optional*, defaults to `64`):
Size of the key-value projection vectors.
d_ff (`int`, *optional*, defaults to `2048`):
The number of dimensions in the intermediate feed-forward layer of `DecoderLayer`'s.
dropout_rate (`float`, *optional*, defaults to `0.1`):
Dropout probability.
"""
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
...@@ -63,7 +86,7 @@ class T5FilmDecoder(ModelMixin, ConfigMixin): ...@@ -63,7 +86,7 @@ class T5FilmDecoder(ModelMixin, ConfigMixin):
self.post_dropout = nn.Dropout(p=dropout_rate) self.post_dropout = nn.Dropout(p=dropout_rate)
self.spec_out = nn.Linear(d_model, input_dims, bias=False) self.spec_out = nn.Linear(d_model, input_dims, bias=False)
def encoder_decoder_mask(self, query_input, key_input): def encoder_decoder_mask(self, query_input: torch.FloatTensor, key_input: torch.FloatTensor) -> torch.FloatTensor:
mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2)) mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
return mask.unsqueeze(-3) return mask.unsqueeze(-3)
...@@ -125,7 +148,27 @@ class T5FilmDecoder(ModelMixin, ConfigMixin): ...@@ -125,7 +148,27 @@ class T5FilmDecoder(ModelMixin, ConfigMixin):
class DecoderLayer(nn.Module): class DecoderLayer(nn.Module):
def __init__(self, d_model, d_kv, num_heads, d_ff, dropout_rate, layer_norm_epsilon=1e-6): r"""
T5 decoder layer.
Args:
d_model (`int`):
Size of the input hidden states.
d_kv (`int`):
Size of the key-value projection vectors.
num_heads (`int`):
Number of attention heads.
d_ff (`int`):
Size of the intermediate feed-forward layer.
dropout_rate (`float`):
Dropout probability.
layer_norm_epsilon (`float`, *optional*, defaults to `1e-6`):
A small value used for numerical stability to avoid dividing by zero.
"""
def __init__(
self, d_model: int, d_kv: int, num_heads: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float = 1e-6
):
super().__init__() super().__init__()
self.layer = nn.ModuleList() self.layer = nn.ModuleList()
...@@ -152,13 +195,13 @@ class DecoderLayer(nn.Module): ...@@ -152,13 +195,13 @@ class DecoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.FloatTensor,
conditioning_emb=None, conditioning_emb: Optional[torch.FloatTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.Tensor] = None,
encoder_decoder_position_bias=None, encoder_decoder_position_bias=None,
): ) -> Tuple[torch.FloatTensor]:
hidden_states = self.layer[0]( hidden_states = self.layer[0](
hidden_states, hidden_states,
conditioning_emb=conditioning_emb, conditioning_emb=conditioning_emb,
...@@ -183,7 +226,21 @@ class DecoderLayer(nn.Module): ...@@ -183,7 +226,21 @@ class DecoderLayer(nn.Module):
class T5LayerSelfAttentionCond(nn.Module): class T5LayerSelfAttentionCond(nn.Module):
def __init__(self, d_model, d_kv, num_heads, dropout_rate): r"""
T5 style self-attention layer with conditioning.
Args:
d_model (`int`):
Size of the input hidden states.
d_kv (`int`):
Size of the key-value projection vectors.
num_heads (`int`):
Number of attention heads.
dropout_rate (`float`):
Dropout probability.
"""
def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float):
super().__init__() super().__init__()
self.layer_norm = T5LayerNorm(d_model) self.layer_norm = T5LayerNorm(d_model)
self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
...@@ -192,10 +249,10 @@ class T5LayerSelfAttentionCond(nn.Module): ...@@ -192,10 +249,10 @@ class T5LayerSelfAttentionCond(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.FloatTensor,
conditioning_emb=None, conditioning_emb: Optional[torch.FloatTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
): ) -> torch.FloatTensor:
# pre_self_attention_layer_norm # pre_self_attention_layer_norm
normed_hidden_states = self.layer_norm(hidden_states) normed_hidden_states = self.layer_norm(hidden_states)
...@@ -211,7 +268,23 @@ class T5LayerSelfAttentionCond(nn.Module): ...@@ -211,7 +268,23 @@ class T5LayerSelfAttentionCond(nn.Module):
class T5LayerCrossAttention(nn.Module): class T5LayerCrossAttention(nn.Module):
def __init__(self, d_model, d_kv, num_heads, dropout_rate, layer_norm_epsilon): r"""
T5 style cross-attention layer.
Args:
d_model (`int`):
Size of the input hidden states.
d_kv (`int`):
Size of the key-value projection vectors.
num_heads (`int`):
Number of attention heads.
dropout_rate (`float`):
Dropout probability.
layer_norm_epsilon (`float`):
A small value used for numerical stability to avoid dividing by zero.
"""
def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float, layer_norm_epsilon: float):
super().__init__() super().__init__()
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False) self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
...@@ -219,10 +292,10 @@ class T5LayerCrossAttention(nn.Module): ...@@ -219,10 +292,10 @@ class T5LayerCrossAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.FloatTensor,
key_value_states=None, key_value_states: Optional[torch.FloatTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
): ) -> torch.FloatTensor:
normed_hidden_states = self.layer_norm(hidden_states) normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.attention( attention_output = self.attention(
normed_hidden_states, normed_hidden_states,
...@@ -234,14 +307,30 @@ class T5LayerCrossAttention(nn.Module): ...@@ -234,14 +307,30 @@ class T5LayerCrossAttention(nn.Module):
class T5LayerFFCond(nn.Module): class T5LayerFFCond(nn.Module):
def __init__(self, d_model, d_ff, dropout_rate, layer_norm_epsilon): r"""
T5 style feed-forward conditional layer.
Args:
d_model (`int`):
Size of the input hidden states.
d_ff (`int`):
Size of the intermediate feed-forward layer.
dropout_rate (`float`):
Dropout probability.
layer_norm_epsilon (`float`):
A small value used for numerical stability to avoid dividing by zero.
"""
def __init__(self, d_model: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float):
super().__init__() super().__init__()
self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate) self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate)
self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
self.dropout = nn.Dropout(dropout_rate) self.dropout = nn.Dropout(dropout_rate)
def forward(self, hidden_states, conditioning_emb=None): def forward(
self, hidden_states: torch.FloatTensor, conditioning_emb: Optional[torch.FloatTensor] = None
) -> torch.FloatTensor:
forwarded_states = self.layer_norm(hidden_states) forwarded_states = self.layer_norm(hidden_states)
if conditioning_emb is not None: if conditioning_emb is not None:
forwarded_states = self.film(forwarded_states, conditioning_emb) forwarded_states = self.film(forwarded_states, conditioning_emb)
...@@ -252,7 +341,19 @@ class T5LayerFFCond(nn.Module): ...@@ -252,7 +341,19 @@ class T5LayerFFCond(nn.Module):
class T5DenseGatedActDense(nn.Module): class T5DenseGatedActDense(nn.Module):
def __init__(self, d_model, d_ff, dropout_rate): r"""
T5 style feed-forward layer with gated activations and dropout.
Args:
d_model (`int`):
Size of the input hidden states.
d_ff (`int`):
Size of the intermediate feed-forward layer.
dropout_rate (`float`):
Dropout probability.
"""
def __init__(self, d_model: int, d_ff: int, dropout_rate: float):
super().__init__() super().__init__()
self.wi_0 = nn.Linear(d_model, d_ff, bias=False) self.wi_0 = nn.Linear(d_model, d_ff, bias=False)
self.wi_1 = nn.Linear(d_model, d_ff, bias=False) self.wi_1 = nn.Linear(d_model, d_ff, bias=False)
...@@ -260,7 +361,7 @@ class T5DenseGatedActDense(nn.Module): ...@@ -260,7 +361,7 @@ class T5DenseGatedActDense(nn.Module):
self.dropout = nn.Dropout(dropout_rate) self.dropout = nn.Dropout(dropout_rate)
self.act = NewGELUActivation() self.act = NewGELUActivation()
def forward(self, hidden_states): def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
hidden_gelu = self.act(self.wi_0(hidden_states)) hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states) hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear hidden_states = hidden_gelu * hidden_linear
...@@ -271,7 +372,17 @@ class T5DenseGatedActDense(nn.Module): ...@@ -271,7 +372,17 @@ class T5DenseGatedActDense(nn.Module):
class T5LayerNorm(nn.Module): class T5LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): r"""
T5 style layer normalization module.
Args:
hidden_size (`int`):
Size of the input hidden states.
eps (`float`, `optional`, defaults to `1e-6`):
A small value used for numerical stability to avoid dividing by zero.
"""
def __init__(self, hidden_size: int, eps: float = 1e-6):
""" """
Construct a layernorm module in the T5 style. No bias and no subtraction of mean. Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
""" """
...@@ -279,7 +390,7 @@ class T5LayerNorm(nn.Module): ...@@ -279,7 +390,7 @@ class T5LayerNorm(nn.Module):
self.weight = nn.Parameter(torch.ones(hidden_size)) self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps self.variance_epsilon = eps
def forward(self, hidden_states): def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
...@@ -307,14 +418,20 @@ class NewGELUActivation(nn.Module): ...@@ -307,14 +418,20 @@ class NewGELUActivation(nn.Module):
class T5FiLMLayer(nn.Module): class T5FiLMLayer(nn.Module):
""" """
FiLM Layer T5 style FiLM Layer.
Args:
in_features (`int`):
Number of input features.
out_features (`int`):
Number of output features.
""" """
def __init__(self, in_features, out_features): def __init__(self, in_features: int, out_features: int):
super().__init__() super().__init__()
self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False) self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
def forward(self, x, conditioning_emb): def forward(self, x: torch.FloatTensor, conditioning_emb: torch.FloatTensor) -> torch.FloatTensor:
emb = self.scale_bias(conditioning_emb) emb = self.scale_bias(conditioning_emb)
scale, shift = torch.chunk(emb, 2, -1) scale, shift = torch.chunk(emb, 2, -1)
x = x * (1 + scale) + shift x = x * (1 + scale) + shift
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Any, Dict, Optional
import torch import torch
from torch import nn from torch import nn
...@@ -48,11 +48,15 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin): ...@@ -48,11 +48,15 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
attention_bias (`bool`, *optional*): attention_bias (`bool`, *optional*):
Configure if the `TransformerBlock` attention should contain a bias parameter. Configure if the `TransformerBlock` attention should contain a bias parameter.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
activation_fn (`str`, *optional*, defaults to `"geglu"`):
Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
activation functions.
norm_elementwise_affine (`bool`, *optional*):
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
double_self_attention (`bool`, *optional*): double_self_attention (`bool`, *optional*):
Configure if each `TransformerBlock` should contain two self-attention layers. Configure if each `TransformerBlock` should contain two self-attention layers.
""" """
...@@ -106,14 +110,14 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin): ...@@ -106,14 +110,14 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.FloatTensor,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.LongTensor] = None,
timestep=None, timestep: Optional[torch.LongTensor] = None,
class_labels=None, class_labels: torch.LongTensor = None,
num_frames=1, num_frames: int = 1,
cross_attention_kwargs=None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True, return_dict: bool = True,
): ) -> TransformerTemporalModelOutput:
""" """
The [`TransformerTemporal`] forward method. The [`TransformerTemporal`] forward method.
...@@ -123,7 +127,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin): ...@@ -123,7 +127,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention. self-attention.
timestep ( `torch.long`, *optional*): timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math import math
from typing import Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -24,17 +25,17 @@ from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange ...@@ -24,17 +25,17 @@ from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange
class DownResnetBlock1D(nn.Module): class DownResnetBlock1D(nn.Module):
def __init__( def __init__(
self, self,
in_channels, in_channels: int,
out_channels=None, out_channels: Optional[int] = None,
num_layers=1, num_layers: int = 1,
conv_shortcut=False, conv_shortcut: bool = False,
temb_channels=32, temb_channels: int = 32,
groups=32, groups: int = 32,
groups_out=None, groups_out: Optional[int] = None,
non_linearity=None, non_linearity: Optional[str] = None,
time_embedding_norm="default", time_embedding_norm: str = "default",
output_scale_factor=1.0, output_scale_factor: float = 1.0,
add_downsample=True, add_downsample: bool = True,
): ):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
...@@ -65,7 +66,7 @@ class DownResnetBlock1D(nn.Module): ...@@ -65,7 +66,7 @@ class DownResnetBlock1D(nn.Module):
if add_downsample: if add_downsample:
self.downsample = Downsample1D(out_channels, use_conv=True, padding=1) self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
def forward(self, hidden_states, temb=None): def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
output_states = () output_states = ()
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
...@@ -86,16 +87,16 @@ class DownResnetBlock1D(nn.Module): ...@@ -86,16 +87,16 @@ class DownResnetBlock1D(nn.Module):
class UpResnetBlock1D(nn.Module): class UpResnetBlock1D(nn.Module):
def __init__( def __init__(
self, self,
in_channels, in_channels: int,
out_channels=None, out_channels: Optional[int] = None,
num_layers=1, num_layers: int = 1,
temb_channels=32, temb_channels: int = 32,
groups=32, groups: int = 32,
groups_out=None, groups_out: Optional[int] = None,
non_linearity=None, non_linearity: Optional[str] = None,
time_embedding_norm="default", time_embedding_norm: str = "default",
output_scale_factor=1.0, output_scale_factor: float = 1.0,
add_upsample=True, add_upsample: bool = True,
): ):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
...@@ -125,7 +126,12 @@ class UpResnetBlock1D(nn.Module): ...@@ -125,7 +126,12 @@ class UpResnetBlock1D(nn.Module):
if add_upsample: if add_upsample:
self.upsample = Upsample1D(out_channels, use_conv_transpose=True) self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None): def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Optional[Tuple[torch.FloatTensor, ...]] = None,
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
if res_hidden_states_tuple is not None: if res_hidden_states_tuple is not None:
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1) hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
...@@ -144,7 +150,7 @@ class UpResnetBlock1D(nn.Module): ...@@ -144,7 +150,7 @@ class UpResnetBlock1D(nn.Module):
class ValueFunctionMidBlock1D(nn.Module): class ValueFunctionMidBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, embed_dim): def __init__(self, in_channels: int, out_channels: int, embed_dim: int):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
...@@ -155,7 +161,7 @@ class ValueFunctionMidBlock1D(nn.Module): ...@@ -155,7 +161,7 @@ class ValueFunctionMidBlock1D(nn.Module):
self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim) self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
self.down2 = Downsample1D(out_channels // 4, use_conv=True) self.down2 = Downsample1D(out_channels // 4, use_conv=True)
def forward(self, x, temb=None): def forward(self, x: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
x = self.res1(x, temb) x = self.res1(x, temb)
x = self.down1(x) x = self.down1(x)
x = self.res2(x, temb) x = self.res2(x, temb)
...@@ -166,13 +172,13 @@ class ValueFunctionMidBlock1D(nn.Module): ...@@ -166,13 +172,13 @@ class ValueFunctionMidBlock1D(nn.Module):
class MidResTemporalBlock1D(nn.Module): class MidResTemporalBlock1D(nn.Module):
def __init__( def __init__(
self, self,
in_channels, in_channels: int,
out_channels, out_channels: int,
embed_dim, embed_dim: int,
num_layers: int = 1, num_layers: int = 1,
add_downsample: bool = False, add_downsample: bool = False,
add_upsample: bool = False, add_upsample: bool = False,
non_linearity=None, non_linearity: Optional[str] = None,
): ):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
...@@ -203,7 +209,7 @@ class MidResTemporalBlock1D(nn.Module): ...@@ -203,7 +209,7 @@ class MidResTemporalBlock1D(nn.Module):
if self.upsample and self.downsample: if self.upsample and self.downsample:
raise ValueError("Block cannot downsample and upsample") raise ValueError("Block cannot downsample and upsample")
def forward(self, hidden_states, temb): def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for resnet in self.resnets[1:]: for resnet in self.resnets[1:]:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -217,14 +223,14 @@ class MidResTemporalBlock1D(nn.Module): ...@@ -217,14 +223,14 @@ class MidResTemporalBlock1D(nn.Module):
class OutConv1DBlock(nn.Module): class OutConv1DBlock(nn.Module):
def __init__(self, num_groups_out, out_channels, embed_dim, act_fn): def __init__(self, num_groups_out: int, out_channels: int, embed_dim: int, act_fn: str):
super().__init__() super().__init__()
self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2) self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim) self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
self.final_conv1d_act = get_activation(act_fn) self.final_conv1d_act = get_activation(act_fn)
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1) self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
def forward(self, hidden_states, temb=None): def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.final_conv1d_1(hidden_states) hidden_states = self.final_conv1d_1(hidden_states)
hidden_states = rearrange_dims(hidden_states) hidden_states = rearrange_dims(hidden_states)
hidden_states = self.final_conv1d_gn(hidden_states) hidden_states = self.final_conv1d_gn(hidden_states)
...@@ -235,7 +241,7 @@ class OutConv1DBlock(nn.Module): ...@@ -235,7 +241,7 @@ class OutConv1DBlock(nn.Module):
class OutValueFunctionBlock(nn.Module): class OutValueFunctionBlock(nn.Module):
def __init__(self, fc_dim, embed_dim, act_fn="mish"): def __init__(self, fc_dim: int, embed_dim: int, act_fn: str = "mish"):
super().__init__() super().__init__()
self.final_block = nn.ModuleList( self.final_block = nn.ModuleList(
[ [
...@@ -245,7 +251,7 @@ class OutValueFunctionBlock(nn.Module): ...@@ -245,7 +251,7 @@ class OutValueFunctionBlock(nn.Module):
] ]
) )
def forward(self, hidden_states, temb): def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = hidden_states.view(hidden_states.shape[0], -1) hidden_states = hidden_states.view(hidden_states.shape[0], -1)
hidden_states = torch.cat((hidden_states, temb), dim=-1) hidden_states = torch.cat((hidden_states, temb), dim=-1)
for layer in self.final_block: for layer in self.final_block:
...@@ -275,14 +281,14 @@ _kernels = { ...@@ -275,14 +281,14 @@ _kernels = {
class Downsample1d(nn.Module): class Downsample1d(nn.Module):
def __init__(self, kernel="linear", pad_mode="reflect"): def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
super().__init__() super().__init__()
self.pad_mode = pad_mode self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel]) kernel_1d = torch.tensor(_kernels[kernel])
self.pad = kernel_1d.shape[0] // 2 - 1 self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer("kernel", kernel_1d) self.register_buffer("kernel", kernel_1d)
def forward(self, hidden_states): def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode) hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
...@@ -292,14 +298,14 @@ class Downsample1d(nn.Module): ...@@ -292,14 +298,14 @@ class Downsample1d(nn.Module):
class Upsample1d(nn.Module): class Upsample1d(nn.Module):
def __init__(self, kernel="linear", pad_mode="reflect"): def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
super().__init__() super().__init__()
self.pad_mode = pad_mode self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel]) * 2 kernel_1d = torch.tensor(_kernels[kernel]) * 2
self.pad = kernel_1d.shape[0] // 2 - 1 self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer("kernel", kernel_1d) self.register_buffer("kernel", kernel_1d)
def forward(self, hidden_states, temb=None): def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
...@@ -309,7 +315,7 @@ class Upsample1d(nn.Module): ...@@ -309,7 +315,7 @@ class Upsample1d(nn.Module):
class SelfAttention1d(nn.Module): class SelfAttention1d(nn.Module):
def __init__(self, in_channels, n_head=1, dropout_rate=0.0): def __init__(self, in_channels: int, n_head: int = 1, dropout_rate: float = 0.0):
super().__init__() super().__init__()
self.channels = in_channels self.channels = in_channels
self.group_norm = nn.GroupNorm(1, num_channels=in_channels) self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
...@@ -329,7 +335,7 @@ class SelfAttention1d(nn.Module): ...@@ -329,7 +335,7 @@ class SelfAttention1d(nn.Module):
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
return new_projection return new_projection
def forward(self, hidden_states): def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
residual = hidden_states residual = hidden_states
batch, channel_dim, seq = hidden_states.shape batch, channel_dim, seq = hidden_states.shape
...@@ -367,7 +373,7 @@ class SelfAttention1d(nn.Module): ...@@ -367,7 +373,7 @@ class SelfAttention1d(nn.Module):
class ResConvBlock(nn.Module): class ResConvBlock(nn.Module):
def __init__(self, in_channels, mid_channels, out_channels, is_last=False): def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_last: bool = False):
super().__init__() super().__init__()
self.is_last = is_last self.is_last = is_last
self.has_conv_skip = in_channels != out_channels self.has_conv_skip = in_channels != out_channels
...@@ -384,7 +390,7 @@ class ResConvBlock(nn.Module): ...@@ -384,7 +390,7 @@ class ResConvBlock(nn.Module):
self.group_norm_2 = nn.GroupNorm(1, out_channels) self.group_norm_2 = nn.GroupNorm(1, out_channels)
self.gelu_2 = nn.GELU() self.gelu_2 = nn.GELU()
def forward(self, hidden_states): def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
hidden_states = self.conv_1(hidden_states) hidden_states = self.conv_1(hidden_states)
...@@ -401,7 +407,7 @@ class ResConvBlock(nn.Module): ...@@ -401,7 +407,7 @@ class ResConvBlock(nn.Module):
class UNetMidBlock1D(nn.Module): class UNetMidBlock1D(nn.Module):
def __init__(self, mid_channels, in_channels, out_channels=None): def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None):
super().__init__() super().__init__()
out_channels = in_channels if out_channels is None else out_channels out_channels = in_channels if out_channels is None else out_channels
...@@ -429,7 +435,7 @@ class UNetMidBlock1D(nn.Module): ...@@ -429,7 +435,7 @@ class UNetMidBlock1D(nn.Module):
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None): def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.down(hidden_states) hidden_states = self.down(hidden_states)
for attn, resnet in zip(self.attentions, self.resnets): for attn, resnet in zip(self.attentions, self.resnets):
hidden_states = resnet(hidden_states) hidden_states = resnet(hidden_states)
...@@ -441,7 +447,7 @@ class UNetMidBlock1D(nn.Module): ...@@ -441,7 +447,7 @@ class UNetMidBlock1D(nn.Module):
class AttnDownBlock1D(nn.Module): class AttnDownBlock1D(nn.Module):
def __init__(self, out_channels, in_channels, mid_channels=None): def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
super().__init__() super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels mid_channels = out_channels if mid_channels is None else mid_channels
...@@ -460,7 +466,7 @@ class AttnDownBlock1D(nn.Module): ...@@ -460,7 +466,7 @@ class AttnDownBlock1D(nn.Module):
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None): def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.down(hidden_states) hidden_states = self.down(hidden_states)
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
...@@ -471,7 +477,7 @@ class AttnDownBlock1D(nn.Module): ...@@ -471,7 +477,7 @@ class AttnDownBlock1D(nn.Module):
class DownBlock1D(nn.Module): class DownBlock1D(nn.Module):
def __init__(self, out_channels, in_channels, mid_channels=None): def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
super().__init__() super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels mid_channels = out_channels if mid_channels is None else mid_channels
...@@ -484,7 +490,7 @@ class DownBlock1D(nn.Module): ...@@ -484,7 +490,7 @@ class DownBlock1D(nn.Module):
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None): def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.down(hidden_states) hidden_states = self.down(hidden_states)
for resnet in self.resnets: for resnet in self.resnets:
...@@ -494,7 +500,7 @@ class DownBlock1D(nn.Module): ...@@ -494,7 +500,7 @@ class DownBlock1D(nn.Module):
class DownBlock1DNoSkip(nn.Module): class DownBlock1DNoSkip(nn.Module):
def __init__(self, out_channels, in_channels, mid_channels=None): def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
super().__init__() super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels mid_channels = out_channels if mid_channels is None else mid_channels
...@@ -506,7 +512,7 @@ class DownBlock1DNoSkip(nn.Module): ...@@ -506,7 +512,7 @@ class DownBlock1DNoSkip(nn.Module):
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None): def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = torch.cat([hidden_states, temb], dim=1) hidden_states = torch.cat([hidden_states, temb], dim=1)
for resnet in self.resnets: for resnet in self.resnets:
hidden_states = resnet(hidden_states) hidden_states = resnet(hidden_states)
...@@ -515,7 +521,7 @@ class DownBlock1DNoSkip(nn.Module): ...@@ -515,7 +521,7 @@ class DownBlock1DNoSkip(nn.Module):
class AttnUpBlock1D(nn.Module): class AttnUpBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None): def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
super().__init__() super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels mid_channels = out_channels if mid_channels is None else mid_channels
...@@ -534,7 +540,12 @@ class AttnUpBlock1D(nn.Module): ...@@ -534,7 +540,12 @@ class AttnUpBlock1D(nn.Module):
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic") self.up = Upsample1d(kernel="cubic")
def forward(self, hidden_states, res_hidden_states_tuple, temb=None): def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
...@@ -548,7 +559,7 @@ class AttnUpBlock1D(nn.Module): ...@@ -548,7 +559,7 @@ class AttnUpBlock1D(nn.Module):
class UpBlock1D(nn.Module): class UpBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None): def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
super().__init__() super().__init__()
mid_channels = in_channels if mid_channels is None else mid_channels mid_channels = in_channels if mid_channels is None else mid_channels
...@@ -561,7 +572,12 @@ class UpBlock1D(nn.Module): ...@@ -561,7 +572,12 @@ class UpBlock1D(nn.Module):
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic") self.up = Upsample1d(kernel="cubic")
def forward(self, hidden_states, res_hidden_states_tuple, temb=None): def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
...@@ -574,7 +590,7 @@ class UpBlock1D(nn.Module): ...@@ -574,7 +590,7 @@ class UpBlock1D(nn.Module):
class UpBlock1DNoSkip(nn.Module): class UpBlock1DNoSkip(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None): def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
super().__init__() super().__init__()
mid_channels = in_channels if mid_channels is None else mid_channels mid_channels = in_channels if mid_channels is None else mid_channels
...@@ -586,7 +602,12 @@ class UpBlock1DNoSkip(nn.Module): ...@@ -586,7 +602,12 @@ class UpBlock1DNoSkip(nn.Module):
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, res_hidden_states_tuple, temb=None): def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
...@@ -596,7 +617,20 @@ class UpBlock1DNoSkip(nn.Module): ...@@ -596,7 +617,20 @@ class UpBlock1DNoSkip(nn.Module):
return hidden_states return hidden_states
def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample): DownBlockType = Union[DownResnetBlock1D, DownBlock1D, AttnDownBlock1D, DownBlock1DNoSkip]
MidBlockType = Union[MidResTemporalBlock1D, ValueFunctionMidBlock1D, UNetMidBlock1D]
OutBlockType = Union[OutConv1DBlock, OutValueFunctionBlock]
UpBlockType = Union[UpResnetBlock1D, UpBlock1D, AttnUpBlock1D, UpBlock1DNoSkip]
def get_down_block(
down_block_type: str,
num_layers: int,
in_channels: int,
out_channels: int,
temb_channels: int,
add_downsample: bool,
) -> DownBlockType:
if down_block_type == "DownResnetBlock1D": if down_block_type == "DownResnetBlock1D":
return DownResnetBlock1D( return DownResnetBlock1D(
in_channels=in_channels, in_channels=in_channels,
...@@ -614,7 +648,9 @@ def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_ ...@@ -614,7 +648,9 @@ def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_
raise ValueError(f"{down_block_type} does not exist.") raise ValueError(f"{down_block_type} does not exist.")
def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_channels, add_upsample): def get_up_block(
up_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_upsample: bool
) -> UpBlockType:
if up_block_type == "UpResnetBlock1D": if up_block_type == "UpResnetBlock1D":
return UpResnetBlock1D( return UpResnetBlock1D(
in_channels=in_channels, in_channels=in_channels,
...@@ -632,7 +668,15 @@ def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_chan ...@@ -632,7 +668,15 @@ def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_chan
raise ValueError(f"{up_block_type} does not exist.") raise ValueError(f"{up_block_type} does not exist.")
def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels, embed_dim, add_downsample): def get_mid_block(
mid_block_type: str,
num_layers: int,
in_channels: int,
mid_channels: int,
out_channels: int,
embed_dim: int,
add_downsample: bool,
) -> MidBlockType:
if mid_block_type == "MidResTemporalBlock1D": if mid_block_type == "MidResTemporalBlock1D":
return MidResTemporalBlock1D( return MidResTemporalBlock1D(
num_layers=num_layers, num_layers=num_layers,
...@@ -648,7 +692,9 @@ def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_cha ...@@ -648,7 +692,9 @@ def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_cha
raise ValueError(f"{mid_block_type} does not exist.") raise ValueError(f"{mid_block_type} does not exist.")
def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim): def get_out_block(
*, out_block_type: str, num_groups_out: int, embed_dim: int, out_channels: int, act_fn: str, fc_dim: int
) -> Optional[OutBlockType]:
if out_block_type == "OutConv1DBlock": if out_block_type == "OutConv1DBlock":
return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn) return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
elif out_block_type == "ValueFunction": elif out_block_type == "ValueFunction":
......
This diff is collapsed.
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional, Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -27,7 +27,7 @@ from .unet_2d_blocks import AutoencoderTinyBlock, UNetMidBlock2D, get_down_block ...@@ -27,7 +27,7 @@ from .unet_2d_blocks import AutoencoderTinyBlock, UNetMidBlock2D, get_down_block
@dataclass @dataclass
class DecoderOutput(BaseOutput): class DecoderOutput(BaseOutput):
""" r"""
Output of decoding method. Output of decoding method.
Args: Args:
...@@ -39,16 +39,39 @@ class DecoderOutput(BaseOutput): ...@@ -39,16 +39,39 @@ class DecoderOutput(BaseOutput):
class Encoder(nn.Module): class Encoder(nn.Module):
r"""
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
double_z (`bool`, *optional*, defaults to `True`):
Whether to double the number of output channels for the last block.
"""
def __init__( def __init__(
self, self,
in_channels=3, in_channels: int = 3,
out_channels=3, out_channels: int = 3,
down_block_types=("DownEncoderBlock2D",), down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
block_out_channels=(64,), block_out_channels: Tuple[int, ...] = (64,),
layers_per_block=2, layers_per_block: int = 2,
norm_num_groups=32, norm_num_groups: int = 32,
act_fn="silu", act_fn: str = "silu",
double_z=True, double_z: bool = True,
): ):
super().__init__() super().__init__()
self.layers_per_block = layers_per_block self.layers_per_block = layers_per_block
...@@ -107,7 +130,8 @@ class Encoder(nn.Module): ...@@ -107,7 +130,8 @@ class Encoder(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, x): def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
sample = x sample = x
sample = self.conv_in(sample) sample = self.conv_in(sample)
...@@ -152,16 +176,38 @@ class Encoder(nn.Module): ...@@ -152,16 +176,38 @@ class Encoder(nn.Module):
class Decoder(nn.Module): class Decoder(nn.Module):
r"""
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
norm_type (`str`, *optional*, defaults to `"group"`):
The normalization type to use. Can be either `"group"` or `"spatial"`.
"""
def __init__( def __init__(
self, self,
in_channels=3, in_channels: int = 3,
out_channels=3, out_channels: int = 3,
up_block_types=("UpDecoderBlock2D",), up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
block_out_channels=(64,), block_out_channels: Tuple[int, ...] = (64,),
layers_per_block=2, layers_per_block: int = 2,
norm_num_groups=32, norm_num_groups: int = 32,
act_fn="silu", act_fn: str = "silu",
norm_type="group", # group, spatial norm_type: str = "group", # group, spatial
): ):
super().__init__() super().__init__()
self.layers_per_block = layers_per_block self.layers_per_block = layers_per_block
...@@ -227,7 +273,8 @@ class Decoder(nn.Module): ...@@ -227,7 +273,8 @@ class Decoder(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, z, latent_embeds=None): def forward(self, z: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
sample = z sample = z
sample = self.conv_in(sample) sample = self.conv_in(sample)
...@@ -283,6 +330,16 @@ class Decoder(nn.Module): ...@@ -283,6 +330,16 @@ class Decoder(nn.Module):
class UpSample(nn.Module): class UpSample(nn.Module):
r"""
The `UpSample` layer of a variational autoencoder that upsamples its input.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
"""
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
...@@ -294,6 +351,7 @@ class UpSample(nn.Module): ...@@ -294,6 +351,7 @@ class UpSample(nn.Module):
self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1) self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `UpSample` class."""
x = torch.relu(x) x = torch.relu(x)
x = self.deconv(x) x = self.deconv(x)
return x return x
...@@ -342,6 +400,7 @@ class MaskConditionEncoder(nn.Module): ...@@ -342,6 +400,7 @@ class MaskConditionEncoder(nn.Module):
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
def forward(self, x: torch.FloatTensor, mask=None) -> torch.FloatTensor: def forward(self, x: torch.FloatTensor, mask=None) -> torch.FloatTensor:
r"""The forward method of the `MaskConditionEncoder` class."""
out = {} out = {}
for l in range(len(self.layers)): for l in range(len(self.layers)):
layer = self.layers[l] layer = self.layers[l]
...@@ -352,19 +411,38 @@ class MaskConditionEncoder(nn.Module): ...@@ -352,19 +411,38 @@ class MaskConditionEncoder(nn.Module):
class MaskConditionDecoder(nn.Module): class MaskConditionDecoder(nn.Module):
"""The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's r"""The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's
decoder with a conditioner on the mask and masked image.""" decoder with a conditioner on the mask and masked image.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
norm_type (`str`, *optional*, defaults to `"group"`):
The normalization type to use. Can be either `"group"` or `"spatial"`.
"""
def __init__( def __init__(
self, self,
in_channels=3, in_channels: int = 3,
out_channels=3, out_channels: int = 3,
up_block_types=("UpDecoderBlock2D",), up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
block_out_channels=(64,), block_out_channels: Tuple[int, ...] = (64,),
layers_per_block=2, layers_per_block: int = 2,
norm_num_groups=32, norm_num_groups: int = 32,
act_fn="silu", act_fn: str = "silu",
norm_type="group", # group, spatial norm_type: str = "group", # group, spatial
): ):
super().__init__() super().__init__()
self.layers_per_block = layers_per_block self.layers_per_block = layers_per_block
...@@ -437,7 +515,14 @@ class MaskConditionDecoder(nn.Module): ...@@ -437,7 +515,14 @@ class MaskConditionDecoder(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, z, image=None, mask=None, latent_embeds=None): def forward(
self,
z: torch.FloatTensor,
image: Optional[torch.FloatTensor] = None,
mask: Optional[torch.FloatTensor] = None,
latent_embeds: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
r"""The forward method of the `MaskConditionDecoder` class."""
sample = z sample = z
sample = self.conv_in(sample) sample = self.conv_in(sample)
...@@ -539,7 +624,14 @@ class VectorQuantizer(nn.Module): ...@@ -539,7 +624,14 @@ class VectorQuantizer(nn.Module):
# backwards compatibility we use the buggy version by default, but you can # backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it. # specify legacy=False to fix it.
def __init__( def __init__(
self, n_e, vq_embed_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True self,
n_e: int,
vq_embed_dim: int,
beta: float,
remap=None,
unknown_index: str = "random",
sane_index_shape: bool = False,
legacy: bool = True,
): ):
super().__init__() super().__init__()
self.n_e = n_e self.n_e = n_e
...@@ -553,6 +645,7 @@ class VectorQuantizer(nn.Module): ...@@ -553,6 +645,7 @@ class VectorQuantizer(nn.Module):
self.remap = remap self.remap = remap
if self.remap is not None: if self.remap is not None:
self.register_buffer("used", torch.tensor(np.load(self.remap))) self.register_buffer("used", torch.tensor(np.load(self.remap)))
self.used: torch.Tensor
self.re_embed = self.used.shape[0] self.re_embed = self.used.shape[0]
self.unknown_index = unknown_index # "random" or "extra" or integer self.unknown_index = unknown_index # "random" or "extra" or integer
if self.unknown_index == "extra": if self.unknown_index == "extra":
...@@ -567,7 +660,7 @@ class VectorQuantizer(nn.Module): ...@@ -567,7 +660,7 @@ class VectorQuantizer(nn.Module):
self.sane_index_shape = sane_index_shape self.sane_index_shape = sane_index_shape
def remap_to_used(self, inds): def remap_to_used(self, inds: torch.LongTensor) -> torch.LongTensor:
ishape = inds.shape ishape = inds.shape
assert len(ishape) > 1 assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1) inds = inds.reshape(ishape[0], -1)
...@@ -581,7 +674,7 @@ class VectorQuantizer(nn.Module): ...@@ -581,7 +674,7 @@ class VectorQuantizer(nn.Module):
new[unknown] = self.unknown_index new[unknown] = self.unknown_index
return new.reshape(ishape) return new.reshape(ishape)
def unmap_to_all(self, inds): def unmap_to_all(self, inds: torch.LongTensor) -> torch.LongTensor:
ishape = inds.shape ishape = inds.shape
assert len(ishape) > 1 assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1) inds = inds.reshape(ishape[0], -1)
...@@ -591,7 +684,7 @@ class VectorQuantizer(nn.Module): ...@@ -591,7 +684,7 @@ class VectorQuantizer(nn.Module):
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
return back.reshape(ishape) return back.reshape(ishape)
def forward(self, z): def forward(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, Tuple]:
# reshape z -> (batch, height, width, channel) and flatten # reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous() z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.vq_embed_dim) z_flattened = z.view(-1, self.vq_embed_dim)
...@@ -610,7 +703,7 @@ class VectorQuantizer(nn.Module): ...@@ -610,7 +703,7 @@ class VectorQuantizer(nn.Module):
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
# preserve gradients # preserve gradients
z_q = z + (z_q - z).detach() z_q: torch.FloatTensor = z + (z_q - z).detach()
# reshape back to match original input shape # reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous() z_q = z_q.permute(0, 3, 1, 2).contiguous()
...@@ -625,7 +718,7 @@ class VectorQuantizer(nn.Module): ...@@ -625,7 +718,7 @@ class VectorQuantizer(nn.Module):
return z_q, loss, (perplexity, min_encodings, min_encoding_indices) return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
def get_codebook_entry(self, indices, shape): def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.FloatTensor:
# shape specifying (batch, height, width, channel) # shape specifying (batch, height, width, channel)
if self.remap is not None: if self.remap is not None:
indices = indices.reshape(shape[0], -1) # add batch axis indices = indices.reshape(shape[0], -1) # add batch axis
...@@ -633,7 +726,7 @@ class VectorQuantizer(nn.Module): ...@@ -633,7 +726,7 @@ class VectorQuantizer(nn.Module):
indices = indices.reshape(-1) # flatten again indices = indices.reshape(-1) # flatten again
# get quantized latent vectors # get quantized latent vectors
z_q = self.embedding(indices) z_q: torch.FloatTensor = self.embedding(indices)
if shape is not None: if shape is not None:
z_q = z_q.view(shape) z_q = z_q.view(shape)
...@@ -644,7 +737,7 @@ class VectorQuantizer(nn.Module): ...@@ -644,7 +737,7 @@ class VectorQuantizer(nn.Module):
class DiagonalGaussianDistribution(object): class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False): def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
self.parameters = parameters self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0) self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
...@@ -664,7 +757,7 @@ class DiagonalGaussianDistribution(object): ...@@ -664,7 +757,7 @@ class DiagonalGaussianDistribution(object):
x = self.mean + self.std * sample x = self.mean + self.std * sample
return x return x
def kl(self, other=None): def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
if self.deterministic: if self.deterministic:
return torch.Tensor([0.0]) return torch.Tensor([0.0])
else: else:
...@@ -680,23 +773,40 @@ class DiagonalGaussianDistribution(object): ...@@ -680,23 +773,40 @@ class DiagonalGaussianDistribution(object):
dim=[1, 2, 3], dim=[1, 2, 3],
) )
def nll(self, sample, dims=[1, 2, 3]): def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
if self.deterministic: if self.deterministic:
return torch.Tensor([0.0]) return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi) logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
def mode(self): def mode(self) -> torch.Tensor:
return self.mean return self.mean
class EncoderTiny(nn.Module): class EncoderTiny(nn.Module):
r"""
The `EncoderTiny` layer is a simpler version of the `Encoder` layer.
Args:
in_channels (`int`):
The number of input channels.
out_channels (`int`):
The number of output channels.
num_blocks (`Tuple[int, ...]`):
Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
use.
block_out_channels (`Tuple[int, ...]`):
The number of output channels for each block.
act_fn (`str`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
"""
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
num_blocks: int, num_blocks: Tuple[int, ...],
block_out_channels: int, block_out_channels: Tuple[int, ...],
act_fn: str, act_fn: str,
): ):
super().__init__() super().__init__()
...@@ -718,7 +828,8 @@ class EncoderTiny(nn.Module): ...@@ -718,7 +828,8 @@ class EncoderTiny(nn.Module):
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, x): def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `EncoderTiny` class."""
if self.training and self.gradient_checkpointing: if self.training and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
...@@ -740,12 +851,31 @@ class EncoderTiny(nn.Module): ...@@ -740,12 +851,31 @@ class EncoderTiny(nn.Module):
class DecoderTiny(nn.Module): class DecoderTiny(nn.Module):
r"""
The `DecoderTiny` layer is a simpler version of the `Decoder` layer.
Args:
in_channels (`int`):
The number of input channels.
out_channels (`int`):
The number of output channels.
num_blocks (`Tuple[int, ...]`):
Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
use.
block_out_channels (`Tuple[int, ...]`):
The number of output channels for each block.
upsampling_scaling_factor (`int`):
The scaling factor to use for upsampling.
act_fn (`str`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
"""
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
num_blocks: int, num_blocks: Tuple[int, ...],
block_out_channels: int, block_out_channels: Tuple[int, ...],
upsampling_scaling_factor: int, upsampling_scaling_factor: int,
act_fn: str, act_fn: str,
): ):
...@@ -772,7 +902,8 @@ class DecoderTiny(nn.Module): ...@@ -772,7 +902,8 @@ class DecoderTiny(nn.Module):
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, x): def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `DecoderTiny` class."""
# Clamp. # Clamp.
x = torch.tanh(x / 3) * 3 x = torch.tanh(x / 3) * 3
......
...@@ -53,10 +53,12 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -53,10 +53,12 @@ class VQModel(ModelMixin, ConfigMixin):
Tuple of upsample block types. Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
Tuple of block output channels. Tuple of block output channels.
layers_per_block (`int`, *optional*, defaults to `1`): Number of layers per block.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): Sample input size. sample_size (`int`, *optional*, defaults to `32`): Sample input size.
num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
norm_num_groups (`int`, *optional*, defaults to `32`): Number of groups for normalization layers.
vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
scaling_factor (`float`, *optional*, defaults to `0.18215`): scaling_factor (`float`, *optional*, defaults to `0.18215`):
The component-wise standard deviation of the trained latent space computed using the first batch of the The component-wise standard deviation of the trained latent space computed using the first batch of the
...@@ -65,6 +67,8 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -65,6 +67,8 @@ class VQModel(ModelMixin, ConfigMixin):
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
norm_type (`str`, *optional*, defaults to `"group"`):
Type of normalization layer to use. Can be one of `"group"` or `"spatial"`.
""" """
@register_to_config @register_to_config
...@@ -72,9 +76,9 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -72,9 +76,9 @@ class VQModel(ModelMixin, ConfigMixin):
self, self,
in_channels: int = 3, in_channels: int = 3,
out_channels: int = 3, out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",), down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
up_block_types: Tuple[str] = ("UpDecoderBlock2D",), up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int] = (64,), block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 1, layers_per_block: int = 1,
act_fn: str = "silu", act_fn: str = "silu",
latent_channels: int = 3, latent_channels: int = 3,
......
...@@ -106,7 +106,6 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -106,7 +106,6 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
feature_extractor ([`~transformers.CLIPImageProcessor`]): feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
""" """
model_cpu_offload_seq = "text_encoder->unet->vae" model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"] _exclude_from_cpu_offload = ["safety_checker"]
......
...@@ -134,7 +134,6 @@ class AltDiffusionImg2ImgPipeline( ...@@ -134,7 +134,6 @@ class AltDiffusionImg2ImgPipeline(
feature_extractor ([`~transformers.CLIPImageProcessor`]): feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
""" """
model_cpu_offload_seq = "text_encoder->unet->vae" model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"] _exclude_from_cpu_offload = ["safety_checker"]
......
...@@ -1508,9 +1508,9 @@ class DownBlockFlat(nn.Module): ...@@ -1508,9 +1508,9 @@ class DownBlockFlat(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
add_downsample=True, add_downsample: bool = True,
downsample_padding=1, downsample_padding: int = 1,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -1547,7 +1547,9 @@ class DownBlockFlat(nn.Module): ...@@ -1547,7 +1547,9 @@ class DownBlockFlat(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None, scale: float = 1.0): def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = () output_states = ()
for resnet in self.resnets: for resnet in self.resnets:
...@@ -1596,16 +1598,16 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -1596,16 +1598,16 @@ class CrossAttnDownBlockFlat(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
num_attention_heads=1, num_attention_heads: int = 1,
cross_attention_dim=1280, cross_attention_dim: int = 1280,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
downsample_padding=1, downsample_padding: int = 1,
add_downsample=True, add_downsample: bool = True,
dual_cross_attention=False, dual_cross_attention: bool = False,
use_linear_projection=False, use_linear_projection: bool = False,
only_cross_attention=False, only_cross_attention: bool = False,
upcast_attention=False, upcast_attention: bool = False,
attention_type="default", attention_type: str = "default",
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -1682,8 +1684,8 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -1682,8 +1684,8 @@ class CrossAttnDownBlockFlat(nn.Module):
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
additional_residuals=None, additional_residuals: Optional[torch.FloatTensor] = None,
): ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = () output_states = ()
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
...@@ -1751,7 +1753,7 @@ class UpBlockFlat(nn.Module): ...@@ -1751,7 +1753,7 @@ class UpBlockFlat(nn.Module):
prev_output_channel: int, prev_output_channel: int,
out_channels: int, out_channels: int,
temb_channels: int, temb_channels: int,
resolution_idx: int = None, resolution_idx: Optional[int] = None,
dropout: float = 0.0, dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
resnet_eps: float = 1e-6, resnet_eps: float = 1e-6,
...@@ -1759,8 +1761,8 @@ class UpBlockFlat(nn.Module): ...@@ -1759,8 +1761,8 @@ class UpBlockFlat(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
add_upsample=True, add_upsample: bool = True,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -1794,7 +1796,14 @@ class UpBlockFlat(nn.Module): ...@@ -1794,7 +1796,14 @@ class UpBlockFlat(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
self.resolution_idx = resolution_idx self.resolution_idx = resolution_idx
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
upsample_size: Optional[int] = None,
scale: float = 1.0,
) -> torch.FloatTensor:
is_freeu_enabled = ( is_freeu_enabled = (
getattr(self, "s1", None) getattr(self, "s1", None)
and getattr(self, "s2", None) and getattr(self, "s2", None)
...@@ -1855,7 +1864,7 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -1855,7 +1864,7 @@ class CrossAttnUpBlockFlat(nn.Module):
out_channels: int, out_channels: int,
prev_output_channel: int, prev_output_channel: int,
temb_channels: int, temb_channels: int,
resolution_idx: int = None, resolution_idx: Optional[int] = None,
dropout: float = 0.0, dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1,
...@@ -1864,15 +1873,15 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -1864,15 +1873,15 @@ class CrossAttnUpBlockFlat(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
num_attention_heads=1, num_attention_heads: int = 1,
cross_attention_dim=1280, cross_attention_dim: int = 1280,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
add_upsample=True, add_upsample: bool = True,
dual_cross_attention=False, dual_cross_attention: bool = False,
use_linear_projection=False, use_linear_projection: bool = False,
only_cross_attention=False, only_cross_attention: bool = False,
upcast_attention=False, upcast_attention: bool = False,
attention_type="default", attention_type: str = "default",
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -1949,7 +1958,7 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -1949,7 +1958,7 @@ class CrossAttnUpBlockFlat(nn.Module):
upsample_size: Optional[int] = None, upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
): ) -> torch.FloatTensor:
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
is_freeu_enabled = ( is_freeu_enabled = (
getattr(self, "s1", None) getattr(self, "s1", None)
...@@ -2066,8 +2075,8 @@ class UNetMidBlockFlat(nn.Module): ...@@ -2066,8 +2075,8 @@ class UNetMidBlockFlat(nn.Module):
attn_groups: Optional[int] = None, attn_groups: Optional[int] = None,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
add_attention: bool = True, add_attention: bool = True,
attention_head_dim=1, attention_head_dim: int = 1,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
): ):
super().__init__() super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
...@@ -2138,7 +2147,7 @@ class UNetMidBlockFlat(nn.Module): ...@@ -2138,7 +2147,7 @@ class UNetMidBlockFlat(nn.Module):
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None): def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None: if attn is not None:
...@@ -2162,13 +2171,13 @@ class UNetMidBlockFlatCrossAttn(nn.Module): ...@@ -2162,13 +2171,13 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
num_attention_heads=1, num_attention_heads: int = 1,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
cross_attention_dim=1280, cross_attention_dim: int = 1280,
dual_cross_attention=False, dual_cross_attention: bool = False,
use_linear_projection=False, use_linear_projection: bool = False,
upcast_attention=False, upcast_attention: bool = False,
attention_type="default", attention_type: str = "default",
): ):
super().__init__() super().__init__()
...@@ -2308,12 +2317,12 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module): ...@@ -2308,12 +2317,12 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
attention_head_dim=1, attention_head_dim: int = 1,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
cross_attention_dim=1280, cross_attention_dim: int = 1280,
skip_time_act=False, skip_time_act: bool = False,
only_cross_attention=False, only_cross_attention: bool = False,
cross_attention_norm=None, cross_attention_norm: Optional[str] = None,
): ):
super().__init__() super().__init__()
...@@ -2389,7 +2398,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module): ...@@ -2389,7 +2398,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
): ) -> torch.FloatTensor:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
lora_scale = cross_attention_kwargs.get("scale", 1.0) lora_scale = cross_attention_kwargs.get("scale", 1.0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment