Commit 19e73bbe authored by Yan Yan's avatar Yan Yan
Browse files

format code with clang-format, better c++ code

parent c336139f
......@@ -10,8 +10,8 @@
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
......@@ -24,7 +24,6 @@
#ifndef TSL_ROBIN_GROWTH_POLICY_H
#define TSL_ROBIN_GROWTH_POLICY_H
#include <algorithm>
#include <array>
#include <climits>
......@@ -35,65 +34,70 @@
#include <ratio>
#include <stdexcept>
#ifdef TSL_DEBUG
# define tsl_rh_assert(expr) assert(expr)
#define tsl_rh_assert(expr) assert(expr)
#else
# define tsl_rh_assert(expr) (static_cast<void>(0))
#define tsl_rh_assert(expr) (static_cast<void>(0))
#endif
/**
* If exceptions are enabled, throw the exception passed in parameter, otherwise call std::terminate.
* If exceptions are enabled, throw the exception passed in parameter, otherwise
* call std::terminate.
*/
#if (defined(__cpp_exceptions) || defined(__EXCEPTIONS) || (defined (_MSC_VER) && defined (_CPPUNWIND))) && !defined(TSL_NO_EXCEPTIONS)
# define TSL_RH_THROW_OR_TERMINATE(ex, msg) throw ex(msg)
#if (defined(__cpp_exceptions) || defined(__EXCEPTIONS) || \
(defined(_MSC_VER) && defined(_CPPUNWIND))) && \
!defined(TSL_NO_EXCEPTIONS)
#define TSL_RH_THROW_OR_TERMINATE(ex, msg) throw ex(msg)
#else
#ifdef NDEBUG
#define TSL_RH_THROW_OR_TERMINATE(ex, msg) std::terminate()
#else
# ifdef NDEBUG
# define TSL_RH_THROW_OR_TERMINATE(ex, msg) std::terminate()
# else
# include <cstdio>
# define TSL_RH_THROW_OR_TERMINATE(ex, msg) do { std::fprintf(stderr, msg); std::terminate(); } while(0)
# endif
#include <cstdio>
#define TSL_RH_THROW_OR_TERMINATE(ex, msg) \
do { \
std::fprintf(stderr, msg); \
std::terminate(); \
} while (0)
#endif
#endif
#if defined(__GNUC__) || defined(__clang__)
# define TSL_RH_LIKELY(exp) (__builtin_expect(!!(exp), true))
#define TSL_RH_LIKELY(exp) (__builtin_expect(!!(exp), true))
#else
# define TSL_RH_LIKELY(exp) (exp)
#define TSL_RH_LIKELY(exp) (exp)
#endif
namespace tsl {
namespace rh {
/**
* Grow the hash table by a factor of GrowthFactor keeping the bucket count to a power of two. It allows
* the table to use a mask operation instead of a modulo operation to map a hash to a bucket.
* Grow the hash table by a factor of GrowthFactor keeping the bucket count to a
* power of two. It allows the table to use a mask operation instead of a modulo
* operation to map a hash to a bucket.
*
* GrowthFactor must be a power of two >= 2.
*/
template<std::size_t GrowthFactor>
class power_of_two_growth_policy {
template <std::size_t GrowthFactor> class power_of_two_growth_policy {
public:
/**
* Called on the hash table creation and on rehash. The number of buckets for the table is passed in parameter.
* This number is a minimum, the policy may update this value with a higher value if needed (but not lower).
* Called on the hash table creation and on rehash. The number of buckets for
* the table is passed in parameter. This number is a minimum, the policy may
* update this value with a higher value if needed (but not lower).
*
* If 0 is given, min_bucket_count_in_out must still be 0 after the policy creation and
* bucket_for_hash must always return 0 in this case.
* If 0 is given, min_bucket_count_in_out must still be 0 after the policy
* creation and bucket_for_hash must always return 0 in this case.
*/
explicit power_of_two_growth_policy(std::size_t& min_bucket_count_in_out) {
if(min_bucket_count_in_out > max_bucket_count()) {
TSL_RH_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size.");
explicit power_of_two_growth_policy(std::size_t &min_bucket_count_in_out) {
if (min_bucket_count_in_out > max_bucket_count()) {
TSL_RH_THROW_OR_TERMINATE(std::length_error,
"The hash table exceeds its maxmimum size.");
}
if(min_bucket_count_in_out > 0) {
min_bucket_count_in_out = round_up_to_power_of_two(min_bucket_count_in_out);
if (min_bucket_count_in_out > 0) {
min_bucket_count_in_out =
round_up_to_power_of_two(min_bucket_count_in_out);
m_mask = min_bucket_count_in_out - 1;
}
else {
} else {
m_mask = 0;
}
}
......@@ -110,8 +114,9 @@ public:
* Return the number of buckets that should be used on next growth.
*/
std::size_t next_bucket_count() const {
if((m_mask + 1) > max_bucket_count() / GrowthFactor) {
TSL_RH_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size.");
if ((m_mask + 1) > max_bucket_count() / GrowthFactor) {
TSL_RH_THROW_OR_TERMINATE(std::length_error,
"The hash table exceeds its maxmimum size.");
}
return (m_mask + 1) * GrowthFactor;
......@@ -127,24 +132,23 @@ public:
/**
* Reset the growth policy as if it was created with a bucket count of 0.
* After a clear, the policy must always return 0 when bucket_for_hash is called.
* After a clear, the policy must always return 0 when bucket_for_hash is
* called.
*/
void clear() noexcept {
m_mask = 0;
}
void clear() noexcept { m_mask = 0; }
private:
static std::size_t round_up_to_power_of_two(std::size_t value) {
if(is_power_of_two(value)) {
if (is_power_of_two(value)) {
return value;
}
if(value == 0) {
if (value == 0) {
return 1;
}
--value;
for(std::size_t i = 1; i < sizeof(std::size_t) * CHAR_BIT; i *= 2) {
for (std::size_t i = 1; i < sizeof(std::size_t) * CHAR_BIT; i *= 2) {
value |= value >> i;
}
......@@ -156,28 +160,28 @@ private:
}
protected:
static_assert(is_power_of_two(GrowthFactor) && GrowthFactor >= 2, "GrowthFactor must be a power of two >= 2.");
static_assert(is_power_of_two(GrowthFactor) && GrowthFactor >= 2,
"GrowthFactor must be a power of two >= 2.");
std::size_t m_mask;
};
/**
* Grow the hash table by GrowthFactor::num / GrowthFactor::den and use a modulo to map a hash
* to a bucket. Slower but it can be useful if you want a slower growth.
* Grow the hash table by GrowthFactor::num / GrowthFactor::den and use a modulo
* to map a hash to a bucket. Slower but it can be useful if you want a slower
* growth.
*/
template<class GrowthFactor = std::ratio<3, 2>>
class mod_growth_policy {
template <class GrowthFactor = std::ratio<3, 2>> class mod_growth_policy {
public:
explicit mod_growth_policy(std::size_t& min_bucket_count_in_out) {
if(min_bucket_count_in_out > max_bucket_count()) {
TSL_RH_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size.");
explicit mod_growth_policy(std::size_t &min_bucket_count_in_out) {
if (min_bucket_count_in_out > max_bucket_count()) {
TSL_RH_THROW_OR_TERMINATE(std::length_error,
"The hash table exceeds its maxmimum size.");
}
if(min_bucket_count_in_out > 0) {
if (min_bucket_count_in_out > 0) {
m_mod = min_bucket_count_in_out;
}
else {
} else {
m_mod = 1;
}
}
......@@ -187,73 +191,79 @@ public:
}
std::size_t next_bucket_count() const {
if(m_mod == max_bucket_count()) {
TSL_RH_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size.");
if (m_mod == max_bucket_count()) {
TSL_RH_THROW_OR_TERMINATE(std::length_error,
"The hash table exceeds its maxmimum size.");
}
const double next_bucket_count = std::ceil(double(m_mod) * REHASH_SIZE_MULTIPLICATION_FACTOR);
if(!std::isnormal(next_bucket_count)) {
TSL_RH_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size.");
const double next_bucket_count =
std::ceil(double(m_mod) * REHASH_SIZE_MULTIPLICATION_FACTOR);
if (!std::isnormal(next_bucket_count)) {
TSL_RH_THROW_OR_TERMINATE(std::length_error,
"The hash table exceeds its maxmimum size.");
}
if(next_bucket_count > double(max_bucket_count())) {
if (next_bucket_count > double(max_bucket_count())) {
return max_bucket_count();
}
else {
} else {
return std::size_t(next_bucket_count);
}
}
std::size_t max_bucket_count() const {
return MAX_BUCKET_COUNT;
}
std::size_t max_bucket_count() const { return MAX_BUCKET_COUNT; }
void clear() noexcept {
m_mod = 1;
}
void clear() noexcept { m_mod = 1; }
private:
static constexpr double REHASH_SIZE_MULTIPLICATION_FACTOR = 1.0 * GrowthFactor::num / GrowthFactor::den;
static constexpr double REHASH_SIZE_MULTIPLICATION_FACTOR =
1.0 * GrowthFactor::num / GrowthFactor::den;
static const std::size_t MAX_BUCKET_COUNT =
std::size_t(double(
std::numeric_limits<std::size_t>::max() / REHASH_SIZE_MULTIPLICATION_FACTOR
));
std::size_t(double(std::numeric_limits<std::size_t>::max() /
REHASH_SIZE_MULTIPLICATION_FACTOR));
static_assert(REHASH_SIZE_MULTIPLICATION_FACTOR >= 1.1, "Growth factor should be >= 1.1.");
static_assert(REHASH_SIZE_MULTIPLICATION_FACTOR >= 1.1,
"Growth factor should be >= 1.1.");
std::size_t m_mod;
};
namespace detail {
static constexpr const std::array<std::size_t, 40> PRIMES = {{
1ul, 5ul, 17ul, 29ul, 37ul, 53ul, 67ul, 79ul, 97ul, 131ul, 193ul, 257ul, 389ul, 521ul, 769ul, 1031ul,
1543ul, 2053ul, 3079ul, 6151ul, 12289ul, 24593ul, 49157ul, 98317ul, 196613ul, 393241ul, 786433ul,
1572869ul, 3145739ul, 6291469ul, 12582917ul, 25165843ul, 50331653ul, 100663319ul, 201326611ul,
402653189ul, 805306457ul, 1610612741ul, 3221225473ul, 4294967291ul
}};
template<unsigned int IPrime>
static constexpr std::size_t mod(std::size_t hash) { return hash % PRIMES[IPrime]; }
// MOD_PRIME[iprime](hash) returns hash % PRIMES[iprime]. This table allows for faster modulo as the
// compiler can optimize the modulo code better with a constant known at the compilation.
static constexpr const std::array<std::size_t(*)(std::size_t), 40> MOD_PRIME = {{
&mod<0>, &mod<1>, &mod<2>, &mod<3>, &mod<4>, &mod<5>, &mod<6>, &mod<7>, &mod<8>, &mod<9>, &mod<10>,
&mod<11>, &mod<12>, &mod<13>, &mod<14>, &mod<15>, &mod<16>, &mod<17>, &mod<18>, &mod<19>, &mod<20>,
&mod<21>, &mod<22>, &mod<23>, &mod<24>, &mod<25>, &mod<26>, &mod<27>, &mod<28>, &mod<29>, &mod<30>,
&mod<31>, &mod<32>, &mod<33>, &mod<34>, &mod<35>, &mod<36>, &mod<37> , &mod<38>, &mod<39>
}};
static constexpr const std::array<std::size_t, 40> PRIMES = {
{1ul, 5ul, 17ul, 29ul, 37ul,
53ul, 67ul, 79ul, 97ul, 131ul,
193ul, 257ul, 389ul, 521ul, 769ul,
1031ul, 1543ul, 2053ul, 3079ul, 6151ul,
12289ul, 24593ul, 49157ul, 98317ul, 196613ul,
393241ul, 786433ul, 1572869ul, 3145739ul, 6291469ul,
12582917ul, 25165843ul, 50331653ul, 100663319ul, 201326611ul,
402653189ul, 805306457ul, 1610612741ul, 3221225473ul, 4294967291ul}};
template <unsigned int IPrime>
static constexpr std::size_t mod(std::size_t hash) {
return hash % PRIMES[IPrime];
}
// MOD_PRIME[iprime](hash) returns hash % PRIMES[iprime]. This table allows for
// faster modulo as the compiler can optimize the modulo code better with a
// constant known at the compilation.
static constexpr const std::array<std::size_t (*)(std::size_t), 40> MOD_PRIME =
{{&mod<0>, &mod<1>, &mod<2>, &mod<3>, &mod<4>, &mod<5>, &mod<6>,
&mod<7>, &mod<8>, &mod<9>, &mod<10>, &mod<11>, &mod<12>, &mod<13>,
&mod<14>, &mod<15>, &mod<16>, &mod<17>, &mod<18>, &mod<19>, &mod<20>,
&mod<21>, &mod<22>, &mod<23>, &mod<24>, &mod<25>, &mod<26>, &mod<27>,
&mod<28>, &mod<29>, &mod<30>, &mod<31>, &mod<32>, &mod<33>, &mod<34>,
&mod<35>, &mod<36>, &mod<37>, &mod<38>, &mod<39>}};
} // namespace detail
/**
* Grow the hash table by using prime numbers as bucket count. Slower than tsl::rh::power_of_two_growth_policy in
* general but will probably distribute the values around better in the buckets with a poor hash function.
* Grow the hash table by using prime numbers as bucket count. Slower than
* tsl::rh::power_of_two_growth_policy in general but will probably distribute
* the values around better in the buckets with a poor hash function.
*
* To allow the compiler to optimize the modulo operation, a lookup table is used with constant primes numbers.
* To allow the compiler to optimize the modulo operation, a lookup table is
* used with constant primes numbers.
*
* With a switch the code would look like:
* \code
......@@ -268,25 +278,27 @@ static constexpr const std::array<std::size_t(*)(std::size_t), 40> MOD_PRIME = {
* }
* \endcode
*
* Due to the constant variable in the modulo the compiler is able to optimize the operation
* by a series of multiplications, substractions and shifts.
* Due to the constant variable in the modulo the compiler is able to optimize
* the operation by a series of multiplications, substractions and shifts.
*
* The 'hash % 5' could become something like 'hash - (hash * 0xCCCCCCCD) >> 34) * 5' in a 64 bits environement.
* The 'hash % 5' could become something like 'hash - (hash * 0xCCCCCCCD) >> 34)
* * 5' in a 64 bits environement.
*/
class prime_growth_policy {
public:
explicit prime_growth_policy(std::size_t& min_bucket_count_in_out) {
auto it_prime = std::lower_bound(detail::PRIMES.begin(),
detail::PRIMES.end(), min_bucket_count_in_out);
if(it_prime == detail::PRIMES.end()) {
TSL_RH_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size.");
explicit prime_growth_policy(std::size_t &min_bucket_count_in_out) {
auto it_prime = std::lower_bound(
detail::PRIMES.begin(), detail::PRIMES.end(), min_bucket_count_in_out);
if (it_prime == detail::PRIMES.end()) {
TSL_RH_THROW_OR_TERMINATE(std::length_error,
"The hash table exceeds its maxmimum size.");
}
m_iprime = static_cast<unsigned int>(std::distance(detail::PRIMES.begin(), it_prime));
if(min_bucket_count_in_out > 0) {
m_iprime = static_cast<unsigned int>(
std::distance(detail::PRIMES.begin(), it_prime));
if (min_bucket_count_in_out > 0) {
min_bucket_count_in_out = *it_prime;
}
else {
} else {
min_bucket_count_in_out = 0;
}
}
......@@ -296,29 +308,27 @@ public:
}
std::size_t next_bucket_count() const {
if(m_iprime + 1 >= detail::PRIMES.size()) {
TSL_RH_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size.");
if (m_iprime + 1 >= detail::PRIMES.size()) {
TSL_RH_THROW_OR_TERMINATE(std::length_error,
"The hash table exceeds its maxmimum size.");
}
return detail::PRIMES[m_iprime + 1];
}
std::size_t max_bucket_count() const {
return detail::PRIMES.back();
}
std::size_t max_bucket_count() const { return detail::PRIMES.back(); }
void clear() noexcept {
m_iprime = 0;
}
void clear() noexcept { m_iprime = 0; }
private:
unsigned int m_iprime;
static_assert(std::numeric_limits<decltype(m_iprime)>::max() >= detail::PRIMES.size(),
static_assert(std::numeric_limits<decltype(m_iprime)>::max() >=
detail::PRIMES.size(),
"The type of m_iprime is not big enough.");
};
}
}
} // namespace rh
} // namespace tsl
#endif
This diff is collapsed.
This diff is collapsed.
......@@ -14,14 +14,14 @@
#pragma once
#include <chrono>
#ifdef SPCONV_CUDA
#ifdef TV_CUDA
#include <cuda_runtime_api.h>
#endif
#include <iostream>
namespace spconv {
#ifdef SPCONV_CUDA
#ifdef TV_CUDA
template <typename TimeT = std::chrono::microseconds> struct CudaContextTimer {
CudaContextTimer() {
cudaDeviceSynchronize();
......
import os
import re
import sys
import platform
import re
import subprocess
import torch
from setuptools import setup, Extension, find_packages
from setuptools.command.build_ext import build_ext
import sys
from distutils.version import LooseVersion
from pathlib import Path
import torch
from setuptools import Extension, find_packages, setup
from setuptools.command.build_ext import build_ext
# if 'LIBTORCH_ROOT' not in os.environ:
# raise ValueError("You must set LIBTORCH_ROOT to your torch c++ library.")
......@@ -100,4 +100,3 @@ setup(
cmdclass=dict(build_ext=CMakeBuild),
zip_safe=False,
)
......@@ -12,21 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
from pathlib import Path
import platform
import numpy as np
import torch
from spconv import utils
from spconv.conv import SparseConv2d, SparseConv3d, SubMConv2d, SubMConv3d
from spconv.conv import SparseConvTranspose2d, SparseConvTranspose3d
from spconv.conv import SparseInverseConv2d, SparseInverseConv3d
from spconv import ops, utils
from spconv.conv import (SparseConv2d, SparseConv3d, SparseConvTranspose2d,
SparseConvTranspose3d, SparseInverseConv2d,
SparseInverseConv3d, SubMConv2d, SubMConv3d)
from spconv.identity import Identity
from spconv.modules import SparseModule, SparseSequential
from spconv.pool import SparseMaxPool2d, SparseMaxPool3d
from spconv.tables import ConcatTable, JoinTable, AddTable
from spconv.identity import Identity
from spconv import ops
from spconv.tables import AddTable, ConcatTable, JoinTable
_LIB_FILE_NAME = "libspconv.so"
if platform.system() == "Windows":
......@@ -34,6 +33,7 @@ if platform.system() == "Windows":
_LIB_PATH = str(Path(__file__).parent / _LIB_FILE_NAME)
torch.ops.load_library(_LIB_PATH)
def scatter_nd(indices, updates, shape):
"""pytorch edition of tensorflow scatter_nd.
this function don't contain except handle code. so use this carefully
......@@ -49,8 +49,10 @@ def scatter_nd(indices, updates, shape):
ret[slices] = updates.view(*output_shape)
return ret
class SparseConvTensor(object):
def __init__(self, features, indices, spatial_shape, batch_size, grid=None):
def __init__(self, features, indices, spatial_shape, batch_size,
grid=None):
"""
Args:
grid: pre-allocated grid tensor. should be used when the volume of spatial shape
......@@ -77,7 +79,8 @@ class SparseConvTensor(object):
return None
def dense(self, channels_first=True):
output_shape = [self.batch_size] + list(self.spatial_shape) + [self.features.shape[1]]
output_shape = [self.batch_size] + list(
self.spatial_shape) + [self.features.shape[1]]
res = scatter_nd(self.indices.long(), self.features, output_shape)
if not channels_first:
return res
......@@ -88,7 +91,8 @@ class SparseConvTensor(object):
@property
def sparity(self):
return self.indices.shape[0] / np.prod(self.spatial_shape) / self.batch_size
return self.indices.shape[0] / np.prod(
self.spatial_shape) / self.batch_size
class ToDense(SparseModule):
......@@ -97,6 +101,7 @@ class ToDense(SparseModule):
def forward(self, x: SparseConvTensor):
return x.dense()
class RemoveGrid(SparseModule):
"""remove pre-allocated grid buffer.
"""
......
......@@ -16,15 +16,16 @@ import math
import time
import numpy as np
import spconv
import spconv.functional as Fsp
import torch
from spconv import ops
from spconv.modules import SparseModule
from torch import nn
from torch.nn import init
from torch.nn.parameter import Parameter
import spconv
import spconv.functional as Fsp
from spconv import ops
from spconv.modules import SparseModule
def _calculate_fan_in_and_fan_out_hwio(tensor):
dimensions = tensor.ndimension()
......@@ -146,8 +147,9 @@ class SparseConvolution(SparseModule):
self.weight.view(self.in_channels, self.out_channels))
if self.bias is not None:
features += self.bias
out_tensor = spconv.SparseConvTensor(
features, input.indices, input.spatial_shape, input.batch_size)
out_tensor = spconv.SparseConvTensor(features, input.indices,
input.spatial_shape,
input.batch_size)
out_tensor.indice_dict = input.indice_dict
out_tensor.grid = input.grid
return out_tensor
......@@ -181,9 +183,12 @@ class SparseConvolution(SparseModule):
spatial_shape)
if self.fused_bn:
assert self.bias is not None
out_features = ops.fused_indice_conv(
features, self.weight, self.bias, indice_pairs.to(device),
indice_pair_num, outids.shape[0], self.inverse, self.subm)
out_features = ops.fused_indice_conv(features, self.weight,
self.bias,
indice_pairs.to(device),
indice_pair_num,
outids.shape[0], self.inverse,
self.subm)
else:
if self.subm:
out_features = Fsp.indice_subm_conv(features, self.weight,
......@@ -222,8 +227,7 @@ class SparseConv2d(SparseConvolution):
bias=True,
indice_key=None,
use_hash=False):
super(SparseConv2d, self).__init__(
2,
super(SparseConv2d, self).__init__(2,
in_channels,
out_channels,
kernel_size,
......@@ -248,8 +252,7 @@ class SparseConv3d(SparseConvolution):
bias=True,
indice_key=None,
use_hash=False):
super(SparseConv3d, self).__init__(
3,
super(SparseConv3d, self).__init__(3,
in_channels,
out_channels,
kernel_size,
......@@ -274,8 +277,7 @@ class SparseConv4d(SparseConvolution):
bias=True,
indice_key=None,
use_hash=False):
super(SparseConv4d, self).__init__(
4,
super(SparseConv4d, self).__init__(4,
in_channels,
out_channels,
kernel_size,
......@@ -300,8 +302,7 @@ class SparseConvTranspose2d(SparseConvolution):
bias=True,
indice_key=None,
use_hash=False):
super(SparseConvTranspose2d, self).__init__(
2,
super(SparseConvTranspose2d, self).__init__(2,
in_channels,
out_channels,
kernel_size,
......@@ -327,8 +328,7 @@ class SparseConvTranspose3d(SparseConvolution):
bias=True,
indice_key=None,
use_hash=False):
super(SparseConvTranspose3d, self).__init__(
3,
super(SparseConvTranspose3d, self).__init__(3,
in_channels,
out_channels,
kernel_size,
......@@ -349,8 +349,7 @@ class SparseInverseConv2d(SparseConvolution):
kernel_size,
indice_key,
bias=True):
super(SparseInverseConv2d, self).__init__(
2,
super(SparseInverseConv2d, self).__init__(2,
in_channels,
out_channels,
kernel_size,
......@@ -366,8 +365,7 @@ class SparseInverseConv3d(SparseConvolution):
kernel_size,
indice_key,
bias=True):
super(SparseInverseConv3d, self).__init__(
3,
super(SparseInverseConv3d, self).__init__(3,
in_channels,
out_channels,
kernel_size,
......@@ -388,8 +386,7 @@ class SubMConv2d(SparseConvolution):
bias=True,
indice_key=None,
use_hash=False):
super(SubMConv2d, self).__init__(
2,
super(SubMConv2d, self).__init__(2,
in_channels,
out_channels,
kernel_size,
......@@ -415,8 +412,7 @@ class SubMConv3d(SparseConvolution):
bias=True,
indice_key=None,
use_hash=False):
super(SubMConv3d, self).__init__(
3,
super(SubMConv3d, self).__init__(3,
in_channels,
out_channels,
kernel_size,
......@@ -442,8 +438,7 @@ class SubMConv4d(SparseConvolution):
bias=True,
indice_key=None,
use_hash=False):
super(SubMConv4d, self).__init__(
4,
super(SubMConv4d, self).__init__(4,
in_channels,
out_channels,
kernel_size,
......
This diff is collapsed.
......@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from collections import OrderedDict
import spconv
import torch
from torch import nn
import time
import spconv
def is_spconv_module(module):
......@@ -81,7 +82,6 @@ class SparseSequential(SparseModule):
relu2=nn.ReLU()
)
"""
def __init__(self, *args, **kwargs):
super(SparseSequential, self).__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict):
......@@ -148,7 +148,8 @@ class SparseSequential(SparseModule):
idx = 0
while idx < len(mods):
if is_sparse_conv(mods[idx]):
if idx < len(mods) - 1 and isinstance(mods[idx + 1], nn.BatchNorm1d):
if idx < len(mods) - 1 and isinstance(mods[idx + 1],
nn.BatchNorm1d):
new_module = SparseConvolution(
ndim=mods[idx].ndim,
in_channels=mods[idx].in_channels,
......
This diff is collapsed.
......@@ -12,20 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import time
import numpy as np
import spconv
import spconv.functional as Fsp
import torch
from spconv import ops
from spconv.modules import SparseModule
from torch import nn
from torch.nn import init
from torch.nn.parameter import Parameter
import spconv
import spconv.functional as Fsp
from spconv import ops
from spconv.modules import SparseModule
class SparseMaxPool(SparseModule):
def __init__(self,
......@@ -61,15 +61,17 @@ class SparseMaxPool(SparseModule):
batch_size = input.batch_size
if not self.subm:
out_spatial_shape = ops.get_conv_output_size(
spatial_shape, self.kernel_size, self.stride, self.padding, self.dilation)
spatial_shape, self.kernel_size, self.stride, self.padding,
self.dilation)
else:
out_spatial_shape = spatial_shape
outids, indice_pairs, indice_pairs_num = ops.get_indice_pairs(
indices, batch_size, spatial_shape, self.kernel_size,
self.stride, self.padding, self.dilation, 0, self.subm)
indices, batch_size, spatial_shape, self.kernel_size, self.stride,
self.padding, self.dilation, 0, self.subm)
out_features = Fsp.indice_maxpool(features, indice_pairs.to(device),
indice_pairs_num.to(device), outids.shape[0])
indice_pairs_num.to(device),
outids.shape[0])
out_tensor = spconv.SparseConvTensor(out_features, outids,
out_spatial_shape, batch_size)
out_tensor.indice_dict = input.indice_dict
......@@ -78,28 +80,12 @@ class SparseMaxPool(SparseModule):
class SparseMaxPool2d(SparseMaxPool):
def __init__(self,
kernel_size,
stride=1,
padding=0,
dilation=1):
super(SparseMaxPool2d, self).__init__(
2,
kernel_size,
stride,
padding,
def __init__(self, kernel_size, stride=1, padding=0, dilation=1):
super(SparseMaxPool2d, self).__init__(2, kernel_size, stride, padding,
dilation)
class SparseMaxPool3d(SparseMaxPool):
def __init__(self,
kernel_size,
stride=1,
padding=0,
dilation=1):
super(SparseMaxPool3d, self).__init__(
3,
kernel_size,
stride,
padding,
def __init__(self, kernel_size, stride=1, padding=0, dilation=1):
super(SparseMaxPool3d, self).__init__(3, kernel_size, stride, padding,
dilation)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment