Commit 6aa820ee authored by wuyuefeng's avatar wuyuefeng
Browse files

little changes

parent ceb89df4
...@@ -65,7 +65,7 @@ class SparseUNet(nn.Module): ...@@ -65,7 +65,7 @@ class SparseUNet(nn.Module):
3, 3,
padding=1, padding=1,
bias=False, bias=False,
indice_key='subm1'), ) indice_key='subm1'))
make_block = self.pre_act_block make_block = self.pre_act_block
else: else:
self.conv_input = spconv.SparseSequential( self.conv_input = spconv.SparseSequential(
...@@ -76,9 +76,7 @@ class SparseUNet(nn.Module): ...@@ -76,9 +76,7 @@ class SparseUNet(nn.Module):
padding=1, padding=1,
bias=False, bias=False,
indice_key='subm1'), indice_key='subm1'),
build_norm_layer(norm_cfg, self.base_channels)[1], build_norm_layer(norm_cfg, self.base_channels)[1], nn.ReLU())
nn.ReLU(),
)
make_block = self.post_act_block make_block = self.post_act_block
encoder_out_channels = self.make_encoder_layers( encoder_out_channels = self.make_encoder_layers(
...@@ -95,8 +93,7 @@ class SparseUNet(nn.Module): ...@@ -95,8 +93,7 @@ class SparseUNet(nn.Module):
bias=False, bias=False,
indice_key='spconv_down2'), indice_key='spconv_down2'),
build_norm_layer(norm_cfg, self.output_channels)[1], build_norm_layer(norm_cfg, self.output_channels)[1],
nn.ReLU(), nn.ReLU())
)
def forward(self, voxel_features, coors, batch_size): def forward(self, voxel_features, coors, batch_size):
"""Forward of SparseUNet """Forward of SparseUNet
...@@ -136,13 +133,10 @@ class SparseUNet(nn.Module): ...@@ -136,13 +133,10 @@ class SparseUNet(nn.Module):
decode_features = [] decode_features = []
x = encode_features[-1] x = encode_features[-1]
for i in range(self.stage_num, 0, -1): for i in range(self.stage_num, 0, -1):
x = self.decoder_layer_forward( x = self.decoder_layer_forward(encode_features[i - 1], x,
encode_features[i - 1], getattr(self, f'lateral_layer{i}'),
x, getattr(self, f'merge_layer{i}'),
getattr(self, f'lateral_layer{i}'), getattr(self, f'upsample_layer{i}'))
getattr(self, f'merge_layer{i}'),
getattr(self, f'upsample_layer{i}'),
)
decode_features.append(x) decode_features.append(x)
seg_features = decode_features[-1].features seg_features = decode_features[-1].features
...@@ -320,8 +314,7 @@ class SparseUNet(nn.Module): ...@@ -320,8 +314,7 @@ class SparseUNet(nn.Module):
bias=False, bias=False,
indice_key=indice_key), indice_key=indice_key),
build_norm_layer(norm_cfg, out_channels)[1], build_norm_layer(norm_cfg, out_channels)[1],
nn.ReLU(inplace=True), nn.ReLU(inplace=True))
)
else: else:
raise NotImplementedError raise NotImplementedError
return m return m
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment