Commit 28089d88 authored by benzlxs's avatar benzlxs
Browse files

add changes for ResNet block

parent 14b8b661
...@@ -23,7 +23,7 @@ from spconv.conv import SparseConvTranspose2d, SparseConvTranspose3d ...@@ -23,7 +23,7 @@ from spconv.conv import SparseConvTranspose2d, SparseConvTranspose3d
from spconv.conv import SparseInverseConv2d, SparseInverseConv3d from spconv.conv import SparseInverseConv2d, SparseInverseConv3d
from spconv.modules import SparseModule, SparseSequential from spconv.modules import SparseModule, SparseSequential
from spconv.pool import SparseMaxPool2d, SparseMaxPool3d from spconv.pool import SparseMaxPool2d, SparseMaxPool3d
from spconv.tables import ConcatTable, JoinTable from spconv.tables import ConcatTable, JoinTable, AddTable
from spconv.identity import Identity from spconv.identity import Identity
from spconv import ops from spconv import ops
......
...@@ -125,6 +125,9 @@ class SparseSequential(SparseModule): ...@@ -125,6 +125,9 @@ class SparseSequential(SparseModule):
def forward(self, input): def forward(self, input):
for k, module in self._modules.items(): for k, module in self._modules.items():
if is_spconv_module(module): # use SpConvTensor as input if is_spconv_module(module): # use SpConvTensor as input
if isinstance(input, list):
input = module(input)
else:
assert isinstance(input, spconv.SparseConvTensor) assert isinstance(input, spconv.SparseConvTensor)
self._sparity_dict[k] = input.sparity self._sparity_dict[k] = input.sparity
input = module(input) input = module(input)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment