"...csrc/git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "39f9920849dc772abd65fff1296cb15de8d64f1e"
Unverified Commit cf2578ae authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Refactor to inherit from nn.Module instead of nn.ModuleList (#17501)

* Refactor to inherit from nn.Module instead of nn.ModuleList

* Fix typo

* Empty to trigger CI re-run

Blender Bot tests failing (should be unrelated to this PR) and pass locally). I don't have sufficient permisisons to re-run the CI workflow (totally or from failed)
parent 77ea5130
......@@ -949,7 +949,24 @@ class BeitConvModule(nn.Module):
return output
class BeitPyramidPoolingModule(nn.ModuleList):
class BeitPyramidPoolingBlock(nn.Module):
def __init__(self, pool_scale: int, in_channels: int, channels: int) -> None:
super().__init__()
self.layers = [
nn.AdaptiveAvgPool2d(pool_scale),
BeitConvModule(in_channels, channels, kernel_size=1),
]
for i, layer in enumerate(self.layers):
self.add_module(str(i), layer)
def forward(self, input: torch.Tensor) -> torch.Tensor:
hidden_state = input
for layer in self.layers:
hidden_state = layer(hidden_state)
return hidden_state
class BeitPyramidPoolingModule(nn.Module):
"""
Pyramid Pooling Module (PPM) used in PSPNet.
......@@ -969,17 +986,15 @@ class BeitPyramidPoolingModule(nn.ModuleList):
self.align_corners = align_corners
self.in_channels = in_channels
self.channels = channels
for pool_scale in pool_scales:
self.append(
nn.Sequential(
nn.AdaptiveAvgPool2d(pool_scale),
BeitConvModule(self.in_channels, self.channels, kernel_size=1),
)
)
self.blocks = []
for i, pool_scale in enumerate(pool_scales):
block = BeitPyramidPoolingBlock(pool_scale=pool_scale, in_channels=in_channels, channels=channels)
self.blocks.append(block)
self.add_module(str(i), block)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
ppm_outs = []
for ppm in self:
for ppm in self.blocks:
ppm_out = ppm(x)
upsampled_ppm_out = nn.functional.interpolate(
ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners
......
......@@ -866,8 +866,26 @@ class Data2VecVisionConvModule(nn.Module):
return output
# Copied from transformers.models.beit.modeling_beit.BeitPyramidPoolingBlock with Beit->Data2VecVision
class Data2VecVisionPyramidPoolingBlock(nn.Module):
def __init__(self, pool_scale: int, in_channels: int, channels: int) -> None:
super().__init__()
self.layers = [
nn.AdaptiveAvgPool2d(pool_scale),
Data2VecVisionConvModule(in_channels, channels, kernel_size=1),
]
for i, layer in enumerate(self.layers):
self.add_module(str(i), layer)
def forward(self, input: torch.Tensor) -> torch.Tensor:
hidden_state = input
for layer in self.layers:
hidden_state = layer(hidden_state)
return hidden_state
# Copied from transformers.models.beit.modeling_beit.BeitPyramidPoolingModule with Beit->Data2VecVision
class Data2VecVisionPyramidPoolingModule(nn.ModuleList):
class Data2VecVisionPyramidPoolingModule(nn.Module):
"""
Pyramid Pooling Module (PPM) used in PSPNet.
......@@ -887,17 +905,17 @@ class Data2VecVisionPyramidPoolingModule(nn.ModuleList):
self.align_corners = align_corners
self.in_channels = in_channels
self.channels = channels
for pool_scale in pool_scales:
self.append(
nn.Sequential(
nn.AdaptiveAvgPool2d(pool_scale),
Data2VecVisionConvModule(self.in_channels, self.channels, kernel_size=1),
)
self.blocks = []
for i, pool_scale in enumerate(pool_scales):
block = Data2VecVisionPyramidPoolingBlock(
pool_scale=pool_scale, in_channels=in_channels, channels=channels
)
self.blocks.append(block)
self.add_module(str(i), block)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
ppm_outs = []
for ppm in self:
for ppm in self.blocks:
ppm_out = ppm(x)
upsampled_ppm_out = nn.functional.interpolate(
ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment