Commit 19c66f79 authored by Gus-Guo's avatar Gus-Guo Committed by Shaoshuai Shi
Browse files

support more flexible setting of conv head; slice inputs when batch size is...

support more flexible setting of conv head; slice inputs when batch size is too large in PFNLayer to avoid bugs (#124)

* support more flexible setting

* slice inputs of nn.Linear when batch size is too large
parent 6b3e52c9
......@@ -7,8 +7,8 @@ class BaseBEVBackbone(nn.Module):
super().__init__()
self.model_cfg = model_cfg
assert len(self.model_cfg.LAYER_NUMS) == len(self.model_cfg.LAYER_STRIDES) == \
len(self.model_cfg.NUM_FILTERS) == len(self.model_cfg.NUM_UPSAMPLE_FILTERS)
assert len(self.model_cfg.LAYER_NUMS) == len(self.model_cfg.LAYER_STRIDES) == len(self.model_cfg.NUM_FILTERS)
assert len(self.model_cfg.UPSAMPLE_STRIDES) == len(self.model_cfg.NUM_UPSAMPLE_FILTERS)
layer_nums = self.model_cfg.LAYER_NUMS
layer_strides = self.model_cfg.LAYER_STRIDES
num_filters = self.model_cfg.NUM_FILTERS
......@@ -36,7 +36,7 @@ class BaseBEVBackbone(nn.Module):
nn.ReLU()
])
self.blocks.append(nn.Sequential(*cur_layers))
if len(upsample_strides) > 0:
self.deblocks.append(nn.Sequential(
nn.ConvTranspose2d(
num_filters[idx], num_upsample_filters[idx],
......@@ -73,12 +73,16 @@ class BaseBEVBackbone(nn.Module):
stride = int(spatial_features.shape[2] / x.shape[2])
ret_dict['spatial_features_%dx' % stride] = x
if len(self.deblocks) > 0:
ups.append(self.deblocks[i](x))
else:
ups.append(x)
if len(ups) > 1:
x = torch.cat(ups, dim=1)
else:
elif len(ups) == 1:
x = ups[0]
if len(self.deblocks) > len(self.blocks):
x = self.deblocks[-1](x)
......
......@@ -22,7 +22,15 @@ class PFNLayer(nn.Module):
else:
self.linear = nn.Linear(in_channels, out_channels, bias=True)
self.part = 50000
def forward(self, inputs):
if inputs.shape[0] > self.part:
# nn.Linear performs randomly when batch size is too large
num_parts = inputs.shape[0] // self.part
part_linear_out = [self.linear(inputs[num_part*self.part:(num_part+1)*self.part]) for num_part in range(num_parts+1)]
x = torch.cat(part_linear_out, dim=0)
else:
x = self.linear(inputs)
total_points, voxel_points, channels = x.shape
x = self.norm(x.view(-1, channels)).view(total_points, voxel_points, channels) if self.use_norm else x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment