"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "c85510f958e6955d88ea1bafb4f320074bfbd0c1"
Unverified Commit 1c57242d authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Fix bug - layer names and activation from previous refactor (#17524)

* Fix activation and layers in MLP head

* Remove unused import
parent babeff55
...@@ -2116,26 +2116,10 @@ class MaskFormerSinePositionEmbedding(nn.Module): ...@@ -2116,26 +2116,10 @@ class MaskFormerSinePositionEmbedding(nn.Module):
return pos return pos
class IdentityBlock(nn.Module): class PredictionBlock(nn.Module):
def __init__(self): def __init__(self, in_dim: int, out_dim: int, activation: nn.Module) -> None:
super().__init__() super().__init__()
# Create as an iterable here so that the identity layer isn't registered self.layers = [nn.Linear(in_dim, out_dim), activation]
# with the name of the instance variable its assigned to
self.layers = [nn.Identity()]
# Maintain submodule indexing as if part of a Sequential block
self.add_module("0", self.layers[0])
def forward(self, input: Tensor) -> Tensor:
hidden_state = input
for layer in self.layers:
hidden_state = layer(hidden_state)
return hidden_state
class NonLinearBlock(nn.Module):
def __init__(self, in_dim: int, out_dim: int) -> None:
super().__init__()
self.layers = [nn.Linear(in_dim, out_dim), nn.ReLU(inplace=True)]
# Maintain submodule indexing as if part of a Sequential block # Maintain submodule indexing as if part of a Sequential block
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
self.add_module(str(i), layer) self.add_module(str(i), layer)
...@@ -2168,7 +2152,8 @@ class MaskformerMLPPredictionHead(nn.Module): ...@@ -2168,7 +2152,8 @@ class MaskformerMLPPredictionHead(nn.Module):
self.layers = [] self.layers = []
for i, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)): for i, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)):
layer = NonLinearBlock(in_dim, out_dim) if i < num_layers - 1 else IdentityBlock() activation = nn.ReLU() if i < num_layers - 1 else nn.Identity()
layer = PredictionBlock(in_dim, out_dim, activation=activation)
self.layers.append(layer) self.layers.append(layer)
# Provide backwards compatibility from when the class inherited from nn.Sequential # Provide backwards compatibility from when the class inherited from nn.Sequential
# In nn.Sequential subclasses, the name given to the layer is its index in the sequence. # In nn.Sequential subclasses, the name given to the layer is its index in the sequence.
......
...@@ -181,19 +181,11 @@ class ResNetStage(nn.Module): ...@@ -181,19 +181,11 @@ class ResNetStage(nn.Module):
layer = ResNetBottleNeckLayer if config.layer_type == "bottleneck" else ResNetBasicLayer layer = ResNetBottleNeckLayer if config.layer_type == "bottleneck" else ResNetBasicLayer
self.layers = [ self.layers = nn.Sequential(
# downsampling is done in the first layer with stride of 2 # downsampling is done in the first layer with stride of 2
layer(in_channels, out_channels, stride=stride, activation=config.hidden_act), layer(in_channels, out_channels, stride=stride, activation=config.hidden_act),
*[layer(out_channels, out_channels, activation=config.hidden_act) for _ in range(depth - 1)], *[layer(out_channels, out_channels, activation=config.hidden_act) for _ in range(depth - 1)],
] )
# Provide backwards compatibility from when the class inherited from nn.Sequential
# In nn.Sequential subclasses, the name given to the layer is its index in the sequence.
# In nn.Module subclasses they derived from the instance attribute they are assigned to e.g.
# self.my_layer_name = Layer()
# We can't give instance attributes integer names i.e. self.0 is not permitted and so need to register
# the module explicitly
for i, layer in enumerate(self.layers):
self.add_module(str(i), layer)
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
hidden_state = input hidden_state = input
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment