Commit 19e73bbe authored by Yan Yan's avatar Yan Yan
Browse files

format code with clang-format, better c++ code

parent c336139f
...@@ -10,8 +10,8 @@ ...@@ -10,8 +10,8 @@
* copies of the Software, and to permit persons to whom the Software is * copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions: * furnished to do so, subject to the following conditions:
* *
* The above copyright notice and this permission notice shall be included in all * The above copyright notice and this permission notice shall be included in
* copies or substantial portions of the Software. * all copies or substantial portions of the Software.
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#ifndef TSL_ROBIN_GROWTH_POLICY_H #ifndef TSL_ROBIN_GROWTH_POLICY_H
#define TSL_ROBIN_GROWTH_POLICY_H #define TSL_ROBIN_GROWTH_POLICY_H
#include <algorithm> #include <algorithm>
#include <array> #include <array>
#include <climits> #include <climits>
...@@ -35,65 +34,70 @@ ...@@ -35,65 +34,70 @@
#include <ratio> #include <ratio>
#include <stdexcept> #include <stdexcept>
#ifdef TSL_DEBUG #ifdef TSL_DEBUG
# define tsl_rh_assert(expr) assert(expr) #define tsl_rh_assert(expr) assert(expr)
#else #else
# define tsl_rh_assert(expr) (static_cast<void>(0)) #define tsl_rh_assert(expr) (static_cast<void>(0))
#endif #endif
/** /**
* If exceptions are enabled, throw the exception passed in parameter, otherwise call std::terminate. * If exceptions are enabled, throw the exception passed in parameter, otherwise
* call std::terminate.
*/ */
#if (defined(__cpp_exceptions) || defined(__EXCEPTIONS) || (defined (_MSC_VER) && defined (_CPPUNWIND))) && !defined(TSL_NO_EXCEPTIONS) #if (defined(__cpp_exceptions) || defined(__EXCEPTIONS) || \
# define TSL_RH_THROW_OR_TERMINATE(ex, msg) throw ex(msg) (defined(_MSC_VER) && defined(_CPPUNWIND))) && \
!defined(TSL_NO_EXCEPTIONS)
#define TSL_RH_THROW_OR_TERMINATE(ex, msg) throw ex(msg)
#else
#ifdef NDEBUG
#define TSL_RH_THROW_OR_TERMINATE(ex, msg) std::terminate()
#else #else
# ifdef NDEBUG #include <cstdio>
# define TSL_RH_THROW_OR_TERMINATE(ex, msg) std::terminate() #define TSL_RH_THROW_OR_TERMINATE(ex, msg) \
# else do { \
# include <cstdio> std::fprintf(stderr, msg); \
# define TSL_RH_THROW_OR_TERMINATE(ex, msg) do { std::fprintf(stderr, msg); std::terminate(); } while(0) std::terminate(); \
# endif } while (0)
#endif
#endif #endif
#if defined(__GNUC__) || defined(__clang__) #if defined(__GNUC__) || defined(__clang__)
# define TSL_RH_LIKELY(exp) (__builtin_expect(!!(exp), true)) #define TSL_RH_LIKELY(exp) (__builtin_expect(!!(exp), true))
#else #else
# define TSL_RH_LIKELY(exp) (exp) #define TSL_RH_LIKELY(exp) (exp)
#endif #endif
namespace tsl { namespace tsl {
namespace rh { namespace rh {
/** /**
* Grow the hash table by a factor of GrowthFactor keeping the bucket count to a power of two. It allows * Grow the hash table by a factor of GrowthFactor keeping the bucket count to a
* the table to use a mask operation instead of a modulo operation to map a hash to a bucket. * power of two. It allows the table to use a mask operation instead of a modulo
* operation to map a hash to a bucket.
* *
* GrowthFactor must be a power of two >= 2. * GrowthFactor must be a power of two >= 2.
*/ */
template<std::size_t GrowthFactor> template <std::size_t GrowthFactor> class power_of_two_growth_policy {
class power_of_two_growth_policy {
public: public:
/** /**
* Called on the hash table creation and on rehash. The number of buckets for the table is passed in parameter. * Called on the hash table creation and on rehash. The number of buckets for
* This number is a minimum, the policy may update this value with a higher value if needed (but not lower). * the table is passed in parameter. This number is a minimum, the policy may
* update this value with a higher value if needed (but not lower).
* *
* If 0 is given, min_bucket_count_in_out must still be 0 after the policy creation and * If 0 is given, min_bucket_count_in_out must still be 0 after the policy
* bucket_for_hash must always return 0 in this case. * creation and bucket_for_hash must always return 0 in this case.
*/ */
explicit power_of_two_growth_policy(std::size_t& min_bucket_count_in_out) { explicit power_of_two_growth_policy(std::size_t &min_bucket_count_in_out) {
if(min_bucket_count_in_out > max_bucket_count()) { if (min_bucket_count_in_out > max_bucket_count()) {
TSL_RH_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size."); TSL_RH_THROW_OR_TERMINATE(std::length_error,
"The hash table exceeds its maxmimum size.");
} }
if(min_bucket_count_in_out > 0) { if (min_bucket_count_in_out > 0) {
min_bucket_count_in_out = round_up_to_power_of_two(min_bucket_count_in_out); min_bucket_count_in_out =
round_up_to_power_of_two(min_bucket_count_in_out);
m_mask = min_bucket_count_in_out - 1; m_mask = min_bucket_count_in_out - 1;
} } else {
else {
m_mask = 0; m_mask = 0;
} }
} }
...@@ -110,8 +114,9 @@ public: ...@@ -110,8 +114,9 @@ public:
* Return the number of buckets that should be used on next growth. * Return the number of buckets that should be used on next growth.
*/ */
std::size_t next_bucket_count() const { std::size_t next_bucket_count() const {
if((m_mask + 1) > max_bucket_count() / GrowthFactor) { if ((m_mask + 1) > max_bucket_count() / GrowthFactor) {
TSL_RH_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size."); TSL_RH_THROW_OR_TERMINATE(std::length_error,
"The hash table exceeds its maxmimum size.");
} }
return (m_mask + 1) * GrowthFactor; return (m_mask + 1) * GrowthFactor;
...@@ -127,24 +132,23 @@ public: ...@@ -127,24 +132,23 @@ public:
/** /**
* Reset the growth policy as if it was created with a bucket count of 0. * Reset the growth policy as if it was created with a bucket count of 0.
* After a clear, the policy must always return 0 when bucket_for_hash is called. * After a clear, the policy must always return 0 when bucket_for_hash is
* called.
*/ */
void clear() noexcept { void clear() noexcept { m_mask = 0; }
m_mask = 0;
}
private: private:
static std::size_t round_up_to_power_of_two(std::size_t value) { static std::size_t round_up_to_power_of_two(std::size_t value) {
if(is_power_of_two(value)) { if (is_power_of_two(value)) {
return value; return value;
} }
if(value == 0) { if (value == 0) {
return 1; return 1;
} }
--value; --value;
for(std::size_t i = 1; i < sizeof(std::size_t) * CHAR_BIT; i *= 2) { for (std::size_t i = 1; i < sizeof(std::size_t) * CHAR_BIT; i *= 2) {
value |= value >> i; value |= value >> i;
} }
...@@ -156,28 +160,28 @@ private: ...@@ -156,28 +160,28 @@ private:
} }
protected: protected:
static_assert(is_power_of_two(GrowthFactor) && GrowthFactor >= 2, "GrowthFactor must be a power of two >= 2."); static_assert(is_power_of_two(GrowthFactor) && GrowthFactor >= 2,
"GrowthFactor must be a power of two >= 2.");
std::size_t m_mask; std::size_t m_mask;
}; };
/** /**
* Grow the hash table by GrowthFactor::num / GrowthFactor::den and use a modulo to map a hash * Grow the hash table by GrowthFactor::num / GrowthFactor::den and use a modulo
* to a bucket. Slower but it can be useful if you want a slower growth. * to map a hash to a bucket. Slower but it can be useful if you want a slower
* growth.
*/ */
template<class GrowthFactor = std::ratio<3, 2>> template <class GrowthFactor = std::ratio<3, 2>> class mod_growth_policy {
class mod_growth_policy {
public: public:
explicit mod_growth_policy(std::size_t& min_bucket_count_in_out) { explicit mod_growth_policy(std::size_t &min_bucket_count_in_out) {
if(min_bucket_count_in_out > max_bucket_count()) { if (min_bucket_count_in_out > max_bucket_count()) {
TSL_RH_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size."); TSL_RH_THROW_OR_TERMINATE(std::length_error,
"The hash table exceeds its maxmimum size.");
} }
if(min_bucket_count_in_out > 0) { if (min_bucket_count_in_out > 0) {
m_mod = min_bucket_count_in_out; m_mod = min_bucket_count_in_out;
} } else {
else {
m_mod = 1; m_mod = 1;
} }
} }
...@@ -187,73 +191,79 @@ public: ...@@ -187,73 +191,79 @@ public:
} }
std::size_t next_bucket_count() const { std::size_t next_bucket_count() const {
if(m_mod == max_bucket_count()) { if (m_mod == max_bucket_count()) {
TSL_RH_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size."); TSL_RH_THROW_OR_TERMINATE(std::length_error,
"The hash table exceeds its maxmimum size.");
} }
const double next_bucket_count = std::ceil(double(m_mod) * REHASH_SIZE_MULTIPLICATION_FACTOR); const double next_bucket_count =
if(!std::isnormal(next_bucket_count)) { std::ceil(double(m_mod) * REHASH_SIZE_MULTIPLICATION_FACTOR);
TSL_RH_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size."); if (!std::isnormal(next_bucket_count)) {
TSL_RH_THROW_OR_TERMINATE(std::length_error,
"The hash table exceeds its maxmimum size.");
} }
if(next_bucket_count > double(max_bucket_count())) { if (next_bucket_count > double(max_bucket_count())) {
return max_bucket_count(); return max_bucket_count();
} } else {
else {
return std::size_t(next_bucket_count); return std::size_t(next_bucket_count);
} }
} }
std::size_t max_bucket_count() const { std::size_t max_bucket_count() const { return MAX_BUCKET_COUNT; }
return MAX_BUCKET_COUNT;
}
void clear() noexcept { void clear() noexcept { m_mod = 1; }
m_mod = 1;
}
private: private:
static constexpr double REHASH_SIZE_MULTIPLICATION_FACTOR = 1.0 * GrowthFactor::num / GrowthFactor::den; static constexpr double REHASH_SIZE_MULTIPLICATION_FACTOR =
1.0 * GrowthFactor::num / GrowthFactor::den;
static const std::size_t MAX_BUCKET_COUNT = static const std::size_t MAX_BUCKET_COUNT =
std::size_t(double( std::size_t(double(std::numeric_limits<std::size_t>::max() /
std::numeric_limits<std::size_t>::max() / REHASH_SIZE_MULTIPLICATION_FACTOR REHASH_SIZE_MULTIPLICATION_FACTOR));
));
static_assert(REHASH_SIZE_MULTIPLICATION_FACTOR >= 1.1, "Growth factor should be >= 1.1."); static_assert(REHASH_SIZE_MULTIPLICATION_FACTOR >= 1.1,
"Growth factor should be >= 1.1.");
std::size_t m_mod; std::size_t m_mod;
}; };
namespace detail { namespace detail {
static constexpr const std::array<std::size_t, 40> PRIMES = {{ static constexpr const std::array<std::size_t, 40> PRIMES = {
1ul, 5ul, 17ul, 29ul, 37ul, 53ul, 67ul, 79ul, 97ul, 131ul, 193ul, 257ul, 389ul, 521ul, 769ul, 1031ul, {1ul, 5ul, 17ul, 29ul, 37ul,
1543ul, 2053ul, 3079ul, 6151ul, 12289ul, 24593ul, 49157ul, 98317ul, 196613ul, 393241ul, 786433ul, 53ul, 67ul, 79ul, 97ul, 131ul,
1572869ul, 3145739ul, 6291469ul, 12582917ul, 25165843ul, 50331653ul, 100663319ul, 201326611ul, 193ul, 257ul, 389ul, 521ul, 769ul,
402653189ul, 805306457ul, 1610612741ul, 3221225473ul, 4294967291ul 1031ul, 1543ul, 2053ul, 3079ul, 6151ul,
}}; 12289ul, 24593ul, 49157ul, 98317ul, 196613ul,
393241ul, 786433ul, 1572869ul, 3145739ul, 6291469ul,
template<unsigned int IPrime> 12582917ul, 25165843ul, 50331653ul, 100663319ul, 201326611ul,
static constexpr std::size_t mod(std::size_t hash) { return hash % PRIMES[IPrime]; } 402653189ul, 805306457ul, 1610612741ul, 3221225473ul, 4294967291ul}};
// MOD_PRIME[iprime](hash) returns hash % PRIMES[iprime]. This table allows for faster modulo as the template <unsigned int IPrime>
// compiler can optimize the modulo code better with a constant known at the compilation. static constexpr std::size_t mod(std::size_t hash) {
static constexpr const std::array<std::size_t(*)(std::size_t), 40> MOD_PRIME = {{ return hash % PRIMES[IPrime];
&mod<0>, &mod<1>, &mod<2>, &mod<3>, &mod<4>, &mod<5>, &mod<6>, &mod<7>, &mod<8>, &mod<9>, &mod<10>,
&mod<11>, &mod<12>, &mod<13>, &mod<14>, &mod<15>, &mod<16>, &mod<17>, &mod<18>, &mod<19>, &mod<20>,
&mod<21>, &mod<22>, &mod<23>, &mod<24>, &mod<25>, &mod<26>, &mod<27>, &mod<28>, &mod<29>, &mod<30>,
&mod<31>, &mod<32>, &mod<33>, &mod<34>, &mod<35>, &mod<36>, &mod<37> , &mod<38>, &mod<39>
}};
} }
// MOD_PRIME[iprime](hash) returns hash % PRIMES[iprime]. This table allows for
// faster modulo as the compiler can optimize the modulo code better with a
// constant known at the compilation.
static constexpr const std::array<std::size_t (*)(std::size_t), 40> MOD_PRIME =
{{&mod<0>, &mod<1>, &mod<2>, &mod<3>, &mod<4>, &mod<5>, &mod<6>,
&mod<7>, &mod<8>, &mod<9>, &mod<10>, &mod<11>, &mod<12>, &mod<13>,
&mod<14>, &mod<15>, &mod<16>, &mod<17>, &mod<18>, &mod<19>, &mod<20>,
&mod<21>, &mod<22>, &mod<23>, &mod<24>, &mod<25>, &mod<26>, &mod<27>,
&mod<28>, &mod<29>, &mod<30>, &mod<31>, &mod<32>, &mod<33>, &mod<34>,
&mod<35>, &mod<36>, &mod<37>, &mod<38>, &mod<39>}};
} // namespace detail
/** /**
* Grow the hash table by using prime numbers as bucket count. Slower than tsl::rh::power_of_two_growth_policy in * Grow the hash table by using prime numbers as bucket count. Slower than
* general but will probably distribute the values around better in the buckets with a poor hash function. * tsl::rh::power_of_two_growth_policy in general but will probably distribute
* the values around better in the buckets with a poor hash function.
* *
* To allow the compiler to optimize the modulo operation, a lookup table is used with constant primes numbers. * To allow the compiler to optimize the modulo operation, a lookup table is
* used with constant primes numbers.
* *
* With a switch the code would look like: * With a switch the code would look like:
* \code * \code
...@@ -268,25 +278,27 @@ static constexpr const std::array<std::size_t(*)(std::size_t), 40> MOD_PRIME = { ...@@ -268,25 +278,27 @@ static constexpr const std::array<std::size_t(*)(std::size_t), 40> MOD_PRIME = {
* } * }
* \endcode * \endcode
* *
* Due to the constant variable in the modulo the compiler is able to optimize the operation * Due to the constant variable in the modulo the compiler is able to optimize
* by a series of multiplications, substractions and shifts. * the operation by a series of multiplications, substractions and shifts.
* *
* The 'hash % 5' could become something like 'hash - (hash * 0xCCCCCCCD) >> 34) * 5' in a 64 bits environement. * The 'hash % 5' could become something like 'hash - (hash * 0xCCCCCCCD) >> 34)
* * 5' in a 64 bits environement.
*/ */
class prime_growth_policy { class prime_growth_policy {
public: public:
explicit prime_growth_policy(std::size_t& min_bucket_count_in_out) { explicit prime_growth_policy(std::size_t &min_bucket_count_in_out) {
auto it_prime = std::lower_bound(detail::PRIMES.begin(), auto it_prime = std::lower_bound(
detail::PRIMES.end(), min_bucket_count_in_out); detail::PRIMES.begin(), detail::PRIMES.end(), min_bucket_count_in_out);
if(it_prime == detail::PRIMES.end()) { if (it_prime == detail::PRIMES.end()) {
TSL_RH_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size."); TSL_RH_THROW_OR_TERMINATE(std::length_error,
"The hash table exceeds its maxmimum size.");
} }
m_iprime = static_cast<unsigned int>(std::distance(detail::PRIMES.begin(), it_prime)); m_iprime = static_cast<unsigned int>(
if(min_bucket_count_in_out > 0) { std::distance(detail::PRIMES.begin(), it_prime));
if (min_bucket_count_in_out > 0) {
min_bucket_count_in_out = *it_prime; min_bucket_count_in_out = *it_prime;
} } else {
else {
min_bucket_count_in_out = 0; min_bucket_count_in_out = 0;
} }
} }
...@@ -296,29 +308,27 @@ public: ...@@ -296,29 +308,27 @@ public:
} }
std::size_t next_bucket_count() const { std::size_t next_bucket_count() const {
if(m_iprime + 1 >= detail::PRIMES.size()) { if (m_iprime + 1 >= detail::PRIMES.size()) {
TSL_RH_THROW_OR_TERMINATE(std::length_error, "The hash table exceeds its maxmimum size."); TSL_RH_THROW_OR_TERMINATE(std::length_error,
"The hash table exceeds its maxmimum size.");
} }
return detail::PRIMES[m_iprime + 1]; return detail::PRIMES[m_iprime + 1];
} }
std::size_t max_bucket_count() const { std::size_t max_bucket_count() const { return detail::PRIMES.back(); }
return detail::PRIMES.back();
}
void clear() noexcept { void clear() noexcept { m_iprime = 0; }
m_iprime = 0;
}
private: private:
unsigned int m_iprime; unsigned int m_iprime;
static_assert(std::numeric_limits<decltype(m_iprime)>::max() >= detail::PRIMES.size(), static_assert(std::numeric_limits<decltype(m_iprime)>::max() >=
detail::PRIMES.size(),
"The type of m_iprime is not big enough."); "The type of m_iprime is not big enough.");
}; };
} } // namespace rh
} } // namespace tsl
#endif #endif
This diff is collapsed.
This diff is collapsed.
...@@ -14,14 +14,14 @@ ...@@ -14,14 +14,14 @@
#pragma once #pragma once
#include <chrono> #include <chrono>
#ifdef SPCONV_CUDA #ifdef TV_CUDA
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#endif #endif
#include <iostream> #include <iostream>
namespace spconv { namespace spconv {
#ifdef SPCONV_CUDA #ifdef TV_CUDA
template <typename TimeT = std::chrono::microseconds> struct CudaContextTimer { template <typename TimeT = std::chrono::microseconds> struct CudaContextTimer {
CudaContextTimer() { CudaContextTimer() {
cudaDeviceSynchronize(); cudaDeviceSynchronize();
......
import os import os
import re
import sys
import platform import platform
import re
import subprocess import subprocess
import torch import sys
from setuptools import setup, Extension, find_packages
from setuptools.command.build_ext import build_ext
from distutils.version import LooseVersion from distutils.version import LooseVersion
from pathlib import Path from pathlib import Path
import torch
from setuptools import Extension, find_packages, setup
from setuptools.command.build_ext import build_ext
# if 'LIBTORCH_ROOT' not in os.environ: # if 'LIBTORCH_ROOT' not in os.environ:
# raise ValueError("You must set LIBTORCH_ROOT to your torch c++ library.") # raise ValueError("You must set LIBTORCH_ROOT to your torch c++ library.")
...@@ -100,4 +100,3 @@ setup( ...@@ -100,4 +100,3 @@ setup(
cmdclass=dict(build_ext=CMakeBuild), cmdclass=dict(build_ext=CMakeBuild),
zip_safe=False, zip_safe=False,
) )
...@@ -12,21 +12,20 @@ ...@@ -12,21 +12,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import platform
from pathlib import Path from pathlib import Path
import platform
import numpy as np import numpy as np
import torch import torch
from spconv import utils
from spconv.conv import SparseConv2d, SparseConv3d, SubMConv2d, SubMConv3d from spconv import ops, utils
from spconv.conv import SparseConvTranspose2d, SparseConvTranspose3d from spconv.conv import (SparseConv2d, SparseConv3d, SparseConvTranspose2d,
from spconv.conv import SparseInverseConv2d, SparseInverseConv3d SparseConvTranspose3d, SparseInverseConv2d,
SparseInverseConv3d, SubMConv2d, SubMConv3d)
from spconv.identity import Identity
from spconv.modules import SparseModule, SparseSequential from spconv.modules import SparseModule, SparseSequential
from spconv.pool import SparseMaxPool2d, SparseMaxPool3d from spconv.pool import SparseMaxPool2d, SparseMaxPool3d
from spconv.tables import ConcatTable, JoinTable, AddTable from spconv.tables import AddTable, ConcatTable, JoinTable
from spconv.identity import Identity
from spconv import ops
_LIB_FILE_NAME = "libspconv.so" _LIB_FILE_NAME = "libspconv.so"
if platform.system() == "Windows": if platform.system() == "Windows":
...@@ -34,6 +33,7 @@ if platform.system() == "Windows": ...@@ -34,6 +33,7 @@ if platform.system() == "Windows":
_LIB_PATH = str(Path(__file__).parent / _LIB_FILE_NAME) _LIB_PATH = str(Path(__file__).parent / _LIB_FILE_NAME)
torch.ops.load_library(_LIB_PATH) torch.ops.load_library(_LIB_PATH)
def scatter_nd(indices, updates, shape): def scatter_nd(indices, updates, shape):
"""pytorch edition of tensorflow scatter_nd. """pytorch edition of tensorflow scatter_nd.
this function don't contain except handle code. so use this carefully this function don't contain except handle code. so use this carefully
...@@ -49,8 +49,10 @@ def scatter_nd(indices, updates, shape): ...@@ -49,8 +49,10 @@ def scatter_nd(indices, updates, shape):
ret[slices] = updates.view(*output_shape) ret[slices] = updates.view(*output_shape)
return ret return ret
class SparseConvTensor(object): class SparseConvTensor(object):
def __init__(self, features, indices, spatial_shape, batch_size, grid=None): def __init__(self, features, indices, spatial_shape, batch_size,
grid=None):
""" """
Args: Args:
grid: pre-allocated grid tensor. should be used when the volume of spatial shape grid: pre-allocated grid tensor. should be used when the volume of spatial shape
...@@ -77,7 +79,8 @@ class SparseConvTensor(object): ...@@ -77,7 +79,8 @@ class SparseConvTensor(object):
return None return None
def dense(self, channels_first=True): def dense(self, channels_first=True):
output_shape = [self.batch_size] + list(self.spatial_shape) + [self.features.shape[1]] output_shape = [self.batch_size] + list(
self.spatial_shape) + [self.features.shape[1]]
res = scatter_nd(self.indices.long(), self.features, output_shape) res = scatter_nd(self.indices.long(), self.features, output_shape)
if not channels_first: if not channels_first:
return res return res
...@@ -88,7 +91,8 @@ class SparseConvTensor(object): ...@@ -88,7 +91,8 @@ class SparseConvTensor(object):
@property @property
def sparity(self): def sparity(self):
return self.indices.shape[0] / np.prod(self.spatial_shape) / self.batch_size return self.indices.shape[0] / np.prod(
self.spatial_shape) / self.batch_size
class ToDense(SparseModule): class ToDense(SparseModule):
...@@ -97,6 +101,7 @@ class ToDense(SparseModule): ...@@ -97,6 +101,7 @@ class ToDense(SparseModule):
def forward(self, x: SparseConvTensor): def forward(self, x: SparseConvTensor):
return x.dense() return x.dense()
class RemoveGrid(SparseModule): class RemoveGrid(SparseModule):
"""remove pre-allocated grid buffer. """remove pre-allocated grid buffer.
""" """
......
...@@ -16,15 +16,16 @@ import math ...@@ -16,15 +16,16 @@ import math
import time import time
import numpy as np import numpy as np
import spconv
import spconv.functional as Fsp
import torch import torch
from spconv import ops
from spconv.modules import SparseModule
from torch import nn from torch import nn
from torch.nn import init from torch.nn import init
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
import spconv
import spconv.functional as Fsp
from spconv import ops
from spconv.modules import SparseModule
def _calculate_fan_in_and_fan_out_hwio(tensor): def _calculate_fan_in_and_fan_out_hwio(tensor):
dimensions = tensor.ndimension() dimensions = tensor.ndimension()
...@@ -146,8 +147,9 @@ class SparseConvolution(SparseModule): ...@@ -146,8 +147,9 @@ class SparseConvolution(SparseModule):
self.weight.view(self.in_channels, self.out_channels)) self.weight.view(self.in_channels, self.out_channels))
if self.bias is not None: if self.bias is not None:
features += self.bias features += self.bias
out_tensor = spconv.SparseConvTensor( out_tensor = spconv.SparseConvTensor(features, input.indices,
features, input.indices, input.spatial_shape, input.batch_size) input.spatial_shape,
input.batch_size)
out_tensor.indice_dict = input.indice_dict out_tensor.indice_dict = input.indice_dict
out_tensor.grid = input.grid out_tensor.grid = input.grid
return out_tensor return out_tensor
...@@ -181,9 +183,12 @@ class SparseConvolution(SparseModule): ...@@ -181,9 +183,12 @@ class SparseConvolution(SparseModule):
spatial_shape) spatial_shape)
if self.fused_bn: if self.fused_bn:
assert self.bias is not None assert self.bias is not None
out_features = ops.fused_indice_conv( out_features = ops.fused_indice_conv(features, self.weight,
features, self.weight, self.bias, indice_pairs.to(device), self.bias,
indice_pair_num, outids.shape[0], self.inverse, self.subm) indice_pairs.to(device),
indice_pair_num,
outids.shape[0], self.inverse,
self.subm)
else: else:
if self.subm: if self.subm:
out_features = Fsp.indice_subm_conv(features, self.weight, out_features = Fsp.indice_subm_conv(features, self.weight,
...@@ -222,8 +227,7 @@ class SparseConv2d(SparseConvolution): ...@@ -222,8 +227,7 @@ class SparseConv2d(SparseConvolution):
bias=True, bias=True,
indice_key=None, indice_key=None,
use_hash=False): use_hash=False):
super(SparseConv2d, self).__init__( super(SparseConv2d, self).__init__(2,
2,
in_channels, in_channels,
out_channels, out_channels,
kernel_size, kernel_size,
...@@ -248,8 +252,7 @@ class SparseConv3d(SparseConvolution): ...@@ -248,8 +252,7 @@ class SparseConv3d(SparseConvolution):
bias=True, bias=True,
indice_key=None, indice_key=None,
use_hash=False): use_hash=False):
super(SparseConv3d, self).__init__( super(SparseConv3d, self).__init__(3,
3,
in_channels, in_channels,
out_channels, out_channels,
kernel_size, kernel_size,
...@@ -274,8 +277,7 @@ class SparseConv4d(SparseConvolution): ...@@ -274,8 +277,7 @@ class SparseConv4d(SparseConvolution):
bias=True, bias=True,
indice_key=None, indice_key=None,
use_hash=False): use_hash=False):
super(SparseConv4d, self).__init__( super(SparseConv4d, self).__init__(4,
4,
in_channels, in_channels,
out_channels, out_channels,
kernel_size, kernel_size,
...@@ -300,8 +302,7 @@ class SparseConvTranspose2d(SparseConvolution): ...@@ -300,8 +302,7 @@ class SparseConvTranspose2d(SparseConvolution):
bias=True, bias=True,
indice_key=None, indice_key=None,
use_hash=False): use_hash=False):
super(SparseConvTranspose2d, self).__init__( super(SparseConvTranspose2d, self).__init__(2,
2,
in_channels, in_channels,
out_channels, out_channels,
kernel_size, kernel_size,
...@@ -327,8 +328,7 @@ class SparseConvTranspose3d(SparseConvolution): ...@@ -327,8 +328,7 @@ class SparseConvTranspose3d(SparseConvolution):
bias=True, bias=True,
indice_key=None, indice_key=None,
use_hash=False): use_hash=False):
super(SparseConvTranspose3d, self).__init__( super(SparseConvTranspose3d, self).__init__(3,
3,
in_channels, in_channels,
out_channels, out_channels,
kernel_size, kernel_size,
...@@ -349,8 +349,7 @@ class SparseInverseConv2d(SparseConvolution): ...@@ -349,8 +349,7 @@ class SparseInverseConv2d(SparseConvolution):
kernel_size, kernel_size,
indice_key, indice_key,
bias=True): bias=True):
super(SparseInverseConv2d, self).__init__( super(SparseInverseConv2d, self).__init__(2,
2,
in_channels, in_channels,
out_channels, out_channels,
kernel_size, kernel_size,
...@@ -366,8 +365,7 @@ class SparseInverseConv3d(SparseConvolution): ...@@ -366,8 +365,7 @@ class SparseInverseConv3d(SparseConvolution):
kernel_size, kernel_size,
indice_key, indice_key,
bias=True): bias=True):
super(SparseInverseConv3d, self).__init__( super(SparseInverseConv3d, self).__init__(3,
3,
in_channels, in_channels,
out_channels, out_channels,
kernel_size, kernel_size,
...@@ -388,8 +386,7 @@ class SubMConv2d(SparseConvolution): ...@@ -388,8 +386,7 @@ class SubMConv2d(SparseConvolution):
bias=True, bias=True,
indice_key=None, indice_key=None,
use_hash=False): use_hash=False):
super(SubMConv2d, self).__init__( super(SubMConv2d, self).__init__(2,
2,
in_channels, in_channels,
out_channels, out_channels,
kernel_size, kernel_size,
...@@ -415,8 +412,7 @@ class SubMConv3d(SparseConvolution): ...@@ -415,8 +412,7 @@ class SubMConv3d(SparseConvolution):
bias=True, bias=True,
indice_key=None, indice_key=None,
use_hash=False): use_hash=False):
super(SubMConv3d, self).__init__( super(SubMConv3d, self).__init__(3,
3,
in_channels, in_channels,
out_channels, out_channels,
kernel_size, kernel_size,
...@@ -442,8 +438,7 @@ class SubMConv4d(SparseConvolution): ...@@ -442,8 +438,7 @@ class SubMConv4d(SparseConvolution):
bias=True, bias=True,
indice_key=None, indice_key=None,
use_hash=False): use_hash=False):
super(SubMConv4d, self).__init__( super(SubMConv4d, self).__init__(4,
4,
in_channels, in_channels,
out_channels, out_channels,
kernel_size, kernel_size,
......
This diff is collapsed.
...@@ -12,12 +12,13 @@ ...@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import time
from collections import OrderedDict from collections import OrderedDict
import spconv
import torch import torch
from torch import nn from torch import nn
import time
import spconv
def is_spconv_module(module): def is_spconv_module(module):
...@@ -81,7 +82,6 @@ class SparseSequential(SparseModule): ...@@ -81,7 +82,6 @@ class SparseSequential(SparseModule):
relu2=nn.ReLU() relu2=nn.ReLU()
) )
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(SparseSequential, self).__init__() super(SparseSequential, self).__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict): if len(args) == 1 and isinstance(args[0], OrderedDict):
...@@ -148,7 +148,8 @@ class SparseSequential(SparseModule): ...@@ -148,7 +148,8 @@ class SparseSequential(SparseModule):
idx = 0 idx = 0
while idx < len(mods): while idx < len(mods):
if is_sparse_conv(mods[idx]): if is_sparse_conv(mods[idx]):
if idx < len(mods) - 1 and isinstance(mods[idx + 1], nn.BatchNorm1d): if idx < len(mods) - 1 and isinstance(mods[idx + 1],
nn.BatchNorm1d):
new_module = SparseConvolution( new_module = SparseConvolution(
ndim=mods[idx].ndim, ndim=mods[idx].ndim,
in_channels=mods[idx].in_channels, in_channels=mods[idx].in_channels,
......
This diff is collapsed.
...@@ -12,20 +12,20 @@ ...@@ -12,20 +12,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math import math
import time import time
import numpy as np import numpy as np
import spconv
import spconv.functional as Fsp
import torch import torch
from spconv import ops
from spconv.modules import SparseModule
from torch import nn from torch import nn
from torch.nn import init from torch.nn import init
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
import spconv
import spconv.functional as Fsp
from spconv import ops
from spconv.modules import SparseModule
class SparseMaxPool(SparseModule): class SparseMaxPool(SparseModule):
def __init__(self, def __init__(self,
...@@ -61,15 +61,17 @@ class SparseMaxPool(SparseModule): ...@@ -61,15 +61,17 @@ class SparseMaxPool(SparseModule):
batch_size = input.batch_size batch_size = input.batch_size
if not self.subm: if not self.subm:
out_spatial_shape = ops.get_conv_output_size( out_spatial_shape = ops.get_conv_output_size(
spatial_shape, self.kernel_size, self.stride, self.padding, self.dilation) spatial_shape, self.kernel_size, self.stride, self.padding,
self.dilation)
else: else:
out_spatial_shape = spatial_shape out_spatial_shape = spatial_shape
outids, indice_pairs, indice_pairs_num = ops.get_indice_pairs( outids, indice_pairs, indice_pairs_num = ops.get_indice_pairs(
indices, batch_size, spatial_shape, self.kernel_size, indices, batch_size, spatial_shape, self.kernel_size, self.stride,
self.stride, self.padding, self.dilation, 0, self.subm) self.padding, self.dilation, 0, self.subm)
out_features = Fsp.indice_maxpool(features, indice_pairs.to(device), out_features = Fsp.indice_maxpool(features, indice_pairs.to(device),
indice_pairs_num.to(device), outids.shape[0]) indice_pairs_num.to(device),
outids.shape[0])
out_tensor = spconv.SparseConvTensor(out_features, outids, out_tensor = spconv.SparseConvTensor(out_features, outids,
out_spatial_shape, batch_size) out_spatial_shape, batch_size)
out_tensor.indice_dict = input.indice_dict out_tensor.indice_dict = input.indice_dict
...@@ -78,28 +80,12 @@ class SparseMaxPool(SparseModule): ...@@ -78,28 +80,12 @@ class SparseMaxPool(SparseModule):
class SparseMaxPool2d(SparseMaxPool): class SparseMaxPool2d(SparseMaxPool):
def __init__(self, def __init__(self, kernel_size, stride=1, padding=0, dilation=1):
kernel_size, super(SparseMaxPool2d, self).__init__(2, kernel_size, stride, padding,
stride=1,
padding=0,
dilation=1):
super(SparseMaxPool2d, self).__init__(
2,
kernel_size,
stride,
padding,
dilation) dilation)
class SparseMaxPool3d(SparseMaxPool): class SparseMaxPool3d(SparseMaxPool):
def __init__(self, def __init__(self, kernel_size, stride=1, padding=0, dilation=1):
kernel_size, super(SparseMaxPool3d, self).__init__(3, kernel_size, stride, padding,
stride=1,
padding=0,
dilation=1):
super(SparseMaxPool3d, self).__init__(
3,
kernel_size,
stride,
padding,
dilation) dilation)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment