Commit 14b8b661 authored by benzlxs's avatar benzlxs
Browse files

add ConcaTable, JoinTable, Identity functions

parent 16491dc0
# Copyright 2019 Yan Yan
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -23,6 +23,8 @@ from spconv.conv import SparseConvTranspose2d, SparseConvTranspose3d
from spconv.conv import SparseInverseConv2d, SparseInverseConv3d
from spconv.modules import SparseModule, SparseSequential
from spconv.pool import SparseMaxPool2d, SparseMaxPool3d
from spconv.tables import ConcatTable, JoinTable
from spconv.identity import Identity
from spconv import ops
......@@ -55,7 +57,7 @@ class SparseConvTensor(object):
is very large.
"""
self.features = features
self.indices = indices
self.indices = indices
if self.indices.dtype != torch.int32:
self.indices.int()
self.spatial_shape = spatial_shape
......@@ -69,7 +71,7 @@ class SparseConvTensor(object):
def find_indice_pair(self, key):
if key is None:
return None
return None
if key in self.indice_dict:
return self.indice_dict[key]
return None
......@@ -100,4 +102,4 @@ class RemoveGrid(SparseModule):
"""
def forward(self, x: SparseConvTensor):
x.grid = None
return x
\ No newline at end of file
return x
# Copyright 2016-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from torch.nn import Module
class Identity(Module):
def forward(self, input):
return input
def input_spatial_size(self, out_size):
return out_size
from torch.autograd import Function
#from torch.nn import Module
from spconv.modules import SparseModule
import spconv
import torch
class JoinTable(SparseModule):# Module):
def forward(self, input):
output = spconv.SparseConvTensor(
torch.cat([i.features for i in input],1), input[1].indices,
input[1].spatial_shape, input[0].batch_size )
output.indice_dict = input[1].indice_dict
output.grid = input[1].grid
return output
def input_spatial_size(self, out_size):
return out_size
class AddTable(SparseModule): # Module):
def forward(self, input):
output = spconv.SparseConvTensor(
sum([i.features for i in input]), input[1].indices,
input[1].spatial_shape, input[1].batch_size )
output.indice_dict = input[1].indice_dict
output.grid = input[1].grid
return output
def input_spatial_size(self, out_size):
return out_size
class ConcatTable(SparseModule): # Module):
def forward(self, input):
return [module(input) for module in self._modules.values()]
def add(self, module):
self._modules[str(len(self._modules))] = module
return self
def input_spatial_size(self, out_size):
return self._modules['0'].input_spatial_size(out_size)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment