Commit 28089d88 authored by benzlxs's avatar benzlxs
Browse files

add changes for ResNet block

parent 14b8b661
......@@ -23,7 +23,7 @@ from spconv.conv import SparseConvTranspose2d, SparseConvTranspose3d
from spconv.conv import SparseInverseConv2d, SparseInverseConv3d
from spconv.modules import SparseModule, SparseSequential
from spconv.pool import SparseMaxPool2d, SparseMaxPool3d
from spconv.tables import ConcatTable, JoinTable
from spconv.tables import ConcatTable, JoinTable, AddTable
from spconv.identity import Identity
from spconv import ops
......
......@@ -72,7 +72,7 @@ class SparseSequential(SparseModule):
('conv2', SparseConv2d(20,64,5)),
('relu2', nn.ReLU())
]))
# Example of using Sequential with kwargs(python 3.6+)
model = SparseSequential(
conv1=SparseConv2d(1,20,5),
......@@ -125,9 +125,12 @@ class SparseSequential(SparseModule):
def forward(self, input):
for k, module in self._modules.items():
if is_spconv_module(module): # use SpConvTensor as input
assert isinstance(input, spconv.SparseConvTensor)
self._sparity_dict[k] = input.sparity
input = module(input)
if isinstance(input, list):
input = module(input)
else:
assert isinstance(input, spconv.SparseConvTensor)
self._sparity_dict[k] = input.sparity
input = module(input)
else:
if isinstance(input, spconv.SparseConvTensor):
if input.indices.shape[0] != 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment