"pcdet/git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "eed89a4d6c6cc6125cf551391a9819a737fd5b20"
Commit 19de178d authored by acivgin1's avatar acivgin1
Browse files

switch to replace_feature method

parent 3817a135
...@@ -50,17 +50,17 @@ class SparseBasicBlock(spconv.SparseModule): ...@@ -50,17 +50,17 @@ class SparseBasicBlock(spconv.SparseModule):
identity = x identity = x
out = self.conv1(x) out = self.conv1(x)
out.features = self.bn1(out.features) out = out.replace_feature(self.bn1(out.features))
out.features = self.relu(out.features) out = out.replace_feature(self.relu(out.features))
out = self.conv2(out) out = self.conv2(out)
out.features = self.bn2(out.features) out = out.replace_feature(self.bn2(out.features))
if self.downsample is not None: if self.downsample is not None:
identity = self.downsample(x) identity = self.downsample(x)
out.features += identity.features out = out.replace_feature(out.features + identity.features)
out.features = self.relu(out.features) out = out.replace_feature(self.relu(out.features))
return out return out
......
...@@ -31,17 +31,17 @@ class SparseBasicBlock(spconv.SparseModule): ...@@ -31,17 +31,17 @@ class SparseBasicBlock(spconv.SparseModule):
assert x.features.dim() == 2, 'x.features.dim()=%d' % x.features.dim() assert x.features.dim() == 2, 'x.features.dim()=%d' % x.features.dim()
out = self.conv1(x) out = self.conv1(x)
out.features = self.bn1(out.features) out = out.replace_feature(self.bn1(out.features))
out.features = self.relu(out.features) out = out.replace_feature(self.relu(out.features))
out = self.conv2(out) out = self.conv2(out)
out.features = self.bn2(out.features) out = out.replace_feature(self.bn2(out.features))
if self.downsample is not None: if self.downsample is not None:
identity = self.downsample(x) identity = self.downsample(x)
out.features += identity out = out.replace_feature(out.features + identity)
out.features = self.relu(out.features) out = out.replace_feature(self.relu(out.features))
return out return out
...@@ -134,10 +134,10 @@ class UNetV2(nn.Module): ...@@ -134,10 +134,10 @@ class UNetV2(nn.Module):
def UR_block_forward(self, x_lateral, x_bottom, conv_t, conv_m, conv_inv): def UR_block_forward(self, x_lateral, x_bottom, conv_t, conv_m, conv_inv):
x_trans = conv_t(x_lateral) x_trans = conv_t(x_lateral)
x = x_trans x = x_trans
x.features = torch.cat((x_bottom.features, x_trans.features), dim=1) x = x.replace_feature(torch.cat((x_bottom.features, x_trans.features), dim=1))
x_m = conv_m(x) x_m = conv_m(x)
x = self.channel_reduction(x, x_m.features.shape[1]) x = self.channel_reduction(x, x_m.features.shape[1])
x.features = x_m.features + x.features x = x.replace_feature(x_m.features + x.features)
x = conv_inv(x) x = conv_inv(x)
return x return x
...@@ -155,7 +155,7 @@ class UNetV2(nn.Module): ...@@ -155,7 +155,7 @@ class UNetV2(nn.Module):
n, in_channels = features.shape n, in_channels = features.shape
assert (in_channels % out_channels == 0) and (in_channels >= out_channels) assert (in_channels % out_channels == 0) and (in_channels >= out_channels)
x.features = features.view(n, out_channels, -1).sum(dim=2) x = x.replace_feature(features.view(n, out_channels, -1).sum(dim=2))
return x return x
def forward(self, batch_dict): def forward(self, batch_dict):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment