Commit 8da6f967 authored by traveller59's avatar traveller59
Browse files

fix #17

parent 20557e83
...@@ -134,7 +134,7 @@ class SparseConvolution(SparseModule): ...@@ -134,7 +134,7 @@ class SparseConvolution(SparseModule):
input.features = torch.mm( input.features = torch.mm(
input.features, input.features,
self.weight.view(self.in_channels, self.out_channels)) self.weight.view(self.in_channels, self.out_channels))
if self.bias: if self.bias is not None:
input.features += self.bias input.features += self.bias
return input return input
datas = input.find_indice_pair(self.indice_key) datas = input.find_indice_pair(self.indice_key)
...@@ -165,7 +165,7 @@ class SparseConvolution(SparseModule): ...@@ -165,7 +165,7 @@ class SparseConvolution(SparseModule):
self.weight, indice_pairs.to(device), self.weight, indice_pairs.to(device),
indice_pair_num, outids.shape[0]) indice_pair_num, outids.shape[0])
if self.bias: if self.bias is not None:
out_features += self.bias out_features += self.bias
out_tensor = spconv.SparseConvTensor(out_features, outids, out_tensor = spconv.SparseConvTensor(out_features, outids,
out_spatial_shape, batch_size) out_spatial_shape, batch_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment