Unverified Commit bdc01711 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Refactor classes to inherit from nn.Module instead of nn.Sequential (#17493)

* Adapt Maskformer, VAN, ResNet and RegNet modules to inherit from nn.Module
parent b1160c0b
...@@ -1958,7 +1958,7 @@ class MaskFormerSwinTransformerBackbone(nn.Module): ...@@ -1958,7 +1958,7 @@ class MaskFormerSwinTransformerBackbone(nn.Module):
return [layer.dim for layer in self.model.encoder.layers] return [layer.dim for layer in self.model.encoder.layers]
class MaskFormerFPNConvLayer(nn.Sequential): class MaskFormerFPNConvLayer(nn.Module):
def __init__(self, in_features: int, out_features: int, kernel_size: int = 3, padding: int = 1): def __init__(self, in_features: int, out_features: int, kernel_size: int = 3, padding: int = 1):
""" """
A basic module that executes conv - norm - in sequence used in MaskFormer. A basic module that executes conv - norm - in sequence used in MaskFormer.
...@@ -1969,11 +1969,26 @@ class MaskFormerFPNConvLayer(nn.Sequential): ...@@ -1969,11 +1969,26 @@ class MaskFormerFPNConvLayer(nn.Sequential):
out_features (`int`): out_features (`int`):
The number of outputs features (channels). The number of outputs features (channels).
""" """
super().__init__( super().__init__()
self.layers = [
nn.Conv2d(in_features, out_features, kernel_size=kernel_size, padding=padding, bias=False), nn.Conv2d(in_features, out_features, kernel_size=kernel_size, padding=padding, bias=False),
nn.GroupNorm(32, out_features), nn.GroupNorm(32, out_features),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
) ]
for i, layer in enumerate(self.layers):
# Provide backwards compatibility from when the class inherited from nn.Sequential
# In nn.Sequential subclasses, the name given to the layer is its index in the sequence.
# In nn.Module subclasses they derived from the instance attribute they are assigned to e.g.
# self.my_layer_name = Layer()
# We can't give instance attributes integer names i.e. self.0 is not permitted and so need to register
# explicitly
self.add_module(str(i), layer)
def forward(self, input: Tensor) -> Tensor:
hidden_state = input
for layer in self.layers:
hidden_state = layer(hidden_state)
return hidden_state
class MaskFormerFPNLayer(nn.Module): class MaskFormerFPNLayer(nn.Module):
...@@ -2101,7 +2116,38 @@ class MaskFormerSinePositionEmbedding(nn.Module): ...@@ -2101,7 +2116,38 @@ class MaskFormerSinePositionEmbedding(nn.Module):
return pos return pos
class MaskformerMLPPredictionHead(nn.Sequential): class IdentityBlock(nn.Module):
def __init__(self):
super().__init__()
# Create as an iterable here so that the identity layer isn't registered
# with the name of the instance variable its assigned to
self.layers = [nn.Identity()]
# Maintain submodule indexing as if part of a Sequential block
self.add_module("0", self.layers[0])
def forward(self, input: Tensor) -> Tensor:
hidden_state = input
for layer in self.layers:
hidden_state = layer(hidden_state)
return hidden_state
class NonLinearBlock(nn.Module):
def __init__(self, in_dim: int, out_dim: int) -> None:
super().__init__()
self.layers = [nn.Linear(in_dim, out_dim), nn.ReLU(inplace=True)]
# Maintain submodule indexing as if part of a Sequential block
for i, layer in enumerate(self.layers):
self.add_module(str(i), layer)
def forward(self, input: Tensor) -> Tensor:
hidden_state = input
for layer in self.layers:
hidden_state = layer(hidden_state)
return hidden_state
class MaskformerMLPPredictionHead(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 3): def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 3):
""" """
A classic Multi Layer Perceptron (MLP). A classic Multi Layer Perceptron (MLP).
...@@ -2116,18 +2162,27 @@ class MaskformerMLPPredictionHead(nn.Sequential): ...@@ -2116,18 +2162,27 @@ class MaskformerMLPPredictionHead(nn.Sequential):
num_layers (int, *optional*, defaults to 3): num_layers (int, *optional*, defaults to 3):
The number of layers. The number of layers.
""" """
super().__init__()
in_dims = [input_dim] + [hidden_dim] * (num_layers - 1) in_dims = [input_dim] + [hidden_dim] * (num_layers - 1)
out_dims = [hidden_dim] * (num_layers - 1) + [output_dim] out_dims = [hidden_dim] * (num_layers - 1) + [output_dim]
layers = [] self.layers = []
for i, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)): for i, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)):
layer = NonLinearBlock(in_dim, out_dim) if i < num_layers - 1 else IdentityBlock()
layer = nn.Sequential( self.layers.append(layer)
nn.Linear(in_dim, out_dim), nn.ReLU(inplace=True) if i < num_layers - 1 else nn.Identity() # Provide backwards compatibility from when the class inherited from nn.Sequential
) # In nn.Sequential subclasses, the name given to the layer is its index in the sequence.
layers.append(layer) # In nn.Module subclasses they derived from the instance attribute they are assigned to e.g.
# self.my_layer_name = Layer()
super().__init__(*layers) # We can't give instance attributes integer names i.e. self.0 is not permitted and so need to register
# explicitly
self.add_module(str(i), layer)
def forward(self, input: Tensor) -> Tensor:
hidden_state = input
for layer in self.layers:
hidden_state = layer(hidden_state)
return hidden_state
class MaskFormerPixelLevelModule(nn.Module): class MaskFormerPixelLevelModule(nn.Module):
...@@ -2253,20 +2308,21 @@ class MaskFormerPreTrainedModel(PreTrainedModel): ...@@ -2253,20 +2308,21 @@ class MaskFormerPreTrainedModel(PreTrainedModel):
nn.init.constant_(module.input_projection.bias, 0) nn.init.constant_(module.input_projection.bias, 0)
# FPN # FPN
elif isinstance(module, MaskFormerFPNModel): elif isinstance(module, MaskFormerFPNModel):
nn.init.xavier_uniform_(module.stem[0].weight, gain=xavier_std) nn.init.xavier_uniform_(module.stem.get_submodule("0").weight, gain=xavier_std)
elif isinstance(module, MaskFormerFPNLayer): elif isinstance(module, MaskFormerFPNLayer):
nn.init.xavier_uniform_(module.proj[0].weight, gain=xavier_std) nn.init.xavier_uniform_(module.proj[0].weight, gain=xavier_std)
elif isinstance(module, MaskFormerFPNConvLayer): elif isinstance(module, MaskFormerFPNConvLayer):
nn.init.xavier_uniform_(module[0].weight, gain=xavier_std) nn.init.xavier_uniform_(module.get_submodule("0").weight, gain=xavier_std)
# The MLP head # The MLP head
elif isinstance(module, MaskformerMLPPredictionHead): elif isinstance(module, MaskformerMLPPredictionHead):
# I was not able to find the correct initializer in the original implementation # I was not able to find the correct initializer in the original implementation
# we'll use xavier # we'll use xavier
for layer in module: for submodule in module.modules():
nn.init.xavier_uniform_(layer[0].weight, gain=xavier_std) if isinstance(submodule, nn.Linear):
nn.init.constant_(layer[0].bias, 0) nn.init.xavier_uniform_(submodule.weight, gain=xavier_std)
nn.init.constant_(submodule.bias, 0)
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
......
...@@ -100,7 +100,7 @@ class RegNetEmbeddings(nn.Module): ...@@ -100,7 +100,7 @@ class RegNetEmbeddings(nn.Module):
# Copied from transformers.models.resnet.modeling_resnet.ResNetShortCut with ResNet->RegNet # Copied from transformers.models.resnet.modeling_resnet.ResNetShortCut with ResNet->RegNet
class RegNetShortCut(nn.Sequential): class RegNetShortCut(nn.Module):
""" """
RegNet shortcut, used to project the residual features to the correct size. If needed, it is also used to RegNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
downsample the input using `stride=2`. downsample the input using `stride=2`.
...@@ -111,6 +111,11 @@ class RegNetShortCut(nn.Sequential): ...@@ -111,6 +111,11 @@ class RegNetShortCut(nn.Sequential):
self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
self.normalization = nn.BatchNorm2d(out_channels) self.normalization = nn.BatchNorm2d(out_channels)
def forward(self, input: Tensor) -> Tensor:
hidden_state = self.convolution(input)
hidden_state = self.normalization(hidden_state)
return hidden_state
class RegNetSELayer(nn.Module): class RegNetSELayer(nn.Module):
""" """
......
...@@ -52,7 +52,7 @@ RESNET_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -52,7 +52,7 @@ RESNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
class ResNetConvLayer(nn.Sequential): class ResNetConvLayer(nn.Module):
def __init__( def __init__(
self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu" self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu"
): ):
...@@ -63,8 +63,14 @@ class ResNetConvLayer(nn.Sequential): ...@@ -63,8 +63,14 @@ class ResNetConvLayer(nn.Sequential):
self.normalization = nn.BatchNorm2d(out_channels) self.normalization = nn.BatchNorm2d(out_channels)
self.activation = ACT2FN[activation] if activation is not None else nn.Identity() self.activation = ACT2FN[activation] if activation is not None else nn.Identity()
def forward(self, input: Tensor) -> Tensor:
hidden_state = self.convolution(input)
hidden_state = self.normalization(hidden_state)
hidden_state = self.activation(hidden_state)
return hidden_state
class ResNetEmbeddings(nn.Sequential):
class ResNetEmbeddings(nn.Module):
""" """
ResNet Embeddings (stem) composed of a single aggressive convolution. ResNet Embeddings (stem) composed of a single aggressive convolution.
""" """
...@@ -76,8 +82,13 @@ class ResNetEmbeddings(nn.Sequential): ...@@ -76,8 +82,13 @@ class ResNetEmbeddings(nn.Sequential):
) )
self.pooler = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.pooler = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
def forward(self, input: Tensor) -> Tensor:
embedding = self.embedder(input)
embedding = self.pooler(embedding)
return embedding
class ResNetShortCut(nn.Sequential): class ResNetShortCut(nn.Module):
""" """
ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
downsample the input using `stride=2`. downsample the input using `stride=2`.
...@@ -88,6 +99,11 @@ class ResNetShortCut(nn.Sequential): ...@@ -88,6 +99,11 @@ class ResNetShortCut(nn.Sequential):
self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
self.normalization = nn.BatchNorm2d(out_channels) self.normalization = nn.BatchNorm2d(out_channels)
def forward(self, input: Tensor) -> Tensor:
hidden_state = self.convolution(input)
hidden_state = self.normalization(hidden_state)
return hidden_state
class ResNetBasicLayer(nn.Module): class ResNetBasicLayer(nn.Module):
""" """
...@@ -148,7 +164,7 @@ class ResNetBottleNeckLayer(nn.Module): ...@@ -148,7 +164,7 @@ class ResNetBottleNeckLayer(nn.Module):
return hidden_state return hidden_state
class ResNetStage(nn.Sequential): class ResNetStage(nn.Module):
""" """
A ResNet stage composed by stacked layers. A ResNet stage composed by stacked layers.
""" """
...@@ -165,11 +181,25 @@ class ResNetStage(nn.Sequential): ...@@ -165,11 +181,25 @@ class ResNetStage(nn.Sequential):
layer = ResNetBottleNeckLayer if config.layer_type == "bottleneck" else ResNetBasicLayer layer = ResNetBottleNeckLayer if config.layer_type == "bottleneck" else ResNetBasicLayer
self.layers = nn.Sequential( self.layers = [
# downsampling is done in the first layer with stride of 2 # downsampling is done in the first layer with stride of 2
layer(in_channels, out_channels, stride=stride, activation=config.hidden_act), layer(in_channels, out_channels, stride=stride, activation=config.hidden_act),
*[layer(out_channels, out_channels, activation=config.hidden_act) for _ in range(depth - 1)], *[layer(out_channels, out_channels, activation=config.hidden_act) for _ in range(depth - 1)],
) ]
# Provide backwards compatibility from when the class inherited from nn.Sequential
# In nn.Sequential subclasses, the name given to the layer is its index in the sequence.
# In nn.Module subclasses they derived from the instance attribute they are assigned to e.g.
# self.my_layer_name = Layer()
# We can't give instance attributes integer names i.e. self.0 is not permitted and so need to register
# the module explicitly
for i, layer in enumerate(self.layers):
self.add_module(str(i), layer)
def forward(self, input: Tensor) -> Tensor:
hidden_state = input
for layer in self.layers:
hidden_state = layer(hidden_state)
return hidden_state
class ResNetEncoder(nn.Module): class ResNetEncoder(nn.Module):
......
...@@ -86,7 +86,7 @@ class VanDropPath(nn.Module): ...@@ -86,7 +86,7 @@ class VanDropPath(nn.Module):
return drop_path(x, self.drop_prob, self.training) return drop_path(x, self.drop_prob, self.training)
class VanOverlappingPatchEmbedder(nn.Sequential): class VanOverlappingPatchEmbedder(nn.Module):
""" """
Downsamples the input using a patchify operation with a `stride` of 4 by default making adjacent windows overlap by Downsamples the input using a patchify operation with a `stride` of 4 by default making adjacent windows overlap by
half of the area. From [PVTv2: Improved Baselines with Pyramid Vision half of the area. From [PVTv2: Improved Baselines with Pyramid Vision
...@@ -100,8 +100,13 @@ class VanOverlappingPatchEmbedder(nn.Sequential): ...@@ -100,8 +100,13 @@ class VanOverlappingPatchEmbedder(nn.Sequential):
) )
self.normalization = nn.BatchNorm2d(hidden_size) self.normalization = nn.BatchNorm2d(hidden_size)
def forward(self, input: torch.Tensor) -> torch.Tensor:
hidden_state = self.convolution(input)
hidden_state = self.normalization(hidden_state)
return hidden_state
class VanMlpLayer(nn.Sequential): class VanMlpLayer(nn.Module):
""" """
MLP with depth-wise convolution, from [PVTv2: Improved Baselines with Pyramid Vision MLP with depth-wise convolution, from [PVTv2: Improved Baselines with Pyramid Vision
Transformer](https://arxiv.org/abs/2106.13797). Transformer](https://arxiv.org/abs/2106.13797).
...@@ -123,8 +128,17 @@ class VanMlpLayer(nn.Sequential): ...@@ -123,8 +128,17 @@ class VanMlpLayer(nn.Sequential):
self.out_dense = nn.Conv2d(hidden_size, out_channels, kernel_size=1) self.out_dense = nn.Conv2d(hidden_size, out_channels, kernel_size=1)
self.dropout2 = nn.Dropout(dropout_rate) self.dropout2 = nn.Dropout(dropout_rate)
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = self.in_dense(hidden_state)
hidden_state = self.depth_wise(hidden_state)
hidden_state = self.activation(hidden_state)
hidden_state = self.dropout1(hidden_state)
hidden_state = self.out_dense(hidden_state)
hidden_state = self.dropout2(hidden_state)
return hidden_state
class VanLargeKernelAttention(nn.Sequential):
class VanLargeKernelAttention(nn.Module):
""" """
Basic Large Kernel Attention (LKA). Basic Large Kernel Attention (LKA).
""" """
...@@ -137,6 +151,12 @@ class VanLargeKernelAttention(nn.Sequential): ...@@ -137,6 +151,12 @@ class VanLargeKernelAttention(nn.Sequential):
) )
self.point_wise = nn.Conv2d(hidden_size, hidden_size, kernel_size=1) self.point_wise = nn.Conv2d(hidden_size, hidden_size, kernel_size=1)
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = self.depth_wise(hidden_state)
hidden_state = self.depth_wise_dilated(hidden_state)
hidden_state = self.point_wise(hidden_state)
return hidden_state
class VanLargeKernelAttentionLayer(nn.Module): class VanLargeKernelAttentionLayer(nn.Module):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment