Commit e6f6151c authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

add RETURN_ENCODED_TENSOR config for Spconv UNet

parent b6fbc50d
......@@ -91,16 +91,18 @@ class UNetV2(nn.Module):
block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm4'),
)
last_pad = 0
last_pad = self.model_cfg.get('last_pad', last_pad)
self.conv_out = spconv.SparseSequential(
# [200, 150, 5] -> [200, 150, 2]
spconv.SparseConv3d(64, 128, (3, 1, 1), stride=(2, 1, 1), padding=last_pad,
bias=False, indice_key='spconv_down2'),
norm_fn(128),
nn.ReLU(),
)
if self.model_cfg.get('RETURN_ENCODED_TENSOR', True):
last_pad = self.model_cfg.get('last_pad', 0)
self.conv_out = spconv.SparseSequential(
# [200, 150, 5] -> [200, 150, 2]
spconv.SparseConv3d(64, 128, (3, 1, 1), stride=(2, 1, 1), padding=last_pad,
bias=False, indice_key='spconv_down2'),
norm_fn(128),
nn.ReLU(),
)
else:
self.conv_out = None
# decoder
# [400, 352, 11] <- [200, 176, 5]
......@@ -181,9 +183,12 @@ class UNetV2(nn.Module):
x_conv3 = self.conv3(x_conv2)
x_conv4 = self.conv4(x_conv3)
# for detection head
# [200, 176, 5] -> [200, 176, 2]
out = self.conv_out(x_conv4)
if self.conv_out is not None:
# for detection head
# [200, 176, 5] -> [200, 176, 2]
out = self.conv_out(x_conv4)
batch_dict['encoded_spconv_tensor'] = out
batch_dict['encoded_spconv_tensor_stride'] = 8
# for segmentation head
# [400, 352, 11] <- [200, 176, 5]
......@@ -201,6 +206,4 @@ class UNetV2(nn.Module):
point_cloud_range=self.point_cloud_range
)
batch_dict['point_coords'] = torch.cat((x_up1.indices[:, 0:1].float(), point_coords), dim=1)
batch_dict['encoded_spconv_tensor'] = out
batch_dict['encoded_spconv_tensor_stride'] = 8
return batch_dict
......@@ -12,6 +12,7 @@ MODEL:
BACKBONE_3D:
NAME: UNetV2
RETURN_ENCODED_TENSOR: False
POINT_HEAD:
NAME: PointHeadBox
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment