Unverified Commit 0c9f174d authored by Aryan V S's avatar Aryan V S Committed by GitHub
Browse files

Improve typehints and docs in `diffusers/models` (#5391)



* improvement: add typehints and docs to src/diffusers/models/attention_processor.py

* improvement: add typehints and docs to src/diffusers/models/vae.py

* improvement: add missing docs in src/diffusers/models/vq_model.py

* improvement: add typehints and docs to src/diffusers/models/transformer_temporal.py

* improvement: add typehints and docs to src/diffusers/models/t5_film_transformer.py

* improvement: add type hints to src/diffusers/models/unet_1d_blocks.py

* improvement: add missing type hints to src/diffusers/models/unet_2d_blocks.py

* fix: CI error (make fix-copies required)

* fix: CI error (make fix-copies required again)

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent d420d713
...@@ -40,14 +40,50 @@ class Attention(nn.Module): ...@@ -40,14 +40,50 @@ class Attention(nn.Module):
A cross attention layer. A cross attention layer.
Parameters: Parameters:
query_dim (`int`): The number of channels in the query. query_dim (`int`):
The number of channels in the query.
cross_attention_dim (`int`, *optional*): cross_attention_dim (`int`, *optional*):
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. heads (`int`, *optional*, defaults to 8):
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. The number of heads to use for multi-head attention.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. dim_head (`int`, *optional*, defaults to 64):
The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability to use.
bias (`bool`, *optional*, defaults to False): bias (`bool`, *optional*, defaults to False):
Set to `True` for the query, key, and value linear layers to contain a bias parameter. Set to `True` for the query, key, and value linear layers to contain a bias parameter.
upcast_attention (`bool`, *optional*, defaults to False):
Set to `True` to upcast the attention computation to `float32`.
upcast_softmax (`bool`, *optional*, defaults to False):
Set to `True` to upcast the softmax computation to `float32`.
cross_attention_norm (`str`, *optional*, defaults to `None`):
The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups to use for the group norm in the cross attention.
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
The number of channels to use for the added key and value projections. If `None`, no projection is used.
norm_num_groups (`int`, *optional*, defaults to `None`):
The number of groups to use for the group norm in the attention.
spatial_norm_dim (`int`, *optional*, defaults to `None`):
The number of channels to use for the spatial normalization.
out_bias (`bool`, *optional*, defaults to `True`):
Set to `True` to use a bias in the output linear layer.
scale_qk (`bool`, *optional*, defaults to `True`):
Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
only_cross_attention (`bool`, *optional*, defaults to `False`):
Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
`added_kv_proj_dim` is not `None`.
eps (`float`, *optional*, defaults to 1e-5):
An additional value added to the denominator in group normalization that is used for numerical stability.
rescale_output_factor (`float`, *optional*, defaults to 1.0):
A factor to rescale the output by dividing it with this value.
residual_connection (`bool`, *optional*, defaults to `False`):
Set to `True` to add the residual connection to the output.
_from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
Set to `True` if the attention block is loaded from a deprecated state dict.
processor (`AttnProcessor`, *optional*, defaults to `None`):
The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
`AttnProcessor` otherwise.
""" """
def __init__( def __init__(
...@@ -57,7 +93,7 @@ class Attention(nn.Module): ...@@ -57,7 +93,7 @@ class Attention(nn.Module):
heads: int = 8, heads: int = 8,
dim_head: int = 64, dim_head: int = 64,
dropout: float = 0.0, dropout: float = 0.0,
bias=False, bias: bool = False,
upcast_attention: bool = False, upcast_attention: bool = False,
upcast_softmax: bool = False, upcast_softmax: bool = False,
cross_attention_norm: Optional[str] = None, cross_attention_norm: Optional[str] = None,
...@@ -71,7 +107,7 @@ class Attention(nn.Module): ...@@ -71,7 +107,7 @@ class Attention(nn.Module):
eps: float = 1e-5, eps: float = 1e-5,
rescale_output_factor: float = 1.0, rescale_output_factor: float = 1.0,
residual_connection: bool = False, residual_connection: bool = False,
_from_deprecated_attn_block=False, _from_deprecated_attn_block: bool = False,
processor: Optional["AttnProcessor"] = None, processor: Optional["AttnProcessor"] = None,
): ):
super().__init__() super().__init__()
...@@ -172,7 +208,17 @@ class Attention(nn.Module): ...@@ -172,7 +208,17 @@ class Attention(nn.Module):
def set_use_memory_efficient_attention_xformers( def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
): ) -> None:
r"""
Set whether to use memory efficient attention from `xformers` or not.
Args:
use_memory_efficient_attention_xformers (`bool`):
Whether to use memory efficient attention from `xformers` or not.
attention_op (`Callable`, *optional*):
The attention operation to use. Defaults to `None` which uses the default attention operation from
`xformers`.
"""
is_lora = hasattr(self, "processor") and isinstance( is_lora = hasattr(self, "processor") and isinstance(
self.processor, self.processor,
LORA_ATTENTION_PROCESSORS, LORA_ATTENTION_PROCESSORS,
...@@ -294,7 +340,14 @@ class Attention(nn.Module): ...@@ -294,7 +340,14 @@ class Attention(nn.Module):
self.set_processor(processor) self.set_processor(processor)
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size: int) -> None:
r"""
Set the slice size for attention computation.
Args:
slice_size (`int`):
The slice size for attention computation.
"""
if slice_size is not None and slice_size > self.sliceable_head_dim: if slice_size is not None and slice_size > self.sliceable_head_dim:
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
...@@ -315,7 +368,16 @@ class Attention(nn.Module): ...@@ -315,7 +368,16 @@ class Attention(nn.Module):
self.set_processor(processor) self.set_processor(processor)
def set_processor(self, processor: "AttnProcessor", _remove_lora=False): def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False) -> None:
r"""
Set the attention processor to use.
Args:
processor (`AttnProcessor`):
The attention processor to use.
_remove_lora (`bool`, *optional*, defaults to `False`):
Set to `True` to remove LoRA layers from the model.
"""
if hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None: if hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
deprecate( deprecate(
"set_processor to offload LoRA", "set_processor to offload LoRA",
...@@ -342,6 +404,16 @@ class Attention(nn.Module): ...@@ -342,6 +404,16 @@ class Attention(nn.Module):
self.processor = processor self.processor = processor
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
r"""
Get the attention processor in use.
Args:
return_deprecated_lora (`bool`, *optional*, defaults to `False`):
Set to `True` to return the deprecated LoRA attention processor.
Returns:
"AttentionProcessor": The attention processor in use.
"""
if not return_deprecated_lora: if not return_deprecated_lora:
return self.processor return self.processor
...@@ -421,7 +493,29 @@ class Attention(nn.Module): ...@@ -421,7 +493,29 @@ class Attention(nn.Module):
return lora_processor return lora_processor
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
**cross_attention_kwargs,
) -> torch.Tensor:
r"""
The forward method of the `Attention` class.
Args:
hidden_states (`torch.Tensor`):
The hidden states of the query.
encoder_hidden_states (`torch.Tensor`, *optional*):
The hidden states of the encoder.
attention_mask (`torch.Tensor`, *optional*):
The attention mask to use. If `None`, no mask is applied.
**cross_attention_kwargs:
Additional keyword arguments to pass along to the cross attention.
Returns:
`torch.Tensor`: The output of the attention layer.
"""
# The `Attention` class can call different attention processors / attention functions # The `Attention` class can call different attention processors / attention functions
# here we simply pass along all tensors to the selected processor class # here we simply pass along all tensors to the selected processor class
# For standard processors that are defined here, `**cross_attention_kwargs` is empty # For standard processors that are defined here, `**cross_attention_kwargs` is empty
...@@ -433,14 +527,36 @@ class Attention(nn.Module): ...@@ -433,14 +527,36 @@ class Attention(nn.Module):
**cross_attention_kwargs, **cross_attention_kwargs,
) )
def batch_to_head_dim(self, tensor): def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
r"""
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
is the number of heads initialized while constructing the `Attention` class.
Args:
tensor (`torch.Tensor`): The tensor to reshape.
Returns:
`torch.Tensor`: The reshaped tensor.
"""
head_size = self.heads head_size = self.heads
batch_size, seq_len, dim = tensor.shape batch_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor return tensor
def head_to_batch_dim(self, tensor, out_dim=3): def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
r"""
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
the number of heads initialized while constructing the `Attention` class.
Args:
tensor (`torch.Tensor`): The tensor to reshape.
out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
reshaped to `[batch_size * heads, seq_len, dim // heads]`.
Returns:
`torch.Tensor`: The reshaped tensor.
"""
head_size = self.heads head_size = self.heads
batch_size, seq_len, dim = tensor.shape batch_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
...@@ -451,7 +567,20 @@ class Attention(nn.Module): ...@@ -451,7 +567,20 @@ class Attention(nn.Module):
return tensor return tensor
def get_attention_scores(self, query, key, attention_mask=None): def get_attention_scores(
self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None
) -> torch.Tensor:
r"""
Compute the attention scores.
Args:
query (`torch.Tensor`): The query tensor.
key (`torch.Tensor`): The key tensor.
attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
Returns:
`torch.Tensor`: The attention probabilities/scores.
"""
dtype = query.dtype dtype = query.dtype
if self.upcast_attention: if self.upcast_attention:
query = query.float() query = query.float()
...@@ -485,7 +614,25 @@ class Attention(nn.Module): ...@@ -485,7 +614,25 @@ class Attention(nn.Module):
return attention_probs return attention_probs
def prepare_attention_mask(self, attention_mask, target_length, batch_size, out_dim=3): def prepare_attention_mask(
self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
) -> torch.Tensor:
r"""
Prepare the attention mask for the attention computation.
Args:
attention_mask (`torch.Tensor`):
The attention mask to prepare.
target_length (`int`):
The target length of the attention mask. This is the length of the attention mask after padding.
batch_size (`int`):
The batch size, which is used to repeat the attention mask.
out_dim (`int`, *optional*, defaults to `3`):
The output dimension of the attention mask. Can be either `3` or `4`.
Returns:
`torch.Tensor`: The prepared attention mask.
"""
head_size = self.heads head_size = self.heads
if attention_mask is None: if attention_mask is None:
return attention_mask return attention_mask
...@@ -514,7 +661,17 @@ class Attention(nn.Module): ...@@ -514,7 +661,17 @@ class Attention(nn.Module):
return attention_mask return attention_mask
def norm_encoder_hidden_states(self, encoder_hidden_states): def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
r"""
Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
`Attention` class.
Args:
encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
Returns:
`torch.Tensor`: The normalized encoder hidden states.
"""
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
if isinstance(self.norm_cross, nn.LayerNorm): if isinstance(self.norm_cross, nn.LayerNorm):
...@@ -542,12 +699,12 @@ class AttnProcessor: ...@@ -542,12 +699,12 @@ class AttnProcessor:
def __call__( def __call__(
self, self,
attn: Attention, attn: Attention,
hidden_states, hidden_states: torch.FloatTensor,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
temb=None, temb: Optional[torch.FloatTensor] = None,
scale=1.0, scale: float = 1.0,
): ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,) args = () if USE_PEFT_BACKEND else (scale,)
...@@ -624,12 +781,12 @@ class CustomDiffusionAttnProcessor(nn.Module): ...@@ -624,12 +781,12 @@ class CustomDiffusionAttnProcessor(nn.Module):
def __init__( def __init__(
self, self,
train_kv=True, train_kv: bool = True,
train_q_out=True, train_q_out: bool = True,
hidden_size=None, hidden_size: Optional[int] = None,
cross_attention_dim=None, cross_attention_dim: Optional[int] = None,
out_bias=True, out_bias: bool = True,
dropout=0.0, dropout: float = 0.0,
): ):
super().__init__() super().__init__()
self.train_kv = train_kv self.train_kv = train_kv
...@@ -648,7 +805,13 @@ class CustomDiffusionAttnProcessor(nn.Module): ...@@ -648,7 +805,13 @@ class CustomDiffusionAttnProcessor(nn.Module):
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias)) self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
self.to_out_custom_diffusion.append(nn.Dropout(dropout)) self.to_out_custom_diffusion.append(nn.Dropout(dropout))
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if self.train_q_out: if self.train_q_out:
...@@ -707,7 +870,14 @@ class AttnAddedKVProcessor: ...@@ -707,7 +870,14 @@ class AttnAddedKVProcessor:
encoder. encoder.
""" """
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
...@@ -767,7 +937,14 @@ class AttnAddedKVProcessor2_0: ...@@ -767,7 +937,14 @@ class AttnAddedKVProcessor2_0:
"AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
) )
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
...@@ -833,7 +1010,13 @@ class XFormersAttnAddedKVProcessor: ...@@ -833,7 +1010,13 @@ class XFormersAttnAddedKVProcessor:
def __init__(self, attention_op: Optional[Callable] = None): def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op self.attention_op = attention_op
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
...@@ -906,7 +1089,7 @@ class XFormersAttnProcessor: ...@@ -906,7 +1089,7 @@ class XFormersAttnProcessor:
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0, scale: float = 1.0,
): ) -> torch.FloatTensor:
residual = hidden_states residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,) args = () if USE_PEFT_BACKEND else (scale,)
...@@ -986,12 +1169,12 @@ class AttnProcessor2_0: ...@@ -986,12 +1169,12 @@ class AttnProcessor2_0:
def __call__( def __call__(
self, self,
attn: Attention, attn: Attention,
hidden_states, hidden_states: torch.FloatTensor,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
temb=None, temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0, scale: float = 1.0,
): ) -> torch.FloatTensor:
residual = hidden_states residual = hidden_states
if attn.spatial_norm is not None: if attn.spatial_norm is not None:
...@@ -1091,12 +1274,12 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module): ...@@ -1091,12 +1274,12 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module):
def __init__( def __init__(
self, self,
train_kv=True, train_kv: bool = True,
train_q_out=False, train_q_out: bool = False,
hidden_size=None, hidden_size: Optional[int] = None,
cross_attention_dim=None, cross_attention_dim: Optional[int] = None,
out_bias=True, out_bias: bool = True,
dropout=0.0, dropout: float = 0.0,
attention_op: Optional[Callable] = None, attention_op: Optional[Callable] = None,
): ):
super().__init__() super().__init__()
...@@ -1117,7 +1300,13 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module): ...@@ -1117,7 +1300,13 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module):
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias)) self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
self.to_out_custom_diffusion.append(nn.Dropout(dropout)) self.to_out_custom_diffusion.append(nn.Dropout(dropout))
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
batch_size, sequence_length, _ = ( batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
) )
...@@ -1197,12 +1386,12 @@ class CustomDiffusionAttnProcessor2_0(nn.Module): ...@@ -1197,12 +1386,12 @@ class CustomDiffusionAttnProcessor2_0(nn.Module):
def __init__( def __init__(
self, self,
train_kv=True, train_kv: bool = True,
train_q_out=True, train_q_out: bool = True,
hidden_size=None, hidden_size: Optional[int] = None,
cross_attention_dim=None, cross_attention_dim: Optional[int] = None,
out_bias=True, out_bias: bool = True,
dropout=0.0, dropout: float = 0.0,
): ):
super().__init__() super().__init__()
self.train_kv = train_kv self.train_kv = train_kv
...@@ -1221,7 +1410,13 @@ class CustomDiffusionAttnProcessor2_0(nn.Module): ...@@ -1221,7 +1410,13 @@ class CustomDiffusionAttnProcessor2_0(nn.Module):
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias)) self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
self.to_out_custom_diffusion.append(nn.Dropout(dropout)) self.to_out_custom_diffusion.append(nn.Dropout(dropout))
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if self.train_q_out: if self.train_q_out:
...@@ -1290,10 +1485,16 @@ class SlicedAttnProcessor: ...@@ -1290,10 +1485,16 @@ class SlicedAttnProcessor:
`attention_head_dim` must be a multiple of the `slice_size`. `attention_head_dim` must be a multiple of the `slice_size`.
""" """
def __init__(self, slice_size): def __init__(self, slice_size: int):
self.slice_size = slice_size self.slice_size = slice_size
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
residual = hidden_states residual = hidden_states
input_ndim = hidden_states.ndim input_ndim = hidden_states.ndim
...@@ -1374,7 +1575,14 @@ class SlicedAttnAddedKVProcessor: ...@@ -1374,7 +1575,14 @@ class SlicedAttnAddedKVProcessor:
def __init__(self, slice_size): def __init__(self, slice_size):
self.slice_size = slice_size self.slice_size = slice_size
def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): def __call__(
self,
attn: "Attention",
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
residual = hidden_states residual = hidden_states
if attn.spatial_norm is not None: if attn.spatial_norm is not None:
...@@ -1448,20 +1656,26 @@ class SlicedAttnAddedKVProcessor: ...@@ -1448,20 +1656,26 @@ class SlicedAttnAddedKVProcessor:
class SpatialNorm(nn.Module): class SpatialNorm(nn.Module):
""" """
Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002 Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002.
Args:
f_channels (`int`):
The number of channels for input to group normalization layer, and output of the spatial norm layer.
zq_channels (`int`):
The number of channels for the quantized vector as described in the paper.
""" """
def __init__( def __init__(
self, self,
f_channels, f_channels: int,
zq_channels, zq_channels: int,
): ):
super().__init__() super().__init__()
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
def forward(self, f, zq): def forward(self, f: torch.FloatTensor, zq: torch.FloatTensor) -> torch.FloatTensor:
f_size = f.shape[-2:] f_size = f.shape[-2:]
zq = F.interpolate(zq, size=f_size, mode="nearest") zq = F.interpolate(zq, size=f_size, mode="nearest")
norm_f = self.norm_layer(f) norm_f = self.norm_layer(f)
...@@ -1483,9 +1697,18 @@ class LoRAAttnProcessor(nn.Module): ...@@ -1483,9 +1697,18 @@ class LoRAAttnProcessor(nn.Module):
The dimension of the LoRA update matrices. The dimension of the LoRA update matrices.
network_alpha (`int`, *optional*): network_alpha (`int`, *optional*):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
kwargs (`dict`):
Additional keyword arguments to pass to the `LoRALinearLayer` layers.
""" """
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, **kwargs): def __init__(
self,
hidden_size: int,
cross_attention_dim: Optional[int] = None,
rank: int = 4,
network_alpha: Optional[int] = None,
**kwargs,
):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -1512,7 +1735,7 @@ class LoRAAttnProcessor(nn.Module): ...@@ -1512,7 +1735,7 @@ class LoRAAttnProcessor(nn.Module):
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha) self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha) self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(self, attn: Attention, hidden_states, *args, **kwargs): def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
self_cls_name = self.__class__.__name__ self_cls_name = self.__class__.__name__
deprecate( deprecate(
self_cls_name, self_cls_name,
...@@ -1547,9 +1770,18 @@ class LoRAAttnProcessor2_0(nn.Module): ...@@ -1547,9 +1770,18 @@ class LoRAAttnProcessor2_0(nn.Module):
The dimension of the LoRA update matrices. The dimension of the LoRA update matrices.
network_alpha (`int`, *optional*): network_alpha (`int`, *optional*):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
kwargs (`dict`):
Additional keyword arguments to pass to the `LoRALinearLayer` layers.
""" """
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, **kwargs): def __init__(
self,
hidden_size: int,
cross_attention_dim: Optional[int] = None,
rank: int = 4,
network_alpha: Optional[int] = None,
**kwargs,
):
super().__init__() super().__init__()
if not hasattr(F, "scaled_dot_product_attention"): if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
...@@ -1578,7 +1810,7 @@ class LoRAAttnProcessor2_0(nn.Module): ...@@ -1578,7 +1810,7 @@ class LoRAAttnProcessor2_0(nn.Module):
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha) self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha) self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(self, attn: Attention, hidden_states, *args, **kwargs): def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
self_cls_name = self.__class__.__name__ self_cls_name = self.__class__.__name__
deprecate( deprecate(
self_cls_name, self_cls_name,
...@@ -1617,16 +1849,17 @@ class LoRAXFormersAttnProcessor(nn.Module): ...@@ -1617,16 +1849,17 @@ class LoRAXFormersAttnProcessor(nn.Module):
operator. operator.
network_alpha (`int`, *optional*): network_alpha (`int`, *optional*):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
kwargs (`dict`):
Additional keyword arguments to pass to the `LoRALinearLayer` layers.
""" """
def __init__( def __init__(
self, self,
hidden_size, hidden_size: int,
cross_attention_dim, cross_attention_dim: int,
rank=4, rank: int = 4,
attention_op: Optional[Callable] = None, attention_op: Optional[Callable] = None,
network_alpha=None, network_alpha: Optional[int] = None,
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
...@@ -1656,7 +1889,7 @@ class LoRAXFormersAttnProcessor(nn.Module): ...@@ -1656,7 +1889,7 @@ class LoRAXFormersAttnProcessor(nn.Module):
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha) self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha) self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(self, attn: Attention, hidden_states, *args, **kwargs): def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
self_cls_name = self.__class__.__name__ self_cls_name = self.__class__.__name__
deprecate( deprecate(
self_cls_name, self_cls_name,
...@@ -1689,10 +1922,19 @@ class LoRAAttnAddedKVProcessor(nn.Module): ...@@ -1689,10 +1922,19 @@ class LoRAAttnAddedKVProcessor(nn.Module):
The number of channels in the `encoder_hidden_states`. The number of channels in the `encoder_hidden_states`.
rank (`int`, defaults to 4): rank (`int`, defaults to 4):
The dimension of the LoRA update matrices. The dimension of the LoRA update matrices.
network_alpha (`int`, *optional*):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
kwargs (`dict`):
Additional keyword arguments to pass to the `LoRALinearLayer` layers.
""" """
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): def __init__(
self,
hidden_size: int,
cross_attention_dim: Optional[int] = None,
rank: int = 4,
network_alpha: Optional[int] = None,
):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -1706,7 +1948,7 @@ class LoRAAttnAddedKVProcessor(nn.Module): ...@@ -1706,7 +1948,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
def __call__(self, attn: Attention, hidden_states, *args, **kwargs): def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
self_cls_name = self.__class__.__name__ self_cls_name = self.__class__.__name__
deprecate( deprecate(
self_cls_name, self_cls_name,
...@@ -1764,7 +2006,7 @@ AttentionProcessor = Union[ ...@@ -1764,7 +2006,7 @@ AttentionProcessor = Union[
CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor,
CustomDiffusionXFormersAttnProcessor, CustomDiffusionXFormersAttnProcessor,
CustomDiffusionAttnProcessor2_0, CustomDiffusionAttnProcessor2_0,
# depraceted # deprecated
LoRAAttnProcessor, LoRAAttnProcessor,
LoRAAttnProcessor2_0, LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math import math
from typing import Optional, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -23,6 +24,28 @@ from .modeling_utils import ModelMixin ...@@ -23,6 +24,28 @@ from .modeling_utils import ModelMixin
class T5FilmDecoder(ModelMixin, ConfigMixin): class T5FilmDecoder(ModelMixin, ConfigMixin):
r"""
T5 style decoder with FiLM conditioning.
Args:
input_dims (`int`, *optional*, defaults to `128`):
The number of input dimensions.
targets_length (`int`, *optional*, defaults to `256`):
The length of the targets.
d_model (`int`, *optional*, defaults to `768`):
Size of the input hidden states.
num_layers (`int`, *optional*, defaults to `12`):
The number of `DecoderLayer`'s to use.
num_heads (`int`, *optional*, defaults to `12`):
The number of attention heads to use.
d_kv (`int`, *optional*, defaults to `64`):
Size of the key-value projection vectors.
d_ff (`int`, *optional*, defaults to `2048`):
The number of dimensions in the intermediate feed-forward layer of `DecoderLayer`'s.
dropout_rate (`float`, *optional*, defaults to `0.1`):
Dropout probability.
"""
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
...@@ -63,7 +86,7 @@ class T5FilmDecoder(ModelMixin, ConfigMixin): ...@@ -63,7 +86,7 @@ class T5FilmDecoder(ModelMixin, ConfigMixin):
self.post_dropout = nn.Dropout(p=dropout_rate) self.post_dropout = nn.Dropout(p=dropout_rate)
self.spec_out = nn.Linear(d_model, input_dims, bias=False) self.spec_out = nn.Linear(d_model, input_dims, bias=False)
def encoder_decoder_mask(self, query_input, key_input): def encoder_decoder_mask(self, query_input: torch.FloatTensor, key_input: torch.FloatTensor) -> torch.FloatTensor:
mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2)) mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
return mask.unsqueeze(-3) return mask.unsqueeze(-3)
...@@ -125,7 +148,27 @@ class T5FilmDecoder(ModelMixin, ConfigMixin): ...@@ -125,7 +148,27 @@ class T5FilmDecoder(ModelMixin, ConfigMixin):
class DecoderLayer(nn.Module): class DecoderLayer(nn.Module):
def __init__(self, d_model, d_kv, num_heads, d_ff, dropout_rate, layer_norm_epsilon=1e-6): r"""
T5 decoder layer.
Args:
d_model (`int`):
Size of the input hidden states.
d_kv (`int`):
Size of the key-value projection vectors.
num_heads (`int`):
Number of attention heads.
d_ff (`int`):
Size of the intermediate feed-forward layer.
dropout_rate (`float`):
Dropout probability.
layer_norm_epsilon (`float`, *optional*, defaults to `1e-6`):
A small value used for numerical stability to avoid dividing by zero.
"""
def __init__(
self, d_model: int, d_kv: int, num_heads: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float = 1e-6
):
super().__init__() super().__init__()
self.layer = nn.ModuleList() self.layer = nn.ModuleList()
...@@ -152,13 +195,13 @@ class DecoderLayer(nn.Module): ...@@ -152,13 +195,13 @@ class DecoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.FloatTensor,
conditioning_emb=None, conditioning_emb: Optional[torch.FloatTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.Tensor] = None,
encoder_decoder_position_bias=None, encoder_decoder_position_bias=None,
): ) -> Tuple[torch.FloatTensor]:
hidden_states = self.layer[0]( hidden_states = self.layer[0](
hidden_states, hidden_states,
conditioning_emb=conditioning_emb, conditioning_emb=conditioning_emb,
...@@ -183,7 +226,21 @@ class DecoderLayer(nn.Module): ...@@ -183,7 +226,21 @@ class DecoderLayer(nn.Module):
class T5LayerSelfAttentionCond(nn.Module): class T5LayerSelfAttentionCond(nn.Module):
def __init__(self, d_model, d_kv, num_heads, dropout_rate): r"""
T5 style self-attention layer with conditioning.
Args:
d_model (`int`):
Size of the input hidden states.
d_kv (`int`):
Size of the key-value projection vectors.
num_heads (`int`):
Number of attention heads.
dropout_rate (`float`):
Dropout probability.
"""
def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float):
super().__init__() super().__init__()
self.layer_norm = T5LayerNorm(d_model) self.layer_norm = T5LayerNorm(d_model)
self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
...@@ -192,10 +249,10 @@ class T5LayerSelfAttentionCond(nn.Module): ...@@ -192,10 +249,10 @@ class T5LayerSelfAttentionCond(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.FloatTensor,
conditioning_emb=None, conditioning_emb: Optional[torch.FloatTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
): ) -> torch.FloatTensor:
# pre_self_attention_layer_norm # pre_self_attention_layer_norm
normed_hidden_states = self.layer_norm(hidden_states) normed_hidden_states = self.layer_norm(hidden_states)
...@@ -211,7 +268,23 @@ class T5LayerSelfAttentionCond(nn.Module): ...@@ -211,7 +268,23 @@ class T5LayerSelfAttentionCond(nn.Module):
class T5LayerCrossAttention(nn.Module): class T5LayerCrossAttention(nn.Module):
def __init__(self, d_model, d_kv, num_heads, dropout_rate, layer_norm_epsilon): r"""
T5 style cross-attention layer.
Args:
d_model (`int`):
Size of the input hidden states.
d_kv (`int`):
Size of the key-value projection vectors.
num_heads (`int`):
Number of attention heads.
dropout_rate (`float`):
Dropout probability.
layer_norm_epsilon (`float`):
A small value used for numerical stability to avoid dividing by zero.
"""
def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float, layer_norm_epsilon: float):
super().__init__() super().__init__()
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False) self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
...@@ -219,10 +292,10 @@ class T5LayerCrossAttention(nn.Module): ...@@ -219,10 +292,10 @@ class T5LayerCrossAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.FloatTensor,
key_value_states=None, key_value_states: Optional[torch.FloatTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
): ) -> torch.FloatTensor:
normed_hidden_states = self.layer_norm(hidden_states) normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.attention( attention_output = self.attention(
normed_hidden_states, normed_hidden_states,
...@@ -234,14 +307,30 @@ class T5LayerCrossAttention(nn.Module): ...@@ -234,14 +307,30 @@ class T5LayerCrossAttention(nn.Module):
class T5LayerFFCond(nn.Module): class T5LayerFFCond(nn.Module):
def __init__(self, d_model, d_ff, dropout_rate, layer_norm_epsilon): r"""
T5 style feed-forward conditional layer.
Args:
d_model (`int`):
Size of the input hidden states.
d_ff (`int`):
Size of the intermediate feed-forward layer.
dropout_rate (`float`):
Dropout probability.
layer_norm_epsilon (`float`):
A small value used for numerical stability to avoid dividing by zero.
"""
def __init__(self, d_model: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float):
super().__init__() super().__init__()
self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate) self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate)
self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
self.dropout = nn.Dropout(dropout_rate) self.dropout = nn.Dropout(dropout_rate)
def forward(self, hidden_states, conditioning_emb=None): def forward(
self, hidden_states: torch.FloatTensor, conditioning_emb: Optional[torch.FloatTensor] = None
) -> torch.FloatTensor:
forwarded_states = self.layer_norm(hidden_states) forwarded_states = self.layer_norm(hidden_states)
if conditioning_emb is not None: if conditioning_emb is not None:
forwarded_states = self.film(forwarded_states, conditioning_emb) forwarded_states = self.film(forwarded_states, conditioning_emb)
...@@ -252,7 +341,19 @@ class T5LayerFFCond(nn.Module): ...@@ -252,7 +341,19 @@ class T5LayerFFCond(nn.Module):
class T5DenseGatedActDense(nn.Module): class T5DenseGatedActDense(nn.Module):
def __init__(self, d_model, d_ff, dropout_rate): r"""
T5 style feed-forward layer with gated activations and dropout.
Args:
d_model (`int`):
Size of the input hidden states.
d_ff (`int`):
Size of the intermediate feed-forward layer.
dropout_rate (`float`):
Dropout probability.
"""
def __init__(self, d_model: int, d_ff: int, dropout_rate: float):
super().__init__() super().__init__()
self.wi_0 = nn.Linear(d_model, d_ff, bias=False) self.wi_0 = nn.Linear(d_model, d_ff, bias=False)
self.wi_1 = nn.Linear(d_model, d_ff, bias=False) self.wi_1 = nn.Linear(d_model, d_ff, bias=False)
...@@ -260,7 +361,7 @@ class T5DenseGatedActDense(nn.Module): ...@@ -260,7 +361,7 @@ class T5DenseGatedActDense(nn.Module):
self.dropout = nn.Dropout(dropout_rate) self.dropout = nn.Dropout(dropout_rate)
self.act = NewGELUActivation() self.act = NewGELUActivation()
def forward(self, hidden_states): def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
hidden_gelu = self.act(self.wi_0(hidden_states)) hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states) hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear hidden_states = hidden_gelu * hidden_linear
...@@ -271,7 +372,17 @@ class T5DenseGatedActDense(nn.Module): ...@@ -271,7 +372,17 @@ class T5DenseGatedActDense(nn.Module):
class T5LayerNorm(nn.Module): class T5LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): r"""
T5 style layer normalization module.
Args:
hidden_size (`int`):
Size of the input hidden states.
eps (`float`, `optional`, defaults to `1e-6`):
A small value used for numerical stability to avoid dividing by zero.
"""
def __init__(self, hidden_size: int, eps: float = 1e-6):
""" """
Construct a layernorm module in the T5 style. No bias and no subtraction of mean. Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
""" """
...@@ -279,7 +390,7 @@ class T5LayerNorm(nn.Module): ...@@ -279,7 +390,7 @@ class T5LayerNorm(nn.Module):
self.weight = nn.Parameter(torch.ones(hidden_size)) self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps self.variance_epsilon = eps
def forward(self, hidden_states): def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
...@@ -307,14 +418,20 @@ class NewGELUActivation(nn.Module): ...@@ -307,14 +418,20 @@ class NewGELUActivation(nn.Module):
class T5FiLMLayer(nn.Module): class T5FiLMLayer(nn.Module):
""" """
FiLM Layer T5 style FiLM Layer.
Args:
in_features (`int`):
Number of input features.
out_features (`int`):
Number of output features.
""" """
def __init__(self, in_features, out_features): def __init__(self, in_features: int, out_features: int):
super().__init__() super().__init__()
self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False) self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
def forward(self, x, conditioning_emb): def forward(self, x: torch.FloatTensor, conditioning_emb: torch.FloatTensor) -> torch.FloatTensor:
emb = self.scale_bias(conditioning_emb) emb = self.scale_bias(conditioning_emb)
scale, shift = torch.chunk(emb, 2, -1) scale, shift = torch.chunk(emb, 2, -1)
x = x * (1 + scale) + shift x = x * (1 + scale) + shift
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Any, Dict, Optional
import torch import torch
from torch import nn from torch import nn
...@@ -48,11 +48,15 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin): ...@@ -48,11 +48,15 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
attention_bias (`bool`, *optional*): attention_bias (`bool`, *optional*):
Configure if the `TransformerBlock` attention should contain a bias parameter. Configure if the `TransformerBlock` attention should contain a bias parameter.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
activation_fn (`str`, *optional*, defaults to `"geglu"`):
Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
activation functions.
norm_elementwise_affine (`bool`, *optional*):
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
double_self_attention (`bool`, *optional*): double_self_attention (`bool`, *optional*):
Configure if each `TransformerBlock` should contain two self-attention layers. Configure if each `TransformerBlock` should contain two self-attention layers.
""" """
...@@ -106,14 +110,14 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin): ...@@ -106,14 +110,14 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.FloatTensor,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.LongTensor] = None,
timestep=None, timestep: Optional[torch.LongTensor] = None,
class_labels=None, class_labels: torch.LongTensor = None,
num_frames=1, num_frames: int = 1,
cross_attention_kwargs=None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True, return_dict: bool = True,
): ) -> TransformerTemporalModelOutput:
""" """
The [`TransformerTemporal`] forward method. The [`TransformerTemporal`] forward method.
...@@ -123,7 +127,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin): ...@@ -123,7 +127,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention. self-attention.
timestep ( `torch.long`, *optional*): timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math import math
from typing import Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -24,17 +25,17 @@ from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange ...@@ -24,17 +25,17 @@ from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange
class DownResnetBlock1D(nn.Module): class DownResnetBlock1D(nn.Module):
def __init__( def __init__(
self, self,
in_channels, in_channels: int,
out_channels=None, out_channels: Optional[int] = None,
num_layers=1, num_layers: int = 1,
conv_shortcut=False, conv_shortcut: bool = False,
temb_channels=32, temb_channels: int = 32,
groups=32, groups: int = 32,
groups_out=None, groups_out: Optional[int] = None,
non_linearity=None, non_linearity: Optional[str] = None,
time_embedding_norm="default", time_embedding_norm: str = "default",
output_scale_factor=1.0, output_scale_factor: float = 1.0,
add_downsample=True, add_downsample: bool = True,
): ):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
...@@ -65,7 +66,7 @@ class DownResnetBlock1D(nn.Module): ...@@ -65,7 +66,7 @@ class DownResnetBlock1D(nn.Module):
if add_downsample: if add_downsample:
self.downsample = Downsample1D(out_channels, use_conv=True, padding=1) self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
def forward(self, hidden_states, temb=None): def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
output_states = () output_states = ()
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
...@@ -86,16 +87,16 @@ class DownResnetBlock1D(nn.Module): ...@@ -86,16 +87,16 @@ class DownResnetBlock1D(nn.Module):
class UpResnetBlock1D(nn.Module): class UpResnetBlock1D(nn.Module):
def __init__( def __init__(
self, self,
in_channels, in_channels: int,
out_channels=None, out_channels: Optional[int] = None,
num_layers=1, num_layers: int = 1,
temb_channels=32, temb_channels: int = 32,
groups=32, groups: int = 32,
groups_out=None, groups_out: Optional[int] = None,
non_linearity=None, non_linearity: Optional[str] = None,
time_embedding_norm="default", time_embedding_norm: str = "default",
output_scale_factor=1.0, output_scale_factor: float = 1.0,
add_upsample=True, add_upsample: bool = True,
): ):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
...@@ -125,7 +126,12 @@ class UpResnetBlock1D(nn.Module): ...@@ -125,7 +126,12 @@ class UpResnetBlock1D(nn.Module):
if add_upsample: if add_upsample:
self.upsample = Upsample1D(out_channels, use_conv_transpose=True) self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None): def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Optional[Tuple[torch.FloatTensor, ...]] = None,
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
if res_hidden_states_tuple is not None: if res_hidden_states_tuple is not None:
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1) hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
...@@ -144,7 +150,7 @@ class UpResnetBlock1D(nn.Module): ...@@ -144,7 +150,7 @@ class UpResnetBlock1D(nn.Module):
class ValueFunctionMidBlock1D(nn.Module): class ValueFunctionMidBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, embed_dim): def __init__(self, in_channels: int, out_channels: int, embed_dim: int):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
...@@ -155,7 +161,7 @@ class ValueFunctionMidBlock1D(nn.Module): ...@@ -155,7 +161,7 @@ class ValueFunctionMidBlock1D(nn.Module):
self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim) self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
self.down2 = Downsample1D(out_channels // 4, use_conv=True) self.down2 = Downsample1D(out_channels // 4, use_conv=True)
def forward(self, x, temb=None): def forward(self, x: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
x = self.res1(x, temb) x = self.res1(x, temb)
x = self.down1(x) x = self.down1(x)
x = self.res2(x, temb) x = self.res2(x, temb)
...@@ -166,13 +172,13 @@ class ValueFunctionMidBlock1D(nn.Module): ...@@ -166,13 +172,13 @@ class ValueFunctionMidBlock1D(nn.Module):
class MidResTemporalBlock1D(nn.Module): class MidResTemporalBlock1D(nn.Module):
def __init__( def __init__(
self, self,
in_channels, in_channels: int,
out_channels, out_channels: int,
embed_dim, embed_dim: int,
num_layers: int = 1, num_layers: int = 1,
add_downsample: bool = False, add_downsample: bool = False,
add_upsample: bool = False, add_upsample: bool = False,
non_linearity=None, non_linearity: Optional[str] = None,
): ):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
...@@ -203,7 +209,7 @@ class MidResTemporalBlock1D(nn.Module): ...@@ -203,7 +209,7 @@ class MidResTemporalBlock1D(nn.Module):
if self.upsample and self.downsample: if self.upsample and self.downsample:
raise ValueError("Block cannot downsample and upsample") raise ValueError("Block cannot downsample and upsample")
def forward(self, hidden_states, temb): def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for resnet in self.resnets[1:]: for resnet in self.resnets[1:]:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
...@@ -217,14 +223,14 @@ class MidResTemporalBlock1D(nn.Module): ...@@ -217,14 +223,14 @@ class MidResTemporalBlock1D(nn.Module):
class OutConv1DBlock(nn.Module): class OutConv1DBlock(nn.Module):
def __init__(self, num_groups_out, out_channels, embed_dim, act_fn): def __init__(self, num_groups_out: int, out_channels: int, embed_dim: int, act_fn: str):
super().__init__() super().__init__()
self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2) self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim) self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
self.final_conv1d_act = get_activation(act_fn) self.final_conv1d_act = get_activation(act_fn)
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1) self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
def forward(self, hidden_states, temb=None): def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.final_conv1d_1(hidden_states) hidden_states = self.final_conv1d_1(hidden_states)
hidden_states = rearrange_dims(hidden_states) hidden_states = rearrange_dims(hidden_states)
hidden_states = self.final_conv1d_gn(hidden_states) hidden_states = self.final_conv1d_gn(hidden_states)
...@@ -235,7 +241,7 @@ class OutConv1DBlock(nn.Module): ...@@ -235,7 +241,7 @@ class OutConv1DBlock(nn.Module):
class OutValueFunctionBlock(nn.Module): class OutValueFunctionBlock(nn.Module):
def __init__(self, fc_dim, embed_dim, act_fn="mish"): def __init__(self, fc_dim: int, embed_dim: int, act_fn: str = "mish"):
super().__init__() super().__init__()
self.final_block = nn.ModuleList( self.final_block = nn.ModuleList(
[ [
...@@ -245,7 +251,7 @@ class OutValueFunctionBlock(nn.Module): ...@@ -245,7 +251,7 @@ class OutValueFunctionBlock(nn.Module):
] ]
) )
def forward(self, hidden_states, temb): def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = hidden_states.view(hidden_states.shape[0], -1) hidden_states = hidden_states.view(hidden_states.shape[0], -1)
hidden_states = torch.cat((hidden_states, temb), dim=-1) hidden_states = torch.cat((hidden_states, temb), dim=-1)
for layer in self.final_block: for layer in self.final_block:
...@@ -275,14 +281,14 @@ _kernels = { ...@@ -275,14 +281,14 @@ _kernels = {
class Downsample1d(nn.Module): class Downsample1d(nn.Module):
def __init__(self, kernel="linear", pad_mode="reflect"): def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
super().__init__() super().__init__()
self.pad_mode = pad_mode self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel]) kernel_1d = torch.tensor(_kernels[kernel])
self.pad = kernel_1d.shape[0] // 2 - 1 self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer("kernel", kernel_1d) self.register_buffer("kernel", kernel_1d)
def forward(self, hidden_states): def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode) hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
...@@ -292,14 +298,14 @@ class Downsample1d(nn.Module): ...@@ -292,14 +298,14 @@ class Downsample1d(nn.Module):
class Upsample1d(nn.Module): class Upsample1d(nn.Module):
def __init__(self, kernel="linear", pad_mode="reflect"): def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
super().__init__() super().__init__()
self.pad_mode = pad_mode self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel]) * 2 kernel_1d = torch.tensor(_kernels[kernel]) * 2
self.pad = kernel_1d.shape[0] // 2 - 1 self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer("kernel", kernel_1d) self.register_buffer("kernel", kernel_1d)
def forward(self, hidden_states, temb=None): def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
...@@ -309,7 +315,7 @@ class Upsample1d(nn.Module): ...@@ -309,7 +315,7 @@ class Upsample1d(nn.Module):
class SelfAttention1d(nn.Module): class SelfAttention1d(nn.Module):
def __init__(self, in_channels, n_head=1, dropout_rate=0.0): def __init__(self, in_channels: int, n_head: int = 1, dropout_rate: float = 0.0):
super().__init__() super().__init__()
self.channels = in_channels self.channels = in_channels
self.group_norm = nn.GroupNorm(1, num_channels=in_channels) self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
...@@ -329,7 +335,7 @@ class SelfAttention1d(nn.Module): ...@@ -329,7 +335,7 @@ class SelfAttention1d(nn.Module):
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
return new_projection return new_projection
def forward(self, hidden_states): def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
residual = hidden_states residual = hidden_states
batch, channel_dim, seq = hidden_states.shape batch, channel_dim, seq = hidden_states.shape
...@@ -367,7 +373,7 @@ class SelfAttention1d(nn.Module): ...@@ -367,7 +373,7 @@ class SelfAttention1d(nn.Module):
class ResConvBlock(nn.Module): class ResConvBlock(nn.Module):
def __init__(self, in_channels, mid_channels, out_channels, is_last=False): def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_last: bool = False):
super().__init__() super().__init__()
self.is_last = is_last self.is_last = is_last
self.has_conv_skip = in_channels != out_channels self.has_conv_skip = in_channels != out_channels
...@@ -384,7 +390,7 @@ class ResConvBlock(nn.Module): ...@@ -384,7 +390,7 @@ class ResConvBlock(nn.Module):
self.group_norm_2 = nn.GroupNorm(1, out_channels) self.group_norm_2 = nn.GroupNorm(1, out_channels)
self.gelu_2 = nn.GELU() self.gelu_2 = nn.GELU()
def forward(self, hidden_states): def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
hidden_states = self.conv_1(hidden_states) hidden_states = self.conv_1(hidden_states)
...@@ -401,7 +407,7 @@ class ResConvBlock(nn.Module): ...@@ -401,7 +407,7 @@ class ResConvBlock(nn.Module):
class UNetMidBlock1D(nn.Module): class UNetMidBlock1D(nn.Module):
def __init__(self, mid_channels, in_channels, out_channels=None): def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None):
super().__init__() super().__init__()
out_channels = in_channels if out_channels is None else out_channels out_channels = in_channels if out_channels is None else out_channels
...@@ -429,7 +435,7 @@ class UNetMidBlock1D(nn.Module): ...@@ -429,7 +435,7 @@ class UNetMidBlock1D(nn.Module):
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None): def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.down(hidden_states) hidden_states = self.down(hidden_states)
for attn, resnet in zip(self.attentions, self.resnets): for attn, resnet in zip(self.attentions, self.resnets):
hidden_states = resnet(hidden_states) hidden_states = resnet(hidden_states)
...@@ -441,7 +447,7 @@ class UNetMidBlock1D(nn.Module): ...@@ -441,7 +447,7 @@ class UNetMidBlock1D(nn.Module):
class AttnDownBlock1D(nn.Module): class AttnDownBlock1D(nn.Module):
def __init__(self, out_channels, in_channels, mid_channels=None): def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
super().__init__() super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels mid_channels = out_channels if mid_channels is None else mid_channels
...@@ -460,7 +466,7 @@ class AttnDownBlock1D(nn.Module): ...@@ -460,7 +466,7 @@ class AttnDownBlock1D(nn.Module):
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None): def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.down(hidden_states) hidden_states = self.down(hidden_states)
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
...@@ -471,7 +477,7 @@ class AttnDownBlock1D(nn.Module): ...@@ -471,7 +477,7 @@ class AttnDownBlock1D(nn.Module):
class DownBlock1D(nn.Module): class DownBlock1D(nn.Module):
def __init__(self, out_channels, in_channels, mid_channels=None): def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
super().__init__() super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels mid_channels = out_channels if mid_channels is None else mid_channels
...@@ -484,7 +490,7 @@ class DownBlock1D(nn.Module): ...@@ -484,7 +490,7 @@ class DownBlock1D(nn.Module):
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None): def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.down(hidden_states) hidden_states = self.down(hidden_states)
for resnet in self.resnets: for resnet in self.resnets:
...@@ -494,7 +500,7 @@ class DownBlock1D(nn.Module): ...@@ -494,7 +500,7 @@ class DownBlock1D(nn.Module):
class DownBlock1DNoSkip(nn.Module): class DownBlock1DNoSkip(nn.Module):
def __init__(self, out_channels, in_channels, mid_channels=None): def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
super().__init__() super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels mid_channels = out_channels if mid_channels is None else mid_channels
...@@ -506,7 +512,7 @@ class DownBlock1DNoSkip(nn.Module): ...@@ -506,7 +512,7 @@ class DownBlock1DNoSkip(nn.Module):
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None): def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = torch.cat([hidden_states, temb], dim=1) hidden_states = torch.cat([hidden_states, temb], dim=1)
for resnet in self.resnets: for resnet in self.resnets:
hidden_states = resnet(hidden_states) hidden_states = resnet(hidden_states)
...@@ -515,7 +521,7 @@ class DownBlock1DNoSkip(nn.Module): ...@@ -515,7 +521,7 @@ class DownBlock1DNoSkip(nn.Module):
class AttnUpBlock1D(nn.Module): class AttnUpBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None): def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
super().__init__() super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels mid_channels = out_channels if mid_channels is None else mid_channels
...@@ -534,7 +540,12 @@ class AttnUpBlock1D(nn.Module): ...@@ -534,7 +540,12 @@ class AttnUpBlock1D(nn.Module):
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic") self.up = Upsample1d(kernel="cubic")
def forward(self, hidden_states, res_hidden_states_tuple, temb=None): def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
...@@ -548,7 +559,7 @@ class AttnUpBlock1D(nn.Module): ...@@ -548,7 +559,7 @@ class AttnUpBlock1D(nn.Module):
class UpBlock1D(nn.Module): class UpBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None): def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
super().__init__() super().__init__()
mid_channels = in_channels if mid_channels is None else mid_channels mid_channels = in_channels if mid_channels is None else mid_channels
...@@ -561,7 +572,12 @@ class UpBlock1D(nn.Module): ...@@ -561,7 +572,12 @@ class UpBlock1D(nn.Module):
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic") self.up = Upsample1d(kernel="cubic")
def forward(self, hidden_states, res_hidden_states_tuple, temb=None): def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
...@@ -574,7 +590,7 @@ class UpBlock1D(nn.Module): ...@@ -574,7 +590,7 @@ class UpBlock1D(nn.Module):
class UpBlock1DNoSkip(nn.Module): class UpBlock1DNoSkip(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None): def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
super().__init__() super().__init__()
mid_channels = in_channels if mid_channels is None else mid_channels mid_channels = in_channels if mid_channels is None else mid_channels
...@@ -586,7 +602,12 @@ class UpBlock1DNoSkip(nn.Module): ...@@ -586,7 +602,12 @@ class UpBlock1DNoSkip(nn.Module):
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, res_hidden_states_tuple, temb=None): def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
...@@ -596,7 +617,20 @@ class UpBlock1DNoSkip(nn.Module): ...@@ -596,7 +617,20 @@ class UpBlock1DNoSkip(nn.Module):
return hidden_states return hidden_states
def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample): DownBlockType = Union[DownResnetBlock1D, DownBlock1D, AttnDownBlock1D, DownBlock1DNoSkip]
MidBlockType = Union[MidResTemporalBlock1D, ValueFunctionMidBlock1D, UNetMidBlock1D]
OutBlockType = Union[OutConv1DBlock, OutValueFunctionBlock]
UpBlockType = Union[UpResnetBlock1D, UpBlock1D, AttnUpBlock1D, UpBlock1DNoSkip]
def get_down_block(
down_block_type: str,
num_layers: int,
in_channels: int,
out_channels: int,
temb_channels: int,
add_downsample: bool,
) -> DownBlockType:
if down_block_type == "DownResnetBlock1D": if down_block_type == "DownResnetBlock1D":
return DownResnetBlock1D( return DownResnetBlock1D(
in_channels=in_channels, in_channels=in_channels,
...@@ -614,7 +648,9 @@ def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_ ...@@ -614,7 +648,9 @@ def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_
raise ValueError(f"{down_block_type} does not exist.") raise ValueError(f"{down_block_type} does not exist.")
def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_channels, add_upsample): def get_up_block(
up_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_upsample: bool
) -> UpBlockType:
if up_block_type == "UpResnetBlock1D": if up_block_type == "UpResnetBlock1D":
return UpResnetBlock1D( return UpResnetBlock1D(
in_channels=in_channels, in_channels=in_channels,
...@@ -632,7 +668,15 @@ def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_chan ...@@ -632,7 +668,15 @@ def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_chan
raise ValueError(f"{up_block_type} does not exist.") raise ValueError(f"{up_block_type} does not exist.")
def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels, embed_dim, add_downsample): def get_mid_block(
mid_block_type: str,
num_layers: int,
in_channels: int,
mid_channels: int,
out_channels: int,
embed_dim: int,
add_downsample: bool,
) -> MidBlockType:
if mid_block_type == "MidResTemporalBlock1D": if mid_block_type == "MidResTemporalBlock1D":
return MidResTemporalBlock1D( return MidResTemporalBlock1D(
num_layers=num_layers, num_layers=num_layers,
...@@ -648,7 +692,9 @@ def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_cha ...@@ -648,7 +692,9 @@ def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_cha
raise ValueError(f"{mid_block_type} does not exist.") raise ValueError(f"{mid_block_type} does not exist.")
def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim): def get_out_block(
*, out_block_type: str, num_groups_out: int, embed_dim: int, out_channels: int, act_fn: str, fc_dim: int
) -> Optional[OutBlockType]:
if out_block_type == "OutConv1DBlock": if out_block_type == "OutConv1DBlock":
return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn) return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
elif out_block_type == "ValueFunction": elif out_block_type == "ValueFunction":
......
...@@ -32,31 +32,31 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -32,31 +32,31 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def get_down_block( def get_down_block(
down_block_type, down_block_type: str,
num_layers, num_layers: int,
in_channels, in_channels: int,
out_channels, out_channels: int,
temb_channels, temb_channels: int,
add_downsample, add_downsample: bool,
resnet_eps, resnet_eps: float,
resnet_act_fn, resnet_act_fn: str,
transformer_layers_per_block=1, transformer_layers_per_block: int = 1,
num_attention_heads=None, num_attention_heads: Optional[int] = None,
resnet_groups=None, resnet_groups: Optional[int] = None,
cross_attention_dim=None, cross_attention_dim: Optional[int] = None,
downsample_padding=None, downsample_padding: Optional[int] = None,
dual_cross_attention=False, dual_cross_attention: bool = False,
use_linear_projection=False, use_linear_projection: bool = False,
only_cross_attention=False, only_cross_attention: bool = False,
upcast_attention=False, upcast_attention: bool = False,
resnet_time_scale_shift="default", resnet_time_scale_shift: str = "default",
attention_type="default", attention_type: str = "default",
resnet_skip_time_act=False, resnet_skip_time_act: bool = False,
resnet_out_scale_factor=1.0, resnet_out_scale_factor: float = 1.0,
cross_attention_norm=None, cross_attention_norm: Optional[str] = None,
attention_head_dim=None, attention_head_dim: Optional[int] = None,
downsample_type=None, downsample_type: Optional[str] = None,
dropout=0.0, dropout: float = 0.0,
): ):
# If attn head dim is not defined, we default it to the number of heads # If attn head dim is not defined, we default it to the number of heads
if attention_head_dim is None: if attention_head_dim is None:
...@@ -241,33 +241,33 @@ def get_down_block( ...@@ -241,33 +241,33 @@ def get_down_block(
def get_up_block( def get_up_block(
up_block_type, up_block_type: str,
num_layers, num_layers: int,
in_channels, in_channels: int,
out_channels, out_channels: int,
prev_output_channel, prev_output_channel: int,
temb_channels, temb_channels: int,
add_upsample, add_upsample: bool,
resnet_eps, resnet_eps: float,
resnet_act_fn, resnet_act_fn: str,
resolution_idx=None, resolution_idx: Optional[int] = None,
transformer_layers_per_block=1, transformer_layers_per_block: int = 1,
num_attention_heads=None, num_attention_heads: Optional[int] = None,
resnet_groups=None, resnet_groups: Optional[int] = None,
cross_attention_dim=None, cross_attention_dim: Optional[int] = None,
dual_cross_attention=False, dual_cross_attention: bool = False,
use_linear_projection=False, use_linear_projection: bool = False,
only_cross_attention=False, only_cross_attention: bool = False,
upcast_attention=False, upcast_attention: bool = False,
resnet_time_scale_shift="default", resnet_time_scale_shift: str = "default",
attention_type="default", attention_type: str = "default",
resnet_skip_time_act=False, resnet_skip_time_act: bool = False,
resnet_out_scale_factor=1.0, resnet_out_scale_factor: float = 1.0,
cross_attention_norm=None, cross_attention_norm: Optional[str] = None,
attention_head_dim=None, attention_head_dim: Optional[int] = None,
upsample_type=None, upsample_type: Optional[str] = None,
dropout=0.0, dropout: float = 0.0,
): ) -> nn.Module:
# If attn head dim is not defined, we default it to the number of heads # If attn head dim is not defined, we default it to the number of heads
if attention_head_dim is None: if attention_head_dim is None:
logger.warn( logger.warn(
...@@ -498,7 +498,7 @@ class AutoencoderTinyBlock(nn.Module): ...@@ -498,7 +498,7 @@ class AutoencoderTinyBlock(nn.Module):
) )
self.fuse = nn.ReLU() self.fuse = nn.ReLU()
def forward(self, x): def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
return self.fuse(self.conv(x) + self.skip(x)) return self.fuse(self.conv(x) + self.skip(x))
...@@ -546,8 +546,8 @@ class UNetMidBlock2D(nn.Module): ...@@ -546,8 +546,8 @@ class UNetMidBlock2D(nn.Module):
attn_groups: Optional[int] = None, attn_groups: Optional[int] = None,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
add_attention: bool = True, add_attention: bool = True,
attention_head_dim=1, attention_head_dim: int = 1,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
): ):
super().__init__() super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
...@@ -617,7 +617,7 @@ class UNetMidBlock2D(nn.Module): ...@@ -617,7 +617,7 @@ class UNetMidBlock2D(nn.Module):
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None): def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None: if attn is not None:
...@@ -640,13 +640,13 @@ class UNetMidBlock2DCrossAttn(nn.Module): ...@@ -640,13 +640,13 @@ class UNetMidBlock2DCrossAttn(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
num_attention_heads=1, num_attention_heads: int = 1,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
cross_attention_dim=1280, cross_attention_dim: int = 1280,
dual_cross_attention=False, dual_cross_attention: bool = False,
use_linear_projection=False, use_linear_projection: bool = False,
upcast_attention=False, upcast_attention: bool = False,
attention_type="default", attention_type: str = "default",
): ):
super().__init__() super().__init__()
...@@ -785,12 +785,12 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module): ...@@ -785,12 +785,12 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
attention_head_dim=1, attention_head_dim: int = 1,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
cross_attention_dim=1280, cross_attention_dim: int = 1280,
skip_time_act=False, skip_time_act: bool = False,
only_cross_attention=False, only_cross_attention: bool = False,
cross_attention_norm=None, cross_attention_norm: Optional[str] = None,
): ):
super().__init__() super().__init__()
...@@ -866,7 +866,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module): ...@@ -866,7 +866,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
): ) -> torch.FloatTensor:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
lora_scale = cross_attention_kwargs.get("scale", 1.0) lora_scale = cross_attention_kwargs.get("scale", 1.0)
...@@ -910,10 +910,10 @@ class AttnDownBlock2D(nn.Module): ...@@ -910,10 +910,10 @@ class AttnDownBlock2D(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
attention_head_dim=1, attention_head_dim: int = 1,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
downsample_padding=1, downsample_padding: int = 1,
downsample_type="conv", downsample_type: str = "conv",
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -989,7 +989,13 @@ class AttnDownBlock2D(nn.Module): ...@@ -989,7 +989,13 @@ class AttnDownBlock2D(nn.Module):
else: else:
self.downsamplers = None self.downsamplers = None
def forward(self, hidden_states, temb=None, upsample_size=None, cross_attention_kwargs=None): def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
upsample_size: Optional[int] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
lora_scale = cross_attention_kwargs.get("scale", 1.0) lora_scale = cross_attention_kwargs.get("scale", 1.0)
...@@ -1028,16 +1034,16 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -1028,16 +1034,16 @@ class CrossAttnDownBlock2D(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
num_attention_heads=1, num_attention_heads: int = 1,
cross_attention_dim=1280, cross_attention_dim: int = 1280,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
downsample_padding=1, downsample_padding: int = 1,
add_downsample=True, add_downsample: bool = True,
dual_cross_attention=False, dual_cross_attention: bool = False,
use_linear_projection=False, use_linear_projection: bool = False,
only_cross_attention=False, only_cross_attention: bool = False,
upcast_attention=False, upcast_attention: bool = False,
attention_type="default", attention_type: str = "default",
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -1114,8 +1120,8 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -1114,8 +1120,8 @@ class CrossAttnDownBlock2D(nn.Module):
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
additional_residuals=None, additional_residuals: Optional[torch.FloatTensor] = None,
): ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = () output_states = ()
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
...@@ -1188,9 +1194,9 @@ class DownBlock2D(nn.Module): ...@@ -1188,9 +1194,9 @@ class DownBlock2D(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
add_downsample=True, add_downsample: bool = True,
downsample_padding=1, downsample_padding: int = 1,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -1227,7 +1233,9 @@ class DownBlock2D(nn.Module): ...@@ -1227,7 +1233,9 @@ class DownBlock2D(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None, scale: float = 1.0): def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = () output_states = ()
for resnet in self.resnets: for resnet in self.resnets:
...@@ -1273,9 +1281,9 @@ class DownEncoderBlock2D(nn.Module): ...@@ -1273,9 +1281,9 @@ class DownEncoderBlock2D(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
add_downsample=True, add_downsample: bool = True,
downsample_padding=1, downsample_padding: int = 1,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -1310,7 +1318,7 @@ class DownEncoderBlock2D(nn.Module): ...@@ -1310,7 +1318,7 @@ class DownEncoderBlock2D(nn.Module):
else: else:
self.downsamplers = None self.downsamplers = None
def forward(self, hidden_states, scale: float = 1.0): def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
for resnet in self.resnets: for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=None, scale=scale) hidden_states = resnet(hidden_states, temb=None, scale=scale)
...@@ -1333,10 +1341,10 @@ class AttnDownEncoderBlock2D(nn.Module): ...@@ -1333,10 +1341,10 @@ class AttnDownEncoderBlock2D(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
attention_head_dim=1, attention_head_dim: int = 1,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
add_downsample=True, add_downsample: bool = True,
downsample_padding=1, downsample_padding: int = 1,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -1393,7 +1401,7 @@ class AttnDownEncoderBlock2D(nn.Module): ...@@ -1393,7 +1401,7 @@ class AttnDownEncoderBlock2D(nn.Module):
else: else:
self.downsamplers = None self.downsamplers = None
def forward(self, hidden_states, scale: float = 1.0): def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb=None, scale=scale) hidden_states = resnet(hidden_states, temb=None, scale=scale)
cross_attention_kwargs = {"scale": scale} cross_attention_kwargs = {"scale": scale}
...@@ -1418,9 +1426,9 @@ class AttnSkipDownBlock2D(nn.Module): ...@@ -1418,9 +1426,9 @@ class AttnSkipDownBlock2D(nn.Module):
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
attention_head_dim=1, attention_head_dim: int = 1,
output_scale_factor=np.sqrt(2.0), output_scale_factor: float = np.sqrt(2.0),
add_downsample=True, add_downsample: bool = True,
): ):
super().__init__() super().__init__()
self.attentions = nn.ModuleList([]) self.attentions = nn.ModuleList([])
...@@ -1487,7 +1495,13 @@ class AttnSkipDownBlock2D(nn.Module): ...@@ -1487,7 +1495,13 @@ class AttnSkipDownBlock2D(nn.Module):
self.downsamplers = None self.downsamplers = None
self.skip_conv = None self.skip_conv = None
def forward(self, hidden_states, temb=None, skip_sample=None, scale: float = 1.0): def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
skip_sample: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
output_states = () output_states = ()
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
...@@ -1520,9 +1534,9 @@ class SkipDownBlock2D(nn.Module): ...@@ -1520,9 +1534,9 @@ class SkipDownBlock2D(nn.Module):
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
output_scale_factor=np.sqrt(2.0), output_scale_factor: float = np.sqrt(2.0),
add_downsample=True, add_downsample: bool = True,
downsample_padding=1, downsample_padding: int = 1,
): ):
super().__init__() super().__init__()
self.resnets = nn.ModuleList([]) self.resnets = nn.ModuleList([])
...@@ -1568,7 +1582,13 @@ class SkipDownBlock2D(nn.Module): ...@@ -1568,7 +1582,13 @@ class SkipDownBlock2D(nn.Module):
self.downsamplers = None self.downsamplers = None
self.skip_conv = None self.skip_conv = None
def forward(self, hidden_states, temb=None, skip_sample=None, scale: float = 1.0): def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
skip_sample: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
output_states = () output_states = ()
for resnet in self.resnets: for resnet in self.resnets:
...@@ -1600,9 +1620,9 @@ class ResnetDownsampleBlock2D(nn.Module): ...@@ -1600,9 +1620,9 @@ class ResnetDownsampleBlock2D(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
add_downsample=True, add_downsample: bool = True,
skip_time_act=False, skip_time_act: bool = False,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -1651,7 +1671,9 @@ class ResnetDownsampleBlock2D(nn.Module): ...@@ -1651,7 +1671,9 @@ class ResnetDownsampleBlock2D(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None, scale: float = 1.0): def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = () output_states = ()
for resnet in self.resnets: for resnet in self.resnets:
...@@ -1698,13 +1720,13 @@ class SimpleCrossAttnDownBlock2D(nn.Module): ...@@ -1698,13 +1720,13 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
attention_head_dim=1, attention_head_dim: int = 1,
cross_attention_dim=1280, cross_attention_dim: int = 1280,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
add_downsample=True, add_downsample: bool = True,
skip_time_act=False, skip_time_act: bool = False,
only_cross_attention=False, only_cross_attention: bool = False,
cross_attention_norm=None, cross_attention_norm: Optional[str] = None,
): ):
super().__init__() super().__init__()
...@@ -1788,7 +1810,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module): ...@@ -1788,7 +1810,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
): ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = () output_states = ()
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
...@@ -1856,7 +1878,7 @@ class KDownBlock2D(nn.Module): ...@@ -1856,7 +1878,7 @@ class KDownBlock2D(nn.Module):
resnet_eps: float = 1e-5, resnet_eps: float = 1e-5,
resnet_act_fn: str = "gelu", resnet_act_fn: str = "gelu",
resnet_group_size: int = 32, resnet_group_size: int = 32,
add_downsample=False, add_downsample: bool = False,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -1891,7 +1913,9 @@ class KDownBlock2D(nn.Module): ...@@ -1891,7 +1913,9 @@ class KDownBlock2D(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None, scale: float = 1.0): def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = () output_states = ()
for resnet in self.resnets: for resnet in self.resnets:
...@@ -1933,7 +1957,7 @@ class KCrossAttnDownBlock2D(nn.Module): ...@@ -1933,7 +1957,7 @@ class KCrossAttnDownBlock2D(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
num_layers: int = 4, num_layers: int = 4,
resnet_group_size: int = 32, resnet_group_size: int = 32,
add_downsample=True, add_downsample: bool = True,
attention_head_dim: int = 64, attention_head_dim: int = 64,
add_self_attention: bool = False, add_self_attention: bool = False,
resnet_eps: float = 1e-5, resnet_eps: float = 1e-5,
...@@ -1996,7 +2020,7 @@ class KCrossAttnDownBlock2D(nn.Module): ...@@ -1996,7 +2020,7 @@ class KCrossAttnDownBlock2D(nn.Module):
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
): ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = () output_states = ()
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
...@@ -2065,9 +2089,9 @@ class AttnUpBlock2D(nn.Module): ...@@ -2065,9 +2089,9 @@ class AttnUpBlock2D(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
attention_head_dim=1, attention_head_dim: int = 1,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
upsample_type="conv", upsample_type: str = "conv",
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -2142,7 +2166,14 @@ class AttnUpBlock2D(nn.Module): ...@@ -2142,7 +2166,14 @@ class AttnUpBlock2D(nn.Module):
self.resolution_idx = resolution_idx self.resolution_idx = resolution_idx
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
upsample_size: Optional[int] = None,
scale: float = 1.0,
) -> torch.FloatTensor:
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states # pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
...@@ -2170,7 +2201,7 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -2170,7 +2201,7 @@ class CrossAttnUpBlock2D(nn.Module):
out_channels: int, out_channels: int,
prev_output_channel: int, prev_output_channel: int,
temb_channels: int, temb_channels: int,
resolution_idx: int = None, resolution_idx: Optional[int] = None,
dropout: float = 0.0, dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1,
...@@ -2179,15 +2210,15 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -2179,15 +2210,15 @@ class CrossAttnUpBlock2D(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
num_attention_heads=1, num_attention_heads: int = 1,
cross_attention_dim=1280, cross_attention_dim: int = 1280,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
add_upsample=True, add_upsample: bool = True,
dual_cross_attention=False, dual_cross_attention: bool = False,
use_linear_projection=False, use_linear_projection: bool = False,
only_cross_attention=False, only_cross_attention: bool = False,
upcast_attention=False, upcast_attention: bool = False,
attention_type="default", attention_type: str = "default",
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -2264,7 +2295,7 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -2264,7 +2295,7 @@ class CrossAttnUpBlock2D(nn.Module):
upsample_size: Optional[int] = None, upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
): ) -> torch.FloatTensor:
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
is_freeu_enabled = ( is_freeu_enabled = (
getattr(self, "s1", None) getattr(self, "s1", None)
...@@ -2343,7 +2374,7 @@ class UpBlock2D(nn.Module): ...@@ -2343,7 +2374,7 @@ class UpBlock2D(nn.Module):
prev_output_channel: int, prev_output_channel: int,
out_channels: int, out_channels: int,
temb_channels: int, temb_channels: int,
resolution_idx: int = None, resolution_idx: Optional[int] = None,
dropout: float = 0.0, dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
resnet_eps: float = 1e-6, resnet_eps: float = 1e-6,
...@@ -2351,8 +2382,8 @@ class UpBlock2D(nn.Module): ...@@ -2351,8 +2382,8 @@ class UpBlock2D(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
add_upsample=True, add_upsample: bool = True,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -2386,7 +2417,14 @@ class UpBlock2D(nn.Module): ...@@ -2386,7 +2417,14 @@ class UpBlock2D(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
self.resolution_idx = resolution_idx self.resolution_idx = resolution_idx
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
upsample_size: Optional[int] = None,
scale: float = 1.0,
) -> torch.FloatTensor:
is_freeu_enabled = ( is_freeu_enabled = (
getattr(self, "s1", None) getattr(self, "s1", None)
and getattr(self, "s2", None) and getattr(self, "s2", None)
...@@ -2444,7 +2482,7 @@ class UpDecoderBlock2D(nn.Module): ...@@ -2444,7 +2482,7 @@ class UpDecoderBlock2D(nn.Module):
self, self,
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
resolution_idx: int = None, resolution_idx: Optional[int] = None,
dropout: float = 0.0, dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
resnet_eps: float = 1e-6, resnet_eps: float = 1e-6,
...@@ -2452,9 +2490,9 @@ class UpDecoderBlock2D(nn.Module): ...@@ -2452,9 +2490,9 @@ class UpDecoderBlock2D(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
add_upsample=True, add_upsample: bool = True,
temb_channels=None, temb_channels: Optional[int] = None,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -2486,7 +2524,9 @@ class UpDecoderBlock2D(nn.Module): ...@@ -2486,7 +2524,9 @@ class UpDecoderBlock2D(nn.Module):
self.resolution_idx = resolution_idx self.resolution_idx = resolution_idx
def forward(self, hidden_states, temb=None, scale: float = 1.0): def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
) -> torch.FloatTensor:
for resnet in self.resnets: for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=temb, scale=scale) hidden_states = resnet(hidden_states, temb=temb, scale=scale)
...@@ -2502,7 +2542,7 @@ class AttnUpDecoderBlock2D(nn.Module): ...@@ -2502,7 +2542,7 @@ class AttnUpDecoderBlock2D(nn.Module):
self, self,
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
resolution_idx: int = None, resolution_idx: Optional[int] = None,
dropout: float = 0.0, dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
resnet_eps: float = 1e-6, resnet_eps: float = 1e-6,
...@@ -2510,10 +2550,10 @@ class AttnUpDecoderBlock2D(nn.Module): ...@@ -2510,10 +2550,10 @@ class AttnUpDecoderBlock2D(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
attention_head_dim=1, attention_head_dim: int = 1,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
add_upsample=True, add_upsample: bool = True,
temb_channels=None, temb_channels: Optional[int] = None,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -2568,7 +2608,9 @@ class AttnUpDecoderBlock2D(nn.Module): ...@@ -2568,7 +2608,9 @@ class AttnUpDecoderBlock2D(nn.Module):
self.resolution_idx = resolution_idx self.resolution_idx = resolution_idx
def forward(self, hidden_states, temb=None, scale: float = 1.0): def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
) -> torch.FloatTensor:
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb=temb, scale=scale) hidden_states = resnet(hidden_states, temb=temb, scale=scale)
cross_attention_kwargs = {"scale": scale} cross_attention_kwargs = {"scale": scale}
...@@ -2588,16 +2630,16 @@ class AttnSkipUpBlock2D(nn.Module): ...@@ -2588,16 +2630,16 @@ class AttnSkipUpBlock2D(nn.Module):
prev_output_channel: int, prev_output_channel: int,
out_channels: int, out_channels: int,
temb_channels: int, temb_channels: int,
resolution_idx: int = None, resolution_idx: Optional[int] = None,
dropout: float = 0.0, dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
resnet_eps: float = 1e-6, resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
attention_head_dim=1, attention_head_dim: int = 1,
output_scale_factor=np.sqrt(2.0), output_scale_factor: float = np.sqrt(2.0),
add_upsample=True, add_upsample: bool = True,
): ):
super().__init__() super().__init__()
self.attentions = nn.ModuleList([]) self.attentions = nn.ModuleList([])
...@@ -2675,7 +2717,14 @@ class AttnSkipUpBlock2D(nn.Module): ...@@ -2675,7 +2717,14 @@ class AttnSkipUpBlock2D(nn.Module):
self.resolution_idx = resolution_idx self.resolution_idx = resolution_idx
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None, scale: float = 1.0): def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
skip_sample=None,
scale: float = 1.0,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
for resnet in self.resnets: for resnet in self.resnets:
# pop res hidden states # pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
...@@ -2711,16 +2760,16 @@ class SkipUpBlock2D(nn.Module): ...@@ -2711,16 +2760,16 @@ class SkipUpBlock2D(nn.Module):
prev_output_channel: int, prev_output_channel: int,
out_channels: int, out_channels: int,
temb_channels: int, temb_channels: int,
resolution_idx: int = None, resolution_idx: Optional[int] = None,
dropout: float = 0.0, dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
resnet_eps: float = 1e-6, resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
output_scale_factor=np.sqrt(2.0), output_scale_factor: float = np.sqrt(2.0),
add_upsample=True, add_upsample: bool = True,
upsample_padding=1, upsample_padding: int = 1,
): ):
super().__init__() super().__init__()
self.resnets = nn.ModuleList([]) self.resnets = nn.ModuleList([])
...@@ -2776,7 +2825,14 @@ class SkipUpBlock2D(nn.Module): ...@@ -2776,7 +2825,14 @@ class SkipUpBlock2D(nn.Module):
self.resolution_idx = resolution_idx self.resolution_idx = resolution_idx
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None, scale: float = 1.0): def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
skip_sample=None,
scale: float = 1.0,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
for resnet in self.resnets: for resnet in self.resnets:
# pop res hidden states # pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
...@@ -2809,7 +2865,7 @@ class ResnetUpsampleBlock2D(nn.Module): ...@@ -2809,7 +2865,7 @@ class ResnetUpsampleBlock2D(nn.Module):
prev_output_channel: int, prev_output_channel: int,
out_channels: int, out_channels: int,
temb_channels: int, temb_channels: int,
resolution_idx: int = None, resolution_idx: Optional[int] = None,
dropout: float = 0.0, dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
resnet_eps: float = 1e-6, resnet_eps: float = 1e-6,
...@@ -2817,9 +2873,9 @@ class ResnetUpsampleBlock2D(nn.Module): ...@@ -2817,9 +2873,9 @@ class ResnetUpsampleBlock2D(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
add_upsample=True, add_upsample: bool = True,
skip_time_act=False, skip_time_act: bool = False,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -2871,7 +2927,14 @@ class ResnetUpsampleBlock2D(nn.Module): ...@@ -2871,7 +2927,14 @@ class ResnetUpsampleBlock2D(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
self.resolution_idx = resolution_idx self.resolution_idx = resolution_idx
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
upsample_size: Optional[int] = None,
scale: float = 1.0,
) -> torch.FloatTensor:
for resnet in self.resnets: for resnet in self.resnets:
# pop res hidden states # pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states = res_hidden_states_tuple[-1]
...@@ -2911,7 +2974,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module): ...@@ -2911,7 +2974,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
out_channels: int, out_channels: int,
prev_output_channel: int, prev_output_channel: int,
temb_channels: int, temb_channels: int,
resolution_idx: int = None, resolution_idx: Optional[int] = None,
dropout: float = 0.0, dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
resnet_eps: float = 1e-6, resnet_eps: float = 1e-6,
...@@ -2919,13 +2982,13 @@ class SimpleCrossAttnUpBlock2D(nn.Module): ...@@ -2919,13 +2982,13 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
attention_head_dim=1, attention_head_dim: int = 1,
cross_attention_dim=1280, cross_attention_dim: int = 1280,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
add_upsample=True, add_upsample: bool = True,
skip_time_act=False, skip_time_act: bool = False,
only_cross_attention=False, only_cross_attention: bool = False,
cross_attention_norm=None, cross_attention_norm: Optional[str] = None,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -3013,7 +3076,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module): ...@@ -3013,7 +3076,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
): ) -> torch.FloatTensor:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
lora_scale = cross_attention_kwargs.get("scale", 1.0) lora_scale = cross_attention_kwargs.get("scale", 1.0)
...@@ -3082,7 +3145,7 @@ class KUpBlock2D(nn.Module): ...@@ -3082,7 +3145,7 @@ class KUpBlock2D(nn.Module):
resnet_eps: float = 1e-5, resnet_eps: float = 1e-5,
resnet_act_fn: str = "gelu", resnet_act_fn: str = "gelu",
resnet_group_size: Optional[int] = 32, resnet_group_size: Optional[int] = 32,
add_upsample=True, add_upsample: bool = True,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -3120,7 +3183,14 @@ class KUpBlock2D(nn.Module): ...@@ -3120,7 +3183,14 @@ class KUpBlock2D(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
self.resolution_idx = resolution_idx self.resolution_idx = resolution_idx
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
upsample_size: Optional[int] = None,
scale: float = 1.0,
) -> torch.FloatTensor:
res_hidden_states_tuple = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[-1]
if res_hidden_states_tuple is not None: if res_hidden_states_tuple is not None:
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
...@@ -3164,7 +3234,7 @@ class KCrossAttnUpBlock2D(nn.Module): ...@@ -3164,7 +3234,7 @@ class KCrossAttnUpBlock2D(nn.Module):
resnet_eps: float = 1e-5, resnet_eps: float = 1e-5,
resnet_act_fn: str = "gelu", resnet_act_fn: str = "gelu",
resnet_group_size: int = 32, resnet_group_size: int = 32,
attention_head_dim=1, # attention dim_head attention_head_dim: int = 1, # attention dim_head
cross_attention_dim: int = 768, cross_attention_dim: int = 768,
add_upsample: bool = True, add_upsample: bool = True,
upcast_attention: bool = False, upcast_attention: bool = False,
...@@ -3248,7 +3318,7 @@ class KCrossAttnUpBlock2D(nn.Module): ...@@ -3248,7 +3318,7 @@ class KCrossAttnUpBlock2D(nn.Module):
upsample_size: Optional[int] = None, upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
): ) -> torch.FloatTensor:
res_hidden_states_tuple = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[-1]
if res_hidden_states_tuple is not None: if res_hidden_states_tuple is not None:
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
...@@ -3310,11 +3380,18 @@ class KAttentionBlock(nn.Module): ...@@ -3310,11 +3380,18 @@ class KAttentionBlock(nn.Module):
attention_head_dim (`int`): The number of channels in each head. attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. attention_bias (`bool`, *optional*, defaults to `False`):
num_embeds_ada_norm (: Configure if the attention layers should contain a bias parameter.
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. upcast_attention (`bool`, *optional*, defaults to `False`):
attention_bias (: Set to `True` to upcast the attention computation to `float32`.
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. temb_channels (`int`, *optional*, defaults to 768):
The number of channels in the token embedding.
add_self_attention (`bool`, *optional*, defaults to `False`):
Set to `True` to add self-attention to the block.
cross_attention_norm (`str`, *optional*, defaults to `None`):
The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
group_size (`int`, *optional*, defaults to 32):
The number of groups to separate the channels into for group normalization.
""" """
def __init__( def __init__(
...@@ -3360,10 +3437,10 @@ class KAttentionBlock(nn.Module): ...@@ -3360,10 +3437,10 @@ class KAttentionBlock(nn.Module):
cross_attention_norm=cross_attention_norm, cross_attention_norm=cross_attention_norm,
) )
def _to_3d(self, hidden_states, height, weight): def _to_3d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor:
return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1) return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1)
def _to_4d(self, hidden_states, height, weight): def _to_4d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor:
return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight) return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight)
def forward( def forward(
...@@ -3376,7 +3453,7 @@ class KAttentionBlock(nn.Module): ...@@ -3376,7 +3453,7 @@ class KAttentionBlock(nn.Module):
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
): ) -> torch.FloatTensor:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
# 1. Self-Attention # 1. Self-Attention
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional, Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -27,7 +27,7 @@ from .unet_2d_blocks import AutoencoderTinyBlock, UNetMidBlock2D, get_down_block ...@@ -27,7 +27,7 @@ from .unet_2d_blocks import AutoencoderTinyBlock, UNetMidBlock2D, get_down_block
@dataclass @dataclass
class DecoderOutput(BaseOutput): class DecoderOutput(BaseOutput):
""" r"""
Output of decoding method. Output of decoding method.
Args: Args:
...@@ -39,16 +39,39 @@ class DecoderOutput(BaseOutput): ...@@ -39,16 +39,39 @@ class DecoderOutput(BaseOutput):
class Encoder(nn.Module): class Encoder(nn.Module):
r"""
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
double_z (`bool`, *optional*, defaults to `True`):
Whether to double the number of output channels for the last block.
"""
def __init__( def __init__(
self, self,
in_channels=3, in_channels: int = 3,
out_channels=3, out_channels: int = 3,
down_block_types=("DownEncoderBlock2D",), down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
block_out_channels=(64,), block_out_channels: Tuple[int, ...] = (64,),
layers_per_block=2, layers_per_block: int = 2,
norm_num_groups=32, norm_num_groups: int = 32,
act_fn="silu", act_fn: str = "silu",
double_z=True, double_z: bool = True,
): ):
super().__init__() super().__init__()
self.layers_per_block = layers_per_block self.layers_per_block = layers_per_block
...@@ -107,7 +130,8 @@ class Encoder(nn.Module): ...@@ -107,7 +130,8 @@ class Encoder(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, x): def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
sample = x sample = x
sample = self.conv_in(sample) sample = self.conv_in(sample)
...@@ -152,16 +176,38 @@ class Encoder(nn.Module): ...@@ -152,16 +176,38 @@ class Encoder(nn.Module):
class Decoder(nn.Module): class Decoder(nn.Module):
r"""
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
norm_type (`str`, *optional*, defaults to `"group"`):
The normalization type to use. Can be either `"group"` or `"spatial"`.
"""
def __init__( def __init__(
self, self,
in_channels=3, in_channels: int = 3,
out_channels=3, out_channels: int = 3,
up_block_types=("UpDecoderBlock2D",), up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
block_out_channels=(64,), block_out_channels: Tuple[int, ...] = (64,),
layers_per_block=2, layers_per_block: int = 2,
norm_num_groups=32, norm_num_groups: int = 32,
act_fn="silu", act_fn: str = "silu",
norm_type="group", # group, spatial norm_type: str = "group", # group, spatial
): ):
super().__init__() super().__init__()
self.layers_per_block = layers_per_block self.layers_per_block = layers_per_block
...@@ -227,7 +273,8 @@ class Decoder(nn.Module): ...@@ -227,7 +273,8 @@ class Decoder(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, z, latent_embeds=None): def forward(self, z: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
sample = z sample = z
sample = self.conv_in(sample) sample = self.conv_in(sample)
...@@ -283,6 +330,16 @@ class Decoder(nn.Module): ...@@ -283,6 +330,16 @@ class Decoder(nn.Module):
class UpSample(nn.Module): class UpSample(nn.Module):
r"""
The `UpSample` layer of a variational autoencoder that upsamples its input.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
"""
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
...@@ -294,6 +351,7 @@ class UpSample(nn.Module): ...@@ -294,6 +351,7 @@ class UpSample(nn.Module):
self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1) self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `UpSample` class."""
x = torch.relu(x) x = torch.relu(x)
x = self.deconv(x) x = self.deconv(x)
return x return x
...@@ -342,6 +400,7 @@ class MaskConditionEncoder(nn.Module): ...@@ -342,6 +400,7 @@ class MaskConditionEncoder(nn.Module):
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
def forward(self, x: torch.FloatTensor, mask=None) -> torch.FloatTensor: def forward(self, x: torch.FloatTensor, mask=None) -> torch.FloatTensor:
r"""The forward method of the `MaskConditionEncoder` class."""
out = {} out = {}
for l in range(len(self.layers)): for l in range(len(self.layers)):
layer = self.layers[l] layer = self.layers[l]
...@@ -352,19 +411,38 @@ class MaskConditionEncoder(nn.Module): ...@@ -352,19 +411,38 @@ class MaskConditionEncoder(nn.Module):
class MaskConditionDecoder(nn.Module): class MaskConditionDecoder(nn.Module):
"""The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's r"""The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's
decoder with a conditioner on the mask and masked image.""" decoder with a conditioner on the mask and masked image.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
norm_type (`str`, *optional*, defaults to `"group"`):
The normalization type to use. Can be either `"group"` or `"spatial"`.
"""
def __init__( def __init__(
self, self,
in_channels=3, in_channels: int = 3,
out_channels=3, out_channels: int = 3,
up_block_types=("UpDecoderBlock2D",), up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
block_out_channels=(64,), block_out_channels: Tuple[int, ...] = (64,),
layers_per_block=2, layers_per_block: int = 2,
norm_num_groups=32, norm_num_groups: int = 32,
act_fn="silu", act_fn: str = "silu",
norm_type="group", # group, spatial norm_type: str = "group", # group, spatial
): ):
super().__init__() super().__init__()
self.layers_per_block = layers_per_block self.layers_per_block = layers_per_block
...@@ -437,7 +515,14 @@ class MaskConditionDecoder(nn.Module): ...@@ -437,7 +515,14 @@ class MaskConditionDecoder(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, z, image=None, mask=None, latent_embeds=None): def forward(
self,
z: torch.FloatTensor,
image: Optional[torch.FloatTensor] = None,
mask: Optional[torch.FloatTensor] = None,
latent_embeds: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
r"""The forward method of the `MaskConditionDecoder` class."""
sample = z sample = z
sample = self.conv_in(sample) sample = self.conv_in(sample)
...@@ -539,7 +624,14 @@ class VectorQuantizer(nn.Module): ...@@ -539,7 +624,14 @@ class VectorQuantizer(nn.Module):
# backwards compatibility we use the buggy version by default, but you can # backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it. # specify legacy=False to fix it.
def __init__( def __init__(
self, n_e, vq_embed_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True self,
n_e: int,
vq_embed_dim: int,
beta: float,
remap=None,
unknown_index: str = "random",
sane_index_shape: bool = False,
legacy: bool = True,
): ):
super().__init__() super().__init__()
self.n_e = n_e self.n_e = n_e
...@@ -553,6 +645,7 @@ class VectorQuantizer(nn.Module): ...@@ -553,6 +645,7 @@ class VectorQuantizer(nn.Module):
self.remap = remap self.remap = remap
if self.remap is not None: if self.remap is not None:
self.register_buffer("used", torch.tensor(np.load(self.remap))) self.register_buffer("used", torch.tensor(np.load(self.remap)))
self.used: torch.Tensor
self.re_embed = self.used.shape[0] self.re_embed = self.used.shape[0]
self.unknown_index = unknown_index # "random" or "extra" or integer self.unknown_index = unknown_index # "random" or "extra" or integer
if self.unknown_index == "extra": if self.unknown_index == "extra":
...@@ -567,7 +660,7 @@ class VectorQuantizer(nn.Module): ...@@ -567,7 +660,7 @@ class VectorQuantizer(nn.Module):
self.sane_index_shape = sane_index_shape self.sane_index_shape = sane_index_shape
def remap_to_used(self, inds): def remap_to_used(self, inds: torch.LongTensor) -> torch.LongTensor:
ishape = inds.shape ishape = inds.shape
assert len(ishape) > 1 assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1) inds = inds.reshape(ishape[0], -1)
...@@ -581,7 +674,7 @@ class VectorQuantizer(nn.Module): ...@@ -581,7 +674,7 @@ class VectorQuantizer(nn.Module):
new[unknown] = self.unknown_index new[unknown] = self.unknown_index
return new.reshape(ishape) return new.reshape(ishape)
def unmap_to_all(self, inds): def unmap_to_all(self, inds: torch.LongTensor) -> torch.LongTensor:
ishape = inds.shape ishape = inds.shape
assert len(ishape) > 1 assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1) inds = inds.reshape(ishape[0], -1)
...@@ -591,7 +684,7 @@ class VectorQuantizer(nn.Module): ...@@ -591,7 +684,7 @@ class VectorQuantizer(nn.Module):
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
return back.reshape(ishape) return back.reshape(ishape)
def forward(self, z): def forward(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, Tuple]:
# reshape z -> (batch, height, width, channel) and flatten # reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous() z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.vq_embed_dim) z_flattened = z.view(-1, self.vq_embed_dim)
...@@ -610,7 +703,7 @@ class VectorQuantizer(nn.Module): ...@@ -610,7 +703,7 @@ class VectorQuantizer(nn.Module):
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
# preserve gradients # preserve gradients
z_q = z + (z_q - z).detach() z_q: torch.FloatTensor = z + (z_q - z).detach()
# reshape back to match original input shape # reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous() z_q = z_q.permute(0, 3, 1, 2).contiguous()
...@@ -625,7 +718,7 @@ class VectorQuantizer(nn.Module): ...@@ -625,7 +718,7 @@ class VectorQuantizer(nn.Module):
return z_q, loss, (perplexity, min_encodings, min_encoding_indices) return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
def get_codebook_entry(self, indices, shape): def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.FloatTensor:
# shape specifying (batch, height, width, channel) # shape specifying (batch, height, width, channel)
if self.remap is not None: if self.remap is not None:
indices = indices.reshape(shape[0], -1) # add batch axis indices = indices.reshape(shape[0], -1) # add batch axis
...@@ -633,7 +726,7 @@ class VectorQuantizer(nn.Module): ...@@ -633,7 +726,7 @@ class VectorQuantizer(nn.Module):
indices = indices.reshape(-1) # flatten again indices = indices.reshape(-1) # flatten again
# get quantized latent vectors # get quantized latent vectors
z_q = self.embedding(indices) z_q: torch.FloatTensor = self.embedding(indices)
if shape is not None: if shape is not None:
z_q = z_q.view(shape) z_q = z_q.view(shape)
...@@ -644,7 +737,7 @@ class VectorQuantizer(nn.Module): ...@@ -644,7 +737,7 @@ class VectorQuantizer(nn.Module):
class DiagonalGaussianDistribution(object): class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False): def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
self.parameters = parameters self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0) self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
...@@ -664,7 +757,7 @@ class DiagonalGaussianDistribution(object): ...@@ -664,7 +757,7 @@ class DiagonalGaussianDistribution(object):
x = self.mean + self.std * sample x = self.mean + self.std * sample
return x return x
def kl(self, other=None): def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
if self.deterministic: if self.deterministic:
return torch.Tensor([0.0]) return torch.Tensor([0.0])
else: else:
...@@ -680,23 +773,40 @@ class DiagonalGaussianDistribution(object): ...@@ -680,23 +773,40 @@ class DiagonalGaussianDistribution(object):
dim=[1, 2, 3], dim=[1, 2, 3],
) )
def nll(self, sample, dims=[1, 2, 3]): def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
if self.deterministic: if self.deterministic:
return torch.Tensor([0.0]) return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi) logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
def mode(self): def mode(self) -> torch.Tensor:
return self.mean return self.mean
class EncoderTiny(nn.Module): class EncoderTiny(nn.Module):
r"""
The `EncoderTiny` layer is a simpler version of the `Encoder` layer.
Args:
in_channels (`int`):
The number of input channels.
out_channels (`int`):
The number of output channels.
num_blocks (`Tuple[int, ...]`):
Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
use.
block_out_channels (`Tuple[int, ...]`):
The number of output channels for each block.
act_fn (`str`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
"""
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
num_blocks: int, num_blocks: Tuple[int, ...],
block_out_channels: int, block_out_channels: Tuple[int, ...],
act_fn: str, act_fn: str,
): ):
super().__init__() super().__init__()
...@@ -718,7 +828,8 @@ class EncoderTiny(nn.Module): ...@@ -718,7 +828,8 @@ class EncoderTiny(nn.Module):
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, x): def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `EncoderTiny` class."""
if self.training and self.gradient_checkpointing: if self.training and self.gradient_checkpointing:
def create_custom_forward(module): def create_custom_forward(module):
...@@ -740,12 +851,31 @@ class EncoderTiny(nn.Module): ...@@ -740,12 +851,31 @@ class EncoderTiny(nn.Module):
class DecoderTiny(nn.Module): class DecoderTiny(nn.Module):
r"""
The `DecoderTiny` layer is a simpler version of the `Decoder` layer.
Args:
in_channels (`int`):
The number of input channels.
out_channels (`int`):
The number of output channels.
num_blocks (`Tuple[int, ...]`):
Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
use.
block_out_channels (`Tuple[int, ...]`):
The number of output channels for each block.
upsampling_scaling_factor (`int`):
The scaling factor to use for upsampling.
act_fn (`str`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
"""
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
num_blocks: int, num_blocks: Tuple[int, ...],
block_out_channels: int, block_out_channels: Tuple[int, ...],
upsampling_scaling_factor: int, upsampling_scaling_factor: int,
act_fn: str, act_fn: str,
): ):
...@@ -772,7 +902,8 @@ class DecoderTiny(nn.Module): ...@@ -772,7 +902,8 @@ class DecoderTiny(nn.Module):
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, x): def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `DecoderTiny` class."""
# Clamp. # Clamp.
x = torch.tanh(x / 3) * 3 x = torch.tanh(x / 3) * 3
......
...@@ -53,10 +53,12 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -53,10 +53,12 @@ class VQModel(ModelMixin, ConfigMixin):
Tuple of upsample block types. Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
Tuple of block output channels. Tuple of block output channels.
layers_per_block (`int`, *optional*, defaults to `1`): Number of layers per block.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): Sample input size. sample_size (`int`, *optional*, defaults to `32`): Sample input size.
num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
norm_num_groups (`int`, *optional*, defaults to `32`): Number of groups for normalization layers.
vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
scaling_factor (`float`, *optional*, defaults to `0.18215`): scaling_factor (`float`, *optional*, defaults to `0.18215`):
The component-wise standard deviation of the trained latent space computed using the first batch of the The component-wise standard deviation of the trained latent space computed using the first batch of the
...@@ -65,6 +67,8 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -65,6 +67,8 @@ class VQModel(ModelMixin, ConfigMixin):
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
norm_type (`str`, *optional*, defaults to `"group"`):
Type of normalization layer to use. Can be one of `"group"` or `"spatial"`.
""" """
@register_to_config @register_to_config
...@@ -72,9 +76,9 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -72,9 +76,9 @@ class VQModel(ModelMixin, ConfigMixin):
self, self,
in_channels: int = 3, in_channels: int = 3,
out_channels: int = 3, out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",), down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
up_block_types: Tuple[str] = ("UpDecoderBlock2D",), up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int] = (64,), block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 1, layers_per_block: int = 1,
act_fn: str = "silu", act_fn: str = "silu",
latent_channels: int = 3, latent_channels: int = 3,
......
...@@ -106,7 +106,6 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -106,7 +106,6 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
feature_extractor ([`~transformers.CLIPImageProcessor`]): feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
""" """
model_cpu_offload_seq = "text_encoder->unet->vae" model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"] _exclude_from_cpu_offload = ["safety_checker"]
......
...@@ -134,7 +134,6 @@ class AltDiffusionImg2ImgPipeline( ...@@ -134,7 +134,6 @@ class AltDiffusionImg2ImgPipeline(
feature_extractor ([`~transformers.CLIPImageProcessor`]): feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
""" """
model_cpu_offload_seq = "text_encoder->unet->vae" model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"] _exclude_from_cpu_offload = ["safety_checker"]
......
...@@ -1508,9 +1508,9 @@ class DownBlockFlat(nn.Module): ...@@ -1508,9 +1508,9 @@ class DownBlockFlat(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
add_downsample=True, add_downsample: bool = True,
downsample_padding=1, downsample_padding: int = 1,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -1547,7 +1547,9 @@ class DownBlockFlat(nn.Module): ...@@ -1547,7 +1547,9 @@ class DownBlockFlat(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None, scale: float = 1.0): def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = () output_states = ()
for resnet in self.resnets: for resnet in self.resnets:
...@@ -1596,16 +1598,16 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -1596,16 +1598,16 @@ class CrossAttnDownBlockFlat(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
num_attention_heads=1, num_attention_heads: int = 1,
cross_attention_dim=1280, cross_attention_dim: int = 1280,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
downsample_padding=1, downsample_padding: int = 1,
add_downsample=True, add_downsample: bool = True,
dual_cross_attention=False, dual_cross_attention: bool = False,
use_linear_projection=False, use_linear_projection: bool = False,
only_cross_attention=False, only_cross_attention: bool = False,
upcast_attention=False, upcast_attention: bool = False,
attention_type="default", attention_type: str = "default",
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -1682,8 +1684,8 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -1682,8 +1684,8 @@ class CrossAttnDownBlockFlat(nn.Module):
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
additional_residuals=None, additional_residuals: Optional[torch.FloatTensor] = None,
): ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = () output_states = ()
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
...@@ -1751,7 +1753,7 @@ class UpBlockFlat(nn.Module): ...@@ -1751,7 +1753,7 @@ class UpBlockFlat(nn.Module):
prev_output_channel: int, prev_output_channel: int,
out_channels: int, out_channels: int,
temb_channels: int, temb_channels: int,
resolution_idx: int = None, resolution_idx: Optional[int] = None,
dropout: float = 0.0, dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
resnet_eps: float = 1e-6, resnet_eps: float = 1e-6,
...@@ -1759,8 +1761,8 @@ class UpBlockFlat(nn.Module): ...@@ -1759,8 +1761,8 @@ class UpBlockFlat(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
add_upsample=True, add_upsample: bool = True,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -1794,7 +1796,14 @@ class UpBlockFlat(nn.Module): ...@@ -1794,7 +1796,14 @@ class UpBlockFlat(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
self.resolution_idx = resolution_idx self.resolution_idx = resolution_idx
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
upsample_size: Optional[int] = None,
scale: float = 1.0,
) -> torch.FloatTensor:
is_freeu_enabled = ( is_freeu_enabled = (
getattr(self, "s1", None) getattr(self, "s1", None)
and getattr(self, "s2", None) and getattr(self, "s2", None)
...@@ -1855,7 +1864,7 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -1855,7 +1864,7 @@ class CrossAttnUpBlockFlat(nn.Module):
out_channels: int, out_channels: int,
prev_output_channel: int, prev_output_channel: int,
temb_channels: int, temb_channels: int,
resolution_idx: int = None, resolution_idx: Optional[int] = None,
dropout: float = 0.0, dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1,
...@@ -1864,15 +1873,15 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -1864,15 +1873,15 @@ class CrossAttnUpBlockFlat(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
num_attention_heads=1, num_attention_heads: int = 1,
cross_attention_dim=1280, cross_attention_dim: int = 1280,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
add_upsample=True, add_upsample: bool = True,
dual_cross_attention=False, dual_cross_attention: bool = False,
use_linear_projection=False, use_linear_projection: bool = False,
only_cross_attention=False, only_cross_attention: bool = False,
upcast_attention=False, upcast_attention: bool = False,
attention_type="default", attention_type: str = "default",
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -1949,7 +1958,7 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -1949,7 +1958,7 @@ class CrossAttnUpBlockFlat(nn.Module):
upsample_size: Optional[int] = None, upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
): ) -> torch.FloatTensor:
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
is_freeu_enabled = ( is_freeu_enabled = (
getattr(self, "s1", None) getattr(self, "s1", None)
...@@ -2066,8 +2075,8 @@ class UNetMidBlockFlat(nn.Module): ...@@ -2066,8 +2075,8 @@ class UNetMidBlockFlat(nn.Module):
attn_groups: Optional[int] = None, attn_groups: Optional[int] = None,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
add_attention: bool = True, add_attention: bool = True,
attention_head_dim=1, attention_head_dim: int = 1,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
): ):
super().__init__() super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
...@@ -2138,7 +2147,7 @@ class UNetMidBlockFlat(nn.Module): ...@@ -2138,7 +2147,7 @@ class UNetMidBlockFlat(nn.Module):
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None): def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None: if attn is not None:
...@@ -2162,13 +2171,13 @@ class UNetMidBlockFlatCrossAttn(nn.Module): ...@@ -2162,13 +2171,13 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
num_attention_heads=1, num_attention_heads: int = 1,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
cross_attention_dim=1280, cross_attention_dim: int = 1280,
dual_cross_attention=False, dual_cross_attention: bool = False,
use_linear_projection=False, use_linear_projection: bool = False,
upcast_attention=False, upcast_attention: bool = False,
attention_type="default", attention_type: str = "default",
): ):
super().__init__() super().__init__()
...@@ -2308,12 +2317,12 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module): ...@@ -2308,12 +2317,12 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
resnet_act_fn: str = "swish", resnet_act_fn: str = "swish",
resnet_groups: int = 32, resnet_groups: int = 32,
resnet_pre_norm: bool = True, resnet_pre_norm: bool = True,
attention_head_dim=1, attention_head_dim: int = 1,
output_scale_factor=1.0, output_scale_factor: float = 1.0,
cross_attention_dim=1280, cross_attention_dim: int = 1280,
skip_time_act=False, skip_time_act: bool = False,
only_cross_attention=False, only_cross_attention: bool = False,
cross_attention_norm=None, cross_attention_norm: Optional[str] = None,
): ):
super().__init__() super().__init__()
...@@ -2389,7 +2398,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module): ...@@ -2389,7 +2398,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
): ) -> torch.FloatTensor:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
lora_scale = cross_attention_kwargs.get("scale", 1.0) lora_scale = cross_attention_kwargs.get("scale", 1.0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment