Unverified Commit 36f8c425 authored by Francesco Saverio Zuppichini's avatar Francesco Saverio Zuppichini Committed by GitHub
Browse files

ResNet: update modules names (#16196)

* updated names

* fit in one line

* typo
parent 5bdf3313
...@@ -96,11 +96,11 @@ class ResNetConvLayer(nn.Sequential): ...@@ -96,11 +96,11 @@ class ResNetConvLayer(nn.Sequential):
self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu" self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu"
): ):
super().__init__() super().__init__()
self.conv = nn.Conv2d( self.convolution = nn.Conv2d(
in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, bias=False in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, bias=False
) )
self.bn = nn.BatchNorm2d(out_channels) self.normalization = nn.BatchNorm2d(out_channels)
self.act = ACT2FN[activation] if activation is not None else nn.Identity() self.activation = ACT2FN[activation] if activation is not None else nn.Identity()
class ResNetEmbeddings(nn.Sequential): class ResNetEmbeddings(nn.Sequential):
...@@ -111,7 +111,7 @@ class ResNetEmbeddings(nn.Sequential): ...@@ -111,7 +111,7 @@ class ResNetEmbeddings(nn.Sequential):
def __init__(self, num_channels: int, out_channels: int, activation: str = "relu"): def __init__(self, num_channels: int, out_channels: int, activation: str = "relu"):
super().__init__() super().__init__()
self.embedder = ResNetConvLayer(num_channels, out_channels, kernel_size=7, stride=2, activation=activation) self.embedder = ResNetConvLayer(num_channels, out_channels, kernel_size=7, stride=2, activation=activation)
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.pooler = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
class ResNetShortCut(nn.Sequential): class ResNetShortCut(nn.Sequential):
...@@ -122,8 +122,8 @@ class ResNetShortCut(nn.Sequential): ...@@ -122,8 +122,8 @@ class ResNetShortCut(nn.Sequential):
def __init__(self, in_channels: int, out_channels: int, stride: int = 2): def __init__(self, in_channels: int, out_channels: int, stride: int = 2):
super().__init__() super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
self.bn = nn.BatchNorm2d(out_channels) self.normalization = nn.BatchNorm2d(out_channels)
class ResNetBasicLayer(nn.Module): class ResNetBasicLayer(nn.Module):
...@@ -192,21 +192,20 @@ class ResNetStage(nn.Sequential): ...@@ -192,21 +192,20 @@ class ResNetStage(nn.Sequential):
def __init__( def __init__(
self, self,
config: ResNetConfig,
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
stride: int = 2, stride: int = 2,
depth: int = 2, depth: int = 2,
layer_type: str = "basic",
activation: str = "relu",
): ):
super().__init__() super().__init__()
layer = ResNetBottleNeckLayer if layer_type == "bottleneck" else ResNetBasicLayer layer = ResNetBottleNeckLayer if config.layer_type == "bottleneck" else ResNetBasicLayer
self.layers = nn.Sequential( self.layers = nn.Sequential(
# downsampling is done in the first layer with stride of 2 # downsampling is done in the first layer with stride of 2
layer(in_channels, out_channels, stride=stride, activation=activation), layer(in_channels, out_channels, stride=stride, activation=config.hidden_act),
*[layer(out_channels, out_channels, activation=activation) for _ in range(depth - 1)], *[layer(out_channels, out_channels, activation=config.hidden_act) for _ in range(depth - 1)],
) )
...@@ -217,21 +216,16 @@ class ResNetEncoder(nn.Module): ...@@ -217,21 +216,16 @@ class ResNetEncoder(nn.Module):
# based on `downsample_in_first_stage` the first layer of the first stage may or may not downsample the input # based on `downsample_in_first_stage` the first layer of the first stage may or may not downsample the input
self.stages.append( self.stages.append(
ResNetStage( ResNetStage(
config,
config.embedding_size, config.embedding_size,
config.hidden_sizes[0], config.hidden_sizes[0],
stride=2 if config.downsample_in_first_stage else 1, stride=2 if config.downsample_in_first_stage else 1,
depth=config.depths[0], depth=config.depths[0],
layer_type=config.layer_type,
activation=config.hidden_act,
) )
) )
in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:]) in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:])
for (in_channels, out_channels), depth in zip(in_out_channels, config.depths[1:]): for (in_channels, out_channels), depth in zip(in_out_channels, config.depths[1:]):
self.stages.append( self.stages.append(ResNetStage(config, in_channels, out_channels, depth=depth))
ResNetStage(
in_channels, out_channels, depth=depth, layer_type=config.layer_type, activation=config.hidden_act
)
)
def forward( def forward(
self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment