"src/diffusers/utils/dummy_torchao_objects.py" did not exist on "b4e6a7403d9ff66dab627098abb16fc68a15d024"
Unverified Commit 77cad127 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding multi-layer perceptron in ops (#6053)

* Adding an MLP block.

* Adding documentation

* Update typos.

* Fix inplace for Dropout.

* Apply recommendations from code review.

* Making changes on pre-trained models.

* Fix linter
parent e65372e1
...@@ -87,6 +87,7 @@ TorchVision provides commonly used building blocks as layers: ...@@ -87,6 +87,7 @@ TorchVision provides commonly used building blocks as layers:
DeformConv2d DeformConv2d
DropBlock2d DropBlock2d
DropBlock3d DropBlock3d
MLP
FrozenBatchNorm2d FrozenBatchNorm2d
SqueezeExcitation SqueezeExcitation
StochasticDepth StochasticDepth
......
...@@ -5,14 +5,14 @@ import torch ...@@ -5,14 +5,14 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn, Tensor from torch import nn, Tensor
from ..ops.misc import MLP
from ..ops.stochastic_depth import StochasticDepth from ..ops.stochastic_depth import StochasticDepth
from ..transforms._presets import ImageClassification, InterpolationMode from ..transforms._presets import ImageClassification, InterpolationMode
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from ._api import WeightsEnum, Weights from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param from ._utils import _ovewrite_named_param
from .convnext import Permute from .convnext import Permute # TODO: move Permute on ops
from .vision_transformer import MLPBlock
__all__ = [ __all__ = [
...@@ -263,7 +263,13 @@ class SwinTransformerBlock(nn.Module): ...@@ -263,7 +263,13 @@ class SwinTransformerBlock(nn.Module):
) )
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
self.mlp = MLPBlock(dim, int(dim * mlp_ratio), dropout) self.mlp = MLP(dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)
for m in self.mlp.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.normal_(m.bias, std=1e-6)
def forward(self, x: Tensor): def forward(self, x: Tensor):
x = x + self.stochastic_depth(self.attn(self.norm1(x))) x = x + self.stochastic_depth(self.attn(self.norm1(x)))
...@@ -412,7 +418,7 @@ _COMMON_META = { ...@@ -412,7 +418,7 @@ _COMMON_META = {
class Swin_T_Weights(WeightsEnum): class Swin_T_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/swin_t-4c37bd06.pth", url="https://download.pytorch.org/models/swin_t-704ceda3.pth",
transforms=partial( transforms=partial(
ImageClassification, crop_size=224, resize_size=232, interpolation=InterpolationMode.BICUBIC ImageClassification, crop_size=224, resize_size=232, interpolation=InterpolationMode.BICUBIC
), ),
...@@ -435,7 +441,7 @@ class Swin_T_Weights(WeightsEnum): ...@@ -435,7 +441,7 @@ class Swin_T_Weights(WeightsEnum):
class Swin_S_Weights(WeightsEnum): class Swin_S_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/swin_s-30134662.pth", url="https://download.pytorch.org/models/swin_s-5e29d889.pth",
transforms=partial( transforms=partial(
ImageClassification, crop_size=224, resize_size=246, interpolation=InterpolationMode.BICUBIC ImageClassification, crop_size=224, resize_size=246, interpolation=InterpolationMode.BICUBIC
), ),
...@@ -458,7 +464,7 @@ class Swin_S_Weights(WeightsEnum): ...@@ -458,7 +464,7 @@ class Swin_S_Weights(WeightsEnum):
class Swin_B_Weights(WeightsEnum): class Swin_B_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/swin_b-1f1feb5c.pth", url="https://download.pytorch.org/models/swin_b-68c6b09e.pth",
transforms=partial( transforms=partial(
ImageClassification, crop_size=224, resize_size=238, interpolation=InterpolationMode.BICUBIC ImageClassification, crop_size=224, resize_size=238, interpolation=InterpolationMode.BICUBIC
), ),
......
...@@ -6,7 +6,7 @@ from typing import Any, Callable, List, NamedTuple, Optional, Dict ...@@ -6,7 +6,7 @@ from typing import Any, Callable, List, NamedTuple, Optional, Dict
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..ops.misc import Conv2dNormActivation from ..ops.misc import Conv2dNormActivation, MLP
from ..transforms._presets import ImageClassification, InterpolationMode from ..transforms._presets import ImageClassification, InterpolationMode
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from ._api import WeightsEnum, Weights from ._api import WeightsEnum, Weights
...@@ -37,21 +37,48 @@ class ConvStemConfig(NamedTuple): ...@@ -37,21 +37,48 @@ class ConvStemConfig(NamedTuple):
activation_layer: Callable[..., nn.Module] = nn.ReLU activation_layer: Callable[..., nn.Module] = nn.ReLU
class MLPBlock(nn.Sequential): class MLPBlock(MLP):
"""Transformer MLP block.""" """Transformer MLP block."""
def __init__(self, in_dim: int, mlp_dim: int, dropout: float): def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
super().__init__() super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)
self.linear_1 = nn.Linear(in_dim, mlp_dim)
self.act = nn.GELU() for m in self.modules():
self.dropout_1 = nn.Dropout(dropout) if isinstance(m, nn.Linear):
self.linear_2 = nn.Linear(mlp_dim, in_dim) nn.init.xavier_uniform_(m.weight)
self.dropout_2 = nn.Dropout(dropout) if m.bias is not None:
nn.init.normal_(m.bias, std=1e-6)
nn.init.xavier_uniform_(self.linear_1.weight)
nn.init.xavier_uniform_(self.linear_2.weight) def _load_from_state_dict(
nn.init.normal_(self.linear_1.bias, std=1e-6) self,
nn.init.normal_(self.linear_2.bias, std=1e-6) state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
version = local_metadata.get("version", None)
if version is None or version < 2:
# Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053
for i in range(2):
for type in ["weight", "bias"]:
old_key = f"{prefix}linear_{i+1}.{type}"
new_key = f"{prefix}{3*i}.{type}"
if old_key in state_dict:
state_dict[new_key] = state_dict.pop(old_key)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
class EncoderBlock(nn.Module): class EncoderBlock(nn.Module):
......
...@@ -19,7 +19,7 @@ from .drop_block import drop_block2d, DropBlock2d, drop_block3d, DropBlock3d ...@@ -19,7 +19,7 @@ from .drop_block import drop_block2d, DropBlock2d, drop_block3d, DropBlock3d
from .feature_pyramid_network import FeaturePyramidNetwork from .feature_pyramid_network import FeaturePyramidNetwork
from .focal_loss import sigmoid_focal_loss from .focal_loss import sigmoid_focal_loss
from .giou_loss import generalized_box_iou_loss from .giou_loss import generalized_box_iou_loss
from .misc import FrozenBatchNorm2d, Conv2dNormActivation, Conv3dNormActivation, SqueezeExcitation from .misc import FrozenBatchNorm2d, Conv2dNormActivation, Conv3dNormActivation, SqueezeExcitation, MLP
from .poolers import MultiScaleRoIAlign from .poolers import MultiScaleRoIAlign
from .ps_roi_align import ps_roi_align, PSRoIAlign from .ps_roi_align import ps_roi_align, PSRoIAlign
from .ps_roi_pool import ps_roi_pool, PSRoIPool from .ps_roi_pool import ps_roi_pool, PSRoIPool
...@@ -61,6 +61,7 @@ __all__ = [ ...@@ -61,6 +61,7 @@ __all__ = [
"Conv2dNormActivation", "Conv2dNormActivation",
"Conv3dNormActivation", "Conv3dNormActivation",
"SqueezeExcitation", "SqueezeExcitation",
"MLP",
"generalized_box_iou_loss", "generalized_box_iou_loss",
"distance_box_iou_loss", "distance_box_iou_loss",
"complete_box_iou_loss", "complete_box_iou_loss",
......
...@@ -129,7 +129,7 @@ class Conv2dNormActivation(ConvNormActivation): ...@@ -129,7 +129,7 @@ class Conv2dNormActivation(ConvNormActivation):
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation`` padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d`` norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d``
activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
dilation (int): Spacing between kernel elements. Default: 1 dilation (int): Spacing between kernel elements. Default: 1
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
...@@ -179,7 +179,7 @@ class Conv3dNormActivation(ConvNormActivation): ...@@ -179,7 +179,7 @@ class Conv3dNormActivation(ConvNormActivation):
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation`` padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm3d`` norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm3d``
activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
dilation (int): Spacing between kernel elements. Default: 1 dilation (int): Spacing between kernel elements. Default: 1
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
...@@ -253,3 +253,47 @@ class SqueezeExcitation(torch.nn.Module): ...@@ -253,3 +253,47 @@ class SqueezeExcitation(torch.nn.Module):
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
scale = self._scale(input) scale = self._scale(input)
return scale * input return scale * input
class MLP(torch.nn.Sequential):
"""This block implements the multi-layer perceptron (MLP) module.
Args:
in_channels (int): Number of channels of the input
hidden_channels (List[int]): List of the hidden channel dimensions
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``None``
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
bias (bool): Whether to use bias in the linear layer. Default ``True``
dropout (float): The probability for the dropout layer. Default: 0.0
"""
def __init__(
self,
in_channels: int,
hidden_channels: List[int],
norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
inplace: Optional[bool] = True,
bias: bool = True,
dropout: float = 0.0,
):
# The addition of `norm_layer` is inspired from the implementation of TorchMultimodal:
# https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py
params = {} if inplace is None else {"inplace": inplace}
layers = []
in_dim = in_channels
for hidden_dim in hidden_channels[:-1]:
layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
if norm_layer is not None:
layers.append(norm_layer(hidden_dim))
layers.append(activation_layer(**params))
layers.append(torch.nn.Dropout(dropout, **params))
in_dim = hidden_dim
layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias))
layers.append(torch.nn.Dropout(dropout, **params))
super().__init__(*layers)
_log_api_usage_once(self)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment