Unverified Commit 36f8c425 authored by Francesco Saverio Zuppichini's avatar Francesco Saverio Zuppichini Committed by GitHub
Browse files

ResNet: update modules names (#16196)

* updated names

* fit in one line

* typo
parent 5bdf3313
......@@ -96,11 +96,11 @@ class ResNetConvLayer(nn.Sequential):
self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu"
):
super().__init__()
self.conv = nn.Conv2d(
self.convolution = nn.Conv2d(
in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, bias=False
)
self.bn = nn.BatchNorm2d(out_channels)
self.act = ACT2FN[activation] if activation is not None else nn.Identity()
self.normalization = nn.BatchNorm2d(out_channels)
self.activation = ACT2FN[activation] if activation is not None else nn.Identity()
class ResNetEmbeddings(nn.Sequential):
......@@ -111,7 +111,7 @@ class ResNetEmbeddings(nn.Sequential):
def __init__(self, num_channels: int, out_channels: int, activation: str = "relu"):
super().__init__()
self.embedder = ResNetConvLayer(num_channels, out_channels, kernel_size=7, stride=2, activation=activation)
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.pooler = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
class ResNetShortCut(nn.Sequential):
......@@ -122,8 +122,8 @@ class ResNetShortCut(nn.Sequential):
def __init__(self, in_channels: int, out_channels: int, stride: int = 2):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
self.normalization = nn.BatchNorm2d(out_channels)
class ResNetBasicLayer(nn.Module):
......@@ -192,21 +192,20 @@ class ResNetStage(nn.Sequential):
def __init__(
self,
config: ResNetConfig,
in_channels: int,
out_channels: int,
stride: int = 2,
depth: int = 2,
layer_type: str = "basic",
activation: str = "relu",
):
super().__init__()
layer = ResNetBottleNeckLayer if layer_type == "bottleneck" else ResNetBasicLayer
layer = ResNetBottleNeckLayer if config.layer_type == "bottleneck" else ResNetBasicLayer
self.layers = nn.Sequential(
# downsampling is done in the first layer with stride of 2
layer(in_channels, out_channels, stride=stride, activation=activation),
*[layer(out_channels, out_channels, activation=activation) for _ in range(depth - 1)],
layer(in_channels, out_channels, stride=stride, activation=config.hidden_act),
*[layer(out_channels, out_channels, activation=config.hidden_act) for _ in range(depth - 1)],
)
......@@ -217,21 +216,16 @@ class ResNetEncoder(nn.Module):
# based on `downsample_in_first_stage` the first layer of the first stage may or may not downsample the input
self.stages.append(
ResNetStage(
config,
config.embedding_size,
config.hidden_sizes[0],
stride=2 if config.downsample_in_first_stage else 1,
depth=config.depths[0],
layer_type=config.layer_type,
activation=config.hidden_act,
)
)
in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:])
for (in_channels, out_channels), depth in zip(in_out_channels, config.depths[1:]):
self.stages.append(
ResNetStage(
in_channels, out_channels, depth=depth, layer_type=config.layer_type, activation=config.hidden_act
)
)
self.stages.append(ResNetStage(config, in_channels, out_channels, depth=depth))
def forward(
self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment