Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
19e73bbe
"vscode:/vscode.git/clone" did not exist on "1a2f45fcf82af9da11e2920bcb62f0b04875e946"
Commit
19e73bbe
authored
May 20, 2020
by
Yan Yan
Browse files
format code with clang-format, better c++ code
parent
c336139f
Changes
77
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
441 additions
and
1041 deletions
+441
-1041
CMakeLists.txt
CMakeLists.txt
+1
-1
format_all.sh
format_all.sh
+5
-0
include/cuhash/cuda_util.h
include/cuhash/cuda_util.h
+35
-31
include/cuhash/debugging.h
include/cuhash/debugging.h
+23
-23
include/cuhash/hash_functions.h
include/cuhash/hash_functions.h
+23
-28
include/cuhash/hash_table.cuh
include/cuhash/hash_table.cuh
+78
-109
include/cuhash/hash_table.h
include/cuhash/hash_table.h
+66
-79
include/paramsgrid.h
include/paramsgrid.h
+5
-4
include/prettyprint.h
include/prettyprint.h
+0
-445
include/pybind11_utils.h
include/pybind11_utils.h
+0
-115
include/spconv/box_iou.h
include/spconv/box_iou.h
+7
-8
include/spconv/fused_spconv_ops.h
include/spconv/fused_spconv_ops.h
+39
-31
include/spconv/geometry.h
include/spconv/geometry.h
+7
-14
include/spconv/indice.cu.h
include/spconv/indice.cu.h
+26
-38
include/spconv/indice.h
include/spconv/indice.h
+72
-51
include/spconv/maxpool.h
include/spconv/maxpool.h
+10
-15
include/spconv/nms.h
include/spconv/nms.h
+6
-6
include/spconv/nms_functor.h
include/spconv/nms_functor.h
+12
-17
include/spconv/nms_gpu.h
include/spconv/nms_gpu.h
+3
-3
include/spconv/nms_ops.h
include/spconv/nms_ops.h
+23
-23
No files found.
CMakeLists.txt
View file @
19e73bbe
...
...
@@ -33,7 +33,7 @@ if (SPCONV_BuildCUDA)
torch_cuda_get_nvcc_gencode_flag
(
NVCC_FLAGS_EXTRA
)
string
(
REPLACE
";"
" "
NVCC_FLAGS_EXTRA_STR
"
${
NVCC_FLAGS_EXTRA
}
"
)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
${
NVCC_FLAGS_EXTRA_STR
}
"
)
add_compile_definitions
(
SPCON
V_CUDA
)
add_compile_definitions
(
T
V_CUDA
)
endif
()
# add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
add_subdirectory
(
third_party/pybind11
)
...
...
format_all.sh
0 → 100644
View file @
19e73bbe
isort
-rc
--atomic
./spconv
&&
\
isort
-rc
--atomic
./test
&&
\
yapf
-i
--recursive
-vv
./spconv ./test
find ./src
-regex
'.*\.\(cpp\|hpp\|cc\|cxx\|cu\|cuh\|h\)'
| xargs clang-format
-i
find ./include
-regex
'.*\.\(cpp\|hpp\|cc\|cxx\|cu\|cuh\|h\)'
| xargs clang-format
-i
\ No newline at end of file
include/cuhash/cuda_util.h
View file @
19e73bbe
...
...
@@ -2,46 +2,50 @@
#define _CUDA_UTIL_H_
#if CUDART_VERSION >= 4000
#define CUDA_DEVICE_SYNCHRONIZE(
)
cudaDeviceSynchronize();
#define CUDA_DEVICE_SYNCHRONIZE(
)
cudaDeviceSynchronize();
#else
#define CUDA_DEVICE_SYNCHRONIZE(
)
cudaThreadSynchronize();
#define CUDA_DEVICE_SYNCHRONIZE(
)
cudaThreadSynchronize();
#endif
# define CUDA_SAFE_CALL_NO_SYNC( call) { \
cudaError err = call; \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", \
__FILE__, __LINE__, cudaGetErrorString( err) ); \
exit(EXIT_FAILURE); \
} }
#define CUDA_SAFE_CALL_NO_SYNC(call) \
{ \
cudaError err = call; \
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", __FILE__, \
__LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
}
#
define CUDA_SAFE_CALL(
call)
CUDA_SAFE_CALL_NO_SYNC(call);
#define CUDA_SAFE_CALL(call) CUDA_SAFE_CALL_NO_SYNC(call);
//! Check for CUDA error
#ifdef _DEBUG
# define CUDA_CHECK_ERROR(errorMessage) { \
cudaError_t err = cudaGetLastError(); \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \
errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) );\
exit(EXIT_FAILURE); \
} \
#define CUDA_CHECK_ERROR(errorMessage) \
{ \
cudaError_t err = cudaGetLastError(); \
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \
errorMessage, __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
err = CUDA_DEVICE_SYNCHRONIZE(); \
if
(
cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \
errorMessage, __FILE__, __LINE__, cudaGetErrorString(
err)
);\
exit(EXIT_FAILURE); \
} \
}
if
(
cudaSuccess != err) {
\
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n",
\
errorMessage, __FILE__, __LINE__, cudaGetErrorString(err));
\
exit(EXIT_FAILURE);
\
}
\
}
#else
# define CUDA_CHECK_ERROR(errorMessage) { \
cudaError_t err = cudaGetLastError(); \
if( cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \
errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) );\
exit(EXIT_FAILURE); \
} \
}
#define CUDA_CHECK_ERROR(errorMessage) \
{ \
cudaError_t err = cudaGetLastError(); \
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \
errorMessage, __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
}
#endif
#endif
\ No newline at end of file
include/cuhash/debugging.h
View file @
19e73bbe
...
...
@@ -3,10 +3,10 @@
// -------------------------------------------------------------
// $Revision:$
// $Date:$
// -------------------------------------------------------------
// -------------------------------------------------------------
// This source code is distributed under the terms of license.txt in
// the root directory of this source distribution.
// -------------------------------------------------------------
// -------------------------------------------------------------
/**
* @file
...
...
@@ -29,44 +29,44 @@ namespace cuhash {
//! @name Debugging functions
/// @{
void
TakeHashFunctionStatistics
(
const
unsigned
num_keys
,
const
unsigned
*
d_keys
,
const
unsigned
table_size
,
const
uint2
*
constants
,
const
unsigned
kNumHashFunctions
);
void
TakeHashFunctionStatistics
(
const
unsigned
num_keys
,
const
unsigned
*
d_keys
,
const
unsigned
table_size
,
const
uint2
*
constants
,
const
unsigned
kNumHashFunctions
);
//!
Output how many probes were required by each thread to perform the
retrieval.
//! Output how many probes were required by each thread to perform the
//! retrieval.
/*! @param[in] n_queries Number of queries being performed.
* @param[in] d_retrieval_probes Device array: the number of probes taken for each thread's retrieval.
* @param[in] d_retrieval_probes Device array: the number of probes taken for
* each thread's retrieval.
* @param[in] n_functions Number of hash functions used.
*/
void
OutputRetrievalStatistics
(
const
unsigned
n_queries
,
void
OutputRetrievalStatistics
(
const
unsigned
n_queries
,
const
unsigned
*
d_retrieval_probes
,
const
unsigned
n_functions
);
const
unsigned
n_functions
);
//! Outputs information about how many iterations threads required to successfully cuckoo hash.
//! Outputs information about how many iterations threads required to
//! successfully cuckoo hash.
/*! @param[in] n Number of keys in the input.
* @param[in] d_iterations_taken Device mem: Number of iterations each thread took.
* @param[in] d_max_iterations_taken Device mem: Largest number of iterations taken by any thread.
* @param[in] d_iterations_taken Device mem: Number of iterations each
* thread took.
* @param[in] d_max_iterations_taken Device mem: Largest number of iterations
* taken by any thread.
*/
void
OutputBuildStatistics
(
const
unsigned
n
,
void
OutputBuildStatistics
(
const
unsigned
n
,
const
unsigned
*
d_iterations_taken
);
//! Prints out the contents of the stash.
void
PrintStashContents
(
const
Entry
*
d_stash
);
//! Checks if a key is assigned the same slot by different hash functions.
bool
CheckAssignedSameSlot
(
const
unsigned
N
,
const
unsigned
num_keys
,
const
unsigned
*
d_keys
,
const
unsigned
table_size
,
uint2
*
constants
);
bool
CheckAssignedSameSlot
(
const
unsigned
N
,
const
unsigned
num_keys
,
const
unsigned
*
d_keys
,
const
unsigned
table_size
,
uint2
*
constants
);
/// @}
};
// namespace
CuckooHashing
};
// namespace
cuhash
#endif
...
...
include/cuhash/hash_functions.h
View file @
19e73bbe
...
...
@@ -5,9 +5,9 @@
#ifndef HASH_FUNCTIONS__H
#define HASH_FUNCTIONS__H
#include "definitions.h"
#include <tensorview/tensorview.h>
#include <vector_types.h>
#include "definitions.h"
namespace
cuhash
{
...
...
@@ -23,30 +23,28 @@ const unsigned kPrimeDivisor = 4294967291u;
/*! @param[in] N Number of hash functions.
@param[out] constants CPU pointer to the constants.
@param[in] num_keys Debug only: How many keys are in the input.
@param[in] d_keys Debug only: Device memory array containing the input keys.
@param[in] d_keys Debug only: Device memory array containing the input
keys.
@param[in] table_size Debug only: Size of the hash table.
*/
void
GenerateFunctions
(
const
unsigned
N
,
const
unsigned
num_keys
,
const
unsigned
*
d_keys
,
const
unsigned
table_size
,
uint2
*
constants
);
void
GenerateFunctions
(
const
unsigned
N
,
const
unsigned
num_keys
,
const
unsigned
*
d_keys
,
const
unsigned
table_size
,
uint2
*
constants
);
//! Container for all of the hash functions.
template
<
unsigned
N
>
struct
Functions
{
//!
The constants required for all of the hash functions, including the stash.
Each function requires 2.
template
<
unsigned
N
>
struct
Functions
{
//! The constants required for all of the hash functions, including the stash.
//! Each function requires 2.
uint2
constants
[
N
];
//! Generate new hash function constants.
/*! The parameters are only used for debugging and examining the key
distribution.
\param[in] num_keys Debug: Number of keys in the input.
/*! The parameters are only used for debugging and examining the key
distribution.
\param[in] num_keys Debug: Number of keys in the input.
\param[in] d_keys Debug: Device array of the input keys.
\param[in] table_size Debug: Size of the hash table.
*/
void
Generate
(
const
unsigned
num_keys
,
const
unsigned
*
d_keys
,
const
unsigned
table_size
)
{
void
Generate
(
const
unsigned
num_keys
,
const
unsigned
*
d_keys
,
const
unsigned
table_size
)
{
GenerateFunctions
(
N
,
num_keys
,
d_keys
,
table_size
,
constants
);
}
};
...
...
@@ -56,17 +54,16 @@ struct Functions {
! \param[in] key Key being hashed.
! \returns The value of the hash function for the key.
*/
inline
__device__
__host__
unsigned
hash_function_inner
(
const
uint2
constants
,
const
unsigned
key
)
{
#if 1
// Fast version.
inline
__device__
__host__
unsigned
hash_function_inner
(
const
uint2
constants
,
const
unsigned
key
)
{
#if 1
// Fast version.
return
((
constants
.
x
^
key
)
+
constants
.
y
)
%
kPrimeDivisor
;
#else
// Slow version.
return
((
unsigned
long
long
)
constants
.
x
*
key
+
constants
.
y
)
%
kPrimeDivisor
;
#endif
}
}
//! Computes the value of a hash function for a given key.
/*! \param[in] functions All of the constants used by the hash functions.
...
...
@@ -75,22 +72,20 @@ unsigned hash_function_inner(const uint2 constants,
! \returns The value of a hash function with a given key.
*/
template
<
unsigned
kNumHashFunctions
>
TV_HOST_DEVICE_INLINE
unsigned
hash_function
(
const
Functions
<
kNumHashFunctions
>
functions
,
const
unsigned
which_function
,
const
unsigned
key
)
{
TV_HOST_DEVICE_INLINE
unsigned
hash_function
(
const
Functions
<
kNumHashFunctions
>
functions
,
const
unsigned
which_function
,
const
unsigned
key
)
{
return
hash_function_inner
(
functions
.
constants
[
which_function
],
key
);
}
//! Simple hash function used by the stash.
TV_HOST_DEVICE_INLINE
unsigned
stash_hash_function
(
const
uint2
stash_constants
,
const
unsigned
key
)
{
unsigned
stash_hash_function
(
const
uint2
stash_constants
,
const
unsigned
key
)
{
return
(
stash_constants
.
x
^
key
+
stash_constants
.
y
)
%
kStashSize
;
}
unsigned
generate_random_uint32
();
};
// namespace
CuckooHashing
};
// namespace
cuhash
#endif
include/cuhash/hash_table.cuh
View file @
19e73bbe
...
...
@@ -3,10 +3,10 @@
// -------------------------------------------------------------
// $Revision:$
// $Date:$
// -------------------------------------------------------------
// -------------------------------------------------------------
// This source code is distributed under the terms of license.txt in
// the root directory of this source distribution.
// -------------------------------------------------------------
// -------------------------------------------------------------
/**
* @file hash_table.cuh
...
...
@@ -19,8 +19,8 @@
#include "definitions.h"
#include "hash_table.h"
#include <tensorview/tensorview.h>
#include <driver_types.h>
#include <tensorview/tensorview.h>
namespace
cuhash
{
...
...
@@ -31,51 +31,42 @@ TV_HOST_DEVICE_INLINE Entry make_entry(unsigned key, unsigned value) {
//! Returns the key of an Entry.
TV_HOST_DEVICE_INLINE
unsigned
get_key
(
Entry
entry
)
{
return
(
unsigned
)
(
entry
>>
32
);
return
(
unsigned
)(
entry
>>
32
);
}
//! Returns the value of an Entry.
TV_HOST_DEVICE_INLINE
unsigned
get_value
(
Entry
entry
)
{
return
(
unsigned
)
(
entry
&
0xffffffff
);
return
(
unsigned
)(
entry
&
0xffffffff
);
}
//! @name Internal
//! @brief Functions used for building the hash table.
//! @{
//! Fills the entire array with a specific value.
template
<
class
T
>
__global__
void
clear_table
(
const
unsigned
table_size
,
const
T
value
,
T
*
table
)
{
unsigned
thread_index
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
+
template
<
class
T
>
__global__
void
clear_table
(
const
unsigned
table_size
,
const
T
value
,
T
*
table
)
{
unsigned
thread_index
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
+
blockIdx
.
y
*
blockDim
.
x
*
gridDim
.
x
;
if
(
thread_index
<
table_size
)
{
table
[
thread_index
]
=
value
;
}
}
//! Determine where in the hash table the key could be located.
template
<
unsigned
kNumHashFunctions
>
__device__
void
KeyLocations
(
const
Functions
<
kNumHashFunctions
>
constants
,
const
unsigned
table_size
,
const
unsigned
key
,
unsigned
locations
[
kNumHashFunctions
])
{
// Compute all possible locations for the key in the big table.
#pragma unroll
__device__
void
KeyLocations
(
const
Functions
<
kNumHashFunctions
>
constants
,
const
unsigned
table_size
,
const
unsigned
key
,
unsigned
locations
[
kNumHashFunctions
])
{
// Compute all possible locations for the key in the big table.
#pragma unroll
for
(
int
i
=
0
;
i
<
kNumHashFunctions
;
++
i
)
{
locations
[
i
]
=
hash_function
(
constants
,
i
,
key
)
%
table_size
;
}
}
//! @}
/* --------------------------------------------------------------------------
Retrieval functions.
-------------------------------------------------------------------------- */
...
...
@@ -87,28 +78,27 @@ KeyLocations(const Functions<kNumHashFunctions> constants,
* @param[in] constants The hash functions used to build the table
* @param[in] stash_constants The hash function used to build the stash
* @param[in] stash_count The number of items in the stash
* @param[out] num_probes_required Debug only: The number of probes required to resolve the query.
* @returns The value of the query key, if the key exists in the table. Otherwise, \ref kNotFound will be returned.
* @param[out] num_probes_required Debug only: The number of probes required
* to resolve the query.
* @returns The value of the query key, if the key exists in the table.
* Otherwise, \ref kNotFound will be returned.
*/
template
<
unsigned
kNumHashFunctions
>
__device__
unsigned
retrieve
(
const
unsigned
query_key
,
const
unsigned
table_size
,
const
Entry
*
table
,
const
Functions
<
kNumHashFunctions
>
constants
,
const
uint2
stash_constants
,
const
unsigned
stash_count
,
unsigned
*
num_probes_required
=
NULL
)
{
template
<
unsigned
kNumHashFunctions
>
__device__
unsigned
retrieve
(
const
unsigned
query_key
,
const
unsigned
table_size
,
const
Entry
*
table
,
const
Functions
<
kNumHashFunctions
>
constants
,
const
uint2
stash_constants
,
const
unsigned
stash_count
,
unsigned
*
num_probes_required
=
NULL
)
{
// Identify all of the locations that the key can be located in.
unsigned
locations
[
kNumHashFunctions
];
KeyLocations
(
constants
,
table_size
,
query_key
,
locations
);
// Check each location until the key is found.
unsigned
num_probes
=
1
;
Entry
entry
=
table
[
locations
[
0
]];
unsigned
key
=
get_key
(
entry
);
Entry
entry
=
table
[
locations
[
0
]];
unsigned
key
=
get_key
(
entry
);
#pragma unroll
#pragma unroll
for
(
unsigned
i
=
1
;
i
<
kNumHashFunctions
;
++
i
)
{
if
(
key
!=
query_key
&&
key
!=
kNotFound
)
{
num_probes
++
;
...
...
@@ -138,37 +128,26 @@ unsigned retrieve(const unsigned query_key,
}
}
//! Perform a retrieval from a basic hash table. Each thread manages a single query.
template
<
unsigned
kNumHashFunctions
>
__global__
void
hash_retrieve
(
const
unsigned
n_queries
,
const
unsigned
*
keys_in
,
const
unsigned
table_size
,
const
Entry
*
table
,
const
Functions
<
kNumHashFunctions
>
constants
,
const
uint2
stash_constants
,
const
unsigned
stash_count
,
unsigned
*
values_out
,
unsigned
*
num_probes_required
=
NULL
)
{
//! Perform a retrieval from a basic hash table. Each thread manages a single
//! query.
template
<
unsigned
kNumHashFunctions
>
__global__
void
hash_retrieve
(
const
unsigned
n_queries
,
const
unsigned
*
keys_in
,
const
unsigned
table_size
,
const
Entry
*
table
,
const
Functions
<
kNumHashFunctions
>
constants
,
const
uint2
stash_constants
,
const
unsigned
stash_count
,
unsigned
*
values_out
,
unsigned
*
num_probes_required
=
NULL
)
{
// Get the key.
unsigned
thread_index
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
+
unsigned
thread_index
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
+
blockIdx
.
y
*
blockDim
.
x
*
gridDim
.
x
;
if
(
thread_index
>=
n_queries
)
return
;
unsigned
key
=
keys_in
[
thread_index
];
values_out
[
thread_index
]
=
retrieve
<
kNumHashFunctions
>
(
key
,
table_size
,
table
,
constants
,
stash_constants
,
stash_count
,
(
num_probes_required
?
num_probes_required
+
thread_index
:
NULL
));
}
values_out
[
thread_index
]
=
retrieve
<
kNumHashFunctions
>
(
key
,
table_size
,
table
,
constants
,
stash_constants
,
stash_count
,
(
num_probes_required
?
num_probes_required
+
thread_index
:
NULL
));
}
/* --------------------------------------------------------------------------
Build a cuckoo hash table.
...
...
@@ -176,55 +155,53 @@ void hash_retrieve(const unsigned n_queries,
//! @name Internal
//! @{
//! Determine where to insert the key next. The hash functions are used in round-robin order.
template
<
unsigned
kNumHashFunctions
>
__device__
unsigned
determine_next_location
(
const
Functions
<
kNumHashFunctions
>
constants
,
const
unsigned
table_size
,
const
unsigned
key
,
const
unsigned
previous_location
)
{
//! Determine where to insert the key next. The hash functions are used in
//! round-robin order.
template
<
unsigned
kNumHashFunctions
>
__device__
unsigned
determine_next_location
(
const
Functions
<
kNumHashFunctions
>
constants
,
const
unsigned
table_size
,
const
unsigned
key
,
const
unsigned
previous_location
)
{
// Identify all possible locations for the entry.
unsigned
locations
[
kNumHashFunctions
];
#pragma unroll
#pragma unroll
for
(
unsigned
i
=
0
;
i
<
kNumHashFunctions
;
++
i
)
{
locations
[
i
]
=
hash_function
(
constants
,
i
,
key
)
%
table_size
;
}
// Figure out where the item should be inserted next.
unsigned
next_location
=
locations
[
0
];
#pragma unroll
#pragma unroll
for
(
int
i
=
kNumHashFunctions
-
2
;
i
>=
0
;
--
i
)
{
next_location
=
(
previous_location
==
locations
[
i
]
?
locations
[
i
+
1
]
:
next_location
);
next_location
=
(
previous_location
==
locations
[
i
]
?
locations
[
i
+
1
]
:
next_location
);
}
return
next_location
;
}
//! Attempts to insert a single entry into the hash table.
/*! This process stops after a certain number of iterations. If the thread is
still holding onto an item because of an eviction, it tries the stash.
If it fails to enter the stash, it returns false.
Otherwise, it succeeds and returns true.
*/
template
<
unsigned
kNumHashFunctions
>
__device__
bool
insert
(
const
unsigned
table_size
,
const
Functions
<
kNumHashFunctions
>
constants
,
const
uint2
stash_constants
,
const
unsigned
max_iteration_attempts
,
Entry
*
table
,
unsigned
*
stash_count
,
Entry
entry
,
unsigned
*
iterations_used
)
{
template
<
unsigned
kNumHashFunctions
>
__device__
bool
insert
(
const
unsigned
table_size
,
const
Functions
<
kNumHashFunctions
>
constants
,
const
uint2
stash_constants
,
const
unsigned
max_iteration_attempts
,
Entry
*
table
,
unsigned
*
stash_count
,
Entry
entry
,
unsigned
*
iterations_used
)
{
unsigned
key
=
get_key
(
entry
);
// The key is always inserted into its first slot at the start.
unsigned
location
=
hash_function
(
constants
,
0
,
key
)
%
table_size
;
// Keep inserting until an empty slot is found or the eviction chain grows too large.
// Keep inserting until an empty slot is found or the eviction chain grows too
// large.
for
(
unsigned
its
=
1
;
its
<=
max_iteration_attempts
;
its
++
)
{
// Insert the new entry.
entry
=
atomicExch
(
&
table
[
location
],
entry
);
key
=
get_key
(
entry
);
key
=
get_key
(
entry
);
// If no key was evicted, we're done.
if
(
key
==
kKeyEmpty
)
{
...
...
@@ -251,54 +228,46 @@ bool insert(const unsigned table_size,
return
true
;
}
// Build a basic hash table, using one big table.
template
<
unsigned
kNumHashFunctions
>
__global__
void
CuckooHash
(
const
unsigned
n_entries
,
const
unsigned
*
keys
,
const
unsigned
*
values
,
const
unsigned
table_size
,
const
Functions
<
kNumHashFunctions
>
constants
,
const
unsigned
max_iteration_attempts
,
Entry
*
table
,
uint2
stash_constants
,
unsigned
*
stash_count
,
unsigned
*
failures
,
unsigned
*
iterations_taken
=
nullptr
)
{
template
<
unsigned
kNumHashFunctions
>
__global__
void
CuckooHash
(
const
unsigned
n_entries
,
const
unsigned
*
keys
,
const
unsigned
*
values
,
const
unsigned
table_size
,
const
Functions
<
kNumHashFunctions
>
constants
,
const
unsigned
max_iteration_attempts
,
Entry
*
table
,
uint2
stash_constants
,
unsigned
*
stash_count
,
unsigned
*
failures
,
unsigned
*
iterations_taken
=
nullptr
)
{
// Check if this thread has an item and if any previous threads failed.
unsigned
thread_index
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
+
unsigned
thread_index
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
+
blockIdx
.
y
*
blockDim
.
x
*
gridDim
.
x
;
if
(
thread_index
>=
n_entries
||
*
failures
)
return
;
Entry
entry
=
make_entry
(
keys
[
thread_index
],
values
[
thread_index
]);
unsigned
iterations
=
0
;
bool
success
=
insert
<
kNumHashFunctions
>
(
table_size
,
constants
,
stash_constants
,
max_iteration_attempts
,
table
,
stash_count
,
entry
,
&
iterations
);
bool
success
=
insert
<
kNumHashFunctions
>
(
table_size
,
constants
,
stash_constants
,
max_iteration_attempts
,
table
,
stash_count
,
entry
,
&
iterations
);
if
(
success
==
false
)
{
// The eviction chain grew too large. Report failure.
#ifdef COUNT_UNINSERTED
#ifdef COUNT_UNINSERTED
atomicAdd
(
failures
,
1
);
#else
#else
*
failures
=
1
;
#endif
#endif
}
#ifdef TRACK_ITERATIONS
iterations_taken
[
thread_index
]
=
iterations
;
#endif
}
}
//! @}
};
// namespace
CuckooHashing
};
// namespace
cuhash
#endif
// Leave this at the end of the file
// Local Variables:
// mode:c++
...
...
include/cuhash/hash_table.h
View file @
19e73bbe
...
...
@@ -3,10 +3,10 @@
// -------------------------------------------------------------
// $Revision:$
// $Date:$
// -------------------------------------------------------------
// -------------------------------------------------------------
// This source code is distributed under the terms of license.txt in
// the root directory of this source distribution.
// -------------------------------------------------------------
// -------------------------------------------------------------
/**
* @file hash_table.h
...
...
@@ -17,15 +17,14 @@
#ifndef CUDAHT__CUCKOO__SRC__LIBRARY__HASH_TABLE__H
#define CUDAHT__CUCKOO__SRC__LIBRARY__HASH_TABLE__H
#include "definitions.h"
#include "hash_functions.h"
#include <cstdio>
/** \addtogroup cudpp_app
* @{
*/
/** \addtogroup cudpp_app
* @{
*/
/** \addtogroup cudpp_hash_data_structures
* @{
...
...
@@ -50,7 +49,8 @@ namespace cuhash {
//! Compute how many thread blocks are required for the given number of threads.
dim3
ComputeGridDim
(
unsigned
threads
);
//! Compute how long an eviction chain is allowed to become for a given input size.
//! Compute how long an eviction chain is allowed to become for a given input
//! size.
/*! \param[in] num_keys Number of keys in the input.
* \param[in] table_size Number of slots in the hash table.
* \param[in] num_functions Number of hash functions being used.
...
...
@@ -72,10 +72,10 @@ unsigned ComputeMaxIterations(const unsigned num_keys,
* @ingroup cudpp_app
*/
class
HashTable
{
public:
public:
HashTable
();
virtual
~
HashTable
()
{
Release
();}
virtual
~
HashTable
()
{
Release
();
}
//! Initialize the hash table's memory. Must be called before \ref
//! Build() and after the random number generator has been seeded.
...
...
@@ -87,7 +87,7 @@ class HashTable {
* 2-5. More hash functions make it easier
* to build the table, but increase
* retrieval times.
* @returns Whether the hash table was initialized successfully (true)
* @returns Whether the hash table was initialized successfully (true)
* or not (false).
*
* The minimum space usage is dependent on the number of functions
...
...
@@ -95,28 +95,27 @@ class HashTable {
* usage is 2.1, 1.1, 1.03, and 1.02 respectively.
*/
virtual
bool
Initialize
(
const
unsigned
max_input_size
,
const
float
space_usage
=
1.25
,
const
unsigned
num_functions
=
4
);
const
float
space_usage
=
1.25
,
const
unsigned
num_functions
=
4
);
//! Free all memory.
virtual
void
Release
();
//! Build the hash table.
/*! @param[in] input_size Number of key-value pairs being inserted.
* @param[in] d_keys Device memory array containing all of the input
* @param[in] d_keys Device memory array containing all of the input
* keys.
* @param[in] d_vals Device memory array containing the keys' values.
* @returns Whether the hash table was built successfully (true) or
* @returns Whether the hash table was built successfully (true) or
* not (false).
*
* Several attempts are allowed to build the hash table in case of failure.
* The input keys are expected to be completely unique.
* To reduce the chance of a failure, increase the space usage or number of
* To reduce the chance of a failure, increase the space usage or number of
* functions.
* Keys are not allowed to be equal to cuhash::kKeyEmpty.
*/
virtual
bool
Build
(
const
unsigned
input_size
,
const
unsigned
*
d_keys
,
virtual
bool
Build
(
const
unsigned
input_size
,
const
unsigned
*
d_keys
,
const
unsigned
*
d_vals
);
//! Query the hash table.
...
...
@@ -128,9 +127,8 @@ class HashTable {
* kNotFound is returned for any query key that failed to be found
* in the table.
*/
virtual
void
Retrieve
(
const
unsigned
n_queries
,
const
unsigned
*
d_query_keys
,
unsigned
*
d_query_results
);
virtual
void
Retrieve
(
const
unsigned
n_queries
,
const
unsigned
*
d_query_keys
,
unsigned
*
d_query_results
);
//! @name Accessors
/// @brief Mainly needed to use the __device__ CudaHT::retrieve()
...
...
@@ -138,96 +136,85 @@ class HashTable {
/// @{
//! Returns how many slots the hash table has.
inline
unsigned
get_table_size
()
const
{
return
table_size_
;}
inline
unsigned
get_table_size
()
const
{
return
table_size_
;
}
//! Returns how many items are stored in the stash.
inline
unsigned
get_stash_count
()
const
{
return
stash_count_
;}
inline
unsigned
get_stash_count
()
const
{
return
stash_count_
;
}
//! Returns the constants used by the stash.
inline
uint2
get_stash_constants
()
const
{
return
stash_constants_
;}
inline
uint2
get_stash_constants
()
const
{
return
stash_constants_
;
}
//! Returns the hash table contents.
inline
const
Entry
*
get_contents
()
const
{
return
d_contents_
;}
inline
const
Entry
*
get_contents
()
const
{
return
d_contents_
;
}
//! Returns the number of hash functions being used.
inline
unsigned
get_num_hash_functions
()
const
{
return
num_hash_functions_
;}
inline
unsigned
get_num_hash_functions
()
const
{
return
num_hash_functions_
;
}
//! When using two hash functions, returns the constants.
inline
Functions
<
2
>
get_constants_2
()
const
{
return
constants_2_
;}
inline
Functions
<
2
>
get_constants_2
()
const
{
return
constants_2_
;
}
//! When using three hash functions, returns the constants.
inline
Functions
<
3
>
get_constants_3
()
const
{
return
constants_3_
;}
inline
Functions
<
3
>
get_constants_3
()
const
{
return
constants_3_
;
}
//! When using four hash functions, returns the constants.
inline
Functions
<
4
>
get_constants_4
()
const
{
return
constants_4_
;}
inline
Functions
<
4
>
get_constants_4
()
const
{
return
constants_4_
;
}
//! When using five hash functions, returns the constants.
inline
Functions
<
5
>
get_constants_5
()
const
{
return
constants_5_
;}
inline
Functions
<
5
>
get_constants_5
()
const
{
return
constants_5_
;
}
/// @}
inline
Entry
*
data
(){
return
d_contents_
;}
inline
const
Entry
*
data
()
const
{
return
d_contents_
;}
protected:
unsigned
table_size_
;
//!< Size of the hash table.
unsigned
num_hash_functions_
;
//!< Number of hash functions being used.
Entry
*
d_contents_
;
//!< Device memory: The hash table contents. The stash is stored at the end.
unsigned
stash_count_
;
//!< Number of key-value pairs currently stored.
uint2
stash_constants_
;
//!< Hash function constants for the stash.
Functions
<
2
>
constants_2_
;
//!< Constants for a set of two hash functions.
Functions
<
3
>
constants_3_
;
//!< Constants for a set of three hash functions.
Functions
<
4
>
constants_4_
;
//!< Constants for a set of four hash functions.
Functions
<
5
>
constants_5_
;
//!< Constants for a set of five hash functions.
unsigned
*
d_failures_
;
//!< Device memory: General use error flag.
inline
Entry
*
data
()
{
return
d_contents_
;
}
inline
const
Entry
*
data
()
const
{
return
d_contents_
;
}
protected:
unsigned
table_size_
;
//!< Size of the hash table.
unsigned
num_hash_functions_
;
//!< Number of hash functions being used.
Entry
*
d_contents_
;
//!< Device memory: The hash table contents. The stash is
//!< stored at the end.
unsigned
stash_count_
;
//!< Number of key-value pairs currently stored.
uint2
stash_constants_
;
//!< Hash function constants for the stash.
Functions
<
2
>
constants_2_
;
//!< Constants for a set of two hash functions.
Functions
<
3
>
constants_3_
;
//!< Constants for a set of three hash functions.
Functions
<
4
>
constants_4_
;
//!< Constants for a set of four hash functions.
Functions
<
5
>
constants_5_
;
//!< Constants for a set of five hash functions.
unsigned
*
d_failures_
;
//!< Device memory: General use error flag.
};
/*! @name Internal
* @{
*/
namespace
CUDAWrapper
{
//! Fills a 64-bit array with a particular value.
void
ClearTable
(
const
unsigned
slots_in_table
,
const
Entry
fill_value
,
Entry
*
d_array
);
void
ClearTable
(
const
unsigned
slots_in_table
,
const
Entry
fill_value
,
Entry
*
d_array
);
//! Calls the Cuckoo Hash construction kernel.
void
CallCuckooHash
(
const
unsigned
n_entries
,
const
unsigned
num_hash_functions
,
const
unsigned
*
d_keys
,
const
unsigned
*
d_values
,
const
unsigned
table_size
,
const
Functions
<
2
>
constants_2
,
const
Functions
<
3
>
constants_3
,
const
Functions
<
4
>
constants_4
,
const
Functions
<
5
>
constants_5
,
const
unsigned
max_iteration_attempts
,
Entry
*
d_contents
,
uint2
stash_constants
,
unsigned
*
d_stash_count
,
unsigned
*
d_failures
,
unsigned
*
d_iterations_taken
);
void
CallCuckooHash
(
const
unsigned
n_entries
,
const
unsigned
num_hash_functions
,
const
unsigned
*
d_keys
,
const
unsigned
*
d_values
,
const
unsigned
table_size
,
const
Functions
<
2
>
constants_2
,
const
Functions
<
3
>
constants_3
,
const
Functions
<
4
>
constants_4
,
const
Functions
<
5
>
constants_5
,
const
unsigned
max_iteration_attempts
,
Entry
*
d_contents
,
uint2
stash_constants
,
unsigned
*
d_stash_count
,
unsigned
*
d_failures
,
unsigned
*
d_iterations_taken
);
//! Calls the kernel that performs retrievals.
void
CallHashRetrieve
(
const
unsigned
n_queries
,
const
unsigned
num_hash_functions
,
const
unsigned
*
keys_in
,
const
unsigned
table_size
,
const
Entry
*
table
,
const
Functions
<
2
>
constants_2
,
const
Functions
<
3
>
constants_3
,
const
Functions
<
4
>
constants_4
,
const
Functions
<
5
>
constants_5
,
const
uint2
stash_constants
,
const
unsigned
stash_count
,
unsigned
*
values_out
);
};
void
CallHashRetrieve
(
const
unsigned
n_queries
,
const
unsigned
num_hash_functions
,
const
unsigned
*
keys_in
,
const
unsigned
table_size
,
const
Entry
*
table
,
const
Functions
<
2
>
constants_2
,
const
Functions
<
3
>
constants_3
,
const
Functions
<
4
>
constants_4
,
const
Functions
<
5
>
constants_5
,
const
uint2
stash_constants
,
const
unsigned
stash_count
,
unsigned
*
values_out
);
};
// namespace CUDAWrapper
/// @}
};
// namespace
CuckooHashing
};
// namespace
cuhash
/** @} */
// end hash table data structures
/** @} */
// end cudpp_app
...
...
include/paramsgrid.h
View file @
19e73bbe
// Copyright 2019 Yan Yan
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//
// http://www.apache.org/licenses/LICENSE-2.0
//
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This file is used for c++ unit test, but pytorch jit ops don't support c++ debug build.
// This file is used for c++ unit test, but pytorch jit ops don't support c++
// debug build.
#ifndef PARAMS_GRID_H_
#define PARAMS_GRID_H_
...
...
include/prettyprint.h
deleted
100644 → 0
View file @
c336139f
// Copyright Louis Delacroix 2010 - 2014.
// Distributed under the Boost Software License, Version 1.0.
// (See accompanying file LICENSE_1_0.txt or copy at
// http://www.boost.org/LICENSE_1_0.txt)
//
// A pretty printing library for C++
//
// Usage:
// Include this header, and operator<< will "just work".
#ifndef H_PRETTY_PRINT
#define H_PRETTY_PRINT
#include <cstddef>
#include <iterator>
#include <memory>
#include <ostream>
#include <set>
#include <tuple>
#include <type_traits>
#include <unordered_set>
#include <utility>
#include <valarray>
namespace
pretty_print
{
namespace
detail
{
// SFINAE type trait to detect whether T::const_iterator exists.
struct
sfinae_base
{
using
yes
=
char
;
using
no
=
yes
[
2
];
};
template
<
typename
T
>
struct
has_const_iterator
:
private
sfinae_base
{
private:
template
<
typename
C
>
static
yes
&
test
(
typename
C
::
const_iterator
*
);
template
<
typename
C
>
static
no
&
test
(...);
public:
static
const
bool
value
=
sizeof
(
test
<
T
>
(
nullptr
))
==
sizeof
(
yes
);
using
type
=
T
;
};
template
<
typename
T
>
struct
has_begin_end
:
private
sfinae_base
{
private:
template
<
typename
C
>
static
yes
&
f
(
typename
std
::
enable_if
<
std
::
is_same
<
decltype
(
static_cast
<
typename
C
::
const_iterator
(
C
::*
)()
const
>
(
&
C
::
begin
)),
typename
C
::
const_iterator
(
C
::*
)()
const
>::
value
>::
type
*
);
template
<
typename
C
>
static
no
&
f
(...);
template
<
typename
C
>
static
yes
&
g
(
typename
std
::
enable_if
<
std
::
is_same
<
decltype
(
static_cast
<
typename
C
::
const_iterator
(
C
::*
)()
const
>
(
&
C
::
end
)),
typename
C
::
const_iterator
(
C
::*
)()
const
>::
value
,
void
>::
type
*
);
template
<
typename
C
>
static
no
&
g
(...);
public:
static
bool
const
beg_value
=
sizeof
(
f
<
T
>
(
nullptr
))
==
sizeof
(
yes
);
static
bool
const
end_value
=
sizeof
(
g
<
T
>
(
nullptr
))
==
sizeof
(
yes
);
};
}
// namespace detail
// Holds the delimiter values for a specific character type
template
<
typename
TChar
>
struct
delimiters_values
{
using
char_type
=
TChar
;
const
char_type
*
prefix
;
const
char_type
*
delimiter
;
const
char_type
*
postfix
;
};
// Defines the delimiter values for a specific container and character type
template
<
typename
T
,
typename
TChar
>
struct
delimiters
{
using
type
=
delimiters_values
<
TChar
>
;
static
const
type
values
;
};
// Functor to print containers. You can use this directly if you want
// to specificy a non-default delimiters type. The printing logic can
// be customized by specializing the nested template.
template
<
typename
T
,
typename
TChar
=
char
,
typename
TCharTraits
=
::
std
::
char_traits
<
TChar
>,
typename
TDelimiters
=
delimiters
<
T
,
TChar
>>
struct
print_container_helper
{
using
delimiters_type
=
TDelimiters
;
using
ostream_type
=
std
::
basic_ostream
<
TChar
,
TCharTraits
>
;
template
<
typename
U
>
struct
printer
{
static
void
print_body
(
const
U
&
c
,
ostream_type
&
stream
)
{
using
std
::
begin
;
using
std
::
end
;
auto
it
=
begin
(
c
);
const
auto
the_end
=
end
(
c
);
if
(
it
!=
the_end
)
{
for
(
;
;
)
{
stream
<<
*
it
;
if
(
++
it
==
the_end
)
break
;
if
(
delimiters_type
::
values
.
delimiter
!=
NULL
)
stream
<<
delimiters_type
::
values
.
delimiter
;
}
}
}
};
print_container_helper
(
const
T
&
container
)
:
container_
(
container
)
{
}
inline
void
operator
()(
ostream_type
&
stream
)
const
{
if
(
delimiters_type
::
values
.
prefix
!=
NULL
)
stream
<<
delimiters_type
::
values
.
prefix
;
printer
<
T
>::
print_body
(
container_
,
stream
);
if
(
delimiters_type
::
values
.
postfix
!=
NULL
)
stream
<<
delimiters_type
::
values
.
postfix
;
}
private:
const
T
&
container_
;
};
// Specialization for pairs
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
,
typename
TDelimiters
>
template
<
typename
T1
,
typename
T2
>
struct
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
printer
<
std
::
pair
<
T1
,
T2
>>
{
using
ostream_type
=
typename
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
ostream_type
;
static
void
print_body
(
const
std
::
pair
<
T1
,
T2
>
&
c
,
ostream_type
&
stream
)
{
stream
<<
c
.
first
;
if
(
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
!=
NULL
)
stream
<<
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
;
stream
<<
c
.
second
;
}
};
// Specialization for tuples
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
,
typename
TDelimiters
>
template
<
typename
...
Args
>
struct
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
printer
<
std
::
tuple
<
Args
...
>>
{
using
ostream_type
=
typename
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
ostream_type
;
using
element_type
=
std
::
tuple
<
Args
...
>
;
template
<
std
::
size_t
I
>
struct
Int
{
};
static
void
print_body
(
const
element_type
&
c
,
ostream_type
&
stream
)
{
tuple_print
(
c
,
stream
,
Int
<
0
>
());
}
static
void
tuple_print
(
const
element_type
&
,
ostream_type
&
,
Int
<
sizeof
...(
Args
)
>
)
{
}
static
void
tuple_print
(
const
element_type
&
c
,
ostream_type
&
stream
,
typename
std
::
conditional
<
sizeof
...(
Args
)
!=
0
,
Int
<
0
>
,
std
::
nullptr_t
>::
type
)
{
stream
<<
std
::
get
<
0
>
(
c
);
tuple_print
(
c
,
stream
,
Int
<
1
>
());
}
template
<
std
::
size_t
N
>
static
void
tuple_print
(
const
element_type
&
c
,
ostream_type
&
stream
,
Int
<
N
>
)
{
if
(
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
!=
NULL
)
stream
<<
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
;
stream
<<
std
::
get
<
N
>
(
c
);
tuple_print
(
c
,
stream
,
Int
<
N
+
1
>
());
}
};
// Prints a print_container_helper to the specified stream.
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
,
typename
TDelimiters
>
inline
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
operator
<<
(
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
stream
,
const
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>
&
helper
)
{
helper
(
stream
);
return
stream
;
}
// Basic is_container template; specialize to derive from std::true_type for all desired container types
template
<
typename
T
>
struct
is_container
:
public
std
::
integral_constant
<
bool
,
detail
::
has_const_iterator
<
T
>::
value
&&
detail
::
has_begin_end
<
T
>::
beg_value
&&
detail
::
has_begin_end
<
T
>::
end_value
>
{
};
template
<
typename
T
,
std
::
size_t
N
>
struct
is_container
<
T
[
N
]
>
:
std
::
true_type
{
};
template
<
std
::
size_t
N
>
struct
is_container
<
char
[
N
]
>
:
std
::
false_type
{
};
template
<
typename
T
>
struct
is_container
<
std
::
valarray
<
T
>>
:
std
::
true_type
{
};
template
<
typename
T1
,
typename
T2
>
struct
is_container
<
std
::
pair
<
T1
,
T2
>>
:
std
::
true_type
{
};
template
<
typename
...
Args
>
struct
is_container
<
std
::
tuple
<
Args
...
>>
:
std
::
true_type
{
};
// Default delimiters
template
<
typename
T
>
struct
delimiters
<
T
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
>
const
delimiters_values
<
char
>
delimiters
<
T
,
char
>::
values
=
{
"["
,
", "
,
"]"
};
template
<
typename
T
>
struct
delimiters
<
T
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
>
const
delimiters_values
<
wchar_t
>
delimiters
<
T
,
wchar_t
>::
values
=
{
L"["
,
L", "
,
L"]"
};
// Delimiters for (multi)set and unordered_(multi)set
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<
::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<
::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<
::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<
::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<
::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<
::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<
::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<
::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<
::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
// Delimiters for pair and tuple
template
<
typename
T1
,
typename
T2
>
struct
delimiters
<
std
::
pair
<
T1
,
T2
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T1
,
typename
T2
>
const
delimiters_values
<
char
>
delimiters
<
std
::
pair
<
T1
,
T2
>
,
char
>::
values
=
{
"("
,
", "
,
")"
};
template
<
typename
T1
,
typename
T2
>
struct
delimiters
<
::
std
::
pair
<
T1
,
T2
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T1
,
typename
T2
>
const
delimiters_values
<
wchar_t
>
delimiters
<
::
std
::
pair
<
T1
,
T2
>
,
wchar_t
>::
values
=
{
L"("
,
L", "
,
L")"
};
template
<
typename
...
Args
>
struct
delimiters
<
std
::
tuple
<
Args
...
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
...
Args
>
const
delimiters_values
<
char
>
delimiters
<
std
::
tuple
<
Args
...
>
,
char
>::
values
=
{
"("
,
", "
,
")"
};
template
<
typename
...
Args
>
struct
delimiters
<
::
std
::
tuple
<
Args
...
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
...
Args
>
const
delimiters_values
<
wchar_t
>
delimiters
<
::
std
::
tuple
<
Args
...
>
,
wchar_t
>::
values
=
{
L"("
,
L", "
,
L")"
};
// Type-erasing helper class for easy use of custom delimiters.
// Requires TCharTraits = std::char_traits<TChar> and TChar = char or wchar_t, and MyDelims needs to be defined for TChar.
// Usage: "cout << pretty_print::custom_delims<MyDelims>(x)".
struct
custom_delims_base
{
virtual
~
custom_delims_base
()
{
}
virtual
std
::
ostream
&
stream
(
::
std
::
ostream
&
)
=
0
;
virtual
std
::
wostream
&
stream
(
::
std
::
wostream
&
)
=
0
;
};
template
<
typename
T
,
typename
Delims
>
struct
custom_delims_wrapper
:
custom_delims_base
{
custom_delims_wrapper
(
const
T
&
t_
)
:
t
(
t_
)
{
}
std
::
ostream
&
stream
(
std
::
ostream
&
s
)
{
return
s
<<
print_container_helper
<
T
,
char
,
std
::
char_traits
<
char
>
,
Delims
>
(
t
);
}
std
::
wostream
&
stream
(
std
::
wostream
&
s
)
{
return
s
<<
print_container_helper
<
T
,
wchar_t
,
std
::
char_traits
<
wchar_t
>
,
Delims
>
(
t
);
}
private:
const
T
&
t
;
};
template
<
typename
Delims
>
struct
custom_delims
{
template
<
typename
Container
>
custom_delims
(
const
Container
&
c
)
:
base
(
new
custom_delims_wrapper
<
Container
,
Delims
>
(
c
))
{
}
std
::
unique_ptr
<
custom_delims_base
>
base
;
};
template
<
typename
TChar
,
typename
TCharTraits
,
typename
Delims
>
inline
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
operator
<<
(
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
s
,
const
custom_delims
<
Delims
>
&
p
)
{
return
p
.
base
->
stream
(
s
);
}
// A wrapper for a C-style array given as pointer-plus-size.
// Usage: std::cout << pretty_print_array(arr, n) << std::endl;
template
<
typename
T
>
struct
array_wrapper_n
{
typedef
const
T
*
const_iterator
;
typedef
T
value_type
;
array_wrapper_n
(
const
T
*
const
a
,
size_t
n
)
:
_array
(
a
),
_n
(
n
)
{
}
inline
const_iterator
begin
()
const
{
return
_array
;
}
inline
const_iterator
end
()
const
{
return
_array
+
_n
;
}
private:
const
T
*
const
_array
;
size_t
_n
;
};
// A wrapper for hash-table based containers that offer local iterators to each bucket.
// Usage: std::cout << bucket_print(m, 4) << std::endl; (Prints bucket 5 of container m.)
template
<
typename
T
>
struct
bucket_print_wrapper
{
typedef
typename
T
::
const_local_iterator
const_iterator
;
typedef
typename
T
::
size_type
size_type
;
const_iterator
begin
()
const
{
return
m_map
.
cbegin
(
n
);
}
const_iterator
end
()
const
{
return
m_map
.
cend
(
n
);
}
bucket_print_wrapper
(
const
T
&
m
,
size_type
bucket
)
:
m_map
(
m
),
n
(
bucket
)
{
}
private:
const
T
&
m_map
;
const
size_type
n
;
};
}
// namespace pretty_print
// Global accessor functions for the convenience wrappers
template
<
typename
T
>
inline
pretty_print
::
array_wrapper_n
<
T
>
pretty_print_array
(
const
T
*
const
a
,
size_t
n
)
{
return
pretty_print
::
array_wrapper_n
<
T
>
(
a
,
n
);
}
template
<
typename
T
>
pretty_print
::
bucket_print_wrapper
<
T
>
bucket_print
(
const
T
&
m
,
typename
T
::
size_type
n
)
{
return
pretty_print
::
bucket_print_wrapper
<
T
>
(
m
,
n
);
}
// Main magic entry point: An overload snuck into namespace std.
// Can we do better?
namespace
std
{
// Prints a container to the stream using default delimiters
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
>
inline
typename
enable_if
<
::
pretty_print
::
is_container
<
T
>::
value
,
basic_ostream
<
TChar
,
TCharTraits
>
&>::
type
operator
<<
(
basic_ostream
<
TChar
,
TCharTraits
>
&
stream
,
const
T
&
container
)
{
return
stream
<<
::
pretty_print
::
print_container_helper
<
T
,
TChar
,
TCharTraits
>
(
container
);
}
}
#endif // H_PRETTY_PRINT
include/pybind11_utils.h
deleted
100644 → 0
View file @
c336139f
// Copyright 2019 Yan Yan
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <tensorview/tensorview.h>
#include <tensorview/tensor.h>
#include <algorithm>
#include <array>
#include <iostream>
#include <pybind11/functional.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace
py
=
pybind11
;
namespace
tv
{
template
<
typename
T
>
TensorView
<
T
>
arrayt2tv
(
py
::
array_t
<
T
>
arr
)
{
Shape
shape
;
for
(
int
i
=
0
;
i
<
arr
.
ndim
();
++
i
)
{
shape
.
push_back
(
arr
.
shape
(
i
));
}
return
TensorView
<
T
>
(
arr
.
mutable_data
(),
shape
);
}
template
<
typename
T
>
TensorView
<
const
T
>
carrayt2tv
(
py
::
array_t
<
T
>
arr
)
{
Shape
shape
;
for
(
int
i
=
0
;
i
<
arr
.
ndim
();
++
i
)
{
shape
.
push_back
(
arr
.
shape
(
i
));
}
return
TensorView
<
const
T
>
(
arr
.
data
(),
shape
);
}
template
<
typename
T
>
TensorView
<
T
>
vector2tv
(
std
::
vector
<
T
>
&
arr
)
{
return
TensorView
<
T
>
(
arr
.
data
(),
{
arr
.
size
()});
}
template
<
typename
T
>
TensorView
<
T
>
vector2tv
(
std
::
vector
<
T
>
&
arr
,
Shape
shape
)
{
TV_ASSERT_INVALID_ARG
(
shape
.
prod
()
==
arr
.
size
(),
"error"
);
return
TensorView
<
T
>
(
arr
.
data
(),
shape
);
}
template
<
typename
T
>
TensorView
<
const
T
>
vector2tv
(
const
std
::
vector
<
T
>
&
arr
)
{
return
TensorView
<
const
T
>
(
arr
.
data
(),
{
arr
.
size
()});
}
template
<
typename
T
>
std
::
vector
<
T
>
shape2stride
(
const
std
::
vector
<
T
>
&
shape
,
T
itemsize
)
{
T
p
=
T
(
1
);
std
::
vector
<
T
>
res
;
for
(
auto
iter
=
shape
.
rbegin
();
iter
!=
shape
.
rend
();
++
iter
)
{
res
.
push_back
(
p
*
itemsize
);
p
*=
*
iter
;
}
std
::
reverse
(
res
.
begin
(),
res
.
end
());
return
res
;
}
tv
::
DType
get_array_tv_dtype
(
const
py
::
array
&
arr
){
//
switch
(
arr
.
dtype
().
kind
()){
case
'b'
:
return
tv
::
bool_
;
case
'i'
:
{
switch
(
arr
.
itemsize
()){
case
1
:
return
tv
::
int8
;
case
2
:
return
tv
::
int16
;
case
4
:
return
tv
::
int32
;
case
8
:
return
tv
::
int64
;
default:
break
;
}
}
case
'u'
:
{
switch
(
arr
.
itemsize
()){
case
1
:
return
tv
::
uint8
;
case
2
:
return
tv
::
uint16
;
case
4
:
return
tv
::
uint32
;
case
8
:
return
tv
::
uint64
;
default:
break
;
}
}
case
'f'
:
{
switch
(
arr
.
itemsize
()){
case
4
:
return
tv
::
float32
;
case
8
:
return
tv
::
float64
;
default:
break
;
}
}
}
TV_THROW_RT_ERR
(
"unknown dtype"
,
arr
.
dtype
().
kind
(),
arr
.
itemsize
());
}
Tensor
array2tensor
(
py
::
array
&
arr
)
{
Shape
shape
;
for
(
int
i
=
0
;
i
<
arr
.
ndim
();
++
i
)
{
shape
.
push_back
(
arr
.
shape
(
i
));
}
return
tv
::
from_blob
(
arr
.
mutable_data
(),
shape
,
get_array_tv_dtype
(
arr
),
-
1
);
}
}
// namespace tv
include/spconv/box_iou.h
View file @
19e73bbe
// Copyright 2019 Yan Yan
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//
// http://www.apache.org/licenses/LICENSE-2.0
//
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef BOX_IOU_H
#define BOX_IOU_H
...
...
@@ -99,9 +98,10 @@ rbbox_iou(py::array_t<DType> box_corners, py::array_t<DType> qbox_corners,
}
template
<
typename
DType
>
py
::
array_t
<
DType
>
rbbox_intersection
(
py
::
array_t
<
DType
>
box_corners
,
py
::
array_t
<
DType
>
qbox_corners
,
py
::
array_t
<
DType
>
standup_iou
,
DType
standup_thresh
)
{
py
::
array_t
<
DType
>
rbbox_intersection
(
py
::
array_t
<
DType
>
box_corners
,
py
::
array_t
<
DType
>
qbox_corners
,
py
::
array_t
<
DType
>
standup_iou
,
DType
standup_thresh
)
{
namespace
bg
=
boost
::
geometry
;
typedef
bg
::
model
::
point
<
DType
,
2
,
bg
::
cs
::
cartesian
>
point_t
;
typedef
bg
::
model
::
polygon
<
point_t
>
polygon_t
;
...
...
@@ -152,6 +152,5 @@ rbbox_intersection(py::array_t<DType> box_corners, py::array_t<DType> qbox_corne
return
overlaps
;
}
}
// namespace spconv
#endif
\ No newline at end of file
include/spconv/fused_spconv_ops.h
View file @
19e73bbe
// Copyright 2019 Yan Yan
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//
// http://www.apache.org/licenses/LICENSE-2.0
//
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -17,17 +17,19 @@
#include <spconv/indice.h>
#include <spconv/reordering.h>
#include <tensorview/torch_utils.h>
#include <torch/script.h>
#include <torch_utils.h>
#include <utility/timer.h>
namespace
spconv
{
// torch.jit's doc says only support int64, so we need to convert to int32.
template
<
typename
T
>
torch
::
Tensor
fusedIndiceConvBatchNorm
(
torch
::
Tensor
features
,
torch
::
Tensor
filters
,
torch
::
Tensor
bias
,
torch
::
Tensor
indicePairs
,
torch
::
Tensor
indiceNum
,
int64_t
numActOut
,
int64_t
_inverse
,
int64_t
_subM
)
{
torch
::
Tensor
fusedIndiceConvBatchNorm
(
torch
::
Tensor
features
,
torch
::
Tensor
filters
,
torch
::
Tensor
bias
,
torch
::
Tensor
indicePairs
,
torch
::
Tensor
indiceNum
,
int64_t
numActOut
,
int64_t
_inverse
,
int64_t
_subM
)
{
bool
subM
=
_subM
!=
0
;
bool
inverse
=
_inverse
!=
0
;
auto
device
=
features
.
device
().
type
();
...
...
@@ -36,13 +38,16 @@ torch::Tensor fusedIndiceConvBatchNorm(torch::Tensor features, torch::Tensor fil
auto
numInPlanes
=
features
.
size
(
1
);
auto
numOutPlanes
=
filters
.
size
(
ndim
+
1
);
auto
indicePairNumCpu
=
indiceNum
.
to
({
torch
::
kCPU
});
auto
indicePairMaxSizeIter
=
std
::
max_element
(
indicePairNumCpu
.
data_ptr
<
int
>
(),
indicePairNumCpu
.
data_ptr
<
int
>
()
+
kernelVolume
);
int
indicePairMaxOffset
=
indicePairMaxSizeIter
-
indicePairNumCpu
.
data_ptr
<
int
>
();
auto
indicePairMaxSizeIter
=
std
::
max_element
(
indicePairNumCpu
.
data_ptr
<
int
>
(),
indicePairNumCpu
.
data_ptr
<
int
>
()
+
kernelVolume
);
int
indicePairMaxOffset
=
indicePairMaxSizeIter
-
indicePairNumCpu
.
data_ptr
<
int
>
();
int
indicePairMaxSize
=
*
indicePairMaxSizeIter
;
/*if (_subM){
std::vector<int> indicePairNumVec(indicePairNumCpu.data_ptr<int>(), indicePairNumCpu.data_ptr<int>() + kernelVolume);
std::vector<int> indicePairNumVec(indicePairNumCpu.data_ptr<int>(),
indicePairNumCpu.data_ptr<int>() + kernelVolume);
indicePairNumVec.erase(indicePairNumVec.begin() + indicePairMaxOffset);
auto indicePairVecMaxSizeIter = std::max_element(
...
...
@@ -55,8 +60,10 @@ torch::Tensor fusedIndiceConvBatchNorm(torch::Tensor features, torch::Tensor fil
// auto indicePairOptions =
// torch::TensorOptions().dtype(torch::kInt64).device(indicePairs.device());
torch
::
Tensor
output
=
torch
::
zeros
({
numActOut
,
numOutPlanes
},
options
).
copy_
(
bias
);
torch
::
Tensor
inputBuffer
=
torch
::
zeros
({
indicePairMaxSize
,
numInPlanes
},
options
);
torch
::
Tensor
output
=
torch
::
zeros
({
numActOut
,
numOutPlanes
},
options
).
copy_
(
bias
);
torch
::
Tensor
inputBuffer
=
torch
::
zeros
({
indicePairMaxSize
,
numInPlanes
},
options
);
torch
::
Tensor
outputBuffer
=
torch
::
zeros
({
indicePairMaxSize
,
numOutPlanes
},
options
);
filters
=
filters
.
view
({
-
1
,
numInPlanes
,
numOutPlanes
});
...
...
@@ -73,30 +80,31 @@ torch::Tensor fusedIndiceConvBatchNorm(torch::Tensor features, torch::Tensor fil
continue
;
}
// auto timer = spconv::CudaContextTimer<>();
auto
outputBufferBlob
=
torch
::
from_blob
(
outputBuffer
.
data_ptr
<
T
>
(),
{
nHot
,
numOutPlanes
},
options
);
auto
inputBufferBlob
=
torch
::
from_blob
(
inputBuffer
.
data_ptr
<
T
>
(),
{
nHot
,
numInPlanes
},
options
);
auto
outputBufferBlob
=
torch
::
from_blob
(
outputBuffer
.
data_ptr
<
T
>
(),
{
nHot
,
numOutPlanes
},
options
);
auto
inputBufferBlob
=
torch
::
from_blob
(
inputBuffer
.
data_ptr
<
T
>
(),
{
nHot
,
numInPlanes
},
options
);
if
(
device
==
torch
::
kCPU
)
{
functor
::
SparseGatherFunctor
<
tv
::
CPU
,
T
,
int
>
gatherFtor
;
gatherFtor
(
tv
::
CPU
(),
tv
::
torch2tv
<
T
>
(
inputBuffer
),
tv
::
torch2tv
<
const
T
>
(
features
),
tv
::
torch2tv
<
const
int
>
(
indicePairs
).
subview
(
i
,
inverse
),
nHot
);
}
#ifdef SPCONV_CUDA
tv
::
torch2tv
<
const
int
>
(
indicePairs
).
subview
(
i
,
inverse
),
nHot
);
}
#ifdef TV_CUDA
else
if
(
device
==
torch
::
kCUDA
)
{
functor
::
SparseGatherFunctor
<
tv
::
GPU
,
T
,
int
>
gatherFtor
;
gatherFtor
(
tv
::
TorchGPU
(),
tv
::
torch2tv
<
T
>
(
inputBuffer
),
tv
::
torch2tv
<
const
T
>
(
features
),
tv
::
torch2tv
<
const
int
>
(
indicePairs
).
subview
(
i
,
inverse
),
nHot
);
tv
::
torch2tv
<
const
int
>
(
indicePairs
).
subview
(
i
,
inverse
),
nHot
);
TV_CHECK_CUDA_ERR
();
/* slower than SparseGatherFunctor, may due to int->long conversion
auto indicePairLong = indicePairs[i][inverse].to(torch::kInt64);
auto indicePairBlob = torch::from_blob(indicePairLong.data<long>(), {nHot},
indicePairOptions);
torch::index_select_out(inputBufferBlob, features, 0,
indicePairBlob);*/
auto indicePairBlob = torch::from_blob(indicePairLong.data<long>(),
{nHot}, indicePairOptions); torch::index_select_out(inputBufferBlob,
features, 0, indicePairBlob);*/
}
#endif
else
{
...
...
@@ -111,16 +119,16 @@ torch::Tensor fusedIndiceConvBatchNorm(torch::Tensor features, torch::Tensor fil
functor
::
SparseScatterAddFunctor
<
tv
::
CPU
,
T
,
int
>
scatterFtor
;
scatterFtor
(
tv
::
CPU
(),
tv
::
torch2tv
<
T
>
(
output
),
tv
::
torch2tv
<
const
T
>
(
outputBuffer
),
tv
::
torch2tv
<
const
int
>
(
indicePairs
).
subview
(
i
,
!
inverse
),
nHot
,
true
);
tv
::
torch2tv
<
const
int
>
(
indicePairs
).
subview
(
i
,
!
inverse
),
nHot
,
true
);
}
#ifdef
SPCON
V_CUDA
#ifdef
T
V_CUDA
else
if
(
device
==
torch
::
kCUDA
)
{
functor
::
SparseScatterAddFunctor
<
tv
::
GPU
,
T
,
int
>
scatterFtor
;
scatterFtor
(
tv
::
TorchGPU
(),
tv
::
torch2tv
<
T
>
(
output
),
tv
::
torch2tv
<
const
T
>
(
outputBuffer
),
tv
::
torch2tv
<
const
int
>
(
indicePairs
).
subview
(
i
,
!
inverse
),
nHot
,
true
);
tv
::
torch2tv
<
const
int
>
(
indicePairs
).
subview
(
i
,
!
inverse
),
nHot
,
true
);
TV_CHECK_CUDA_ERR
();
}
#endif
...
...
include/spconv/geometry.h
View file @
19e73bbe
...
...
@@ -26,34 +26,27 @@ namespace detail {
template
<
typename
T
>
struct
ToUnsigned
;
template
<
>
struct
ToUnsigned
<
int
>
{
using
type
=
uint32_t
;
};
template
<
>
struct
ToUnsigned
<
int
>
{
using
type
=
uint32_t
;
};
template
<
>
struct
ToUnsigned
<
long
>
{
using
type
=
uint64_t
;
};
template
<
>
struct
ToUnsigned
<
long
>
{
using
type
=
uint64_t
;
};
template
<
typename
T
>
struct
FNVInternal
;
template
<
>
struct
FNVInternal
<
uint32_t
>
{
template
<
>
struct
FNVInternal
<
uint32_t
>
{
constexpr
static
uint32_t
defaultOffsetBasis
=
0x811C9DC5
;
constexpr
static
uint32_t
prime
=
0x01000193
;
};
template
<
>
struct
FNVInternal
<
uint64_t
>
{
template
<
>
struct
FNVInternal
<
uint64_t
>
{
constexpr
static
uint64_t
defaultOffsetBasis
=
0xcbf29ce484222325
;
constexpr
static
uint64_t
prime
=
0x100000001b3
;
};
}
}
// namespace detail
template
<
typename
T
>
using
to_unsigned_t
=
typename
detail
::
ToUnsigned
<
std
::
remove_const_t
<
T
>>::
type
;
template
<
typename
T
>
struct
FNV1a
:
detail
::
FNVInternal
<
T
>
{
std
::
size_t
operator
()(
const
T
*
data
,
std
::
size_t
size
){
template
<
typename
T
>
struct
FNV1a
:
detail
::
FNVInternal
<
T
>
{
std
::
size_t
operator
()(
const
T
*
data
,
std
::
size_t
size
)
{
to_unsigned_t
<
T
>
hash
=
detail
::
FNVInternal
<
T
>::
defaultOffsetBasis
;
for
(
std
::
size_t
i
=
0
;
i
<
size
;
++
i
)
{
hash
*=
detail
::
FNVInternal
<
T
>::
prime
;
...
...
include/spconv/indice.cu.h
View file @
19e73bbe
...
...
@@ -16,15 +16,14 @@
#define INDICE_CU_H_
#include <cuhash/hash_table.cuh>
#include <spconv/geometry.h>
#include <tensorview/
helper_kernel.cu
.h>
#include <tensorview/
kernel_utils
.h>
#include <tensorview/tensorview.h>
namespace
spconv
{
template
<
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
,
int
KernelMaxVolume
=
256
,
typename
Index1D
=
int
>
template
<
typename
Index
,
unsigned
NDim
,
int
KernelMaxVolume
=
256
,
typename
Index1D
=
int
>
__global__
void
prepareIndicePairsKernel
(
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
Index
>
indicesOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
tv
::
TensorView
<
Index1D
>
indicePairUnique
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
kernelSize
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
stride
,
...
...
@@ -65,11 +64,9 @@ __global__ void prepareIndicePairsKernel(
}
}
template
<
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
,
int
KernelMaxVolume
=
256
>
template
<
typename
Index
,
unsigned
NDim
,
int
KernelMaxVolume
=
256
>
__global__
void
prepareDeConvIndicePairsKernel
(
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
Index
>
indicesOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
tv
::
TensorView
<
Index
>
indicePairUnique
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
kernelSize
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
stride
,
...
...
@@ -128,12 +125,12 @@ __global__ void assignGridAndIndiceOutKernel(
}
}
template
<
typename
Index
,
unsigned
NDim
,
unsigned
kNumHashFunctions
=
4
>
__global__
void
assignIndiceOutKernel
(
tv
::
TensorView
<
Index
>
indice
sOut
,
int
numAct
,
tv
::
TensorView
<
Index
>
indicePairUniqu
e
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
int
batchSize
)
{
template
<
typename
Index
,
unsigned
NDim
,
unsigned
kNumHashFunctions
=
4
>
__global__
void
assignIndiceOutKernel
(
tv
::
TensorView
<
Index
>
indicesOut
,
int
numAct
,
tv
::
TensorView
<
Index
>
indice
PairUnique
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShap
e
,
int
batchSize
)
{
Index
index
;
auto
indicesOutPtr
=
indicesOut
.
data
();
...
...
@@ -145,8 +142,7 @@ __global__ void assignIndiceOutKernel(
}
}
template
<
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
,
unsigned
kNumHashFunctions
=
4
>
template
<
typename
Index
,
unsigned
NDim
,
unsigned
kNumHashFunctions
=
4
>
__global__
void
assignIndicePairsHashKernel
(
tv
::
TensorView
<
Index
>
indicesOut
,
int
numActIn
,
tv
::
TensorView
<
Index
>
indicePairs
,
...
...
@@ -161,9 +157,8 @@ assignIndicePairsHashKernel(tv::TensorView<Index> indicesOut, int numActIn,
for
(
int
i
=
0
;
i
<
kernelVolume
;
++
i
)
{
index
=
indicePairs
(
i
,
1
,
ix
);
if
(
index
>
-
1
)
{
auto
val
=
cuhash
::
retrieve
((
unsigned
)(
index
),
table_size
,
table
,
constants
,
stash_constants
,
stash_count
);
auto
val
=
cuhash
::
retrieve
((
unsigned
)(
index
),
table_size
,
table
,
constants
,
stash_constants
,
stash_count
);
assert
(
val
!=
cuhash
::
kNotFound
);
indicePairs
(
i
,
1
,
ix
)
=
(
unsigned
)
val
;
}
...
...
@@ -213,9 +208,8 @@ prepareSubMGridKernel(tv::TensorView<const Index> indicesIn,
template
<
typename
Index
,
unsigned
NDim
>
__global__
void
prepareSubMHashKernel
(
tv
::
TensorView
<
const
Index
>
indicesIn
,
unsigned
*
keys
,
unsigned
*
values
,
prepareSubMHashKernel
(
tv
::
TensorView
<
const
Index
>
indicesIn
,
unsigned
*
keys
,
unsigned
*
values
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
)
{
auto
numActIn
=
indicesIn
.
dim
(
0
);
Index
spatialVolume
=
1
;
...
...
@@ -233,7 +227,6 @@ prepareSubMHashKernel(tv::TensorView<const Index> indicesIn,
}
}
template
<
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
,
int
KernelMaxVolume
=
256
>
__global__
void
getSubMIndicePairsKernel
(
...
...
@@ -273,18 +266,17 @@ __global__ void getSubMIndicePairsKernel(
}
}
template
<
typename
Index
,
unsigned
NDim
,
int
KernelMaxVolume
=
256
,
unsigned
kNumHashFunctions
=
4
>
template
<
typename
Index
,
unsigned
NDim
,
int
KernelMaxVolume
=
256
,
unsigned
kNumHashFunctions
=
4
>
__global__
void
getSubMIndicePairsHashKernel
(
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
kernelSize
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
stride
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
padding
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
dilation
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
unsigned
table_size
,
const
cuhash
::
Entry
*
table
,
cuhash
::
Functions
<
kNumHashFunctions
>
constants
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
unsigned
table_size
,
const
cuhash
::
Entry
*
table
,
cuhash
::
Functions
<
kNumHashFunctions
>
constants
,
uint2
stash_constants
,
unsigned
stash_count
)
{
auto
numActIn
=
indicesIn
.
dim
(
0
);
Index
spatialVolume
=
1
;
...
...
@@ -306,9 +298,8 @@ __global__ void getSubMIndicePairsHashKernel(
auto
offset
=
pointPtr
[
NDim
];
index
=
tv
::
rowArrayIdx
<
Index
,
NDim
>
(
pointPtr
,
outSpatialShape
.
data
())
+
spatialVolume
*
indicesIn
(
ix
,
0
);
auto
val
=
cuhash
::
retrieve
((
unsigned
)(
index
),
table_size
,
table
,
constants
,
stash_constants
,
stash_count
);
auto
val
=
cuhash
::
retrieve
((
unsigned
)(
index
),
table_size
,
table
,
constants
,
stash_constants
,
stash_count
);
if
(
val
!=
cuhash
::
kNotFound
)
{
auto
oldNum
=
atomicAdd
(
indiceNum
.
data
()
+
offset
,
Index
(
1
));
indicePairs
(
offset
,
1
,
oldNum
)
=
val
;
...
...
@@ -318,7 +309,6 @@ __global__ void getSubMIndicePairsHashKernel(
}
}
template
<
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
__global__
void
resetGridKernel
(
const
Index
*
indicePairUnique
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
...
...
@@ -328,14 +318,12 @@ __global__ void resetGridKernel(const Index *indicePairUnique,
}
}
template
<
typename
T
>
__global__
void
arangeKernel
(
T
*
data
,
int
size
)
{
template
<
typename
T
>
__global__
void
arangeKernel
(
T
*
data
,
int
size
)
{
for
(
int
ix
:
tv
::
KernelLoopX
<
int
>
(
size
))
{
data
[
ix
]
=
ix
;
}
}
template
<
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
__global__
void
resetGridSubMKernel
(
const
Index
*
indices
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
...
...
include/spconv/indice.h
View file @
19e73bbe
// Copyright 2019 Yan Yan
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//
// http://www.apache.org/licenses/LICENSE-2.0
//
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -15,67 +15,88 @@
#ifndef SPARSE_CONV_INDICE_FUNCTOR_H_
#define SPARSE_CONV_INDICE_FUNCTOR_H_
#include <tensorview/tensorview.h>
#include <torch/script.h>
namespace
spconv
{
namespace
functor
{
namespace
spconv
{
namespace
functor
{
template
<
typename
Device
,
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
struct
CreateConvIndicePairFunctorP1
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
Index
>
indicesOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
tv
::
TensorView
<
Index
>
indicePairUnique
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
kernelSize
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
stride
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
padding
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
dilation
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
);
struct
CreateConvIndicePairFunctorP1
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
Index
>
indicesOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
tv
::
TensorView
<
Index
>
indicePairUnique
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
kernelSize
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
stride
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
padding
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
dilation
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
);
};
template
<
typename
Device
,
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
struct
CreateConvIndicePairFunctorP2
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
Index
>
indice
sOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
tv
::
TensorView
<
Index
>
indicePairUnique
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
,
bool
resetGrid
=
false
,
bool
useHash
=
true
);
struct
CreateConvIndicePairFunctorP2
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
Index
>
indicesOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indice
Pairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
tv
::
TensorView
<
Index
>
indicePairUnique
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
,
bool
resetGrid
=
false
,
bool
useHash
=
true
);
};
template
<
typename
Device
,
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
struct
CreateConvIndicePairFunctor
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
Index
>
indice
sOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
kernelSize
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
stride
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
padding
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
dilation
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
,
bool
resetGrid
=
false
,
bool
useHash
=
true
);
struct
CreateConvIndicePairFunctor
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
Index
>
indicesOut
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indice
Pairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
kernelSize
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
stride
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
padding
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
dilation
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
,
bool
resetGrid
=
false
,
bool
useHash
=
true
);
};
template
<
typename
Device
,
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
struct
CreateSubMIndicePairFunctor
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indice
sIn
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indicePairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
kernelSize
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
stride
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
padding
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
dilation
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
,
bool
resetGrid
=
false
,
bool
useHash
=
true
);
struct
CreateSubMIndicePairFunctor
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
Index
>
indicesIn
,
tv
::
TensorView
<
IndexGrid
>
gridsOut
,
tv
::
TensorView
<
Index
>
indice
Pairs
,
tv
::
TensorView
<
Index
>
indiceNum
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
kernelSize
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
stride
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
padding
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
dilation
,
const
tv
::
SimpleVector
<
Index
,
NDim
>
outSpatialShape
,
bool
transpose
,
bool
resetGrid
=
false
,
bool
useHash
=
true
);
};
}
// namespace functor
int
create_conv_indice_pair_p1_cuda
(
torch
::
Tensor
indicesIn
,
torch
::
Tensor
indicePairs
,
torch
::
Tensor
indiceNum
,
torch
::
Tensor
indicePairUnique
,
std
::
vector
<
int64_t
>
kernelSize
,
std
::
vector
<
int64_t
>
stride
,
std
::
vector
<
int64_t
>
padding
,
std
::
vector
<
int64_t
>
dilation
,
std
::
vector
<
int64_t
>
outSpatialShape
,
bool
transpose
);
int
create_conv_indice_pair_p2_cuda
(
torch
::
Tensor
indicesIn
,
torch
::
Tensor
indicesOut
,
torch
::
Tensor
gridsOut
,
torch
::
Tensor
indicePairs
,
torch
::
Tensor
indiceNum
,
torch
::
Tensor
indicePairUnique
,
std
::
vector
<
int64_t
>
outSpatialShape
,
bool
transpose
,
bool
resetGrid
,
bool
useHash
);
int
create_submconv_indice_pair_cuda
(
torch
::
Tensor
indicesIn
,
torch
::
Tensor
gridsOut
,
torch
::
Tensor
indicePairs
,
torch
::
Tensor
indiceNum
,
std
::
vector
<
int64_t
>
kernelSize
,
std
::
vector
<
int64_t
>
stride
,
std
::
vector
<
int64_t
>
padding
,
std
::
vector
<
int64_t
>
dilation
,
std
::
vector
<
int64_t
>
outSpatialShape
,
bool
transpose
,
bool
resetGrid
,
bool
useHash
);
}
// namespace spconv
#endif
\ No newline at end of file
include/spconv/maxpool.h
View file @
19e73bbe
// Copyright 2019 Yan Yan
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//
// http://www.apache.org/licenses/LICENSE-2.0
//
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -16,25 +16,20 @@
#define SPARSE_MAXPOOL_FUNCTOR_H_
#include <tensorview/tensorview.h>
namespace
spconv
{
namespace
functor
{
namespace
spconv
{
namespace
functor
{
template
<
typename
Device
,
typename
T
,
typename
Index
>
struct
SparseMaxPoolForwardFunctor
{
void
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
T
>
outFeatures
,
struct
SparseMaxPoolForwardFunctor
{
void
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
T
>
outFeatures
,
tv
::
TensorView
<
const
T
>
inFeatures
,
tv
::
TensorView
<
const
Index
>
indices
,
int
size
);
};
template
<
typename
Device
,
typename
T
,
typename
Index
>
struct
SparseMaxPoolBackwardFunctor
{
void
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
T
>
outFeatures
,
struct
SparseMaxPoolBackwardFunctor
{
void
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
const
T
>
outFeatures
,
tv
::
TensorView
<
const
T
>
inFeatures
,
tv
::
TensorView
<
const
T
>
dout
,
tv
::
TensorView
<
T
>
din
,
tv
::
TensorView
<
const
T
>
dout
,
tv
::
TensorView
<
T
>
din
,
tv
::
TensorView
<
const
Index
>
indices
,
int
size
);
};
...
...
include/spconv/nms.h
View file @
19e73bbe
// Copyright 2019 Yan Yan
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//
// http://www.apache.org/licenses/LICENSE-2.0
//
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -16,13 +16,13 @@
#define NMS_CPU_H
#include <pybind11/pybind11.h>
// must include pybind11/stl.h if using containers in STL in arguments.
#include "box_iou.h"
#include "nms_gpu.h"
#include <algorithm>
#include <boost/geometry.hpp>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <vector>
#include "box_iou.h"
#include "nms_gpu.h"
namespace
spconv
{
namespace
py
=
pybind11
;
using
namespace
pybind11
::
literals
;
...
...
@@ -181,7 +181,7 @@ std::vector<int> rotate_non_max_suppression_cpu(py::array_t<DType> box_corners,
}
return
keep
;
}
#ifdef
SPCON
V_CUDA
#ifdef
T
V_CUDA
constexpr
int
const
threadsPerBlock
=
sizeof
(
unsigned
long
long
)
*
8
;
template
<
typename
DType
>
...
...
include/spconv/nms_functor.h
View file @
19e73bbe
// Copyright 2019 Yan Yan
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//
// http://www.apache.org/licenses/LICENSE-2.0
//
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -16,24 +16,19 @@
#define NMS_FUNCTOR_H_
#include <tensorview/tensorview.h>
namespace
spconv
{
namespace
functor
{
namespace
spconv
{
namespace
functor
{
template
<
typename
Device
,
typename
T
,
typename
Index
>
struct
NonMaxSupressionFunctor
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
Index
>
keep
,
tv
::
TensorView
<
const
T
>
boxes
,
T
threshold
,
T
eps
);
struct
NonMaxSupressionFunctor
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
Index
>
keep
,
tv
::
TensorView
<
const
T
>
boxes
,
T
threshold
,
T
eps
);
};
template
<
typename
Device
,
typename
T
,
typename
Index
>
struct
rotateNonMaxSupressionFunctor
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
Index
>
keep
,
tv
::
TensorView
<
const
T
>
boxCorners
,
tv
::
TensorView
<
const
T
>
standupIoU
,
T
threshold
);
struct
rotateNonMaxSupressionFunctor
{
Index
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
Index
>
keep
,
tv
::
TensorView
<
const
T
>
boxCorners
,
tv
::
TensorView
<
const
T
>
standupIoU
,
T
threshold
);
};
}
// namespace functor
...
...
include/spconv/nms_gpu.h
View file @
19e73bbe
// Copyright 2019 Yan Yan
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//
// http://www.apache.org/licenses/LICENSE-2.0
//
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
include/spconv/nms_ops.h
View file @
19e73bbe
// Copyright 2019 Yan Yan
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//
// http://www.apache.org/licenses/LICENSE-2.0
//
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -16,35 +16,35 @@
#define NMS_TORCH_OP_H_
#include <spconv/indice.h>
#include <spconv/nms_functor.h>
#include <spconv/reordering.h>
#include <tensorview/torch_utils.h>
#include <torch/script.h>
#include <torch_utils.h>
#include <utility/timer.h>
#include <spconv/nms_functor.h>
namespace
spconv
{
// torch.jit's doc says only support int64, so we need to convert to int32.
template
<
typename
T
>
torch
::
Tensor
nonMaxSuppression
(
torch
::
Tensor
boxes
,
torch
::
Tensor
scores
,
int64_t
p
re
MaxSize
,
int64_t
postMaxSize
,
double
thresh
,
double
eps
)
{
torch
::
Tensor
nonMaxSuppression
(
torch
::
Tensor
boxes
,
torch
::
Tensor
scores
,
int64_t
preMaxSize
,
int64_t
p
ost
MaxSize
,
double
thresh
,
double
eps
)
{
// auto timer = spconv::CudaContextTimer<>();
tv
::
check_torch_dtype
<
T
>
(
boxes
);
auto
resOptions
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
boxes
.
device
());
if
(
boxes
.
size
(
0
)
==
0
){
return
torch
::
zeros
({
0
},
resOptions
);
if
(
boxes
.
size
(
0
)
==
0
)
{
return
torch
::
zeros
({
0
},
resOptions
);
}
torch
::
Tensor
indices
;
if
(
preMaxSize
>
0
){
auto
numKeepedScores
=
scores
.
size
(
0
);
preMaxSize
=
std
::
min
(
numKeepedScores
,
preMaxSize
);
auto
res
=
torch
::
topk
(
scores
,
preMaxSize
);
indices
=
std
::
get
<
1
>
(
res
);
boxes
=
torch
::
index_select
(
boxes
,
0
,
indices
);
}
else
{
indices
=
std
::
get
<
1
>
(
torch
::
sort
(
scores
));
boxes
=
torch
::
index_select
(
boxes
,
0
,
indices
);
if
(
preMaxSize
>
0
)
{
auto
numKeepedScores
=
scores
.
size
(
0
);
preMaxSize
=
std
::
min
(
numKeepedScores
,
preMaxSize
);
auto
res
=
torch
::
topk
(
scores
,
preMaxSize
);
indices
=
std
::
get
<
1
>
(
res
);
boxes
=
torch
::
index_select
(
boxes
,
0
,
indices
);
}
else
{
indices
=
std
::
get
<
1
>
(
torch
::
sort
(
scores
));
boxes
=
torch
::
index_select
(
boxes
,
0
,
indices
);
}
if
(
boxes
.
size
(
0
)
==
0
)
return
torch
::
zeros
({
0
},
resOptions
);
...
...
@@ -54,16 +54,16 @@ nonMaxSuppression(torch::Tensor boxes, torch::Tensor scores, int64_t preMaxSize,
if
(
boxes
.
device
().
type
()
==
torch
::
kCPU
)
{
auto
nmsFunctor
=
functor
::
NonMaxSupressionFunctor
<
tv
::
CPU
,
T
,
int64_t
>
();
keepNum
=
nmsFunctor
(
tv
::
CPU
(),
tv
::
torch2tv
<
int64_t
>
(
keep
),
tv
::
torch2tv
<
const
T
>
(
boxes
),
T
(
thresh
),
T
(
eps
));
}
else
{
tv
::
torch2tv
<
const
T
>
(
boxes
),
T
(
thresh
),
T
(
eps
));
}
else
{
TV_ASSERT_RT_ERR
(
false
,
"not implemented"
);
}
if
(
postMaxSize
<=
0
){
if
(
postMaxSize
<=
0
)
{
postMaxSize
=
keepNum
;
}
// std::cout << keep << std::endl;
keep
=
keep
.
slice
(
0
,
0
,
std
::
min
(
keepNum
,
postMaxSize
));
if
(
preMaxSize
>
0
){
if
(
preMaxSize
>
0
)
{
return
torch
::
index_select
(
indices
,
0
,
keep
);
}
return
keep
;
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment