"tests/python/common/test_heterograph-kernel.py" did not exist on "3192beb42dc9a31bdfd1d5947484c49224072914"
Unverified Commit 91fd1812 authored by Aryan V S's avatar Aryan V S Committed by GitHub
Browse files

Improve typehints and docs in `diffusers/models` (#5312)



* improvement: add missing typehints and docs to diffusers/models/attention.py

* chore: convert doc strings to raw python strings

add missing typehints

* improvement: add missing typehints and docs to diffusers/models/adapter.py

* improvement: add missing typehints and docs to diffusers/models/lora.py

* docs: include suggestion by @sayakpaul in src/diffusers/models/adapter.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* docs: include suggestion by @sayakpaul in src/diffusers/models/adapter.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* docs: include suggestion by @sayakpaul in src/diffusers/models/adapter.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* docs: include suggestion by @sayakpaul in src/diffusers/models/adapter.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/models/lora.py

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 0fa32bd6
...@@ -231,7 +231,11 @@ class T2IAdapter(ModelMixin, ConfigMixin): ...@@ -231,7 +231,11 @@ class T2IAdapter(ModelMixin, ConfigMixin):
The number of channel of each downsample block's output hidden state. The `len(block_out_channels)` will The number of channel of each downsample block's output hidden state. The `len(block_out_channels)` will
also determine the number of downsample blocks in the Adapter. also determine the number of downsample blocks in the Adapter.
num_res_blocks (`int`, *optional*, defaults to 2): num_res_blocks (`int`, *optional*, defaults to 2):
Number of ResNet blocks in each downsample block Number of ResNet blocks in each downsample block.
downscale_factor (`int`, *optional*, defaults to 8):
A factor that determines the total downscale factor of the Adapter.
adapter_type (`str`, *optional*, defaults to `full_adapter`):
The type of Adapter to use. Choose either `full_adapter` or `full_adapter_xl` or `light_adapter`.
""" """
@register_to_config @register_to_config
...@@ -275,6 +279,10 @@ class T2IAdapter(ModelMixin, ConfigMixin): ...@@ -275,6 +279,10 @@ class T2IAdapter(ModelMixin, ConfigMixin):
class FullAdapter(nn.Module): class FullAdapter(nn.Module):
r"""
See [`T2IAdapter`] for more information.
"""
def __init__( def __init__(
self, self,
in_channels: int = 3, in_channels: int = 3,
...@@ -321,6 +329,10 @@ class FullAdapter(nn.Module): ...@@ -321,6 +329,10 @@ class FullAdapter(nn.Module):
class FullAdapterXL(nn.Module): class FullAdapterXL(nn.Module):
r"""
See [`T2IAdapter`] for more information.
"""
def __init__( def __init__(
self, self,
in_channels: int = 3, in_channels: int = 3,
...@@ -367,7 +379,22 @@ class FullAdapterXL(nn.Module): ...@@ -367,7 +379,22 @@ class FullAdapterXL(nn.Module):
class AdapterBlock(nn.Module): class AdapterBlock(nn.Module):
def __init__(self, in_channels, out_channels, num_res_blocks, down=False): r"""
An AdapterBlock is a helper model that contains multiple ResNet-like blocks. It is used in the `FullAdapter` and
`FullAdapterXL` models.
Parameters:
in_channels (`int`):
Number of channels of AdapterBlock's input.
out_channels (`int`):
Number of channels of AdapterBlock's output.
num_res_blocks (`int`):
Number of ResNet blocks in the AdapterBlock.
down (`bool`, *optional*, defaults to `False`):
Whether to perform downsampling on AdapterBlock's input.
"""
def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
super().__init__() super().__init__()
self.downsample = None self.downsample = None
...@@ -382,7 +409,7 @@ class AdapterBlock(nn.Module): ...@@ -382,7 +409,7 @@ class AdapterBlock(nn.Module):
*[AdapterResnetBlock(out_channels) for _ in range(num_res_blocks)], *[AdapterResnetBlock(out_channels) for _ in range(num_res_blocks)],
) )
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
r""" r"""
This method takes tensor x as input and performs operations downsampling and convolutional layers if the This method takes tensor x as input and performs operations downsampling and convolutional layers if the
self.downsample and self.in_conv properties of AdapterBlock model are specified. Then it applies a series of self.downsample and self.in_conv properties of AdapterBlock model are specified. Then it applies a series of
...@@ -400,13 +427,21 @@ class AdapterBlock(nn.Module): ...@@ -400,13 +427,21 @@ class AdapterBlock(nn.Module):
class AdapterResnetBlock(nn.Module): class AdapterResnetBlock(nn.Module):
def __init__(self, channels): r"""
An `AdapterResnetBlock` is a helper model that implements a ResNet-like block.
Parameters:
channels (`int`):
Number of channels of AdapterResnetBlock's input and output.
"""
def __init__(self, channels: int):
super().__init__() super().__init__()
self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.act = nn.ReLU() self.act = nn.ReLU()
self.block2 = nn.Conv2d(channels, channels, kernel_size=1) self.block2 = nn.Conv2d(channels, channels, kernel_size=1)
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
r""" r"""
This method takes input tensor x and applies a convolutional layer, ReLU activation, and another convolutional This method takes input tensor x and applies a convolutional layer, ReLU activation, and another convolutional
layer on the input tensor. It returns addition with the input tensor. layer on the input tensor. It returns addition with the input tensor.
...@@ -423,6 +458,10 @@ class AdapterResnetBlock(nn.Module): ...@@ -423,6 +458,10 @@ class AdapterResnetBlock(nn.Module):
class LightAdapter(nn.Module): class LightAdapter(nn.Module):
r"""
See [`T2IAdapter`] for more information.
"""
def __init__( def __init__(
self, self,
in_channels: int = 3, in_channels: int = 3,
...@@ -449,7 +488,7 @@ class LightAdapter(nn.Module): ...@@ -449,7 +488,7 @@ class LightAdapter(nn.Module):
self.total_downscale_factor = downscale_factor * (2 ** len(channels)) self.total_downscale_factor = downscale_factor * (2 ** len(channels))
def forward(self, x): def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
r""" r"""
This method takes the input tensor x and performs downscaling and appends it in list of feature tensors. Each This method takes the input tensor x and performs downscaling and appends it in list of feature tensors. Each
feature tensor corresponds to a different level of processing within the LightAdapter. feature tensor corresponds to a different level of processing within the LightAdapter.
...@@ -466,7 +505,22 @@ class LightAdapter(nn.Module): ...@@ -466,7 +505,22 @@ class LightAdapter(nn.Module):
class LightAdapterBlock(nn.Module): class LightAdapterBlock(nn.Module):
def __init__(self, in_channels, out_channels, num_res_blocks, down=False): r"""
A `LightAdapterBlock` is a helper model that contains multiple `LightAdapterResnetBlocks`. It is used in the
`LightAdapter` model.
Parameters:
in_channels (`int`):
Number of channels of LightAdapterBlock's input.
out_channels (`int`):
Number of channels of LightAdapterBlock's output.
num_res_blocks (`int`):
Number of LightAdapterResnetBlocks in the LightAdapterBlock.
down (`bool`, *optional*, defaults to `False`):
Whether to perform downsampling on LightAdapterBlock's input.
"""
def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
super().__init__() super().__init__()
mid_channels = out_channels // 4 mid_channels = out_channels // 4
...@@ -478,7 +532,7 @@ class LightAdapterBlock(nn.Module): ...@@ -478,7 +532,7 @@ class LightAdapterBlock(nn.Module):
self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)]) self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)])
self.out_conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1) self.out_conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1)
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
r""" r"""
This method takes tensor x as input and performs downsampling if required. Then it applies in convolution This method takes tensor x as input and performs downsampling if required. Then it applies in convolution
layer, a sequence of residual blocks, and out convolutional layer. layer, a sequence of residual blocks, and out convolutional layer.
...@@ -494,13 +548,22 @@ class LightAdapterBlock(nn.Module): ...@@ -494,13 +548,22 @@ class LightAdapterBlock(nn.Module):
class LightAdapterResnetBlock(nn.Module): class LightAdapterResnetBlock(nn.Module):
def __init__(self, channels): """
A `LightAdapterResnetBlock` is a helper model that implements a ResNet-like block with a slightly different
architecture than `AdapterResnetBlock`.
Parameters:
channels (`int`):
Number of channels of LightAdapterResnetBlock's input and output.
"""
def __init__(self, channels: int):
super().__init__() super().__init__()
self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.act = nn.ReLU() self.act = nn.ReLU()
self.block2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) self.block2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
r""" r"""
This function takes input tensor x and processes it through one convolutional layer, ReLU activation, and This function takes input tensor x and processes it through one convolutional layer, ReLU activation, and
another convolutional layer and adds it to input tensor. another convolutional layer and adds it to input tensor.
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -26,7 +26,17 @@ from .lora import LoRACompatibleLinear ...@@ -26,7 +26,17 @@ from .lora import LoRACompatibleLinear
@maybe_allow_in_graph @maybe_allow_in_graph
class GatedSelfAttentionDense(nn.Module): class GatedSelfAttentionDense(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, d_head): r"""
A gated self-attention dense layer that combines visual features and object features.
Parameters:
query_dim (`int`): The number of channels in the query.
context_dim (`int`): The number of channels in the context.
n_heads (`int`): The number of heads to use for attention.
d_head (`int`): The number of channels in each head.
"""
def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
super().__init__() super().__init__()
# we need a linear projection since we need cat visual feature and obj feature # we need a linear projection since we need cat visual feature and obj feature
...@@ -43,7 +53,7 @@ class GatedSelfAttentionDense(nn.Module): ...@@ -43,7 +53,7 @@ class GatedSelfAttentionDense(nn.Module):
self.enabled = True self.enabled = True
def forward(self, x, objs): def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
if not self.enabled: if not self.enabled:
return x return x
...@@ -67,15 +77,25 @@ class BasicTransformerBlock(nn.Module): ...@@ -67,15 +77,25 @@ class BasicTransformerBlock(nn.Module):
attention_head_dim (`int`): The number of channels in each head. attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (: num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (: attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
""" """
def __init__( def __init__(
...@@ -175,7 +195,7 @@ class BasicTransformerBlock(nn.Module): ...@@ -175,7 +195,7 @@ class BasicTransformerBlock(nn.Module):
timestep: Optional[torch.LongTensor] = None, timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None, cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None, class_labels: Optional[torch.LongTensor] = None,
): ) -> torch.FloatTensor:
# Notice that normalization is always applied before the real computation in the following blocks. # Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention # 0. Self-Attention
if self.use_ada_layer_norm: if self.use_ada_layer_norm:
...@@ -301,7 +321,7 @@ class FeedForward(nn.Module): ...@@ -301,7 +321,7 @@ class FeedForward(nn.Module):
if final_dropout: if final_dropout:
self.net.append(nn.Dropout(dropout)) self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states, scale: float = 1.0): def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
for module in self.net: for module in self.net:
if isinstance(module, (LoRACompatibleLinear, GEGLU)): if isinstance(module, (LoRACompatibleLinear, GEGLU)):
hidden_states = module(hidden_states, scale) hidden_states = module(hidden_states, scale)
...@@ -313,6 +333,11 @@ class FeedForward(nn.Module): ...@@ -313,6 +333,11 @@ class FeedForward(nn.Module):
class GELU(nn.Module): class GELU(nn.Module):
r""" r"""
GELU activation function with tanh approximation support with `approximate="tanh"`. GELU activation function with tanh approximation support with `approximate="tanh"`.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
""" """
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
...@@ -320,7 +345,7 @@ class GELU(nn.Module): ...@@ -320,7 +345,7 @@ class GELU(nn.Module):
self.proj = nn.Linear(dim_in, dim_out) self.proj = nn.Linear(dim_in, dim_out)
self.approximate = approximate self.approximate = approximate
def gelu(self, gate): def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps": if gate.device.type != "mps":
return F.gelu(gate, approximate=self.approximate) return F.gelu(gate, approximate=self.approximate)
# mps: gelu is not implemented for float16 # mps: gelu is not implemented for float16
...@@ -345,7 +370,7 @@ class GEGLU(nn.Module): ...@@ -345,7 +370,7 @@ class GEGLU(nn.Module):
super().__init__() super().__init__()
self.proj = LoRACompatibleLinear(dim_in, dim_out * 2) self.proj = LoRACompatibleLinear(dim_in, dim_out * 2)
def gelu(self, gate): def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps": if gate.device.type != "mps":
return F.gelu(gate) return F.gelu(gate)
# mps: gelu is not implemented for float16 # mps: gelu is not implemented for float16
...@@ -357,34 +382,41 @@ class GEGLU(nn.Module): ...@@ -357,34 +382,41 @@ class GEGLU(nn.Module):
class ApproximateGELU(nn.Module): class ApproximateGELU(nn.Module):
""" r"""
The approximate form of Gaussian Error Linear Unit (GELU) The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2:
https://arxiv.org/abs/1606.08415.
For more details, see section 2: https://arxiv.org/abs/1606.08415 Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
""" """
def __init__(self, dim_in: int, dim_out: int): def __init__(self, dim_in: int, dim_out: int):
super().__init__() super().__init__()
self.proj = nn.Linear(dim_in, dim_out) self.proj = nn.Linear(dim_in, dim_out)
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x) x = self.proj(x)
return x * torch.sigmoid(1.702 * x) return x * torch.sigmoid(1.702 * x)
class AdaLayerNorm(nn.Module): class AdaLayerNorm(nn.Module):
""" r"""
Norm layer modified to incorporate timestep embeddings. Norm layer modified to incorporate timestep embeddings.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the dictionary of embeddings.
""" """
def __init__(self, embedding_dim, num_embeddings): def __init__(self, embedding_dim: int, num_embeddings: int):
super().__init__() super().__init__()
self.emb = nn.Embedding(num_embeddings, embedding_dim) self.emb = nn.Embedding(num_embeddings, embedding_dim)
self.silu = nn.SiLU() self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, embedding_dim * 2) self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
def forward(self, x, timestep): def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
emb = self.linear(self.silu(self.emb(timestep))) emb = self.linear(self.silu(self.emb(timestep)))
scale, shift = torch.chunk(emb, 2) scale, shift = torch.chunk(emb, 2)
x = self.norm(x) * (1 + scale) + shift x = self.norm(x) * (1 + scale) + shift
...@@ -392,11 +424,15 @@ class AdaLayerNorm(nn.Module): ...@@ -392,11 +424,15 @@ class AdaLayerNorm(nn.Module):
class AdaLayerNormZero(nn.Module): class AdaLayerNormZero(nn.Module):
""" r"""
Norm layer adaptive layer norm zero (adaLN-Zero). Norm layer adaptive layer norm zero (adaLN-Zero).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the dictionary of embeddings.
""" """
def __init__(self, embedding_dim, num_embeddings): def __init__(self, embedding_dim: int, num_embeddings: int):
super().__init__() super().__init__()
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
...@@ -405,7 +441,13 @@ class AdaLayerNormZero(nn.Module): ...@@ -405,7 +441,13 @@ class AdaLayerNormZero(nn.Module):
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, timestep, class_labels, hidden_dtype=None): def forward(
self,
x: torch.Tensor,
timestep: torch.Tensor,
class_labels: torch.LongTensor,
hidden_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype))) emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
...@@ -413,8 +455,15 @@ class AdaLayerNormZero(nn.Module): ...@@ -413,8 +455,15 @@ class AdaLayerNormZero(nn.Module):
class AdaGroupNorm(nn.Module): class AdaGroupNorm(nn.Module):
""" r"""
GroupNorm layer modified to incorporate timestep embeddings. GroupNorm layer modified to incorporate timestep embeddings.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the dictionary of embeddings.
num_groups (`int`): The number of groups to separate the channels into.
act_fn (`str`, *optional*, defaults to `None`): The activation function to use.
eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability.
""" """
def __init__( def __init__(
...@@ -431,7 +480,7 @@ class AdaGroupNorm(nn.Module): ...@@ -431,7 +480,7 @@ class AdaGroupNorm(nn.Module):
self.linear = nn.Linear(embedding_dim, out_dim * 2) self.linear = nn.Linear(embedding_dim, out_dim * 2)
def forward(self, x, emb): def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
if self.act: if self.act:
emb = self.act(emb) emb = self.act(emb)
emb = self.linear(emb) emb = self.linear(emb)
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional from typing import Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -40,7 +40,35 @@ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0): ...@@ -40,7 +40,35 @@ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
class LoRALinearLayer(nn.Module): class LoRALinearLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): r"""
A linear layer that is used with LoRA.
Parameters:
in_features (`int`):
Number of input features.
out_features (`int`):
Number of output features.
rank (`int`, `optional`, defaults to 4):
The rank of the LoRA layer.
network_alpha (`float`, `optional`, defaults to `None`):
The value of the network alpha used for stable learning and preventing underflow. This value has the same
meaning as the `--network_alpha` option in the kohya-ss trainer script. See
https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
device (`torch.device`, `optional`, defaults to `None`):
The device to use for the layer's weights.
dtype (`torch.dtype`, `optional`, defaults to `None`):
The dtype to use for the layer's weights.
"""
def __init__(
self,
in_features: int,
out_features: int,
rank: int = 4,
network_alpha: Optional[float] = None,
device: Optional[Union[torch.device, str]] = None,
dtype: Optional[torch.dtype] = None,
):
super().__init__() super().__init__()
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
...@@ -55,7 +83,7 @@ class LoRALinearLayer(nn.Module): ...@@ -55,7 +83,7 @@ class LoRALinearLayer(nn.Module):
nn.init.normal_(self.down.weight, std=1 / rank) nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight) nn.init.zeros_(self.up.weight)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_dtype = hidden_states.dtype orig_dtype = hidden_states.dtype
dtype = self.down.weight.dtype dtype = self.down.weight.dtype
...@@ -69,8 +97,37 @@ class LoRALinearLayer(nn.Module): ...@@ -69,8 +97,37 @@ class LoRALinearLayer(nn.Module):
class LoRAConv2dLayer(nn.Module): class LoRAConv2dLayer(nn.Module):
r"""
A convolutional layer that is used with LoRA.
Parameters:
in_features (`int`):
Number of input features.
out_features (`int`):
Number of output features.
rank (`int`, `optional`, defaults to 4):
The rank of the LoRA layer.
kernel_size (`int` or `tuple` of two `int`, `optional`, defaults to 1):
The kernel size of the convolution.
stride (`int` or `tuple` of two `int`, `optional`, defaults to 1):
The stride of the convolution.
padding (`int` or `tuple` of two `int` or `str`, `optional`, defaults to 0):
The padding of the convolution.
network_alpha (`float`, `optional`, defaults to `None`):
The value of the network alpha used for stable learning and preventing underflow. This value has the same
meaning as the `--network_alpha` option in the kohya-ss trainer script. See
https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
"""
def __init__( def __init__(
self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None self,
in_features: int,
out_features: int,
rank: int = 4,
kernel_size: Union[int, Tuple[int, int]] = (1, 1),
stride: Union[int, Tuple[int, int]] = (1, 1),
padding: Union[int, Tuple[int, int], str] = 0,
network_alpha: Optional[float] = None,
): ):
super().__init__() super().__init__()
...@@ -87,7 +144,7 @@ class LoRAConv2dLayer(nn.Module): ...@@ -87,7 +144,7 @@ class LoRAConv2dLayer(nn.Module):
nn.init.normal_(self.down.weight, std=1 / rank) nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight) nn.init.zeros_(self.up.weight)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_dtype = hidden_states.dtype orig_dtype = hidden_states.dtype
dtype = self.down.weight.dtype dtype = self.down.weight.dtype
...@@ -112,7 +169,7 @@ class LoRACompatibleConv(nn.Conv2d): ...@@ -112,7 +169,7 @@ class LoRACompatibleConv(nn.Conv2d):
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]): def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
self.lora_layer = lora_layer self.lora_layer = lora_layer
def _fuse_lora(self, lora_scale=1.0, safe_fusing=False): def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
if self.lora_layer is None: if self.lora_layer is None:
return return
...@@ -164,7 +221,7 @@ class LoRACompatibleConv(nn.Conv2d): ...@@ -164,7 +221,7 @@ class LoRACompatibleConv(nn.Conv2d):
self.w_up = None self.w_up = None
self.w_down = None self.w_down = None
def forward(self, hidden_states, scale: float = 1.0): def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
if self.lora_layer is None: if self.lora_layer is None:
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break # make sure to the functional Conv2D function as otherwise torch.compile's graph will break
# see: https://github.com/huggingface/diffusers/pull/4315 # see: https://github.com/huggingface/diffusers/pull/4315
...@@ -190,7 +247,7 @@ class LoRACompatibleLinear(nn.Linear): ...@@ -190,7 +247,7 @@ class LoRACompatibleLinear(nn.Linear):
def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]): def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
self.lora_layer = lora_layer self.lora_layer = lora_layer
def _fuse_lora(self, lora_scale=1.0, safe_fusing=False): def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
if self.lora_layer is None: if self.lora_layer is None:
return return
...@@ -238,7 +295,7 @@ class LoRACompatibleLinear(nn.Linear): ...@@ -238,7 +295,7 @@ class LoRACompatibleLinear(nn.Linear):
self.w_up = None self.w_up = None
self.w_down = None self.w_down = None
def forward(self, hidden_states, scale: float = 1.0): def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
if self.lora_layer is None: if self.lora_layer is None:
out = super().forward(hidden_states) out = super().forward(hidden_states)
return out return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment