Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
19e73bbe
"libraries/vscode:/vscode.git/clone" did not exist on "691ee61997aa13f8a0bf0c7fffd72a818a0d0c8d"
Commit
19e73bbe
authored
May 20, 2020
by
Yan Yan
Browse files
format code with clang-format, better c++ code
parent
c336139f
Changes
77
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2825 additions
and
2975 deletions
+2825
-2975
include/tsl/robin_growth_policy.h
include/tsl/robin_growth_policy.h
+238
-228
include/tsl/robin_hash.h
include/tsl/robin_hash.h
+1219
-1274
include/tsl/robin_map.h
include/tsl/robin_map.h
+654
-623
include/utility/timer.h
include/utility/timer.h
+2
-2
setup.py
setup.py
+6
-7
spconv/__init__.py
spconv/__init__.py
+17
-12
spconv/conv.py
spconv/conv.py
+121
-126
spconv/functional.py
spconv/functional.py
+42
-63
spconv/modules.py
spconv/modules.py
+5
-4
spconv/ops.py
spconv/ops.py
+62
-54
spconv/pool.py
spconv/pool.py
+21
-35
spconv/tables.py
spconv/tables.py
+13
-11
spconv/test_utils.py
spconv/test_utils.py
+50
-46
spconv/utils/__init__.py
spconv/utils/__init__.py
+50
-41
src/cuhash/debugging.cpp
src/cuhash/debugging.cpp
+17
-18
src/cuhash/debugging.cu
src/cuhash/debugging.cu
+57
-74
src/cuhash/hash_functions.cpp
src/cuhash/hash_functions.cpp
+3
-5
src/cuhash/hash_functions.cu
src/cuhash/hash_functions.cu
+11
-13
src/cuhash/hash_table.cpp
src/cuhash/hash_table.cpp
+154
-183
src/cuhash/hash_table.cu
src/cuhash/hash_table.cu
+83
-156
No files found.
include/tsl/robin_growth_policy.h
View file @
19e73bbe
/**
/**
* MIT License
* MIT License
*
*
* Copyright (c) 2017 Tessil
* Copyright (c) 2017 Tessil
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
* furnished to do so, subject to the following conditions:
*
*
* The above copyright notice and this permission notice shall be included in
all
* The above copyright notice and this permission notice shall be included in
* copies or substantial portions of the Software.
*
all
copies or substantial portions of the Software.
*
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
...
@@ -22,8 +22,7 @@
...
@@ -22,8 +22,7 @@
* SOFTWARE.
* SOFTWARE.
*/
*/
#ifndef TSL_ROBIN_GROWTH_POLICY_H
#ifndef TSL_ROBIN_GROWTH_POLICY_H
#define TSL_ROBIN_GROWTH_POLICY_H
#define TSL_ROBIN_GROWTH_POLICY_H
#include <algorithm>
#include <algorithm>
#include <array>
#include <array>
...
@@ -35,226 +34,237 @@
...
@@ -35,226 +34,237 @@
#include <ratio>
#include <ratio>
#include <stdexcept>
#include <stdexcept>
#ifdef TSL_DEBUG
#ifdef TSL_DEBUG
#
define tsl_rh_assert(expr) assert(expr)
#define tsl_rh_assert(expr) assert(expr)
#else
#else
#
define tsl_rh_assert(expr) (static_cast<void>(0))
#define tsl_rh_assert(expr) (static_cast<void>(0))
#endif
#endif
/**
/**
* If exceptions are enabled, throw the exception passed in parameter, otherwise call std::terminate.
* If exceptions are enabled, throw the exception passed in parameter, otherwise
* call std::terminate.
*/
*/
#if (defined(__cpp_exceptions) || defined(__EXCEPTIONS) || (defined (_MSC_VER) && defined (_CPPUNWIND))) && !defined(TSL_NO_EXCEPTIONS)
#if (defined(__cpp_exceptions) || defined(__EXCEPTIONS) || \
# define TSL_RH_THROW_OR_TERMINATE(ex, msg) throw ex(msg)
(defined(_MSC_VER) && defined(_CPPUNWIND))) && \
!defined(TSL_NO_EXCEPTIONS)
#define TSL_RH_THROW_OR_TERMINATE(ex, msg) throw ex(msg)
#else
#ifdef NDEBUG
#define TSL_RH_THROW_OR_TERMINATE(ex, msg) std::terminate()
#else
#else
# ifdef NDEBUG
#include <cstdio>
# define TSL_RH_THROW_OR_TERMINATE(ex, msg) std::terminate()
#define TSL_RH_THROW_OR_TERMINATE(ex, msg) \
# else
do { \
# include <cstdio>
std::fprintf(stderr, msg); \
# define TSL_RH_THROW_OR_TERMINATE(ex, msg) do { std::fprintf(stderr, msg); std::terminate(); } while(0)
std::terminate(); \
# endif
} while (0)
#endif
#endif
#endif
#if defined(__GNUC__) || defined(__clang__)
#if defined(__GNUC__) || defined(__clang__)
#
define TSL_RH_LIKELY(exp) (__builtin_expect(!!(exp), true))
#define TSL_RH_LIKELY(exp) (__builtin_expect(!!(exp), true))
#else
#else
#
define TSL_RH_LIKELY(exp) (exp)
#define TSL_RH_LIKELY(exp) (exp)
#endif
#endif
namespace
tsl
{
namespace
tsl
{
namespace
rh
{
namespace
rh
{
/**
/**
* Grow the hash table by a factor of GrowthFactor keeping the bucket count to a power of two. It allows
* Grow the hash table by a factor of GrowthFactor keeping the bucket count to a
* the table to use a mask operation instead of a modulo operation to map a hash to a bucket.
* power of two. It allows the table to use a mask operation instead of a modulo
*
* operation to map a hash to a bucket.
*
* GrowthFactor must be a power of two >= 2.
* GrowthFactor must be a power of two >= 2.
*/
*/
template
<
std
::
size_t
GrowthFactor
>
template
<
std
::
size_t
GrowthFactor
>
class
power_of_two_growth_policy
{
class
power_of_two_growth_policy
{
public:
public:
/**
/**
* Called on the hash table creation and on rehash. The number of buckets for the table is passed in parameter.
* Called on the hash table creation and on rehash. The number of buckets for
* This number is a minimum, the policy may update this value with a higher value if needed (but not lower).
* the table is passed in parameter. This number is a minimum, the policy may
*
* update this value with a higher value if needed (but not lower).
* If 0 is given, min_bucket_count_in_out must still be 0 after the policy creation and
*
* bucket_for_hash must always return 0 in this case.
* If 0 is given, min_bucket_count_in_out must still be 0 after the policy
*/
* creation and bucket_for_hash must always return 0 in this case.
explicit
power_of_two_growth_policy
(
std
::
size_t
&
min_bucket_count_in_out
)
{
*/
if
(
min_bucket_count_in_out
>
max_bucket_count
())
{
explicit
power_of_two_growth_policy
(
std
::
size_t
&
min_bucket_count_in_out
)
{
TSL_RH_THROW_OR_TERMINATE
(
std
::
length_error
,
"The hash table exceeds its maxmimum size."
);
if
(
min_bucket_count_in_out
>
max_bucket_count
())
{
}
TSL_RH_THROW_OR_TERMINATE
(
std
::
length_error
,
"The hash table exceeds its maxmimum size."
);
if
(
min_bucket_count_in_out
>
0
)
{
min_bucket_count_in_out
=
round_up_to_power_of_two
(
min_bucket_count_in_out
);
m_mask
=
min_bucket_count_in_out
-
1
;
}
else
{
m_mask
=
0
;
}
}
/**
* Return the bucket [0, bucket_count()) to which the hash belongs.
* If bucket_count() is 0, it must always return 0.
*/
std
::
size_t
bucket_for_hash
(
std
::
size_t
hash
)
const
noexcept
{
return
hash
&
m_mask
;
}
/**
* Return the number of buckets that should be used on next growth.
*/
std
::
size_t
next_bucket_count
()
const
{
if
((
m_mask
+
1
)
>
max_bucket_count
()
/
GrowthFactor
)
{
TSL_RH_THROW_OR_TERMINATE
(
std
::
length_error
,
"The hash table exceeds its maxmimum size."
);
}
return
(
m_mask
+
1
)
*
GrowthFactor
;
}
}
/**
if
(
min_bucket_count_in_out
>
0
)
{
* Return the maximum number of buckets supported by the policy.
min_bucket_count_in_out
=
*/
round_up_to_power_of_two
(
min_bucket_count_in_out
);
std
::
size_t
max
_bucket_count
()
const
{
m_mask
=
min
_bucket_count
_in_out
-
1
;
// Largest power of two.
}
else
{
return
(
std
::
numeric_limits
<
std
::
size_t
>::
max
()
/
2
)
+
1
;
m_mask
=
0
;
}
}
}
/**
* Reset the growth policy as if it was created with a bucket count of 0.
/**
* After a clear, the policy must always return 0 when bucket_for_hash is called.
* Return the bucket [0, bucket_count()) to which the hash belongs.
*/
* If bucket_count() is 0, it must always return 0.
void
clear
()
noexcept
{
*/
m_mask
=
0
;
std
::
size_t
bucket_for_hash
(
std
::
size_t
hash
)
const
noexcept
{
return
hash
&
m_mask
;
}
/**
* Return the number of buckets that should be used on next growth.
*/
std
::
size_t
next_bucket_count
()
const
{
if
((
m_mask
+
1
)
>
max_bucket_count
()
/
GrowthFactor
)
{
TSL_RH_THROW_OR_TERMINATE
(
std
::
length_error
,
"The hash table exceeds its maxmimum size."
);
}
}
return
(
m_mask
+
1
)
*
GrowthFactor
;
}
/**
* Return the maximum number of buckets supported by the policy.
*/
std
::
size_t
max_bucket_count
()
const
{
// Largest power of two.
return
(
std
::
numeric_limits
<
std
::
size_t
>::
max
()
/
2
)
+
1
;
}
/**
* Reset the growth policy as if it was created with a bucket count of 0.
* After a clear, the policy must always return 0 when bucket_for_hash is
* called.
*/
void
clear
()
noexcept
{
m_mask
=
0
;
}
private:
private:
static
std
::
size_t
round_up_to_power_of_two
(
std
::
size_t
value
)
{
static
std
::
size_t
round_up_to_power_of_two
(
std
::
size_t
value
)
{
if
(
is_power_of_two
(
value
))
{
if
(
is_power_of_two
(
value
))
{
return
value
;
return
value
;
}
}
if
(
value
==
0
)
{
if
(
value
==
0
)
{
return
1
;
return
1
;
}
--
value
;
for
(
std
::
size_t
i
=
1
;
i
<
sizeof
(
std
::
size_t
)
*
CHAR_BIT
;
i
*=
2
)
{
value
|=
value
>>
i
;
}
return
value
+
1
;
}
}
static
constexpr
bool
is_power_of_two
(
std
::
size_t
value
)
{
--
value
;
return
value
!=
0
&&
(
value
&
(
value
-
1
))
==
0
;
for
(
std
::
size_t
i
=
1
;
i
<
sizeof
(
std
::
size_t
)
*
CHAR_BIT
;
i
*=
2
)
{
value
|=
value
>>
i
;
}
}
return
value
+
1
;
}
static
constexpr
bool
is_power_of_two
(
std
::
size_t
value
)
{
return
value
!=
0
&&
(
value
&
(
value
-
1
))
==
0
;
}
protected:
protected:
static_assert
(
is_power_of_two
(
GrowthFactor
)
&&
GrowthFactor
>=
2
,
"GrowthFactor must be a power of two >= 2."
);
static_assert
(
is_power_of_two
(
GrowthFactor
)
&&
GrowthFactor
>=
2
,
"GrowthFactor must be a power of two >= 2."
);
std
::
size_t
m_mask
;
};
std
::
size_t
m_mask
;
};
/**
/**
* Grow the hash table by GrowthFactor::num / GrowthFactor::den and use a modulo to map a hash
* Grow the hash table by GrowthFactor::num / GrowthFactor::den and use a modulo
* to a bucket. Slower but it can be useful if you want a slower growth.
* to map a hash to a bucket. Slower but it can be useful if you want a slower
* growth.
*/
*/
template
<
class
GrowthFactor
=
std
::
ratio
<
3
,
2
>
>
template
<
class
GrowthFactor
=
std
::
ratio
<
3
,
2
>
>
class
mod_growth_policy
{
class
mod_growth_policy
{
public:
public:
explicit
mod_growth_policy
(
std
::
size_t
&
min_bucket_count_in_out
)
{
explicit
mod_growth_policy
(
std
::
size_t
&
min_bucket_count_in_out
)
{
if
(
min_bucket_count_in_out
>
max_bucket_count
())
{
if
(
min_bucket_count_in_out
>
max_bucket_count
())
{
TSL_RH_THROW_OR_TERMINATE
(
std
::
length_error
,
"The hash table exceeds its maxmimum size."
);
TSL_RH_THROW_OR_TERMINATE
(
std
::
length_error
,
}
"The hash table exceeds its maxmimum size."
);
if
(
min_bucket_count_in_out
>
0
)
{
m_mod
=
min_bucket_count_in_out
;
}
else
{
m_mod
=
1
;
}
}
}
std
::
size_t
bucket_for_hash
(
std
::
size_t
hash
)
const
noexcept
{
if
(
min_bucket_count_in_out
>
0
)
{
return
hash
%
m_mod
;
m_mod
=
min_bucket_count_in_out
;
}
else
{
m_mod
=
1
;
}
}
}
std
::
size_t
next_bucket_count
()
const
{
if
(
m_mod
==
max_bucket_count
())
{
std
::
size_t
bucket_for_hash
(
std
::
size_t
hash
)
const
noexcept
{
TSL_RH_THROW_OR_TERMINATE
(
std
::
length_error
,
"The hash table exceeds its maxmimum size."
);
return
hash
%
m_mod
;
}
}
const
double
next_bucket_count
=
std
::
ceil
(
double
(
m_mod
)
*
REHASH_SIZE_MULTIPLICATION_FACTOR
);
std
::
size_t
next_bucket_count
()
const
{
if
(
!
std
::
isnormal
(
next_bucket_count
))
{
if
(
m_mod
==
max_bucket_count
())
{
TSL_RH_THROW_OR_TERMINATE
(
std
::
length_error
,
"The hash table exceeds its maxmimum size."
);
TSL_RH_THROW_OR_TERMINATE
(
std
::
length_error
,
}
"The hash table exceeds its maxmimum size."
);
if
(
next_bucket_count
>
double
(
max_bucket_count
()))
{
return
max_bucket_count
();
}
else
{
return
std
::
size_t
(
next_bucket_count
);
}
}
}
std
::
size_t
max_bucket_count
()
const
{
const
double
next_bucket_count
=
return
MAX_BUCKET_COUNT
;
std
::
ceil
(
double
(
m_mod
)
*
REHASH_SIZE_MULTIPLICATION_FACTOR
);
if
(
!
std
::
isnormal
(
next_bucket_count
))
{
TSL_RH_THROW_OR_TERMINATE
(
std
::
length_error
,
"The hash table exceeds its maxmimum size."
);
}
}
void
clear
()
noexcept
{
if
(
next_bucket_count
>
double
(
max_bucket_count
()))
{
m_mod
=
1
;
return
max_bucket_count
();
}
else
{
return
std
::
size_t
(
next_bucket_count
);
}
}
}
std
::
size_t
max_bucket_count
()
const
{
return
MAX_BUCKET_COUNT
;
}
void
clear
()
noexcept
{
m_mod
=
1
;
}
private:
private:
static
constexpr
double
REHASH_SIZE_MULTIPLICATION_FACTOR
=
1.0
*
GrowthFactor
::
num
/
GrowthFactor
::
den
;
static
constexpr
double
REHASH_SIZE_MULTIPLICATION_FACTOR
=
static
const
std
::
size_t
MAX_BUCKET_COUNT
=
1.0
*
GrowthFactor
::
num
/
GrowthFactor
::
den
;
std
::
size_t
(
double
(
static
const
std
::
size_t
MAX_BUCKET_COUNT
=
std
::
numeric_limits
<
std
::
size_t
>::
max
()
/
REHASH_SIZE_MULTIPLICATION_FACTOR
std
::
size_t
(
double
(
std
::
numeric_limits
<
std
::
size_t
>::
max
()
/
));
REHASH_SIZE_MULTIPLICATION_FACTOR
));
static_assert
(
REHASH_SIZE_MULTIPLICATION_FACTOR
>=
1.1
,
"Growth factor should be >= 1.1."
);
std
::
size_t
m_mod
;
};
static_assert
(
REHASH_SIZE_MULTIPLICATION_FACTOR
>=
1.1
,
"Growth factor should be >= 1.1."
);
std
::
size_t
m_mod
;
};
namespace
detail
{
namespace
detail
{
static
constexpr
const
std
::
array
<
std
::
size_t
,
40
>
PRIMES
=
{{
static
constexpr
const
std
::
array
<
std
::
size_t
,
40
>
PRIMES
=
{
1ul
,
5ul
,
17ul
,
29ul
,
37ul
,
53ul
,
67ul
,
79ul
,
97ul
,
131ul
,
193ul
,
257ul
,
389ul
,
521ul
,
769ul
,
1031ul
,
{
1ul
,
5ul
,
17ul
,
29ul
,
37ul
,
1543ul
,
2053ul
,
3079ul
,
6151ul
,
12289ul
,
24593ul
,
49157ul
,
98317ul
,
196613ul
,
393241ul
,
786433ul
,
53ul
,
67ul
,
79ul
,
97ul
,
131ul
,
1572869ul
,
3145739ul
,
6291469ul
,
12582917ul
,
25165843ul
,
50331653ul
,
100663319ul
,
201326611ul
,
193ul
,
257ul
,
389ul
,
521ul
,
769ul
,
402653189ul
,
805306457ul
,
1610612741ul
,
3221225473ul
,
4294967291ul
1031ul
,
1543ul
,
2053ul
,
3079ul
,
6151ul
,
}};
12289ul
,
24593ul
,
49157ul
,
98317ul
,
196613ul
,
393241ul
,
786433ul
,
1572869ul
,
3145739ul
,
6291469ul
,
template
<
unsigned
int
IPrime
>
12582917ul
,
25165843ul
,
50331653ul
,
100663319ul
,
201326611ul
,
static
constexpr
std
::
size_t
mod
(
std
::
size_t
hash
)
{
return
hash
%
PRIMES
[
IPrime
];
}
402653189ul
,
805306457ul
,
1610612741ul
,
3221225473ul
,
4294967291ul
}};
// MOD_PRIME[iprime](hash) returns hash % PRIMES[iprime]. This table allows for faster modulo as the
// compiler can optimize the modulo code better with a constant known at the compilation.
static
constexpr
const
std
::
array
<
std
::
size_t
(
*
)(
std
::
size_t
),
40
>
MOD_PRIME
=
{{
&
mod
<
0
>
,
&
mod
<
1
>
,
&
mod
<
2
>
,
&
mod
<
3
>
,
&
mod
<
4
>
,
&
mod
<
5
>
,
&
mod
<
6
>
,
&
mod
<
7
>
,
&
mod
<
8
>
,
&
mod
<
9
>
,
&
mod
<
10
>
,
&
mod
<
11
>
,
&
mod
<
12
>
,
&
mod
<
13
>
,
&
mod
<
14
>
,
&
mod
<
15
>
,
&
mod
<
16
>
,
&
mod
<
17
>
,
&
mod
<
18
>
,
&
mod
<
19
>
,
&
mod
<
20
>
,
&
mod
<
21
>
,
&
mod
<
22
>
,
&
mod
<
23
>
,
&
mod
<
24
>
,
&
mod
<
25
>
,
&
mod
<
26
>
,
&
mod
<
27
>
,
&
mod
<
28
>
,
&
mod
<
29
>
,
&
mod
<
30
>
,
&
mod
<
31
>
,
&
mod
<
32
>
,
&
mod
<
33
>
,
&
mod
<
34
>
,
&
mod
<
35
>
,
&
mod
<
36
>
,
&
mod
<
37
>
,
&
mod
<
38
>
,
&
mod
<
39
>
}};
template
<
unsigned
int
IPrime
>
static
constexpr
std
::
size_t
mod
(
std
::
size_t
hash
)
{
return
hash
%
PRIMES
[
IPrime
];
}
}
// MOD_PRIME[iprime](hash) returns hash % PRIMES[iprime]. This table allows for
// faster modulo as the compiler can optimize the modulo code better with a
// constant known at the compilation.
static
constexpr
const
std
::
array
<
std
::
size_t
(
*
)(
std
::
size_t
),
40
>
MOD_PRIME
=
{{
&
mod
<
0
>
,
&
mod
<
1
>
,
&
mod
<
2
>
,
&
mod
<
3
>
,
&
mod
<
4
>
,
&
mod
<
5
>
,
&
mod
<
6
>
,
&
mod
<
7
>
,
&
mod
<
8
>
,
&
mod
<
9
>
,
&
mod
<
10
>
,
&
mod
<
11
>
,
&
mod
<
12
>
,
&
mod
<
13
>
,
&
mod
<
14
>
,
&
mod
<
15
>
,
&
mod
<
16
>
,
&
mod
<
17
>
,
&
mod
<
18
>
,
&
mod
<
19
>
,
&
mod
<
20
>
,
&
mod
<
21
>
,
&
mod
<
22
>
,
&
mod
<
23
>
,
&
mod
<
24
>
,
&
mod
<
25
>
,
&
mod
<
26
>
,
&
mod
<
27
>
,
&
mod
<
28
>
,
&
mod
<
29
>
,
&
mod
<
30
>
,
&
mod
<
31
>
,
&
mod
<
32
>
,
&
mod
<
33
>
,
&
mod
<
34
>
,
&
mod
<
35
>
,
&
mod
<
36
>
,
&
mod
<
37
>
,
&
mod
<
38
>
,
&
mod
<
39
>
}};
}
// namespace detail
/**
/**
* Grow the hash table by using prime numbers as bucket count. Slower than tsl::rh::power_of_two_growth_policy in
* Grow the hash table by using prime numbers as bucket count. Slower than
* general but will probably distribute the values around better in the buckets with a poor hash function.
* tsl::rh::power_of_two_growth_policy in general but will probably distribute
*
* the values around better in the buckets with a poor hash function.
* To allow the compiler to optimize the modulo operation, a lookup table is used with constant primes numbers.
*
*
* To allow the compiler to optimize the modulo operation, a lookup table is
* used with constant primes numbers.
*
* With a switch the code would look like:
* With a switch the code would look like:
* \code
* \code
* switch(iprime) { // iprime is the current prime of the hash table
* switch(iprime) { // iprime is the current prime of the hash table
...
@@ -265,60 +275,60 @@ static constexpr const std::array<std::size_t(*)(std::size_t), 40> MOD_PRIME = {
...
@@ -265,60 +275,60 @@ static constexpr const std::array<std::size_t(*)(std::size_t), 40> MOD_PRIME = {
* case 2: hash % 29ul;
* case 2: hash % 29ul;
* break;
* break;
* ...
* ...
* }
* }
* \endcode
* \endcode
*
*
* Due to the constant variable in the modulo the compiler is able to optimize the operation
* Due to the constant variable in the modulo the compiler is able to optimize
* by a series of multiplications, substractions and shifts.
* the operation by a series of multiplications, substractions and shifts.
*
*
* The 'hash % 5' could become something like 'hash - (hash * 0xCCCCCCCD) >> 34) * 5' in a 64 bits environement.
* The 'hash % 5' could become something like 'hash - (hash * 0xCCCCCCCD) >> 34)
* * 5' in a 64 bits environement.
*/
*/
class
prime_growth_policy
{
class
prime_growth_policy
{
public:
public:
explicit
prime_growth_policy
(
std
::
size_t
&
min_bucket_count_in_out
)
{
explicit
prime_growth_policy
(
std
::
size_t
&
min_bucket_count_in_out
)
{
auto
it_prime
=
std
::
lower_bound
(
detail
::
PRIMES
.
begin
(),
auto
it_prime
=
std
::
lower_bound
(
detail
::
PRIMES
.
end
(),
min_bucket_count_in_out
);
detail
::
PRIMES
.
begin
(),
detail
::
PRIMES
.
end
(),
min_bucket_count_in_out
);
if
(
it_prime
==
detail
::
PRIMES
.
end
())
{
if
(
it_prime
==
detail
::
PRIMES
.
end
())
{
TSL_RH_THROW_OR_TERMINATE
(
std
::
length_error
,
"The hash table exceeds its maxmimum size."
);
TSL_RH_THROW_OR_TERMINATE
(
std
::
length_error
,
}
"The hash table exceeds its maxmimum size."
);
m_iprime
=
static_cast
<
unsigned
int
>
(
std
::
distance
(
detail
::
PRIMES
.
begin
(),
it_prime
));
if
(
min_bucket_count_in_out
>
0
)
{
min_bucket_count_in_out
=
*
it_prime
;
}
else
{
min_bucket_count_in_out
=
0
;
}
}
}
std
::
size_t
bucket_for_hash
(
std
::
size_t
hash
)
const
noexcept
{
m_iprime
=
static_cast
<
unsigned
int
>
(
return
detail
::
MOD_PRIME
[
m_iprime
](
hash
);
std
::
distance
(
detail
::
PRIMES
.
begin
(),
it_prime
));
}
if
(
min_bucket_count_in_out
>
0
)
{
min_bucket_count_in_out
=
*
it_prime
;
std
::
size_t
next_bucket_count
()
const
{
}
else
{
if
(
m_iprime
+
1
>=
detail
::
PRIMES
.
size
())
{
min_bucket_count_in_out
=
0
;
TSL_RH_THROW_OR_TERMINATE
(
std
::
length_error
,
"The hash table exceeds its maxmimum size."
);
}
return
detail
::
PRIMES
[
m_iprime
+
1
];
}
std
::
size_t
max_bucket_count
()
const
{
return
detail
::
PRIMES
.
back
();
}
}
}
void
clear
()
noexcept
{
m_iprime
=
0
;
std
::
size_t
bucket_for_hash
(
std
::
size_t
hash
)
const
noexcept
{
return
detail
::
MOD_PRIME
[
m_iprime
](
hash
);
}
std
::
size_t
next_bucket_count
()
const
{
if
(
m_iprime
+
1
>=
detail
::
PRIMES
.
size
())
{
TSL_RH_THROW_OR_TERMINATE
(
std
::
length_error
,
"The hash table exceeds its maxmimum size."
);
}
}
return
detail
::
PRIMES
[
m_iprime
+
1
];
}
std
::
size_t
max_bucket_count
()
const
{
return
detail
::
PRIMES
.
back
();
}
void
clear
()
noexcept
{
m_iprime
=
0
;
}
private:
private:
unsigned
int
m_iprime
;
unsigned
int
m_iprime
;
static_assert
(
std
::
numeric_limits
<
decltype
(
m_iprime
)
>::
max
()
>=
detail
::
PRIMES
.
size
(),
"The type of m_iprime is not big enough."
);
};
}
static_assert
(
std
::
numeric_limits
<
decltype
(
m_iprime
)
>::
max
()
>=
}
detail
::
PRIMES
.
size
(),
"The type of m_iprime is not big enough."
);
};
}
// namespace rh
}
// namespace tsl
#endif
#endif
include/tsl/robin_hash.h
View file @
19e73bbe
/**
/**
* MIT License
* MIT License
*
*
* Copyright (c) 2017 Tessil
* Copyright (c) 2017 Tessil
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
* furnished to do so, subject to the following conditions:
*
*
* The above copyright notice and this permission notice shall be included in
all
* The above copyright notice and this permission notice shall be included in
* copies or substantial portions of the Software.
*
all
copies or substantial portions of the Software.
*
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
...
@@ -22,9 +22,9 @@
...
@@ -22,9 +22,9 @@
* SOFTWARE.
* SOFTWARE.
*/
*/
#ifndef TSL_ROBIN_HASH_H
#ifndef TSL_ROBIN_HASH_H
#define TSL_ROBIN_HASH_H
#define TSL_ROBIN_HASH_H
#include "robin_growth_policy.h"
#include <algorithm>
#include <algorithm>
#include <cassert>
#include <cassert>
#include <cmath>
#include <cmath>
...
@@ -39,1377 +39,1322 @@
...
@@ -39,1377 +39,1322 @@
#include <type_traits>
#include <type_traits>
#include <utility>
#include <utility>
#include <vector>
#include <vector>
#include "robin_growth_policy.h"
namespace
tsl
{
namespace
tsl
{
namespace
detail_robin_hash
{
namespace
detail_robin_hash
{
template
<
typename
T
>
template
<
typename
T
>
struct
make_void
{
using
type
=
void
;
};
struct
make_void
{
using
type
=
void
;
};
template
<
typename
T
,
typename
=
void
>
template
<
typename
T
,
typename
=
void
>
struct
has_is_transparent
:
std
::
false_type
{
struct
has_is_transparent
:
std
::
false_type
{};
};
template
<
typename
T
>
template
<
typename
T
>
struct
has_is_transparent
<
T
,
typename
make_void
<
typename
T
::
is_transparent
>::
type
>:
std
::
true_type
{
struct
has_is_transparent
<
T
,
};
typename
make_void
<
typename
T
::
is_transparent
>::
type
>
:
std
::
true_type
{};
template
<
typename
U
>
template
<
typename
U
>
struct
is_power_of_two_policy
:
std
::
false_type
{};
struct
is_power_of_two_policy
:
std
::
false_type
{
};
template
<
std
::
size_t
GrowthFactor
>
template
<
std
::
size_t
GrowthFactor
>
struct
is_power_of_two_policy
<
tsl
::
rh
::
power_of_two_growth_policy
<
GrowthFactor
>>
:
std
::
true_type
{
struct
is_power_of_two_policy
<
tsl
::
rh
::
power_of_two_growth_policy
<
GrowthFactor
>>
};
:
std
::
true_type
{
};
// Only available in C++17, we need to be compatible with C++11
// Only available in C++17, we need to be compatible with C++11
template
<
class
T
>
template
<
class
T
>
const
T
&
clamp
(
const
T
&
v
,
const
T
&
lo
,
const
T
&
hi
)
{
const
T
&
clamp
(
const
T
&
v
,
const
T
&
lo
,
const
T
&
hi
)
{
return
std
::
min
(
hi
,
std
::
max
(
lo
,
v
));
return
std
::
min
(
hi
,
std
::
max
(
lo
,
v
));
}
}
using
truncated_hash_type
=
std
::
uint_least32_t
;
using
truncated_hash_type
=
std
::
uint_least32_t
;
/**
/**
* Helper class that stores a truncated hash if StoreHash is true and nothing otherwise.
* Helper class that stores a truncated hash if StoreHash is true and nothing
* otherwise.
*/
*/
template
<
bool
StoreHash
>
template
<
bool
StoreHash
>
class
bucket_entry_hash
{
class
bucket_entry_hash
{
public:
public:
bool
bucket_hash_equal
(
std
::
size_t
/*hash*/
)
const
noexcept
{
bool
bucket_hash_equal
(
std
::
size_t
/*hash*/
)
const
noexcept
{
return
true
;
}
return
true
;
}
truncated_hash_type
truncated_hash
()
const
noexcept
{
return
0
;
}
truncated_hash_type
truncated_hash
()
const
noexcept
{
return
0
;
}
protected:
protected:
void
set_hash
(
truncated_hash_type
/*hash*/
)
noexcept
{
void
set_hash
(
truncated_hash_type
/*hash*/
)
noexcept
{}
}
};
};
template
<
>
template
<
>
class
bucket_entry_hash
<
true
>
{
class
bucket_entry_hash
<
true
>
{
public:
public:
bool
bucket_hash_equal
(
std
::
size_t
hash
)
const
noexcept
{
bool
bucket_hash_equal
(
std
::
size_t
hash
)
const
noexcept
{
return
m_hash
==
truncated_hash_type
(
hash
);
return
m_hash
==
truncated_hash_type
(
hash
);
}
}
truncated_hash_type
truncated_hash
()
const
noexcept
{
truncated_hash_type
truncated_hash
()
const
noexcept
{
return
m_hash
;
}
return
m_hash
;
}
protected:
protected:
void
set_hash
(
truncated_hash_type
hash
)
noexcept
{
void
set_hash
(
truncated_hash_type
hash
)
noexcept
{
m_hash
=
truncated_hash_type
(
hash
);
m_hash
=
truncated_hash_type
(
hash
);
}
}
private:
truncated_hash_type
m_hash
;
};
private:
truncated_hash_type
m_hash
;
};
/**
/**
* Each bucket entry has:
* Each bucket entry has:
* - A value of type `ValueType`.
* - A value of type `ValueType`.
* - An integer to store how far the value of the bucket, if any, is from its ideal bucket
* - An integer to store how far the value of the bucket, if any, is from its
* (ex: if the current bucket 5 has the value 'foo' and `hash('foo') % nb_buckets` == 3,
* ideal bucket (ex: if the current bucket 5 has the value 'foo' and
* `dist_from_ideal_bucket()` will return 2 as the current value of the bucket is two
* `hash('foo') % nb_buckets` == 3, `dist_from_ideal_bucket()` will return 2 as
* buckets away from its ideal bucket)
* the current value of the bucket is two buckets away from its ideal bucket) If
* If there is no value in the bucket (i.e. `empty()` is true) `dist_from_ideal_bucket()` will be < 0.
* there is no value in the bucket (i.e. `empty()` is true)
* - A marker which tells us if the bucket is the last bucket of the bucket array (useful for the
* `dist_from_ideal_bucket()` will be < 0.
* iterator of the hash table).
* - A marker which tells us if the bucket is the last bucket of the bucket
* - If `StoreHash` is true, 32 bits of the hash of the value, if any, are also stored in the bucket.
* array (useful for the iterator of the hash table).
* If the size of the hash is more than 32 bits, it is truncated. We don't store the full hash
* - If `StoreHash` is true, 32 bits of the hash of the value, if any, are also
* as storing the hash is a potential opportunity to use the unused space due to the alignement
* stored in the bucket. If the size of the hash is more than 32 bits, it is
* of the bucket_entry structure. We can thus potentially store the hash without any extra space
* truncated. We don't store the full hash as storing the hash is a potential
* opportunity to use the unused space due to the alignement of the bucket_entry
* structure. We can thus potentially store the hash without any extra space
* (which would not be possible with 64 bits of the hash).
* (which would not be possible with 64 bits of the hash).
*/
*/
template
<
typename
ValueType
,
bool
StoreHash
>
template
<
typename
ValueType
,
bool
StoreHash
>
class
bucket_entry
:
public
bucket_entry_hash
<
StoreHash
>
{
class
bucket_entry
:
public
bucket_entry_hash
<
StoreHash
>
{
using
bucket_hash
=
bucket_entry_hash
<
StoreHash
>
;
using
bucket_hash
=
bucket_entry_hash
<
StoreHash
>
;
public:
public:
using
value_type
=
ValueType
;
using
value_type
=
ValueType
;
using
distance_type
=
std
::
int_least16_t
;
using
distance_type
=
std
::
int_least16_t
;
bucket_entry
()
noexcept
bucket_entry
()
noexcept
:
bucket_hash
(),
m_dist_from_ideal_bucket
(
EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET
),
:
bucket_hash
(),
m_last_bucket
(
false
)
m_dist_from_ideal_bucket
(
EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET
),
{
m_last_bucket
(
false
)
{
tsl_rh_assert
(
empty
());
tsl_rh_assert
(
empty
());
}
}
bucket_entry
(
bool
last_bucket
)
noexcept
:
bucket_hash
(),
m_dist_from_ideal_bucket
(
EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET
),
bucket_entry
(
bool
last_bucket
)
noexcept
m_last_bucket
(
last_bucket
)
:
bucket_hash
(),
{
m_dist_from_ideal_bucket
(
EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET
),
tsl_rh_assert
(
empty
());
m_last_bucket
(
last_bucket
)
{
}
tsl_rh_assert
(
empty
());
}
bucket_entry
(
const
bucket_entry
&
other
)
noexcept
(
std
::
is_nothrow_copy_constructible
<
value_type
>::
value
)
:
bucket_hash
(
other
),
bucket_entry
(
const
bucket_entry
&
other
)
noexcept
(
m_dist_from_ideal_bucket
(
EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET
),
std
::
is_nothrow_copy_constructible
<
value_type
>::
value
)
m_last_bucket
(
other
.
m_last_bucket
)
:
bucket_hash
(
other
),
{
m_dist_from_ideal_bucket
(
EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET
),
if
(
!
other
.
empty
())
{
m_last_bucket
(
other
.
m_last_bucket
)
{
::
new
(
static_cast
<
void
*>
(
std
::
addressof
(
m_value
)))
value_type
(
other
.
value
());
if
(
!
other
.
empty
())
{
m_dist_from_ideal_bucket
=
other
.
m_dist_from_ideal_bucket
;
::
new
(
static_cast
<
void
*>
(
std
::
addressof
(
m_value
)))
}
value_type
(
other
.
value
());
}
m_dist_from_ideal_bucket
=
other
.
m_dist_from_ideal_bucket
;
}
/**
}
* Never really used, but still necessary as we must call resize on an empty `std::vector<bucket_entry>`.
* and we need to support move-only types. See robin_hash constructor for details.
/**
*/
* Never really used, but still necessary as we must call resize on an empty
bucket_entry
(
bucket_entry
&&
other
)
noexcept
(
std
::
is_nothrow_move_constructible
<
value_type
>::
value
)
:
* `std::vector<bucket_entry>`. and we need to support move-only types. See
bucket_hash
(
std
::
move
(
other
)),
* robin_hash constructor for details.
m_dist_from_ideal_bucket
(
EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET
),
*/
m_last_bucket
(
other
.
m_last_bucket
)
bucket_entry
(
bucket_entry
&&
other
)
noexcept
(
{
std
::
is_nothrow_move_constructible
<
value_type
>::
value
)
if
(
!
other
.
empty
())
{
:
bucket_hash
(
std
::
move
(
other
)),
::
new
(
static_cast
<
void
*>
(
std
::
addressof
(
m_value
)))
value_type
(
std
::
move
(
other
.
value
()));
m_dist_from_ideal_bucket
(
EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET
),
m_dist_from_ideal_bucket
=
other
.
m_dist_from_ideal_bucket
;
m_last_bucket
(
other
.
m_last_bucket
)
{
}
if
(
!
other
.
empty
())
{
}
::
new
(
static_cast
<
void
*>
(
std
::
addressof
(
m_value
)))
value_type
(
std
::
move
(
other
.
value
()));
bucket_entry
&
operator
=
(
const
bucket_entry
&
other
)
m_dist_from_ideal_bucket
=
other
.
m_dist_from_ideal_bucket
;
noexcept
(
std
::
is_nothrow_copy_constructible
<
value_type
>::
value
)
}
{
}
if
(
this
!=
&
other
)
{
clear
();
bucket_entry
&
operator
=
(
const
bucket_entry
&
other
)
noexcept
(
std
::
is_nothrow_copy_constructible
<
value_type
>::
value
)
{
bucket_hash
::
operator
=
(
other
);
if
(
this
!=
&
other
)
{
if
(
!
other
.
empty
())
{
clear
();
::
new
(
static_cast
<
void
*>
(
std
::
addressof
(
m_value
)))
value_type
(
other
.
value
());
}
bucket_hash
::
operator
=
(
other
);
if
(
!
other
.
empty
())
{
m_dist_from_ideal_bucket
=
other
.
m_dist_from_ideal_bucket
;
::
new
(
static_cast
<
void
*>
(
std
::
addressof
(
m_value
)))
m_last_bucket
=
other
.
m_last_bucket
;
value_type
(
other
.
value
());
}
}
return
*
this
;
m_dist_from_ideal_bucket
=
other
.
m_dist_from_ideal_bucket
;
}
m_last_bucket
=
other
.
m_last_bucket
;
bucket_entry
&
operator
=
(
bucket_entry
&&
)
=
delete
;
~
bucket_entry
()
noexcept
{
clear
();
}
void
clear
()
noexcept
{
if
(
!
empty
())
{
destroy_value
();
m_dist_from_ideal_bucket
=
EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET
;
}
}
bool
empty
()
const
noexcept
{
return
m_dist_from_ideal_bucket
==
EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET
;
}
value_type
&
value
()
noexcept
{
tsl_rh_assert
(
!
empty
());
return
*
reinterpret_cast
<
value_type
*>
(
std
::
addressof
(
m_value
));
}
const
value_type
&
value
()
const
noexcept
{
tsl_rh_assert
(
!
empty
());
return
*
reinterpret_cast
<
const
value_type
*>
(
std
::
addressof
(
m_value
));
}
distance_type
dist_from_ideal_bucket
()
const
noexcept
{
return
m_dist_from_ideal_bucket
;
}
bool
last_bucket
()
const
noexcept
{
return
m_last_bucket
;
}
void
set_as_last_bucket
()
noexcept
{
m_last_bucket
=
true
;
}
template
<
typename
...
Args
>
void
set_value_of_empty_bucket
(
distance_type
dist_from_ideal_bucket
,
truncated_hash_type
hash
,
Args
&&
...
value_type_args
)
{
tsl_rh_assert
(
dist_from_ideal_bucket
>=
0
);
tsl_rh_assert
(
empty
());
::
new
(
static_cast
<
void
*>
(
std
::
addressof
(
m_value
)))
value_type
(
std
::
forward
<
Args
>
(
value_type_args
)...);
this
->
set_hash
(
hash
);
m_dist_from_ideal_bucket
=
dist_from_ideal_bucket
;
tsl_rh_assert
(
!
empty
());
}
void
swap_with_value_in_bucket
(
distance_type
&
dist_from_ideal_bucket
,
truncated_hash_type
&
hash
,
value_type
&
value
)
{
tsl_rh_assert
(
!
empty
());
using
std
::
swap
;
swap
(
value
,
this
->
value
());
swap
(
dist_from_ideal_bucket
,
m_dist_from_ideal_bucket
);
// Avoid warning of unused variable if StoreHash is false
(
void
)
hash
;
if
(
StoreHash
)
{
const
truncated_hash_type
tmp_hash
=
this
->
truncated_hash
();
this
->
set_hash
(
hash
);
hash
=
tmp_hash
;
}
}
}
static
truncated_hash_type
truncate_hash
(
std
::
size_t
hash
)
noexcept
{
return
*
this
;
return
truncated_hash_type
(
hash
);
}
bucket_entry
&
operator
=
(
bucket_entry
&&
)
=
delete
;
~
bucket_entry
()
noexcept
{
clear
();
}
void
clear
()
noexcept
{
if
(
!
empty
())
{
destroy_value
();
m_dist_from_ideal_bucket
=
EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET
;
}
}
}
bool
empty
()
const
noexcept
{
return
m_dist_from_ideal_bucket
==
EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET
;
}
value_type
&
value
()
noexcept
{
tsl_rh_assert
(
!
empty
());
return
*
reinterpret_cast
<
value_type
*>
(
std
::
addressof
(
m_value
));
}
const
value_type
&
value
()
const
noexcept
{
tsl_rh_assert
(
!
empty
());
return
*
reinterpret_cast
<
const
value_type
*>
(
std
::
addressof
(
m_value
));
}
distance_type
dist_from_ideal_bucket
()
const
noexcept
{
return
m_dist_from_ideal_bucket
;
}
bool
last_bucket
()
const
noexcept
{
return
m_last_bucket
;
}
void
set_as_last_bucket
()
noexcept
{
m_last_bucket
=
true
;
}
template
<
typename
...
Args
>
void
set_value_of_empty_bucket
(
distance_type
dist_from_ideal_bucket
,
truncated_hash_type
hash
,
Args
&&
...
value_type_args
)
{
tsl_rh_assert
(
dist_from_ideal_bucket
>=
0
);
tsl_rh_assert
(
empty
());
::
new
(
static_cast
<
void
*>
(
std
::
addressof
(
m_value
)))
value_type
(
std
::
forward
<
Args
>
(
value_type_args
)...);
this
->
set_hash
(
hash
);
m_dist_from_ideal_bucket
=
dist_from_ideal_bucket
;
tsl_rh_assert
(
!
empty
());
}
void
swap_with_value_in_bucket
(
distance_type
&
dist_from_ideal_bucket
,
truncated_hash_type
&
hash
,
value_type
&
value
)
{
tsl_rh_assert
(
!
empty
());
using
std
::
swap
;
swap
(
value
,
this
->
value
());
swap
(
dist_from_ideal_bucket
,
m_dist_from_ideal_bucket
);
// Avoid warning of unused variable if StoreHash is false
(
void
)
hash
;
if
(
StoreHash
)
{
const
truncated_hash_type
tmp_hash
=
this
->
truncated_hash
();
this
->
set_hash
(
hash
);
hash
=
tmp_hash
;
}
}
static
truncated_hash_type
truncate_hash
(
std
::
size_t
hash
)
noexcept
{
return
truncated_hash_type
(
hash
);
}
private:
private:
void
destroy_value
()
noexcept
{
void
destroy_value
()
noexcept
{
tsl_rh_assert
(
!
empty
());
tsl_rh_assert
(
!
empty
());
value
().
~
value_type
();
value
().
~
value_type
();
}
}
private:
private:
using
storage
=
typename
std
::
aligned_storage
<
sizeof
(
value_type
),
alignof
(
value_type
)
>::
type
;
using
storage
=
typename
std
::
aligned_storage
<
sizeof
(
value_type
),
alignof
(
value_type
)
>::
type
;
static
const
distance_type
EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET
=
-
1
;
distance_type
m_dist_from_ideal_bucket
;
bool
m_last_bucket
;
storage
m_value
;
};
static
const
distance_type
EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET
=
-
1
;
distance_type
m_dist_from_ideal_bucket
;
bool
m_last_bucket
;
storage
m_value
;
};
/**
/**
* Internal common class used by `robin_map` and `robin_set`.
* Internal common class used by `robin_map` and `robin_set`.
*
*
* ValueType is what will be stored by `robin_hash` (usually `std::pair<Key, T>` for map and `Key` for set).
* ValueType is what will be stored by `robin_hash` (usually `std::pair<Key, T>`
*
* for map and `Key` for set).
* `KeySelect` should be a `FunctionObject` which takes a `ValueType` in parameter and returns a
*
* reference to the key.
* `KeySelect` should be a `FunctionObject` which takes a `ValueType` in
*
* parameter and returns a reference to the key.
* `ValueSelect` should be a `FunctionObject` which takes a `ValueType` in parameter and returns a
*
* reference to the value. `ValueSelect` should be void if there is no value (in a set for example).
* `ValueSelect` should be a `FunctionObject` which takes a `ValueType` in
*
* parameter and returns a reference to the value. `ValueSelect` should be void
* The strong exception guarantee only holds if the expression
* if there is no value (in a set for example).
* `std::is_nothrow_swappable<ValueType>::value && std::is_nothrow_move_constructible<ValueType>::value` is true.
*
*
* The strong exception guarantee only holds if the expression
* `std::is_nothrow_swappable<ValueType>::value &&
* std::is_nothrow_move_constructible<ValueType>::value` is true.
*
* Behaviour is undefined if the destructor of `ValueType` throws.
* Behaviour is undefined if the destructor of `ValueType` throws.
*/
*/
template
<
class
ValueType
,
template
<
class
ValueType
,
class
KeySelect
,
class
ValueSelect
,
class
Hash
,
class
KeySelect
,
class
KeyEqual
,
class
Allocator
,
bool
StoreHash
,
class
GrowthPolicy
>
class
ValueSelect
,
class
robin_hash
:
private
Hash
,
private
KeyEqual
,
private
GrowthPolicy
{
class
Hash
,
private:
class
KeyEqual
,
template
<
typename
U
>
class
Allocator
,
using
has_mapped_type
=
bool
StoreHash
,
typename
std
::
integral_constant
<
bool
,
!
std
::
is_same
<
U
,
void
>::
value
>
;
class
GrowthPolicy
>
class
robin_hash
:
private
Hash
,
private
KeyEqual
,
private
GrowthPolicy
{
static_assert
(
private:
noexcept
(
std
::
declval
<
GrowthPolicy
>
().
bucket_for_hash
(
std
::
size_t
(
0
))),
template
<
typename
U
>
"GrowthPolicy::bucket_for_hash must be noexcept."
);
using
has_mapped_type
=
typename
std
::
integral_constant
<
bool
,
!
std
::
is_same
<
U
,
void
>::
value
>
;
static_assert
(
noexcept
(
std
::
declval
<
GrowthPolicy
>
().
clear
()),
"GrowthPolicy::clear must be noexcept."
);
static_assert
(
noexcept
(
std
::
declval
<
GrowthPolicy
>
().
bucket_for_hash
(
std
::
size_t
(
0
))),
"GrowthPolicy::bucket_for_hash must be noexcept."
);
static_assert
(
noexcept
(
std
::
declval
<
GrowthPolicy
>
().
clear
()),
"GrowthPolicy::clear must be noexcept."
);
public:
public:
template
<
bool
IsConst
>
template
<
bool
IsConst
>
class
robin_iterator
;
class
robin_iterator
;
using
key_type
=
typename
KeySelect
::
key_type
;
using
key_type
=
typename
KeySelect
::
key_type
;
using
value_type
=
ValueType
;
using
value_type
=
ValueType
;
using
size_type
=
std
::
size_t
;
using
size_type
=
std
::
size_t
;
using
difference_type
=
std
::
ptrdiff_t
;
using
difference_type
=
std
::
ptrdiff_t
;
using
hasher
=
Hash
;
using
hasher
=
Hash
;
using
key_equal
=
KeyEqual
;
using
key_equal
=
KeyEqual
;
using
allocator_type
=
Allocator
;
using
allocator_type
=
Allocator
;
using
reference
=
value_type
&
;
using
reference
=
value_type
&
;
using
const_reference
=
const
value_type
&
;
using
const_reference
=
const
value_type
&
;
using
pointer
=
value_type
*
;
using
pointer
=
value_type
*
;
using
const_pointer
=
const
value_type
*
;
using
const_pointer
=
const
value_type
*
;
using
iterator
=
robin_iterator
<
false
>
;
using
iterator
=
robin_iterator
<
false
>
;
using
const_iterator
=
robin_iterator
<
true
>
;
using
const_iterator
=
robin_iterator
<
true
>
;
private:
private:
/**
/**
* Either store the hash because we are asked by the `StoreHash` template parameter
* Either store the hash because we are asked by the `StoreHash` template
* or store the hash because it doesn't cost us anything in size and can be used to speed up rehash.
* parameter or store the hash because it doesn't cost us anything in size and
*/
* can be used to speed up rehash.
static
constexpr
bool
STORE_HASH
=
StoreHash
||
*/
(
static
constexpr
bool
STORE_HASH
=
(
sizeof
(
tsl
::
detail_robin_hash
::
bucket_entry
<
value_type
,
true
>
)
==
StoreHash
||
sizeof
(
tsl
::
detail_robin_hash
::
bucket_entry
<
value_type
,
false
>
))
((
sizeof
(
tsl
::
detail_robin_hash
::
bucket_entry
<
value_type
,
true
>
)
==
&&
sizeof
(
tsl
::
detail_robin_hash
::
bucket_entry
<
value_type
,
false
>
))
&&
(
sizeof
(
std
::
size_t
)
==
sizeof
(
truncated_hash_type
)
||
(
sizeof
(
std
::
size_t
)
==
sizeof
(
truncated_hash_type
)
||
is_power_of_two_policy
<
GrowthPolicy
>::
value
)
is_power_of_two_policy
<
GrowthPolicy
>::
value
)
&&
&&
// Don't store the hash for primitive types with default hash.
// Don't store the hash for primitive types with default hash.
(
!
std
::
is_arithmetic
<
key_type
>::
value
||
(
!
std
::
is_arithmetic
<
key_type
>::
value
||
!
std
::
is_same
<
Hash
,
std
::
hash
<
key_type
>>::
value
));
!
std
::
is_same
<
Hash
,
std
::
hash
<
key_type
>>::
value
)
);
/**
* Only use the stored hash on lookup if we are explictly asked. We are not sure how slow
* the KeyEqual operation is. An extra comparison may slow things down with a fast KeyEqual.
*/
static
constexpr
bool
USE_STORED_HASH_ON_LOOKUP
=
StoreHash
;
/**
/**
* We can only use the hash on rehash if the size of the hash type is the same as the stored one or
* Only use the stored hash on lookup if we are explictly asked. We are not
* if we use a power of two modulo. In the case of the power of two modulo, we just mask
* sure how slow the KeyEqual operation is. An extra comparison may slow
* the least significant bytes, we just have to check that the truncated_hash_type didn't truncated
* things down with a fast KeyEqual.
* more bytes.
*/
*/
static
constexpr
bool
USE_STORED_HASH_ON_LOOKUP
=
StoreHash
;
static
bool
USE_STORED_HASH_ON_REHASH
(
size_type
bucket_count
)
{
(
void
)
bucket_count
;
if
(
STORE_HASH
&&
sizeof
(
std
::
size_t
)
==
sizeof
(
truncated_hash_type
))
{
return
true
;
}
else
if
(
STORE_HASH
&&
is_power_of_two_policy
<
GrowthPolicy
>::
value
)
{
tsl_rh_assert
(
bucket_count
>
0
);
return
(
bucket_count
-
1
)
<=
std
::
numeric_limits
<
truncated_hash_type
>::
max
();
}
else
{
return
false
;
}
}
using
bucket_entry
=
tsl
::
detail_robin_hash
::
bucket_entry
<
value_type
,
STORE_HASH
>
;
using
distance_type
=
typename
bucket_entry
::
distance_type
;
using
buckets_allocator
=
typename
std
::
allocator_traits
<
allocator_type
>::
template
rebind_alloc
<
bucket_entry
>;
using
buckets_container_type
=
std
::
vector
<
bucket_entry
,
buckets_allocator
>
;
public:
/**
* The 'operator*()' and 'operator->()' methods return a const reference and const pointer respectively to the
* stored value type.
*
* In case of a map, to get a mutable reference to the value associated to a key (the '.second' in the
* stored pair), you have to call 'value()'.
*
* The main reason for this is that if we returned a `std::pair<Key, T>&` instead
* of a `const std::pair<Key, T>&`, the user may modify the key which will put the map in a undefined state.
*/
template
<
bool
IsConst
>
class
robin_iterator
{
friend
class
robin_hash
;
private:
using
bucket_entry_ptr
=
typename
std
::
conditional
<
IsConst
,
const
bucket_entry
*
,
bucket_entry
*>::
type
;
robin_iterator
(
bucket_entry_ptr
bucket
)
noexcept
:
m_bucket
(
bucket
)
{
}
public:
using
iterator_category
=
std
::
forward_iterator_tag
;
using
value_type
=
const
typename
robin_hash
::
value_type
;
using
difference_type
=
std
::
ptrdiff_t
;
using
reference
=
value_type
&
;
using
pointer
=
value_type
*
;
robin_iterator
()
noexcept
{
}
// Copy constructor from iterator to const_iterator.
template
<
bool
TIsConst
=
IsConst
,
typename
std
::
enable_if
<
TIsConst
>
::
type
*
=
nullptr
>
robin_iterator
(
const
robin_iterator
<!
TIsConst
>&
other
)
noexcept
:
m_bucket
(
other
.
m_bucket
)
{
}
robin_iterator
(
const
robin_iterator
&
other
)
=
default
;
robin_iterator
(
robin_iterator
&&
other
)
=
default
;
robin_iterator
&
operator
=
(
const
robin_iterator
&
other
)
=
default
;
robin_iterator
&
operator
=
(
robin_iterator
&&
other
)
=
default
;
const
typename
robin_hash
::
key_type
&
key
()
const
{
return
KeySelect
()(
m_bucket
->
value
());
}
template
<
class
U
=
ValueSelect
,
typename
std
::
enable_if
<
has_mapped_type
<
U
>
::
value
&&
IsConst
>::
type
*
=
nullptr
>
/**
const
typename
U
::
value_type
&
value
()
const
{
* We can only use the hash on rehash if the size of the hash type is the same
return
U
()(
m_bucket
->
value
());
* as the stored one or if we use a power of two modulo. In the case of the
}
* power of two modulo, we just mask the least significant bytes, we just have
* to check that the truncated_hash_type didn't truncated more bytes.
*/
static
bool
USE_STORED_HASH_ON_REHASH
(
size_type
bucket_count
)
{
(
void
)
bucket_count
;
if
(
STORE_HASH
&&
sizeof
(
std
::
size_t
)
==
sizeof
(
truncated_hash_type
))
{
return
true
;
}
else
if
(
STORE_HASH
&&
is_power_of_two_policy
<
GrowthPolicy
>::
value
)
{
tsl_rh_assert
(
bucket_count
>
0
);
return
(
bucket_count
-
1
)
<=
std
::
numeric_limits
<
truncated_hash_type
>::
max
();
}
else
{
return
false
;
}
}
template
<
class
U
=
ValueSelect
,
typename
std
::
enable_if
<
has_mapped_type
<
U
>
::
value
&&
!
IsConst
>::
type
*
=
nullptr
>
using
bucket_entry
=
typename
U
::
value_type
&
value
()
{
tsl
::
detail_robin_hash
::
bucket_entry
<
value_type
,
STORE_HASH
>
;
return
U
()(
m_bucket
->
value
());
using
distance_type
=
typename
bucket_entry
::
distance_type
;
}
using
buckets_allocator
=
typename
std
::
allocator_traits
<
reference
operator
*
()
const
{
allocator_type
>::
template
rebind_alloc
<
bucket_entry
>;
return
m_bucket
->
value
();
using
buckets_container_type
=
std
::
vector
<
bucket_entry
,
buckets_allocator
>
;
}
pointer
operator
->
()
const
{
return
std
::
addressof
(
m_bucket
->
value
());
}
robin_iterator
&
operator
++
()
{
while
(
true
)
{
if
(
m_bucket
->
last_bucket
())
{
++
m_bucket
;
return
*
this
;
}
++
m_bucket
;
if
(
!
m_bucket
->
empty
())
{
return
*
this
;
}
}
}
robin_iterator
operator
++
(
int
)
{
robin_iterator
tmp
(
*
this
);
++*
this
;
return
tmp
;
}
friend
bool
operator
==
(
const
robin_iterator
&
lhs
,
const
robin_iterator
&
rhs
)
{
return
lhs
.
m_bucket
==
rhs
.
m_bucket
;
}
friend
bool
operator
!=
(
const
robin_iterator
&
lhs
,
const
robin_iterator
&
rhs
)
{
return
!
(
lhs
==
rhs
);
}
private:
bucket_entry_ptr
m_bucket
;
};
public:
public:
#if defined(__cplusplus) && __cplusplus >= 201402L
/**
robin_hash
(
size_type
bucket_count
,
* The 'operator*()' and 'operator->()' methods return a const reference and
const
Hash
&
hash
,
* const pointer respectively to the stored value type.
const
KeyEqual
&
equal
,
*
const
Allocator
&
alloc
,
* In case of a map, to get a mutable reference to the value associated to a
float
min_load_factor
=
DEFAULT_MIN_LOAD_FACTOR
,
* key (the '.second' in the stored pair), you have to call 'value()'.
float
max_load_factor
=
DEFAULT_MAX_LOAD_FACTOR
)
:
*
Hash
(
hash
),
* The main reason for this is that if we returned a `std::pair<Key, T>&`
KeyEqual
(
equal
),
* instead of a `const std::pair<Key, T>&`, the user may modify the key which
GrowthPolicy
(
bucket_count
),
* will put the map in a undefined state.
m_buckets_data
(
*/
[
&
]()
{
template
<
bool
IsConst
>
class
robin_iterator
{
if
(
bucket_count
>
max_bucket_count
())
{
friend
class
robin_hash
;
TSL_RH_THROW_OR_TERMINATE
(
std
::
length_error
,
"The map exceeds its maximum bucket count."
);
private:
}
using
bucket_entry_ptr
=
typename
std
::
conditional
<
IsConst
,
const
bucket_entry
*
,
return
bucket_count
;
bucket_entry
*>::
type
;
}(),
alloc
),
robin_iterator
(
bucket_entry_ptr
bucket
)
noexcept
:
m_bucket
(
bucket
)
{}
m_buckets
(
m_buckets_data
.
empty
()
?
static_empty_bucket_ptr
()
:
m_buckets_data
.
data
()),
m_bucket_count
(
bucket_count
),
public:
m_nb_elements
(
0
),
using
iterator_category
=
std
::
forward_iterator_tag
;
m_grow_on_next_insert
(
false
),
using
value_type
=
const
typename
robin_hash
::
value_type
;
m_try_skrink_on_next_insert
(
false
)
using
difference_type
=
std
::
ptrdiff_t
;
{
using
reference
=
value_type
&
;
if
(
m_bucket_count
>
0
)
{
using
pointer
=
value_type
*
;
tsl_rh_assert
(
!
m_buckets_data
.
empty
());
m_buckets_data
.
back
().
set_as_last_bucket
();
robin_iterator
()
noexcept
{}
}
// Copy constructor from iterator to const_iterator.
this
->
min_load_factor
(
min_load_factor
);
template
<
bool
TIsConst
=
IsConst
,
this
->
max_load_factor
(
max_load_factor
);
typename
std
::
enable_if
<
TIsConst
>
::
type
*
=
nullptr
>
}
robin_iterator
(
const
robin_iterator
<!
TIsConst
>
&
other
)
noexcept
#else
:
m_bucket
(
other
.
m_bucket
)
{}
/**
* C++11 doesn't support the creation of a std::vector with a custom allocator and 'count' default-inserted elements.
robin_iterator
(
const
robin_iterator
&
other
)
=
default
;
* The needed contructor `explicit vector(size_type count, const Allocator& alloc = Allocator());` is only
robin_iterator
(
robin_iterator
&&
other
)
=
default
;
* available in C++14 and later. We thus must resize after using the `vector(const Allocator& alloc)` constructor.
robin_iterator
&
operator
=
(
const
robin_iterator
&
other
)
=
default
;
*
robin_iterator
&
operator
=
(
robin_iterator
&&
other
)
=
default
;
* We can't use `vector(size_type count, const T& value, const Allocator& alloc)` as it requires the
* value T to be copyable.
const
typename
robin_hash
::
key_type
&
key
()
const
{
*/
return
KeySelect
()(
m_bucket
->
value
());
robin_hash
(
size_type
bucket_count
,
const
Hash
&
hash
,
const
KeyEqual
&
equal
,
const
Allocator
&
alloc
,
float
min_load_factor
=
DEFAULT_MIN_LOAD_FACTOR
,
float
max_load_factor
=
DEFAULT_MAX_LOAD_FACTOR
)
:
Hash
(
hash
),
KeyEqual
(
equal
),
GrowthPolicy
(
bucket_count
),
m_buckets_data
(
alloc
),
m_buckets
(
static_empty_bucket_ptr
()),
m_bucket_count
(
bucket_count
),
m_nb_elements
(
0
),
m_grow_on_next_insert
(
false
),
m_try_skrink_on_next_insert
(
false
)
{
if
(
bucket_count
>
max_bucket_count
())
{
TSL_RH_THROW_OR_TERMINATE
(
std
::
length_error
,
"The map exceeds its maxmimum bucket count."
);
}
if
(
m_bucket_count
>
0
)
{
m_buckets_data
.
resize
(
m_bucket_count
);
m_buckets
=
m_buckets_data
.
data
();
tsl_rh_assert
(
!
m_buckets_data
.
empty
());
m_buckets_data
.
back
().
set_as_last_bucket
();
}
this
->
min_load_factor
(
min_load_factor
);
this
->
max_load_factor
(
max_load_factor
);
}
#endif
robin_hash
(
const
robin_hash
&
other
)
:
Hash
(
other
),
KeyEqual
(
other
),
GrowthPolicy
(
other
),
m_buckets_data
(
other
.
m_buckets_data
),
m_buckets
(
m_buckets_data
.
empty
()
?
static_empty_bucket_ptr
()
:
m_buckets_data
.
data
()),
m_bucket_count
(
other
.
m_bucket_count
),
m_nb_elements
(
other
.
m_nb_elements
),
m_load_threshold
(
other
.
m_load_threshold
),
m_max_load_factor
(
other
.
m_max_load_factor
),
m_grow_on_next_insert
(
other
.
m_grow_on_next_insert
),
m_min_load_factor
(
other
.
m_min_load_factor
),
m_try_skrink_on_next_insert
(
other
.
m_try_skrink_on_next_insert
)
{
}
robin_hash
(
robin_hash
&&
other
)
noexcept
(
std
::
is_nothrow_move_constructible
<
Hash
>::
value
&&
std
::
is_nothrow_move_constructible
<
KeyEqual
>::
value
&&
std
::
is_nothrow_move_constructible
<
GrowthPolicy
>::
value
&&
std
::
is_nothrow_move_constructible
<
buckets_container_type
>::
value
)
:
Hash
(
std
::
move
(
static_cast
<
Hash
&>
(
other
))),
KeyEqual
(
std
::
move
(
static_cast
<
KeyEqual
&>
(
other
))),
GrowthPolicy
(
std
::
move
(
static_cast
<
GrowthPolicy
&>
(
other
))),
m_buckets_data
(
std
::
move
(
other
.
m_buckets_data
)),
m_buckets
(
m_buckets_data
.
empty
()
?
static_empty_bucket_ptr
()
:
m_buckets_data
.
data
()),
m_bucket_count
(
other
.
m_bucket_count
),
m_nb_elements
(
other
.
m_nb_elements
),
m_load_threshold
(
other
.
m_load_threshold
),
m_max_load_factor
(
other
.
m_max_load_factor
),
m_grow_on_next_insert
(
other
.
m_grow_on_next_insert
),
m_min_load_factor
(
other
.
m_min_load_factor
),
m_try_skrink_on_next_insert
(
other
.
m_try_skrink_on_next_insert
)
{
other
.
GrowthPolicy
::
clear
();
other
.
m_buckets_data
.
clear
();
other
.
m_buckets
=
static_empty_bucket_ptr
();
other
.
m_bucket_count
=
0
;
other
.
m_nb_elements
=
0
;
other
.
m_load_threshold
=
0
;
other
.
m_grow_on_next_insert
=
false
;
other
.
m_try_skrink_on_next_insert
=
false
;
}
robin_hash
&
operator
=
(
const
robin_hash
&
other
)
{
if
(
&
other
!=
this
)
{
Hash
::
operator
=
(
other
);
KeyEqual
::
operator
=
(
other
);
GrowthPolicy
::
operator
=
(
other
);
m_buckets_data
=
other
.
m_buckets_data
;
m_buckets
=
m_buckets_data
.
empty
()
?
static_empty_bucket_ptr
()
:
m_buckets_data
.
data
();
m_bucket_count
=
other
.
m_bucket_count
;
m_nb_elements
=
other
.
m_nb_elements
;
m_load_threshold
=
other
.
m_load_threshold
;
m_max_load_factor
=
other
.
m_max_load_factor
;
m_grow_on_next_insert
=
other
.
m_grow_on_next_insert
;
m_min_load_factor
=
other
.
m_min_load_factor
;
m_try_skrink_on_next_insert
=
other
.
m_try_skrink_on_next_insert
;
}
return
*
this
;
}
}
robin_hash
&
operator
=
(
robin_hash
&&
other
)
{
template
<
class
U
=
ValueSelect
,
other
.
swap
(
*
this
);
typename
std
::
enable_if
<
has_mapped_type
<
U
>
::
value
&&
other
.
clear
();
IsConst
>::
type
*
=
nullptr
>
const
typename
U
::
value_type
&
value
()
const
{
return
*
this
;
return
U
()(
m_bucket
->
value
())
;
}
}
allocator_type
get_allocator
()
const
{
template
<
class
U
=
ValueSelect
,
return
m_buckets_data
.
get_allocator
();
typename
std
::
enable_if
<
has_mapped_type
<
U
>
::
value
&&
!
IsConst
>::
type
*
=
nullptr
>
typename
U
::
value_type
&
value
()
{
return
U
()(
m_bucket
->
value
());
}
}
reference
operator
*
()
const
{
return
m_bucket
->
value
();
}
/*
* Iterators
pointer
operator
->
()
const
{
return
std
::
addressof
(
m_bucket
->
value
());
}
*/
iterator
begin
()
noexcept
{
robin_iterator
&
operator
++
()
{
std
::
size_t
i
=
0
;
while
(
true
)
{
while
(
i
<
m_bucket_count
&&
m_buckets
[
i
].
empty
())
{
if
(
m_bucket
->
last_bucket
())
{
i
++
;
++
m_bucket
;
return
*
this
;
}
}
return
iterator
(
m_buckets
+
i
);
++
m_bucket
;
}
if
(
!
m_bucket
->
empty
())
{
return
*
this
;
const_iterator
begin
()
const
noexcept
{
return
cbegin
();
}
const_iterator
cbegin
()
const
noexcept
{
std
::
size_t
i
=
0
;
while
(
i
<
m_bucket_count
&&
m_buckets
[
i
].
empty
())
{
i
++
;
}
}
}
return
const_iterator
(
m_buckets
+
i
);
}
}
iterator
end
()
noexcept
{
robin_iterator
operator
++
(
int
)
{
return
iterator
(
m_buckets
+
m_bucket_count
);
robin_iterator
tmp
(
*
this
);
++*
this
;
return
tmp
;
}
}
const_iterator
end
()
const
noexcept
{
friend
bool
operator
==
(
const
robin_iterator
&
lhs
,
return
cend
();
const
robin_iterator
&
rhs
)
{
return
lhs
.
m_bucket
==
rhs
.
m_bucket
;
}
}
const_iterator
cend
()
const
noexcept
{
friend
bool
operator
!=
(
const
robin_iterator
&
lhs
,
return
const_iterator
(
m_buckets
+
m_bucket_count
);
const
robin_iterator
&
rhs
)
{
return
!
(
lhs
==
rhs
);
}
}
private:
/*
bucket_entry_ptr
m_bucket
;
* Capacity
};
*/
bool
empty
()
const
noexcept
{
public:
return
m_nb_elements
==
0
;
#if defined(__cplusplus) && __cplusplus >= 201402L
robin_hash
(
size_type
bucket_count
,
const
Hash
&
hash
,
const
KeyEqual
&
equal
,
const
Allocator
&
alloc
,
float
min_load_factor
=
DEFAULT_MIN_LOAD_FACTOR
,
float
max_load_factor
=
DEFAULT_MAX_LOAD_FACTOR
)
:
Hash
(
hash
),
KeyEqual
(
equal
),
GrowthPolicy
(
bucket_count
),
m_buckets_data
(
[
&
]()
{
if
(
bucket_count
>
max_bucket_count
())
{
TSL_RH_THROW_OR_TERMINATE
(
std
::
length_error
,
"The map exceeds its maximum bucket count."
);
}
return
bucket_count
;
}(),
alloc
),
m_buckets
(
m_buckets_data
.
empty
()
?
static_empty_bucket_ptr
()
:
m_buckets_data
.
data
()),
m_bucket_count
(
bucket_count
),
m_nb_elements
(
0
),
m_grow_on_next_insert
(
false
),
m_try_skrink_on_next_insert
(
false
)
{
if
(
m_bucket_count
>
0
)
{
tsl_rh_assert
(
!
m_buckets_data
.
empty
());
m_buckets_data
.
back
().
set_as_last_bucket
();
}
}
size_type
size
()
const
noexcept
{
this
->
min_load_factor
(
min_load_factor
);
return
m_nb_elements
;
this
->
max_load_factor
(
max_load_factor
);
}
#else
/**
* C++11 doesn't support the creation of a std::vector with a custom allocator
* and 'count' default-inserted elements. The needed contructor `explicit
* vector(size_type count, const Allocator& alloc = Allocator());` is only
* available in C++14 and later. We thus must resize after using the
* `vector(const Allocator& alloc)` constructor.
*
* We can't use `vector(size_type count, const T& value, const Allocator&
* alloc)` as it requires the value T to be copyable.
*/
robin_hash
(
size_type
bucket_count
,
const
Hash
&
hash
,
const
KeyEqual
&
equal
,
const
Allocator
&
alloc
,
float
min_load_factor
=
DEFAULT_MIN_LOAD_FACTOR
,
float
max_load_factor
=
DEFAULT_MAX_LOAD_FACTOR
)
:
Hash
(
hash
),
KeyEqual
(
equal
),
GrowthPolicy
(
bucket_count
),
m_buckets_data
(
alloc
),
m_buckets
(
static_empty_bucket_ptr
()),
m_bucket_count
(
bucket_count
),
m_nb_elements
(
0
),
m_grow_on_next_insert
(
false
),
m_try_skrink_on_next_insert
(
false
)
{
if
(
bucket_count
>
max_bucket_count
())
{
TSL_RH_THROW_OR_TERMINATE
(
std
::
length_error
,
"The map exceeds its maxmimum bucket count."
);
}
}
size_type
max_size
()
const
noexcept
{
if
(
m_bucket_count
>
0
)
{
return
m_buckets_data
.
max_size
();
m_buckets_data
.
resize
(
m_bucket_count
);
m_buckets
=
m_buckets_data
.
data
();
tsl_rh_assert
(
!
m_buckets_data
.
empty
());
m_buckets_data
.
back
().
set_as_last_bucket
();
}
}
/*
this
->
min_load_factor
(
min_load_factor
);
* Modifiers
this
->
max_load_factor
(
max_load_factor
);
*/
}
void
clear
()
noexcept
{
#endif
for
(
auto
&
bucket
:
m_buckets_data
)
{
bucket
.
clear
();
robin_hash
(
const
robin_hash
&
other
)
}
:
Hash
(
other
),
KeyEqual
(
other
),
GrowthPolicy
(
other
),
m_buckets_data
(
other
.
m_buckets_data
),
m_nb_elements
=
0
;
m_buckets
(
m_buckets_data
.
empty
()
?
static_empty_bucket_ptr
()
m_grow_on_next_insert
=
false
;
:
m_buckets_data
.
data
()),
m_bucket_count
(
other
.
m_bucket_count
),
m_nb_elements
(
other
.
m_nb_elements
),
m_load_threshold
(
other
.
m_load_threshold
),
m_max_load_factor
(
other
.
m_max_load_factor
),
m_grow_on_next_insert
(
other
.
m_grow_on_next_insert
),
m_min_load_factor
(
other
.
m_min_load_factor
),
m_try_skrink_on_next_insert
(
other
.
m_try_skrink_on_next_insert
)
{}
robin_hash
(
robin_hash
&&
other
)
noexcept
(
std
::
is_nothrow_move_constructible
<
Hash
>::
value
&&
std
::
is_nothrow_move_constructible
<
KeyEqual
>::
value
&&
std
::
is_nothrow_move_constructible
<
GrowthPolicy
>::
value
&&
std
::
is_nothrow_move_constructible
<
buckets_container_type
>::
value
)
:
Hash
(
std
::
move
(
static_cast
<
Hash
&>
(
other
))),
KeyEqual
(
std
::
move
(
static_cast
<
KeyEqual
&>
(
other
))),
GrowthPolicy
(
std
::
move
(
static_cast
<
GrowthPolicy
&>
(
other
))),
m_buckets_data
(
std
::
move
(
other
.
m_buckets_data
)),
m_buckets
(
m_buckets_data
.
empty
()
?
static_empty_bucket_ptr
()
:
m_buckets_data
.
data
()),
m_bucket_count
(
other
.
m_bucket_count
),
m_nb_elements
(
other
.
m_nb_elements
),
m_load_threshold
(
other
.
m_load_threshold
),
m_max_load_factor
(
other
.
m_max_load_factor
),
m_grow_on_next_insert
(
other
.
m_grow_on_next_insert
),
m_min_load_factor
(
other
.
m_min_load_factor
),
m_try_skrink_on_next_insert
(
other
.
m_try_skrink_on_next_insert
)
{
other
.
GrowthPolicy
::
clear
();
other
.
m_buckets_data
.
clear
();
other
.
m_buckets
=
static_empty_bucket_ptr
();
other
.
m_bucket_count
=
0
;
other
.
m_nb_elements
=
0
;
other
.
m_load_threshold
=
0
;
other
.
m_grow_on_next_insert
=
false
;
other
.
m_try_skrink_on_next_insert
=
false
;
}
robin_hash
&
operator
=
(
const
robin_hash
&
other
)
{
if
(
&
other
!=
this
)
{
Hash
::
operator
=
(
other
);
KeyEqual
::
operator
=
(
other
);
GrowthPolicy
::
operator
=
(
other
);
m_buckets_data
=
other
.
m_buckets_data
;
m_buckets
=
m_buckets_data
.
empty
()
?
static_empty_bucket_ptr
()
:
m_buckets_data
.
data
();
m_bucket_count
=
other
.
m_bucket_count
;
m_nb_elements
=
other
.
m_nb_elements
;
m_load_threshold
=
other
.
m_load_threshold
;
m_max_load_factor
=
other
.
m_max_load_factor
;
m_grow_on_next_insert
=
other
.
m_grow_on_next_insert
;
m_min_load_factor
=
other
.
m_min_load_factor
;
m_try_skrink_on_next_insert
=
other
.
m_try_skrink_on_next_insert
;
}
}
return
*
this
;
}
template
<
typename
P
>
std
::
pair
<
iterator
,
bool
>
insert
(
P
&&
value
)
{
robin_hash
&
operator
=
(
robin_hash
&&
other
)
{
return
insert_impl
(
KeySelect
()(
value
),
std
::
forward
<
P
>
(
value
));
other
.
swap
(
*
this
);
other
.
clear
();
return
*
this
;
}
allocator_type
get_allocator
()
const
{
return
m_buckets_data
.
get_allocator
();
}
/*
* Iterators
*/
iterator
begin
()
noexcept
{
std
::
size_t
i
=
0
;
while
(
i
<
m_bucket_count
&&
m_buckets
[
i
].
empty
())
{
i
++
;
}
}
template
<
typename
P
>
return
iterator
(
m_buckets
+
i
);
iterator
insert_hint
(
const_iterator
hint
,
P
&&
value
)
{
}
if
(
hint
!=
cend
()
&&
compare_keys
(
KeySelect
()(
*
hint
),
KeySelect
()(
value
)))
{
return
mutable_iterator
(
hint
);
const_iterator
begin
()
const
noexcept
{
return
cbegin
();
}
}
const_iterator
cbegin
()
const
noexcept
{
return
insert
(
std
::
forward
<
P
>
(
value
)).
first
;
std
::
size_t
i
=
0
;
while
(
i
<
m_bucket_count
&&
m_buckets
[
i
].
empty
())
{
i
++
;
}
}
template
<
class
InputIt
>
return
const_iterator
(
m_buckets
+
i
);
void
insert
(
InputIt
first
,
InputIt
last
)
{
}
if
(
std
::
is_base_of
<
std
::
forward_iterator_tag
,
typename
std
::
iterator_traits
<
InputIt
>::
iterator_category
>::
value
)
iterator
end
()
noexcept
{
return
iterator
(
m_buckets
+
m_bucket_count
);
}
{
const
auto
nb_elements_insert
=
std
::
distance
(
first
,
last
);
const_iterator
end
()
const
noexcept
{
return
cend
();
}
const
size_type
nb_free_buckets
=
m_load_threshold
-
size
();
tsl_rh_assert
(
m_load_threshold
>=
size
());
const_iterator
cend
()
const
noexcept
{
return
const_iterator
(
m_buckets
+
m_bucket_count
);
if
(
nb_elements_insert
>
0
&&
nb_free_buckets
<
size_type
(
nb_elements_insert
))
{
}
reserve
(
size
()
+
size_type
(
nb_elements_insert
));
}
/*
}
* Capacity
*/
for
(;
first
!=
last
;
++
first
)
{
bool
empty
()
const
noexcept
{
return
m_nb_elements
==
0
;
}
insert
(
*
first
);
}
size_type
size
()
const
noexcept
{
return
m_nb_elements
;
}
size_type
max_size
()
const
noexcept
{
return
m_buckets_data
.
max_size
();
}
/*
* Modifiers
*/
void
clear
()
noexcept
{
for
(
auto
&
bucket
:
m_buckets_data
)
{
bucket
.
clear
();
}
}
m_nb_elements
=
0
;
m_grow_on_next_insert
=
false
;
template
<
class
K
,
class
M
>
}
std
::
pair
<
iterator
,
bool
>
insert_or_assign
(
K
&&
key
,
M
&&
obj
)
{
auto
it
=
try_emplace
(
std
::
forward
<
K
>
(
key
),
std
::
forward
<
M
>
(
obj
));
template
<
typename
P
>
std
::
pair
<
iterator
,
bool
>
insert
(
P
&&
value
)
{
if
(
!
it
.
second
)
{
return
insert_impl
(
KeySelect
()(
value
),
std
::
forward
<
P
>
(
value
));
it
.
first
.
value
()
=
std
::
forward
<
M
>
(
obj
);
}
}
template
<
typename
P
>
iterator
insert_hint
(
const_iterator
hint
,
P
&&
value
)
{
return
it
;
if
(
hint
!=
cend
()
&&
compare_keys
(
KeySelect
()(
*
hint
),
KeySelect
()(
value
)))
{
return
mutable_iterator
(
hint
);
}
}
template
<
class
K
,
class
M
>
return
insert
(
std
::
forward
<
P
>
(
value
)).
first
;
iterator
insert_or_assign
(
const_iterator
hint
,
K
&&
key
,
M
&&
obj
)
{
}
if
(
hint
!=
cend
()
&&
compare_keys
(
KeySelect
()(
*
hint
),
key
))
{
auto
it
=
mutable_iterator
(
hint
);
template
<
class
InputIt
>
void
insert
(
InputIt
first
,
InputIt
last
)
{
it
.
value
()
=
std
::
forward
<
M
>
(
obj
);
if
(
std
::
is_base_of
<
std
::
forward_iterator_tag
,
return
it
;
typename
std
::
iterator_traits
<
InputIt
>::
iterator_category
>::
value
)
{
}
const
auto
nb_elements_insert
=
std
::
distance
(
first
,
last
);
const
size_type
nb_free_buckets
=
m_load_threshold
-
size
();
return
insert_or_assign
(
std
::
forward
<
K
>
(
key
),
std
::
forward
<
M
>
(
obj
)).
first
;
tsl_rh_assert
(
m_load_threshold
>=
size
());
if
(
nb_elements_insert
>
0
&&
nb_free_buckets
<
size_type
(
nb_elements_insert
))
{
reserve
(
size
()
+
size_type
(
nb_elements_insert
));
}
}
}
for
(;
first
!=
last
;
++
first
)
{
template
<
class
...
Args
>
insert
(
*
first
);
std
::
pair
<
iterator
,
bool
>
emplace
(
Args
&&
...
args
)
{
return
insert
(
value_type
(
std
::
forward
<
Args
>
(
args
)...));
}
}
}
template
<
class
...
Args
>
iterator
emplace_hint
(
const_iterator
hint
,
Args
&&
...
args
)
{
template
<
class
K
,
class
M
>
return
insert_hint
(
hint
,
value_type
(
std
::
forward
<
Args
>
(
args
)...));
std
::
pair
<
iterator
,
bool
>
insert_or_assign
(
K
&&
key
,
M
&&
obj
)
{
auto
it
=
try_emplace
(
std
::
forward
<
K
>
(
key
),
std
::
forward
<
M
>
(
obj
));
if
(
!
it
.
second
)
{
it
.
first
.
value
()
=
std
::
forward
<
M
>
(
obj
);
}
}
return
it
;
}
template
<
class
K
,
class
...
Args
>
std
::
pair
<
iterator
,
bool
>
try_emplace
(
K
&&
key
,
Args
&&
...
args
)
{
template
<
class
K
,
class
M
>
return
insert_impl
(
key
,
std
::
piecewise_construct
,
iterator
insert_or_assign
(
const_iterator
hint
,
K
&&
key
,
M
&&
obj
)
{
std
::
forward_as_tuple
(
std
::
forward
<
K
>
(
key
)),
if
(
hint
!=
cend
()
&&
compare_keys
(
KeySelect
()(
*
hint
),
key
))
{
std
::
forward_as_tuple
(
std
::
forward
<
Args
>
(
args
)...));
auto
it
=
mutable_iterator
(
hint
);
it
.
value
()
=
std
::
forward
<
M
>
(
obj
);
return
it
;
}
}
template
<
class
K
,
class
...
Args
>
return
insert_or_assign
(
std
::
forward
<
K
>
(
key
),
std
::
forward
<
M
>
(
obj
)).
first
;
iterator
try_emplace_hint
(
const_iterator
hint
,
K
&&
key
,
Args
&&
...
args
)
{
}
if
(
hint
!=
cend
()
&&
compare_keys
(
KeySelect
()(
*
hint
),
key
))
{
return
mutable_iterator
(
hint
);
template
<
class
...
Args
>
std
::
pair
<
iterator
,
bool
>
emplace
(
Args
&&
...
args
)
{
}
return
insert
(
value_type
(
std
::
forward
<
Args
>
(
args
)...));
}
return
try_emplace
(
std
::
forward
<
K
>
(
key
),
std
::
forward
<
Args
>
(
args
)...).
first
;
template
<
class
...
Args
>
iterator
emplace_hint
(
const_iterator
hint
,
Args
&&
...
args
)
{
return
insert_hint
(
hint
,
value_type
(
std
::
forward
<
Args
>
(
args
)...));
}
template
<
class
K
,
class
...
Args
>
std
::
pair
<
iterator
,
bool
>
try_emplace
(
K
&&
key
,
Args
&&
...
args
)
{
return
insert_impl
(
key
,
std
::
piecewise_construct
,
std
::
forward_as_tuple
(
std
::
forward
<
K
>
(
key
)),
std
::
forward_as_tuple
(
std
::
forward
<
Args
>
(
args
)...));
}
template
<
class
K
,
class
...
Args
>
iterator
try_emplace_hint
(
const_iterator
hint
,
K
&&
key
,
Args
&&
...
args
)
{
if
(
hint
!=
cend
()
&&
compare_keys
(
KeySelect
()(
*
hint
),
key
))
{
return
mutable_iterator
(
hint
);
}
}
return
try_emplace
(
std
::
forward
<
K
>
(
key
),
std
::
forward
<
Args
>
(
args
)...).
first
;
}
/**
* Here to avoid `template<class K> size_type erase(const K& key)` being used
* when we use an `iterator` instead of a `const_iterator`.
*/
iterator
erase
(
iterator
pos
)
{
erase_from_bucket
(
pos
);
/**
/**
* Here to avoid `template<class K> size_type erase(const K& key)` being used when
* Erase bucket used a backward shift after clearing the bucket.
* we use an `iterator` instead of a `const_iterator`.
* Check if there is a new value in the bucket, if not get the next
* non-empty.
*/
*/
iterator
erase
(
iterator
pos
)
{
if
(
pos
.
m_bucket
->
empty
())
{
erase_from_bucket
(
pos
);
++
pos
;
/**
* Erase bucket used a backward shift after clearing the bucket.
* Check if there is a new value in the bucket, if not get the next non-empty.
*/
if
(
pos
.
m_bucket
->
empty
())
{
++
pos
;
}
m_try_skrink_on_next_insert
=
true
;
return
pos
;
}
iterator
erase
(
const_iterator
pos
)
{
return
erase
(
mutable_iterator
(
pos
));
}
iterator
erase
(
const_iterator
first
,
const_iterator
last
)
{
if
(
first
==
last
)
{
return
mutable_iterator
(
first
);
}
auto
first_mutable
=
mutable_iterator
(
first
);
auto
last_mutable
=
mutable_iterator
(
last
);
for
(
auto
it
=
first_mutable
.
m_bucket
;
it
!=
last_mutable
.
m_bucket
;
++
it
)
{
if
(
!
it
->
empty
())
{
it
->
clear
();
m_nb_elements
--
;
}
}
if
(
last_mutable
==
end
())
{
return
end
();
}
/*
* Backward shift on the values which come after the deleted values.
* We try to move the values closer to their ideal bucket.
*/
std
::
size_t
icloser_bucket
=
static_cast
<
std
::
size_t
>
(
first_mutable
.
m_bucket
-
m_buckets
);
std
::
size_t
ito_move_closer_value
=
static_cast
<
std
::
size_t
>
(
last_mutable
.
m_bucket
-
m_buckets
);
tsl_rh_assert
(
ito_move_closer_value
>
icloser_bucket
);
const
std
::
size_t
ireturn_bucket
=
ito_move_closer_value
-
std
::
min
(
ito_move_closer_value
-
icloser_bucket
,
std
::
size_t
(
m_buckets
[
ito_move_closer_value
].
dist_from_ideal_bucket
()));
while
(
ito_move_closer_value
<
m_bucket_count
&&
m_buckets
[
ito_move_closer_value
].
dist_from_ideal_bucket
()
>
0
)
{
icloser_bucket
=
ito_move_closer_value
-
std
::
min
(
ito_move_closer_value
-
icloser_bucket
,
std
::
size_t
(
m_buckets
[
ito_move_closer_value
].
dist_from_ideal_bucket
()));
tsl_rh_assert
(
m_buckets
[
icloser_bucket
].
empty
());
const
distance_type
new_distance
=
distance_type
(
m_buckets
[
ito_move_closer_value
].
dist_from_ideal_bucket
()
-
(
ito_move_closer_value
-
icloser_bucket
));
m_buckets
[
icloser_bucket
].
set_value_of_empty_bucket
(
new_distance
,
m_buckets
[
ito_move_closer_value
].
truncated_hash
(),
std
::
move
(
m_buckets
[
ito_move_closer_value
].
value
()));
m_buckets
[
ito_move_closer_value
].
clear
();
++
icloser_bucket
;
++
ito_move_closer_value
;
}
m_try_skrink_on_next_insert
=
true
;
return
iterator
(
m_buckets
+
ireturn_bucket
);
}
}
m_try_skrink_on_next_insert
=
true
;
template
<
class
K
>
size_type
erase
(
const
K
&
key
)
{
return
pos
;
return
erase
(
key
,
hash_key
(
key
));
}
iterator
erase
(
const_iterator
pos
)
{
return
erase
(
mutable_iterator
(
pos
));
}
iterator
erase
(
const_iterator
first
,
const_iterator
last
)
{
if
(
first
==
last
)
{
return
mutable_iterator
(
first
);
}
}
template
<
class
K
>
auto
first_mutable
=
mutable_iterator
(
first
);
size_type
erase
(
const
K
&
key
,
std
::
size_t
hash
)
{
auto
last_mutable
=
mutable_iterator
(
last
);
auto
it
=
find
(
key
,
hash
);
for
(
auto
it
=
first_mutable
.
m_bucket
;
it
!=
last_mutable
.
m_bucket
;
++
it
)
{
if
(
it
!=
end
())
{
if
(
!
it
->
empty
())
{
erase_from_bucket
(
it
);
it
->
clear
();
m_try_skrink_on_next_insert
=
true
;
m_nb_elements
--
;
}
return
1
;
}
else
{
return
0
;
}
}
}
if
(
last_mutable
==
end
())
{
return
end
();
void
swap
(
robin_hash
&
other
)
{
using
std
::
swap
;
swap
(
static_cast
<
Hash
&>
(
*
this
),
static_cast
<
Hash
&>
(
other
));
swap
(
static_cast
<
KeyEqual
&>
(
*
this
),
static_cast
<
KeyEqual
&>
(
other
));
swap
(
static_cast
<
GrowthPolicy
&>
(
*
this
),
static_cast
<
GrowthPolicy
&>
(
other
));
swap
(
m_buckets_data
,
other
.
m_buckets_data
);
swap
(
m_buckets
,
other
.
m_buckets
);
swap
(
m_bucket_count
,
other
.
m_bucket_count
);
swap
(
m_nb_elements
,
other
.
m_nb_elements
);
swap
(
m_load_threshold
,
other
.
m_load_threshold
);
swap
(
m_max_load_factor
,
other
.
m_max_load_factor
);
swap
(
m_grow_on_next_insert
,
other
.
m_grow_on_next_insert
);
swap
(
m_min_load_factor
,
other
.
m_min_load_factor
);
swap
(
m_try_skrink_on_next_insert
,
other
.
m_try_skrink_on_next_insert
);
}
}
/*
/*
* Lookup
* Backward shift on the values which come after the deleted values.
* We try to move the values closer to their ideal bucket.
*/
*/
template
<
class
K
,
class
U
=
ValueSelect
,
typename
std
::
enable_if
<
has_mapped_type
<
U
>
::
value
>::
type
*
=
nullptr
>
std
::
size_t
icloser_bucket
=
typename
U
::
value_type
&
at
(
const
K
&
key
)
{
static_cast
<
std
::
size_t
>
(
first_mutable
.
m_bucket
-
m_buckets
);
return
at
(
key
,
hash_key
(
key
));
std
::
size_t
ito_move_closer_value
=
}
static_cast
<
std
::
size_t
>
(
last_mutable
.
m_bucket
-
m_buckets
);
tsl_rh_assert
(
ito_move_closer_value
>
icloser_bucket
);
template
<
class
K
,
class
U
=
ValueSelect
,
typename
std
::
enable_if
<
has_mapped_type
<
U
>
::
value
>::
type
*
=
nullptr
>
typename
U
::
value_type
&
at
(
const
K
&
key
,
std
::
size_t
hash
)
{
const
std
::
size_t
ireturn_bucket
=
return
const_cast
<
typename
U
::
value_type
&>
(
static_cast
<
const
robin_hash
*>
(
this
)
->
at
(
key
,
hash
));
ito_move_closer_value
-
}
std
::
min
(
ito_move_closer_value
-
icloser_bucket
,
std
::
size_t
(
template
<
class
K
,
class
U
=
ValueSelect
,
typename
std
::
enable_if
<
has_mapped_type
<
U
>
::
value
>::
type
*
=
nullptr
>
m_buckets
[
ito_move_closer_value
].
dist_from_ideal_bucket
()));
const
typename
U
::
value_type
&
at
(
const
K
&
key
)
const
{
return
at
(
key
,
hash_key
(
key
));
while
(
ito_move_closer_value
<
m_bucket_count
&&
}
m_buckets
[
ito_move_closer_value
].
dist_from_ideal_bucket
()
>
0
)
{
icloser_bucket
=
template
<
class
K
,
class
U
=
ValueSelect
,
typename
std
::
enable_if
<
has_mapped_type
<
U
>
::
value
>::
type
*
=
nullptr
>
ito_move_closer_value
-
const
typename
U
::
value_type
&
at
(
const
K
&
key
,
std
::
size_t
hash
)
const
{
std
::
min
(
auto
it
=
find
(
key
,
hash
);
ito_move_closer_value
-
icloser_bucket
,
if
(
it
!=
cend
())
{
std
::
size_t
(
return
it
.
value
();
m_buckets
[
ito_move_closer_value
].
dist_from_ideal_bucket
()));
}
else
{
tsl_rh_assert
(
m_buckets
[
icloser_bucket
].
empty
());
TSL_RH_THROW_OR_TERMINATE
(
std
::
out_of_range
,
"Couldn't find key."
);
const
distance_type
new_distance
=
distance_type
(
}
m_buckets
[
ito_move_closer_value
].
dist_from_ideal_bucket
()
-
}
(
ito_move_closer_value
-
icloser_bucket
));
m_buckets
[
icloser_bucket
].
set_value_of_empty_bucket
(
template
<
class
K
,
class
U
=
ValueSelect
,
typename
std
::
enable_if
<
has_mapped_type
<
U
>
::
value
>::
type
*
=
nullptr
>
new_distance
,
m_buckets
[
ito_move_closer_value
].
truncated_hash
(),
typename
U
::
value_type
&
operator
[](
K
&&
key
)
{
std
::
move
(
m_buckets
[
ito_move_closer_value
].
value
()));
return
try_emplace
(
std
::
forward
<
K
>
(
key
)).
first
.
value
();
m_buckets
[
ito_move_closer_value
].
clear
();
}
++
icloser_bucket
;
++
ito_move_closer_value
;
template
<
class
K
>
size_type
count
(
const
K
&
key
)
const
{
return
count
(
key
,
hash_key
(
key
));
}
template
<
class
K
>
size_type
count
(
const
K
&
key
,
std
::
size_t
hash
)
const
{
if
(
find
(
key
,
hash
)
!=
cend
())
{
return
1
;
}
else
{
return
0
;
}
}
template
<
class
K
>
iterator
find
(
const
K
&
key
)
{
return
find_impl
(
key
,
hash_key
(
key
));
}
template
<
class
K
>
iterator
find
(
const
K
&
key
,
std
::
size_t
hash
)
{
return
find_impl
(
key
,
hash
);
}
template
<
class
K
>
const_iterator
find
(
const
K
&
key
)
const
{
return
find_impl
(
key
,
hash_key
(
key
));
}
template
<
class
K
>
const_iterator
find
(
const
K
&
key
,
std
::
size_t
hash
)
const
{
return
find_impl
(
key
,
hash
);
}
template
<
class
K
>
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
K
&
key
)
{
return
equal_range
(
key
,
hash_key
(
key
));
}
template
<
class
K
>
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
K
&
key
,
std
::
size_t
hash
)
{
iterator
it
=
find
(
key
,
hash
);
return
std
::
make_pair
(
it
,
(
it
==
end
())
?
it
:
std
::
next
(
it
));
}
}
m_try_skrink_on_next_insert
=
true
;
template
<
class
K
>
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
K
&
key
)
const
{
return
iterator
(
m_buckets
+
ireturn_bucket
);
return
equal_range
(
key
,
hash_key
(
key
));
}
template
<
class
K
>
size_type
erase
(
const
K
&
key
)
{
return
erase
(
key
,
hash_key
(
key
));
}
template
<
class
K
>
size_type
erase
(
const
K
&
key
,
std
::
size_t
hash
)
{
auto
it
=
find
(
key
,
hash
);
if
(
it
!=
end
())
{
erase_from_bucket
(
it
);
m_try_skrink_on_next_insert
=
true
;
return
1
;
}
else
{
return
0
;
}
}
}
template
<
class
K
>
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
K
&
key
,
std
::
size_t
hash
)
const
{
void
swap
(
robin_hash
&
other
)
{
const_iterator
it
=
find
(
key
,
hash
);
using
std
::
swap
;
return
std
::
make_pair
(
it
,
(
it
==
cend
())
?
it
:
std
::
next
(
it
));
swap
(
static_cast
<
Hash
&>
(
*
this
),
static_cast
<
Hash
&>
(
other
));
swap
(
static_cast
<
KeyEqual
&>
(
*
this
),
static_cast
<
KeyEqual
&>
(
other
));
swap
(
static_cast
<
GrowthPolicy
&>
(
*
this
),
static_cast
<
GrowthPolicy
&>
(
other
));
swap
(
m_buckets_data
,
other
.
m_buckets_data
);
swap
(
m_buckets
,
other
.
m_buckets
);
swap
(
m_bucket_count
,
other
.
m_bucket_count
);
swap
(
m_nb_elements
,
other
.
m_nb_elements
);
swap
(
m_load_threshold
,
other
.
m_load_threshold
);
swap
(
m_max_load_factor
,
other
.
m_max_load_factor
);
swap
(
m_grow_on_next_insert
,
other
.
m_grow_on_next_insert
);
swap
(
m_min_load_factor
,
other
.
m_min_load_factor
);
swap
(
m_try_skrink_on_next_insert
,
other
.
m_try_skrink_on_next_insert
);
}
/*
* Lookup
*/
template
<
class
K
,
class
U
=
ValueSelect
,
typename
std
::
enable_if
<
has_mapped_type
<
U
>
::
value
>::
type
*
=
nullptr
>
typename
U
::
value_type
&
at
(
const
K
&
key
)
{
return
at
(
key
,
hash_key
(
key
));
}
template
<
class
K
,
class
U
=
ValueSelect
,
typename
std
::
enable_if
<
has_mapped_type
<
U
>
::
value
>::
type
*
=
nullptr
>
typename
U
::
value_type
&
at
(
const
K
&
key
,
std
::
size_t
hash
)
{
return
const_cast
<
typename
U
::
value_type
&>
(
static_cast
<
const
robin_hash
*>
(
this
)
->
at
(
key
,
hash
));
}
template
<
class
K
,
class
U
=
ValueSelect
,
typename
std
::
enable_if
<
has_mapped_type
<
U
>
::
value
>::
type
*
=
nullptr
>
const
typename
U
::
value_type
&
at
(
const
K
&
key
)
const
{
return
at
(
key
,
hash_key
(
key
));
}
template
<
class
K
,
class
U
=
ValueSelect
,
typename
std
::
enable_if
<
has_mapped_type
<
U
>
::
value
>::
type
*
=
nullptr
>
const
typename
U
::
value_type
&
at
(
const
K
&
key
,
std
::
size_t
hash
)
const
{
auto
it
=
find
(
key
,
hash
);
if
(
it
!=
cend
())
{
return
it
.
value
();
}
else
{
TSL_RH_THROW_OR_TERMINATE
(
std
::
out_of_range
,
"Couldn't find key."
);
}
}
template
<
class
K
,
class
U
=
ValueSelect
,
typename
std
::
enable_if
<
has_mapped_type
<
U
>
::
value
>::
type
*
=
nullptr
>
typename
U
::
value_type
&
operator
[](
K
&&
key
)
{
return
try_emplace
(
std
::
forward
<
K
>
(
key
)).
first
.
value
();
}
template
<
class
K
>
size_type
count
(
const
K
&
key
)
const
{
return
count
(
key
,
hash_key
(
key
));
}
template
<
class
K
>
size_type
count
(
const
K
&
key
,
std
::
size_t
hash
)
const
{
if
(
find
(
key
,
hash
)
!=
cend
())
{
return
1
;
}
else
{
return
0
;
}
}
}
/*
* Bucket interface
template
<
class
K
>
iterator
find
(
const
K
&
key
)
{
*/
return
find_impl
(
key
,
hash_key
(
key
));
size_type
bucket_count
()
const
{
}
return
m_bucket_count
;
template
<
class
K
>
iterator
find
(
const
K
&
key
,
std
::
size_t
hash
)
{
return
find_impl
(
key
,
hash
);
}
template
<
class
K
>
const_iterator
find
(
const
K
&
key
)
const
{
return
find_impl
(
key
,
hash_key
(
key
));
}
template
<
class
K
>
const_iterator
find
(
const
K
&
key
,
std
::
size_t
hash
)
const
{
return
find_impl
(
key
,
hash
);
}
template
<
class
K
>
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
K
&
key
)
{
return
equal_range
(
key
,
hash_key
(
key
));
}
template
<
class
K
>
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
K
&
key
,
std
::
size_t
hash
)
{
iterator
it
=
find
(
key
,
hash
);
return
std
::
make_pair
(
it
,
(
it
==
end
())
?
it
:
std
::
next
(
it
));
}
template
<
class
K
>
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
K
&
key
)
const
{
return
equal_range
(
key
,
hash_key
(
key
));
}
template
<
class
K
>
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
K
&
key
,
std
::
size_t
hash
)
const
{
const_iterator
it
=
find
(
key
,
hash
);
return
std
::
make_pair
(
it
,
(
it
==
cend
())
?
it
:
std
::
next
(
it
));
}
/*
* Bucket interface
*/
size_type
bucket_count
()
const
{
return
m_bucket_count
;
}
size_type
max_bucket_count
()
const
{
return
std
::
min
(
GrowthPolicy
::
max_bucket_count
(),
m_buckets_data
.
max_size
());
}
/*
* Hash policy
*/
float
load_factor
()
const
{
if
(
bucket_count
()
==
0
)
{
return
0
;
}
}
size_type
max_bucket_count
()
const
{
return
float
(
m_nb_elements
)
/
float
(
bucket_count
());
return
std
::
min
(
GrowthPolicy
::
max_bucket_count
(),
m_buckets_data
.
max_size
());
}
float
min_load_factor
()
const
{
return
m_min_load_factor
;
}
float
max_load_factor
()
const
{
return
m_max_load_factor
;
}
void
min_load_factor
(
float
ml
)
{
m_min_load_factor
=
clamp
(
ml
,
float
(
MINIMUM_MIN_LOAD_FACTOR
),
float
(
MAXIMUM_MIN_LOAD_FACTOR
));
}
void
max_load_factor
(
float
ml
)
{
m_max_load_factor
=
clamp
(
ml
,
float
(
MINIMUM_MAX_LOAD_FACTOR
),
float
(
MAXIMUM_MAX_LOAD_FACTOR
));
m_load_threshold
=
size_type
(
float
(
bucket_count
())
*
m_max_load_factor
);
}
void
rehash
(
size_type
count
)
{
count
=
std
::
max
(
count
,
size_type
(
std
::
ceil
(
float
(
size
())
/
max_load_factor
())));
rehash_impl
(
count
);
}
void
reserve
(
size_type
count
)
{
rehash
(
size_type
(
std
::
ceil
(
float
(
count
)
/
max_load_factor
())));
}
/*
* Observers
*/
hasher
hash_function
()
const
{
return
static_cast
<
const
Hash
&>
(
*
this
);
}
key_equal
key_eq
()
const
{
return
static_cast
<
const
KeyEqual
&>
(
*
this
);
}
/*
* Other
*/
iterator
mutable_iterator
(
const_iterator
pos
)
{
return
iterator
(
const_cast
<
bucket_entry
*>
(
pos
.
m_bucket
));
}
private:
template
<
class
K
>
std
::
size_t
hash_key
(
const
K
&
key
)
const
{
return
Hash
::
operator
()(
key
);
}
template
<
class
K1
,
class
K2
>
bool
compare_keys
(
const
K1
&
key1
,
const
K2
&
key2
)
const
{
return
KeyEqual
::
operator
()(
key1
,
key2
);
}
std
::
size_t
bucket_for_hash
(
std
::
size_t
hash
)
const
{
const
std
::
size_t
bucket
=
GrowthPolicy
::
bucket_for_hash
(
hash
);
tsl_rh_assert
(
bucket
<
m_bucket_count
||
(
bucket
==
0
&&
m_bucket_count
==
0
));
return
bucket
;
}
template
<
class
U
=
GrowthPolicy
,
typename
std
::
enable_if
<
is_power_of_two_policy
<
U
>
::
value
>::
type
*
=
nullptr
>
std
::
size_t
next_bucket
(
std
::
size_t
index
)
const
noexcept
{
tsl_rh_assert
(
index
<
bucket_count
());
return
(
index
+
1
)
&
this
->
m_mask
;
}
template
<
class
U
=
GrowthPolicy
,
typename
std
::
enable_if
<!
is_power_of_two_policy
<
U
>
::
value
>::
type
*
=
nullptr
>
std
::
size_t
next_bucket
(
std
::
size_t
index
)
const
noexcept
{
tsl_rh_assert
(
index
<
bucket_count
());
index
++
;
return
(
index
!=
bucket_count
())
?
index
:
0
;
}
template
<
class
K
>
iterator
find_impl
(
const
K
&
key
,
std
::
size_t
hash
)
{
return
mutable_iterator
(
static_cast
<
const
robin_hash
*>
(
this
)
->
find
(
key
,
hash
));
}
template
<
class
K
>
const_iterator
find_impl
(
const
K
&
key
,
std
::
size_t
hash
)
const
{
std
::
size_t
ibucket
=
bucket_for_hash
(
hash
);
distance_type
dist_from_ideal_bucket
=
0
;
while
(
dist_from_ideal_bucket
<=
m_buckets
[
ibucket
].
dist_from_ideal_bucket
())
{
if
(
TSL_RH_LIKELY
(
(
!
USE_STORED_HASH_ON_LOOKUP
||
m_buckets
[
ibucket
].
bucket_hash_equal
(
hash
))
&&
compare_keys
(
KeySelect
()(
m_buckets
[
ibucket
].
value
()),
key
)))
{
return
const_iterator
(
m_buckets
+
ibucket
);
}
ibucket
=
next_bucket
(
ibucket
);
dist_from_ideal_bucket
++
;
}
}
/*
return
cend
();
* Hash policy
}
void
erase_from_bucket
(
iterator
pos
)
{
pos
.
m_bucket
->
clear
();
m_nb_elements
--
;
/**
* Backward shift, swap the empty bucket, previous_ibucket, with the values
* on its right, ibucket, until we cross another empty bucket or if the
* other bucket has a distance_from_ideal_bucket == 0.
*
* We try to move the values closer to their ideal bucket.
*/
*/
float
load_factor
()
const
{
std
::
size_t
previous_ibucket
=
if
(
bucket_count
()
==
0
)
{
static_cast
<
std
::
size_t
>
(
pos
.
m_bucket
-
m_buckets
);
return
0
;
std
::
size_t
ibucket
=
next_bucket
(
previous_ibucket
);
}
while
(
m_buckets
[
ibucket
].
dist_from_ideal_bucket
()
>
0
)
{
return
float
(
m_nb_elements
)
/
float
(
bucket_count
());
tsl_rh_assert
(
m_buckets
[
previous_ibucket
].
empty
());
}
const
distance_type
new_distance
=
float
min_load_factor
()
const
{
distance_type
(
m_buckets
[
ibucket
].
dist_from_ideal_bucket
()
-
1
);
return
m_min_load_factor
;
m_buckets
[
previous_ibucket
].
set_value_of_empty_bucket
(
}
new_distance
,
m_buckets
[
ibucket
].
truncated_hash
(),
std
::
move
(
m_buckets
[
ibucket
].
value
()));
float
max_load_factor
()
const
{
m_buckets
[
ibucket
].
clear
();
return
m_max_load_factor
;
previous_ibucket
=
ibucket
;
ibucket
=
next_bucket
(
ibucket
);
}
}
}
void
min_load_factor
(
float
ml
)
{
m_min_load_factor
=
clamp
(
ml
,
float
(
MINIMUM_MIN_LOAD_FACTOR
),
template
<
class
K
,
class
...
Args
>
float
(
MAXIMUM_MIN_LOAD_FACTOR
));
std
::
pair
<
iterator
,
bool
>
insert_impl
(
const
K
&
key
,
Args
&&
...
value_type_args
)
{
const
std
::
size_t
hash
=
hash_key
(
key
);
std
::
size_t
ibucket
=
bucket_for_hash
(
hash
);
distance_type
dist_from_ideal_bucket
=
0
;
while
(
dist_from_ideal_bucket
<=
m_buckets
[
ibucket
].
dist_from_ideal_bucket
())
{
if
((
!
USE_STORED_HASH_ON_LOOKUP
||
m_buckets
[
ibucket
].
bucket_hash_equal
(
hash
))
&&
compare_keys
(
KeySelect
()(
m_buckets
[
ibucket
].
value
()),
key
))
{
return
std
::
make_pair
(
iterator
(
m_buckets
+
ibucket
),
false
);
}
ibucket
=
next_bucket
(
ibucket
);
dist_from_ideal_bucket
++
;
}
}
void
max_load_factor
(
float
ml
)
{
if
(
rehash_on_extreme_load
())
{
m_max_load_factor
=
clamp
(
ml
,
float
(
MINIMUM_MAX_LOAD_FACTOR
),
ibucket
=
bucket_for_hash
(
hash
);
float
(
MAXIMUM_MAX_LOAD_FACTOR
));
dist_from_ideal_bucket
=
0
;
m_load_threshold
=
size_type
(
float
(
bucket_count
())
*
m_max_load_factor
);
while
(
dist_from_ideal_bucket
<=
m_buckets
[
ibucket
].
dist_from_ideal_bucket
())
{
ibucket
=
next_bucket
(
ibucket
);
dist_from_ideal_bucket
++
;
}
}
}
void
rehash
(
size_type
count
)
{
if
(
m_buckets
[
ibucket
].
empty
())
{
count
=
std
::
max
(
count
,
size_type
(
std
::
ceil
(
float
(
size
())
/
max_load_factor
())));
m_buckets
[
ibucket
].
set_value_of_empty_bucket
(
rehash_impl
(
count
);
dist_from_ideal_bucket
,
bucket_entry
::
truncate_hash
(
hash
),
std
::
forward
<
Args
>
(
value_type_args
)...);
}
else
{
insert_value
(
ibucket
,
dist_from_ideal_bucket
,
bucket_entry
::
truncate_hash
(
hash
),
std
::
forward
<
Args
>
(
value_type_args
)...);
}
}
void
reserve
(
size_type
count
)
{
m_nb_elements
++
;
rehash
(
size_type
(
std
::
ceil
(
float
(
count
)
/
max_load_factor
())));
}
/*
/*
* Observers
* The value will be inserted in ibucket in any case, either because it was
* empty or by stealing the bucket (robin hood).
*/
*/
hasher
hash_function
()
const
{
return
std
::
make_pair
(
iterator
(
m_buckets
+
ibucket
),
true
);
return
static_cast
<
const
Hash
&>
(
*
this
);
}
}
template
<
class
...
Args
>
key_equal
key_eq
()
const
{
void
insert_value
(
std
::
size_t
ibucket
,
distance_type
dist_from_ideal_bucket
,
return
static_cast
<
const
KeyEqual
&>
(
*
this
);
truncated_hash_type
hash
,
Args
&&
...
value_type_args
)
{
}
value_type
value
(
std
::
forward
<
Args
>
(
value_type_args
)...);
insert_value_impl
(
ibucket
,
dist_from_ideal_bucket
,
hash
,
value
);
}
/*
* Other
void
insert_value
(
std
::
size_t
ibucket
,
distance_type
dist_from_ideal_bucket
,
*/
truncated_hash_type
hash
,
value_type
&&
value
)
{
iterator
mutable_iterator
(
const_iterator
pos
)
{
insert_value_impl
(
ibucket
,
dist_from_ideal_bucket
,
hash
,
value
);
return
iterator
(
const_cast
<
bucket_entry
*>
(
pos
.
m_bucket
));
}
}
/*
private:
* We don't use `value_type&& value` as last argument due to a bug in MSVC
template
<
class
K
>
* when `value_type` is a pointer, The compiler is not able to see the
std
::
size_t
hash_key
(
const
K
&
key
)
const
{
* difference between `std::string*` and `std::string*&&` resulting in compile
return
Hash
::
operator
()(
key
);
* error.
}
*
* The `value` will be in a moved state at the end of the function.
template
<
class
K1
,
class
K2
>
*/
bool
compare_keys
(
const
K1
&
key1
,
const
K2
&
key2
)
const
{
void
insert_value_impl
(
std
::
size_t
ibucket
,
return
KeyEqual
::
operator
()(
key1
,
key2
);
distance_type
dist_from_ideal_bucket
,
}
truncated_hash_type
hash
,
value_type
&
value
)
{
m_buckets
[
ibucket
].
swap_with_value_in_bucket
(
dist_from_ideal_bucket
,
hash
,
std
::
size_t
bucket_for_hash
(
std
::
size_t
hash
)
const
{
value
);
const
std
::
size_t
bucket
=
GrowthPolicy
::
bucket_for_hash
(
hash
);
ibucket
=
next_bucket
(
ibucket
);
tsl_rh_assert
(
bucket
<
m_bucket_count
||
(
bucket
==
0
&&
m_bucket_count
==
0
));
dist_from_ideal_bucket
++
;
return
bucket
;
while
(
!
m_buckets
[
ibucket
].
empty
())
{
}
if
(
dist_from_ideal_bucket
>
m_buckets
[
ibucket
].
dist_from_ideal_bucket
())
{
template
<
class
U
=
GrowthPolicy
,
typename
std
::
enable_if
<
is_power_of_two_policy
<
U
>
::
value
>::
type
*
=
nullptr
>
if
(
dist_from_ideal_bucket
>=
REHASH_ON_HIGH_NB_PROBES__NPROBES
&&
std
::
size_t
next_bucket
(
std
::
size_t
index
)
const
noexcept
{
load_factor
()
>=
REHASH_ON_HIGH_NB_PROBES__MIN_LOAD_FACTOR
)
{
tsl_rh_assert
(
index
<
bucket_count
());
/**
* The number of probes is really high, rehash the map on the next
return
(
index
+
1
)
&
this
->
m_mask
;
* insert. Difficult to do now as rehash may throw an exception.
}
*/
m_grow_on_next_insert
=
true
;
template
<
class
U
=
GrowthPolicy
,
typename
std
::
enable_if
<!
is_power_of_two_policy
<
U
>
::
value
>::
type
*
=
nullptr
>
std
::
size_t
next_bucket
(
std
::
size_t
index
)
const
noexcept
{
tsl_rh_assert
(
index
<
bucket_count
());
index
++
;
return
(
index
!=
bucket_count
())
?
index
:
0
;
}
template
<
class
K
>
iterator
find_impl
(
const
K
&
key
,
std
::
size_t
hash
)
{
return
mutable_iterator
(
static_cast
<
const
robin_hash
*>
(
this
)
->
find
(
key
,
hash
));
}
template
<
class
K
>
const_iterator
find_impl
(
const
K
&
key
,
std
::
size_t
hash
)
const
{
std
::
size_t
ibucket
=
bucket_for_hash
(
hash
);
distance_type
dist_from_ideal_bucket
=
0
;
while
(
dist_from_ideal_bucket
<=
m_buckets
[
ibucket
].
dist_from_ideal_bucket
())
{
if
(
TSL_RH_LIKELY
((
!
USE_STORED_HASH_ON_LOOKUP
||
m_buckets
[
ibucket
].
bucket_hash_equal
(
hash
))
&&
compare_keys
(
KeySelect
()(
m_buckets
[
ibucket
].
value
()),
key
)))
{
return
const_iterator
(
m_buckets
+
ibucket
);
}
ibucket
=
next_bucket
(
ibucket
);
dist_from_ideal_bucket
++
;
}
return
cend
();
}
void
erase_from_bucket
(
iterator
pos
)
{
pos
.
m_bucket
->
clear
();
m_nb_elements
--
;
/**
* Backward shift, swap the empty bucket, previous_ibucket, with the values on its right, ibucket,
* until we cross another empty bucket or if the other bucket has a distance_from_ideal_bucket == 0.
*
* We try to move the values closer to their ideal bucket.
*/
std
::
size_t
previous_ibucket
=
static_cast
<
std
::
size_t
>
(
pos
.
m_bucket
-
m_buckets
);
std
::
size_t
ibucket
=
next_bucket
(
previous_ibucket
);
while
(
m_buckets
[
ibucket
].
dist_from_ideal_bucket
()
>
0
)
{
tsl_rh_assert
(
m_buckets
[
previous_ibucket
].
empty
());
const
distance_type
new_distance
=
distance_type
(
m_buckets
[
ibucket
].
dist_from_ideal_bucket
()
-
1
);
m_buckets
[
previous_ibucket
].
set_value_of_empty_bucket
(
new_distance
,
m_buckets
[
ibucket
].
truncated_hash
(),
std
::
move
(
m_buckets
[
ibucket
].
value
()));
m_buckets
[
ibucket
].
clear
();
previous_ibucket
=
ibucket
;
ibucket
=
next_bucket
(
ibucket
);
}
}
template
<
class
K
,
class
...
Args
>
std
::
pair
<
iterator
,
bool
>
insert_impl
(
const
K
&
key
,
Args
&&
...
value_type_args
)
{
const
std
::
size_t
hash
=
hash_key
(
key
);
std
::
size_t
ibucket
=
bucket_for_hash
(
hash
);
distance_type
dist_from_ideal_bucket
=
0
;
while
(
dist_from_ideal_bucket
<=
m_buckets
[
ibucket
].
dist_from_ideal_bucket
())
{
if
((
!
USE_STORED_HASH_ON_LOOKUP
||
m_buckets
[
ibucket
].
bucket_hash_equal
(
hash
))
&&
compare_keys
(
KeySelect
()(
m_buckets
[
ibucket
].
value
()),
key
))
{
return
std
::
make_pair
(
iterator
(
m_buckets
+
ibucket
),
false
);
}
ibucket
=
next_bucket
(
ibucket
);
dist_from_ideal_bucket
++
;
}
if
(
rehash_on_extreme_load
())
{
ibucket
=
bucket_for_hash
(
hash
);
dist_from_ideal_bucket
=
0
;
while
(
dist_from_ideal_bucket
<=
m_buckets
[
ibucket
].
dist_from_ideal_bucket
())
{
ibucket
=
next_bucket
(
ibucket
);
dist_from_ideal_bucket
++
;
}
}
if
(
m_buckets
[
ibucket
].
empty
())
{
m_buckets
[
ibucket
].
set_value_of_empty_bucket
(
dist_from_ideal_bucket
,
bucket_entry
::
truncate_hash
(
hash
),
std
::
forward
<
Args
>
(
value_type_args
)...);
}
else
{
insert_value
(
ibucket
,
dist_from_ideal_bucket
,
bucket_entry
::
truncate_hash
(
hash
),
std
::
forward
<
Args
>
(
value_type_args
)...);
}
}
m_nb_elements
++
;
/*
* The value will be inserted in ibucket in any case, either because it was
* empty or by stealing the bucket (robin hood).
*/
return
std
::
make_pair
(
iterator
(
m_buckets
+
ibucket
),
true
);
}
template
<
class
...
Args
>
void
insert_value
(
std
::
size_t
ibucket
,
distance_type
dist_from_ideal_bucket
,
truncated_hash_type
hash
,
Args
&&
...
value_type_args
)
{
value_type
value
(
std
::
forward
<
Args
>
(
value_type_args
)...);
insert_value_impl
(
ibucket
,
dist_from_ideal_bucket
,
hash
,
value
);
}
void
insert_value
(
std
::
size_t
ibucket
,
distance_type
dist_from_ideal_bucket
,
m_buckets
[
ibucket
].
swap_with_value_in_bucket
(
dist_from_ideal_bucket
,
truncated_hash_type
hash
,
value_type
&&
value
)
hash
,
value
);
{
}
insert_value_impl
(
ibucket
,
dist_from_ideal_bucket
,
hash
,
value
);
ibucket
=
next_bucket
(
ibucket
);
dist_from_ideal_bucket
++
;
}
}
/*
m_buckets
[
ibucket
].
set_value_of_empty_bucket
(
dist_from_ideal_bucket
,
hash
,
* We don't use `value_type&& value` as last argument due to a bug in MSVC when `value_type` is a pointer,
std
::
move
(
value
));
* The compiler is not able to see the difference between `std::string*` and `std::string*&&` resulting in
}
* compile error.
*
void
rehash_impl
(
size_type
count
)
{
* The `value` will be in a moved state at the end of the function.
robin_hash
new_table
(
count
,
static_cast
<
Hash
&>
(
*
this
),
*/
static_cast
<
KeyEqual
&>
(
*
this
),
get_allocator
(),
void
insert_value_impl
(
std
::
size_t
ibucket
,
distance_type
dist_from_ideal_bucket
,
m_min_load_factor
,
m_max_load_factor
);
truncated_hash_type
hash
,
value_type
&
value
)
{
const
bool
use_stored_hash
=
m_buckets
[
ibucket
].
swap_with_value_in_bucket
(
dist_from_ideal_bucket
,
hash
,
value
);
USE_STORED_HASH_ON_REHASH
(
new_table
.
bucket_count
());
ibucket
=
next_bucket
(
ibucket
);
for
(
auto
&
bucket
:
m_buckets_data
)
{
dist_from_ideal_bucket
++
;
if
(
bucket
.
empty
())
{
continue
;
while
(
!
m_buckets
[
ibucket
].
empty
())
{
}
if
(
dist_from_ideal_bucket
>
m_buckets
[
ibucket
].
dist_from_ideal_bucket
())
{
if
(
dist_from_ideal_bucket
>=
REHASH_ON_HIGH_NB_PROBES__NPROBES
&&
const
std
::
size_t
hash
=
load_factor
()
>=
REHASH_ON_HIGH_NB_PROBES__MIN_LOAD_FACTOR
)
use_stored_hash
?
bucket
.
truncated_hash
()
{
:
new_table
.
hash_key
(
KeySelect
()(
bucket
.
value
()));
/**
* The number of probes is really high, rehash the map on the next insert.
new_table
.
insert_value_on_rehash
(
new_table
.
bucket_for_hash
(
hash
),
0
,
* Difficult to do now as rehash may throw an exception.
bucket_entry
::
truncate_hash
(
hash
),
*/
std
::
move
(
bucket
.
value
()));
m_grow_on_next_insert
=
true
;
}
m_buckets
[
ibucket
].
swap_with_value_in_bucket
(
dist_from_ideal_bucket
,
hash
,
value
);
}
ibucket
=
next_bucket
(
ibucket
);
dist_from_ideal_bucket
++
;
}
m_buckets
[
ibucket
].
set_value_of_empty_bucket
(
dist_from_ideal_bucket
,
hash
,
std
::
move
(
value
));
}
}
new_table
.
m_nb_elements
=
m_nb_elements
;
void
rehash_impl
(
size_type
count
)
{
new_table
.
swap
(
*
this
);
robin_hash
new_table
(
count
,
static_cast
<
Hash
&>
(
*
this
),
static_cast
<
KeyEqual
&>
(
*
this
),
}
get_allocator
(),
m_min_load_factor
,
m_max_load_factor
);
void
insert_value_on_rehash
(
std
::
size_t
ibucket
,
const
bool
use_stored_hash
=
USE_STORED_HASH_ON_REHASH
(
new_table
.
bucket_count
());
distance_type
dist_from_ideal_bucket
,
for
(
auto
&
bucket
:
m_buckets_data
)
{
truncated_hash_type
hash
,
value_type
&&
value
)
{
if
(
bucket
.
empty
())
{
while
(
true
)
{
continue
;
if
(
dist_from_ideal_bucket
>
}
m_buckets
[
ibucket
].
dist_from_ideal_bucket
())
{
if
(
m_buckets
[
ibucket
].
empty
())
{
const
std
::
size_t
hash
=
use_stored_hash
?
bucket
.
truncated_hash
()
:
m_buckets
[
ibucket
].
set_value_of_empty_bucket
(
dist_from_ideal_bucket
,
new_table
.
hash_key
(
KeySelect
()(
bucket
.
value
()));
hash
,
std
::
move
(
value
));
return
;
new_table
.
insert_value_on_rehash
(
new_table
.
bucket_for_hash
(
hash
),
0
,
}
else
{
bucket_entry
::
truncate_hash
(
hash
),
std
::
move
(
bucket
.
value
()));
m_buckets
[
ibucket
].
swap_with_value_in_bucket
(
dist_from_ideal_bucket
,
hash
,
value
);
}
}
}
new_table
.
m_nb_elements
=
m_nb_elements
;
new_table
.
swap
(
*
this
);
dist_from_ideal_bucket
++
;
ibucket
=
next_bucket
(
ibucket
);
}
}
}
void
insert_value_on_rehash
(
std
::
size_t
ibucket
,
distance_type
dist_from_ideal_bucket
,
truncated_hash_type
hash
,
value_type
&&
value
)
/**
{
* Grow the table if m_grow_on_next_insert is true or we reached the
while
(
true
)
{
* max_load_factor. Shrink the table if m_try_skrink_on_next_insert is true
if
(
dist_from_ideal_bucket
>
m_buckets
[
ibucket
].
dist_from_ideal_bucket
())
{
* (an erase occured) and we're below the min_load_factor.
if
(
m_buckets
[
ibucket
].
empty
())
{
*
m_buckets
[
ibucket
].
set_value_of_empty_bucket
(
dist_from_ideal_bucket
,
hash
,
std
::
move
(
value
));
* Return true if the table has been rehashed.
return
;
*/
}
bool
rehash_on_extreme_load
()
{
else
{
if
(
m_grow_on_next_insert
||
size
()
>=
m_load_threshold
)
{
m_buckets
[
ibucket
].
swap_with_value_in_bucket
(
dist_from_ideal_bucket
,
hash
,
value
);
rehash_impl
(
GrowthPolicy
::
next_bucket_count
());
}
m_grow_on_next_insert
=
false
;
}
return
true
;
dist_from_ideal_bucket
++
;
ibucket
=
next_bucket
(
ibucket
);
}
}
}
if
(
m_try_skrink_on_next_insert
)
{
m_try_skrink_on_next_insert
=
false
;
/**
if
(
m_min_load_factor
!=
0.0
f
&&
load_factor
()
<
m_min_load_factor
)
{
* Grow the table if m_grow_on_next_insert is true or we reached the max_load_factor.
reserve
(
size
()
+
1
);
* Shrink the table if m_try_skrink_on_next_insert is true (an erase occured) and
* we're below the min_load_factor.
return
true
;
*
}
* Return true if the table has been rehashed.
*/
bool
rehash_on_extreme_load
()
{
if
(
m_grow_on_next_insert
||
size
()
>=
m_load_threshold
)
{
rehash_impl
(
GrowthPolicy
::
next_bucket_count
());
m_grow_on_next_insert
=
false
;
return
true
;
}
if
(
m_try_skrink_on_next_insert
)
{
m_try_skrink_on_next_insert
=
false
;
if
(
m_min_load_factor
!=
0.0
f
&&
load_factor
()
<
m_min_load_factor
)
{
reserve
(
size
()
+
1
);
return
true
;
}
}
return
false
;
}
}
return
false
;
}
public:
public:
static
const
size_type
DEFAULT_INIT_BUCKETS_SIZE
=
0
;
static
const
size_type
DEFAULT_INIT_BUCKETS_SIZE
=
0
;
static
constexpr
float
DEFAULT_MAX_LOAD_FACTOR
=
0.5
f
;
static
constexpr
float
DEFAULT_MAX_LOAD_FACTOR
=
0.5
f
;
static
constexpr
float
MINIMUM_MAX_LOAD_FACTOR
=
0.2
f
;
static
constexpr
float
MINIMUM_MAX_LOAD_FACTOR
=
0.2
f
;
static
constexpr
float
MAXIMUM_MAX_LOAD_FACTOR
=
0.95
f
;
static
constexpr
float
MAXIMUM_MAX_LOAD_FACTOR
=
0.95
f
;
static
constexpr
float
DEFAULT_MIN_LOAD_FACTOR
=
0.0
f
;
static
constexpr
float
DEFAULT_MIN_LOAD_FACTOR
=
0.0
f
;
static
constexpr
float
MINIMUM_MIN_LOAD_FACTOR
=
0.0
f
;
static
constexpr
float
MINIMUM_MIN_LOAD_FACTOR
=
0.0
f
;
static
constexpr
float
MAXIMUM_MIN_LOAD_FACTOR
=
0.15
f
;
static
constexpr
float
MAXIMUM_MIN_LOAD_FACTOR
=
0.15
f
;
static_assert
(
MINIMUM_MAX_LOAD_FACTOR
<
MAXIMUM_MAX_LOAD_FACTOR
,
static_assert
(
MINIMUM_MAX_LOAD_FACTOR
<
MAXIMUM_MAX_LOAD_FACTOR
,
"MINIMUM_MAX_LOAD_FACTOR should be < MAXIMUM_MAX_LOAD_FACTOR"
);
"MINIMUM_MAX_LOAD_FACTOR should be < MAXIMUM_MAX_LOAD_FACTOR"
);
static_assert
(
MINIMUM_MIN_LOAD_FACTOR
<
MAXIMUM_MIN_LOAD_FACTOR
,
static_assert
(
MINIMUM_MIN_LOAD_FACTOR
<
MAXIMUM_MIN_LOAD_FACTOR
,
"MINIMUM_MIN_LOAD_FACTOR should be < MAXIMUM_MIN_LOAD_FACTOR"
);
"MINIMUM_MIN_LOAD_FACTOR should be < MAXIMUM_MIN_LOAD_FACTOR"
);
static_assert
(
MAXIMUM_MIN_LOAD_FACTOR
<
MINIMUM_MAX_LOAD_FACTOR
,
static_assert
(
MAXIMUM_MIN_LOAD_FACTOR
<
MINIMUM_MAX_LOAD_FACTOR
,
"MAXIMUM_MIN_LOAD_FACTOR should be < MINIMUM_MAX_LOAD_FACTOR"
);
"MAXIMUM_MIN_LOAD_FACTOR should be < MINIMUM_MAX_LOAD_FACTOR"
);
private:
private:
static
const
distance_type
REHASH_ON_HIGH_NB_PROBES__NPROBES
=
128
;
static
const
distance_type
REHASH_ON_HIGH_NB_PROBES__NPROBES
=
128
;
static
constexpr
float
REHASH_ON_HIGH_NB_PROBES__MIN_LOAD_FACTOR
=
0.15
f
;
static
constexpr
float
REHASH_ON_HIGH_NB_PROBES__MIN_LOAD_FACTOR
=
0.15
f
;
/**
/**
* Return an always valid pointer to an static empty bucket_entry with
* Return an always valid pointer to an static empty bucket_entry with
last_bucket() == true.
*
last_bucket() == true.
*/
*/
bucket_entry
*
static_empty_bucket_ptr
()
{
bucket_entry
*
static_empty_bucket_ptr
()
{
static
bucket_entry
empty_bucket
(
true
);
static
bucket_entry
empty_bucket
(
true
);
return
&
empty_bucket
;
return
&
empty_bucket
;
}
}
private:
private:
buckets_container_type
m_buckets_data
;
buckets_container_type
m_buckets_data
;
/**
/**
* Points to m_buckets_data.data() if !m_buckets_data.empty() otherwise points to static_empty_bucket_ptr.
* Points to m_buckets_data.data() if !m_buckets_data.empty() otherwise points
* This variable is useful to avoid the cost of checking if m_buckets_data is empty when trying
* to static_empty_bucket_ptr. This variable is useful to avoid the cost of
* to find an element.
* checking if m_buckets_data is empty when trying to find an element.
*
*
* TODO Remove m_buckets_data and only use a pointer instead of a pointer+vector to save some space in the robin_hash object.
* TODO Remove m_buckets_data and only use a pointer instead of a
* Manage the Allocator manually.
* pointer+vector to save some space in the robin_hash object. Manage the
*/
* Allocator manually.
bucket_entry
*
m_buckets
;
*/
bucket_entry
*
m_buckets
;
/**
* Used a lot in find, avoid the call to m_buckets_data.size() which is a bit slower.
/**
*/
* Used a lot in find, avoid the call to m_buckets_data.size() which is a bit
size_type
m_bucket_count
;
* slower.
*/
size_type
m_nb_elements
;
size_type
m_bucket_count
;
size_type
m_load_threshold
;
size_type
m_nb_elements
;
float
m_max_load_factor
;
size_type
m_load_threshold
;
bool
m_grow_on_next_insert
;
float
m_max_load_factor
;
float
m_min_load_factor
;
bool
m_grow_on_next_insert
;
/**
float
m_min_load_factor
;
* We can't shrink down the map on erase operations as the erase methods need to return the next iterator.
* Shrinking the map would invalidate all the iterators and we could not return the next iterator in a meaningful way,
/**
* On erase, we thus just indicate on erase that we should try to shrink the hash table on the next insert
* We can't shrink down the map on erase operations as the erase methods need
* if we go below the min_load_factor.
* to return the next iterator. Shrinking the map would invalidate all the
*/
* iterators and we could not return the next iterator in a meaningful way, On
bool
m_try_skrink_on_next_insert
;
* erase, we thus just indicate on erase that we should try to shrink the hash
* table on the next insert if we go below the min_load_factor.
*/
bool
m_try_skrink_on_next_insert
;
};
};
}
}
// namespace detail_robin_hash
}
}
// namespace tsl
#endif
#endif
include/tsl/robin_map.h
View file @
19e73bbe
/**
/**
* MIT License
* MIT License
*
*
* Copyright (c) 2017 Tessil
* Copyright (c) 2017 Tessil
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
* furnished to do so, subject to the following conditions:
*
*
* The above copyright notice and this permission notice shall be included in
all
* The above copyright notice and this permission notice shall be included in
* copies or substantial portions of the Software.
*
all
copies or substantial portions of the Software.
*
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
...
@@ -22,662 +22,693 @@
...
@@ -22,662 +22,693 @@
* SOFTWARE.
* SOFTWARE.
*/
*/
#ifndef TSL_ROBIN_MAP_H
#ifndef TSL_ROBIN_MAP_H
#define TSL_ROBIN_MAP_H
#define TSL_ROBIN_MAP_H
#include "robin_hash.h"
#include <cstddef>
#include <cstddef>
#include <functional>
#include <functional>
#include <initializer_list>
#include <initializer_list>
#include <memory>
#include <memory>
#include <type_traits>
#include <type_traits>
#include <utility>
#include <utility>
#include "robin_hash.h"
namespace
tsl
{
namespace
tsl
{
/**
/**
* Implementation of a hash map using open-adressing and the robin hood hashing algorithm with backward shift deletion.
* Implementation of a hash map using open-adressing and the robin hood hashing
*
* algorithm with backward shift deletion.
* For operations modifying the hash map (insert, erase, rehash, ...), the strong exception guarantee
*
* is only guaranteed when the expression `std::is_nothrow_swappable<std::pair<Key, T>>::value &&
* For operations modifying the hash map (insert, erase, rehash, ...), the
* std::is_nothrow_move_constructible<std::pair<Key, T>>::value` is true, otherwise if an exception
* strong exception guarantee is only guaranteed when the expression
* is thrown during the swap or the move, the hash map may end up in a undefined state. Per the standard
* `std::is_nothrow_swappable<std::pair<Key, T>>::value &&
* a `Key` or `T` with a noexcept copy constructor and no move constructor also satisfies the
* std::is_nothrow_move_constructible<std::pair<Key, T>>::value` is true,
* `std::is_nothrow_move_constructible<std::pair<Key, T>>::value` criterion (and will thus guarantee the
* otherwise if an exception is thrown during the swap or the move, the hash map
* strong exception for the map).
* may end up in a undefined state. Per the standard a `Key` or `T` with a
*
* noexcept copy constructor and no move constructor also satisfies the
* When `StoreHash` is true, 32 bits of the hash are stored alongside the values. It can improve
* `std::is_nothrow_move_constructible<std::pair<Key, T>>::value` criterion (and
* the performance during lookups if the `KeyEqual` function takes time (if it engenders a cache-miss for example)
* will thus guarantee the strong exception for the map).
* as we then compare the stored hashes before comparing the keys. When `tsl::rh::power_of_two_growth_policy` is used
*
* as `GrowthPolicy`, it may also speed-up the rehash process as we can avoid to recalculate the hash.
* When `StoreHash` is true, 32 bits of the hash are stored alongside the
* When it is detected that storing the hash will not incur any memory penality due to alignement (i.e.
* values. It can improve the performance during lookups if the `KeyEqual`
* `sizeof(tsl::detail_robin_hash::bucket_entry<ValueType, true>) ==
* function takes time (if it engenders a cache-miss for example) as we then
* sizeof(tsl::detail_robin_hash::bucket_entry<ValueType, false>)`) and `tsl::rh::power_of_two_growth_policy` is
* compare the stored hashes before comparing the keys. When
* used, the hash will be stored even if `StoreHash` is false so that we can speed-up the rehash (but it will
* `tsl::rh::power_of_two_growth_policy` is used as `GrowthPolicy`, it may also
* speed-up the rehash process as we can avoid to recalculate the hash. When it
* is detected that storing the hash will not incur any memory penality due to
* alignement (i.e. `sizeof(tsl::detail_robin_hash::bucket_entry<ValueType,
* true>) == sizeof(tsl::detail_robin_hash::bucket_entry<ValueType, false>)`)
* and `tsl::rh::power_of_two_growth_policy` is used, the hash will be stored
* even if `StoreHash` is false so that we can speed-up the rehash (but it will
* not be used on lookups unless `StoreHash` is true).
* not be used on lookups unless `StoreHash` is true).
*
*
* `GrowthPolicy` defines how the map grows and consequently how a hash value is mapped to a bucket.
* `GrowthPolicy` defines how the map grows and consequently how a hash value is
* By default the map uses `tsl::rh::power_of_two_growth_policy`. This policy keeps the number of buckets
* mapped to a bucket. By default the map uses
* to a power of two and uses a mask to map the hash to a bucket instead of the slow modulo.
* `tsl::rh::power_of_two_growth_policy`. This policy keeps the number of
* Other growth policies are available and you may define your own growth policy,
* buckets to a power of two and uses a mask to map the hash to a bucket instead
* check `tsl::rh::power_of_two_growth_policy` for the interface.
* of the slow modulo. Other growth policies are available and you may define
*
* your own growth policy, check `tsl::rh::power_of_two_growth_policy` for the
* interface.
*
* `std::pair<Key, T>` must be swappable.
* `std::pair<Key, T>` must be swappable.
*
*
* `Key` and `T` must be copy and/or move constructible.
* `Key` and `T` must be copy and/or move constructible.
*
*
* If the destructor of `Key` or `T` throws an exception, the behaviour of the class is undefined.
* If the destructor of `Key` or `T` throws an exception, the behaviour of the
*
* class is undefined.
*
* Iterators invalidation:
* Iterators invalidation:
* - clear, operator=, reserve, rehash: always invalidate the iterators.
* - clear, operator=, reserve, rehash: always invalidate the iterators.
* - insert, emplace, emplace_hint, operator[]: if there is an effective insert, invalidate the iterators.
* - insert, emplace, emplace_hint, operator[]: if there is an effective
* insert, invalidate the iterators.
* - erase: always invalidate the iterators.
* - erase: always invalidate the iterators.
*/
*/
template
<
class
Key
,
template
<
class
Key
,
class
T
,
class
Hash
=
std
::
hash
<
Key
>,
class
T
,
class
KeyEqual
=
std
::
equal_to
<
Key
>
,
class
Hash
=
std
::
hash
<
Key
>,
class
Allocator
=
std
::
allocator
<
std
::
pair
<
Key
,
T
>>
,
class
KeyEqual
=
std
::
equal_to
<
Key
>
,
bool
StoreHash
=
false
,
class
Allocator
=
std
::
allocator
<
std
::
pair
<
Key
,
T
>>
,
class
GrowthPolicy
=
tsl
::
rh
::
power_of_two_growth_policy
<
2
>>
bool
StoreHash
=
false
,
class
GrowthPolicy
=
tsl
::
rh
::
power_of_two_growth_policy
<
2
>>
class
robin_map
{
class
robin_map
{
private:
private:
template
<
typename
U
>
template
<
typename
U
>
using
has_is_transparent
=
tsl
::
detail_robin_hash
::
has_is_transparent
<
U
>
;
using
has_is_transparent
=
tsl
::
detail_robin_hash
::
has_is_transparent
<
U
>
;
class
KeySelect
{
public:
using
key_type
=
Key
;
const
key_type
&
operator
()(
const
std
::
pair
<
Key
,
T
>&
key_value
)
const
noexcept
{
return
key_value
.
first
;
}
key_type
&
operator
()(
std
::
pair
<
Key
,
T
>&
key_value
)
noexcept
{
return
key_value
.
first
;
}
};
class
ValueSelect
{
public:
using
value_type
=
T
;
const
value_type
&
operator
()(
const
std
::
pair
<
Key
,
T
>&
key_value
)
const
noexcept
{
return
key_value
.
second
;
}
value_type
&
operator
()(
std
::
pair
<
Key
,
T
>&
key_value
)
noexcept
{
return
key_value
.
second
;
}
};
using
ht
=
detail_robin_hash
::
robin_hash
<
std
::
pair
<
Key
,
T
>
,
KeySelect
,
ValueSelect
,
Hash
,
KeyEqual
,
Allocator
,
StoreHash
,
GrowthPolicy
>
;
public:
using
key_type
=
typename
ht
::
key_type
;
using
mapped_type
=
T
;
using
value_type
=
typename
ht
::
value_type
;
using
size_type
=
typename
ht
::
size_type
;
using
difference_type
=
typename
ht
::
difference_type
;
using
hasher
=
typename
ht
::
hasher
;
using
key_equal
=
typename
ht
::
key_equal
;
using
allocator_type
=
typename
ht
::
allocator_type
;
using
reference
=
typename
ht
::
reference
;
using
const_reference
=
typename
ht
::
const_reference
;
using
pointer
=
typename
ht
::
pointer
;
using
const_pointer
=
typename
ht
::
const_pointer
;
using
iterator
=
typename
ht
::
iterator
;
using
const_iterator
=
typename
ht
::
const_iterator
;
public:
/*
* Constructors
*/
robin_map
()
:
robin_map
(
ht
::
DEFAULT_INIT_BUCKETS_SIZE
)
{
}
explicit
robin_map
(
size_type
bucket_count
,
const
Hash
&
hash
=
Hash
(),
const
KeyEqual
&
equal
=
KeyEqual
(),
const
Allocator
&
alloc
=
Allocator
())
:
m_ht
(
bucket_count
,
hash
,
equal
,
alloc
)
{
}
robin_map
(
size_type
bucket_count
,
const
Allocator
&
alloc
)
:
robin_map
(
bucket_count
,
Hash
(),
KeyEqual
(),
alloc
)
{
}
robin_map
(
size_type
bucket_count
,
const
Hash
&
hash
,
const
Allocator
&
alloc
)
:
robin_map
(
bucket_count
,
hash
,
KeyEqual
(),
alloc
)
{
}
explicit
robin_map
(
const
Allocator
&
alloc
)
:
robin_map
(
ht
::
DEFAULT_INIT_BUCKETS_SIZE
,
alloc
)
{
}
template
<
class
InputIt
>
robin_map
(
InputIt
first
,
InputIt
last
,
size_type
bucket_count
=
ht
::
DEFAULT_INIT_BUCKETS_SIZE
,
const
Hash
&
hash
=
Hash
(),
const
KeyEqual
&
equal
=
KeyEqual
(),
const
Allocator
&
alloc
=
Allocator
())
:
robin_map
(
bucket_count
,
hash
,
equal
,
alloc
)
{
insert
(
first
,
last
);
}
template
<
class
InputIt
>
robin_map
(
InputIt
first
,
InputIt
last
,
size_type
bucket_count
,
const
Allocator
&
alloc
)
:
robin_map
(
first
,
last
,
bucket_count
,
Hash
(),
KeyEqual
(),
alloc
)
{
}
template
<
class
InputIt
>
robin_map
(
InputIt
first
,
InputIt
last
,
size_type
bucket_count
,
const
Hash
&
hash
,
const
Allocator
&
alloc
)
:
robin_map
(
first
,
last
,
bucket_count
,
hash
,
KeyEqual
(),
alloc
)
{
}
robin_map
(
std
::
initializer_list
<
value_type
>
init
,
class
KeySelect
{
size_type
bucket_count
=
ht
::
DEFAULT_INIT_BUCKETS_SIZE
,
public:
const
Hash
&
hash
=
Hash
(),
using
key_type
=
Key
;
const
KeyEqual
&
equal
=
KeyEqual
(),
const
Allocator
&
alloc
=
Allocator
())
:
robin_map
(
init
.
begin
(),
init
.
end
(),
bucket_count
,
hash
,
equal
,
alloc
)
{
}
robin_map
(
std
::
initializer_list
<
value_type
>
init
,
const
key_type
&
size_type
bucket_count
,
operator
()(
const
std
::
pair
<
Key
,
T
>
&
key_value
)
const
noexcept
{
const
Allocator
&
alloc
)
:
return
key_value
.
first
;
robin_map
(
init
.
begin
(),
init
.
end
(),
bucket_count
,
Hash
(),
KeyEqual
(),
alloc
)
{
}
}
robin_map
(
std
::
initializer_list
<
value_type
>
init
,
key_type
&
operator
()(
std
::
pair
<
Key
,
T
>
&
key_value
)
noexcept
{
size_type
bucket_count
,
return
key_value
.
first
;
const
Hash
&
hash
,
const
Allocator
&
alloc
)
:
robin_map
(
init
.
begin
(),
init
.
end
(),
bucket_count
,
hash
,
KeyEqual
(),
alloc
)
{
}
robin_map
&
operator
=
(
std
::
initializer_list
<
value_type
>
ilist
)
{
m_ht
.
clear
();
m_ht
.
reserve
(
ilist
.
size
());
m_ht
.
insert
(
ilist
.
begin
(),
ilist
.
end
());
return
*
this
;
}
allocator_type
get_allocator
()
const
{
return
m_ht
.
get_allocator
();
}
/*
* Iterators
*/
iterator
begin
()
noexcept
{
return
m_ht
.
begin
();
}
const_iterator
begin
()
const
noexcept
{
return
m_ht
.
begin
();
}
const_iterator
cbegin
()
const
noexcept
{
return
m_ht
.
cbegin
();
}
iterator
end
()
noexcept
{
return
m_ht
.
end
();
}
const_iterator
end
()
const
noexcept
{
return
m_ht
.
end
();
}
const_iterator
cend
()
const
noexcept
{
return
m_ht
.
cend
();
}
/*
* Capacity
*/
bool
empty
()
const
noexcept
{
return
m_ht
.
empty
();
}
size_type
size
()
const
noexcept
{
return
m_ht
.
size
();
}
size_type
max_size
()
const
noexcept
{
return
m_ht
.
max_size
();
}
/*
* Modifiers
*/
void
clear
()
noexcept
{
m_ht
.
clear
();
}
std
::
pair
<
iterator
,
bool
>
insert
(
const
value_type
&
value
)
{
return
m_ht
.
insert
(
value
);
}
template
<
class
P
,
typename
std
::
enable_if
<
std
::
is_constructible
<
value_type
,
P
&&
>
::
value
>::
type
*
=
nullptr
>
std
::
pair
<
iterator
,
bool
>
insert
(
P
&&
value
)
{
return
m_ht
.
emplace
(
std
::
forward
<
P
>
(
value
));
}
std
::
pair
<
iterator
,
bool
>
insert
(
value_type
&&
value
)
{
return
m_ht
.
insert
(
std
::
move
(
value
));
}
iterator
insert
(
const_iterator
hint
,
const
value_type
&
value
)
{
return
m_ht
.
insert_hint
(
hint
,
value
);
}
template
<
class
P
,
typename
std
::
enable_if
<
std
::
is_constructible
<
value_type
,
P
&&
>
::
value
>::
type
*
=
nullptr
>
iterator
insert
(
const_iterator
hint
,
P
&&
value
)
{
return
m_ht
.
emplace_hint
(
hint
,
std
::
forward
<
P
>
(
value
));
}
iterator
insert
(
const_iterator
hint
,
value_type
&&
value
)
{
return
m_ht
.
insert_hint
(
hint
,
std
::
move
(
value
));
}
template
<
class
InputIt
>
void
insert
(
InputIt
first
,
InputIt
last
)
{
m_ht
.
insert
(
first
,
last
);
}
void
insert
(
std
::
initializer_list
<
value_type
>
ilist
)
{
m_ht
.
insert
(
ilist
.
begin
(),
ilist
.
end
());
}
}
};
class
ValueSelect
{
public:
using
value_type
=
T
;
template
<
class
M
>
std
::
pair
<
iterator
,
bool
>
insert_or_assign
(
const
key_type
&
k
,
M
&&
obj
)
{
return
m_ht
.
insert_or_assign
(
k
,
std
::
forward
<
M
>
(
obj
));
}
template
<
class
M
>
const
value_type
&
std
::
pair
<
iterator
,
bool
>
insert_or_assign
(
key_type
&&
k
,
M
&&
obj
)
{
operator
()(
const
std
::
pair
<
Key
,
T
>
&
key_value
)
const
noexcept
{
return
m_ht
.
insert_or_assign
(
std
::
move
(
k
),
std
::
forward
<
M
>
(
obj
));
return
key_value
.
second
;
}
template
<
class
M
>
iterator
insert_or_assign
(
const_iterator
hint
,
const
key_type
&
k
,
M
&&
obj
)
{
return
m_ht
.
insert_or_assign
(
hint
,
k
,
std
::
forward
<
M
>
(
obj
));
}
template
<
class
M
>
iterator
insert_or_assign
(
const_iterator
hint
,
key_type
&&
k
,
M
&&
obj
)
{
return
m_ht
.
insert_or_assign
(
hint
,
std
::
move
(
k
),
std
::
forward
<
M
>
(
obj
));
}
/**
* Due to the way elements are stored, emplace will need to move or copy the key-value once.
* The method is equivalent to insert(value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template
<
class
...
Args
>
std
::
pair
<
iterator
,
bool
>
emplace
(
Args
&&
...
args
)
{
return
m_ht
.
emplace
(
std
::
forward
<
Args
>
(
args
)...);
}
/**
* Due to the way elements are stored, emplace_hint will need to move or copy the key-value once.
* The method is equivalent to insert(hint, value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template
<
class
...
Args
>
iterator
emplace_hint
(
const_iterator
hint
,
Args
&&
...
args
)
{
return
m_ht
.
emplace_hint
(
hint
,
std
::
forward
<
Args
>
(
args
)...);
}
template
<
class
...
Args
>
std
::
pair
<
iterator
,
bool
>
try_emplace
(
const
key_type
&
k
,
Args
&&
...
args
)
{
return
m_ht
.
try_emplace
(
k
,
std
::
forward
<
Args
>
(
args
)...);
}
template
<
class
...
Args
>
std
::
pair
<
iterator
,
bool
>
try_emplace
(
key_type
&&
k
,
Args
&&
...
args
)
{
return
m_ht
.
try_emplace
(
std
::
move
(
k
),
std
::
forward
<
Args
>
(
args
)...);
}
template
<
class
...
Args
>
iterator
try_emplace
(
const_iterator
hint
,
const
key_type
&
k
,
Args
&&
...
args
)
{
return
m_ht
.
try_emplace_hint
(
hint
,
k
,
std
::
forward
<
Args
>
(
args
)...);
}
template
<
class
...
Args
>
iterator
try_emplace
(
const_iterator
hint
,
key_type
&&
k
,
Args
&&
...
args
)
{
return
m_ht
.
try_emplace_hint
(
hint
,
std
::
move
(
k
),
std
::
forward
<
Args
>
(
args
)...);
}
iterator
erase
(
iterator
pos
)
{
return
m_ht
.
erase
(
pos
);
}
iterator
erase
(
const_iterator
pos
)
{
return
m_ht
.
erase
(
pos
);
}
iterator
erase
(
const_iterator
first
,
const_iterator
last
)
{
return
m_ht
.
erase
(
first
,
last
);
}
size_type
erase
(
const
key_type
&
key
)
{
return
m_ht
.
erase
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
*/
size_type
erase
(
const
key_type
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
erase
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
size_type
erase
(
const
K
&
key
)
{
return
m_ht
.
erase
(
key
);
}
/**
* @copydoc erase(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
size_type
erase
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
erase
(
key
,
precalculated_hash
);
}
void
swap
(
robin_map
&
other
)
{
other
.
m_ht
.
swap
(
m_ht
);
}
/*
* Lookup
*/
T
&
at
(
const
Key
&
key
)
{
return
m_ht
.
at
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
T
&
at
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
at
(
key
,
precalculated_hash
);
}
const
T
&
at
(
const
Key
&
key
)
const
{
return
m_ht
.
at
(
key
);
}
/**
* @copydoc at(const Key& key, std::size_t precalculated_hash)
*/
const
T
&
at
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
at
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
T
&
at
(
const
K
&
key
)
{
return
m_ht
.
at
(
key
);
}
/**
* @copydoc at(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
T
&
at
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
at
(
key
,
precalculated_hash
);
}
/**
* @copydoc at(const K& key)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
const
T
&
at
(
const
K
&
key
)
const
{
return
m_ht
.
at
(
key
);
}
/**
* @copydoc at(const K& key, std::size_t precalculated_hash)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
const
T
&
at
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
at
(
key
,
precalculated_hash
);
}
T
&
operator
[](
const
Key
&
key
)
{
return
m_ht
[
key
];
}
T
&
operator
[](
Key
&&
key
)
{
return
m_ht
[
std
::
move
(
key
)];
}
size_type
count
(
const
Key
&
key
)
const
{
return
m_ht
.
count
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
size_type
count
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
count
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
size_type
count
(
const
K
&
key
)
const
{
return
m_ht
.
count
(
key
);
}
/**
* @copydoc count(const K& key) const
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
size_type
count
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
count
(
key
,
precalculated_hash
);
}
iterator
find
(
const
Key
&
key
)
{
return
m_ht
.
find
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
iterator
find
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
find
(
key
,
precalculated_hash
);
}
const_iterator
find
(
const
Key
&
key
)
const
{
return
m_ht
.
find
(
key
);
}
/**
* @copydoc find(const Key& key, std::size_t precalculated_hash)
*/
const_iterator
find
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
find
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
iterator
find
(
const
K
&
key
)
{
return
m_ht
.
find
(
key
);
}
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
iterator
find
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
find
(
key
,
precalculated_hash
);
}
/**
* @copydoc find(const K& key)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
const_iterator
find
(
const
K
&
key
)
const
{
return
m_ht
.
find
(
key
);
}
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
const_iterator
find
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
find
(
key
,
precalculated_hash
);
}
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
Key
&
key
)
{
return
m_ht
.
equal_range
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
equal_range
(
key
,
precalculated_hash
);
}
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
Key
&
key
)
const
{
return
m_ht
.
equal_range
(
key
);
}
/**
* @copydoc equal_range(const Key& key, std::size_t precalculated_hash)
*/
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
equal_range
(
key
,
precalculated_hash
);
}
}
/**
value_type
&
operator
()(
std
::
pair
<
Key
,
T
>
&
key_value
)
noexcept
{
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
return
key_value
.
second
;
* If so, K must be hashable and comparable to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
K
&
key
)
{
return
m_ht
.
equal_range
(
key
);
}
/**
* @copydoc equal_range(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
equal_range
(
key
,
precalculated_hash
);
}
/**
* @copydoc equal_range(const K& key)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
K
&
key
)
const
{
return
m_ht
.
equal_range
(
key
);
}
/**
* @copydoc equal_range(const K& key, std::size_t precalculated_hash)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
equal_range
(
key
,
precalculated_hash
);
}
/*
* Bucket interface
*/
size_type
bucket_count
()
const
{
return
m_ht
.
bucket_count
();
}
size_type
max_bucket_count
()
const
{
return
m_ht
.
max_bucket_count
();
}
/*
* Hash policy
*/
float
load_factor
()
const
{
return
m_ht
.
load_factor
();
}
float
min_load_factor
()
const
{
return
m_ht
.
min_load_factor
();
}
float
max_load_factor
()
const
{
return
m_ht
.
max_load_factor
();
}
/**
* Set the `min_load_factor` to `ml`. When the `load_factor` of the map goes
* below `min_load_factor` after some erase operations, the map will be
* shrunk when an insertion occurs. The erase method itself never shrinks
* the map.
*
* The default value of `min_load_factor` is 0.0f, the map never shrinks by default.
*/
void
min_load_factor
(
float
ml
)
{
m_ht
.
min_load_factor
(
ml
);
}
void
max_load_factor
(
float
ml
)
{
m_ht
.
max_load_factor
(
ml
);
}
void
rehash
(
size_type
count
)
{
m_ht
.
rehash
(
count
);
}
void
reserve
(
size_type
count
)
{
m_ht
.
reserve
(
count
);
}
/*
* Observers
*/
hasher
hash_function
()
const
{
return
m_ht
.
hash_function
();
}
key_equal
key_eq
()
const
{
return
m_ht
.
key_eq
();
}
/*
* Other
*/
/**
* Convert a const_iterator to an iterator.
*/
iterator
mutable_iterator
(
const_iterator
pos
)
{
return
m_ht
.
mutable_iterator
(
pos
);
}
friend
bool
operator
==
(
const
robin_map
&
lhs
,
const
robin_map
&
rhs
)
{
if
(
lhs
.
size
()
!=
rhs
.
size
())
{
return
false
;
}
for
(
const
auto
&
element_lhs
:
lhs
)
{
const
auto
it_element_rhs
=
rhs
.
find
(
element_lhs
.
first
);
if
(
it_element_rhs
==
rhs
.
cend
()
||
element_lhs
.
second
!=
it_element_rhs
->
second
)
{
return
false
;
}
}
return
true
;
}
}
};
using
ht
=
detail_robin_hash
::
robin_hash
<
std
::
pair
<
Key
,
T
>
,
KeySelect
,
ValueSelect
,
Hash
,
KeyEqual
,
Allocator
,
StoreHash
,
GrowthPolicy
>
;
public:
using
key_type
=
typename
ht
::
key_type
;
using
mapped_type
=
T
;
using
value_type
=
typename
ht
::
value_type
;
using
size_type
=
typename
ht
::
size_type
;
using
difference_type
=
typename
ht
::
difference_type
;
using
hasher
=
typename
ht
::
hasher
;
using
key_equal
=
typename
ht
::
key_equal
;
using
allocator_type
=
typename
ht
::
allocator_type
;
using
reference
=
typename
ht
::
reference
;
using
const_reference
=
typename
ht
::
const_reference
;
using
pointer
=
typename
ht
::
pointer
;
using
const_pointer
=
typename
ht
::
const_pointer
;
using
iterator
=
typename
ht
::
iterator
;
using
const_iterator
=
typename
ht
::
const_iterator
;
public:
/*
* Constructors
*/
robin_map
()
:
robin_map
(
ht
::
DEFAULT_INIT_BUCKETS_SIZE
)
{}
explicit
robin_map
(
size_type
bucket_count
,
const
Hash
&
hash
=
Hash
(),
const
KeyEqual
&
equal
=
KeyEqual
(),
const
Allocator
&
alloc
=
Allocator
())
:
m_ht
(
bucket_count
,
hash
,
equal
,
alloc
)
{}
robin_map
(
size_type
bucket_count
,
const
Allocator
&
alloc
)
:
robin_map
(
bucket_count
,
Hash
(),
KeyEqual
(),
alloc
)
{}
robin_map
(
size_type
bucket_count
,
const
Hash
&
hash
,
const
Allocator
&
alloc
)
:
robin_map
(
bucket_count
,
hash
,
KeyEqual
(),
alloc
)
{}
explicit
robin_map
(
const
Allocator
&
alloc
)
:
robin_map
(
ht
::
DEFAULT_INIT_BUCKETS_SIZE
,
alloc
)
{}
template
<
class
InputIt
>
robin_map
(
InputIt
first
,
InputIt
last
,
size_type
bucket_count
=
ht
::
DEFAULT_INIT_BUCKETS_SIZE
,
const
Hash
&
hash
=
Hash
(),
const
KeyEqual
&
equal
=
KeyEqual
(),
const
Allocator
&
alloc
=
Allocator
())
:
robin_map
(
bucket_count
,
hash
,
equal
,
alloc
)
{
insert
(
first
,
last
);
}
template
<
class
InputIt
>
robin_map
(
InputIt
first
,
InputIt
last
,
size_type
bucket_count
,
const
Allocator
&
alloc
)
:
robin_map
(
first
,
last
,
bucket_count
,
Hash
(),
KeyEqual
(),
alloc
)
{}
template
<
class
InputIt
>
robin_map
(
InputIt
first
,
InputIt
last
,
size_type
bucket_count
,
const
Hash
&
hash
,
const
Allocator
&
alloc
)
:
robin_map
(
first
,
last
,
bucket_count
,
hash
,
KeyEqual
(),
alloc
)
{}
robin_map
(
std
::
initializer_list
<
value_type
>
init
,
size_type
bucket_count
=
ht
::
DEFAULT_INIT_BUCKETS_SIZE
,
const
Hash
&
hash
=
Hash
(),
const
KeyEqual
&
equal
=
KeyEqual
(),
const
Allocator
&
alloc
=
Allocator
())
:
robin_map
(
init
.
begin
(),
init
.
end
(),
bucket_count
,
hash
,
equal
,
alloc
)
{}
robin_map
(
std
::
initializer_list
<
value_type
>
init
,
size_type
bucket_count
,
const
Allocator
&
alloc
)
:
robin_map
(
init
.
begin
(),
init
.
end
(),
bucket_count
,
Hash
(),
KeyEqual
(),
alloc
)
{}
robin_map
(
std
::
initializer_list
<
value_type
>
init
,
size_type
bucket_count
,
const
Hash
&
hash
,
const
Allocator
&
alloc
)
:
robin_map
(
init
.
begin
(),
init
.
end
(),
bucket_count
,
hash
,
KeyEqual
(),
alloc
)
{}
robin_map
&
operator
=
(
std
::
initializer_list
<
value_type
>
ilist
)
{
m_ht
.
clear
();
m_ht
.
reserve
(
ilist
.
size
());
m_ht
.
insert
(
ilist
.
begin
(),
ilist
.
end
());
return
*
this
;
}
allocator_type
get_allocator
()
const
{
return
m_ht
.
get_allocator
();
}
/*
* Iterators
*/
iterator
begin
()
noexcept
{
return
m_ht
.
begin
();
}
const_iterator
begin
()
const
noexcept
{
return
m_ht
.
begin
();
}
const_iterator
cbegin
()
const
noexcept
{
return
m_ht
.
cbegin
();
}
iterator
end
()
noexcept
{
return
m_ht
.
end
();
}
const_iterator
end
()
const
noexcept
{
return
m_ht
.
end
();
}
const_iterator
cend
()
const
noexcept
{
return
m_ht
.
cend
();
}
/*
* Capacity
*/
bool
empty
()
const
noexcept
{
return
m_ht
.
empty
();
}
size_type
size
()
const
noexcept
{
return
m_ht
.
size
();
}
size_type
max_size
()
const
noexcept
{
return
m_ht
.
max_size
();
}
/*
* Modifiers
*/
void
clear
()
noexcept
{
m_ht
.
clear
();
}
friend
bool
operator
!=
(
const
robin_map
&
lhs
,
const
robin_map
&
rhs
)
{
std
::
pair
<
iterator
,
bool
>
insert
(
const
value_type
&
value
)
{
return
!
operator
==
(
lhs
,
rhs
);
return
m_ht
.
insert
(
value
);
}
template
<
class
P
,
typename
std
::
enable_if
<
std
::
is_constructible
<
value_type
,
P
&&
>
::
value
>::
type
*
=
nullptr
>
std
::
pair
<
iterator
,
bool
>
insert
(
P
&&
value
)
{
return
m_ht
.
emplace
(
std
::
forward
<
P
>
(
value
));
}
std
::
pair
<
iterator
,
bool
>
insert
(
value_type
&&
value
)
{
return
m_ht
.
insert
(
std
::
move
(
value
));
}
iterator
insert
(
const_iterator
hint
,
const
value_type
&
value
)
{
return
m_ht
.
insert_hint
(
hint
,
value
);
}
template
<
class
P
,
typename
std
::
enable_if
<
std
::
is_constructible
<
value_type
,
P
&&
>
::
value
>::
type
*
=
nullptr
>
iterator
insert
(
const_iterator
hint
,
P
&&
value
)
{
return
m_ht
.
emplace_hint
(
hint
,
std
::
forward
<
P
>
(
value
));
}
iterator
insert
(
const_iterator
hint
,
value_type
&&
value
)
{
return
m_ht
.
insert_hint
(
hint
,
std
::
move
(
value
));
}
template
<
class
InputIt
>
void
insert
(
InputIt
first
,
InputIt
last
)
{
m_ht
.
insert
(
first
,
last
);
}
void
insert
(
std
::
initializer_list
<
value_type
>
ilist
)
{
m_ht
.
insert
(
ilist
.
begin
(),
ilist
.
end
());
}
template
<
class
M
>
std
::
pair
<
iterator
,
bool
>
insert_or_assign
(
const
key_type
&
k
,
M
&&
obj
)
{
return
m_ht
.
insert_or_assign
(
k
,
std
::
forward
<
M
>
(
obj
));
}
template
<
class
M
>
std
::
pair
<
iterator
,
bool
>
insert_or_assign
(
key_type
&&
k
,
M
&&
obj
)
{
return
m_ht
.
insert_or_assign
(
std
::
move
(
k
),
std
::
forward
<
M
>
(
obj
));
}
template
<
class
M
>
iterator
insert_or_assign
(
const_iterator
hint
,
const
key_type
&
k
,
M
&&
obj
)
{
return
m_ht
.
insert_or_assign
(
hint
,
k
,
std
::
forward
<
M
>
(
obj
));
}
template
<
class
M
>
iterator
insert_or_assign
(
const_iterator
hint
,
key_type
&&
k
,
M
&&
obj
)
{
return
m_ht
.
insert_or_assign
(
hint
,
std
::
move
(
k
),
std
::
forward
<
M
>
(
obj
));
}
/**
* Due to the way elements are stored, emplace will need to move or copy the
* key-value once. The method is equivalent to
* insert(value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template
<
class
...
Args
>
std
::
pair
<
iterator
,
bool
>
emplace
(
Args
&&
...
args
)
{
return
m_ht
.
emplace
(
std
::
forward
<
Args
>
(
args
)...);
}
/**
* Due to the way elements are stored, emplace_hint will need to move or copy
* the key-value once. The method is equivalent to insert(hint,
* value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template
<
class
...
Args
>
iterator
emplace_hint
(
const_iterator
hint
,
Args
&&
...
args
)
{
return
m_ht
.
emplace_hint
(
hint
,
std
::
forward
<
Args
>
(
args
)...);
}
template
<
class
...
Args
>
std
::
pair
<
iterator
,
bool
>
try_emplace
(
const
key_type
&
k
,
Args
&&
...
args
)
{
return
m_ht
.
try_emplace
(
k
,
std
::
forward
<
Args
>
(
args
)...);
}
template
<
class
...
Args
>
std
::
pair
<
iterator
,
bool
>
try_emplace
(
key_type
&&
k
,
Args
&&
...
args
)
{
return
m_ht
.
try_emplace
(
std
::
move
(
k
),
std
::
forward
<
Args
>
(
args
)...);
}
template
<
class
...
Args
>
iterator
try_emplace
(
const_iterator
hint
,
const
key_type
&
k
,
Args
&&
...
args
)
{
return
m_ht
.
try_emplace_hint
(
hint
,
k
,
std
::
forward
<
Args
>
(
args
)...);
}
template
<
class
...
Args
>
iterator
try_emplace
(
const_iterator
hint
,
key_type
&&
k
,
Args
&&
...
args
)
{
return
m_ht
.
try_emplace_hint
(
hint
,
std
::
move
(
k
),
std
::
forward
<
Args
>
(
args
)...);
}
iterator
erase
(
iterator
pos
)
{
return
m_ht
.
erase
(
pos
);
}
iterator
erase
(
const_iterator
pos
)
{
return
m_ht
.
erase
(
pos
);
}
iterator
erase
(
const_iterator
first
,
const_iterator
last
)
{
return
m_ht
.
erase
(
first
,
last
);
}
size_type
erase
(
const
key_type
&
key
)
{
return
m_ht
.
erase
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup to the value if you already have the hash.
*/
size_type
erase
(
const
key_type
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
erase
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef
* KeyEqual::is_transparent exists. If so, K must be hashable and comparable
* to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
size_type
erase
(
const
K
&
key
)
{
return
m_ht
.
erase
(
key
);
}
/**
* @copydoc erase(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup to the value if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
size_type
erase
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
erase
(
key
,
precalculated_hash
);
}
void
swap
(
robin_map
&
other
)
{
other
.
m_ht
.
swap
(
m_ht
);
}
/*
* Lookup
*/
T
&
at
(
const
Key
&
key
)
{
return
m_ht
.
at
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup if you already have the hash.
*/
T
&
at
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
at
(
key
,
precalculated_hash
);
}
const
T
&
at
(
const
Key
&
key
)
const
{
return
m_ht
.
at
(
key
);
}
/**
* @copydoc at(const Key& key, std::size_t precalculated_hash)
*/
const
T
&
at
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
at
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef
* KeyEqual::is_transparent exists. If so, K must be hashable and comparable
* to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
T
&
at
(
const
K
&
key
)
{
return
m_ht
.
at
(
key
);
}
/**
* @copydoc at(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
T
&
at
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
at
(
key
,
precalculated_hash
);
}
/**
* @copydoc at(const K& key)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
const
T
&
at
(
const
K
&
key
)
const
{
return
m_ht
.
at
(
key
);
}
/**
* @copydoc at(const K& key, std::size_t precalculated_hash)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
const
T
&
at
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
at
(
key
,
precalculated_hash
);
}
T
&
operator
[](
const
Key
&
key
)
{
return
m_ht
[
key
];
}
T
&
operator
[](
Key
&&
key
)
{
return
m_ht
[
std
::
move
(
key
)];
}
size_type
count
(
const
Key
&
key
)
const
{
return
m_ht
.
count
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup if you already have the hash.
*/
size_type
count
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
count
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef
* KeyEqual::is_transparent exists. If so, K must be hashable and comparable
* to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
size_type
count
(
const
K
&
key
)
const
{
return
m_ht
.
count
(
key
);
}
/**
* @copydoc count(const K& key) const
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
size_type
count
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
count
(
key
,
precalculated_hash
);
}
iterator
find
(
const
Key
&
key
)
{
return
m_ht
.
find
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup if you already have the hash.
*/
iterator
find
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
find
(
key
,
precalculated_hash
);
}
const_iterator
find
(
const
Key
&
key
)
const
{
return
m_ht
.
find
(
key
);
}
/**
* @copydoc find(const Key& key, std::size_t precalculated_hash)
*/
const_iterator
find
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
find
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef
* KeyEqual::is_transparent exists. If so, K must be hashable and comparable
* to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
iterator
find
(
const
K
&
key
)
{
return
m_ht
.
find
(
key
);
}
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
iterator
find
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
find
(
key
,
precalculated_hash
);
}
/**
* @copydoc find(const K& key)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
const_iterator
find
(
const
K
&
key
)
const
{
return
m_ht
.
find
(
key
);
}
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
const_iterator
find
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
find
(
key
,
precalculated_hash
);
}
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
Key
&
key
)
{
return
m_ht
.
equal_range
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup if you already have the hash.
*/
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
equal_range
(
key
,
precalculated_hash
);
}
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
Key
&
key
)
const
{
return
m_ht
.
equal_range
(
key
);
}
/**
* @copydoc equal_range(const Key& key, std::size_t precalculated_hash)
*/
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
equal_range
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef
* KeyEqual::is_transparent exists. If so, K must be hashable and comparable
* to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
K
&
key
)
{
return
m_ht
.
equal_range
(
key
);
}
/**
* @copydoc equal_range(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
equal_range
(
key
,
precalculated_hash
);
}
/**
* @copydoc equal_range(const K& key)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
K
&
key
)
const
{
return
m_ht
.
equal_range
(
key
);
}
/**
* @copydoc equal_range(const K& key, std::size_t precalculated_hash)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
equal_range
(
key
,
precalculated_hash
);
}
/*
* Bucket interface
*/
size_type
bucket_count
()
const
{
return
m_ht
.
bucket_count
();
}
size_type
max_bucket_count
()
const
{
return
m_ht
.
max_bucket_count
();
}
/*
* Hash policy
*/
float
load_factor
()
const
{
return
m_ht
.
load_factor
();
}
float
min_load_factor
()
const
{
return
m_ht
.
min_load_factor
();
}
float
max_load_factor
()
const
{
return
m_ht
.
max_load_factor
();
}
/**
* Set the `min_load_factor` to `ml`. When the `load_factor` of the map goes
* below `min_load_factor` after some erase operations, the map will be
* shrunk when an insertion occurs. The erase method itself never shrinks
* the map.
*
* The default value of `min_load_factor` is 0.0f, the map never shrinks by
* default.
*/
void
min_load_factor
(
float
ml
)
{
m_ht
.
min_load_factor
(
ml
);
}
void
max_load_factor
(
float
ml
)
{
m_ht
.
max_load_factor
(
ml
);
}
void
rehash
(
size_type
count
)
{
m_ht
.
rehash
(
count
);
}
void
reserve
(
size_type
count
)
{
m_ht
.
reserve
(
count
);
}
/*
* Observers
*/
hasher
hash_function
()
const
{
return
m_ht
.
hash_function
();
}
key_equal
key_eq
()
const
{
return
m_ht
.
key_eq
();
}
/*
* Other
*/
/**
* Convert a const_iterator to an iterator.
*/
iterator
mutable_iterator
(
const_iterator
pos
)
{
return
m_ht
.
mutable_iterator
(
pos
);
}
friend
bool
operator
==
(
const
robin_map
&
lhs
,
const
robin_map
&
rhs
)
{
if
(
lhs
.
size
()
!=
rhs
.
size
())
{
return
false
;
}
}
friend
void
swap
(
robin_map
&
lhs
,
robin_map
&
rhs
)
{
for
(
const
auto
&
element_lhs
:
lhs
)
{
lhs
.
swap
(
rhs
);
const
auto
it_element_rhs
=
rhs
.
find
(
element_lhs
.
first
);
if
(
it_element_rhs
==
rhs
.
cend
()
||
element_lhs
.
second
!=
it_element_rhs
->
second
)
{
return
false
;
}
}
}
return
true
;
}
friend
bool
operator
!=
(
const
robin_map
&
lhs
,
const
robin_map
&
rhs
)
{
return
!
operator
==
(
lhs
,
rhs
);
}
friend
void
swap
(
robin_map
&
lhs
,
robin_map
&
rhs
)
{
lhs
.
swap
(
rhs
);
}
private:
private:
ht
m_ht
;
ht
m_ht
;
};
};
/**
/**
* Same as `tsl::robin_map<Key, T, Hash, KeyEqual, Allocator, StoreHash, tsl::rh::prime_growth_policy>`.
* Same as `tsl::robin_map<Key, T, Hash, KeyEqual, Allocator, StoreHash,
* tsl::rh::prime_growth_policy>`.
*/
*/
template
<
class
Key
,
template
<
class
Key
,
class
T
,
class
Hash
=
std
::
hash
<
Key
>,
class
T
,
class
KeyEqual
=
std
::
equal_to
<
Key
>
,
class
Hash
=
std
::
hash
<
Key
>,
class
Allocator
=
std
::
allocator
<
std
::
pair
<
Key
,
T
>>
,
class
KeyEqual
=
std
::
equal_to
<
Key
>
,
bool
StoreHash
=
false
>
class
Allocator
=
std
::
allocator
<
std
::
pair
<
Key
,
T
>>
,
using
robin_pg_map
=
robin_map
<
Key
,
T
,
Hash
,
KeyEqual
,
Allocator
,
StoreHash
,
bool
StoreHash
=
false
>
tsl
::
rh
::
prime_growth_policy
>
;
using
robin_pg_map
=
robin_map
<
Key
,
T
,
Hash
,
KeyEqual
,
Allocator
,
StoreHash
,
tsl
::
rh
::
prime_growth_policy
>
;
}
// end namespace tsl
}
// end namespace tsl
...
...
include/utility/timer.h
View file @
19e73bbe
...
@@ -14,14 +14,14 @@
...
@@ -14,14 +14,14 @@
#pragma once
#pragma once
#include <chrono>
#include <chrono>
#ifdef
SPCON
V_CUDA
#ifdef
T
V_CUDA
#include <cuda_runtime_api.h>
#include <cuda_runtime_api.h>
#endif
#endif
#include <iostream>
#include <iostream>
namespace
spconv
{
namespace
spconv
{
#ifdef
SPCON
V_CUDA
#ifdef
T
V_CUDA
template
<
typename
TimeT
=
std
::
chrono
::
microseconds
>
struct
CudaContextTimer
{
template
<
typename
TimeT
=
std
::
chrono
::
microseconds
>
struct
CudaContextTimer
{
CudaContextTimer
()
{
CudaContextTimer
()
{
cudaDeviceSynchronize
();
cudaDeviceSynchronize
();
...
...
setup.py
View file @
19e73bbe
import
os
import
os
import
re
import
sys
import
platform
import
platform
import
re
import
subprocess
import
subprocess
import
torch
import
sys
from
setuptools
import
setup
,
Extension
,
find_packages
from
setuptools.command.build_ext
import
build_ext
from
distutils.version
import
LooseVersion
from
distutils.version
import
LooseVersion
from
pathlib
import
Path
from
pathlib
import
Path
import
torch
from
setuptools
import
Extension
,
find_packages
,
setup
from
setuptools.command.build_ext
import
build_ext
# if 'LIBTORCH_ROOT' not in os.environ:
# if 'LIBTORCH_ROOT' not in os.environ:
# raise ValueError("You must set LIBTORCH_ROOT to your torch c++ library.")
# raise ValueError("You must set LIBTORCH_ROOT to your torch c++ library.")
...
@@ -100,4 +100,3 @@ setup(
...
@@ -100,4 +100,3 @@ setup(
cmdclass
=
dict
(
build_ext
=
CMakeBuild
),
cmdclass
=
dict
(
build_ext
=
CMakeBuild
),
zip_safe
=
False
,
zip_safe
=
False
,
)
)
spconv/__init__.py
View file @
19e73bbe
...
@@ -12,21 +12,20 @@
...
@@ -12,21 +12,20 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
platform
from
pathlib
import
Path
from
pathlib
import
Path
import
platform
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
spconv
import
utils
from
spconv.conv
import
SparseConv2d
,
SparseConv3d
,
SubMConv2d
,
SubMConv3d
from
spconv
import
ops
,
utils
from
spconv.conv
import
SparseConvTranspose2d
,
SparseConvTranspose3d
from
spconv.conv
import
(
SparseConv2d
,
SparseConv3d
,
SparseConvTranspose2d
,
from
spconv.conv
import
SparseInverseConv2d
,
SparseInverseConv3d
SparseConvTranspose3d
,
SparseInverseConv2d
,
SparseInverseConv3d
,
SubMConv2d
,
SubMConv3d
)
from
spconv.identity
import
Identity
from
spconv.modules
import
SparseModule
,
SparseSequential
from
spconv.modules
import
SparseModule
,
SparseSequential
from
spconv.pool
import
SparseMaxPool2d
,
SparseMaxPool3d
from
spconv.pool
import
SparseMaxPool2d
,
SparseMaxPool3d
from
spconv.tables
import
ConcatTable
,
JoinTable
,
AddTable
from
spconv.tables
import
AddTable
,
ConcatTable
,
JoinTable
from
spconv.identity
import
Identity
from
spconv
import
ops
_LIB_FILE_NAME
=
"libspconv.so"
_LIB_FILE_NAME
=
"libspconv.so"
if
platform
.
system
()
==
"Windows"
:
if
platform
.
system
()
==
"Windows"
:
...
@@ -34,6 +33,7 @@ if platform.system() == "Windows":
...
@@ -34,6 +33,7 @@ if platform.system() == "Windows":
_LIB_PATH
=
str
(
Path
(
__file__
).
parent
/
_LIB_FILE_NAME
)
_LIB_PATH
=
str
(
Path
(
__file__
).
parent
/
_LIB_FILE_NAME
)
torch
.
ops
.
load_library
(
_LIB_PATH
)
torch
.
ops
.
load_library
(
_LIB_PATH
)
def
scatter_nd
(
indices
,
updates
,
shape
):
def
scatter_nd
(
indices
,
updates
,
shape
):
"""pytorch edition of tensorflow scatter_nd.
"""pytorch edition of tensorflow scatter_nd.
this function don't contain except handle code. so use this carefully
this function don't contain except handle code. so use this carefully
...
@@ -49,8 +49,10 @@ def scatter_nd(indices, updates, shape):
...
@@ -49,8 +49,10 @@ def scatter_nd(indices, updates, shape):
ret
[
slices
]
=
updates
.
view
(
*
output_shape
)
ret
[
slices
]
=
updates
.
view
(
*
output_shape
)
return
ret
return
ret
class
SparseConvTensor
(
object
):
class
SparseConvTensor
(
object
):
def
__init__
(
self
,
features
,
indices
,
spatial_shape
,
batch_size
,
grid
=
None
):
def
__init__
(
self
,
features
,
indices
,
spatial_shape
,
batch_size
,
grid
=
None
):
"""
"""
Args:
Args:
grid: pre-allocated grid tensor. should be used when the volume of spatial shape
grid: pre-allocated grid tensor. should be used when the volume of spatial shape
...
@@ -77,7 +79,8 @@ class SparseConvTensor(object):
...
@@ -77,7 +79,8 @@ class SparseConvTensor(object):
return
None
return
None
def
dense
(
self
,
channels_first
=
True
):
def
dense
(
self
,
channels_first
=
True
):
output_shape
=
[
self
.
batch_size
]
+
list
(
self
.
spatial_shape
)
+
[
self
.
features
.
shape
[
1
]]
output_shape
=
[
self
.
batch_size
]
+
list
(
self
.
spatial_shape
)
+
[
self
.
features
.
shape
[
1
]]
res
=
scatter_nd
(
self
.
indices
.
long
(),
self
.
features
,
output_shape
)
res
=
scatter_nd
(
self
.
indices
.
long
(),
self
.
features
,
output_shape
)
if
not
channels_first
:
if
not
channels_first
:
return
res
return
res
...
@@ -88,7 +91,8 @@ class SparseConvTensor(object):
...
@@ -88,7 +91,8 @@ class SparseConvTensor(object):
@
property
@
property
def
sparity
(
self
):
def
sparity
(
self
):
return
self
.
indices
.
shape
[
0
]
/
np
.
prod
(
self
.
spatial_shape
)
/
self
.
batch_size
return
self
.
indices
.
shape
[
0
]
/
np
.
prod
(
self
.
spatial_shape
)
/
self
.
batch_size
class
ToDense
(
SparseModule
):
class
ToDense
(
SparseModule
):
...
@@ -97,6 +101,7 @@ class ToDense(SparseModule):
...
@@ -97,6 +101,7 @@ class ToDense(SparseModule):
def
forward
(
self
,
x
:
SparseConvTensor
):
def
forward
(
self
,
x
:
SparseConvTensor
):
return
x
.
dense
()
return
x
.
dense
()
class
RemoveGrid
(
SparseModule
):
class
RemoveGrid
(
SparseModule
):
"""remove pre-allocated grid buffer.
"""remove pre-allocated grid buffer.
"""
"""
...
...
spconv/conv.py
View file @
19e73bbe
...
@@ -16,15 +16,16 @@ import math
...
@@ -16,15 +16,16 @@ import math
import
time
import
time
import
numpy
as
np
import
numpy
as
np
import
spconv
import
spconv.functional
as
Fsp
import
torch
import
torch
from
spconv
import
ops
from
spconv.modules
import
SparseModule
from
torch
import
nn
from
torch
import
nn
from
torch.nn
import
init
from
torch.nn
import
init
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
import
spconv
import
spconv.functional
as
Fsp
from
spconv
import
ops
from
spconv.modules
import
SparseModule
def
_calculate_fan_in_and_fan_out_hwio
(
tensor
):
def
_calculate_fan_in_and_fan_out_hwio
(
tensor
):
dimensions
=
tensor
.
ndimension
()
dimensions
=
tensor
.
ndimension
()
...
@@ -146,8 +147,9 @@ class SparseConvolution(SparseModule):
...
@@ -146,8 +147,9 @@ class SparseConvolution(SparseModule):
self
.
weight
.
view
(
self
.
in_channels
,
self
.
out_channels
))
self
.
weight
.
view
(
self
.
in_channels
,
self
.
out_channels
))
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
features
+=
self
.
bias
features
+=
self
.
bias
out_tensor
=
spconv
.
SparseConvTensor
(
out_tensor
=
spconv
.
SparseConvTensor
(
features
,
input
.
indices
,
features
,
input
.
indices
,
input
.
spatial_shape
,
input
.
batch_size
)
input
.
spatial_shape
,
input
.
batch_size
)
out_tensor
.
indice_dict
=
input
.
indice_dict
out_tensor
.
indice_dict
=
input
.
indice_dict
out_tensor
.
grid
=
input
.
grid
out_tensor
.
grid
=
input
.
grid
return
out_tensor
return
out_tensor
...
@@ -181,9 +183,12 @@ class SparseConvolution(SparseModule):
...
@@ -181,9 +183,12 @@ class SparseConvolution(SparseModule):
spatial_shape
)
spatial_shape
)
if
self
.
fused_bn
:
if
self
.
fused_bn
:
assert
self
.
bias
is
not
None
assert
self
.
bias
is
not
None
out_features
=
ops
.
fused_indice_conv
(
out_features
=
ops
.
fused_indice_conv
(
features
,
self
.
weight
,
features
,
self
.
weight
,
self
.
bias
,
indice_pairs
.
to
(
device
),
self
.
bias
,
indice_pair_num
,
outids
.
shape
[
0
],
self
.
inverse
,
self
.
subm
)
indice_pairs
.
to
(
device
),
indice_pair_num
,
outids
.
shape
[
0
],
self
.
inverse
,
self
.
subm
)
else
:
else
:
if
self
.
subm
:
if
self
.
subm
:
out_features
=
Fsp
.
indice_subm_conv
(
features
,
self
.
weight
,
out_features
=
Fsp
.
indice_subm_conv
(
features
,
self
.
weight
,
...
@@ -222,18 +227,17 @@ class SparseConv2d(SparseConvolution):
...
@@ -222,18 +227,17 @@ class SparseConv2d(SparseConvolution):
bias
=
True
,
bias
=
True
,
indice_key
=
None
,
indice_key
=
None
,
use_hash
=
False
):
use_hash
=
False
):
super
(
SparseConv2d
,
self
).
__init__
(
super
(
SparseConv2d
,
self
).
__init__
(
2
,
2
,
in_channels
,
in_channels
,
out_channels
,
out_channels
,
kernel_size
,
kernel_size
,
stride
,
stride
,
padding
,
padding
,
dilation
,
dilation
,
groups
,
groups
,
bias
,
bias
,
indice_key
=
indice_key
,
indice_key
=
indice_key
,
use_hash
=
use_hash
)
use_hash
=
use_hash
)
class
SparseConv3d
(
SparseConvolution
):
class
SparseConv3d
(
SparseConvolution
):
...
@@ -248,18 +252,17 @@ class SparseConv3d(SparseConvolution):
...
@@ -248,18 +252,17 @@ class SparseConv3d(SparseConvolution):
bias
=
True
,
bias
=
True
,
indice_key
=
None
,
indice_key
=
None
,
use_hash
=
False
):
use_hash
=
False
):
super
(
SparseConv3d
,
self
).
__init__
(
super
(
SparseConv3d
,
self
).
__init__
(
3
,
3
,
in_channels
,
in_channels
,
out_channels
,
out_channels
,
kernel_size
,
kernel_size
,
stride
,
stride
,
padding
,
padding
,
dilation
,
dilation
,
groups
,
groups
,
bias
,
bias
,
indice_key
=
indice_key
,
indice_key
=
indice_key
,
use_hash
=
use_hash
)
use_hash
=
use_hash
)
class
SparseConv4d
(
SparseConvolution
):
class
SparseConv4d
(
SparseConvolution
):
...
@@ -274,18 +277,17 @@ class SparseConv4d(SparseConvolution):
...
@@ -274,18 +277,17 @@ class SparseConv4d(SparseConvolution):
bias
=
True
,
bias
=
True
,
indice_key
=
None
,
indice_key
=
None
,
use_hash
=
False
):
use_hash
=
False
):
super
(
SparseConv4d
,
self
).
__init__
(
super
(
SparseConv4d
,
self
).
__init__
(
4
,
4
,
in_channels
,
in_channels
,
out_channels
,
out_channels
,
kernel_size
,
kernel_size
,
stride
,
stride
,
padding
,
padding
,
dilation
,
dilation
,
groups
,
groups
,
bias
,
bias
,
indice_key
=
indice_key
,
indice_key
=
indice_key
,
use_hash
=
use_hash
)
use_hash
=
use_hash
)
class
SparseConvTranspose2d
(
SparseConvolution
):
class
SparseConvTranspose2d
(
SparseConvolution
):
...
@@ -300,19 +302,18 @@ class SparseConvTranspose2d(SparseConvolution):
...
@@ -300,19 +302,18 @@ class SparseConvTranspose2d(SparseConvolution):
bias
=
True
,
bias
=
True
,
indice_key
=
None
,
indice_key
=
None
,
use_hash
=
False
):
use_hash
=
False
):
super
(
SparseConvTranspose2d
,
self
).
__init__
(
super
(
SparseConvTranspose2d
,
self
).
__init__
(
2
,
2
,
in_channels
,
in_channels
,
out_channels
,
out_channels
,
kernel_size
,
kernel_size
,
stride
,
stride
,
padding
,
padding
,
dilation
,
dilation
,
groups
,
groups
,
bias
,
bias
,
transposed
=
True
,
transposed
=
True
,
indice_key
=
indice_key
,
indice_key
=
indice_key
,
use_hash
=
use_hash
)
use_hash
=
use_hash
)
class
SparseConvTranspose3d
(
SparseConvolution
):
class
SparseConvTranspose3d
(
SparseConvolution
):
...
@@ -327,19 +328,18 @@ class SparseConvTranspose3d(SparseConvolution):
...
@@ -327,19 +328,18 @@ class SparseConvTranspose3d(SparseConvolution):
bias
=
True
,
bias
=
True
,
indice_key
=
None
,
indice_key
=
None
,
use_hash
=
False
):
use_hash
=
False
):
super
(
SparseConvTranspose3d
,
self
).
__init__
(
super
(
SparseConvTranspose3d
,
self
).
__init__
(
3
,
3
,
in_channels
,
in_channels
,
out_channels
,
out_channels
,
kernel_size
,
kernel_size
,
stride
,
stride
,
padding
,
padding
,
dilation
,
dilation
,
groups
,
groups
,
bias
,
bias
,
transposed
=
True
,
transposed
=
True
,
indice_key
=
indice_key
,
indice_key
=
indice_key
,
use_hash
=
use_hash
)
use_hash
=
use_hash
)
class
SparseInverseConv2d
(
SparseConvolution
):
class
SparseInverseConv2d
(
SparseConvolution
):
...
@@ -349,14 +349,13 @@ class SparseInverseConv2d(SparseConvolution):
...
@@ -349,14 +349,13 @@ class SparseInverseConv2d(SparseConvolution):
kernel_size
,
kernel_size
,
indice_key
,
indice_key
,
bias
=
True
):
bias
=
True
):
super
(
SparseInverseConv2d
,
self
).
__init__
(
super
(
SparseInverseConv2d
,
self
).
__init__
(
2
,
2
,
in_channels
,
in_channels
,
out_channels
,
out_channels
,
kernel_size
,
kernel_size
,
bias
=
bias
,
bias
=
bias
,
inverse
=
True
,
inverse
=
True
,
indice_key
=
indice_key
)
indice_key
=
indice_key
)
class
SparseInverseConv3d
(
SparseConvolution
):
class
SparseInverseConv3d
(
SparseConvolution
):
...
@@ -366,14 +365,13 @@ class SparseInverseConv3d(SparseConvolution):
...
@@ -366,14 +365,13 @@ class SparseInverseConv3d(SparseConvolution):
kernel_size
,
kernel_size
,
indice_key
,
indice_key
,
bias
=
True
):
bias
=
True
):
super
(
SparseInverseConv3d
,
self
).
__init__
(
super
(
SparseInverseConv3d
,
self
).
__init__
(
3
,
3
,
in_channels
,
in_channels
,
out_channels
,
out_channels
,
kernel_size
,
kernel_size
,
bias
=
bias
,
bias
=
bias
,
inverse
=
True
,
inverse
=
True
,
indice_key
=
indice_key
)
indice_key
=
indice_key
)
class
SubMConv2d
(
SparseConvolution
):
class
SubMConv2d
(
SparseConvolution
):
...
@@ -388,19 +386,18 @@ class SubMConv2d(SparseConvolution):
...
@@ -388,19 +386,18 @@ class SubMConv2d(SparseConvolution):
bias
=
True
,
bias
=
True
,
indice_key
=
None
,
indice_key
=
None
,
use_hash
=
False
):
use_hash
=
False
):
super
(
SubMConv2d
,
self
).
__init__
(
super
(
SubMConv2d
,
self
).
__init__
(
2
,
2
,
in_channels
,
in_channels
,
out_channels
,
out_channels
,
kernel_size
,
kernel_size
,
stride
,
stride
,
padding
,
padding
,
dilation
,
dilation
,
groups
,
groups
,
bias
,
bias
,
True
,
True
,
indice_key
=
indice_key
,
indice_key
=
indice_key
,
use_hash
=
use_hash
)
use_hash
=
use_hash
)
class
SubMConv3d
(
SparseConvolution
):
class
SubMConv3d
(
SparseConvolution
):
...
@@ -415,19 +412,18 @@ class SubMConv3d(SparseConvolution):
...
@@ -415,19 +412,18 @@ class SubMConv3d(SparseConvolution):
bias
=
True
,
bias
=
True
,
indice_key
=
None
,
indice_key
=
None
,
use_hash
=
False
):
use_hash
=
False
):
super
(
SubMConv3d
,
self
).
__init__
(
super
(
SubMConv3d
,
self
).
__init__
(
3
,
3
,
in_channels
,
in_channels
,
out_channels
,
out_channels
,
kernel_size
,
kernel_size
,
stride
,
stride
,
padding
,
padding
,
dilation
,
dilation
,
groups
,
groups
,
bias
,
bias
,
True
,
True
,
indice_key
=
indice_key
,
indice_key
=
indice_key
,
use_hash
=
use_hash
)
use_hash
=
use_hash
)
class
SubMConv4d
(
SparseConvolution
):
class
SubMConv4d
(
SparseConvolution
):
...
@@ -442,16 +438,15 @@ class SubMConv4d(SparseConvolution):
...
@@ -442,16 +438,15 @@ class SubMConv4d(SparseConvolution):
bias
=
True
,
bias
=
True
,
indice_key
=
None
,
indice_key
=
None
,
use_hash
=
False
):
use_hash
=
False
):
super
(
SubMConv4d
,
self
).
__init__
(
super
(
SubMConv4d
,
self
).
__init__
(
4
,
4
,
in_channels
,
in_channels
,
out_channels
,
out_channels
,
kernel_size
,
kernel_size
,
stride
,
stride
,
padding
,
padding
,
dilation
,
dilation
,
groups
,
groups
,
bias
,
bias
,
True
,
True
,
indice_key
=
indice_key
,
indice_key
=
indice_key
,
use_hash
=
use_hash
)
use_hash
=
use_hash
)
spconv/functional.py
View file @
19e73bbe
# Copyright 2019 Yan Yan
# Copyright 2019 Yan Yan
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
spconv.ops
as
ops
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torch.autograd
import
Function
from
torch.autograd
import
Function
import
spconv.ops
as
ops
class
SparseConvFunction
(
Function
):
class
SparseConvFunction
(
Function
):
@
staticmethod
@
staticmethod
def
forward
(
def
forward
(
ctx
,
features
,
filters
,
indice_pairs
,
indice_pair_num
,
ctx
,
num_activate_out
):
features
,
ctx
.
save_for_backward
(
indice_pairs
,
indice_pair_num
,
features
,
filters
)
filters
,
return
ops
.
indice_conv
(
features
,
filters
,
indice_pairs
,
indice_pairs
,
indice_pair_num
,
num_activate_out
,
False
)
indice_pair_num
,
num_activate_out
):
ctx
.
save_for_backward
(
indice_pairs
,
indice_pair_num
,
features
,
filters
)
return
ops
.
indice_conv
(
features
,
filters
,
indice_pairs
,
indice_pair_num
,
num_activate_out
,
False
)
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
indice_pairs
,
indice_pair_num
,
features
,
filters
=
ctx
.
saved_tensors
indice_pairs
,
indice_pair_num
,
features
,
filters
=
ctx
.
saved_tensors
input_bp
,
filters_bp
=
ops
.
indice_conv_backward
(
features
,
filters
,
grad_output
,
indice_pairs
,
indice_pair_num
,
False
)
input_bp
,
filters_bp
=
ops
.
indice_conv_backward
(
features
,
filters
,
grad_output
,
indice_pairs
,
indice_pair_num
,
False
)
return
input_bp
,
filters_bp
,
None
,
None
,
None
return
input_bp
,
filters_bp
,
None
,
None
,
None
class
SparseInverseConvFunction
(
Function
):
class
SparseInverseConvFunction
(
Function
):
@
staticmethod
@
staticmethod
def
forward
(
def
forward
(
ctx
,
features
,
filters
,
indice_pairs
,
indice_pair_num
,
ctx
,
num_activate_out
):
features
,
ctx
.
save_for_backward
(
indice_pairs
,
indice_pair_num
,
features
,
filters
)
filters
,
return
ops
.
indice_conv
(
features
,
filters
,
indice_pairs
,
indice_pairs
,
indice_pair_num
,
num_activate_out
,
True
,
False
)
indice_pair_num
,
num_activate_out
):
ctx
.
save_for_backward
(
indice_pairs
,
indice_pair_num
,
features
,
filters
)
return
ops
.
indice_conv
(
features
,
filters
,
indice_pairs
,
indice_pair_num
,
num_activate_out
,
True
,
False
)
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
indice_pairs
,
indice_pair_num
,
features
,
filters
=
ctx
.
saved_tensors
indice_pairs
,
indice_pair_num
,
features
,
filters
=
ctx
.
saved_tensors
input_bp
,
filters_bp
=
ops
.
indice_conv_backward
(
features
,
filters
,
grad_output
,
indice_pairs
,
indice_pair_num
,
True
,
False
)
input_bp
,
filters_bp
=
ops
.
indice_conv_backward
(
features
,
filters
,
grad_output
,
indice_pairs
,
indice_pair_num
,
True
,
False
)
return
input_bp
,
filters_bp
,
None
,
None
,
None
return
input_bp
,
filters_bp
,
None
,
None
,
None
class
SubMConvFunction
(
Function
):
class
SubMConvFunction
(
Function
):
@
staticmethod
@
staticmethod
def
forward
(
def
forward
(
ctx
,
features
,
filters
,
indice_pairs
,
indice_pair_num
,
ctx
,
num_activate_out
):
features
,
ctx
.
save_for_backward
(
indice_pairs
,
indice_pair_num
,
features
,
filters
)
filters
,
return
ops
.
indice_conv
(
features
,
filters
,
indice_pairs
,
indice_pairs
,
indice_pair_num
,
num_activate_out
,
False
,
True
)
indice_pair_num
,
num_activate_out
):
ctx
.
save_for_backward
(
indice_pairs
,
indice_pair_num
,
features
,
filters
)
return
ops
.
indice_conv
(
features
,
filters
,
indice_pairs
,
indice_pair_num
,
num_activate_out
,
False
,
True
)
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
indice_pairs
,
indice_pair_num
,
features
,
filters
=
ctx
.
saved_tensors
indice_pairs
,
indice_pair_num
,
features
,
filters
=
ctx
.
saved_tensors
input_bp
,
filters_bp
=
ops
.
indice_conv_backward
(
features
,
filters
,
grad_output
,
indice_pairs
,
indice_pair_num
,
False
,
True
)
input_bp
,
filters_bp
=
ops
.
indice_conv_backward
(
features
,
filters
,
grad_output
,
indice_pairs
,
indice_pair_num
,
False
,
True
)
return
input_bp
,
filters_bp
,
None
,
None
,
None
return
input_bp
,
filters_bp
,
None
,
None
,
None
class
SparseMaxPoolFunction
(
Function
):
class
SparseMaxPoolFunction
(
Function
):
@
staticmethod
@
staticmethod
def
forward
(
def
forward
(
ctx
,
features
,
indice_pairs
,
indice_pair_num
,
ctx
,
num_activate_out
):
features
,
out
=
ops
.
indice_maxpool
(
features
,
indice_pairs
,
indice_pair_num
,
indice_pairs
,
num_activate_out
)
indice_pair_num
,
ctx
.
save_for_backward
(
indice_pairs
,
indice_pair_num
,
features
,
out
)
num_activate_out
):
out
=
ops
.
indice_maxpool
(
features
,
indice_pairs
,
indice_pair_num
,
num_activate_out
)
ctx
.
save_for_backward
(
indice_pairs
,
indice_pair_num
,
features
,
out
)
return
out
return
out
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
indice_pairs
,
indice_pair_num
,
features
,
out
=
ctx
.
saved_tensors
indice_pairs
,
indice_pair_num
,
features
,
out
=
ctx
.
saved_tensors
input_bp
=
ops
.
indice_maxpool_backward
(
features
,
out
,
grad_output
,
indice_pairs
,
indice_pair_num
)
input_bp
=
ops
.
indice_maxpool_backward
(
features
,
out
,
grad_output
,
indice_pairs
,
indice_pair_num
)
return
input_bp
,
None
,
None
,
None
return
input_bp
,
None
,
None
,
None
indice_conv
=
SparseConvFunction
.
apply
indice_conv
=
SparseConvFunction
.
apply
indice_inverse_conv
=
SparseInverseConvFunction
.
apply
indice_inverse_conv
=
SparseInverseConvFunction
.
apply
indice_subm_conv
=
SubMConvFunction
.
apply
indice_subm_conv
=
SubMConvFunction
.
apply
...
...
spconv/modules.py
View file @
19e73bbe
...
@@ -12,12 +12,13 @@
...
@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
time
from
collections
import
OrderedDict
from
collections
import
OrderedDict
import
spconv
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
import
time
import
spconv
def
is_spconv_module
(
module
):
def
is_spconv_module
(
module
):
...
@@ -81,7 +82,6 @@ class SparseSequential(SparseModule):
...
@@ -81,7 +82,6 @@ class SparseSequential(SparseModule):
relu2=nn.ReLU()
relu2=nn.ReLU()
)
)
"""
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
SparseSequential
,
self
).
__init__
()
super
(
SparseSequential
,
self
).
__init__
()
if
len
(
args
)
==
1
and
isinstance
(
args
[
0
],
OrderedDict
):
if
len
(
args
)
==
1
and
isinstance
(
args
[
0
],
OrderedDict
):
...
@@ -148,7 +148,8 @@ class SparseSequential(SparseModule):
...
@@ -148,7 +148,8 @@ class SparseSequential(SparseModule):
idx
=
0
idx
=
0
while
idx
<
len
(
mods
):
while
idx
<
len
(
mods
):
if
is_sparse_conv
(
mods
[
idx
]):
if
is_sparse_conv
(
mods
[
idx
]):
if
idx
<
len
(
mods
)
-
1
and
isinstance
(
mods
[
idx
+
1
],
nn
.
BatchNorm1d
):
if
idx
<
len
(
mods
)
-
1
and
isinstance
(
mods
[
idx
+
1
],
nn
.
BatchNorm1d
):
new_module
=
SparseConvolution
(
new_module
=
SparseConvolution
(
ndim
=
mods
[
idx
].
ndim
,
ndim
=
mods
[
idx
].
ndim
,
in_channels
=
mods
[
idx
].
in_channels
,
in_channels
=
mods
[
idx
].
in_channels
,
...
...
spconv/ops.py
View file @
19e73bbe
# Copyright 2019 Yan Yan
# Copyright 2019 Yan Yan
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
spconv
import
torch
import
torch
import
spconv
def
get_conv_output_size
(
input_size
,
kernel_size
,
stride
,
padding
,
dilation
):
def
get_conv_output_size
(
input_size
,
kernel_size
,
stride
,
padding
,
dilation
):
ndim
=
len
(
input_size
)
ndim
=
len
(
input_size
)
...
@@ -30,7 +31,7 @@ def get_conv_output_size(input_size, kernel_size, stride, padding, dilation):
...
@@ -30,7 +31,7 @@ def get_conv_output_size(input_size, kernel_size, stride, padding, dilation):
def
get_deconv_output_size
(
input_size
,
kernel_size
,
stride
,
padding
,
dilation
,
def
get_deconv_output_size
(
input_size
,
kernel_size
,
stride
,
padding
,
dilation
,
output_padding
):
output_padding
):
ndim
=
len
(
input_size
)
ndim
=
len
(
input_size
)
output_size
=
[]
output_size
=
[]
for
i
in
range
(
ndim
):
for
i
in
range
(
ndim
):
...
@@ -43,17 +44,17 @@ def get_deconv_output_size(input_size, kernel_size, stride, padding, dilation,
...
@@ -43,17 +44,17 @@ def get_deconv_output_size(input_size, kernel_size, stride, padding, dilation,
def
get_indice_pairs
(
indices
,
def
get_indice_pairs
(
indices
,
batch_size
,
batch_size
,
spatial_shape
,
spatial_shape
,
ksize
=
3
,
ksize
=
3
,
stride
=
1
,
stride
=
1
,
padding
=
0
,
padding
=
0
,
dilation
=
1
,
dilation
=
1
,
out_padding
=
0
,
out_padding
=
0
,
subm
=
False
,
subm
=
False
,
transpose
=
False
,
transpose
=
False
,
grid
=
None
,
grid
=
None
,
use_hash
=
True
):
use_hash
=
True
):
ndim
=
indices
.
shape
[
1
]
-
1
ndim
=
indices
.
shape
[
1
]
-
1
if
not
isinstance
(
ksize
,
(
list
,
tuple
)):
if
not
isinstance
(
ksize
,
(
list
,
tuple
)):
ksize
=
[
ksize
]
*
ndim
ksize
=
[
ksize
]
*
ndim
...
@@ -68,14 +69,14 @@ def get_indice_pairs(indices,
...
@@ -68,14 +69,14 @@ def get_indice_pairs(indices,
for
d
,
s
in
zip
(
dilation
,
stride
):
for
d
,
s
in
zip
(
dilation
,
stride
):
assert
any
([
s
==
1
,
d
==
1
]),
"don't support this."
assert
any
([
s
==
1
,
d
==
1
]),
"don't support this."
if
not
subm
:
if
not
subm
:
if
transpose
:
if
transpose
:
out_shape
=
get_deconv_output_size
(
spatial_shape
,
ksize
,
stride
,
padding
,
out_shape
=
get_deconv_output_size
(
spatial_shape
,
ksize
,
stride
,
dilation
,
out_padding
)
padding
,
dilation
,
out_padding
)
else
:
else
:
out_shape
=
get_conv_output_size
(
spatial_shape
,
ksize
,
stride
,
padding
,
out_shape
=
get_conv_output_size
(
spatial_shape
,
ksize
,
stride
,
dilation
)
padding
,
dilation
)
else
:
else
:
out_shape
=
spatial_shape
out_shape
=
spatial_shape
...
@@ -89,8 +90,10 @@ def get_indice_pairs(indices,
...
@@ -89,8 +90,10 @@ def get_indice_pairs(indices,
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
res
=
get_indice_pairs_func
(
indices
,
batch_size
,
out_shape
,
spatial_shape
,
ksize
,
res
=
get_indice_pairs_func
(
indices
,
batch_size
,
out_shape
,
stride
,
padding
,
dilation
,
out_padding
,
int
(
subm
),
int
(
transpose
),
int
(
use_hash
))
spatial_shape
,
ksize
,
stride
,
padding
,
dilation
,
out_padding
,
int
(
subm
),
int
(
transpose
),
int
(
use_hash
))
return
res
return
res
else
:
else
:
if
ndim
==
2
:
if
ndim
==
2
:
...
@@ -99,26 +102,26 @@ def get_indice_pairs(indices,
...
@@ -99,26 +102,26 @@ def get_indice_pairs(indices,
get_indice_pairs_func
=
torch
.
ops
.
spconv
.
get_indice_pairs_grid_3d
get_indice_pairs_func
=
torch
.
ops
.
spconv
.
get_indice_pairs_grid_3d
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
return
get_indice_pairs_func
(
indices
,
grid
,
batch_size
,
out_shape
,
spatial_shape
,
ksize
,
return
get_indice_pairs_func
(
indices
,
grid
,
batch_size
,
out_shape
,
stride
,
padding
,
dilation
,
out_padding
,
int
(
subm
),
int
(
transpose
),
int
(
use_hash
))
spatial_shape
,
ksize
,
stride
,
padding
,
dilation
,
out_padding
,
int
(
subm
),
int
(
transpose
),
int
(
use_hash
))
def
indice_conv
(
features
,
def
indice_conv
(
features
,
filters
,
filters
,
indice_pairs
,
indice_pairs
,
indice_pair_num
,
indice_pair_num
,
num_activate_out
,
num_activate_out
,
inverse
=
False
,
inverse
=
False
,
subm
=
False
):
subm
=
False
):
return
torch
.
ops
.
spconv
.
indice_conv
(
features
,
filters
,
indice_pairs
,
return
torch
.
ops
.
spconv
.
indice_conv
(
features
,
filters
,
indice_pairs
,
indice_pair_num
,
num_activate_out
,
indice_pair_num
,
num_activate_out
,
int
(
inverse
),
int
(
subm
))
int
(
inverse
),
int
(
subm
))
def
fused_indice_conv
(
features
,
filters
,
bias
,
indice_pairs
,
def
fused_indice_conv
(
features
,
filters
,
bias
,
indice_pairs
,
indice_pair_num
,
indice_pair_num
,
num_activate_out
,
inverse
,
subm
):
num_activate_out
,
inverse
,
subm
):
if
features
.
dtype
==
torch
.
half
:
if
features
.
dtype
==
torch
.
half
:
func
=
torch
.
ops
.
spconv
.
fused_indice_conv_half
func
=
torch
.
ops
.
spconv
.
fused_indice_conv_half
elif
filters
.
dtype
==
torch
.
float32
:
elif
filters
.
dtype
==
torch
.
float32
:
...
@@ -126,34 +129,37 @@ def fused_indice_conv(features, filters, bias,
...
@@ -126,34 +129,37 @@ def fused_indice_conv(features, filters, bias,
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
return
func
(
features
,
filters
,
bias
,
indice_pairs
,
return
func
(
features
,
filters
,
bias
,
indice_pairs
,
indice_pair_num
,
indice_pair_num
,
num_activate_out
,
num_activate_out
,
int
(
inverse
),
int
(
subm
))
int
(
inverse
),
int
(
subm
))
def
indice_conv_backward
(
features
,
def
indice_conv_backward
(
features
,
filters
,
filters
,
out_bp
,
out_bp
,
indice_pairs
,
indice_pairs
,
indice_pair_num
,
indice_pair_num
,
inverse
=
False
,
inverse
=
False
,
subm
=
False
):
subm
=
False
):
return
torch
.
ops
.
spconv
.
indice_conv_backward
(
return
torch
.
ops
.
spconv
.
indice_conv_backward
(
features
,
filters
,
out_bp
,
features
,
filters
,
out_bp
,
indice_pairs
,
indice_pair_num
,
int
(
inverse
),
int
(
subm
))
indice_pairs
,
indice_pair_num
,
int
(
inverse
),
int
(
subm
))
def
indice_maxpool
(
features
,
indice_pairs
,
indice_pair_num
,
num_activate_out
):
def
indice_maxpool
(
features
,
indice_pairs
,
indice_pair_num
,
num_activate_out
):
if
features
.
dtype
==
torch
.
float32
:
if
features
.
dtype
==
torch
.
float32
:
return
torch
.
ops
.
spconv
.
indice_maxpool_fp32
(
features
,
indice_pairs
,
indice_pair_num
,
return
torch
.
ops
.
spconv
.
indice_maxpool_fp32
(
features
,
indice_pairs
,
num_activate_out
)
indice_pair_num
,
num_activate_out
)
elif
features
.
dtype
==
torch
.
half
:
elif
features
.
dtype
==
torch
.
half
:
return
torch
.
ops
.
spconv
.
indice_maxpool_half
(
features
,
indice_pairs
,
indice_pair_num
,
return
torch
.
ops
.
spconv
.
indice_maxpool_half
(
features
,
indice_pairs
,
num_activate_out
)
indice_pair_num
,
num_activate_out
)
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
def
indice_maxpool_backward
(
features
,
out_features
,
out_bp
,
indice_pairs
,
indice_pair_num
):
def
indice_maxpool_backward
(
features
,
out_features
,
out_bp
,
indice_pairs
,
indice_pair_num
):
if
features
.
dtype
==
torch
.
float32
:
if
features
.
dtype
==
torch
.
float32
:
return
torch
.
ops
.
spconv
.
indice_maxpool_backward_fp32
(
return
torch
.
ops
.
spconv
.
indice_maxpool_backward_fp32
(
features
,
out_features
,
out_bp
,
indice_pairs
,
indice_pair_num
)
features
,
out_features
,
out_bp
,
indice_pairs
,
indice_pair_num
)
...
@@ -163,11 +169,13 @@ def indice_maxpool_backward(features, out_features, out_bp, indice_pairs, indice
...
@@ -163,11 +169,13 @@ def indice_maxpool_backward(features, out_features, out_bp, indice_pairs, indice
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
def
nms
(
boxes
,
scores
,
pre_max_size
,
post_max_size
,
thresh
,
eps
):
def
nms
(
boxes
,
scores
,
pre_max_size
,
post_max_size
,
thresh
,
eps
):
res
=
torch
.
ops
.
spconv
.
nms
(
res
=
torch
.
ops
.
spconv
.
nms
(
boxes
,
scores
,
pre_max_size
,
post_max_size
,
boxes
,
scores
,
pre_max_size
,
post_max_size
,
thresh
,
eps
)
thresh
,
eps
)
return
res
return
res
def
pillar_scatter
(
features
,
coors
,
shape
):
def
pillar_scatter
(
features
,
coors
,
shape
):
if
features
.
dtype
==
torch
.
float32
:
if
features
.
dtype
==
torch
.
float32
:
return
torch
.
ops
.
spconv
.
pillar_scatter_float
(
features
,
coors
,
shape
)
return
torch
.
ops
.
spconv
.
pillar_scatter_float
(
features
,
coors
,
shape
)
...
...
spconv/pool.py
View file @
19e73bbe
# Copyright 2019 Yan Yan
# Copyright 2019 Yan Yan
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
math
import
math
import
time
import
time
import
numpy
as
np
import
numpy
as
np
import
spconv
import
spconv.functional
as
Fsp
import
torch
import
torch
from
spconv
import
ops
from
spconv.modules
import
SparseModule
from
torch
import
nn
from
torch
import
nn
from
torch.nn
import
init
from
torch.nn
import
init
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
import
spconv
import
spconv.functional
as
Fsp
from
spconv
import
ops
from
spconv.modules
import
SparseModule
class
SparseMaxPool
(
SparseModule
):
class
SparseMaxPool
(
SparseModule
):
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -61,15 +61,17 @@ class SparseMaxPool(SparseModule):
...
@@ -61,15 +61,17 @@ class SparseMaxPool(SparseModule):
batch_size
=
input
.
batch_size
batch_size
=
input
.
batch_size
if
not
self
.
subm
:
if
not
self
.
subm
:
out_spatial_shape
=
ops
.
get_conv_output_size
(
out_spatial_shape
=
ops
.
get_conv_output_size
(
spatial_shape
,
self
.
kernel_size
,
self
.
stride
,
self
.
padding
,
self
.
dilation
)
spatial_shape
,
self
.
kernel_size
,
self
.
stride
,
self
.
padding
,
self
.
dilation
)
else
:
else
:
out_spatial_shape
=
spatial_shape
out_spatial_shape
=
spatial_shape
outids
,
indice_pairs
,
indice_pairs_num
=
ops
.
get_indice_pairs
(
outids
,
indice_pairs
,
indice_pairs_num
=
ops
.
get_indice_pairs
(
indices
,
batch_size
,
spatial_shape
,
self
.
kernel_size
,
indices
,
batch_size
,
spatial_shape
,
self
.
kernel_size
,
self
.
stride
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
0
,
self
.
subm
)
self
.
padding
,
self
.
dilation
,
0
,
self
.
subm
)
out_features
=
Fsp
.
indice_maxpool
(
features
,
indice_pairs
.
to
(
device
),
out_features
=
Fsp
.
indice_maxpool
(
features
,
indice_pairs
.
to
(
device
),
indice_pairs_num
.
to
(
device
),
outids
.
shape
[
0
])
indice_pairs_num
.
to
(
device
),
outids
.
shape
[
0
])
out_tensor
=
spconv
.
SparseConvTensor
(
out_features
,
outids
,
out_tensor
=
spconv
.
SparseConvTensor
(
out_features
,
outids
,
out_spatial_shape
,
batch_size
)
out_spatial_shape
,
batch_size
)
out_tensor
.
indice_dict
=
input
.
indice_dict
out_tensor
.
indice_dict
=
input
.
indice_dict
...
@@ -78,28 +80,12 @@ class SparseMaxPool(SparseModule):
...
@@ -78,28 +80,12 @@ class SparseMaxPool(SparseModule):
class
SparseMaxPool2d
(
SparseMaxPool
):
class
SparseMaxPool2d
(
SparseMaxPool
):
def
__init__
(
self
,
def
__init__
(
self
,
kernel_size
,
stride
=
1
,
padding
=
0
,
dilation
=
1
):
kernel_size
,
super
(
SparseMaxPool2d
,
self
).
__init__
(
2
,
kernel_size
,
stride
,
padding
,
stride
=
1
,
dilation
)
padding
=
0
,
dilation
=
1
):
super
(
SparseMaxPool2d
,
self
).
__init__
(
2
,
kernel_size
,
stride
,
padding
,
dilation
)
class
SparseMaxPool3d
(
SparseMaxPool
):
class
SparseMaxPool3d
(
SparseMaxPool
):
def
__init__
(
self
,
def
__init__
(
self
,
kernel_size
,
stride
=
1
,
padding
=
0
,
dilation
=
1
):
kernel_size
,
super
(
SparseMaxPool3d
,
self
).
__init__
(
3
,
kernel_size
,
stride
,
padding
,
stride
=
1
,
dilation
)
padding
=
0
,
dilation
=
1
):
super
(
SparseMaxPool3d
,
self
).
__init__
(
3
,
kernel_size
,
stride
,
padding
,
dilation
)
spconv/tables.py
View file @
19e73bbe
import
torch
from
torch.autograd
import
Function
from
torch.autograd
import
Function
import
spconv
#from torch.nn import Module
#from torch.nn import Module
from
spconv.modules
import
SparseModule
from
spconv.modules
import
SparseModule
import
spconv
import
torch
class
JoinTable
(
SparseModule
):
# Module):
class
JoinTable
(
SparseModule
):
# Module):
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
output
=
spconv
.
SparseConvTensor
(
output
=
spconv
.
SparseConvTensor
(
torch
.
cat
([
i
.
features
for
i
in
input
],
1
),
input
[
1
].
indices
,
torch
.
cat
([
i
.
features
for
i
in
input
],
1
),
input
[
1
].
indices
,
input
[
1
].
spatial_shape
,
input
[
0
].
batch_size
)
input
[
1
].
spatial_shape
,
input
[
0
].
batch_size
)
output
.
indice_dict
=
input
[
1
].
indice_dict
output
.
indice_dict
=
input
[
1
].
indice_dict
output
.
grid
=
input
[
1
].
grid
output
.
grid
=
input
[
1
].
grid
return
output
return
output
...
@@ -18,11 +19,12 @@ class JoinTable(SparseModule):# Module):
...
@@ -18,11 +19,12 @@ class JoinTable(SparseModule):# Module):
return
out_size
return
out_size
class
AddTable
(
SparseModule
):
# Module):
class
AddTable
(
SparseModule
):
# Module):
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
output
=
spconv
.
SparseConvTensor
(
output
=
spconv
.
SparseConvTensor
(
sum
([
i
.
features
for
i
in
input
]),
sum
([
i
.
features
for
i
in
input
]),
input
[
1
].
indices
,
input
[
1
].
indices
,
input
[
1
].
spatial_shape
,
input
[
1
].
batch_size
)
input
[
1
].
spatial_shape
,
input
[
1
].
batch_size
)
output
.
indice_dict
=
input
[
1
].
indice_dict
output
.
indice_dict
=
input
[
1
].
indice_dict
output
.
grid
=
input
[
1
].
grid
output
.
grid
=
input
[
1
].
grid
...
@@ -32,7 +34,7 @@ class AddTable(SparseModule): # Module):
...
@@ -32,7 +34,7 @@ class AddTable(SparseModule): # Module):
return
out_size
return
out_size
class
ConcatTable
(
SparseModule
):
# Module):
class
ConcatTable
(
SparseModule
):
# Module):
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
return
[
module
(
input
)
for
module
in
self
.
_modules
.
values
()]
return
[
module
(
input
)
for
module
in
self
.
_modules
.
values
()]
...
...
spconv/test_utils.py
View file @
19e73bbe
# Copyright 2019 Yan Yan
# Copyright 2019 Yan Yan
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -31,14 +31,14 @@ class TestCase(unittest.TestCase):
...
@@ -31,14 +31,14 @@ class TestCase(unittest.TestCase):
"""
"""
a
=
self
.
_GetNdArray
(
a
)
a
=
self
.
_GetNdArray
(
a
)
b
=
self
.
_GetNdArray
(
b
)
b
=
self
.
_GetNdArray
(
b
)
self
.
assertEqual
(
a
.
shape
,
b
.
shape
,
self
.
assertEqual
(
"Shape mismatch: expected %s, got %s."
%
(
a
.
shape
,
a
.
shape
,
b
.
shape
,
b
.
shape
))
"Shape mismatch: expected %s, got %s."
%
(
a
.
shape
,
b
.
shape
))
same
=
(
a
==
b
)
same
=
(
a
==
b
)
if
a
.
dtype
==
np
.
float32
or
a
.
dtype
==
np
.
float64
:
if
a
.
dtype
==
np
.
float32
or
a
.
dtype
==
np
.
float64
:
same
=
np
.
logical_or
(
same
,
np
.
logical_and
(
same
=
np
.
logical_or
(
same
,
np
.
logical_and
(
np
.
isnan
(
a
),
np
.
isnan
(
a
),
np
.
isnan
(
b
)))
np
.
isnan
(
b
)))
if
not
np
.
all
(
same
):
if
not
np
.
all
(
same
):
# Prints more details than np.testing.assert_array_equal.
# Prints more details than np.testing.assert_array_equal.
diff
=
np
.
logical_not
(
same
)
diff
=
np
.
logical_not
(
same
)
...
@@ -68,30 +68,29 @@ class TestCase(unittest.TestCase):
...
@@ -68,30 +68,29 @@ class TestCase(unittest.TestCase):
"""
"""
is_a_dict
=
isinstance
(
a
,
dict
)
is_a_dict
=
isinstance
(
a
,
dict
)
if
is_a_dict
!=
isinstance
(
b
,
dict
):
if
is_a_dict
!=
isinstance
(
b
,
dict
):
raise
ValueError
(
"Can't compare dict to non-dict, %s vs %s."
%
(
a
,
raise
ValueError
(
"Can't compare dict to non-dict, %s vs %s."
%
b
))
(
a
,
b
))
if
is_a_dict
:
if
is_a_dict
:
self
.
assertCountEqual
(
self
.
assertCountEqual
(
a
.
keys
(),
a
.
keys
(),
b
.
keys
(),
b
.
keys
(),
msg
=
"mismatched keys, expected %s, got %s"
%
msg
=
"mismatched keys, expected %s, got %s"
%
(
a
.
keys
(),
(
a
.
keys
(),
b
.
keys
()))
b
.
keys
()))
for
k
in
a
:
for
k
in
a
:
self
.
_assertArrayLikeAllClose
(
self
.
_assertArrayLikeAllClose
(
a
[
k
],
a
[
k
],
b
[
k
],
b
[
k
]
,
rtol
=
rtol
,
r
tol
=
r
tol
,
a
tol
=
a
tol
,
atol
=
atol
,
msg
=
"%s: expected %s, got %s."
%
msg
=
"%s: expected %s, got %s."
%
(
k
,
a
,
b
))
(
k
,
a
,
b
))
else
:
else
:
self
.
_assertArrayLikeAllClose
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
)
self
.
_assertArrayLikeAllClose
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
)
def
_assertArrayLikeAllClose
(
self
,
a
,
b
,
rtol
=
1e-6
,
atol
=
1e-6
,
msg
=
None
):
def
_assertArrayLikeAllClose
(
self
,
a
,
b
,
rtol
=
1e-6
,
atol
=
1e-6
,
msg
=
None
):
a
=
self
.
_GetNdArray
(
a
)
a
=
self
.
_GetNdArray
(
a
)
b
=
self
.
_GetNdArray
(
b
)
b
=
self
.
_GetNdArray
(
b
)
self
.
assertEqual
(
a
.
shape
,
b
.
shape
,
self
.
assertEqual
(
"Shape mismatch: expected %s, got %s."
%
(
a
.
shape
,
a
.
shape
,
b
.
shape
,
b
.
shape
))
"Shape mismatch: expected %s, got %s."
%
(
a
.
shape
,
b
.
shape
))
if
not
np
.
allclose
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
):
if
not
np
.
allclose
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
):
# Prints more details than np.testing.assert_allclose.
# Prints more details than np.testing.assert_allclose.
#
#
...
@@ -118,6 +117,7 @@ class TestCase(unittest.TestCase):
...
@@ -118,6 +117,7 @@ class TestCase(unittest.TestCase):
print
(
"dtype = %s, shape = %s"
%
(
a
.
dtype
,
a
.
shape
))
print
(
"dtype = %s, shape = %s"
%
(
a
.
dtype
,
a
.
shape
))
np
.
testing
.
assert_allclose
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
,
err_msg
=
msg
)
np
.
testing
.
assert_allclose
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
,
err_msg
=
msg
)
def
params_grid
(
*
params
):
def
params_grid
(
*
params
):
size
=
len
(
params
)
size
=
len
(
params
)
length
=
1
length
=
1
...
@@ -127,7 +127,7 @@ def params_grid(*params):
...
@@ -127,7 +127,7 @@ def params_grid(*params):
counter
=
[
0
]
*
size
counter
=
[
0
]
*
size
total
=
[]
total
=
[]
for
i
in
range
(
length
):
for
i
in
range
(
length
):
total
.
append
([
0
]
*
size
)
total
.
append
([
0
]
*
size
)
for
i
in
range
(
length
):
for
i
in
range
(
length
):
for
j
in
range
(
size
):
for
j
in
range
(
size
):
total
[
i
][
j
]
=
params
[
j
][
counter
[
j
]]
total
[
i
][
j
]
=
params
[
j
][
counter
[
j
]]
...
@@ -138,13 +138,14 @@ def params_grid(*params):
...
@@ -138,13 +138,14 @@ def params_grid(*params):
counter
[
c
]
=
0
counter
[
c
]
=
0
return
total
return
total
def
generate_sparse_data
(
shape
,
def
generate_sparse_data
(
shape
,
num_points
,
num_points
,
num_channels
,
num_channels
,
integer
=
False
,
integer
=
False
,
data_range
=
(
-
1
,
1
),
data_range
=
(
-
1
,
1
),
with_dense
=
True
,
with_dense
=
True
,
dtype
=
np
.
float32
):
dtype
=
np
.
float32
):
dense_shape
=
shape
dense_shape
=
shape
ndim
=
len
(
dense_shape
)
ndim
=
len
(
dense_shape
)
# num_points = np.random.randint(10, 100, size=[batch_size, ndim])
# num_points = np.random.randint(10, 100, size=[batch_size, ndim])
...
@@ -152,32 +153,35 @@ def generate_sparse_data(shape,
...
@@ -152,32 +153,35 @@ def generate_sparse_data(shape,
# num_points = np.array([3, 2])
# num_points = np.array([3, 2])
batch_size
=
len
(
num_points
)
batch_size
=
len
(
num_points
)
batch_indices
=
[]
batch_indices
=
[]
coors_total
=
np
.
stack
(
coors_total
=
np
.
stack
(
np
.
meshgrid
(
*
[
np
.
arange
(
0
,
s
)
for
s
in
shape
]),
np
.
meshgrid
(
*
[
np
.
arange
(
0
,
s
)
for
s
in
shape
]),
axis
=-
1
)
axis
=-
1
)
coors_total
=
coors_total
.
reshape
(
-
1
,
ndim
)
coors_total
=
coors_total
.
reshape
(
-
1
,
ndim
)
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
np
.
random
.
shuffle
(
coors_total
)
np
.
random
.
shuffle
(
coors_total
)
inds_total
=
coors_total
[:
num_points
[
i
]]
inds_total
=
coors_total
[:
num_points
[
i
]]
inds_total
=
np
.
pad
(
inds_total
=
np
.
pad
(
inds_total
,
((
0
,
0
),
(
0
,
1
)),
inds_total
,
((
0
,
0
),
(
0
,
1
)),
mode
=
"constant"
,
constant_values
=
i
)
mode
=
"constant"
,
constant_values
=
i
)
batch_indices
.
append
(
inds_total
)
batch_indices
.
append
(
inds_total
)
if
integer
:
if
integer
:
sparse_data
=
np
.
random
.
randint
(
sparse_data
=
np
.
random
.
randint
(
data_range
[
0
],
data_range
[
0
],
data_range
[
1
],
size
=
[
num_points
.
sum
(),
num_channels
]).
astype
(
dtype
)
data_range
[
1
],
size
=
[
num_points
.
sum
(),
num_channels
]).
astype
(
dtype
)
else
:
else
:
sparse_data
=
np
.
random
.
uniform
(
sparse_data
=
np
.
random
.
uniform
(
data_range
[
0
],
data_range
[
0
],
data_range
[
1
],
size
=
[
num_points
.
sum
(),
num_channels
]).
astype
(
dtype
)
data_range
[
1
],
size
=
[
num_points
.
sum
(),
num_channels
]).
astype
(
dtype
)
# sparse_data = np.arange(1, num_points.sum() + 1).astype(np.float32).reshape(5, 1)
# sparse_data = np.arange(1, num_points.sum() + 1).astype(np.float32).reshape(5, 1)
res
=
{
res
=
{
"features"
:
sparse_data
.
astype
(
dtype
),
"features"
:
sparse_data
.
astype
(
dtype
),
}
}
if
with_dense
:
if
with_dense
:
dense_data
=
np
.
zeros
(
dense_data
=
np
.
zeros
(
[
batch_size
,
num_channels
,
*
dense_shape
],
[
batch_size
,
num_channels
,
*
dense_shape
],
dtype
=
sparse_data
.
dtype
)
dtype
=
sparse_data
.
dtype
)
start
=
0
start
=
0
for
i
,
inds
in
enumerate
(
batch_indices
):
for
i
,
inds
in
enumerate
(
batch_indices
):
for
j
,
ind
in
enumerate
(
inds
):
for
j
,
ind
in
enumerate
(
inds
):
...
@@ -187,4 +191,4 @@ def generate_sparse_data(shape,
...
@@ -187,4 +191,4 @@ def generate_sparse_data(shape,
res
[
"features_dense"
]
=
dense_data
.
astype
(
dtype
)
res
[
"features_dense"
]
=
dense_data
.
astype
(
dtype
)
batch_indices
=
np
.
concatenate
(
batch_indices
,
axis
=
0
)
batch_indices
=
np
.
concatenate
(
batch_indices
,
axis
=
0
)
res
[
"indices"
]
=
batch_indices
.
astype
(
np
.
int32
)
res
[
"indices"
]
=
batch_indices
.
astype
(
np
.
int32
)
return
res
return
res
spconv/utils/__init__.py
View file @
19e73bbe
...
@@ -15,9 +15,13 @@
...
@@ -15,9 +15,13 @@
import
numpy
as
np
import
numpy
as
np
from
spconv
import
spconv_utils
from
spconv
import
spconv_utils
from
spconv.spconv_utils
import
(
non_max_suppression_cpu
,
points_to_voxel_3d_np
,
from
spconv.spconv_utils
import
(
non_max_suppression_cpu
,
points_to_voxel_3d_np_mean
,
points_to_voxel_3d_with_filtering
,
points_to_voxel_3d_np
,
rbbox_intersection
,
rbbox_iou
,
rotate_non_max_suppression_cpu
)
points_to_voxel_3d_np_mean
,
points_to_voxel_3d_with_filtering
,
rbbox_intersection
,
rbbox_iou
,
rotate_non_max_suppression_cpu
)
try
:
try
:
from
spconv.spconv_utils
import
non_max_suppression
from
spconv.spconv_utils
import
non_max_suppression
except
ImportError
:
except
ImportError
:
...
@@ -71,10 +75,10 @@ def points_to_voxel(points,
...
@@ -71,10 +75,10 @@ def points_to_voxel(points,
voxelmap_shape
=
tuple
(
np
.
round
(
voxelmap_shape
).
astype
(
np
.
int32
).
tolist
())
voxelmap_shape
=
tuple
(
np
.
round
(
voxelmap_shape
).
astype
(
np
.
int32
).
tolist
())
voxelmap_shape
=
voxelmap_shape
[::
-
1
]
voxelmap_shape
=
voxelmap_shape
[::
-
1
]
num_points_per_voxel
=
np
.
zeros
(
shape
=
(
max_voxels
,
),
dtype
=
np
.
int32
)
num_points_per_voxel
=
np
.
zeros
(
shape
=
(
max_voxels
,
),
dtype
=
np
.
int32
)
voxels
=
np
.
zeros
(
voxels
=
np
.
zeros
(
shape
=
(
max_voxels
,
max_points
,
points
.
shape
[
-
1
]),
shape
=
(
max_voxels
,
max_points
,
points
.
shape
[
-
1
]),
dtype
=
points
.
dtype
)
dtype
=
points
.
dtype
)
voxel_point_mask
=
np
.
zeros
(
voxel_point_mask
=
np
.
zeros
(
shape
=
(
max_voxels
,
max_points
),
shape
=
(
max_voxels
,
max_points
),
dtype
=
points
.
dtype
)
dtype
=
points
.
dtype
)
coors
=
np
.
zeros
(
shape
=
(
max_voxels
,
3
),
dtype
=
np
.
int32
)
coors
=
np
.
zeros
(
shape
=
(
max_voxels
,
3
),
dtype
=
np
.
int32
)
res
=
{
res
=
{
"voxels"
:
voxels
,
"voxels"
:
voxels
,
...
@@ -83,12 +87,15 @@ def points_to_voxel(points,
...
@@ -83,12 +87,15 @@ def points_to_voxel(points,
"voxel_point_mask"
:
voxel_point_mask
,
"voxel_point_mask"
:
voxel_point_mask
,
}
}
if
full_mean
:
if
full_mean
:
means
=
np
.
zeros
(
means
=
np
.
zeros
(
shape
=
(
max_voxels
,
points
.
shape
[
-
1
]),
shape
=
(
max_voxels
,
points
.
shape
[
-
1
]),
dtype
=
points
.
dtype
)
dtype
=
points
.
dtype
)
voxel_num
=
points_to_voxel_3d_np_mean
(
voxel_num
=
points_to_voxel_3d_np_mean
(
points
,
voxels
,
points
,
voxels
,
voxel_point_mask
,
means
,
coors
,
voxel_point_mask
,
means
,
coors
,
num_points_per_voxel
,
coor_to_voxelidx
,
voxel_size
.
tolist
(),
num_points_per_voxel
,
coors_range
.
tolist
(),
max_points
,
max_voxels
)
coor_to_voxelidx
,
voxel_size
.
tolist
(),
coors_range
.
tolist
(),
max_points
,
max_voxels
)
else
:
else
:
if
block_filtering
:
if
block_filtering
:
block_shape
=
[
*
voxelmap_shape
[
1
:]]
block_shape
=
[
*
voxelmap_shape
[
1
:]]
...
@@ -123,10 +130,12 @@ def points_to_voxel(points,
...
@@ -123,10 +130,12 @@ def points_to_voxel(points,
res
[
"voxel_point_mask"
]
=
voxel_point_mask
[
voxel_mask
]
res
[
"voxel_point_mask"
]
=
voxel_point_mask
[
voxel_mask
]
voxel_num
=
coors_
.
shape
[
0
]
voxel_num
=
coors_
.
shape
[
0
]
else
:
else
:
voxel_num
=
points_to_voxel_3d_np
(
voxel_num
=
points_to_voxel_3d_np
(
points
,
voxels
,
voxel_point_mask
,
points
,
voxels
,
voxel_point_mask
,
coors
,
coors
,
num_points_per_voxel
,
num_points_per_voxel
,
coor_to_voxelidx
,
voxel_size
.
tolist
(),
coor_to_voxelidx
,
coors_range
.
tolist
(),
max_points
,
max_voxels
)
voxel_size
.
tolist
(),
coors_range
.
tolist
(),
max_points
,
max_voxels
)
res
[
"voxel_num"
]
=
voxel_num
res
[
"voxel_num"
]
=
voxel_num
res
[
"voxel_point_mask"
]
=
res
[
"voxel_point_mask"
].
reshape
(
res
[
"voxel_point_mask"
]
=
res
[
"voxel_point_mask"
].
reshape
(
-
1
,
max_points
,
1
)
-
1
,
max_points
,
1
)
...
@@ -143,8 +152,8 @@ class VoxelGenerator:
...
@@ -143,8 +152,8 @@ class VoxelGenerator:
point_cloud_range
=
np
.
array
(
point_cloud_range
,
dtype
=
np
.
float32
)
point_cloud_range
=
np
.
array
(
point_cloud_range
,
dtype
=
np
.
float32
)
# [0, -40, -3, 70.4, 40, 1]
# [0, -40, -3, 70.4, 40, 1]
voxel_size
=
np
.
array
(
voxel_size
,
dtype
=
np
.
float32
)
voxel_size
=
np
.
array
(
voxel_size
,
dtype
=
np
.
float32
)
grid_size
=
(
grid_size
=
(
point_cloud_range
[
3
:]
-
point_cloud_range
[
3
:]
-
point_cloud_range
[:
3
])
/
voxel_size
point_cloud_range
[:
3
])
/
voxel_size
grid_size
=
np
.
round
(
grid_size
).
astype
(
np
.
int64
)
grid_size
=
np
.
round
(
grid_size
).
astype
(
np
.
int64
)
voxelmap_shape
=
tuple
(
np
.
round
(
grid_size
).
astype
(
np
.
int32
).
tolist
())
voxelmap_shape
=
tuple
(
np
.
round
(
grid_size
).
astype
(
np
.
int32
).
tolist
())
voxelmap_shape
=
voxelmap_shape
[::
-
1
]
voxelmap_shape
=
voxelmap_shape
[::
-
1
]
...
@@ -216,8 +225,8 @@ class VoxelGeneratorV2:
...
@@ -216,8 +225,8 @@ class VoxelGeneratorV2:
point_cloud_range
=
np
.
array
(
point_cloud_range
,
dtype
=
np
.
float32
)
point_cloud_range
=
np
.
array
(
point_cloud_range
,
dtype
=
np
.
float32
)
# [0, -40, -3, 70.4, 40, 1]
# [0, -40, -3, 70.4, 40, 1]
voxel_size
=
np
.
array
(
voxel_size
,
dtype
=
np
.
float32
)
voxel_size
=
np
.
array
(
voxel_size
,
dtype
=
np
.
float32
)
grid_size
=
(
grid_size
=
(
point_cloud_range
[
3
:]
-
point_cloud_range
[
3
:]
-
point_cloud_range
[:
3
])
/
voxel_size
point_cloud_range
[:
3
])
/
voxel_size
grid_size
=
np
.
round
(
grid_size
).
astype
(
np
.
int64
)
grid_size
=
np
.
round
(
grid_size
).
astype
(
np
.
int64
)
if
block_filtering
:
if
block_filtering
:
assert
block_size
>
0
assert
block_size
>
0
...
@@ -240,32 +249,32 @@ class VoxelGeneratorV2:
...
@@ -240,32 +249,32 @@ class VoxelGeneratorV2:
self
.
_height_high_threshold
=
height_high_threshold
self
.
_height_high_threshold
=
height_high_threshold
def
generate
(
self
,
points
,
max_voxels
=
None
):
def
generate
(
self
,
points
,
max_voxels
=
None
):
res
=
points_to_voxel
(
res
=
points_to_voxel
(
points
,
self
.
_voxel_size
,
points
,
self
.
_voxel_size
,
self
.
_point_cloud_range
,
self
.
_point_cloud_range
,
self
.
_coor_to_voxelidx
,
self
.
_coor_to_voxelidx
,
self
.
_max_num_points
,
max_voxels
self
.
_max_num_points
,
max_voxels
or
self
.
_max_voxels
,
self
.
_full_mean
,
self
.
_block_filtering
,
or
self
.
_max_voxels
,
self
.
_full_mean
,
self
.
_block_factor
,
self
.
_block_size
,
self
.
_height_threshold
,
self
.
_block_filtering
,
self
.
_block_factor
,
self
.
_height_high_threshold
)
self
.
_block_size
,
self
.
_height_threshold
,
self
.
_height_high_threshold
)
for
k
,
v
in
res
.
items
():
for
k
,
v
in
res
.
items
():
if
k
!=
"voxel_num"
:
if
k
!=
"voxel_num"
:
res
[
k
]
=
v
[:
res
[
"voxel_num"
]]
res
[
k
]
=
v
[:
res
[
"voxel_num"
]]
return
res
return
res
def
generate_multi_gpu
(
self
,
points
,
max_voxels
=
None
):
def
generate_multi_gpu
(
self
,
points
,
max_voxels
=
None
):
res
=
points_to_voxel
(
res
=
points_to_voxel
(
points
,
points
,
self
.
_voxel_size
,
self
.
_voxel_size
,
self
.
_point_cloud_range
,
self
.
_point_cloud_range
,
self
.
_coor_to_voxelidx
,
self
.
_coor_to_voxelidx
,
self
.
_max_num_points
,
self
.
_max_num_points
,
max_voxels
or
self
.
_max_voxels
,
max_voxels
or
self
.
_max_voxels
,
self
.
_full_mean
,
self
.
_full_mean
,
self
.
_block_filtering
,
self
.
_block_filtering
,
self
.
_block_factor
,
self
.
_block_factor
,
self
.
_block_size
,
self
.
_block_size
,
self
.
_height_threshold
,
self
.
_height_threshold
,
self
.
_height_high_threshold
,
self
.
_height_high_threshold
,
pad_output
=
True
)
pad_output
=
True
)
return
res
return
res
@
property
@
property
...
...
src/cuhash/debugging.cpp
View file @
19e73bbe
...
@@ -3,10 +3,10 @@
...
@@ -3,10 +3,10 @@
// -------------------------------------------------------------
// -------------------------------------------------------------
// $Revision:$
// $Revision:$
// $Date:$
// $Date:$
// -------------------------------------------------------------
// -------------------------------------------------------------
// This source code is distributed under the terms of license.txt in
// This source code is distributed under the terms of license.txt in
// the root directory of this source distribution.
// the root directory of this source distribution.
// -------------------------------------------------------------
// -------------------------------------------------------------
/**
/**
* @file
* @file
...
@@ -24,18 +24,16 @@
...
@@ -24,18 +24,16 @@
namespace
cuhash
{
namespace
cuhash
{
void
OutputRetrievalStatistics
(
const
unsigned
n_queries
,
void
OutputRetrievalStatistics
(
const
unsigned
n_queries
,
const
unsigned
*
d_retrieval_probes
,
const
unsigned
*
d_retrieval_probes
,
const
unsigned
n_functions
)
const
unsigned
n_functions
)
{
{
unsigned
*
retrieval_probes
=
new
unsigned
[
n_queries
];
unsigned
*
retrieval_probes
=
new
unsigned
[
n_queries
];
CUDA_SAFE_CALL
(
cudaMemcpy
(
retrieval_probes
,
CUDA_SAFE_CALL
(
cudaMemcpy
(
retrieval_probes
,
d_retrieval_probes
,
d_retrieval_probes
,
sizeof
(
unsigned
)
*
n_queries
,
sizeof
(
unsigned
)
*
n_queries
,
cudaMemcpyDeviceToHost
));
cudaMemcpyDeviceToHost
));
// Create a histogram showing how many items needed how many probes to be found.
// Create a histogram showing how many items needed how many probes to be
// found.
unsigned
possible_probes
=
n_functions
+
2
;
unsigned
possible_probes
=
n_functions
+
2
;
unsigned
*
histogram
=
new
unsigned
[
possible_probes
];
unsigned
*
histogram
=
new
unsigned
[
possible_probes
];
memset
(
histogram
,
0
,
sizeof
(
unsigned
)
*
(
possible_probes
));
memset
(
histogram
,
0
,
sizeof
(
unsigned
)
*
(
possible_probes
));
...
@@ -51,16 +49,16 @@ void OutputRetrievalStatistics(const unsigned n_queries,
...
@@ -51,16 +49,16 @@ void OutputRetrievalStatistics(const unsigned n_queries,
sprintf
(
buffer
,
"
\t
(%u, %u)"
,
i
,
histogram
[
i
]);
sprintf
(
buffer
,
"
\t
(%u, %u)"
,
i
,
histogram
[
i
]);
PrintMessage
(
buffer
);
PrintMessage
(
buffer
);
}
}
delete
[]
retrieval_probes
;
delete
[]
retrieval_probes
;
delete
[]
histogram
;
delete
[]
histogram
;
}
}
void
OutputBuildStatistics
(
const
unsigned
n
,
void
OutputBuildStatistics
(
const
unsigned
n
,
const
unsigned
*
d_iterations_taken
)
{
const
unsigned
*
d_iterations_taken
)
{
// Output how many iterations each thread took until it found an empty slot.
// Output how many iterations each thread took until it found an empty slot.
unsigned
*
iterations_taken
=
new
unsigned
[
n
];
unsigned
*
iterations_taken
=
new
unsigned
[
n
];
CUDA_SAFE_CALL
(
cudaMemcpy
(
iterations_taken
,
d_iterations_taken
,
sizeof
(
unsigned
)
*
n
,
cudaMemcpyDeviceToHost
));
CUDA_SAFE_CALL
(
cudaMemcpy
(
iterations_taken
,
d_iterations_taken
,
sizeof
(
unsigned
)
*
n
,
cudaMemcpyDeviceToHost
));
std
::
sort
(
iterations_taken
,
iterations_taken
+
n
);
std
::
sort
(
iterations_taken
,
iterations_taken
+
n
);
unsigned
total_iterations
=
0
;
unsigned
total_iterations
=
0
;
unsigned
max_iterations_taken
=
0
;
unsigned
max_iterations_taken
=
0
;
...
@@ -86,17 +84,18 @@ void OutputBuildStatistics(const unsigned n,
...
@@ -86,17 +84,18 @@ void OutputBuildStatistics(const unsigned n,
PrintMessage
(
buffer
);
PrintMessage
(
buffer
);
sprintf
(
buffer
,
"Total iterations: %u"
,
total_iterations
);
sprintf
(
buffer
,
"Total iterations: %u"
,
total_iterations
);
PrintMessage
(
buffer
);
PrintMessage
(
buffer
);
sprintf
(
buffer
,
"Avg/Med/Max iterations: (%f %u %u)"
,
(
float
)
total_iterations
/
n
,
iterations_taken
[
n
/
2
],
iterations_taken
[
n
-
1
]);
sprintf
(
buffer
,
"Avg/Med/Max iterations: (%f %u %u)"
,
(
float
)
total_iterations
/
n
,
iterations_taken
[
n
/
2
],
iterations_taken
[
n
-
1
]);
PrintMessage
(
buffer
);
PrintMessage
(
buffer
);
delete
[]
iterations_taken
;
delete
[]
iterations_taken
;
// Print the length of the longest eviction chain.
// Print the length of the longest eviction chain.
sprintf
(
buffer
,
"Max iterations: %u"
,
max_iterations_taken
);
sprintf
(
buffer
,
"Max iterations: %u"
,
max_iterations_taken
);
PrintMessage
(
buffer
);
PrintMessage
(
buffer
);
}
}
};
// namespace cuhash
};
// namespace CuckooHashing
// Leave this at the end of the file
// Leave this at the end of the file
// Local Variables:
// Local Variables:
...
...
src/cuhash/debugging.cu
View file @
19e73bbe
...
@@ -3,10 +3,10 @@
...
@@ -3,10 +3,10 @@
// -------------------------------------------------------------
// -------------------------------------------------------------
// $Revision:$
// $Revision:$
// $Date:$
// $Date:$
// -------------------------------------------------------------
// -------------------------------------------------------------
// This source code is distributed under the terms of license.txt in
// This source code is distributed under the terms of license.txt in
// the root directory of this source distribution.
// the root directory of this source distribution.
// -------------------------------------------------------------
// -------------------------------------------------------------
/**
/**
* @file
* @file
...
@@ -24,24 +24,17 @@
...
@@ -24,24 +24,17 @@
namespace
cuhash
{
namespace
cuhash
{
//! Debugging function: Takes statistics on the hash functions' distribution.
//! Debugging function: Takes statistics on the hash functions' distribution.
/*! Determines:
/*! Determines:
* - How many unique slots each key has.
* - How many unique slots each key has.
* - How many keys hash into each slot.
* - How many keys hash into each slot.
* - Whether any keys failed to get a full set of slots.
* - Whether any keys failed to get a full set of slots.
*/
*/
__global__
__global__
void
take_hash_function_statistics_kernel
(
void
take_hash_function_statistics_kernel
(
const
unsigned
*
keys
,
const
unsigned
*
keys
,
const
unsigned
n_entries
,
const
unsigned
table_size
,
const
unsigned
n_entries
,
const
uint2
*
constants
,
const
unsigned
num_functions
,
const
unsigned
table_size
,
unsigned
*
num_slots_available
,
unsigned
*
num_hashing_in
,
unsigned
*
failed
)
{
const
uint2
*
constants
,
unsigned
thread_index
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
+
const
unsigned
num_functions
,
unsigned
*
num_slots_available
,
unsigned
*
num_hashing_in
,
unsigned
*
failed
)
{
unsigned
thread_index
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
+
blockIdx
.
y
*
blockDim
.
x
*
gridDim
.
x
;
blockIdx
.
y
*
blockDim
.
x
*
gridDim
.
x
;
if
(
thread_index
>=
n_entries
)
if
(
thread_index
>=
n_entries
)
...
@@ -83,12 +76,10 @@ void take_hash_function_statistics_kernel(const unsigned *keys,
...
@@ -83,12 +76,10 @@ void take_hash_function_statistics_kernel(const unsigned *keys,
}
}
}
}
void
TakeHashFunctionStatistics
(
const
unsigned
num_keys
,
const
unsigned
*
d_keys
,
void
TakeHashFunctionStatistics
(
const
unsigned
num_keys
,
const
unsigned
table_size
,
const
unsigned
*
d_keys
,
const
uint2
*
constants
,
const
unsigned
table_size
,
const
unsigned
kNumHashFunctions
)
{
const
uint2
*
constants
,
const
unsigned
kNumHashFunctions
)
{
char
buffer
[
16000
];
char
buffer
[
16000
];
PrintMessage
(
"Hash function constants: "
);
PrintMessage
(
"Hash function constants: "
);
...
@@ -98,35 +89,34 @@ void TakeHashFunctionStatistics(const unsigned num_keys,
...
@@ -98,35 +89,34 @@ void TakeHashFunctionStatistics(const unsigned num_keys,
}
}
unsigned
*
d_num_hashing_in
=
NULL
;
unsigned
*
d_num_hashing_in
=
NULL
;
#ifdef COUNT_HOW_MANY_HASH_INTO_EACH_SLOT
#ifdef COUNT_HOW_MANY_HASH_INTO_EACH_SLOT
CUDA_SAFE_CALL
(
cudaMalloc
((
void
**
)
&
d_num_hashing_in
,
CUDA_SAFE_CALL
(
sizeof
(
unsigned
)
*
table_size
));
cudaMalloc
((
void
**
)
&
d_num_hashing_in
,
sizeof
(
unsigned
)
*
table_size
));
CUDA_SAFE_CALL
(
cudaMemset
(
d_num_hashing_in
,
0
,
sizeof
(
unsigned
)
*
table_size
));
CUDA_SAFE_CALL
(
#endif
cudaMemset
(
d_num_hashing_in
,
0
,
sizeof
(
unsigned
)
*
table_size
));
#endif
unsigned
*
d_num_slots_available
=
NULL
;
unsigned
*
d_num_slots_available
=
NULL
;
#ifdef COUNT_HOW_MANY_HAVE_CYCLES
#ifdef COUNT_HOW_MANY_HAVE_CYCLES
CUDA_SAFE_CALL
(
cudaMalloc
((
void
**
)
&
d_num_slots_available
,
CUDA_SAFE_CALL
(
sizeof
(
unsigned
)
*
num_keys
));
cudaMalloc
((
void
**
)
&
d_num_slots_available
,
sizeof
(
unsigned
)
*
num_keys
));
#endif
#endif
uint2
*
d_constants
=
NULL
;
uint2
*
d_constants
=
NULL
;
CUDA_SAFE_CALL
(
cudaMalloc
((
void
**
)
&
d_constants
,
sizeof
(
uint2
)
*
kNumHashFunctions
));
CUDA_SAFE_CALL
(
CUDA_SAFE_CALL
(
cudaMemcpy
(
d_constants
,
constants
,
sizeof
(
uint2
)
*
kNumHashFunctions
,
cudaMemcpyHostToDevice
));
cudaMalloc
((
void
**
)
&
d_constants
,
sizeof
(
uint2
)
*
kNumHashFunctions
));
CUDA_SAFE_CALL
(
cudaMemcpy
(
d_constants
,
constants
,
take_hash_function_statistics_kernel
<<<
ComputeGridDim
(
num_keys
),
kBlockSize
>>>
sizeof
(
uint2
)
*
kNumHashFunctions
,
(
d_keys
,
num_keys
,
cudaMemcpyHostToDevice
));
table_size
,
d_constants
,
take_hash_function_statistics_kernel
<<<
ComputeGridDim
(
num_keys
),
kNumHashFunctions
,
kBlockSize
>>>
(
d_num_slots_available
,
d_keys
,
num_keys
,
table_size
,
d_constants
,
kNumHashFunctions
,
d_num_hashing_in
,
d_num_slots_available
,
d_num_hashing_in
,
NULL
);
NULL
);
CUDA_SAFE_CALL
(
cudaFree
(
d_constants
));
CUDA_SAFE_CALL
(
cudaFree
(
d_constants
));
#ifdef COUNT_HOW_MANY_HASH_INTO_EACH_SLOT
#ifdef COUNT_HOW_MANY_HASH_INTO_EACH_SLOT
unsigned
*
num_hashing_in
=
new
unsigned
[
table_size
];
unsigned
*
num_hashing_in
=
new
unsigned
[
table_size
];
CUDA_SAFE_CALL
(
cudaMemcpy
(
num_hashing_in
,
CUDA_SAFE_CALL
(
cudaMemcpy
(
num_hashing_in
,
d_num_hashing_in
,
d_num_hashing_in
,
sizeof
(
unsigned
)
*
table_size
,
sizeof
(
unsigned
)
*
table_size
,
cudaMemcpyDeviceToHost
));
cudaMemcpyDeviceToHost
));
...
@@ -165,14 +155,13 @@ void TakeHashFunctionStatistics(const unsigned num_keys,
...
@@ -165,14 +155,13 @@ void TakeHashFunctionStatistics(const unsigned num_keys,
sprintf
(
buffer
,
"
\t
(%u, %u)"
,
previous
,
count
);
sprintf
(
buffer
,
"
\t
(%u, %u)"
,
previous
,
count
);
PrintMessage
(
buffer
);
PrintMessage
(
buffer
);
delete
[]
num_hashing_in
;
delete
[]
num_hashing_in
;
CUDA_SAFE_CALL
(
cudaFree
(
d_num_hashing_in
));
CUDA_SAFE_CALL
(
cudaFree
(
d_num_hashing_in
));
#endif
#endif
#ifdef COUNT_HOW_MANY_HAVE_CYCLES
#ifdef COUNT_HOW_MANY_HAVE_CYCLES
unsigned
*
num_slots_available
=
new
unsigned
[
num_keys
];
unsigned
*
num_slots_available
=
new
unsigned
[
num_keys
];
CUDA_SAFE_CALL
(
cudaMemcpy
(
num_slots_available
,
CUDA_SAFE_CALL
(
cudaMemcpy
(
num_slots_available
,
d_num_slots_available
,
d_num_slots_available
,
sizeof
(
unsigned
)
*
num_keys
,
sizeof
(
unsigned
)
*
num_keys
,
cudaMemcpyDeviceToHost
));
cudaMemcpyDeviceToHost
));
...
@@ -189,38 +178,32 @@ void TakeHashFunctionStatistics(const unsigned num_keys,
...
@@ -189,38 +178,32 @@ void TakeHashFunctionStatistics(const unsigned num_keys,
}
}
PrintMessage
(
buffer
);
PrintMessage
(
buffer
);
delete
[]
histogram
;
delete
[]
histogram
;
delete
[]
num_slots_available
;
delete
[]
num_slots_available
;
CUDA_SAFE_CALL
(
cudaFree
(
d_num_slots_available
));
CUDA_SAFE_CALL
(
cudaFree
(
d_num_slots_available
));
#endif
#endif
}
}
bool
CheckAssignedSameSlot
(
const
unsigned
N
,
bool
CheckAssignedSameSlot
(
const
unsigned
N
,
const
unsigned
num_keys
,
const
unsigned
num_keys
,
const
unsigned
*
d_keys
,
const
unsigned
table_size
,
const
unsigned
*
d_keys
,
uint2
*
constants
)
{
const
unsigned
table_size
,
uint2
*
constants
)
{
unsigned
*
d_cycle_exists
=
NULL
;
unsigned
*
d_cycle_exists
=
NULL
;
uint2
*
d_constants
=
NULL
;
uint2
*
d_constants
=
NULL
;
CUDA_SAFE_CALL
(
cudaMalloc
((
void
**
)
&
d_cycle_exists
,
sizeof
(
unsigned
)));
CUDA_SAFE_CALL
(
cudaMalloc
((
void
**
)
&
d_cycle_exists
,
sizeof
(
unsigned
)));
CUDA_SAFE_CALL
(
cudaMalloc
((
void
**
)
&
d_constants
,
sizeof
(
uint2
)
*
N
));
CUDA_SAFE_CALL
(
cudaMalloc
((
void
**
)
&
d_constants
,
sizeof
(
uint2
)
*
N
));
CUDA_SAFE_CALL
(
cudaMemset
(
d_cycle_exists
,
0
,
sizeof
(
unsigned
)));
CUDA_SAFE_CALL
(
cudaMemset
(
d_cycle_exists
,
0
,
sizeof
(
unsigned
)));
CUDA_SAFE_CALL
(
cudaMemcpy
(
d_constants
,
CUDA_SAFE_CALL
(
cudaMemcpy
(
d_constants
,
constants
,
sizeof
(
uint2
)
*
N
,
constants
,
sizeof
(
uint2
)
*
N
,
cudaMemcpyHostToDevice
));
cudaMemcpyHostToDevice
));
// Check if all keys were given a full set of N slots by the functions.
// Check if all keys were given a full set of N slots by the functions.
take_hash_function_statistics_kernel
<<<
ComputeGridDim
(
num_keys
),
kBlockSize
>>>
take_hash_function_statistics_kernel
<<<
ComputeGridDim
(
num_keys
),
(
d_keys
,
num_keys
,
table_size
,
d_constants
,
N
,
kBlockSize
>>>
(
NULL
,
NULL
,
d_cycle_exists
);
d_keys
,
num_keys
,
table_size
,
d_constants
,
N
,
NULL
,
NULL
,
d_cycle_exists
);
unsigned
cycle_exists
;
unsigned
cycle_exists
;
CUDA_SAFE_CALL
(
cudaMemcpy
(
&
cycle_exists
,
CUDA_SAFE_CALL
(
cudaMemcpy
(
&
cycle_exists
,
d_cycle_exists
,
sizeof
(
unsigned
),
d_cycle_exists
,
sizeof
(
unsigned
),
cudaMemcpyDeviceToHost
));
cudaMemcpyDeviceToHost
));
CUDA_SAFE_CALL
(
cudaFree
(
d_cycle_exists
));
CUDA_SAFE_CALL
(
cudaFree
(
d_cycle_exists
));
...
@@ -229,22 +212,22 @@ bool CheckAssignedSameSlot(const unsigned N,
...
@@ -229,22 +212,22 @@ bool CheckAssignedSameSlot(const unsigned N,
return
(
cycle_exists
!=
0
);
return
(
cycle_exists
!=
0
);
}
}
void
PrintStashContents
(
const
Entry
*
d_stash
)
{
void
PrintStashContents
(
const
Entry
*
d_stash
)
{
Entry
*
stash
=
new
Entry
[
cuhash
::
kStashSize
];
Entry
*
stash
=
new
Entry
[
cuhash
::
kStashSize
];
CUDA_SAFE_CALL
(
cudaMemcpy
(
stash
,
d_stash
,
sizeof
(
Entry
)
*
cuhash
::
kStashSize
,
cudaMemcpyDeviceToHost
));
CUDA_SAFE_CALL
(
cudaMemcpy
(
stash
,
d_stash
,
sizeof
(
Entry
)
*
cuhash
::
kStashSize
,
cudaMemcpyDeviceToHost
));
for
(
unsigned
i
=
0
;
i
<
cuhash
::
kStashSize
;
++
i
)
{
for
(
unsigned
i
=
0
;
i
<
cuhash
::
kStashSize
;
++
i
)
{
if
(
get_key
(
stash
[
i
])
!=
kKeyEmpty
)
{
if
(
get_key
(
stash
[
i
])
!=
kKeyEmpty
)
{
char
buffer
[
256
];
char
buffer
[
256
];
sprintf
(
buffer
,
"Stash[%u]: %u = %u"
,
i
,
get_key
(
stash
[
i
]),
get_value
(
stash
[
i
]));
sprintf
(
buffer
,
"Stash[%u]: %u = %u"
,
i
,
get_key
(
stash
[
i
]),
get_value
(
stash
[
i
]));
PrintMessage
(
buffer
,
true
);
PrintMessage
(
buffer
,
true
);
}
}
}
}
delete
[]
stash
;
delete
[]
stash
;
}
}
};
// namespace cuhash
};
// namespace CuckooHashing
// Leave this at the end of the file
// Leave this at the end of the file
// Local Variables:
// Local Variables:
...
...
src/cuhash/hash_functions.cpp
View file @
19e73bbe
...
@@ -3,14 +3,12 @@
...
@@ -3,14 +3,12 @@
#include <random>
#include <random>
namespace
cuhash
{
namespace
cuhash
{
std
::
random_device
random_dev
;
std
::
random_device
random_dev
;
std
::
mt19937
random_engine
(
random_dev
());
std
::
mt19937
random_engine
(
random_dev
());
std
::
uniform_int_distribution
<
unsigned
>
uint_distribution
;
std
::
uniform_int_distribution
<
unsigned
>
uint_distribution
;
unsigned
generate_random_uint32
(){
unsigned
generate_random_uint32
()
{
return
uint_distribution
(
random_engine
);
}
return
uint_distribution
(
random_engine
);
}
}
}
// namespace cuhash
\ No newline at end of file
\ No newline at end of file
src/cuhash/hash_functions.cu
View file @
19e73bbe
#include <cuhash/hash_table.h>
#include <cuhash/hash_functions.h>
#include <cuhash/debugging.h>
#include <cassert>
#include <cassert>
#include <cuhash/debugging.h>
#include <cuhash/hash_functions.h>
#include <cuhash/hash_table.h>
namespace
cuhash
{
namespace
cuhash
{
void
GenerateFunctions
(
const
unsigned
N
,
void
GenerateFunctions
(
const
unsigned
N
,
const
unsigned
num_keys
,
const
unsigned
num_keys
,
const
unsigned
*
d_keys
,
const
unsigned
table_size
,
const
unsigned
*
d_keys
,
uint2
*
constants
)
{
const
unsigned
table_size
,
uint2
*
constants
)
{
bool
regenerate
=
true
;
bool
regenerate
=
true
;
while
(
regenerate
)
{
while
(
regenerate
)
{
regenerate
=
false
;
regenerate
=
false
;
// Generate a set of hash function constants for this build attempt.
// Generate a set of hash function constants for this build attempt.
for
(
unsigned
i
=
0
;
i
<
N
;
++
i
)
{
for
(
unsigned
i
=
0
;
i
<
N
;
++
i
)
{
// uint_distribution(random_engine) % kPrimeDivisor;
// uint_distribution(random_engine) % kPrimeDivisor;
// genrand_int32() % kPrimeDivisor;
// genrand_int32() % kPrimeDivisor;
unsigned
new_a
=
generate_random_uint32
()
%
kPrimeDivisor
;
unsigned
new_a
=
generate_random_uint32
()
%
kPrimeDivisor
;
...
@@ -26,15 +24,15 @@ void GenerateFunctions(const unsigned N,
...
@@ -26,15 +24,15 @@ void GenerateFunctions(const unsigned N,
#ifdef FORCEFULLY_GENERATE_NO_CYCLES
#ifdef FORCEFULLY_GENERATE_NO_CYCLES
// Ensure that every key gets N different slots.
// Ensure that every key gets N different slots.
regenerate
=
CheckAssignedSameSlot
(
N
,
num_keys
,
d_keys
,
table_size
,
constants
);
regenerate
=
CheckAssignedSameSlot
(
N
,
num_keys
,
d_keys
,
table_size
,
constants
);
#endif
#endif
}
}
#ifdef TAKE_HASH_FUNCTION_STATISTICS
#ifdef TAKE_HASH_FUNCTION_STATISTICS
// Examine how well distributed the items are.
// Examine how well distributed the items are.
TakeHashFunctionStatistics
(
num_keys
,
d_keys
,
table_size
,
constants
,
N
);
TakeHashFunctionStatistics
(
num_keys
,
d_keys
,
table_size
,
constants
,
N
);
#endif
#endif
}
}
};
// namespace
CuckooHashing
};
// namespace
cuhash
src/cuhash/hash_table.cpp
View file @
19e73bbe
...
@@ -3,10 +3,10 @@
...
@@ -3,10 +3,10 @@
// -------------------------------------------------------------
// -------------------------------------------------------------
// $Revision:$
// $Revision:$
// $Date:$
// $Date:$
// -------------------------------------------------------------
// -------------------------------------------------------------
// This source code is distributed under the terms of license.txt in
// This source code is distributed under the terms of license.txt in
// the root directory of this source distribution.
// the root directory of this source distribution.
// -------------------------------------------------------------
// -------------------------------------------------------------
/**
/**
* @file hash_table.cpp
* @file hash_table.cpp
...
@@ -14,16 +14,16 @@
...
@@ -14,16 +14,16 @@
* @brief Implements a basic hash table that stores one value per key.
* @brief Implements a basic hash table that stores one value per key.
*/
*/
#include <cuhash/hash_table.h>
#include <cuhash/debugging.h>
#include <cuhash/debugging.h>
#include <cuhash/hash_table.h>
#include <algorithm>
#include <algorithm>
#include <cmath>
#include <cmath>
#include <cstdio>
#include <cstdio>
#include <cstring>
#include <cstring>
#include <limits>
#include <cuda_runtime_api.h>
#include <cuda_runtime_api.h>
#include <cuhash/cuda_util.h>
#include <cuhash/cuda_util.h>
#include <limits>
namespace
cuhash
{
namespace
cuhash
{
...
@@ -32,227 +32,198 @@ char buffer[256];
...
@@ -32,227 +32,198 @@ char buffer[256];
//! @name Internal
//! @name Internal
/// @{
/// @{
dim3
ComputeGridDim
(
unsigned
n
)
{
dim3
ComputeGridDim
(
unsigned
n
)
{
// Round up in order to make sure all items are hashed in.
// Round up in order to make sure all items are hashed in.
dim3
grid
(
(
n
+
kBlockSize
-
1
)
/
kBlockSize
);
dim3
grid
((
n
+
kBlockSize
-
1
)
/
kBlockSize
);
if
(
grid
.
x
>
kGridSize
)
{
if
(
grid
.
x
>
kGridSize
)
{
grid
.
y
=
(
grid
.
x
+
kGridSize
-
1
)
/
kGridSize
;
grid
.
y
=
(
grid
.
x
+
kGridSize
-
1
)
/
kGridSize
;
grid
.
x
=
kGridSize
;
grid
.
x
=
kGridSize
;
}
}
return
grid
;
return
grid
;
}
}
unsigned
ComputeMaxIterations
(
const
unsigned
n
,
const
unsigned
table_size
,
unsigned
ComputeMaxIterations
(
const
unsigned
n
,
const
unsigned
table_size
,
const
unsigned
num_functions
)
{
const
unsigned
num_functions
)
{
float
lg_input_size
=
(
float
)(
log
((
double
)
n
)
/
log
(
2.0
));
float
lg_input_size
=
(
float
)(
log
((
double
)
n
)
/
log
(
2.0
));
// #define CONSTANT_ITERATIONS
// #define CONSTANT_ITERATIONS
#ifdef CONSTANT_ITERATIONS
#ifdef CONSTANT_ITERATIONS
// Set the maximum number of iterations to 7lg(N).
// Set the maximum number of iterations to 7lg(N).
const
unsigned
MAX_ITERATION_CONSTANT
=
7
;
const
unsigned
MAX_ITERATION_CONSTANT
=
7
;
unsigned
max_iterations
=
MAX_ITERATION_CONSTANT
*
lg_input_size
;
unsigned
max_iterations
=
MAX_ITERATION_CONSTANT
*
lg_input_size
;
#else
#else
// Use an empirical formula for determining what the maximum number of
// Use an empirical formula for determining what the maximum number of
// iterations should be. Works OK in most situations.
// iterations should be. Works OK in most situations.
float
load_factor
=
float
(
n
)
/
table_size
;
float
load_factor
=
float
(
n
)
/
table_size
;
float
ln_load_factor
=
(
float
)(
log
(
load_factor
)
/
log
(
2.71828183
));
float
ln_load_factor
=
(
float
)(
log
(
load_factor
)
/
log
(
2.71828183
));
unsigned
max_iterations
=
(
unsigned
)(
4.0
*
ceil
(
-
1.0
/
(
0.028255
+
1.1594772
*
unsigned
max_iterations
=
ln_load_factor
)
*
lg_input_size
));
(
unsigned
)(
4.0
*
ceil
(
-
1.0
/
(
0.028255
+
1.1594772
*
ln_load_factor
)
*
lg_input_size
));
#endif
#endif
return
max_iterations
;
return
max_iterations
;
}
}
/// @}
/// @}
HashTable
::
HashTable
()
HashTable
::
HashTable
()
:
table_size_
(
0
),
:
table_size_
(
0
),
d_contents_
(
NULL
),
stash_count_
(
0
),
d_failures_
(
NULL
)
{
d_contents_
(
NULL
),
CUDA_CHECK_ERROR
(
"Failed in constructor.
\n
"
);
stash_count_
(
0
),
}
d_failures_
(
NULL
)
{
CUDA_CHECK_ERROR
(
"Failed in constructor.
\n
"
);
}
bool
HashTable
::
Initialize
(
const
unsigned
max_table_entries
,
bool
HashTable
::
Initialize
(
const
unsigned
max_table_entries
,
const
float
space_usage
,
const
float
space_usage
,
const
unsigned
num_functions
)
{
const
unsigned
num_functions
)
{
Release
();
Release
();
// Determine the minimum amount of slots the table requires,
// Determine the minimum amount of slots the table requires,
// and whether the space_usage is within range.
// and whether the space_usage is within range.
float
minimum_space_usage
;
float
minimum_space_usage
;
if
(
num_functions
<
2
||
num_functions
>
5
)
{
if
(
num_functions
<
2
||
num_functions
>
5
)
{
char
message
[
256
]
=
"Number of hash functions must be from 2 to 5; "
char
message
[
256
]
=
"Number of hash functions must be from 2 to 5; "
"others are unimplemented."
;
"others are unimplemented."
;
PrintMessage
(
message
,
true
);
PrintMessage
(
message
,
true
);
return
false
;
return
false
;
}
else
{
}
else
{
minimum_space_usage
=
kMinimumSpaceUsages
[
num_functions
];
minimum_space_usage
=
kMinimumSpaceUsages
[
num_functions
];
}
}
if
(
space_usage
<
minimum_space_usage
)
{
if
(
space_usage
<
minimum_space_usage
)
{
sprintf
(
buffer
,
"Minimum possible space usage for %u functions is %f."
,
sprintf
(
buffer
,
"Minimum possible space usage for %u functions is %f."
,
num_functions
,
minimum_space_usage
);
num_functions
,
minimum_space_usage
);
PrintMessage
(
buffer
);
PrintMessage
(
buffer
);
return
false
;
return
false
;
}
}
num_hash_functions_
=
num_functions
;
num_hash_functions_
=
num_functions
;
table_size_
=
unsigned
(
ceil
(
max_table_entries
*
space_usage
));
table_size_
=
unsigned
(
ceil
(
max_table_entries
*
space_usage
));
// Allocate memory.
// Allocate memory.
const
unsigned
slots_to_allocate
=
table_size_
+
kStashSize
;
const
unsigned
slots_to_allocate
=
table_size_
+
kStashSize
;
CUDA_SAFE_CALL
(
cudaMalloc
(
(
void
**
)
&
d_contents_
,
CUDA_SAFE_CALL
(
sizeof
(
Entry
)
*
slots_to_allocate
));
cudaMalloc
((
void
**
)
&
d_contents_
,
sizeof
(
Entry
)
*
slots_to_allocate
));
CUDA_SAFE_CALL
(
cudaMalloc
(
(
void
**
)
&
d_failures_
,
sizeof
(
unsigned
)
));
CUDA_SAFE_CALL
(
cudaMalloc
((
void
**
)
&
d_failures_
,
sizeof
(
unsigned
)));
if
(
!
d_contents_
||
!
d_failures_
)
{
if
(
!
d_contents_
||
!
d_failures_
)
{
fprintf
(
stderr
,
"Failed to allocate %u slots.
\n
"
,
slots_to_allocate
);
fprintf
(
stderr
,
"Failed to allocate %u slots.
\n
"
,
slots_to_allocate
);
return
false
;
return
false
;
}
}
CUDA_CHECK_ERROR
(
"Failed to initialize.
\n
"
);
CUDA_CHECK_ERROR
(
"Failed to initialize.
\n
"
);
return
true
;
return
true
;
}
}
void
HashTable
::
Release
()
{
void
HashTable
::
Release
()
{
table_size_
=
0
;
table_size_
=
0
;
CUDA_SAFE_CALL
(
cudaFree
(
d_contents_
));
CUDA_SAFE_CALL
(
cudaFree
(
d_contents_
));
CUDA_SAFE_CALL
(
cudaFree
(
d_failures_
));
CUDA_SAFE_CALL
(
cudaFree
(
d_failures_
));
d_contents_
=
NULL
;
d_contents_
=
NULL
;
d_failures_
=
NULL
;
d_failures_
=
NULL
;
CUDA_CHECK_ERROR
(
"Failed during release.
\n
"
);
CUDA_CHECK_ERROR
(
"Failed during release.
\n
"
);
}
}
bool
HashTable
::
Build
(
const
unsigned
n
,
const
unsigned
*
d_keys
,
bool
HashTable
::
Build
(
const
unsigned
n
,
const
unsigned
*
d_keys
,
const
unsigned
*
d_values
)
{
const
unsigned
*
d_values
)
{
unsigned
max_iterations
=
ComputeMaxIterations
(
n
,
table_size_
,
unsigned
max_iterations
=
num_hash_functions_
);
ComputeMaxIterations
(
n
,
table_size_
,
num_hash_functions_
);
unsigned
num_failures
=
1
;
unsigned
num_failures
=
1
;
unsigned
num_attempts
=
0
;
unsigned
num_attempts
=
0
;
// Storage for statistics collection.
// Storage for statistics collection.
unsigned
*
d_iterations_taken
=
NULL
;
unsigned
*
d_iterations_taken
=
NULL
;
#ifdef TRACK_ITERATIONS
#ifdef TRACK_ITERATIONS
CUDA_SAFE_CALL
(
cudaMalloc
((
void
**
)
&
d_iterations_taken
,
sizeof
(
unsigned
)
*
n
));
CUDA_SAFE_CALL
(
cudaMalloc
((
void
**
)
&
d_iterations_taken
,
sizeof
(
unsigned
)
*
n
));
#endif
#endif
// Track how many items ended up in the stash.
// Track how many items ended up in the stash.
unsigned
*
d_stash_count
=
NULL
;
unsigned
*
d_stash_count
=
NULL
;
CUDA_SAFE_CALL
(
cudaMalloc
((
void
**
)
&
d_stash_count
,
sizeof
(
unsigned
)));
CUDA_SAFE_CALL
(
cudaMalloc
((
void
**
)
&
d_stash_count
,
sizeof
(
unsigned
)));
CUDA_CHECK_ERROR
(
"Failed before main build loop.
\n
"
);
CUDA_CHECK_ERROR
(
"Failed before main build loop.
\n
"
);
// Main build loop.
// Main build loop.
while
(
num_failures
&&
++
num_attempts
<
kMaxRestartAttempts
)
{
while
(
num_failures
&&
++
num_attempts
<
kMaxRestartAttempts
)
{
CUDA_SAFE_CALL
(
cudaMemset
(
d_stash_count
,
0
,
sizeof
(
unsigned
)));
CUDA_SAFE_CALL
(
cudaMemset
(
d_stash_count
,
0
,
sizeof
(
unsigned
)));
// Generate new hash functions.
// Generate new hash functions.
if
(
num_hash_functions_
==
2
)
if
(
num_hash_functions_
==
2
)
constants_2_
.
Generate
(
n
,
d_keys
,
table_size_
);
constants_2_
.
Generate
(
n
,
d_keys
,
table_size_
);
else
if
(
num_hash_functions_
==
3
)
else
if
(
num_hash_functions_
==
3
)
constants_3_
.
Generate
(
n
,
d_keys
,
table_size_
);
constants_3_
.
Generate
(
n
,
d_keys
,
table_size_
);
else
if
(
num_hash_functions_
==
4
)
else
if
(
num_hash_functions_
==
4
)
constants_4_
.
Generate
(
n
,
d_keys
,
table_size_
);
constants_4_
.
Generate
(
n
,
d_keys
,
table_size_
);
else
else
constants_5_
.
Generate
(
n
,
d_keys
,
table_size_
);
constants_5_
.
Generate
(
n
,
d_keys
,
table_size_
);
stash_constants_
.
x
=
std
::
max
(
1u
,
generate_random_uint32
())
%
kPrimeDivisor
;
stash_constants_
.
x
=
std
::
max
(
1u
,
generate_random_uint32
())
%
kPrimeDivisor
;
stash_constants_
.
y
=
generate_random_uint32
()
%
kPrimeDivisor
;
stash_constants_
.
y
=
generate_random_uint32
()
%
kPrimeDivisor
;
stash_count_
=
0
;
stash_count_
=
0
;
// Initialize memory.
// Initialize memory.
unsigned
slots_in_table
=
table_size_
+
kStashSize
;
unsigned
slots_in_table
=
table_size_
+
kStashSize
;
CUDAWrapper
::
ClearTable
(
slots_in_table
,
CUDAWrapper
::
ClearTable
(
slots_in_table
,
kEntryEmpty
,
d_contents_
);
kEntryEmpty
,
d_contents_
);
num_failures
=
0
;
num_failures
=
0
;
CUDAWrapper
::
CallCuckooHash
(
n
,
num_hash_functions_
,
d_keys
,
d_values
,
table_size_
,
constants_2_
,
CUDAWrapper
::
CallCuckooHash
(
n
,
constants_3_
,
constants_4_
,
constants_5_
,
max_iterations
,
d_contents_
,
num_hash_functions_
,
stash_constants_
,
d_stash_count
,
d_failures_
,
d_iterations_taken
);
d_keys
,
d_values
,
// Check if successful.
table_size_
,
CUDA_SAFE_CALL
(
cudaMemcpy
(
&
num_failures
,
d_failures_
,
sizeof
(
unsigned
),
constants_2_
,
cudaMemcpyDeviceToHost
));
constants_3_
,
constants_4_
,
constants_5_
,
max_iterations
,
d_contents_
,
stash_constants_
,
d_stash_count
,
d_failures_
,
d_iterations_taken
);
// Check if successful.
CUDA_SAFE_CALL
(
cudaMemcpy
(
&
num_failures
,
d_failures_
,
sizeof
(
unsigned
),
cudaMemcpyDeviceToHost
));
#ifdef COUNT_UNINSERTED
#ifdef COUNT_UNINSERTED
if
(
num_failures
)
{
if
(
num_failures
)
{
printf
(
"Failed to insert %u items.
\n
"
,
num_failures
);
printf
(
"Failed to insert %u items.
\n
"
,
num_failures
);
}
#endif
}
}
#endif
}
// Copy out the stash size.
// Copy out the stash size.
CUDA_SAFE_CALL
(
cudaMemcpy
(
&
stash_count_
,
d_stash_count
,
sizeof
(
unsigned
),
cudaMemcpyDeviceToHost
));
CUDA_SAFE_CALL
(
cudaMemcpy
(
&
stash_count_
,
d_stash_count
,
sizeof
(
unsigned
),
if
(
stash_count_
&&
num_failures
==
0
)
{
cudaMemcpyDeviceToHost
));
// sprintf(buffer, "Stash size: %u", stash_count_);
if
(
stash_count_
&&
num_failures
==
0
)
{
// PrintMessage(buffer, true);
// sprintf(buffer, "Stash size: %u", stash_count_);
// PrintMessage(buffer, true);
#ifdef _DEBUG
#ifdef _DEBUG
PrintStashContents
(
d_contents_
+
table_size_
);
PrintStashContents
(
d_contents_
+
table_size_
);
#endif
#endif
}
}
CUDA_SAFE_CALL
(
cudaFree
(
d_stash_count
));
CUDA_SAFE_CALL
(
cudaFree
(
d_stash_count
));
#ifdef TRACK_ITERATIONS
#ifdef TRACK_ITERATIONS
if
(
num_failures
==
0
)
{
if
(
num_failures
==
0
)
{
OutputBuildStatistics
(
n
,
d_iterations_taken
);
OutputBuildStatistics
(
n
,
d_iterations_taken
);
}
}
CUDA_SAFE_CALL
(
cudaFree
(
d_iterations_taken
));
CUDA_SAFE_CALL
(
cudaFree
(
d_iterations_taken
));
#endif
#endif
// Dump some info if a restart was required.
// Dump some info if a restart was required.
if
(
num_attempts
>=
kMaxRestartAttempts
)
{
if
(
num_attempts
>=
kMaxRestartAttempts
)
{
sprintf
(
buffer
,
"Completely failed to build"
);
sprintf
(
buffer
,
"Completely failed to build"
);
PrintMessage
(
buffer
,
true
);
PrintMessage
(
buffer
,
true
);
}
else
if
(
num_attempts
>
1
)
{
}
else
if
(
num_attempts
>
1
)
{
sprintf
(
buffer
,
"Needed %u attempts to build, you can ignore this message."
,
num_attempts
);
sprintf
(
buffer
,
"Needed %u attempts to build, you can ignore this message."
,
PrintMessage
(
buffer
,
true
);
num_attempts
);
}
PrintMessage
(
buffer
,
true
);
}
CUDA_CHECK_ERROR
(
"Error occurred during hash table build.
\n
"
);
return
num_failures
==
0
;
CUDA_CHECK_ERROR
(
"Error occurred during hash table build.
\n
"
);
return
num_failures
==
0
;
}
}
void
HashTable
::
Retrieve
(
const
unsigned
n_queries
,
const
unsigned
*
d_keys
,
void
HashTable
::
Retrieve
(
const
unsigned
n_queries
,
const
unsigned
*
d_keys
,
unsigned
*
d_values
)
{
unsigned
*
d_values
)
{
CUDAWrapper
::
CallHashRetrieve
(
n_queries
,
CUDAWrapper
::
CallHashRetrieve
(
n_queries
,
num_hash_functions_
,
d_keys
,
num_hash_functions_
,
table_size_
,
d_contents_
,
constants_2_
,
d_keys
,
constants_3_
,
constants_4_
,
constants_5_
,
table_size_
,
stash_constants_
,
stash_count_
,
d_values
);
d_contents_
,
constants_2_
,
constants_3_
,
constants_4_
,
constants_5_
,
stash_constants_
,
stash_count_
,
d_values
);
}
}
};
// namespace cuhash
};
// namesapce CuckooHashing
// Leave this at the end of the file
// Leave this at the end of the file
// Local Variables:
// Local Variables:
...
...
src/cuhash/hash_table.cu
View file @
19e73bbe
...
@@ -3,10 +3,10 @@
...
@@ -3,10 +3,10 @@
// -------------------------------------------------------------
// -------------------------------------------------------------
// $Revision:$
// $Revision:$
// $Date:$
// $Date:$
// -------------------------------------------------------------
// -------------------------------------------------------------
// This source code is distributed under the terms of license.txt in
// This source code is distributed under the terms of license.txt in
// the root directory of this source distribution.
// the root directory of this source distribution.
// -------------------------------------------------------------
// -------------------------------------------------------------
/**
/**
* @file hash_table.cu
* @file hash_table.cu
...
@@ -24,162 +24,89 @@
...
@@ -24,162 +24,89 @@
namespace
cuhash
{
namespace
cuhash
{
namespace
CUDAWrapper
{
namespace
CUDAWrapper
{
void
ClearTable
(
const
unsigned
slots_in_table
,
void
ClearTable
(
const
unsigned
slots_in_table
,
const
Entry
fill_value
,
const
Entry
fill_value
,
Entry
*
d_contents
)
{
Entry
*
d_contents
)
{
clear_table
<
Entry
><<<
ComputeGridDim
(
slots_in_table
),
kBlockSize
>>>
(
clear_table
<
Entry
><<<
ComputeGridDim
(
slots_in_table
),
kBlockSize
>>>
slots_in_table
,
fill_value
,
d_contents
);
(
slots_in_table
,
fill_value
,
d_contents
);
TV_CHECK_CUDA_ERR_V2
(
"Error occurred during hash table clear.
\n
"
);
TV_CHECK_CUDA_ERR_V2
(
"Error occurred during hash table clear.
\n
"
);
}
}
void
CallCuckooHash
(
const
unsigned
n
,
const
unsigned
num_hash_functions
,
const
unsigned
*
d_keys
,
const
unsigned
*
d_values
,
const
unsigned
table_size
,
const
Functions
<
2
>
constants_2
,
const
Functions
<
3
>
constants_3
,
const
Functions
<
4
>
constants_4
,
const
Functions
<
5
>
constants_5
,
const
unsigned
max_iterations
,
Entry
*
d_contents
,
uint2
stash_constants
,
unsigned
*
d_stash_count
,
unsigned
*
d_failures
,
unsigned
*
d_iterations_taken
)
{
// Build the table.
cudaMemset
(
d_failures
,
0
,
sizeof
(
unsigned
));
if
(
num_hash_functions
==
2
)
{
CuckooHash
<<<
ComputeGridDim
(
n
),
kBlockSize
>>>
(
n
,
d_keys
,
d_values
,
table_size
,
constants_2
,
max_iterations
,
d_contents
,
stash_constants
,
d_stash_count
,
d_failures
,
d_iterations_taken
);
}
else
if
(
num_hash_functions
==
3
)
{
CuckooHash
<<<
ComputeGridDim
(
n
),
kBlockSize
>>>
(
n
,
d_keys
,
d_values
,
table_size
,
constants_3
,
max_iterations
,
d_contents
,
stash_constants
,
d_stash_count
,
d_failures
,
d_iterations_taken
);
}
else
if
(
num_hash_functions
==
4
)
{
CuckooHash
<<<
ComputeGridDim
(
n
),
kBlockSize
>>>
(
n
,
d_keys
,
d_values
,
table_size
,
constants_4
,
max_iterations
,
d_contents
,
stash_constants
,
d_stash_count
,
d_failures
,
d_iterations_taken
);
}
else
{
CuckooHash
<<<
ComputeGridDim
(
n
),
kBlockSize
>>>
(
n
,
d_keys
,
d_values
,
table_size
,
constants_5
,
max_iterations
,
d_contents
,
stash_constants
,
d_stash_count
,
d_failures
,
d_iterations_taken
);
}
CUDA_CHECK_ERROR
(
"Error occurred during hash table build.
\n
"
);
}
void
CallCuckooHash
(
const
unsigned
n
,
void
CallHashRetrieve
(
const
unsigned
n_queries
,
const
unsigned
num_hash_functions
,
const
unsigned
num_hash_functions
,
const
unsigned
*
d_keys
,
const
unsigned
*
d_keys
,
const
unsigned
table_size
,
const
Entry
*
d_contents
,
const
unsigned
*
d_values
,
const
Functions
<
2
>
constants_2
,
const
unsigned
table_size
,
const
Functions
<
3
>
constants_3
,
const
Functions
<
2
>
constants_2
,
const
Functions
<
4
>
constants_4
,
const
Functions
<
3
>
constants_3
,
const
Functions
<
5
>
constants_5
,
const
Functions
<
4
>
constants_4
,
const
uint2
stash_constants
,
const
unsigned
stash_count
,
const
Functions
<
5
>
constants_5
,
unsigned
*
d_values
)
{
const
unsigned
max_iterations
,
unsigned
*
d_retrieval_probes
=
NULL
;
Entry
*
d_contents
,
#ifdef TRACK_ITERATIONS
uint2
stash_constants
,
CUDA_SAFE_CALL
(
unsigned
*
d_stash_count
,
cudaMalloc
((
void
**
)
&
d_retrieval_probes
,
sizeof
(
unsigned
)
*
n_queries
));
unsigned
*
d_failures
,
#endif
unsigned
*
d_iterations_taken
)
{
// Build the table.
cudaMemset
(
d_failures
,
0
,
sizeof
(
unsigned
));
if
(
num_hash_functions
==
2
)
{
CuckooHash
<<<
ComputeGridDim
(
n
),
kBlockSize
>>>
(
n
,
d_keys
,
d_values
,
table_size
,
constants_2
,
max_iterations
,
d_contents
,
stash_constants
,
d_stash_count
,
d_failures
,
d_iterations_taken
);
}
else
if
(
num_hash_functions
==
3
)
{
CuckooHash
<<<
ComputeGridDim
(
n
),
kBlockSize
>>>
(
n
,
d_keys
,
d_values
,
table_size
,
constants_3
,
max_iterations
,
d_contents
,
stash_constants
,
d_stash_count
,
d_failures
,
d_iterations_taken
);
}
else
if
(
num_hash_functions
==
4
)
{
CuckooHash
<<<
ComputeGridDim
(
n
),
kBlockSize
>>>
(
n
,
d_keys
,
d_values
,
table_size
,
constants_4
,
max_iterations
,
d_contents
,
stash_constants
,
d_stash_count
,
d_failures
,
d_iterations_taken
);
}
else
{
CuckooHash
<<<
ComputeGridDim
(
n
),
kBlockSize
>>>
(
n
,
d_keys
,
d_values
,
table_size
,
constants_5
,
max_iterations
,
d_contents
,
stash_constants
,
d_stash_count
,
d_failures
,
d_iterations_taken
);
}
CUDA_CHECK_ERROR
(
"Error occurred during hash table build.
\n
"
);
}
if
(
num_hash_functions
==
2
)
{
hash_retrieve
<<<
ComputeGridDim
(
n_queries
),
kBlockSize
>>>
(
n_queries
,
d_keys
,
table_size
,
d_contents
,
constants_2
,
stash_constants
,
stash_count
,
d_values
,
d_retrieval_probes
);
}
else
if
(
num_hash_functions
==
3
)
{
hash_retrieve
<<<
ComputeGridDim
(
n_queries
),
kBlockSize
>>>
(
n_queries
,
d_keys
,
table_size
,
d_contents
,
constants_3
,
stash_constants
,
stash_count
,
d_values
,
d_retrieval_probes
);
}
else
if
(
num_hash_functions
==
4
)
{
hash_retrieve
<<<
ComputeGridDim
(
n_queries
),
kBlockSize
>>>
(
n_queries
,
d_keys
,
table_size
,
d_contents
,
constants_4
,
stash_constants
,
stash_count
,
d_values
,
d_retrieval_probes
);
}
else
{
hash_retrieve
<<<
ComputeGridDim
(
n_queries
),
kBlockSize
>>>
(
n_queries
,
d_keys
,
table_size
,
d_contents
,
constants_5
,
stash_constants
,
stash_count
,
d_values
,
d_retrieval_probes
);
}
void
CallHashRetrieve
(
const
unsigned
n_queries
,
CUDA_CHECK_ERROR
(
"Retrieval failed.
\n
"
);
const
unsigned
num_hash_functions
,
const
unsigned
*
d_keys
,
const
unsigned
table_size
,
const
Entry
*
d_contents
,
const
Functions
<
2
>
constants_2
,
const
Functions
<
3
>
constants_3
,
const
Functions
<
4
>
constants_4
,
const
Functions
<
5
>
constants_5
,
const
uint2
stash_constants
,
const
unsigned
stash_count
,
unsigned
*
d_values
)
{
unsigned
*
d_retrieval_probes
=
NULL
;
#ifdef TRACK_ITERATIONS
CUDA_SAFE_CALL
(
cudaMalloc
((
void
**
)
&
d_retrieval_probes
,
sizeof
(
unsigned
)
*
n_queries
));
#endif
if
(
num_hash_functions
==
2
)
{
hash_retrieve
<<<
ComputeGridDim
(
n_queries
),
kBlockSize
>>>
(
n_queries
,
d_keys
,
table_size
,
d_contents
,
constants_2
,
stash_constants
,
stash_count
,
d_values
,
d_retrieval_probes
);
}
else
if
(
num_hash_functions
==
3
)
{
hash_retrieve
<<<
ComputeGridDim
(
n_queries
),
kBlockSize
>>>
(
n_queries
,
d_keys
,
table_size
,
d_contents
,
constants_3
,
stash_constants
,
stash_count
,
d_values
,
d_retrieval_probes
);
}
else
if
(
num_hash_functions
==
4
)
{
hash_retrieve
<<<
ComputeGridDim
(
n_queries
),
kBlockSize
>>>
(
n_queries
,
d_keys
,
table_size
,
d_contents
,
constants_4
,
stash_constants
,
stash_count
,
d_values
,
d_retrieval_probes
);
}
else
{
hash_retrieve
<<<
ComputeGridDim
(
n_queries
),
kBlockSize
>>>
(
n_queries
,
d_keys
,
table_size
,
d_contents
,
constants_5
,
stash_constants
,
stash_count
,
d_values
,
d_retrieval_probes
);
}
CUDA_CHECK_ERROR
(
"Retrieval failed.
\n
"
);
#ifdef TRACK_ITERATIONS
OutputRetrievalStatistics
(
n_queries
,
d_retrieval_probes
,
num_hash_functions
);
CUDA_SAFE_CALL
(
cudaFree
(
d_retrieval_probes
));
#endif
}
};
// namespace CUDAWrapper
#ifdef TRACK_ITERATIONS
OutputRetrievalStatistics
(
n_queries
,
d_retrieval_probes
,
num_hash_functions
);
CUDA_SAFE_CALL
(
cudaFree
(
d_retrieval_probes
));
#endif
}
};
// namespace CUDAWrapper
};
// namespace
CuckooHashing
};
// namespace
cuhash
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment