Unverified Commit bdc01711 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Refactor classes to inherit from nn.Module instead of nn.Sequential (#17493)

* Adapt Maskformer, VAN, ResNet and RegNet modules to inherit from nn.Module
parent b1160c0b
......@@ -1958,7 +1958,7 @@ class MaskFormerSwinTransformerBackbone(nn.Module):
return [layer.dim for layer in self.model.encoder.layers]
class MaskFormerFPNConvLayer(nn.Sequential):
class MaskFormerFPNConvLayer(nn.Module):
def __init__(self, in_features: int, out_features: int, kernel_size: int = 3, padding: int = 1):
"""
A basic module that executes conv - norm - in sequence used in MaskFormer.
......@@ -1969,11 +1969,26 @@ class MaskFormerFPNConvLayer(nn.Sequential):
out_features (`int`):
The number of outputs features (channels).
"""
super().__init__(
super().__init__()
self.layers = [
nn.Conv2d(in_features, out_features, kernel_size=kernel_size, padding=padding, bias=False),
nn.GroupNorm(32, out_features),
nn.ReLU(inplace=True),
)
]
for i, layer in enumerate(self.layers):
# Provide backwards compatibility from when the class inherited from nn.Sequential
# In nn.Sequential subclasses, the name given to the layer is its index in the sequence.
# In nn.Module subclasses they derived from the instance attribute they are assigned to e.g.
# self.my_layer_name = Layer()
# We can't give instance attributes integer names i.e. self.0 is not permitted and so need to register
# explicitly
self.add_module(str(i), layer)
def forward(self, input: Tensor) -> Tensor:
hidden_state = input
for layer in self.layers:
hidden_state = layer(hidden_state)
return hidden_state
class MaskFormerFPNLayer(nn.Module):
......@@ -2101,7 +2116,38 @@ class MaskFormerSinePositionEmbedding(nn.Module):
return pos
class MaskformerMLPPredictionHead(nn.Sequential):
class IdentityBlock(nn.Module):
def __init__(self):
super().__init__()
# Create as an iterable here so that the identity layer isn't registered
# with the name of the instance variable its assigned to
self.layers = [nn.Identity()]
# Maintain submodule indexing as if part of a Sequential block
self.add_module("0", self.layers[0])
def forward(self, input: Tensor) -> Tensor:
hidden_state = input
for layer in self.layers:
hidden_state = layer(hidden_state)
return hidden_state
class NonLinearBlock(nn.Module):
def __init__(self, in_dim: int, out_dim: int) -> None:
super().__init__()
self.layers = [nn.Linear(in_dim, out_dim), nn.ReLU(inplace=True)]
# Maintain submodule indexing as if part of a Sequential block
for i, layer in enumerate(self.layers):
self.add_module(str(i), layer)
def forward(self, input: Tensor) -> Tensor:
hidden_state = input
for layer in self.layers:
hidden_state = layer(hidden_state)
return hidden_state
class MaskformerMLPPredictionHead(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 3):
"""
A classic Multi Layer Perceptron (MLP).
......@@ -2116,18 +2162,27 @@ class MaskformerMLPPredictionHead(nn.Sequential):
num_layers (int, *optional*, defaults to 3):
The number of layers.
"""
super().__init__()
in_dims = [input_dim] + [hidden_dim] * (num_layers - 1)
out_dims = [hidden_dim] * (num_layers - 1) + [output_dim]
layers = []
self.layers = []
for i, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)):
layer = nn.Sequential(
nn.Linear(in_dim, out_dim), nn.ReLU(inplace=True) if i < num_layers - 1 else nn.Identity()
)
layers.append(layer)
super().__init__(*layers)
layer = NonLinearBlock(in_dim, out_dim) if i < num_layers - 1 else IdentityBlock()
self.layers.append(layer)
# Provide backwards compatibility from when the class inherited from nn.Sequential
# In nn.Sequential subclasses, the name given to the layer is its index in the sequence.
# In nn.Module subclasses they derived from the instance attribute they are assigned to e.g.
# self.my_layer_name = Layer()
# We can't give instance attributes integer names i.e. self.0 is not permitted and so need to register
# explicitly
self.add_module(str(i), layer)
def forward(self, input: Tensor) -> Tensor:
hidden_state = input
for layer in self.layers:
hidden_state = layer(hidden_state)
return hidden_state
class MaskFormerPixelLevelModule(nn.Module):
......@@ -2253,20 +2308,21 @@ class MaskFormerPreTrainedModel(PreTrainedModel):
nn.init.constant_(module.input_projection.bias, 0)
# FPN
elif isinstance(module, MaskFormerFPNModel):
nn.init.xavier_uniform_(module.stem[0].weight, gain=xavier_std)
nn.init.xavier_uniform_(module.stem.get_submodule("0").weight, gain=xavier_std)
elif isinstance(module, MaskFormerFPNLayer):
nn.init.xavier_uniform_(module.proj[0].weight, gain=xavier_std)
elif isinstance(module, MaskFormerFPNConvLayer):
nn.init.xavier_uniform_(module[0].weight, gain=xavier_std)
nn.init.xavier_uniform_(module.get_submodule("0").weight, gain=xavier_std)
# The MLP head
elif isinstance(module, MaskformerMLPPredictionHead):
# I was not able to find the correct initializer in the original implementation
# we'll use xavier
for layer in module:
nn.init.xavier_uniform_(layer[0].weight, gain=xavier_std)
nn.init.constant_(layer[0].bias, 0)
for submodule in module.modules():
if isinstance(submodule, nn.Linear):
nn.init.xavier_uniform_(submodule.weight, gain=xavier_std)
nn.init.constant_(submodule.bias, 0)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
......
......@@ -100,7 +100,7 @@ class RegNetEmbeddings(nn.Module):
# Copied from transformers.models.resnet.modeling_resnet.ResNetShortCut with ResNet->RegNet
class RegNetShortCut(nn.Sequential):
class RegNetShortCut(nn.Module):
"""
RegNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
downsample the input using `stride=2`.
......@@ -111,6 +111,11 @@ class RegNetShortCut(nn.Sequential):
self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
self.normalization = nn.BatchNorm2d(out_channels)
def forward(self, input: Tensor) -> Tensor:
hidden_state = self.convolution(input)
hidden_state = self.normalization(hidden_state)
return hidden_state
class RegNetSELayer(nn.Module):
"""
......
......@@ -52,7 +52,7 @@ RESNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
class ResNetConvLayer(nn.Sequential):
class ResNetConvLayer(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu"
):
......@@ -63,8 +63,14 @@ class ResNetConvLayer(nn.Sequential):
self.normalization = nn.BatchNorm2d(out_channels)
self.activation = ACT2FN[activation] if activation is not None else nn.Identity()
def forward(self, input: Tensor) -> Tensor:
hidden_state = self.convolution(input)
hidden_state = self.normalization(hidden_state)
hidden_state = self.activation(hidden_state)
return hidden_state
class ResNetEmbeddings(nn.Sequential):
class ResNetEmbeddings(nn.Module):
"""
ResNet Embeddings (stem) composed of a single aggressive convolution.
"""
......@@ -76,8 +82,13 @@ class ResNetEmbeddings(nn.Sequential):
)
self.pooler = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
def forward(self, input: Tensor) -> Tensor:
embedding = self.embedder(input)
embedding = self.pooler(embedding)
return embedding
class ResNetShortCut(nn.Sequential):
class ResNetShortCut(nn.Module):
"""
ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
downsample the input using `stride=2`.
......@@ -88,6 +99,11 @@ class ResNetShortCut(nn.Sequential):
self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
self.normalization = nn.BatchNorm2d(out_channels)
def forward(self, input: Tensor) -> Tensor:
hidden_state = self.convolution(input)
hidden_state = self.normalization(hidden_state)
return hidden_state
class ResNetBasicLayer(nn.Module):
"""
......@@ -148,7 +164,7 @@ class ResNetBottleNeckLayer(nn.Module):
return hidden_state
class ResNetStage(nn.Sequential):
class ResNetStage(nn.Module):
"""
A ResNet stage composed by stacked layers.
"""
......@@ -165,11 +181,25 @@ class ResNetStage(nn.Sequential):
layer = ResNetBottleNeckLayer if config.layer_type == "bottleneck" else ResNetBasicLayer
self.layers = nn.Sequential(
self.layers = [
# downsampling is done in the first layer with stride of 2
layer(in_channels, out_channels, stride=stride, activation=config.hidden_act),
*[layer(out_channels, out_channels, activation=config.hidden_act) for _ in range(depth - 1)],
)
]
# Provide backwards compatibility from when the class inherited from nn.Sequential
# In nn.Sequential subclasses, the name given to the layer is its index in the sequence.
# In nn.Module subclasses they derived from the instance attribute they are assigned to e.g.
# self.my_layer_name = Layer()
# We can't give instance attributes integer names i.e. self.0 is not permitted and so need to register
# the module explicitly
for i, layer in enumerate(self.layers):
self.add_module(str(i), layer)
def forward(self, input: Tensor) -> Tensor:
hidden_state = input
for layer in self.layers:
hidden_state = layer(hidden_state)
return hidden_state
class ResNetEncoder(nn.Module):
......
......@@ -86,7 +86,7 @@ class VanDropPath(nn.Module):
return drop_path(x, self.drop_prob, self.training)
class VanOverlappingPatchEmbedder(nn.Sequential):
class VanOverlappingPatchEmbedder(nn.Module):
"""
Downsamples the input using a patchify operation with a `stride` of 4 by default making adjacent windows overlap by
half of the area. From [PVTv2: Improved Baselines with Pyramid Vision
......@@ -100,8 +100,13 @@ class VanOverlappingPatchEmbedder(nn.Sequential):
)
self.normalization = nn.BatchNorm2d(hidden_size)
def forward(self, input: torch.Tensor) -> torch.Tensor:
hidden_state = self.convolution(input)
hidden_state = self.normalization(hidden_state)
return hidden_state
class VanMlpLayer(nn.Sequential):
class VanMlpLayer(nn.Module):
"""
MLP with depth-wise convolution, from [PVTv2: Improved Baselines with Pyramid Vision
Transformer](https://arxiv.org/abs/2106.13797).
......@@ -123,8 +128,17 @@ class VanMlpLayer(nn.Sequential):
self.out_dense = nn.Conv2d(hidden_size, out_channels, kernel_size=1)
self.dropout2 = nn.Dropout(dropout_rate)
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = self.in_dense(hidden_state)
hidden_state = self.depth_wise(hidden_state)
hidden_state = self.activation(hidden_state)
hidden_state = self.dropout1(hidden_state)
hidden_state = self.out_dense(hidden_state)
hidden_state = self.dropout2(hidden_state)
return hidden_state
class VanLargeKernelAttention(nn.Sequential):
class VanLargeKernelAttention(nn.Module):
"""
Basic Large Kernel Attention (LKA).
"""
......@@ -137,6 +151,12 @@ class VanLargeKernelAttention(nn.Sequential):
)
self.point_wise = nn.Conv2d(hidden_size, hidden_size, kernel_size=1)
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = self.depth_wise(hidden_state)
hidden_state = self.depth_wise_dilated(hidden_state)
hidden_state = self.point_wise(hidden_state)
return hidden_state
class VanLargeKernelAttentionLayer(nn.Module):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment