Commit 19e73bbe authored by Yan Yan's avatar Yan Yan
Browse files

format code with clang-format, better c++ code

parent c336139f
...@@ -33,7 +33,7 @@ if (SPCONV_BuildCUDA) ...@@ -33,7 +33,7 @@ if (SPCONV_BuildCUDA)
torch_cuda_get_nvcc_gencode_flag(NVCC_FLAGS_EXTRA) torch_cuda_get_nvcc_gencode_flag(NVCC_FLAGS_EXTRA)
string (REPLACE ";" " " NVCC_FLAGS_EXTRA_STR "${NVCC_FLAGS_EXTRA}") string (REPLACE ";" " " NVCC_FLAGS_EXTRA_STR "${NVCC_FLAGS_EXTRA}")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${NVCC_FLAGS_EXTRA_STR}") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${NVCC_FLAGS_EXTRA_STR}")
add_compile_definitions(SPCONV_CUDA) add_compile_definitions(TV_CUDA)
endif() endif()
# add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) # add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
add_subdirectory(third_party/pybind11) add_subdirectory(third_party/pybind11)
......
isort -rc --atomic ./spconv && \
isort -rc --atomic ./test && \
yapf -i --recursive -vv ./spconv ./test
find ./src -regex '.*\.\(cpp\|hpp\|cc\|cxx\|cu\|cuh\|h\)' | xargs clang-format -i
find ./include -regex '.*\.\(cpp\|hpp\|cc\|cxx\|cu\|cuh\|h\)' | xargs clang-format -i
\ No newline at end of file
...@@ -2,46 +2,50 @@ ...@@ -2,46 +2,50 @@
#define _CUDA_UTIL_H_ #define _CUDA_UTIL_H_
#if CUDART_VERSION >= 4000 #if CUDART_VERSION >= 4000
#define CUDA_DEVICE_SYNCHRONIZE( ) cudaDeviceSynchronize(); #define CUDA_DEVICE_SYNCHRONIZE() cudaDeviceSynchronize();
#else #else
#define CUDA_DEVICE_SYNCHRONIZE( ) cudaThreadSynchronize(); #define CUDA_DEVICE_SYNCHRONIZE() cudaThreadSynchronize();
#endif #endif
# define CUDA_SAFE_CALL_NO_SYNC( call) { \ #define CUDA_SAFE_CALL_NO_SYNC(call) \
cudaError err = call; \ { \
if( cudaSuccess != err) { \ cudaError err = call; \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", \ if (cudaSuccess != err) { \
__FILE__, __LINE__, cudaGetErrorString( err) ); \ fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", __FILE__, \
exit(EXIT_FAILURE); \ __LINE__, cudaGetErrorString(err)); \
} } exit(EXIT_FAILURE); \
} \
}
# define CUDA_SAFE_CALL( call) CUDA_SAFE_CALL_NO_SYNC(call); #define CUDA_SAFE_CALL(call) CUDA_SAFE_CALL_NO_SYNC(call);
//! Check for CUDA error //! Check for CUDA error
#ifdef _DEBUG #ifdef _DEBUG
# define CUDA_CHECK_ERROR(errorMessage) { \ #define CUDA_CHECK_ERROR(errorMessage) \
cudaError_t err = cudaGetLastError(); \ { \
if( cudaSuccess != err) { \ cudaError_t err = cudaGetLastError(); \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \ if (cudaSuccess != err) { \
errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) );\ fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \
exit(EXIT_FAILURE); \ errorMessage, __FILE__, __LINE__, cudaGetErrorString(err)); \
} \ exit(EXIT_FAILURE); \
} \
err = CUDA_DEVICE_SYNCHRONIZE(); \ err = CUDA_DEVICE_SYNCHRONIZE(); \
if( cudaSuccess != err) { \ if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \ fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \
errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) );\ errorMessage, __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \ exit(EXIT_FAILURE); \
} \ } \
} }
#else #else
# define CUDA_CHECK_ERROR(errorMessage) { \ #define CUDA_CHECK_ERROR(errorMessage) \
cudaError_t err = cudaGetLastError(); \ { \
if( cudaSuccess != err) { \ cudaError_t err = cudaGetLastError(); \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \ if (cudaSuccess != err) { \
errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) );\ fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \
exit(EXIT_FAILURE); \ errorMessage, __FILE__, __LINE__, cudaGetErrorString(err)); \
} \ exit(EXIT_FAILURE); \
} } \
}
#endif #endif
#endif #endif
\ No newline at end of file
...@@ -3,10 +3,10 @@ ...@@ -3,10 +3,10 @@
// ------------------------------------------------------------- // -------------------------------------------------------------
// $Revision:$ // $Revision:$
// $Date:$ // $Date:$
// ------------------------------------------------------------- // -------------------------------------------------------------
// This source code is distributed under the terms of license.txt in // This source code is distributed under the terms of license.txt in
// the root directory of this source distribution. // the root directory of this source distribution.
// ------------------------------------------------------------- // -------------------------------------------------------------
/** /**
* @file * @file
...@@ -29,44 +29,44 @@ namespace cuhash { ...@@ -29,44 +29,44 @@ namespace cuhash {
//! @name Debugging functions //! @name Debugging functions
/// @{ /// @{
void TakeHashFunctionStatistics(const unsigned num_keys, void TakeHashFunctionStatistics(const unsigned num_keys, const unsigned *d_keys,
const unsigned *d_keys, const unsigned table_size,
const unsigned table_size, const uint2 *constants,
const uint2 *constants, const unsigned kNumHashFunctions);
const unsigned kNumHashFunctions);
//! Output how many probes were required by each thread to perform the
//! Output how many probes were required by each thread to perform the retrieval. //! retrieval.
/*! @param[in] n_queries Number of queries being performed. /*! @param[in] n_queries Number of queries being performed.
* @param[in] d_retrieval_probes Device array: the number of probes taken for each thread's retrieval. * @param[in] d_retrieval_probes Device array: the number of probes taken for
* each thread's retrieval.
* @param[in] n_functions Number of hash functions used. * @param[in] n_functions Number of hash functions used.
*/ */
void OutputRetrievalStatistics(const unsigned n_queries, void OutputRetrievalStatistics(const unsigned n_queries,
const unsigned *d_retrieval_probes, const unsigned *d_retrieval_probes,
const unsigned n_functions); const unsigned n_functions);
//! Outputs information about how many iterations threads required to successfully cuckoo hash. //! Outputs information about how many iterations threads required to
//! successfully cuckoo hash.
/*! @param[in] n Number of keys in the input. /*! @param[in] n Number of keys in the input.
* @param[in] d_iterations_taken Device mem: Number of iterations each thread took. * @param[in] d_iterations_taken Device mem: Number of iterations each
* @param[in] d_max_iterations_taken Device mem: Largest number of iterations taken by any thread. * thread took.
* @param[in] d_max_iterations_taken Device mem: Largest number of iterations
* taken by any thread.
*/ */
void OutputBuildStatistics(const unsigned n, void OutputBuildStatistics(const unsigned n,
const unsigned *d_iterations_taken); const unsigned *d_iterations_taken);
//! Prints out the contents of the stash. //! Prints out the contents of the stash.
void PrintStashContents(const Entry *d_stash); void PrintStashContents(const Entry *d_stash);
//! Checks if a key is assigned the same slot by different hash functions. //! Checks if a key is assigned the same slot by different hash functions.
bool CheckAssignedSameSlot(const unsigned N, bool CheckAssignedSameSlot(const unsigned N, const unsigned num_keys,
const unsigned num_keys, const unsigned *d_keys, const unsigned table_size,
const unsigned *d_keys, uint2 *constants);
const unsigned table_size,
uint2 *constants);
/// @} /// @}
}; // namespace CuckooHashing }; // namespace cuhash
#endif #endif
......
...@@ -5,9 +5,9 @@ ...@@ -5,9 +5,9 @@
#ifndef HASH_FUNCTIONS__H #ifndef HASH_FUNCTIONS__H
#define HASH_FUNCTIONS__H #define HASH_FUNCTIONS__H
#include "definitions.h"
#include <tensorview/tensorview.h> #include <tensorview/tensorview.h>
#include <vector_types.h> #include <vector_types.h>
#include "definitions.h"
namespace cuhash { namespace cuhash {
...@@ -23,30 +23,28 @@ const unsigned kPrimeDivisor = 4294967291u; ...@@ -23,30 +23,28 @@ const unsigned kPrimeDivisor = 4294967291u;
/*! @param[in] N Number of hash functions. /*! @param[in] N Number of hash functions.
@param[out] constants CPU pointer to the constants. @param[out] constants CPU pointer to the constants.
@param[in] num_keys Debug only: How many keys are in the input. @param[in] num_keys Debug only: How many keys are in the input.
@param[in] d_keys Debug only: Device memory array containing the input keys. @param[in] d_keys Debug only: Device memory array containing the input
keys.
@param[in] table_size Debug only: Size of the hash table. @param[in] table_size Debug only: Size of the hash table.
*/ */
void GenerateFunctions(const unsigned N, void GenerateFunctions(const unsigned N, const unsigned num_keys,
const unsigned num_keys, const unsigned *d_keys, const unsigned table_size,
const unsigned *d_keys, uint2 *constants);
const unsigned table_size,
uint2 *constants);
//! Container for all of the hash functions. //! Container for all of the hash functions.
template <unsigned N> template <unsigned N> struct Functions {
struct Functions { //! The constants required for all of the hash functions, including the stash.
//! The constants required for all of the hash functions, including the stash. Each function requires 2. //! Each function requires 2.
uint2 constants[N]; uint2 constants[N];
//! Generate new hash function constants. //! Generate new hash function constants.
/*! The parameters are only used for debugging and examining the key distribution. /*! The parameters are only used for debugging and examining the key
\param[in] num_keys Debug: Number of keys in the input. distribution. \param[in] num_keys Debug: Number of keys in the input.
\param[in] d_keys Debug: Device array of the input keys. \param[in] d_keys Debug: Device array of the input keys.
\param[in] table_size Debug: Size of the hash table. \param[in] table_size Debug: Size of the hash table.
*/ */
void Generate(const unsigned num_keys, void Generate(const unsigned num_keys, const unsigned *d_keys,
const unsigned *d_keys, const unsigned table_size) {
const unsigned table_size) {
GenerateFunctions(N, num_keys, d_keys, table_size, constants); GenerateFunctions(N, num_keys, d_keys, table_size, constants);
} }
}; };
...@@ -56,17 +54,16 @@ struct Functions { ...@@ -56,17 +54,16 @@ struct Functions {
! \param[in] key Key being hashed. ! \param[in] key Key being hashed.
! \returns The value of the hash function for the key. ! \returns The value of the hash function for the key.
*/ */
inline __device__ __host__ inline __device__ __host__ unsigned hash_function_inner(const uint2 constants,
unsigned hash_function_inner(const uint2 constants, const unsigned key) {
const unsigned key) { #if 1
#if 1 // Fast version.
// Fast version.
return ((constants.x ^ key) + constants.y) % kPrimeDivisor; return ((constants.x ^ key) + constants.y) % kPrimeDivisor;
#else #else
// Slow version. // Slow version.
return ((unsigned long long)constants.x * key + constants.y) % kPrimeDivisor; return ((unsigned long long)constants.x * key + constants.y) % kPrimeDivisor;
#endif #endif
} }
//! Computes the value of a hash function for a given key. //! Computes the value of a hash function for a given key.
/*! \param[in] functions All of the constants used by the hash functions. /*! \param[in] functions All of the constants used by the hash functions.
...@@ -75,22 +72,20 @@ unsigned hash_function_inner(const uint2 constants, ...@@ -75,22 +72,20 @@ unsigned hash_function_inner(const uint2 constants,
! \returns The value of a hash function with a given key. ! \returns The value of a hash function with a given key.
*/ */
template <unsigned kNumHashFunctions> template <unsigned kNumHashFunctions>
TV_HOST_DEVICE_INLINE TV_HOST_DEVICE_INLINE unsigned
unsigned hash_function(const Functions<kNumHashFunctions> functions, hash_function(const Functions<kNumHashFunctions> functions,
const unsigned which_function, const unsigned which_function, const unsigned key) {
const unsigned key) {
return hash_function_inner(functions.constants[which_function], key); return hash_function_inner(functions.constants[which_function], key);
} }
//! Simple hash function used by the stash. //! Simple hash function used by the stash.
TV_HOST_DEVICE_INLINE TV_HOST_DEVICE_INLINE
unsigned stash_hash_function(const uint2 stash_constants, unsigned stash_hash_function(const uint2 stash_constants, const unsigned key) {
const unsigned key) {
return (stash_constants.x ^ key + stash_constants.y) % kStashSize; return (stash_constants.x ^ key + stash_constants.y) % kStashSize;
} }
unsigned generate_random_uint32(); unsigned generate_random_uint32();
}; // namespace CuckooHashing }; // namespace cuhash
#endif #endif
...@@ -3,10 +3,10 @@ ...@@ -3,10 +3,10 @@
// ------------------------------------------------------------- // -------------------------------------------------------------
// $Revision:$ // $Revision:$
// $Date:$ // $Date:$
// ------------------------------------------------------------- // -------------------------------------------------------------
// This source code is distributed under the terms of license.txt in // This source code is distributed under the terms of license.txt in
// the root directory of this source distribution. // the root directory of this source distribution.
// ------------------------------------------------------------- // -------------------------------------------------------------
/** /**
* @file hash_table.cuh * @file hash_table.cuh
...@@ -19,8 +19,8 @@ ...@@ -19,8 +19,8 @@
#include "definitions.h" #include "definitions.h"
#include "hash_table.h" #include "hash_table.h"
#include <tensorview/tensorview.h>
#include <driver_types.h> #include <driver_types.h>
#include <tensorview/tensorview.h>
namespace cuhash { namespace cuhash {
...@@ -31,51 +31,42 @@ TV_HOST_DEVICE_INLINE Entry make_entry(unsigned key, unsigned value) { ...@@ -31,51 +31,42 @@ TV_HOST_DEVICE_INLINE Entry make_entry(unsigned key, unsigned value) {
//! Returns the key of an Entry. //! Returns the key of an Entry.
TV_HOST_DEVICE_INLINE unsigned get_key(Entry entry) { TV_HOST_DEVICE_INLINE unsigned get_key(Entry entry) {
return (unsigned) (entry >> 32); return (unsigned)(entry >> 32);
} }
//! Returns the value of an Entry. //! Returns the value of an Entry.
TV_HOST_DEVICE_INLINE unsigned get_value(Entry entry) { TV_HOST_DEVICE_INLINE unsigned get_value(Entry entry) {
return (unsigned) (entry & 0xffffffff); return (unsigned)(entry & 0xffffffff);
} }
//! @name Internal //! @name Internal
//! @brief Functions used for building the hash table. //! @brief Functions used for building the hash table.
//! @{ //! @{
//! Fills the entire array with a specific value. //! Fills the entire array with a specific value.
template <class T> __global__ template <class T>
void clear_table(const unsigned table_size, __global__ void clear_table(const unsigned table_size, const T value,
const T value, T *table) {
T *table) unsigned thread_index = threadIdx.x + blockIdx.x * blockDim.x +
{
unsigned thread_index = threadIdx.x +
blockIdx.x * blockDim.x +
blockIdx.y * blockDim.x * gridDim.x; blockIdx.y * blockDim.x * gridDim.x;
if (thread_index < table_size) { if (thread_index < table_size) {
table[thread_index] = value; table[thread_index] = value;
} }
} }
//! Determine where in the hash table the key could be located. //! Determine where in the hash table the key could be located.
template <unsigned kNumHashFunctions> template <unsigned kNumHashFunctions>
__device__ void __device__ void KeyLocations(const Functions<kNumHashFunctions> constants,
KeyLocations(const Functions<kNumHashFunctions> constants, const unsigned table_size, const unsigned key,
const unsigned table_size, unsigned locations[kNumHashFunctions]) {
const unsigned key, // Compute all possible locations for the key in the big table.
unsigned locations[kNumHashFunctions]) #pragma unroll
{
// Compute all possible locations for the key in the big table.
#pragma unroll
for (int i = 0; i < kNumHashFunctions; ++i) { for (int i = 0; i < kNumHashFunctions; ++i) {
locations[i] = hash_function(constants, i, key) % table_size; locations[i] = hash_function(constants, i, key) % table_size;
} }
} }
//! @} //! @}
/* -------------------------------------------------------------------------- /* --------------------------------------------------------------------------
Retrieval functions. Retrieval functions.
-------------------------------------------------------------------------- */ -------------------------------------------------------------------------- */
...@@ -87,28 +78,27 @@ KeyLocations(const Functions<kNumHashFunctions> constants, ...@@ -87,28 +78,27 @@ KeyLocations(const Functions<kNumHashFunctions> constants,
* @param[in] constants The hash functions used to build the table * @param[in] constants The hash functions used to build the table
* @param[in] stash_constants The hash function used to build the stash * @param[in] stash_constants The hash function used to build the stash
* @param[in] stash_count The number of items in the stash * @param[in] stash_count The number of items in the stash
* @param[out] num_probes_required Debug only: The number of probes required to resolve the query. * @param[out] num_probes_required Debug only: The number of probes required
* @returns The value of the query key, if the key exists in the table. Otherwise, \ref kNotFound will be returned. * to resolve the query.
* @returns The value of the query key, if the key exists in the table.
* Otherwise, \ref kNotFound will be returned.
*/ */
template <unsigned kNumHashFunctions> __device__ template <unsigned kNumHashFunctions>
unsigned retrieve(const unsigned query_key, __device__ unsigned
const unsigned table_size, retrieve(const unsigned query_key, const unsigned table_size,
const Entry *table, const Entry *table, const Functions<kNumHashFunctions> constants,
const Functions<kNumHashFunctions> constants, const uint2 stash_constants, const unsigned stash_count,
const uint2 stash_constants, unsigned *num_probes_required = NULL) {
const unsigned stash_count,
unsigned *num_probes_required = NULL)
{
// Identify all of the locations that the key can be located in. // Identify all of the locations that the key can be located in.
unsigned locations[kNumHashFunctions]; unsigned locations[kNumHashFunctions];
KeyLocations(constants, table_size, query_key, locations); KeyLocations(constants, table_size, query_key, locations);
// Check each location until the key is found. // Check each location until the key is found.
unsigned num_probes = 1; unsigned num_probes = 1;
Entry entry = table[locations[0]]; Entry entry = table[locations[0]];
unsigned key = get_key(entry); unsigned key = get_key(entry);
#pragma unroll #pragma unroll
for (unsigned i = 1; i < kNumHashFunctions; ++i) { for (unsigned i = 1; i < kNumHashFunctions; ++i) {
if (key != query_key && key != kNotFound) { if (key != query_key && key != kNotFound) {
num_probes++; num_probes++;
...@@ -138,37 +128,26 @@ unsigned retrieve(const unsigned query_key, ...@@ -138,37 +128,26 @@ unsigned retrieve(const unsigned query_key,
} }
} }
//! Perform a retrieval from a basic hash table. Each thread manages a single
//! Perform a retrieval from a basic hash table. Each thread manages a single query. //! query.
template <unsigned kNumHashFunctions> __global__ template <unsigned kNumHashFunctions>
void hash_retrieve(const unsigned n_queries, __global__ void hash_retrieve(const unsigned n_queries, const unsigned *keys_in,
const unsigned *keys_in, const unsigned table_size, const Entry *table,
const unsigned table_size, const Functions<kNumHashFunctions> constants,
const Entry *table, const uint2 stash_constants,
const Functions<kNumHashFunctions> constants, const unsigned stash_count, unsigned *values_out,
const uint2 stash_constants, unsigned *num_probes_required = NULL) {
const unsigned stash_count,
unsigned *values_out,
unsigned *num_probes_required = NULL)
{
// Get the key. // Get the key.
unsigned thread_index = threadIdx.x + unsigned thread_index = threadIdx.x + blockIdx.x * blockDim.x +
blockIdx.x * blockDim.x +
blockIdx.y * blockDim.x * gridDim.x; blockIdx.y * blockDim.x * gridDim.x;
if (thread_index >= n_queries) if (thread_index >= n_queries)
return; return;
unsigned key = keys_in[thread_index]; unsigned key = keys_in[thread_index];
values_out[thread_index] = retrieve<kNumHashFunctions> values_out[thread_index] = retrieve<kNumHashFunctions>(
(key, key, table_size, table, constants, stash_constants, stash_count,
table_size, (num_probes_required ? num_probes_required + thread_index : NULL));
table, }
constants,
stash_constants,
stash_count,
(num_probes_required ? num_probes_required + thread_index : NULL));
}
/* -------------------------------------------------------------------------- /* --------------------------------------------------------------------------
Build a cuckoo hash table. Build a cuckoo hash table.
...@@ -176,55 +155,53 @@ void hash_retrieve(const unsigned n_queries, ...@@ -176,55 +155,53 @@ void hash_retrieve(const unsigned n_queries,
//! @name Internal //! @name Internal
//! @{ //! @{
//! Determine where to insert the key next. The hash functions are used in round-robin order. //! Determine where to insert the key next. The hash functions are used in
template <unsigned kNumHashFunctions> __device__ //! round-robin order.
unsigned determine_next_location(const Functions<kNumHashFunctions> constants, template <unsigned kNumHashFunctions>
const unsigned table_size, __device__ unsigned
const unsigned key, determine_next_location(const Functions<kNumHashFunctions> constants,
const unsigned previous_location) { const unsigned table_size, const unsigned key,
const unsigned previous_location) {
// Identify all possible locations for the entry. // Identify all possible locations for the entry.
unsigned locations[kNumHashFunctions]; unsigned locations[kNumHashFunctions];
#pragma unroll #pragma unroll
for (unsigned i = 0; i < kNumHashFunctions; ++i) { for (unsigned i = 0; i < kNumHashFunctions; ++i) {
locations[i] = hash_function(constants, i, key) % table_size; locations[i] = hash_function(constants, i, key) % table_size;
} }
// Figure out where the item should be inserted next. // Figure out where the item should be inserted next.
unsigned next_location = locations[0]; unsigned next_location = locations[0];
#pragma unroll #pragma unroll
for (int i = kNumHashFunctions - 2; i >= 0; --i) { for (int i = kNumHashFunctions - 2; i >= 0; --i) {
next_location = (previous_location == locations[i] ? locations[i+1] next_location =
: next_location); (previous_location == locations[i] ? locations[i + 1] : next_location);
} }
return next_location; return next_location;
} }
//! Attempts to insert a single entry into the hash table. //! Attempts to insert a single entry into the hash table.
/*! This process stops after a certain number of iterations. If the thread is /*! This process stops after a certain number of iterations. If the thread is
still holding onto an item because of an eviction, it tries the stash. still holding onto an item because of an eviction, it tries the stash.
If it fails to enter the stash, it returns false. If it fails to enter the stash, it returns false.
Otherwise, it succeeds and returns true. Otherwise, it succeeds and returns true.
*/ */
template <unsigned kNumHashFunctions> __device__ template <unsigned kNumHashFunctions>
bool insert(const unsigned table_size, __device__ bool
const Functions<kNumHashFunctions> constants, insert(const unsigned table_size, const Functions<kNumHashFunctions> constants,
const uint2 stash_constants, const uint2 stash_constants, const unsigned max_iteration_attempts,
const unsigned max_iteration_attempts, Entry *table, unsigned *stash_count, Entry entry,
Entry *table, unsigned *iterations_used) {
unsigned *stash_count,
Entry entry,
unsigned *iterations_used) {
unsigned key = get_key(entry); unsigned key = get_key(entry);
// The key is always inserted into its first slot at the start. // The key is always inserted into its first slot at the start.
unsigned location = hash_function(constants, 0, key) % table_size; unsigned location = hash_function(constants, 0, key) % table_size;
// Keep inserting until an empty slot is found or the eviction chain grows too large. // Keep inserting until an empty slot is found or the eviction chain grows too
// large.
for (unsigned its = 1; its <= max_iteration_attempts; its++) { for (unsigned its = 1; its <= max_iteration_attempts; its++) {
// Insert the new entry. // Insert the new entry.
entry = atomicExch(&table[location], entry); entry = atomicExch(&table[location], entry);
key = get_key(entry); key = get_key(entry);
// If no key was evicted, we're done. // If no key was evicted, we're done.
if (key == kKeyEmpty) { if (key == kKeyEmpty) {
...@@ -251,54 +228,46 @@ bool insert(const unsigned table_size, ...@@ -251,54 +228,46 @@ bool insert(const unsigned table_size,
return true; return true;
} }
// Build a basic hash table, using one big table. // Build a basic hash table, using one big table.
template <unsigned kNumHashFunctions> __global__ template <unsigned kNumHashFunctions>
void CuckooHash(const unsigned n_entries, __global__ void CuckooHash(const unsigned n_entries, const unsigned *keys,
const unsigned *keys, const unsigned *values, const unsigned table_size,
const unsigned *values, const Functions<kNumHashFunctions> constants,
const unsigned table_size, const unsigned max_iteration_attempts, Entry *table,
const Functions<kNumHashFunctions> constants, uint2 stash_constants, unsigned *stash_count,
const unsigned max_iteration_attempts, unsigned *failures,
Entry *table, unsigned *iterations_taken = nullptr) {
uint2 stash_constants,
unsigned *stash_count,
unsigned *failures,
unsigned *iterations_taken = nullptr) {
// Check if this thread has an item and if any previous threads failed. // Check if this thread has an item and if any previous threads failed.
unsigned thread_index = threadIdx.x + unsigned thread_index = threadIdx.x + blockIdx.x * blockDim.x +
blockIdx.x * blockDim.x +
blockIdx.y * blockDim.x * gridDim.x; blockIdx.y * blockDim.x * gridDim.x;
if (thread_index >= n_entries || *failures) if (thread_index >= n_entries || *failures)
return; return;
Entry entry = make_entry(keys[thread_index], values[thread_index]); Entry entry = make_entry(keys[thread_index], values[thread_index]);
unsigned iterations = 0; unsigned iterations = 0;
bool success = insert<kNumHashFunctions> bool success = insert<kNumHashFunctions>(
(table_size, constants, stash_constants, table_size, constants, stash_constants, max_iteration_attempts, table,
max_iteration_attempts, table, stash_count, entry, &iterations); stash_count, entry, &iterations);
if (success == false) { if (success == false) {
// The eviction chain grew too large. Report failure. // The eviction chain grew too large. Report failure.
#ifdef COUNT_UNINSERTED #ifdef COUNT_UNINSERTED
atomicAdd(failures, 1); atomicAdd(failures, 1);
#else #else
*failures = 1; *failures = 1;
#endif #endif
} }
#ifdef TRACK_ITERATIONS #ifdef TRACK_ITERATIONS
iterations_taken[thread_index] = iterations; iterations_taken[thread_index] = iterations;
#endif #endif
} }
//! @} //! @}
}; // namespace CuckooHashing }; // namespace cuhash
#endif #endif
// Leave this at the end of the file // Leave this at the end of the file
// Local Variables: // Local Variables:
// mode:c++ // mode:c++
......
...@@ -3,10 +3,10 @@ ...@@ -3,10 +3,10 @@
// ------------------------------------------------------------- // -------------------------------------------------------------
// $Revision:$ // $Revision:$
// $Date:$ // $Date:$
// ------------------------------------------------------------- // -------------------------------------------------------------
// This source code is distributed under the terms of license.txt in // This source code is distributed under the terms of license.txt in
// the root directory of this source distribution. // the root directory of this source distribution.
// ------------------------------------------------------------- // -------------------------------------------------------------
/** /**
* @file hash_table.h * @file hash_table.h
...@@ -17,15 +17,14 @@ ...@@ -17,15 +17,14 @@
#ifndef CUDAHT__CUCKOO__SRC__LIBRARY__HASH_TABLE__H #ifndef CUDAHT__CUCKOO__SRC__LIBRARY__HASH_TABLE__H
#define CUDAHT__CUCKOO__SRC__LIBRARY__HASH_TABLE__H #define CUDAHT__CUCKOO__SRC__LIBRARY__HASH_TABLE__H
#include "definitions.h" #include "definitions.h"
#include "hash_functions.h" #include "hash_functions.h"
#include <cstdio> #include <cstdio>
/** \addtogroup cudpp_app /** \addtogroup cudpp_app
* @{ * @{
*/ */
/** \addtogroup cudpp_hash_data_structures /** \addtogroup cudpp_hash_data_structures
* @{ * @{
...@@ -50,7 +49,8 @@ namespace cuhash { ...@@ -50,7 +49,8 @@ namespace cuhash {
//! Compute how many thread blocks are required for the given number of threads. //! Compute how many thread blocks are required for the given number of threads.
dim3 ComputeGridDim(unsigned threads); dim3 ComputeGridDim(unsigned threads);
//! Compute how long an eviction chain is allowed to become for a given input size. //! Compute how long an eviction chain is allowed to become for a given input
//! size.
/*! \param[in] num_keys Number of keys in the input. /*! \param[in] num_keys Number of keys in the input.
* \param[in] table_size Number of slots in the hash table. * \param[in] table_size Number of slots in the hash table.
* \param[in] num_functions Number of hash functions being used. * \param[in] num_functions Number of hash functions being used.
...@@ -72,10 +72,10 @@ unsigned ComputeMaxIterations(const unsigned num_keys, ...@@ -72,10 +72,10 @@ unsigned ComputeMaxIterations(const unsigned num_keys,
* @ingroup cudpp_app * @ingroup cudpp_app
*/ */
class HashTable { class HashTable {
public: public:
HashTable(); HashTable();
virtual ~HashTable() {Release();} virtual ~HashTable() { Release(); }
//! Initialize the hash table's memory. Must be called before \ref //! Initialize the hash table's memory. Must be called before \ref
//! Build() and after the random number generator has been seeded. //! Build() and after the random number generator has been seeded.
...@@ -87,7 +87,7 @@ class HashTable { ...@@ -87,7 +87,7 @@ class HashTable {
* 2-5. More hash functions make it easier * 2-5. More hash functions make it easier
* to build the table, but increase * to build the table, but increase
* retrieval times. * retrieval times.
* @returns Whether the hash table was initialized successfully (true) * @returns Whether the hash table was initialized successfully (true)
* or not (false). * or not (false).
* *
* The minimum space usage is dependent on the number of functions * The minimum space usage is dependent on the number of functions
...@@ -95,28 +95,27 @@ class HashTable { ...@@ -95,28 +95,27 @@ class HashTable {
* usage is 2.1, 1.1, 1.03, and 1.02 respectively. * usage is 2.1, 1.1, 1.03, and 1.02 respectively.
*/ */
virtual bool Initialize(const unsigned max_input_size, virtual bool Initialize(const unsigned max_input_size,
const float space_usage = 1.25, const float space_usage = 1.25,
const unsigned num_functions = 4); const unsigned num_functions = 4);
//! Free all memory. //! Free all memory.
virtual void Release(); virtual void Release();
//! Build the hash table. //! Build the hash table.
/*! @param[in] input_size Number of key-value pairs being inserted. /*! @param[in] input_size Number of key-value pairs being inserted.
* @param[in] d_keys Device memory array containing all of the input * @param[in] d_keys Device memory array containing all of the input
* keys. * keys.
* @param[in] d_vals Device memory array containing the keys' values. * @param[in] d_vals Device memory array containing the keys' values.
* @returns Whether the hash table was built successfully (true) or * @returns Whether the hash table was built successfully (true) or
* not (false). * not (false).
* *
* Several attempts are allowed to build the hash table in case of failure. * Several attempts are allowed to build the hash table in case of failure.
* The input keys are expected to be completely unique. * The input keys are expected to be completely unique.
* To reduce the chance of a failure, increase the space usage or number of * To reduce the chance of a failure, increase the space usage or number of
* functions. * functions.
* Keys are not allowed to be equal to cuhash::kKeyEmpty. * Keys are not allowed to be equal to cuhash::kKeyEmpty.
*/ */
virtual bool Build(const unsigned input_size, virtual bool Build(const unsigned input_size, const unsigned *d_keys,
const unsigned *d_keys,
const unsigned *d_vals); const unsigned *d_vals);
//! Query the hash table. //! Query the hash table.
...@@ -128,9 +127,8 @@ class HashTable { ...@@ -128,9 +127,8 @@ class HashTable {
* kNotFound is returned for any query key that failed to be found * kNotFound is returned for any query key that failed to be found
* in the table. * in the table.
*/ */
virtual void Retrieve(const unsigned n_queries, virtual void Retrieve(const unsigned n_queries, const unsigned *d_query_keys,
const unsigned *d_query_keys, unsigned *d_query_results);
unsigned *d_query_results);
//! @name Accessors //! @name Accessors
/// @brief Mainly needed to use the __device__ CudaHT::retrieve() /// @brief Mainly needed to use the __device__ CudaHT::retrieve()
...@@ -138,96 +136,85 @@ class HashTable { ...@@ -138,96 +136,85 @@ class HashTable {
/// @{ /// @{
//! Returns how many slots the hash table has. //! Returns how many slots the hash table has.
inline unsigned get_table_size() const {return table_size_;} inline unsigned get_table_size() const { return table_size_; }
//! Returns how many items are stored in the stash. //! Returns how many items are stored in the stash.
inline unsigned get_stash_count() const {return stash_count_;} inline unsigned get_stash_count() const { return stash_count_; }
//! Returns the constants used by the stash. //! Returns the constants used by the stash.
inline uint2 get_stash_constants() const {return stash_constants_;} inline uint2 get_stash_constants() const { return stash_constants_; }
//! Returns the hash table contents. //! Returns the hash table contents.
inline const Entry* get_contents() const {return d_contents_;} inline const Entry *get_contents() const { return d_contents_; }
//! Returns the number of hash functions being used. //! Returns the number of hash functions being used.
inline unsigned get_num_hash_functions() const {return inline unsigned get_num_hash_functions() const { return num_hash_functions_; }
num_hash_functions_;}
//! When using two hash functions, returns the constants. //! When using two hash functions, returns the constants.
inline Functions<2> get_constants_2() const {return constants_2_;} inline Functions<2> get_constants_2() const { return constants_2_; }
//! When using three hash functions, returns the constants. //! When using three hash functions, returns the constants.
inline Functions<3> get_constants_3() const {return constants_3_;} inline Functions<3> get_constants_3() const { return constants_3_; }
//! When using four hash functions, returns the constants. //! When using four hash functions, returns the constants.
inline Functions<4> get_constants_4() const {return constants_4_;} inline Functions<4> get_constants_4() const { return constants_4_; }
//! When using five hash functions, returns the constants. //! When using five hash functions, returns the constants.
inline Functions<5> get_constants_5() const {return constants_5_;} inline Functions<5> get_constants_5() const { return constants_5_; }
/// @} /// @}
inline Entry * data(){return d_contents_;} inline Entry *data() { return d_contents_; }
inline const Entry * data() const {return d_contents_;} inline const Entry *data() const { return d_contents_; }
protected: protected:
unsigned table_size_; //!< Size of the hash table. unsigned table_size_; //!< Size of the hash table.
unsigned num_hash_functions_; //!< Number of hash functions being used. unsigned num_hash_functions_; //!< Number of hash functions being used.
Entry *d_contents_; //!< Device memory: The hash table contents. The stash is stored at the end. Entry *d_contents_; //!< Device memory: The hash table contents. The stash is
unsigned stash_count_; //!< Number of key-value pairs currently stored. //!< stored at the end.
uint2 stash_constants_; //!< Hash function constants for the stash. unsigned stash_count_; //!< Number of key-value pairs currently stored.
uint2 stash_constants_; //!< Hash function constants for the stash.
Functions<2> constants_2_; //!< Constants for a set of two hash functions.
Functions<3> constants_3_; //!< Constants for a set of three hash functions. Functions<2> constants_2_; //!< Constants for a set of two hash functions.
Functions<4> constants_4_; //!< Constants for a set of four hash functions. Functions<3> constants_3_; //!< Constants for a set of three hash functions.
Functions<5> constants_5_; //!< Constants for a set of five hash functions. Functions<4> constants_4_; //!< Constants for a set of four hash functions.
Functions<5> constants_5_; //!< Constants for a set of five hash functions.
unsigned *d_failures_; //!< Device memory: General use error flag.
unsigned *d_failures_; //!< Device memory: General use error flag.
}; };
/*! @name Internal /*! @name Internal
* @{ * @{
*/ */
namespace CUDAWrapper { namespace CUDAWrapper {
//! Fills a 64-bit array with a particular value. //! Fills a 64-bit array with a particular value.
void ClearTable(const unsigned slots_in_table, void ClearTable(const unsigned slots_in_table, const Entry fill_value,
const Entry fill_value, Entry *d_array);
Entry *d_array);
//! Calls the Cuckoo Hash construction kernel. //! Calls the Cuckoo Hash construction kernel.
void CallCuckooHash(const unsigned n_entries, void CallCuckooHash(const unsigned n_entries, const unsigned num_hash_functions,
const unsigned num_hash_functions, const unsigned *d_keys, const unsigned *d_values,
const unsigned *d_keys, const unsigned table_size, const Functions<2> constants_2,
const unsigned *d_values, const Functions<3> constants_3,
const unsigned table_size, const Functions<4> constants_4,
const Functions<2> constants_2, const Functions<5> constants_5,
const Functions<3> constants_3, const unsigned max_iteration_attempts, Entry *d_contents,
const Functions<4> constants_4, uint2 stash_constants, unsigned *d_stash_count,
const Functions<5> constants_5, unsigned *d_failures, unsigned *d_iterations_taken);
const unsigned max_iteration_attempts,
Entry *d_contents,
uint2 stash_constants,
unsigned *d_stash_count,
unsigned *d_failures,
unsigned *d_iterations_taken);
//! Calls the kernel that performs retrievals. //! Calls the kernel that performs retrievals.
void CallHashRetrieve(const unsigned n_queries, void CallHashRetrieve(const unsigned n_queries,
const unsigned num_hash_functions, const unsigned num_hash_functions,
const unsigned *keys_in, const unsigned *keys_in, const unsigned table_size,
const unsigned table_size, const Entry *table, const Functions<2> constants_2,
const Entry *table, const Functions<3> constants_3,
const Functions<2> constants_2, const Functions<4> constants_4,
const Functions<3> constants_3, const Functions<5> constants_5,
const Functions<4> constants_4, const uint2 stash_constants, const unsigned stash_count,
const Functions<5> constants_5, unsigned *values_out);
const uint2 stash_constants, }; // namespace CUDAWrapper
const unsigned stash_count,
unsigned *values_out);
};
/// @} /// @}
}; // namespace CuckooHashing }; // namespace cuhash
/** @} */ // end hash table data structures /** @} */ // end hash table data structures
/** @} */ // end cudpp_app /** @} */ // end cudpp_app
......
// Copyright 2019 Yan Yan // Copyright 2019 Yan Yan
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// This file is used for c++ unit test, but pytorch jit ops don't support c++ debug build. // This file is used for c++ unit test, but pytorch jit ops don't support c++
// debug build.
#ifndef PARAMS_GRID_H_ #ifndef PARAMS_GRID_H_
#define PARAMS_GRID_H_ #define PARAMS_GRID_H_
......
// Copyright Louis Delacroix 2010 - 2014.
// Distributed under the Boost Software License, Version 1.0.
// (See accompanying file LICENSE_1_0.txt or copy at
// http://www.boost.org/LICENSE_1_0.txt)
//
// A pretty printing library for C++
//
// Usage:
// Include this header, and operator<< will "just work".
#ifndef H_PRETTY_PRINT
#define H_PRETTY_PRINT
#include <cstddef>
#include <iterator>
#include <memory>
#include <ostream>
#include <set>
#include <tuple>
#include <type_traits>
#include <unordered_set>
#include <utility>
#include <valarray>
namespace pretty_print
{
namespace detail
{
// SFINAE type trait to detect whether T::const_iterator exists.
struct sfinae_base
{
using yes = char;
using no = yes[2];
};
template <typename T>
struct has_const_iterator : private sfinae_base
{
private:
template <typename C> static yes & test(typename C::const_iterator*);
template <typename C> static no & test(...);
public:
static const bool value = sizeof(test<T>(nullptr)) == sizeof(yes);
using type = T;
};
template <typename T>
struct has_begin_end : private sfinae_base
{
private:
template <typename C>
static yes & f(typename std::enable_if<
std::is_same<decltype(static_cast<typename C::const_iterator(C::*)() const>(&C::begin)),
typename C::const_iterator(C::*)() const>::value>::type *);
template <typename C> static no & f(...);
template <typename C>
static yes & g(typename std::enable_if<
std::is_same<decltype(static_cast<typename C::const_iterator(C::*)() const>(&C::end)),
typename C::const_iterator(C::*)() const>::value, void>::type*);
template <typename C> static no & g(...);
public:
static bool const beg_value = sizeof(f<T>(nullptr)) == sizeof(yes);
static bool const end_value = sizeof(g<T>(nullptr)) == sizeof(yes);
};
} // namespace detail
// Holds the delimiter values for a specific character type
template <typename TChar>
struct delimiters_values
{
using char_type = TChar;
const char_type * prefix;
const char_type * delimiter;
const char_type * postfix;
};
// Defines the delimiter values for a specific container and character type
template <typename T, typename TChar>
struct delimiters
{
using type = delimiters_values<TChar>;
static const type values;
};
// Functor to print containers. You can use this directly if you want
// to specificy a non-default delimiters type. The printing logic can
// be customized by specializing the nested template.
template <typename T,
typename TChar = char,
typename TCharTraits = ::std::char_traits<TChar>,
typename TDelimiters = delimiters<T, TChar>>
struct print_container_helper
{
using delimiters_type = TDelimiters;
using ostream_type = std::basic_ostream<TChar, TCharTraits>;
template <typename U>
struct printer
{
static void print_body(const U & c, ostream_type & stream)
{
using std::begin;
using std::end;
auto it = begin(c);
const auto the_end = end(c);
if (it != the_end)
{
for ( ; ; )
{
stream << *it;
if (++it == the_end) break;
if (delimiters_type::values.delimiter != NULL)
stream << delimiters_type::values.delimiter;
}
}
}
};
print_container_helper(const T & container)
: container_(container)
{ }
inline void operator()(ostream_type & stream) const
{
if (delimiters_type::values.prefix != NULL)
stream << delimiters_type::values.prefix;
printer<T>::print_body(container_, stream);
if (delimiters_type::values.postfix != NULL)
stream << delimiters_type::values.postfix;
}
private:
const T & container_;
};
// Specialization for pairs
template <typename T, typename TChar, typename TCharTraits, typename TDelimiters>
template <typename T1, typename T2>
struct print_container_helper<T, TChar, TCharTraits, TDelimiters>::printer<std::pair<T1, T2>>
{
using ostream_type = typename print_container_helper<T, TChar, TCharTraits, TDelimiters>::ostream_type;
static void print_body(const std::pair<T1, T2> & c, ostream_type & stream)
{
stream << c.first;
if (print_container_helper<T, TChar, TCharTraits, TDelimiters>::delimiters_type::values.delimiter != NULL)
stream << print_container_helper<T, TChar, TCharTraits, TDelimiters>::delimiters_type::values.delimiter;
stream << c.second;
}
};
// Specialization for tuples
template <typename T, typename TChar, typename TCharTraits, typename TDelimiters>
template <typename ...Args>
struct print_container_helper<T, TChar, TCharTraits, TDelimiters>::printer<std::tuple<Args...>>
{
using ostream_type = typename print_container_helper<T, TChar, TCharTraits, TDelimiters>::ostream_type;
using element_type = std::tuple<Args...>;
template <std::size_t I> struct Int { };
static void print_body(const element_type & c, ostream_type & stream)
{
tuple_print(c, stream, Int<0>());
}
static void tuple_print(const element_type &, ostream_type &, Int<sizeof...(Args)>)
{
}
static void tuple_print(const element_type & c, ostream_type & stream,
typename std::conditional<sizeof...(Args) != 0, Int<0>, std::nullptr_t>::type)
{
stream << std::get<0>(c);
tuple_print(c, stream, Int<1>());
}
template <std::size_t N>
static void tuple_print(const element_type & c, ostream_type & stream, Int<N>)
{
if (print_container_helper<T, TChar, TCharTraits, TDelimiters>::delimiters_type::values.delimiter != NULL)
stream << print_container_helper<T, TChar, TCharTraits, TDelimiters>::delimiters_type::values.delimiter;
stream << std::get<N>(c);
tuple_print(c, stream, Int<N + 1>());
}
};
// Prints a print_container_helper to the specified stream.
template<typename T, typename TChar, typename TCharTraits, typename TDelimiters>
inline std::basic_ostream<TChar, TCharTraits> & operator<<(
std::basic_ostream<TChar, TCharTraits> & stream,
const print_container_helper<T, TChar, TCharTraits, TDelimiters> & helper)
{
helper(stream);
return stream;
}
// Basic is_container template; specialize to derive from std::true_type for all desired container types
template <typename T>
struct is_container : public std::integral_constant<bool,
detail::has_const_iterator<T>::value &&
detail::has_begin_end<T>::beg_value &&
detail::has_begin_end<T>::end_value> { };
template <typename T, std::size_t N>
struct is_container<T[N]> : std::true_type { };
template <std::size_t N>
struct is_container<char[N]> : std::false_type { };
template <typename T>
struct is_container<std::valarray<T>> : std::true_type { };
template <typename T1, typename T2>
struct is_container<std::pair<T1, T2>> : std::true_type { };
template <typename ...Args>
struct is_container<std::tuple<Args...>> : std::true_type { };
// Default delimiters
template <typename T> struct delimiters<T, char> { static const delimiters_values<char> values; };
template <typename T> const delimiters_values<char> delimiters<T, char>::values = { "[", ", ", "]" };
template <typename T> struct delimiters<T, wchar_t> { static const delimiters_values<wchar_t> values; };
template <typename T> const delimiters_values<wchar_t> delimiters<T, wchar_t>::values = { L"[", L", ", L"]" };
// Delimiters for (multi)set and unordered_(multi)set
template <typename T, typename TComp, typename TAllocator>
struct delimiters< ::std::set<T, TComp, TAllocator>, char> { static const delimiters_values<char> values; };
template <typename T, typename TComp, typename TAllocator>
const delimiters_values<char> delimiters< ::std::set<T, TComp, TAllocator>, char>::values = { "{", ", ", "}" };
template <typename T, typename TComp, typename TAllocator>
struct delimiters< ::std::set<T, TComp, TAllocator>, wchar_t> { static const delimiters_values<wchar_t> values; };
template <typename T, typename TComp, typename TAllocator>
const delimiters_values<wchar_t> delimiters< ::std::set<T, TComp, TAllocator>, wchar_t>::values = { L"{", L", ", L"}" };
template <typename T, typename TComp, typename TAllocator>
struct delimiters< ::std::multiset<T, TComp, TAllocator>, char> { static const delimiters_values<char> values; };
template <typename T, typename TComp, typename TAllocator>
const delimiters_values<char> delimiters< ::std::multiset<T, TComp, TAllocator>, char>::values = { "{", ", ", "}" };
template <typename T, typename TComp, typename TAllocator>
struct delimiters< ::std::multiset<T, TComp, TAllocator>, wchar_t> { static const delimiters_values<wchar_t> values; };
template <typename T, typename TComp, typename TAllocator>
const delimiters_values<wchar_t> delimiters< ::std::multiset<T, TComp, TAllocator>, wchar_t>::values = { L"{", L", ", L"}" };
template <typename T, typename THash, typename TEqual, typename TAllocator>
struct delimiters< ::std::unordered_set<T, THash, TEqual, TAllocator>, char> { static const delimiters_values<char> values; };
template <typename T, typename THash, typename TEqual, typename TAllocator>
const delimiters_values<char> delimiters< ::std::unordered_set<T, THash, TEqual, TAllocator>, char>::values = { "{", ", ", "}" };
template <typename T, typename THash, typename TEqual, typename TAllocator>
struct delimiters< ::std::unordered_set<T, THash, TEqual, TAllocator>, wchar_t> { static const delimiters_values<wchar_t> values; };
template <typename T, typename THash, typename TEqual, typename TAllocator>
const delimiters_values<wchar_t> delimiters< ::std::unordered_set<T, THash, TEqual, TAllocator>, wchar_t>::values = { L"{", L", ", L"}" };
template <typename T, typename THash, typename TEqual, typename TAllocator>
struct delimiters< ::std::unordered_multiset<T, THash, TEqual, TAllocator>, char> { static const delimiters_values<char> values; };
template <typename T, typename THash, typename TEqual, typename TAllocator>
const delimiters_values<char> delimiters< ::std::unordered_multiset<T, THash, TEqual, TAllocator>, char>::values = { "{", ", ", "}" };
template <typename T, typename THash, typename TEqual, typename TAllocator>
struct delimiters< ::std::unordered_multiset<T, THash, TEqual, TAllocator>, wchar_t> { static const delimiters_values<wchar_t> values; };
template <typename T, typename THash, typename TEqual, typename TAllocator>
const delimiters_values<wchar_t> delimiters< ::std::unordered_multiset<T, THash, TEqual, TAllocator>, wchar_t>::values = { L"{", L", ", L"}" };
// Delimiters for pair and tuple
template <typename T1, typename T2> struct delimiters<std::pair<T1, T2>, char> { static const delimiters_values<char> values; };
template <typename T1, typename T2> const delimiters_values<char> delimiters<std::pair<T1, T2>, char>::values = { "(", ", ", ")" };
template <typename T1, typename T2> struct delimiters< ::std::pair<T1, T2>, wchar_t> { static const delimiters_values<wchar_t> values; };
template <typename T1, typename T2> const delimiters_values<wchar_t> delimiters< ::std::pair<T1, T2>, wchar_t>::values = { L"(", L", ", L")" };
template <typename ...Args> struct delimiters<std::tuple<Args...>, char> { static const delimiters_values<char> values; };
template <typename ...Args> const delimiters_values<char> delimiters<std::tuple<Args...>, char>::values = { "(", ", ", ")" };
template <typename ...Args> struct delimiters< ::std::tuple<Args...>, wchar_t> { static const delimiters_values<wchar_t> values; };
template <typename ...Args> const delimiters_values<wchar_t> delimiters< ::std::tuple<Args...>, wchar_t>::values = { L"(", L", ", L")" };
// Type-erasing helper class for easy use of custom delimiters.
// Requires TCharTraits = std::char_traits<TChar> and TChar = char or wchar_t, and MyDelims needs to be defined for TChar.
// Usage: "cout << pretty_print::custom_delims<MyDelims>(x)".
struct custom_delims_base
{
virtual ~custom_delims_base() { }
virtual std::ostream & stream(::std::ostream &) = 0;
virtual std::wostream & stream(::std::wostream &) = 0;
};
template <typename T, typename Delims>
struct custom_delims_wrapper : custom_delims_base
{
custom_delims_wrapper(const T & t_) : t(t_) { }
std::ostream & stream(std::ostream & s)
{
return s << print_container_helper<T, char, std::char_traits<char>, Delims>(t);
}
std::wostream & stream(std::wostream & s)
{
return s << print_container_helper<T, wchar_t, std::char_traits<wchar_t>, Delims>(t);
}
private:
const T & t;
};
template <typename Delims>
struct custom_delims
{
template <typename Container>
custom_delims(const Container & c) : base(new custom_delims_wrapper<Container, Delims>(c)) { }
std::unique_ptr<custom_delims_base> base;
};
template <typename TChar, typename TCharTraits, typename Delims>
inline std::basic_ostream<TChar, TCharTraits> & operator<<(std::basic_ostream<TChar, TCharTraits> & s, const custom_delims<Delims> & p)
{
return p.base->stream(s);
}
// A wrapper for a C-style array given as pointer-plus-size.
// Usage: std::cout << pretty_print_array(arr, n) << std::endl;
template<typename T>
struct array_wrapper_n
{
typedef const T * const_iterator;
typedef T value_type;
array_wrapper_n(const T * const a, size_t n) : _array(a), _n(n) { }
inline const_iterator begin() const { return _array; }
inline const_iterator end() const { return _array + _n; }
private:
const T * const _array;
size_t _n;
};
// A wrapper for hash-table based containers that offer local iterators to each bucket.
// Usage: std::cout << bucket_print(m, 4) << std::endl; (Prints bucket 5 of container m.)
template <typename T>
struct bucket_print_wrapper
{
typedef typename T::const_local_iterator const_iterator;
typedef typename T::size_type size_type;
const_iterator begin() const
{
return m_map.cbegin(n);
}
const_iterator end() const
{
return m_map.cend(n);
}
bucket_print_wrapper(const T & m, size_type bucket) : m_map(m), n(bucket) { }
private:
const T & m_map;
const size_type n;
};
} // namespace pretty_print
// Global accessor functions for the convenience wrappers
template<typename T>
inline pretty_print::array_wrapper_n<T> pretty_print_array(const T * const a, size_t n)
{
return pretty_print::array_wrapper_n<T>(a, n);
}
template <typename T> pretty_print::bucket_print_wrapper<T>
bucket_print(const T & m, typename T::size_type n)
{
return pretty_print::bucket_print_wrapper<T>(m, n);
}
// Main magic entry point: An overload snuck into namespace std.
// Can we do better?
namespace std
{
// Prints a container to the stream using default delimiters
template<typename T, typename TChar, typename TCharTraits>
inline typename enable_if< ::pretty_print::is_container<T>::value,
basic_ostream<TChar, TCharTraits> &>::type
operator<<(basic_ostream<TChar, TCharTraits> & stream, const T & container)
{
return stream << ::pretty_print::print_container_helper<T, TChar, TCharTraits>(container);
}
}
#endif // H_PRETTY_PRINT
// Copyright 2019 Yan Yan
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <tensorview/tensorview.h>
#include <tensorview/tensor.h>
#include <algorithm>
#include <array>
#include <iostream>
#include <pybind11/functional.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11;
namespace tv {
template <typename T> TensorView<T> arrayt2tv(py::array_t<T> arr) {
Shape shape;
for (int i = 0; i < arr.ndim(); ++i) {
shape.push_back(arr.shape(i));
}
return TensorView<T>(arr.mutable_data(), shape);
}
template <typename T> TensorView<const T> carrayt2tv(py::array_t<T> arr) {
Shape shape;
for (int i = 0; i < arr.ndim(); ++i) {
shape.push_back(arr.shape(i));
}
return TensorView<const T>(arr.data(), shape);
}
template <typename T> TensorView<T> vector2tv(std::vector<T> &arr) {
return TensorView<T>(arr.data(), {arr.size()});
}
template <typename T>
TensorView<T> vector2tv(std::vector<T> &arr, Shape shape) {
TV_ASSERT_INVALID_ARG(shape.prod() == arr.size(), "error");
return TensorView<T>(arr.data(), shape);
}
template <typename T> TensorView<const T> vector2tv(const std::vector<T> &arr) {
return TensorView<const T>(arr.data(), {arr.size()});
}
template <typename T>
std::vector<T> shape2stride(const std::vector<T> &shape, T itemsize) {
T p = T(1);
std::vector<T> res;
for (auto iter = shape.rbegin(); iter != shape.rend(); ++iter) {
res.push_back(p * itemsize);
p *= *iter;
}
std::reverse(res.begin(), res.end());
return res;
}
tv::DType get_array_tv_dtype(const py::array& arr){
//
switch (arr.dtype().kind()){
case 'b': return tv::bool_;
case 'i': {
switch (arr.itemsize()){
case 1: return tv::int8;
case 2: return tv::int16;
case 4: return tv::int32;
case 8: return tv::int64;
default: break;
}
}
case 'u': {
switch (arr.itemsize()){
case 1: return tv::uint8;
case 2: return tv::uint16;
case 4: return tv::uint32;
case 8: return tv::uint64;
default: break;
}
}
case 'f': {
switch (arr.itemsize()){
case 4: return tv::float32;
case 8: return tv::float64;
default: break;
}
}
}
TV_THROW_RT_ERR("unknown dtype", arr.dtype().kind(), arr.itemsize());
}
Tensor array2tensor(py::array& arr) {
Shape shape;
for (int i = 0; i < arr.ndim(); ++i) {
shape.push_back(arr.shape(i));
}
return tv::from_blob(arr.mutable_data(), shape, get_array_tv_dtype(arr), -1);
}
} // namespace tv
// Copyright 2019 Yan Yan // Copyright 2019 Yan Yan
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef BOX_IOU_H #ifndef BOX_IOU_H
#define BOX_IOU_H #define BOX_IOU_H
...@@ -99,9 +98,10 @@ rbbox_iou(py::array_t<DType> box_corners, py::array_t<DType> qbox_corners, ...@@ -99,9 +98,10 @@ rbbox_iou(py::array_t<DType> box_corners, py::array_t<DType> qbox_corners,
} }
template <typename DType> template <typename DType>
py::array_t<DType> py::array_t<DType> rbbox_intersection(py::array_t<DType> box_corners,
rbbox_intersection(py::array_t<DType> box_corners, py::array_t<DType> qbox_corners, py::array_t<DType> qbox_corners,
py::array_t<DType> standup_iou, DType standup_thresh) { py::array_t<DType> standup_iou,
DType standup_thresh) {
namespace bg = boost::geometry; namespace bg = boost::geometry;
typedef bg::model::point<DType, 2, bg::cs::cartesian> point_t; typedef bg::model::point<DType, 2, bg::cs::cartesian> point_t;
typedef bg::model::polygon<point_t> polygon_t; typedef bg::model::polygon<point_t> polygon_t;
...@@ -152,6 +152,5 @@ rbbox_intersection(py::array_t<DType> box_corners, py::array_t<DType> qbox_corne ...@@ -152,6 +152,5 @@ rbbox_intersection(py::array_t<DType> box_corners, py::array_t<DType> qbox_corne
return overlaps; return overlaps;
} }
} // namespace spconv } // namespace spconv
#endif #endif
\ No newline at end of file
// Copyright 2019 Yan Yan // Copyright 2019 Yan Yan
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -17,17 +17,19 @@ ...@@ -17,17 +17,19 @@
#include <spconv/indice.h> #include <spconv/indice.h>
#include <spconv/reordering.h> #include <spconv/reordering.h>
#include <tensorview/torch_utils.h>
#include <torch/script.h> #include <torch/script.h>
#include <torch_utils.h>
#include <utility/timer.h> #include <utility/timer.h>
namespace spconv { namespace spconv {
// torch.jit's doc says only support int64, so we need to convert to int32. // torch.jit's doc says only support int64, so we need to convert to int32.
template <typename T> template <typename T>
torch::Tensor fusedIndiceConvBatchNorm(torch::Tensor features, torch::Tensor filters, torch::Tensor bias, torch::Tensor
torch::Tensor indicePairs, torch::Tensor indiceNum, fusedIndiceConvBatchNorm(torch::Tensor features, torch::Tensor filters,
int64_t numActOut, int64_t _inverse, int64_t _subM) { torch::Tensor bias, torch::Tensor indicePairs,
torch::Tensor indiceNum, int64_t numActOut,
int64_t _inverse, int64_t _subM) {
bool subM = _subM != 0; bool subM = _subM != 0;
bool inverse = _inverse != 0; bool inverse = _inverse != 0;
auto device = features.device().type(); auto device = features.device().type();
...@@ -36,13 +38,16 @@ torch::Tensor fusedIndiceConvBatchNorm(torch::Tensor features, torch::Tensor fil ...@@ -36,13 +38,16 @@ torch::Tensor fusedIndiceConvBatchNorm(torch::Tensor features, torch::Tensor fil
auto numInPlanes = features.size(1); auto numInPlanes = features.size(1);
auto numOutPlanes = filters.size(ndim + 1); auto numOutPlanes = filters.size(ndim + 1);
auto indicePairNumCpu = indiceNum.to({torch::kCPU}); auto indicePairNumCpu = indiceNum.to({torch::kCPU});
auto indicePairMaxSizeIter = std::max_element( auto indicePairMaxSizeIter =
indicePairNumCpu.data_ptr<int>(), indicePairNumCpu.data_ptr<int>() + kernelVolume); std::max_element(indicePairNumCpu.data_ptr<int>(),
int indicePairMaxOffset = indicePairMaxSizeIter - indicePairNumCpu.data_ptr<int>(); indicePairNumCpu.data_ptr<int>() + kernelVolume);
int indicePairMaxOffset =
indicePairMaxSizeIter - indicePairNumCpu.data_ptr<int>();
int indicePairMaxSize = *indicePairMaxSizeIter; int indicePairMaxSize = *indicePairMaxSizeIter;
/*if (_subM){ /*if (_subM){
std::vector<int> indicePairNumVec(indicePairNumCpu.data_ptr<int>(), indicePairNumCpu.data_ptr<int>() + kernelVolume); std::vector<int> indicePairNumVec(indicePairNumCpu.data_ptr<int>(),
indicePairNumCpu.data_ptr<int>() + kernelVolume);
indicePairNumVec.erase(indicePairNumVec.begin() + indicePairMaxOffset); indicePairNumVec.erase(indicePairNumVec.begin() + indicePairMaxOffset);
auto indicePairVecMaxSizeIter = std::max_element( auto indicePairVecMaxSizeIter = std::max_element(
...@@ -55,8 +60,10 @@ torch::Tensor fusedIndiceConvBatchNorm(torch::Tensor features, torch::Tensor fil ...@@ -55,8 +60,10 @@ torch::Tensor fusedIndiceConvBatchNorm(torch::Tensor features, torch::Tensor fil
// auto indicePairOptions = // auto indicePairOptions =
// torch::TensorOptions().dtype(torch::kInt64).device(indicePairs.device()); // torch::TensorOptions().dtype(torch::kInt64).device(indicePairs.device());
torch::Tensor output = torch::zeros({numActOut, numOutPlanes}, options).copy_(bias); torch::Tensor output =
torch::Tensor inputBuffer = torch::zeros({indicePairMaxSize, numInPlanes}, options); torch::zeros({numActOut, numOutPlanes}, options).copy_(bias);
torch::Tensor inputBuffer =
torch::zeros({indicePairMaxSize, numInPlanes}, options);
torch::Tensor outputBuffer = torch::Tensor outputBuffer =
torch::zeros({indicePairMaxSize, numOutPlanes}, options); torch::zeros({indicePairMaxSize, numOutPlanes}, options);
filters = filters.view({-1, numInPlanes, numOutPlanes}); filters = filters.view({-1, numInPlanes, numOutPlanes});
...@@ -73,30 +80,31 @@ torch::Tensor fusedIndiceConvBatchNorm(torch::Tensor features, torch::Tensor fil ...@@ -73,30 +80,31 @@ torch::Tensor fusedIndiceConvBatchNorm(torch::Tensor features, torch::Tensor fil
continue; continue;
} }
// auto timer = spconv::CudaContextTimer<>(); // auto timer = spconv::CudaContextTimer<>();
auto outputBufferBlob = auto outputBufferBlob = torch::from_blob(outputBuffer.data_ptr<T>(),
torch::from_blob(outputBuffer.data_ptr<T>(), {nHot, numOutPlanes}, options); {nHot, numOutPlanes}, options);
auto inputBufferBlob = auto inputBufferBlob = torch::from_blob(inputBuffer.data_ptr<T>(),
torch::from_blob(inputBuffer.data_ptr<T>(), {nHot, numInPlanes}, options); {nHot, numInPlanes}, options);
if (device == torch::kCPU) { if (device == torch::kCPU) {
functor::SparseGatherFunctor<tv::CPU, T, int> gatherFtor; functor::SparseGatherFunctor<tv::CPU, T, int> gatherFtor;
gatherFtor(tv::CPU(), tv::torch2tv<T>(inputBuffer), gatherFtor(tv::CPU(), tv::torch2tv<T>(inputBuffer),
tv::torch2tv<const T>(features), tv::torch2tv<const T>(features),
tv::torch2tv<const int>(indicePairs).subview(i, inverse), nHot); tv::torch2tv<const int>(indicePairs).subview(i, inverse),
} nHot);
#ifdef SPCONV_CUDA }
#ifdef TV_CUDA
else if (device == torch::kCUDA) { else if (device == torch::kCUDA) {
functor::SparseGatherFunctor<tv::GPU, T, int> gatherFtor; functor::SparseGatherFunctor<tv::GPU, T, int> gatherFtor;
gatherFtor(tv::TorchGPU(), tv::torch2tv<T>(inputBuffer), gatherFtor(tv::TorchGPU(), tv::torch2tv<T>(inputBuffer),
tv::torch2tv<const T>(features), tv::torch2tv<const T>(features),
tv::torch2tv<const int>(indicePairs).subview(i, inverse), nHot); tv::torch2tv<const int>(indicePairs).subview(i, inverse),
nHot);
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR();
/* slower than SparseGatherFunctor, may due to int->long conversion /* slower than SparseGatherFunctor, may due to int->long conversion
auto indicePairLong = indicePairs[i][inverse].to(torch::kInt64); auto indicePairLong = indicePairs[i][inverse].to(torch::kInt64);
auto indicePairBlob = torch::from_blob(indicePairLong.data<long>(), {nHot}, auto indicePairBlob = torch::from_blob(indicePairLong.data<long>(),
indicePairOptions); {nHot}, indicePairOptions); torch::index_select_out(inputBufferBlob,
torch::index_select_out(inputBufferBlob, features, 0, features, 0, indicePairBlob);*/
indicePairBlob);*/
} }
#endif #endif
else { else {
...@@ -111,16 +119,16 @@ torch::Tensor fusedIndiceConvBatchNorm(torch::Tensor features, torch::Tensor fil ...@@ -111,16 +119,16 @@ torch::Tensor fusedIndiceConvBatchNorm(torch::Tensor features, torch::Tensor fil
functor::SparseScatterAddFunctor<tv::CPU, T, int> scatterFtor; functor::SparseScatterAddFunctor<tv::CPU, T, int> scatterFtor;
scatterFtor(tv::CPU(), tv::torch2tv<T>(output), scatterFtor(tv::CPU(), tv::torch2tv<T>(output),
tv::torch2tv<const T>(outputBuffer), tv::torch2tv<const T>(outputBuffer),
tv::torch2tv<const int>(indicePairs).subview(i, !inverse), nHot, tv::torch2tv<const int>(indicePairs).subview(i, !inverse),
true); nHot, true);
} }
#ifdef SPCONV_CUDA #ifdef TV_CUDA
else if (device == torch::kCUDA) { else if (device == torch::kCUDA) {
functor::SparseScatterAddFunctor<tv::GPU, T, int> scatterFtor; functor::SparseScatterAddFunctor<tv::GPU, T, int> scatterFtor;
scatterFtor(tv::TorchGPU(), tv::torch2tv<T>(output), scatterFtor(tv::TorchGPU(), tv::torch2tv<T>(output),
tv::torch2tv<const T>(outputBuffer), tv::torch2tv<const T>(outputBuffer),
tv::torch2tv<const int>(indicePairs).subview(i, !inverse), nHot, tv::torch2tv<const int>(indicePairs).subview(i, !inverse),
true); nHot, true);
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR();
} }
#endif #endif
......
...@@ -26,34 +26,27 @@ namespace detail { ...@@ -26,34 +26,27 @@ namespace detail {
template <typename T> struct ToUnsigned; template <typename T> struct ToUnsigned;
template <> struct ToUnsigned<int>{ template <> struct ToUnsigned<int> { using type = uint32_t; };
using type = uint32_t;
};
template <> struct ToUnsigned<long>{ template <> struct ToUnsigned<long> { using type = uint64_t; };
using type = uint64_t;
};
template <typename T> struct FNVInternal; template <typename T> struct FNVInternal;
template <> struct FNVInternal<uint32_t> template <> struct FNVInternal<uint32_t> {
{
constexpr static uint32_t defaultOffsetBasis = 0x811C9DC5; constexpr static uint32_t defaultOffsetBasis = 0x811C9DC5;
constexpr static uint32_t prime = 0x01000193; constexpr static uint32_t prime = 0x01000193;
}; };
template <> struct FNVInternal<uint64_t> template <> struct FNVInternal<uint64_t> {
{
constexpr static uint64_t defaultOffsetBasis = 0xcbf29ce484222325; constexpr static uint64_t defaultOffsetBasis = 0xcbf29ce484222325;
constexpr static uint64_t prime = 0x100000001b3; constexpr static uint64_t prime = 0x100000001b3;
}; };
} } // namespace detail
template <typename T> template <typename T>
using to_unsigned_t = typename detail::ToUnsigned<std::remove_const_t<T>>::type; using to_unsigned_t = typename detail::ToUnsigned<std::remove_const_t<T>>::type;
template <typename T> template <typename T> struct FNV1a : detail::FNVInternal<T> {
struct FNV1a : detail::FNVInternal<T>{ std::size_t operator()(const T *data, std::size_t size) {
std::size_t operator()(const T* data, std::size_t size){
to_unsigned_t<T> hash = detail::FNVInternal<T>::defaultOffsetBasis; to_unsigned_t<T> hash = detail::FNVInternal<T>::defaultOffsetBasis;
for (std::size_t i = 0; i < size; ++i) { for (std::size_t i = 0; i < size; ++i) {
hash *= detail::FNVInternal<T>::prime; hash *= detail::FNVInternal<T>::prime;
......
...@@ -16,15 +16,14 @@ ...@@ -16,15 +16,14 @@
#define INDICE_CU_H_ #define INDICE_CU_H_
#include <cuhash/hash_table.cuh> #include <cuhash/hash_table.cuh>
#include <spconv/geometry.h> #include <spconv/geometry.h>
#include <tensorview/helper_kernel.cu.h> #include <tensorview/kernel_utils.h>
#include <tensorview/tensorview.h> #include <tensorview/tensorview.h>
namespace spconv { namespace spconv {
template <typename Index, typename IndexGrid, unsigned NDim, template <typename Index, unsigned NDim, int KernelMaxVolume = 256,
int KernelMaxVolume = 256, typename Index1D=int> typename Index1D = int>
__global__ void prepareIndicePairsKernel( __global__ void prepareIndicePairsKernel(
tv::TensorView<const Index> indicesIn, tv::TensorView<Index> indicesOut, tv::TensorView<const Index> indicesIn, tv::TensorView<Index> indicePairs,
tv::TensorView<IndexGrid> gridsOut, tv::TensorView<Index> indicePairs,
tv::TensorView<Index> indiceNum, tv::TensorView<Index1D> indicePairUnique, tv::TensorView<Index> indiceNum, tv::TensorView<Index1D> indicePairUnique,
const tv::SimpleVector<Index, NDim> kernelSize, const tv::SimpleVector<Index, NDim> kernelSize,
const tv::SimpleVector<Index, NDim> stride, const tv::SimpleVector<Index, NDim> stride,
...@@ -65,11 +64,9 @@ __global__ void prepareIndicePairsKernel( ...@@ -65,11 +64,9 @@ __global__ void prepareIndicePairsKernel(
} }
} }
template <typename Index, typename IndexGrid, unsigned NDim, template <typename Index, unsigned NDim, int KernelMaxVolume = 256>
int KernelMaxVolume = 256>
__global__ void prepareDeConvIndicePairsKernel( __global__ void prepareDeConvIndicePairsKernel(
tv::TensorView<const Index> indicesIn, tv::TensorView<Index> indicesOut, tv::TensorView<const Index> indicesIn, tv::TensorView<Index> indicePairs,
tv::TensorView<IndexGrid> gridsOut, tv::TensorView<Index> indicePairs,
tv::TensorView<Index> indiceNum, tv::TensorView<Index> indicePairUnique, tv::TensorView<Index> indiceNum, tv::TensorView<Index> indicePairUnique,
const tv::SimpleVector<Index, NDim> kernelSize, const tv::SimpleVector<Index, NDim> kernelSize,
const tv::SimpleVector<Index, NDim> stride, const tv::SimpleVector<Index, NDim> stride,
...@@ -128,12 +125,12 @@ __global__ void assignGridAndIndiceOutKernel( ...@@ -128,12 +125,12 @@ __global__ void assignGridAndIndiceOutKernel(
} }
} }
template <typename Index, unsigned NDim, template <typename Index, unsigned NDim, unsigned kNumHashFunctions = 4>
unsigned kNumHashFunctions = 4> __global__ void
__global__ void assignIndiceOutKernel( assignIndiceOutKernel(tv::TensorView<Index> indicesOut, int numAct,
tv::TensorView<Index> indicesOut, int numAct, tv::TensorView<Index> indicePairUnique,
tv::TensorView<Index> indicePairUnique, const tv::SimpleVector<Index, NDim> outSpatialShape,
const tv::SimpleVector<Index, NDim> outSpatialShape, int batchSize) { int batchSize) {
Index index; Index index;
auto indicesOutPtr = indicesOut.data(); auto indicesOutPtr = indicesOut.data();
...@@ -145,8 +142,7 @@ __global__ void assignIndiceOutKernel( ...@@ -145,8 +142,7 @@ __global__ void assignIndiceOutKernel(
} }
} }
template <typename Index, typename IndexGrid, unsigned NDim, template <typename Index, unsigned NDim, unsigned kNumHashFunctions = 4>
unsigned kNumHashFunctions = 4>
__global__ void __global__ void
assignIndicePairsHashKernel(tv::TensorView<Index> indicesOut, int numActIn, assignIndicePairsHashKernel(tv::TensorView<Index> indicesOut, int numActIn,
tv::TensorView<Index> indicePairs, tv::TensorView<Index> indicePairs,
...@@ -161,9 +157,8 @@ assignIndicePairsHashKernel(tv::TensorView<Index> indicesOut, int numActIn, ...@@ -161,9 +157,8 @@ assignIndicePairsHashKernel(tv::TensorView<Index> indicesOut, int numActIn,
for (int i = 0; i < kernelVolume; ++i) { for (int i = 0; i < kernelVolume; ++i) {
index = indicePairs(i, 1, ix); index = indicePairs(i, 1, ix);
if (index > -1) { if (index > -1) {
auto val = auto val = cuhash::retrieve((unsigned)(index), table_size, table,
cuhash::retrieve((unsigned)(index), table_size, constants, stash_constants, stash_count);
table, constants, stash_constants, stash_count);
assert(val != cuhash::kNotFound); assert(val != cuhash::kNotFound);
indicePairs(i, 1, ix) = (unsigned)val; indicePairs(i, 1, ix) = (unsigned)val;
} }
...@@ -213,9 +208,8 @@ prepareSubMGridKernel(tv::TensorView<const Index> indicesIn, ...@@ -213,9 +208,8 @@ prepareSubMGridKernel(tv::TensorView<const Index> indicesIn,
template <typename Index, unsigned NDim> template <typename Index, unsigned NDim>
__global__ void __global__ void
prepareSubMHashKernel(tv::TensorView<const Index> indicesIn, prepareSubMHashKernel(tv::TensorView<const Index> indicesIn, unsigned *keys,
unsigned* keys, unsigned *values,
unsigned* values,
const tv::SimpleVector<Index, NDim> outSpatialShape) { const tv::SimpleVector<Index, NDim> outSpatialShape) {
auto numActIn = indicesIn.dim(0); auto numActIn = indicesIn.dim(0);
Index spatialVolume = 1; Index spatialVolume = 1;
...@@ -233,7 +227,6 @@ prepareSubMHashKernel(tv::TensorView<const Index> indicesIn, ...@@ -233,7 +227,6 @@ prepareSubMHashKernel(tv::TensorView<const Index> indicesIn,
} }
} }
template <typename Index, typename IndexGrid, unsigned NDim, template <typename Index, typename IndexGrid, unsigned NDim,
int KernelMaxVolume = 256> int KernelMaxVolume = 256>
__global__ void getSubMIndicePairsKernel( __global__ void getSubMIndicePairsKernel(
...@@ -273,18 +266,17 @@ __global__ void getSubMIndicePairsKernel( ...@@ -273,18 +266,17 @@ __global__ void getSubMIndicePairsKernel(
} }
} }
template <typename Index, unsigned NDim, template <typename Index, unsigned NDim, int KernelMaxVolume = 256,
int KernelMaxVolume = 256, unsigned kNumHashFunctions=4> unsigned kNumHashFunctions = 4>
__global__ void getSubMIndicePairsHashKernel( __global__ void getSubMIndicePairsHashKernel(
tv::TensorView<const Index> indicesIn, tv::TensorView<const Index> indicesIn, tv::TensorView<Index> indicePairs,
tv::TensorView<Index> indicePairs, tv::TensorView<Index> indiceNum, tv::TensorView<Index> indiceNum,
const tv::SimpleVector<Index, NDim> kernelSize, const tv::SimpleVector<Index, NDim> kernelSize,
const tv::SimpleVector<Index, NDim> stride, const tv::SimpleVector<Index, NDim> stride,
const tv::SimpleVector<Index, NDim> padding, const tv::SimpleVector<Index, NDim> padding,
const tv::SimpleVector<Index, NDim> dilation, const tv::SimpleVector<Index, NDim> dilation,
const tv::SimpleVector<Index, NDim> outSpatialShape, const tv::SimpleVector<Index, NDim> outSpatialShape, unsigned table_size,
unsigned table_size, const cuhash::Entry *table, const cuhash::Entry *table, cuhash::Functions<kNumHashFunctions> constants,
cuhash::Functions<kNumHashFunctions> constants,
uint2 stash_constants, unsigned stash_count) { uint2 stash_constants, unsigned stash_count) {
auto numActIn = indicesIn.dim(0); auto numActIn = indicesIn.dim(0);
Index spatialVolume = 1; Index spatialVolume = 1;
...@@ -306,9 +298,8 @@ __global__ void getSubMIndicePairsHashKernel( ...@@ -306,9 +298,8 @@ __global__ void getSubMIndicePairsHashKernel(
auto offset = pointPtr[NDim]; auto offset = pointPtr[NDim];
index = tv::rowArrayIdx<Index, NDim>(pointPtr, outSpatialShape.data()) + index = tv::rowArrayIdx<Index, NDim>(pointPtr, outSpatialShape.data()) +
spatialVolume * indicesIn(ix, 0); spatialVolume * indicesIn(ix, 0);
auto val = auto val = cuhash::retrieve((unsigned)(index), table_size, table,
cuhash::retrieve((unsigned)(index), table_size, constants, stash_constants, stash_count);
table, constants, stash_constants, stash_count);
if (val != cuhash::kNotFound) { if (val != cuhash::kNotFound) {
auto oldNum = atomicAdd(indiceNum.data() + offset, Index(1)); auto oldNum = atomicAdd(indiceNum.data() + offset, Index(1));
indicePairs(offset, 1, oldNum) = val; indicePairs(offset, 1, oldNum) = val;
...@@ -318,7 +309,6 @@ __global__ void getSubMIndicePairsHashKernel( ...@@ -318,7 +309,6 @@ __global__ void getSubMIndicePairsHashKernel(
} }
} }
template <typename Index, typename IndexGrid, unsigned NDim> template <typename Index, typename IndexGrid, unsigned NDim>
__global__ void resetGridKernel(const Index *indicePairUnique, __global__ void resetGridKernel(const Index *indicePairUnique,
tv::TensorView<IndexGrid> gridsOut, tv::TensorView<IndexGrid> gridsOut,
...@@ -328,14 +318,12 @@ __global__ void resetGridKernel(const Index *indicePairUnique, ...@@ -328,14 +318,12 @@ __global__ void resetGridKernel(const Index *indicePairUnique,
} }
} }
template <typename T> template <typename T> __global__ void arangeKernel(T *data, int size) {
__global__ void arangeKernel(T *data, int size) {
for (int ix : tv::KernelLoopX<int>(size)) { for (int ix : tv::KernelLoopX<int>(size)) {
data[ix] = ix; data[ix] = ix;
} }
} }
template <typename Index, typename IndexGrid, unsigned NDim> template <typename Index, typename IndexGrid, unsigned NDim>
__global__ void __global__ void
resetGridSubMKernel(const Index *indices, tv::TensorView<IndexGrid> gridsOut, resetGridSubMKernel(const Index *indices, tv::TensorView<IndexGrid> gridsOut,
......
// Copyright 2019 Yan Yan // Copyright 2019 Yan Yan
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -15,67 +15,88 @@ ...@@ -15,67 +15,88 @@
#ifndef SPARSE_CONV_INDICE_FUNCTOR_H_ #ifndef SPARSE_CONV_INDICE_FUNCTOR_H_
#define SPARSE_CONV_INDICE_FUNCTOR_H_ #define SPARSE_CONV_INDICE_FUNCTOR_H_
#include <tensorview/tensorview.h> #include <tensorview/tensorview.h>
#include <torch/script.h>
namespace spconv namespace spconv {
{ namespace functor {
namespace functor
{
template <typename Device, typename Index, typename IndexGrid, unsigned NDim> template <typename Device, typename Index, typename IndexGrid, unsigned NDim>
struct CreateConvIndicePairFunctorP1 struct CreateConvIndicePairFunctorP1 {
{ Index operator()(const Device &d, tv::TensorView<const Index> indicesIn,
Index operator()( tv::TensorView<Index> indicesOut,
const Device& d, tv::TensorView<const Index> indicesIn, tv::TensorView<IndexGrid> gridsOut,
tv::TensorView<Index> indicesOut, tv::TensorView<IndexGrid> gridsOut, tv::TensorView<Index> indicePairs,
tv::TensorView<Index> indicePairs, tv::TensorView<Index> indiceNum, tv::TensorView<Index> indiceNum,
tv::TensorView<Index> indicePairUnique, tv::TensorView<Index> indicePairUnique,
const tv::SimpleVector<Index, NDim> kernelSize, const tv::SimpleVector<Index, NDim> kernelSize,
const tv::SimpleVector<Index, NDim> stride, const tv::SimpleVector<Index, NDim> stride,
const tv::SimpleVector<Index, NDim> padding, const tv::SimpleVector<Index, NDim> padding,
const tv::SimpleVector<Index, NDim> dilation, const tv::SimpleVector<Index, NDim> dilation,
const tv::SimpleVector<Index, NDim> outSpatialShape, bool transpose); const tv::SimpleVector<Index, NDim> outSpatialShape,
bool transpose);
}; };
template <typename Device, typename Index, typename IndexGrid, unsigned NDim> template <typename Device, typename Index, typename IndexGrid, unsigned NDim>
struct CreateConvIndicePairFunctorP2 struct CreateConvIndicePairFunctorP2 {
{ Index operator()(const Device &d, tv::TensorView<const Index> indicesIn,
Index operator()( tv::TensorView<Index> indicesOut,
const Device& d, tv::TensorView<const Index> indicesIn, tv::TensorView<IndexGrid> gridsOut,
tv::TensorView<Index> indicesOut, tv::TensorView<IndexGrid> gridsOut, tv::TensorView<Index> indicePairs,
tv::TensorView<Index> indicePairs, tv::TensorView<Index> indiceNum, tv::TensorView<Index> indiceNum,
tv::TensorView<Index> indicePairUnique, tv::TensorView<Index> indicePairUnique,
const tv::SimpleVector<Index, NDim> outSpatialShape, bool transpose, const tv::SimpleVector<Index, NDim> outSpatialShape,
bool resetGrid=false, bool useHash=true); bool transpose, bool resetGrid = false, bool useHash = true);
}; };
template <typename Device, typename Index, typename IndexGrid, unsigned NDim> template <typename Device, typename Index, typename IndexGrid, unsigned NDim>
struct CreateConvIndicePairFunctor struct CreateConvIndicePairFunctor {
{ Index operator()(const Device &d, tv::TensorView<const Index> indicesIn,
Index operator()( tv::TensorView<Index> indicesOut,
const Device& d, tv::TensorView<const Index> indicesIn, tv::TensorView<IndexGrid> gridsOut,
tv::TensorView<Index> indicesOut, tv::TensorView<IndexGrid> gridsOut, tv::TensorView<Index> indicePairs,
tv::TensorView<Index> indicePairs, tv::TensorView<Index> indiceNum, tv::TensorView<Index> indiceNum,
const tv::SimpleVector<Index, NDim> kernelSize, const tv::SimpleVector<Index, NDim> kernelSize,
const tv::SimpleVector<Index, NDim> stride, const tv::SimpleVector<Index, NDim> stride,
const tv::SimpleVector<Index, NDim> padding, const tv::SimpleVector<Index, NDim> padding,
const tv::SimpleVector<Index, NDim> dilation, const tv::SimpleVector<Index, NDim> dilation,
const tv::SimpleVector<Index, NDim> outSpatialShape, bool transpose, bool resetGrid=false, const tv::SimpleVector<Index, NDim> outSpatialShape,
bool useHash=true); bool transpose, bool resetGrid = false, bool useHash = true);
}; };
template <typename Device, typename Index, typename IndexGrid, unsigned NDim> template <typename Device, typename Index, typename IndexGrid, unsigned NDim>
struct CreateSubMIndicePairFunctor struct CreateSubMIndicePairFunctor {
{ Index operator()(const Device &d, tv::TensorView<const Index> indicesIn,
Index operator()( tv::TensorView<IndexGrid> gridsOut,
const Device& d, tv::TensorView<const Index> indicesIn, tv::TensorView<IndexGrid> gridsOut, tv::TensorView<Index> indicePairs,
tv::TensorView<Index> indicePairs, tv::TensorView<Index> indiceNum, tv::TensorView<Index> indiceNum,
const tv::SimpleVector<Index, NDim> kernelSize, const tv::SimpleVector<Index, NDim> kernelSize,
const tv::SimpleVector<Index, NDim> stride, const tv::SimpleVector<Index, NDim> stride,
const tv::SimpleVector<Index, NDim> padding, const tv::SimpleVector<Index, NDim> padding,
const tv::SimpleVector<Index, NDim> dilation, const tv::SimpleVector<Index, NDim> dilation,
const tv::SimpleVector<Index, NDim> outSpatialShape, bool transpose, bool resetGrid=false, const tv::SimpleVector<Index, NDim> outSpatialShape,
bool useHash=true); bool transpose, bool resetGrid = false, bool useHash = true);
}; };
} // namespace functor } // namespace functor
int create_conv_indice_pair_p1_cuda(
torch::Tensor indicesIn, torch::Tensor indicePairs, torch::Tensor indiceNum,
torch::Tensor indicePairUnique, std::vector<int64_t> kernelSize,
std::vector<int64_t> stride, std::vector<int64_t> padding,
std::vector<int64_t> dilation, std::vector<int64_t> outSpatialShape,
bool transpose);
int create_conv_indice_pair_p2_cuda(
torch::Tensor indicesIn, torch::Tensor indicesOut, torch::Tensor gridsOut,
torch::Tensor indicePairs, torch::Tensor indiceNum,
torch::Tensor indicePairUnique, std::vector<int64_t> outSpatialShape,
bool transpose, bool resetGrid, bool useHash);
int create_submconv_indice_pair_cuda(
torch::Tensor indicesIn, torch::Tensor gridsOut, torch::Tensor indicePairs,
torch::Tensor indiceNum, std::vector<int64_t> kernelSize,
std::vector<int64_t> stride, std::vector<int64_t> padding,
std::vector<int64_t> dilation, std::vector<int64_t> outSpatialShape,
bool transpose, bool resetGrid, bool useHash);
} // namespace spconv } // namespace spconv
#endif #endif
\ No newline at end of file
// Copyright 2019 Yan Yan // Copyright 2019 Yan Yan
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -16,25 +16,20 @@ ...@@ -16,25 +16,20 @@
#define SPARSE_MAXPOOL_FUNCTOR_H_ #define SPARSE_MAXPOOL_FUNCTOR_H_
#include <tensorview/tensorview.h> #include <tensorview/tensorview.h>
namespace spconv namespace spconv {
{ namespace functor {
namespace functor
{
template <typename Device, typename T, typename Index> template <typename Device, typename T, typename Index>
struct SparseMaxPoolForwardFunctor struct SparseMaxPoolForwardFunctor {
{ void operator()(const Device &d, tv::TensorView<T> outFeatures,
void operator()(const Device& d, tv::TensorView<T> outFeatures,
tv::TensorView<const T> inFeatures, tv::TensorView<const T> inFeatures,
tv::TensorView<const Index> indices, int size); tv::TensorView<const Index> indices, int size);
}; };
template <typename Device, typename T, typename Index> template <typename Device, typename T, typename Index>
struct SparseMaxPoolBackwardFunctor struct SparseMaxPoolBackwardFunctor {
{ void operator()(const Device &d, tv::TensorView<const T> outFeatures,
void operator()(const Device& d, tv::TensorView<const T> outFeatures,
tv::TensorView<const T> inFeatures, tv::TensorView<const T> inFeatures,
tv::TensorView<const T> dout, tv::TensorView<const T> dout, tv::TensorView<T> din,
tv::TensorView<T> din,
tv::TensorView<const Index> indices, int size); tv::TensorView<const Index> indices, int size);
}; };
......
// Copyright 2019 Yan Yan // Copyright 2019 Yan Yan
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -16,13 +16,13 @@ ...@@ -16,13 +16,13 @@
#define NMS_CPU_H #define NMS_CPU_H
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
// must include pybind11/stl.h if using containers in STL in arguments. // must include pybind11/stl.h if using containers in STL in arguments.
#include "box_iou.h"
#include "nms_gpu.h"
#include <algorithm> #include <algorithm>
#include <boost/geometry.hpp> #include <boost/geometry.hpp>
#include <pybind11/numpy.h> #include <pybind11/numpy.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <vector> #include <vector>
#include "box_iou.h"
#include "nms_gpu.h"
namespace spconv { namespace spconv {
namespace py = pybind11; namespace py = pybind11;
using namespace pybind11::literals; using namespace pybind11::literals;
...@@ -181,7 +181,7 @@ std::vector<int> rotate_non_max_suppression_cpu(py::array_t<DType> box_corners, ...@@ -181,7 +181,7 @@ std::vector<int> rotate_non_max_suppression_cpu(py::array_t<DType> box_corners,
} }
return keep; return keep;
} }
#ifdef SPCONV_CUDA #ifdef TV_CUDA
constexpr int const threadsPerBlock = sizeof(unsigned long long) * 8; constexpr int const threadsPerBlock = sizeof(unsigned long long) * 8;
template <typename DType> template <typename DType>
......
// Copyright 2019 Yan Yan // Copyright 2019 Yan Yan
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -16,24 +16,19 @@ ...@@ -16,24 +16,19 @@
#define NMS_FUNCTOR_H_ #define NMS_FUNCTOR_H_
#include <tensorview/tensorview.h> #include <tensorview/tensorview.h>
namespace spconv namespace spconv {
{ namespace functor {
namespace functor
{
template <typename Device, typename T, typename Index> template <typename Device, typename T, typename Index>
struct NonMaxSupressionFunctor struct NonMaxSupressionFunctor {
{ Index operator()(const Device &d, tv::TensorView<Index> keep,
Index operator()(const Device& d, tv::TensorView<Index> keep, tv::TensorView<const T> boxes, T threshold, T eps);
tv::TensorView<const T> boxes,
T threshold, T eps);
}; };
template <typename Device, typename T, typename Index> template <typename Device, typename T, typename Index>
struct rotateNonMaxSupressionFunctor struct rotateNonMaxSupressionFunctor {
{ Index operator()(const Device &d, tv::TensorView<Index> keep,
Index operator()(const Device& d, tv::TensorView<Index> keep, tv::TensorView<const T> boxCorners,
tv::TensorView<const T> boxCorners, tv::TensorView<const T> standupIoU, T threshold);
tv::TensorView<const T> standupIoU, T threshold);
}; };
} // namespace functor } // namespace functor
......
// Copyright 2019 Yan Yan // Copyright 2019 Yan Yan
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......
// Copyright 2019 Yan Yan // Copyright 2019 Yan Yan
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
// //
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
// //
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -16,35 +16,35 @@ ...@@ -16,35 +16,35 @@
#define NMS_TORCH_OP_H_ #define NMS_TORCH_OP_H_
#include <spconv/indice.h> #include <spconv/indice.h>
#include <spconv/nms_functor.h>
#include <spconv/reordering.h> #include <spconv/reordering.h>
#include <tensorview/torch_utils.h>
#include <torch/script.h> #include <torch/script.h>
#include <torch_utils.h>
#include <utility/timer.h> #include <utility/timer.h>
#include <spconv/nms_functor.h>
namespace spconv { namespace spconv {
// torch.jit's doc says only support int64, so we need to convert to int32. // torch.jit's doc says only support int64, so we need to convert to int32.
template <typename T> template <typename T>
torch::Tensor torch::Tensor nonMaxSuppression(torch::Tensor boxes, torch::Tensor scores,
nonMaxSuppression(torch::Tensor boxes, torch::Tensor scores, int64_t preMaxSize, int64_t preMaxSize, int64_t postMaxSize,
int64_t postMaxSize, double thresh, double eps) { double thresh, double eps) {
// auto timer = spconv::CudaContextTimer<>(); // auto timer = spconv::CudaContextTimer<>();
tv::check_torch_dtype<T>(boxes); tv::check_torch_dtype<T>(boxes);
auto resOptions = auto resOptions =
torch::TensorOptions().dtype(torch::kInt64).device(boxes.device()); torch::TensorOptions().dtype(torch::kInt64).device(boxes.device());
if (boxes.size(0) == 0){ if (boxes.size(0) == 0) {
return torch::zeros({0}, resOptions); return torch::zeros({0}, resOptions);
} }
torch::Tensor indices; torch::Tensor indices;
if (preMaxSize > 0){ if (preMaxSize > 0) {
auto numKeepedScores = scores.size(0); auto numKeepedScores = scores.size(0);
preMaxSize = std::min(numKeepedScores, preMaxSize); preMaxSize = std::min(numKeepedScores, preMaxSize);
auto res = torch::topk(scores, preMaxSize); auto res = torch::topk(scores, preMaxSize);
indices = std::get<1>(res); indices = std::get<1>(res);
boxes = torch::index_select(boxes, 0, indices); boxes = torch::index_select(boxes, 0, indices);
}else{ } else {
indices = std::get<1>(torch::sort(scores)); indices = std::get<1>(torch::sort(scores));
boxes = torch::index_select(boxes, 0, indices); boxes = torch::index_select(boxes, 0, indices);
} }
if (boxes.size(0) == 0) if (boxes.size(0) == 0)
return torch::zeros({0}, resOptions); return torch::zeros({0}, resOptions);
...@@ -54,16 +54,16 @@ nonMaxSuppression(torch::Tensor boxes, torch::Tensor scores, int64_t preMaxSize, ...@@ -54,16 +54,16 @@ nonMaxSuppression(torch::Tensor boxes, torch::Tensor scores, int64_t preMaxSize,
if (boxes.device().type() == torch::kCPU) { if (boxes.device().type() == torch::kCPU) {
auto nmsFunctor = functor::NonMaxSupressionFunctor<tv::CPU, T, int64_t>(); auto nmsFunctor = functor::NonMaxSupressionFunctor<tv::CPU, T, int64_t>();
keepNum = nmsFunctor(tv::CPU(), tv::torch2tv<int64_t>(keep), keepNum = nmsFunctor(tv::CPU(), tv::torch2tv<int64_t>(keep),
tv::torch2tv<const T>(boxes), T(thresh), T(eps)); tv::torch2tv<const T>(boxes), T(thresh), T(eps));
}else{ } else {
TV_ASSERT_RT_ERR(false, "not implemented"); TV_ASSERT_RT_ERR(false, "not implemented");
} }
if (postMaxSize <= 0){ if (postMaxSize <= 0) {
postMaxSize = keepNum; postMaxSize = keepNum;
} }
// std::cout << keep << std::endl; // std::cout << keep << std::endl;
keep = keep.slice(0, 0, std::min(keepNum, postMaxSize)); keep = keep.slice(0, 0, std::min(keepNum, postMaxSize));
if (preMaxSize > 0){ if (preMaxSize > 0) {
return torch::index_select(indices, 0, keep); return torch::index_select(indices, 0, keep);
} }
return keep; return keep;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment