Unverified Commit cf2578ae authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Refactor to inherit from nn.Module instead of nn.ModuleList (#17501)

* Refactor to inherit from nn.Module instead of nn.ModuleList

* Fix typo

* Empty to trigger CI re-run

Blender Bot tests failing (should be unrelated to this PR) and pass locally). I don't have sufficient permisisons to re-run the CI workflow (totally or from failed)
parent 77ea5130
...@@ -949,7 +949,24 @@ class BeitConvModule(nn.Module): ...@@ -949,7 +949,24 @@ class BeitConvModule(nn.Module):
return output return output
class BeitPyramidPoolingModule(nn.ModuleList): class BeitPyramidPoolingBlock(nn.Module):
def __init__(self, pool_scale: int, in_channels: int, channels: int) -> None:
super().__init__()
self.layers = [
nn.AdaptiveAvgPool2d(pool_scale),
BeitConvModule(in_channels, channels, kernel_size=1),
]
for i, layer in enumerate(self.layers):
self.add_module(str(i), layer)
def forward(self, input: torch.Tensor) -> torch.Tensor:
hidden_state = input
for layer in self.layers:
hidden_state = layer(hidden_state)
return hidden_state
class BeitPyramidPoolingModule(nn.Module):
""" """
Pyramid Pooling Module (PPM) used in PSPNet. Pyramid Pooling Module (PPM) used in PSPNet.
...@@ -969,17 +986,15 @@ class BeitPyramidPoolingModule(nn.ModuleList): ...@@ -969,17 +986,15 @@ class BeitPyramidPoolingModule(nn.ModuleList):
self.align_corners = align_corners self.align_corners = align_corners
self.in_channels = in_channels self.in_channels = in_channels
self.channels = channels self.channels = channels
for pool_scale in pool_scales: self.blocks = []
self.append( for i, pool_scale in enumerate(pool_scales):
nn.Sequential( block = BeitPyramidPoolingBlock(pool_scale=pool_scale, in_channels=in_channels, channels=channels)
nn.AdaptiveAvgPool2d(pool_scale), self.blocks.append(block)
BeitConvModule(self.in_channels, self.channels, kernel_size=1), self.add_module(str(i), block)
)
)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]: def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
ppm_outs = [] ppm_outs = []
for ppm in self: for ppm in self.blocks:
ppm_out = ppm(x) ppm_out = ppm(x)
upsampled_ppm_out = nn.functional.interpolate( upsampled_ppm_out = nn.functional.interpolate(
ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners
......
...@@ -866,8 +866,26 @@ class Data2VecVisionConvModule(nn.Module): ...@@ -866,8 +866,26 @@ class Data2VecVisionConvModule(nn.Module):
return output return output
# Copied from transformers.models.beit.modeling_beit.BeitPyramidPoolingBlock with Beit->Data2VecVision
class Data2VecVisionPyramidPoolingBlock(nn.Module):
def __init__(self, pool_scale: int, in_channels: int, channels: int) -> None:
super().__init__()
self.layers = [
nn.AdaptiveAvgPool2d(pool_scale),
Data2VecVisionConvModule(in_channels, channels, kernel_size=1),
]
for i, layer in enumerate(self.layers):
self.add_module(str(i), layer)
def forward(self, input: torch.Tensor) -> torch.Tensor:
hidden_state = input
for layer in self.layers:
hidden_state = layer(hidden_state)
return hidden_state
# Copied from transformers.models.beit.modeling_beit.BeitPyramidPoolingModule with Beit->Data2VecVision # Copied from transformers.models.beit.modeling_beit.BeitPyramidPoolingModule with Beit->Data2VecVision
class Data2VecVisionPyramidPoolingModule(nn.ModuleList): class Data2VecVisionPyramidPoolingModule(nn.Module):
""" """
Pyramid Pooling Module (PPM) used in PSPNet. Pyramid Pooling Module (PPM) used in PSPNet.
...@@ -887,17 +905,17 @@ class Data2VecVisionPyramidPoolingModule(nn.ModuleList): ...@@ -887,17 +905,17 @@ class Data2VecVisionPyramidPoolingModule(nn.ModuleList):
self.align_corners = align_corners self.align_corners = align_corners
self.in_channels = in_channels self.in_channels = in_channels
self.channels = channels self.channels = channels
for pool_scale in pool_scales: self.blocks = []
self.append( for i, pool_scale in enumerate(pool_scales):
nn.Sequential( block = Data2VecVisionPyramidPoolingBlock(
nn.AdaptiveAvgPool2d(pool_scale), pool_scale=pool_scale, in_channels=in_channels, channels=channels
Data2VecVisionConvModule(self.in_channels, self.channels, kernel_size=1),
)
) )
self.blocks.append(block)
self.add_module(str(i), block)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]: def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
ppm_outs = [] ppm_outs = []
for ppm in self: for ppm in self.blocks:
ppm_out = ppm(x) ppm_out = ppm(x)
upsampled_ppm_out = nn.functional.interpolate( upsampled_ppm_out = nn.functional.interpolate(
ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment