Commit 19c66f79 authored by Gus-Guo's avatar Gus-Guo Committed by Shaoshuai Shi
Browse files

support more flexible setting of conv head; slice inputs when batch size is...

support more flexible setting of conv head; slice inputs when batch size is too large in PFNLayer to avoid bugs (#124)

* support more flexible setting

* slice inputs of nn.Linear when batch size is too large
parent 6b3e52c9
...@@ -7,8 +7,8 @@ class BaseBEVBackbone(nn.Module): ...@@ -7,8 +7,8 @@ class BaseBEVBackbone(nn.Module):
super().__init__() super().__init__()
self.model_cfg = model_cfg self.model_cfg = model_cfg
assert len(self.model_cfg.LAYER_NUMS) == len(self.model_cfg.LAYER_STRIDES) == \ assert len(self.model_cfg.LAYER_NUMS) == len(self.model_cfg.LAYER_STRIDES) == len(self.model_cfg.NUM_FILTERS)
len(self.model_cfg.NUM_FILTERS) == len(self.model_cfg.NUM_UPSAMPLE_FILTERS) assert len(self.model_cfg.UPSAMPLE_STRIDES) == len(self.model_cfg.NUM_UPSAMPLE_FILTERS)
layer_nums = self.model_cfg.LAYER_NUMS layer_nums = self.model_cfg.LAYER_NUMS
layer_strides = self.model_cfg.LAYER_STRIDES layer_strides = self.model_cfg.LAYER_STRIDES
num_filters = self.model_cfg.NUM_FILTERS num_filters = self.model_cfg.NUM_FILTERS
...@@ -36,16 +36,16 @@ class BaseBEVBackbone(nn.Module): ...@@ -36,16 +36,16 @@ class BaseBEVBackbone(nn.Module):
nn.ReLU() nn.ReLU()
]) ])
self.blocks.append(nn.Sequential(*cur_layers)) self.blocks.append(nn.Sequential(*cur_layers))
if len(upsample_strides) > 0:
self.deblocks.append(nn.Sequential( self.deblocks.append(nn.Sequential(
nn.ConvTranspose2d( nn.ConvTranspose2d(
num_filters[idx], num_upsample_filters[idx], num_filters[idx], num_upsample_filters[idx],
upsample_strides[idx], upsample_strides[idx],
stride=upsample_strides[idx], bias=False stride=upsample_strides[idx], bias=False
), ),
nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01), nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01),
nn.ReLU() nn.ReLU()
)) ))
c_in = sum(num_upsample_filters) c_in = sum(num_upsample_filters)
if len(upsample_strides) > num_levels: if len(upsample_strides) > num_levels:
...@@ -73,12 +73,16 @@ class BaseBEVBackbone(nn.Module): ...@@ -73,12 +73,16 @@ class BaseBEVBackbone(nn.Module):
stride = int(spatial_features.shape[2] / x.shape[2]) stride = int(spatial_features.shape[2] / x.shape[2])
ret_dict['spatial_features_%dx' % stride] = x ret_dict['spatial_features_%dx' % stride] = x
ups.append(self.deblocks[i](x)) if len(self.deblocks) > 0:
ups.append(self.deblocks[i](x))
else:
ups.append(x)
if len(ups) > 1: if len(ups) > 1:
x = torch.cat(ups, dim=1) x = torch.cat(ups, dim=1)
else: elif len(ups) == 1:
x = ups[0] x = ups[0]
if len(self.deblocks) > len(self.blocks): if len(self.deblocks) > len(self.blocks):
x = self.deblocks[-1](x) x = self.deblocks[-1](x)
......
...@@ -22,8 +22,16 @@ class PFNLayer(nn.Module): ...@@ -22,8 +22,16 @@ class PFNLayer(nn.Module):
else: else:
self.linear = nn.Linear(in_channels, out_channels, bias=True) self.linear = nn.Linear(in_channels, out_channels, bias=True)
self.part = 50000
def forward(self, inputs): def forward(self, inputs):
x = self.linear(inputs) if inputs.shape[0] > self.part:
# nn.Linear performs randomly when batch size is too large
num_parts = inputs.shape[0] // self.part
part_linear_out = [self.linear(inputs[num_part*self.part:(num_part+1)*self.part]) for num_part in range(num_parts+1)]
x = torch.cat(part_linear_out, dim=0)
else:
x = self.linear(inputs)
total_points, voxel_points, channels = x.shape total_points, voxel_points, channels = x.shape
x = self.norm(x.view(-1, channels)).view(total_points, voxel_points, channels) if self.use_norm else x x = self.norm(x.view(-1, channels)).view(total_points, voxel_points, channels) if self.use_norm else x
x = F.relu(x) x = F.relu(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment