Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
a6abf55d
Commit
a6abf55d
authored
Oct 20, 2021
by
yan.yan
Browse files
Merge branch 'develop'
parents
fad30002
79a3eaf2
Changes
142
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
764 additions
and
6464 deletions
+764
-6464
include/tensorview/common.h
include/tensorview/common.h
+0
-111
include/tensorview/cuda_utils.h
include/tensorview/cuda_utils.h
+0
-31
include/tensorview/eigen_utils.h
include/tensorview/eigen_utils.h
+0
-41
include/tensorview/kernel_utils.h
include/tensorview/kernel_utils.h
+0
-72
include/tensorview/mp_helper.h
include/tensorview/mp_helper.h
+0
-53
include/tensorview/prettyprint.h
include/tensorview/prettyprint.h
+0
-475
include/tensorview/pybind_utils.h
include/tensorview/pybind_utils.h
+0
-170
include/tensorview/tensor.h
include/tensorview/tensor.h
+0
-1003
include/tensorview/tensorview.h
include/tensorview/tensorview.h
+0
-1503
include/tensorview/tools.h
include/tensorview/tools.h
+0
-58
include/tensorview/torch_utils.h
include/tensorview/torch_utils.h
+0
-147
include/torch_utils.h
include/torch_utils.h
+0
-124
include/tsl/robin_growth_policy.h
include/tsl/robin_growth_policy.h
+0
-334
include/tsl/robin_hash.h
include/tsl/robin_hash.h
+0
-1360
include/tsl/robin_map.h
include/tsl/robin_map.h
+0
-715
include/utility/timer.h
include/utility/timer.h
+0
-58
pyproject.toml
pyproject.toml
+3
-0
setup.py
setup.py
+189
-93
spconv/__init__.py
spconv/__init__.py
+7
-116
spconv/algo.py
spconv/algo.py
+565
-0
No files found.
include/tensorview/common.h
deleted
100644 → 0
View file @
fad30002
// Copyright 2019-2020 Yan Yan
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <sstream>
#ifdef TV_USE_STACKTRACE
#if defined(WIN32) || defined(_WIN32) || \
defined(__WIN32) && !defined(__CYGWIN__)
#define BOOST_STACKTRACE_USE_WINDBG
#else
// require linking with -ldl and -lbacktrace in linux
#define BOOST_STACKTRACE_USE_BACKTRACE
#endif
#include <boost/stacktrace.hpp>
#endif
#ifdef TV_CUDA
#include <cuda.h>
#endif
#if defined(TV_USE_BOOST_TYPEOF) || (!defined(__clang__) && defined(CUDA_VERSION) && CUDA_VERSION >= 11000)
// a workaround when built with cuda 11
// two options: use BOOST_TYPEOF or identity_t.
// this is a nvcc bug, msvc/gcc/clang don't have this problem.
// #include <boost/typeof/typeof.hpp>
// #define TV_DECLTYPE(x) BOOST_TYPEOF(x)
namespace
tv
{
template
<
typename
T
>
using
identity_t
=
T
;
}
#define TV_DECLTYPE(x) tv::identity_t<decltype(x)>
#else
#define TV_DECLTYPE(x) decltype(x)
#endif
namespace
tv
{
template
<
class
SStream
,
class
T
>
void
sstream_print
(
SStream
&
ss
,
T
val
)
{
ss
<<
val
;
}
template
<
class
SStream
,
class
T
,
class
...
TArgs
>
void
sstream_print
(
SStream
&
ss
,
T
val
,
TArgs
...
args
)
{
ss
<<
val
<<
" "
;
sstream_print
(
ss
,
args
...);
}
template
<
class
...
TArgs
>
void
ssprint
(
TArgs
...
args
)
{
std
::
stringstream
ss
;
sstream_print
(
ss
,
args
...);
std
::
cout
<<
ss
.
str
()
<<
std
::
endl
;
}
#ifdef TV_USE_STACKTRACE
#define TV_BACKTRACE_PRINT(ss) \
ss << std::endl << boost::stacktrace::stacktrace();
#else
#define TV_BACKTRACE_PRINT(ss)
#endif
#define TV_THROW_RT_ERR(...) \
{ \
std::stringstream __macro_s; \
__macro_s << __FILE__ << " " << __LINE__ << "\n"; \
tv::sstream_print(__macro_s, __VA_ARGS__); \
TV_BACKTRACE_PRINT(__macro_s); \
throw std::runtime_error(__macro_s.str()); \
}
#define TV_THROW_INVALID_ARG(...) \
{ \
std::stringstream __macro_s; \
__macro_s << __FILE__ << " " << __LINE__ << "\n"; \
tv::sstream_print(__macro_s, __VA_ARGS__); \
TV_BACKTRACE_PRINT(__macro_s); \
throw std::invalid_argument(__macro_s.str()); \
}
#define TV_ASSERT_RT_ERR(expr, ...) \
{ \
if (!(expr)) { \
std::stringstream __macro_s; \
__macro_s << __FILE__ << " " << __LINE__ << "\n"; \
__macro_s << #expr << " assert faild. "; \
tv::sstream_print(__macro_s, __VA_ARGS__); \
TV_BACKTRACE_PRINT(__macro_s); \
throw std::runtime_error(__macro_s.str()); \
} \
}
#define TV_ASSERT_INVALID_ARG(expr, ...) \
{ \
if (!(expr)) { \
std::stringstream __macro_s; \
__macro_s << __FILE__ << " " << __LINE__ << "\n"; \
__macro_s << #expr << " assert faild. "; \
tv::sstream_print(__macro_s, __VA_ARGS__); \
TV_BACKTRACE_PRINT(__macro_s); \
throw std::invalid_argument(__macro_s.str()); \
} \
}
}
// namespace tv
\ No newline at end of file
include/tensorview/cuda_utils.h
deleted
100644 → 0
View file @
fad30002
#pragma once
// from pytorch.aten
#include "tensorview.h"
#include <type_traits>
namespace
tv
{
namespace
cuda
{
template
<
typename
T1
,
typename
T2
>
inline
int
DivUp
(
const
T1
a
,
const
T2
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
// Use 1024 threads per block, which requires cuda sm_2x or above
constexpr
int
CUDA_NUM_THREADS
=
1024
;
// CUDA: number of blocks for threads.
inline
int
getNumThreads
(
const
int
N
)
{
if
(
N
>
CUDA_NUM_THREADS
)
{
return
CUDA_NUM_THREADS
;
}
return
DivUp
(
N
,
32
)
*
32
;
}
inline
int
getBlocks
(
const
int
N
)
{
TV_ASSERT_RT_ERR
(
N
>
0
,
"CUDA kernel launch blocks must be positive, but got N="
,
N
);
return
DivUp
(
N
,
getNumThreads
(
N
));
}
}
// namespace cuda
}
// namespace tv
\ No newline at end of file
include/tensorview/eigen_utils.h
deleted
100644 → 0
View file @
fad30002
// Copyright 2019-2020 Yan Yan
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tensor.h"
#include "tensorview.h"
#include <eigen3/Eigen/Dense>
namespace
tv
{
template
<
typename
T
,
int
Row
=
Eigen
::
Dynamic
,
int
Col
=
Eigen
::
Dynamic
>
Eigen
::
Map
<
Eigen
::
Matrix
<
T
,
Row
,
Col
,
Eigen
::
RowMajor
>>
tv2eigen
(
TensorView
<
T
>
view
)
{
TV_ASSERT_INVALID_ARG
(
view
.
ndim
()
<=
2
&&
view
.
ndim
()
>
0
,
"error"
);
if
(
Row
!=
Eigen
::
Dynamic
)
{
TV_ASSERT_INVALID_ARG
(
view
.
dim
(
0
)
==
Row
,
"error"
);
}
if
(
Col
!=
Eigen
::
Dynamic
)
{
TV_ASSERT_INVALID_ARG
(
view
.
dim
(
1
)
==
Col
,
"error"
);
}
int
row
=
1
;
if
(
view
.
ndim
()
==
2
)
{
row
=
view
.
dim
(
0
);
}
Eigen
::
Map
<
Eigen
::
Matrix
<
T
,
Row
,
Col
,
Eigen
::
RowMajor
>>
eigen_map
(
view
.
data
(),
row
,
view
.
dim
(
1
));
return
eigen_map
;
}
}
// namespace tv
include/tensorview/kernel_utils.h
deleted
100644 → 0
View file @
fad30002
#pragma once
// from tensorflow
namespace
tv
{
namespace
detail
{
template
<
typename
T
>
class
KernelLoop
{
struct
Iterator
{
__forceinline__
__device__
Iterator
(
T
index
,
T
delta
)
:
index_
(
index
),
delta_
(
delta
)
{}
__forceinline__
__device__
T
operator
*
()
const
{
return
index_
;
}
__forceinline__
__device__
Iterator
&
operator
++
()
{
index_
+=
delta_
;
return
*
this
;
}
__forceinline__
__device__
bool
operator
!=
(
const
Iterator
&
other
)
const
{
bool
greater
=
index_
>
other
.
index_
;
bool
less
=
index_
<
other
.
index_
;
// Anything past an end iterator (delta_ == 0) is equal.
// In range-based for loops, this optimizes to 'return less'.
if
(
!
other
.
delta_
)
{
return
less
;
}
if
(
!
delta_
)
{
return
greater
;
}
return
less
||
greater
;
}
private:
T
index_
;
const
T
delta_
;
};
public:
__forceinline__
__device__
KernelLoop
(
T
begin
,
T
delta
,
T
end
)
:
begin_
(
begin
),
delta_
(
delta
),
end_
(
end
)
{}
__forceinline__
__device__
Iterator
begin
()
const
{
return
Iterator
{
begin_
,
delta_
};
}
__forceinline__
__device__
Iterator
end
()
const
{
return
Iterator
{
end_
,
0
};
}
private:
T
begin_
;
T
delta_
;
T
end_
;
};
}
// namespace detail
template
<
typename
T
,
int
NumILP
=
1
>
__forceinline__
__device__
detail
::
KernelLoop
<
T
>
KernelLoopX
(
T
count
)
{
return
detail
::
KernelLoop
<
T
>
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
,
gridDim
.
x
*
blockDim
.
x
*
NumILP
,
count
);
}
// Helper to visit indices in the range 0 <= i < count using the y-coordinate.
// Usage: for(int i : KernelLoopY(count)) { visit(i); }
template
<
typename
T
,
int
NumILP
=
1
>
__forceinline__
__device__
detail
::
KernelLoop
<
T
>
KernelLoopY
(
T
count
)
{
return
detail
::
KernelLoop
<
T
>
(
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
,
gridDim
.
y
*
blockDim
.
y
*
NumILP
,
count
);
}
// Helper to visit indices in the range 0 <= i < count using the z-coordinate.
// Usage: for(int i : KernelLoopZ(count)) { visit(i); }
template
<
typename
T
,
int
NumILP
=
1
>
__forceinline__
__device__
detail
::
KernelLoop
<
T
>
KernelLoopZ
(
T
count
)
{
return
detail
::
KernelLoop
<
T
>
(
blockIdx
.
z
*
blockDim
.
z
+
threadIdx
.
z
,
gridDim
.
z
*
blockDim
.
z
*
NumILP
,
count
);
}
}
// namespace tv
\ No newline at end of file
include/tensorview/mp_helper.h
deleted
100644 → 0
View file @
fad30002
#ifndef MP_HELPER_H_
#define MP_HELPER_H_
#include <type_traits>
#include <utility>
namespace
tv
{
template
<
class
...
T
>
struct
mp_list
{};
template
<
class
T
,
T
...
I
>
using
mp_list_c
=
mp_list
<
std
::
integral_constant
<
T
,
I
>
...
>
;
namespace
detail
{
template
<
class
...
Ts
,
class
F
>
constexpr
F
mp_for_each_impl
(
mp_list
<
Ts
...
>
,
F
&&
f
)
{
return
(
void
)(
std
::
initializer_list
<
int
>
{(
f
(
Ts
()),
0
)...}),
std
::
forward
<
F
>
(
f
);
}
template
<
class
F
>
constexpr
F
mp_for_each_impl
(
mp_list
<>
,
F
&&
f
)
{
return
std
::
forward
<
F
>
(
f
);
}
}
// namespace detail
template
<
class
...
T
>
using
mp_length
=
std
::
integral_constant
<
std
::
size_t
,
sizeof
...(
T
)
>
;
namespace
detail
{
template
<
class
A
,
template
<
class
...
>
class
B
>
struct
mp_rename_impl
{
// An error "no type named 'type'" here means that the first argument to
// mp_rename is not a list
};
template
<
template
<
class
...
>
class
A
,
class
...
T
,
template
<
class
...
>
class
B
>
struct
mp_rename_impl
<
A
<
T
...
>
,
B
>
{
using
type
=
B
<
T
...
>
;
};
}
// namespace detail
template
<
class
A
,
template
<
class
...
>
class
B
>
using
mp_rename
=
typename
detail
::
mp_rename_impl
<
A
,
B
>::
type
;
template
<
class
L
>
using
mp_size
=
mp_rename
<
L
,
mp_length
>
;
template
<
class
L
,
class
F
>
constexpr
F
mp_for_each
(
F
&&
f
)
{
return
detail
::
mp_for_each_impl
(
mp_rename
<
L
,
mp_list
>
(),
std
::
forward
<
F
>
(
f
));
}
}
// namespace tv
#endif
\ No newline at end of file
include/tensorview/prettyprint.h
deleted
100644 → 0
View file @
fad30002
// Copyright Louis Delacroix 2010 - 2014.
// Distributed under the Boost Software License, Version 1.0.
// (See accompanying file LICENSE_1_0.txt or copy at
// http://www.boost.org/LICENSE_1_0.txt)
//
// A pretty printing library for C++
//
// Usage:
// Include this header, and operator<< will "just work".
#ifndef H_PRETTY_PRINT
#define H_PRETTY_PRINT
#include <cstddef>
#include <iterator>
#include <memory>
#include <ostream>
#include <set>
#include <tuple>
#include <type_traits>
#include <unordered_set>
#include <utility>
#include <valarray>
namespace
pretty_print
{
namespace
detail
{
// SFINAE type trait to detect whether T::const_iterator exists.
struct
sfinae_base
{
using
yes
=
char
;
using
no
=
yes
[
2
];
};
template
<
typename
T
>
struct
has_const_iterator
:
private
sfinae_base
{
private:
template
<
typename
C
>
static
yes
&
test
(
typename
C
::
const_iterator
*
);
template
<
typename
C
>
static
no
&
test
(...);
public:
static
const
bool
value
=
sizeof
(
test
<
T
>
(
nullptr
))
==
sizeof
(
yes
);
using
type
=
T
;
};
template
<
typename
T
>
struct
has_begin_end
:
private
sfinae_base
{
private:
template
<
typename
C
>
static
yes
&
f
(
typename
std
::
enable_if
<
std
::
is_same
<
decltype
(
static_cast
<
typename
C
::
const_iterator
(
C
::*
)()
const
>
(
&
C
::
begin
)),
typename
C
::
const_iterator
(
C
::*
)()
const
>::
value
>::
type
*
);
template
<
typename
C
>
static
no
&
f
(...);
template
<
typename
C
>
static
yes
&
g
(
typename
std
::
enable_if
<
std
::
is_same
<
decltype
(
static_cast
<
typename
C
::
const_iterator
(
C
::*
)()
const
>
(
&
C
::
end
)),
typename
C
::
const_iterator
(
C
::*
)()
const
>::
value
,
void
>::
type
*
);
template
<
typename
C
>
static
no
&
g
(...);
public:
static
bool
const
beg_value
=
sizeof
(
f
<
T
>
(
nullptr
))
==
sizeof
(
yes
);
static
bool
const
end_value
=
sizeof
(
g
<
T
>
(
nullptr
))
==
sizeof
(
yes
);
};
}
// namespace detail
// Holds the delimiter values for a specific character type
template
<
typename
TChar
>
struct
delimiters_values
{
using
char_type
=
TChar
;
const
char_type
*
prefix
;
const
char_type
*
delimiter
;
const
char_type
*
postfix
;
};
// Defines the delimiter values for a specific container and character type
template
<
typename
T
,
typename
TChar
>
struct
delimiters
{
using
type
=
delimiters_values
<
TChar
>
;
static
const
type
values
;
};
// Functor to print containers. You can use this directly if you want
// to specificy a non-default delimiters type. The printing logic can
// be customized by specializing the nested template.
template
<
typename
T
,
typename
TChar
=
char
,
typename
TCharTraits
=
::
std
::
char_traits
<
TChar
>,
typename
TDelimiters
=
delimiters
<
T
,
TChar
>>
struct
print_container_helper
{
using
delimiters_type
=
TDelimiters
;
using
ostream_type
=
std
::
basic_ostream
<
TChar
,
TCharTraits
>
;
template
<
typename
U
>
struct
printer
{
static
void
print_body
(
const
U
&
c
,
ostream_type
&
stream
)
{
using
std
::
begin
;
using
std
::
end
;
auto
it
=
begin
(
c
);
const
auto
the_end
=
end
(
c
);
if
(
it
!=
the_end
)
{
for
(;;)
{
stream
<<
*
it
;
if
(
++
it
==
the_end
)
break
;
if
(
delimiters_type
::
values
.
delimiter
!=
NULL
)
stream
<<
delimiters_type
::
values
.
delimiter
;
}
}
}
};
print_container_helper
(
const
T
&
container
)
:
container_
(
container
)
{}
inline
void
operator
()(
ostream_type
&
stream
)
const
{
if
(
delimiters_type
::
values
.
prefix
!=
NULL
)
stream
<<
delimiters_type
::
values
.
prefix
;
printer
<
T
>::
print_body
(
container_
,
stream
);
if
(
delimiters_type
::
values
.
postfix
!=
NULL
)
stream
<<
delimiters_type
::
values
.
postfix
;
}
private:
const
T
&
container_
;
};
// Specialization for pairs
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
,
typename
TDelimiters
>
template
<
typename
T1
,
typename
T2
>
struct
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
printer
<
std
::
pair
<
T1
,
T2
>>
{
using
ostream_type
=
typename
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
ostream_type
;
static
void
print_body
(
const
std
::
pair
<
T1
,
T2
>
&
c
,
ostream_type
&
stream
)
{
stream
<<
c
.
first
;
if
(
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
!=
NULL
)
stream
<<
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
;
stream
<<
c
.
second
;
}
};
// Specialization for tuples
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
,
typename
TDelimiters
>
template
<
typename
...
Args
>
struct
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
printer
<
std
::
tuple
<
Args
...
>>
{
using
ostream_type
=
typename
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
ostream_type
;
using
element_type
=
std
::
tuple
<
Args
...
>
;
template
<
std
::
size_t
I
>
struct
Int
{};
static
void
print_body
(
const
element_type
&
c
,
ostream_type
&
stream
)
{
tuple_print
(
c
,
stream
,
Int
<
0
>
());
}
static
void
tuple_print
(
const
element_type
&
,
ostream_type
&
,
Int
<
sizeof
...(
Args
)
>
)
{}
static
void
tuple_print
(
const
element_type
&
c
,
ostream_type
&
stream
,
typename
std
::
conditional
<
sizeof
...(
Args
)
!=
0
,
Int
<
0
>
,
std
::
nullptr_t
>::
type
)
{
stream
<<
std
::
get
<
0
>
(
c
);
tuple_print
(
c
,
stream
,
Int
<
1
>
());
}
template
<
std
::
size_t
N
>
static
void
tuple_print
(
const
element_type
&
c
,
ostream_type
&
stream
,
Int
<
N
>
)
{
if
(
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
!=
NULL
)
stream
<<
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>::
delimiters_type
::
values
.
delimiter
;
stream
<<
std
::
get
<
N
>
(
c
);
tuple_print
(
c
,
stream
,
Int
<
N
+
1
>
());
}
};
// Prints a print_container_helper to the specified stream.
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
,
typename
TDelimiters
>
inline
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
operator
<<
(
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
stream
,
const
print_container_helper
<
T
,
TChar
,
TCharTraits
,
TDelimiters
>
&
helper
)
{
helper
(
stream
);
return
stream
;
}
// Basic is_container template; specialize to derive from std::true_type for all
// desired container types
template
<
typename
T
>
struct
is_container
:
public
std
::
integral_constant
<
bool
,
detail
::
has_const_iterator
<
T
>::
value
&&
detail
::
has_begin_end
<
T
>::
beg_value
&&
detail
::
has_begin_end
<
T
>::
end_value
>
{};
template
<
typename
T
,
std
::
size_t
N
>
struct
is_container
<
T
[
N
]
>
:
std
::
true_type
{};
template
<
std
::
size_t
N
>
struct
is_container
<
char
[
N
]
>
:
std
::
false_type
{};
template
<
typename
T
>
struct
is_container
<
std
::
valarray
<
T
>>
:
std
::
true_type
{};
template
<
typename
T1
,
typename
T2
>
struct
is_container
<
std
::
pair
<
T1
,
T2
>>
:
std
::
true_type
{};
template
<
typename
...
Args
>
struct
is_container
<
std
::
tuple
<
Args
...
>>
:
std
::
true_type
{};
// Default delimiters
template
<
typename
T
>
struct
delimiters
<
T
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
>
const
delimiters_values
<
char
>
delimiters
<
T
,
char
>::
values
=
{
"["
,
", "
,
"]"
};
template
<
typename
T
>
struct
delimiters
<
T
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
>
const
delimiters_values
<
wchar_t
>
delimiters
<
T
,
wchar_t
>::
values
=
{
L"["
,
L", "
,
L"]"
};
// Delimiters for (multi)set and unordered_(multi)set
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<::
std
::
set
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
struct
delimiters
<::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
TComp
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<::
std
::
multiset
<
T
,
TComp
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<
::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<
::
std
::
unordered_set
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
char
>
delimiters
<
::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
char
>::
values
=
{
"{"
,
", "
,
"}"
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
struct
delimiters
<::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T
,
typename
THash
,
typename
TEqual
,
typename
TAllocator
>
const
delimiters_values
<
wchar_t
>
delimiters
<::
std
::
unordered_multiset
<
T
,
THash
,
TEqual
,
TAllocator
>
,
wchar_t
>::
values
=
{
L"{"
,
L", "
,
L"}"
};
// Delimiters for pair and tuple
template
<
typename
T1
,
typename
T2
>
struct
delimiters
<
std
::
pair
<
T1
,
T2
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
T1
,
typename
T2
>
const
delimiters_values
<
char
>
delimiters
<
std
::
pair
<
T1
,
T2
>
,
char
>::
values
=
{
"("
,
", "
,
")"
};
template
<
typename
T1
,
typename
T2
>
struct
delimiters
<::
std
::
pair
<
T1
,
T2
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
T1
,
typename
T2
>
const
delimiters_values
<
wchar_t
>
delimiters
<::
std
::
pair
<
T1
,
T2
>
,
wchar_t
>::
values
=
{
L"("
,
L", "
,
L")"
};
template
<
typename
...
Args
>
struct
delimiters
<
std
::
tuple
<
Args
...
>
,
char
>
{
static
const
delimiters_values
<
char
>
values
;
};
template
<
typename
...
Args
>
const
delimiters_values
<
char
>
delimiters
<
std
::
tuple
<
Args
...
>
,
char
>::
values
=
{
"("
,
", "
,
")"
};
template
<
typename
...
Args
>
struct
delimiters
<::
std
::
tuple
<
Args
...
>
,
wchar_t
>
{
static
const
delimiters_values
<
wchar_t
>
values
;
};
template
<
typename
...
Args
>
const
delimiters_values
<
wchar_t
>
delimiters
<::
std
::
tuple
<
Args
...
>
,
wchar_t
>::
values
=
{
L"("
,
L", "
,
L")"
};
// Type-erasing helper class for easy use of custom delimiters.
// Requires TCharTraits = std::char_traits<TChar> and TChar = char or wchar_t,
// and MyDelims needs to be defined for TChar. Usage: "cout <<
// pretty_print::custom_delims<MyDelims>(x)".
struct
custom_delims_base
{
virtual
~
custom_delims_base
()
{}
virtual
std
::
ostream
&
stream
(
::
std
::
ostream
&
)
=
0
;
virtual
std
::
wostream
&
stream
(
::
std
::
wostream
&
)
=
0
;
};
template
<
typename
T
,
typename
Delims
>
struct
custom_delims_wrapper
:
custom_delims_base
{
custom_delims_wrapper
(
const
T
&
t_
)
:
t
(
t_
)
{}
std
::
ostream
&
stream
(
std
::
ostream
&
s
)
{
return
s
<<
print_container_helper
<
T
,
char
,
std
::
char_traits
<
char
>
,
Delims
>
(
t
);
}
std
::
wostream
&
stream
(
std
::
wostream
&
s
)
{
return
s
<<
print_container_helper
<
T
,
wchar_t
,
std
::
char_traits
<
wchar_t
>
,
Delims
>
(
t
);
}
private:
const
T
&
t
;
};
template
<
typename
Delims
>
struct
custom_delims
{
template
<
typename
Container
>
custom_delims
(
const
Container
&
c
)
:
base
(
new
custom_delims_wrapper
<
Container
,
Delims
>
(
c
))
{}
std
::
unique_ptr
<
custom_delims_base
>
base
;
};
template
<
typename
TChar
,
typename
TCharTraits
,
typename
Delims
>
inline
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
operator
<<
(
std
::
basic_ostream
<
TChar
,
TCharTraits
>
&
s
,
const
custom_delims
<
Delims
>
&
p
)
{
return
p
.
base
->
stream
(
s
);
}
// A wrapper for a C-style array given as pointer-plus-size.
// Usage: std::cout << pretty_print_array(arr, n) << std::endl;
template
<
typename
T
>
struct
array_wrapper_n
{
typedef
const
T
*
const_iterator
;
typedef
T
value_type
;
array_wrapper_n
(
const
T
*
const
a
,
size_t
n
)
:
_array
(
a
),
_n
(
n
)
{}
inline
const_iterator
begin
()
const
{
return
_array
;
}
inline
const_iterator
end
()
const
{
return
_array
+
_n
;
}
private:
const
T
*
const
_array
;
size_t
_n
;
};
// A wrapper for hash-table based containers that offer local iterators to each
// bucket. Usage: std::cout << bucket_print(m, 4) << std::endl; (Prints bucket
// 5 of container m.)
template
<
typename
T
>
struct
bucket_print_wrapper
{
typedef
typename
T
::
const_local_iterator
const_iterator
;
typedef
typename
T
::
size_type
size_type
;
const_iterator
begin
()
const
{
return
m_map
.
cbegin
(
n
);
}
const_iterator
end
()
const
{
return
m_map
.
cend
(
n
);
}
bucket_print_wrapper
(
const
T
&
m
,
size_type
bucket
)
:
m_map
(
m
),
n
(
bucket
)
{}
private:
const
T
&
m_map
;
const
size_type
n
;
};
}
// namespace pretty_print
// Global accessor functions for the convenience wrappers
template
<
typename
T
>
inline
pretty_print
::
array_wrapper_n
<
T
>
pretty_print_array
(
const
T
*
const
a
,
size_t
n
)
{
return
pretty_print
::
array_wrapper_n
<
T
>
(
a
,
n
);
}
template
<
typename
T
>
pretty_print
::
bucket_print_wrapper
<
T
>
bucket_print
(
const
T
&
m
,
typename
T
::
size_type
n
)
{
return
pretty_print
::
bucket_print_wrapper
<
T
>
(
m
,
n
);
}
// Main magic entry point: An overload snuck into namespace std.
// Can we do better?
namespace
std
{
// Prints a container to the stream using default delimiters
template
<
typename
T
,
typename
TChar
,
typename
TCharTraits
>
inline
typename
enable_if
<::
pretty_print
::
is_container
<
T
>::
value
,
basic_ostream
<
TChar
,
TCharTraits
>
&>::
type
operator
<<
(
basic_ostream
<
TChar
,
TCharTraits
>
&
stream
,
const
T
&
container
)
{
return
stream
<<
::
pretty_print
::
print_container_helper
<
T
,
TChar
,
TCharTraits
>
(
container
);
}
}
// namespace std
#endif // H_PRETTY_PRINT
include/tensorview/pybind_utils.h
deleted
100644 → 0
View file @
fad30002
// Copyright 2019-2020 Yan Yan
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tensor.h"
#include "tensorview.h"
#include <algorithm>
#include <array>
#include <iostream>
#include <pybind11/functional.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace
py
=
pybind11
;
namespace
tv
{
template
<
typename
Tarr
>
bool
is_c_style
(
const
Tarr
&
arr
)
{
return
bool
(
arr
.
flags
()
&
py
::
array
::
c_style
);
}
template
<
typename
T
,
int
Rank
=
-
1
>
TensorView
<
T
,
Rank
>
arrayt2tv
(
py
::
array_t
<
T
>
arr
)
{
TV_ASSERT_INVALID_ARG
(
is_c_style
(
arr
),
"array must be c-contiguous array"
);
Shape
shape
;
for
(
int
i
=
0
;
i
<
arr
.
ndim
();
++
i
)
{
shape
.
push_back
(
arr
.
shape
(
i
));
}
if
(
Rank
>=
0
)
{
TV_ASSERT_INVALID_ARG
(
shape
.
ndim
()
==
Rank
,
"error"
);
}
return
TensorView
<
T
,
Rank
>
(
arr
.
mutable_data
(),
shape
);
}
template
<
typename
T
,
int
Rank
=
-
1
>
TensorView
<
const
T
>
carrayt2tv
(
py
::
array_t
<
T
>
arr
)
{
TV_ASSERT_INVALID_ARG
(
is_c_style
(
arr
),
"array must be c-contiguous array"
);
Shape
shape
;
for
(
int
i
=
0
;
i
<
arr
.
ndim
();
++
i
)
{
shape
.
push_back
(
arr
.
shape
(
i
));
}
if
(
Rank
>=
0
)
{
TV_ASSERT_INVALID_ARG
(
shape
.
ndim
()
==
Rank
,
"error"
);
}
return
TensorView
<
const
T
,
Rank
>
(
arr
.
data
(),
shape
);
}
template
<
typename
Tarr
>
tv
::
DType
get_array_tv_dtype
(
const
Tarr
&
arr
)
{
switch
(
arr
.
dtype
().
kind
())
{
case
'b'
:
return
tv
::
bool_
;
case
'i'
:
{
switch
(
arr
.
itemsize
())
{
case
1
:
return
tv
::
int8
;
case
2
:
return
tv
::
int16
;
case
4
:
return
tv
::
int32
;
case
8
:
return
tv
::
int64
;
default:
break
;
}
}
case
'u'
:
{
switch
(
arr
.
itemsize
())
{
case
1
:
return
tv
::
uint8
;
case
2
:
return
tv
::
uint16
;
case
4
:
return
tv
::
uint32
;
case
8
:
return
tv
::
uint64
;
default:
break
;
}
}
case
'f'
:
{
switch
(
arr
.
itemsize
())
{
case
2
:
return
tv
::
float16
;
case
4
:
return
tv
::
float32
;
case
8
:
return
tv
::
float64
;
default:
break
;
}
}
}
TV_THROW_RT_ERR
(
"unknown dtype"
,
arr
.
dtype
().
kind
(),
arr
.
itemsize
());
}
template
<
typename
Tarr
>
Tensor
array2tensor
(
Tarr
&
arr
)
{
TV_ASSERT_INVALID_ARG
(
is_c_style
(
arr
),
"array must be c-contiguous array"
);
TensorShape
shape
;
for
(
int
i
=
0
;
i
<
arr
.
ndim
();
++
i
)
{
shape
.
push_back
(
arr
.
shape
(
i
));
}
return
tv
::
from_blob
(
arr
.
mutable_data
(),
shape
,
get_array_tv_dtype
(
arr
),
-
1
);
}
template
<
typename
T
>
Tensor
arrayt2tensor
(
py
::
array_t
<
T
>
&
arr
)
{
TV_ASSERT_INVALID_ARG
(
is_c_style
(
arr
),
"array must be c-contiguous array"
);
TensorShape
shape
;
for
(
int
i
=
0
;
i
<
arr
.
ndim
();
++
i
)
{
shape
.
push_back
(
arr
.
shape
(
i
));
}
return
tv
::
from_blob
(
arr
.
mutable_data
(),
shape
,
tv
::
type_v
<
T
>
,
-
1
);
}
template
<
typename
TDType
>
py
::
dtype
tv_dtype_to_py
(
TDType
d
)
{
switch
(
d
)
{
case
float32
:
return
py
::
dtype
(
"float32"
);
case
float64
:
return
py
::
dtype
(
"float64"
);
case
float16
:
return
py
::
dtype
(
"float16"
);
case
int32
:
return
py
::
dtype
(
"int32"
);
case
int16
:
return
py
::
dtype
(
"int16"
);
case
int8
:
return
py
::
dtype
(
"int8"
);
case
int64
:
return
py
::
dtype
(
"int64"
);
case
uint32
:
return
py
::
dtype
(
"uint32"
);
case
uint16
:
return
py
::
dtype
(
"uint16"
);
case
uint8
:
return
py
::
dtype
(
"uint8"
);
case
uint64
:
return
py
::
dtype
(
"uint64"
);
case
bool_
:
return
py
::
dtype
(
"bool_"
);
default:
;
}
TV_THROW_INVALID_ARG
(
"unknown dtype"
,
d
);
}
// add template to define function in header
template
<
typename
Ttensor
>
py
::
array
tensor2array
(
Ttensor
&
tensor
)
{
// you cant call this function during GIL released.
TV_ASSERT_INVALID_ARG
(
tensor
.
device
()
==
-
1
,
"must be cpu tensor"
);
auto
shape
=
tensor
.
shape
();
std
::
vector
<
int
>
shape_vec
(
shape
.
begin
(),
shape
.
end
());
auto
dtype
=
tv_dtype_to_py
(
tensor
.
dtype
());
// construct py::array will copy content from ptr.
// its expected because we can't transfer ownership from
// c++ tv::Tensor to numpy array when c++ object is deleted.
return
py
::
array
(
dtype
,
shape_vec
,
{},
tensor
.
raw_data
());
}
}
// namespace tv
include/tensorview/tensor.h
deleted
100644 → 0
View file @
fad30002
// Copyright 2019-2020 Yan Yan
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
tv::Tensor is a lightweight header-only tensor container
without template and annoying dependencies. no algorithm is implemented.
it should only be used when you want a no-template simple container but
dont want to link with libtorch.
If you can use libtorch, dont use tv::Tensor.
*/
#pragma once
#include "mp_helper.h"
#include "tensorview.h"
#include <cstring>
#include <iomanip>
#include <memory>
#include <type_traits>
#ifdef TV_CUDA
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#endif
namespace
tv
{
enum
DType
{
float32
,
int32
,
int16
,
int8
,
float64
,
bool_
,
uint8
,
float16
,
int64
,
uint16
,
uint32
,
uint64
};
namespace
detail
{
using
dtype_collection_t
=
tv
::
mp_list_c
<
int
,
float32
,
int32
,
int16
,
int8
,
float64
,
bool_
,
uint8
,
float16
,
int64
,
uint16
,
uint32
,
uint64
>
;
#ifdef TV_CUDA
using
all_tensor_types_t
=
std
::
tuple
<
float
,
double
,
int8_t
,
int16_t
,
int32_t
,
int64_t
,
uint8_t
,
uint16_t
,
uint32_t
,
uint64_t
,
bool
>
;
#else
using
all_tensor_types_t
=
std
::
tuple
<
float
,
double
,
int8_t
,
int16_t
,
int32_t
,
int64_t
,
uint8_t
,
uint16_t
,
uint32_t
,
uint64_t
,
bool
>
;
#endif
template
<
typename
T
>
class
TensorStorage
{
public:
TensorStorage
(
size_t
size
,
int
device
=
-
1
,
bool
managed
=
false
,
bool
pinned
=
false
)
:
mSize
(
size
),
device_
(
device
),
managed_
(
managed
),
pinned_
(
pinned
)
{
if
(
size
==
0
)
{
mPtr
=
nullptr
;
}
else
{
if
(
device
==
-
1
)
{
if
(
pinned_
)
{
#ifdef TV_CUDA
checkCudaErrors
(
cudaMallocHost
(
&
mPtr
,
size
*
sizeof
(
T
)));
#else
TV_THROW_INVALID_ARG
(
"you need to define TV_CUDA to use pinned"
);
#endif
}
else
{
mPtr
=
new
T
[
size
];
}
}
else
{
#ifdef TV_CUDA
// we should select device in external
/*
int deviceCount;
cudaGetDeviceCount(&deviceCount);
if (device >= deviceCount) {
TV_THROW_INVALID_ARG("you provide device ", device,
" but you only have ", deviceCount, " device.");
}
cudaSetDevice(device);
*/
if
(
managed
)
{
checkCudaErrors
(
cudaMallocManaged
(
&
this
->
mPtr
,
size
*
sizeof
(
T
)));
}
else
{
checkCudaErrors
(
cudaMalloc
(
&
mPtr
,
size
*
sizeof
(
T
)));
}
#else
TV_THROW_INVALID_ARG
(
"don't compiled with cuda"
);
#endif
}
}
}
TensorStorage
(
T
*
ptr
,
size_t
size
,
int
device
)
:
mSize
(
size
),
mPtr
(
ptr
),
from_blob_
(
true
),
device_
(
device
)
{}
virtual
~
TensorStorage
()
{
if
(
empty
())
{
return
;
}
if
(
from_blob_
)
{
return
;
}
if
(
device_
==
-
1
)
{
if
(
pinned_
)
{
#ifdef TV_CUDA
cudaFreeHost
(
mPtr
);
#endif
}
else
{
delete
[]
mPtr
;
}
}
else
{
#ifdef TV_CUDA
cudaFree
(
mPtr
);
#endif
}
};
inline
size_t
size
()
const
{
return
mSize
;
}
T
*
data
()
{
return
mPtr
;
}
const
T
*
data
()
const
{
return
mPtr
;
}
bool
empty
()
const
{
return
mPtr
==
nullptr
||
mSize
==
0
;
}
bool
managed
()
const
{
return
managed_
;
}
bool
pinned
()
const
{
return
pinned_
;
}
int
device
()
const
{
return
device_
;
}
void
zero_
()
{
if
(
device_
==
-
1
)
{
std
::
memset
(
data
(),
0
,
mSize
);
// std::fill(data(), data() + mSize, 0);
}
else
{
#ifdef TV_CUDA
checkCudaErrors
(
cudaMemset
(
data
(),
0
,
mSize
/
sizeof
(
T
)));
#else
TV_THROW_INVALID_ARG
(
"don't compiled with cuda"
);
#endif
}
}
private:
size_t
mSize
=
0
;
T
*
mPtr
=
nullptr
;
bool
from_blob_
=
false
;
int
device_
=
-
1
;
bool
managed_
=
false
;
bool
pinned_
=
false
;
};
template
<
typename
T
>
size_t
sizeof_dtype
(
T
dtype
)
{
switch
(
dtype
)
{
case
float32
:
return
sizeof
(
float
);
case
int8
:
return
sizeof
(
int8_t
);
case
int16
:
return
sizeof
(
int16_t
);
case
int32
:
return
sizeof
(
int32_t
);
case
float64
:
return
sizeof
(
double
);
case
int64
:
return
sizeof
(
int64_t
);
case
bool_
:
return
sizeof
(
bool
);
case
uint8
:
return
sizeof
(
uint8_t
);
case
uint16
:
return
sizeof
(
uint16_t
);
case
uint32
:
return
sizeof
(
uint32_t
);
case
uint64
:
return
sizeof
(
uint64_t
);
case
float16
:
return
2
;
default:
TV_THROW_RT_ERR
(
"unsupported dtype"
);
}
return
0
;
}
template
<
typename
T
>
std
::
string
typeString
(
T
t
)
{
switch
(
t
)
{
case
DType
::
bool_
:
return
"bool"
;
case
DType
::
float32
:
return
"float32"
;
case
DType
::
int8
:
return
"int8"
;
case
DType
::
int16
:
return
"int16"
;
case
DType
::
int32
:
return
"int32"
;
case
DType
::
float64
:
return
"float64"
;
case
DType
::
int64
:
return
"int64"
;
case
DType
::
uint8
:
return
"uint8"
;
case
DType
::
uint16
:
return
"uint16"
;
case
DType
::
uint32
:
return
"uint32"
;
case
DType
::
uint64
:
return
"uint64"
;
case
DType
::
float16
:
return
"half"
;
default:
return
""
;
}
}
template
<
typename
T
>
struct
TypeToDtypeTraits
;
template
<
>
struct
TypeToDtypeTraits
<
int32_t
>
{
static
constexpr
DType
dtype
=
int32
;
};
#ifdef TV_CUDA
template
<
>
struct
TypeToDtypeTraits
<
__half
>
{
static
constexpr
DType
dtype
=
float16
;
};
#endif
template
<
>
struct
TypeToDtypeTraits
<
float
>
{
static
constexpr
DType
dtype
=
float32
;
};
template
<
>
struct
TypeToDtypeTraits
<
double
>
{
static
constexpr
DType
dtype
=
float64
;
};
template
<
>
struct
TypeToDtypeTraits
<
int16_t
>
{
static
constexpr
DType
dtype
=
int16
;
};
template
<
>
struct
TypeToDtypeTraits
<
int8_t
>
{
static
constexpr
DType
dtype
=
int8
;
};
template
<
>
struct
TypeToDtypeTraits
<
int64_t
>
{
static
constexpr
DType
dtype
=
int64
;
};
template
<
>
struct
TypeToDtypeTraits
<
uint8_t
>
{
static
constexpr
DType
dtype
=
uint8
;
};
template
<
>
struct
TypeToDtypeTraits
<
uint16_t
>
{
static
constexpr
DType
dtype
=
uint16
;
};
template
<
>
struct
TypeToDtypeTraits
<
uint32_t
>
{
static
constexpr
DType
dtype
=
uint32
;
};
template
<
>
struct
TypeToDtypeTraits
<
uint64_t
>
{
static
constexpr
DType
dtype
=
uint64
;
};
template
<
>
struct
TypeToDtypeTraits
<
bool
>
{
static
constexpr
DType
dtype
=
bool_
;
};
template
<
>
struct
TypeToDtypeTraits
<
const
int32_t
>
{
static
constexpr
DType
dtype
=
int32
;
};
#ifdef TV_CUDA
template
<
>
struct
TypeToDtypeTraits
<
const
__half
>
{
static
constexpr
DType
dtype
=
float16
;
};
#endif
template
<
>
struct
TypeToDtypeTraits
<
const
float
>
{
static
constexpr
DType
dtype
=
float32
;
};
template
<
>
struct
TypeToDtypeTraits
<
const
double
>
{
static
constexpr
DType
dtype
=
float64
;
};
template
<
>
struct
TypeToDtypeTraits
<
const
int16_t
>
{
static
constexpr
DType
dtype
=
int16
;
};
template
<
>
struct
TypeToDtypeTraits
<
const
int8_t
>
{
static
constexpr
DType
dtype
=
int8
;
};
template
<
>
struct
TypeToDtypeTraits
<
const
int64_t
>
{
static
constexpr
DType
dtype
=
int64
;
};
template
<
>
struct
TypeToDtypeTraits
<
const
uint8_t
>
{
static
constexpr
DType
dtype
=
uint8
;
};
template
<
>
struct
TypeToDtypeTraits
<
const
uint16_t
>
{
static
constexpr
DType
dtype
=
uint16
;
};
template
<
>
struct
TypeToDtypeTraits
<
const
uint32_t
>
{
static
constexpr
DType
dtype
=
uint32
;
};
template
<
>
struct
TypeToDtypeTraits
<
const
uint64_t
>
{
static
constexpr
DType
dtype
=
uint64
;
};
template
<
>
struct
TypeToDtypeTraits
<
const
bool
>
{
static
constexpr
DType
dtype
=
bool_
;
};
}
// namespace detail
template
<
class
T
>
constexpr
DType
type_v
=
detail
::
TypeToDtypeTraits
<
T
>::
dtype
;
template
<
class
...
Ts
,
typename
F
>
bool
dispatch_noexcept
(
DType
t
,
F
&&
f
)
{
static_assert
(
sizeof
...(
Ts
)
>
0
,
"you need to provide at least one type"
);
bool
notFound
=
true
;
mp_for_each
<
mp_list
<
Ts
...
>>
([
=
,
&
notFound
,
&
f
](
auto
I
)
{
if
(
type_v
<
TV_DECLTYPE
(
I
)
>
==
t
&&
notFound
)
{
std
::
forward
<
F
>
(
f
)(
TV_DECLTYPE
(
I
)());
notFound
=
false
;
}
});
return
!
notFound
;
}
template
<
class
...
Ts
,
typename
F
>
void
dispatch
(
DType
t
,
F
&&
f
)
{
if
(
!
dispatch_noexcept
<
Ts
...
>
(
t
,
std
::
forward
<
F
>
(
f
)))
{
std
::
stringstream
ss
;
mp_for_each
<
mp_list
<
Ts
...
>>
([
=
,
&
ss
](
auto
I
)
{
ss
<<
detail
::
TypeToString
<
TV_DECLTYPE
(
I
)
>::
value
<<
" "
;
});
TV_THROW_RT_ERR
(
"unknown type"
,
detail
::
typeString
(
t
),
", available:"
,
ss
.
str
());
}
}
template
<
typename
T
,
T
...
Is
,
typename
F
>
void
dispatch_scalar
(
T
idx
,
F
&&
f
)
{
static_assert
(
sizeof
...(
Is
)
>
0
,
"you need to provide at least one candidate"
);
bool
notFound
=
true
;
mp_for_each
<
mp_list_c
<
T
,
Is
...
>>
([
=
,
&
notFound
,
&
f
](
auto
I
)
{
if
(
T
(
I
)
==
idx
&&
notFound
)
{
std
::
forward
<
F
>
(
f
)(
I
);
notFound
=
false
;
}
});
if
(
notFound
)
{
std
::
stringstream
ss
;
mp_for_each
<
mp_list_c
<
T
,
Is
...
>>
([
=
,
&
ss
](
auto
I
)
{
ss
<<
T
(
I
)
<<
" "
;
});
TV_THROW_RT_ERR
(
"unknown value"
,
idx
,
", available:"
,
ss
.
str
());
}
}
template
<
int
...
Is
,
typename
F
>
bool
dispatch_int_noexcept
(
int
idx
,
F
&&
f
)
{
static_assert
(
sizeof
...(
Is
)
>
0
,
"you need to provide at least one candidate"
);
bool
notFound
=
true
;
mp_for_each
<
mp_list_c
<
int
,
Is
...
>>
([
=
,
&
notFound
,
&
f
](
auto
I
)
{
if
(
TV_DECLTYPE
(
I
)
::
value
==
idx
&&
notFound
)
{
std
::
forward
<
F
>
(
f
)(
I
);
notFound
=
false
;
}
});
return
!
notFound
;
}
template
<
int
...
Is
,
typename
F
,
class
BinaryPredicate
>
bool
dispatch_int_noexcept
(
int
idx
,
BinaryPredicate
p
,
F
&&
f
)
{
static_assert
(
sizeof
...(
Is
)
>
0
,
"you need to provide at least one candidate"
);
bool
notFound
=
true
;
mp_for_each
<
mp_list_c
<
int
,
Is
...
>>
([
=
,
&
notFound
,
&
f
](
auto
I
)
{
if
(
p
(
idx
,
TV_DECLTYPE
(
I
)
::
value
)
&&
notFound
)
{
std
::
forward
<
F
>
(
f
)(
I
);
notFound
=
false
;
}
});
return
!
notFound
;
}
template
<
int
...
Is
,
typename
F
>
void
dispatch_int
(
int
idx
,
F
&&
f
)
{
if
(
!
dispatch_int_noexcept
<
Is
...
>
(
idx
,
std
::
forward
<
F
>
(
f
)))
{
std
::
stringstream
ss
;
mp_for_each
<
mp_list_c
<
int
,
Is
...
>>
(
[
=
,
&
ss
](
auto
I
)
{
ss
<<
TV_DECLTYPE
(
I
)
::
value
<<
" "
;
});
TV_THROW_RT_ERR
(
"unknown value"
,
idx
,
", available:"
,
ss
.
str
());
}
}
template
<
int
...
Is
,
typename
F
,
class
BinaryPredicate
>
void
dispatch_int
(
int
idx
,
BinaryPredicate
p
,
F
&&
f
)
{
// BinaryPredicate: BinaryPredicate(idx, candidate)
if
(
!
dispatch_int_noexcept
<
Is
...
>
(
idx
,
p
,
std
::
forward
<
F
>
(
f
)))
{
std
::
stringstream
ss
;
mp_for_each
<
mp_list_c
<
int
,
Is
...
>>
(
[
=
,
&
ss
](
auto
I
)
{
ss
<<
TV_DECLTYPE
(
I
)
::
value
<<
" "
;
});
TV_THROW_RT_ERR
(
"unknown value"
,
idx
,
", available:"
,
ss
.
str
());
}
}
// Ts is pack of mp_list_c
template
<
class
...
Ts
,
typename
Iterator
,
typename
F
>
bool
dispatch_container_noexcept
(
Iterator
begin
,
Iterator
end
,
F
&&
f
)
{
static_assert
(
sizeof
...(
Ts
)
>
0
,
"you need to provide at least one candidate"
);
bool
notFound
=
true
;
mp_for_each
<
mp_list
<
Ts
...
>>
([
=
,
&
notFound
,
&
f
](
auto
I
)
{
using
val_lst_t
=
TV_DECLTYPE
(
I
);
auto
val_lst_size
=
mp_size
<
val_lst_t
>::
value
;
bool
equal
=
true
;
std
::
size_t
count
=
0
;
auto
iter
=
begin
;
mp_for_each
<
val_lst_t
>
([
&
](
auto
E
)
{
if
(
iter
==
end
||
!
equal
)
{
return
;
}
if
(
count
>=
val_lst_size
)
{
TV_THROW_INVALID_ARG
(
"iterator length invalid:"
,
val_lst_size
);
}
constexpr
auto
c
=
TV_DECLTYPE
(
E
)
::
value
;
if
(
c
!=
*
iter
)
{
equal
=
false
;
}
++
count
;
std
::
advance
(
iter
,
1
);
});
if
(
count
!=
val_lst_size
||
iter
!=
end
)
{
equal
=
false
;
}
if
(
equal
&&
notFound
)
{
std
::
forward
<
F
>
(
f
)(
I
);
notFound
=
false
;
}
});
return
!
notFound
;
}
template
<
class
...
Ts
,
typename
Iterator
,
typename
F
>
void
dispatch_container
(
Iterator
begin
,
Iterator
end
,
F
&&
f
)
{
if
(
!
dispatch_container_noexcept
<
Ts
...
>
(
begin
,
end
,
std
::
forward
<
F
>
(
f
)))
{
std
::
stringstream
ss
;
ss
<<
"unknown value ["
;
for
(
auto
iter
=
begin
;
iter
!=
end
;
std
::
advance
(
iter
,
1
))
{
ss
<<
*
iter
<<
","
;
}
ss
<<
"], available: "
;
mp_for_each
<
mp_list
<
Ts
...
>>
([
=
,
&
ss
](
auto
I
)
{
ss
<<
"["
;
mp_for_each
<
TV_DECLTYPE
(
I
)
>
(
[
=
,
&
ss
](
auto
E
)
{
ss
<<
TV_DECLTYPE
(
E
)
::
value
<<
","
;
});
ss
<<
"]"
;
});
TV_THROW_RT_ERR
(
ss
.
str
());
}
}
/*
template <int... Is, typename F> void dispatch_int(int idx, F &&f) {
return dispatch_scalar<int, Is...>(idx, f);
}
*/
template
<
class
T
>
struct
Dispatch
;
template
<
template
<
class
...
>
class
T
,
class
...
Args
>
struct
Dispatch
<
T
<
Args
...
>>
{
template
<
typename
F
>
inline
void
operator
()(
DType
t
,
F
&&
f
)
{
return
dispatch
<
Args
...
>
(
t
,
std
::
forward
<
F
>
(
f
));
}
};
template
<
class
T
>
struct
DispatchContainer
;
template
<
template
<
class
...
>
class
T
,
class
...
Args
>
struct
DispatchContainer
<
T
<
Args
...
>>
{
template
<
typename
Iterator
,
typename
F
>
inline
void
operator
()(
Iterator
begin
,
Iterator
end
,
F
&&
f
)
{
return
dispatch_container
<
Args
...
>
(
begin
,
end
,
std
::
forward
<
F
>
(
f
));
}
};
template
<
class
T
>
struct
DispatchContainerNoexcept
;
template
<
template
<
class
...
>
class
T
,
class
...
Args
>
struct
DispatchContainerNoexcept
<
T
<
Args
...
>>
{
template
<
typename
Iterator
,
typename
F
>
inline
bool
operator
()(
Iterator
begin
,
Iterator
end
,
F
&&
f
)
{
return
dispatch_container_noexcept
<
Args
...
>
(
begin
,
end
,
std
::
forward
<
F
>
(
f
));
}
};
template
<
class
T
>
struct
DispatchInt
;
// Args should be std::integral_constant<int, value>
// you need to use type_container<std::integral_constant<int, value>...>
// as template parameter of DispatchInt.
// tv::mp_list_c is ok.
template
<
template
<
class
...
>
class
T
,
class
...
Args
>
struct
DispatchInt
<
T
<
Args
...
>>
{
template
<
typename
F
>
inline
void
operator
()(
int
t
,
F
&&
f
)
{
return
dispatch_int
<
Args
::
value
...
>
(
t
,
std
::
forward
<
F
>
(
f
));
}
template
<
typename
F
,
typename
BinaryPredicate
>
inline
void
operator
()(
int
t
,
BinaryPredicate
p
,
F
&&
f
)
{
return
dispatch_int
<
Args
::
value
...
>
(
t
,
p
,
std
::
forward
<
F
>
(
f
));
}
};
constexpr
size_t
kTensorMaxDim
=
10
;
using
TensorShape
=
ShapeBase
<
kTensorMaxDim
,
int64_t
>
;
struct
Tensor
{
Tensor
()
{}
Tensor
(
TensorShape
shape
,
TensorShape
stride
,
DType
dtype
,
int
device
=
-
1
,
bool
pinned
=
false
,
bool
managed
=
false
)
:
dtype_
(
dtype
)
{
TV_ASSERT_INVALID_ARG
(
!
shape
.
empty
(),
"dont support empty shape"
);
storage_
=
std
::
make_shared
<
detail
::
TensorStorage
<
uint8_t
>>
(
shape
.
size
()
*
detail
::
sizeof_dtype
(
dtype
),
device
,
managed
,
pinned
);
shape_
=
shape
;
stride_
=
stride
;
}
Tensor
(
TensorShape
shape
,
DType
dtype
,
int
device
=
-
1
,
bool
pinned
=
false
,
bool
managed
=
false
)
:
dtype_
(
dtype
)
{
TV_ASSERT_INVALID_ARG
(
!
shape
.
empty
(),
"dont support empty shape"
);
storage_
=
std
::
make_shared
<
detail
::
TensorStorage
<
uint8_t
>>
(
shape
.
size
()
*
detail
::
sizeof_dtype
(
dtype
),
device
,
managed
,
pinned
);
shape_
=
shape
;
stride_
=
shape
.
stride_rowmajor
();
}
Tensor
(
void
*
ptr
,
TensorShape
shape
,
TensorShape
stride
,
DType
dtype
,
int
device
=
-
1
)
:
dtype_
(
dtype
)
{
TV_ASSERT_INVALID_ARG
(
!
shape
.
empty
(),
"dont support empty shape"
);
storage_
=
std
::
make_shared
<
detail
::
TensorStorage
<
uint8_t
>>
(
reinterpret_cast
<
uint8_t
*>
(
ptr
),
shape
.
size
()
*
detail
::
sizeof_dtype
(
dtype
),
device
);
shape_
=
shape
;
stride_
=
stride
;
}
Tensor
(
void
*
ptr
,
TensorShape
shape
,
DType
dtype
,
int
device
=
-
1
)
:
dtype_
(
dtype
)
{
TV_ASSERT_INVALID_ARG
(
!
shape
.
empty
(),
"dont support empty shape"
);
storage_
=
std
::
make_shared
<
detail
::
TensorStorage
<
uint8_t
>>
(
reinterpret_cast
<
uint8_t
*>
(
ptr
),
shape
.
size
()
*
detail
::
sizeof_dtype
(
dtype
),
device
);
shape_
=
shape
;
stride_
=
shape
.
stride_rowmajor
();
}
Tensor
(
const
void
*
ptr
,
TensorShape
shape
,
TensorShape
stride
,
DType
dtype
,
int
device
=
-
1
)
:
dtype_
(
dtype
),
writeable_
(
false
)
{
TV_ASSERT_INVALID_ARG
(
!
shape
.
empty
(),
"dont support empty shape"
);
storage_
=
std
::
make_shared
<
detail
::
TensorStorage
<
uint8_t
>>
(
reinterpret_cast
<
uint8_t
*>
(
const_cast
<
void
*>
(
ptr
)),
shape
.
size
()
*
detail
::
sizeof_dtype
(
dtype
),
device
);
shape_
=
shape
;
stride_
=
stride
;
}
Tensor
(
const
void
*
ptr
,
TensorShape
shape
,
DType
dtype
,
int
device
=
-
1
)
:
dtype_
(
dtype
),
writeable_
(
false
)
{
TV_ASSERT_INVALID_ARG
(
!
shape
.
empty
(),
"dont support empty shape"
);
storage_
=
std
::
make_shared
<
detail
::
TensorStorage
<
uint8_t
>>
(
reinterpret_cast
<
uint8_t
*>
(
const_cast
<
void
*>
(
ptr
)),
shape
.
size
()
*
detail
::
sizeof_dtype
(
dtype
),
device
);
shape_
=
shape
;
stride_
=
shape
.
stride_rowmajor
();
}
Tensor
(
std
::
initializer_list
<
int32_t
>
init
)
:
Tensor
({
int
(
init
.
size
())},
tv
::
int32
)
{
std
::
copy
(
init
.
begin
(),
init
.
end
(),
data
<
int32_t
>
());
}
Tensor
(
std
::
initializer_list
<
int64_t
>
init
)
:
Tensor
({
int
(
init
.
size
())},
tv
::
int64
)
{
std
::
copy
(
init
.
begin
(),
init
.
end
(),
data
<
int64_t
>
());
}
Tensor
(
std
::
initializer_list
<
float
>
init
)
:
Tensor
({
int
(
init
.
size
())},
tv
::
float32
)
{
std
::
copy
(
init
.
begin
(),
init
.
end
(),
data
<
float
>
());
}
Tensor
(
std
::
initializer_list
<
double
>
init
)
:
Tensor
({
int
(
init
.
size
())},
tv
::
float64
)
{
std
::
copy
(
init
.
begin
(),
init
.
end
(),
data
<
double
>
());
}
template
<
typename
T
,
int
Rank
=
-
1
,
template
<
class
>
class
PtrTraits
=
DefaultPtrTraits
,
typename
Tindex
=
int
,
typename
std
::
enable_if
<
(
Rank
>
0
),
int
>::
type
=
0
>
TensorView
<
T
,
Rank
,
PtrTraits
,
Tindex
>
tview
()
{
using
tv_shape_t
=
typename
TensorView
<
T
,
Rank
,
PtrTraits
,
Tindex
>::
tv_shape_t
;
writable_check
();
static_assert
(
Rank
==
-
1
||
Rank
>
0
,
"error"
);
TV_ASSERT_RT_ERR
(
dtype_
==
type_v
<
T
>
,
"error"
);
tv_shape_t
shape
(
Rank
),
stride
(
Rank
);
for
(
int
i
=
0
;
i
<
Rank
;
++
i
)
{
shape
[
i
]
=
shape_
[
i
];
stride
[
i
]
=
stride_
[
i
];
}
return
TensorView
<
T
,
Rank
,
PtrTraits
,
Tindex
>
(
reinterpret_cast
<
T
*>
(
data
<
T
>
()),
shape
,
stride
);
}
template
<
typename
T
,
int
Rank
=
-
1
,
template
<
class
>
class
PtrTraits
=
DefaultPtrTraits
,
typename
Tindex
=
int
,
typename
std
::
enable_if
<
Rank
==
-
1
,
int
>
::
type
=
0
>
TensorView
<
T
,
Rank
,
PtrTraits
,
Tindex
>
tview
()
{
writable_check
();
static_assert
(
Rank
==
-
1
||
Rank
>
0
,
"error"
);
TV_ASSERT_RT_ERR
(
dtype_
==
type_v
<
T
>
,
"error"
);
ShapeBase
<
TV_MAX_DIM
,
Tindex
>
shape
(
ndim
()),
stride
(
ndim
());
for
(
size_t
i
=
0
;
i
<
ndim
();
++
i
)
{
shape
[
i
]
=
shape_
[
i
];
stride
[
i
]
=
stride_
[
i
];
}
return
TensorView
<
T
,
Rank
,
PtrTraits
,
Tindex
>
(
reinterpret_cast
<
T
*>
(
data
<
T
>
()),
shape
,
stride
);
}
template
<
typename
T
,
int
Rank
=
-
1
,
template
<
class
>
class
PtrTraits
=
DefaultPtrTraits
,
typename
Tindex
=
int
,
typename
std
::
enable_if
<
(
Rank
>
0
),
int
>::
type
=
0
>
TensorView
<
const
std
::
remove_const_t
<
T
>
,
Rank
,
PtrTraits
,
Tindex
>
tview
()
const
{
static_assert
(
Rank
==
-
1
||
Rank
>
0
,
"error"
);
if
(
Rank
>
0
)
{
TV_ASSERT_RT_ERR
(
Rank
==
ndim
(),
"error"
);
}
TV_ASSERT_RT_ERR
(
dtype_
==
type_v
<
T
>
,
"error"
);
ShapeBase
<
Rank
==
-
1
?
TV_MAX_DIM
:
Rank
,
Tindex
>
shape
(
Rank
),
stride
(
Rank
);
for
(
int
i
=
0
;
i
<
Rank
;
++
i
)
{
shape
[
i
]
=
shape_
[
i
];
stride
[
i
]
=
stride_
[
i
];
}
return
TensorView
<
const
std
::
remove_const_t
<
T
>
,
Rank
,
PtrTraits
,
Tindex
>
(
reinterpret_cast
<
const
std
::
remove_const_t
<
T
>
*>
(
data
<
T
>
()),
shape
,
stride
);
}
template
<
typename
T
,
int
Rank
=
-
1
,
template
<
class
>
class
PtrTraits
=
DefaultPtrTraits
,
typename
Tindex
=
int
,
typename
std
::
enable_if
<
Rank
==
-
1
,
int
>
::
type
=
0
>
TensorView
<
const
std
::
remove_const_t
<
T
>
,
Rank
,
PtrTraits
,
Tindex
>
tview
()
const
{
static_assert
(
Rank
==
-
1
||
Rank
>
0
,
"error"
);
if
(
Rank
>
0
)
{
TV_ASSERT_RT_ERR
(
Rank
==
ndim
(),
"error"
);
}
TV_ASSERT_RT_ERR
(
dtype_
==
type_v
<
T
>
,
"error"
);
ShapeBase
<
TV_MAX_DIM
,
Tindex
>
shape
(
ndim
()),
stride
(
ndim
());
for
(
int
i
=
0
;
i
<
int
(
ndim
());
++
i
)
{
shape
[
i
]
=
shape_
[
i
];
stride
[
i
]
=
stride_
[
i
];
}
return
TensorView
<
const
std
::
remove_const_t
<
T
>
,
Rank
,
PtrTraits
,
Tindex
>
(
reinterpret_cast
<
const
std
::
remove_const_t
<
T
>
*>
(
data
<
T
>
()),
shape
,
stride
);
}
template
<
class
...
Inds
>
Tensor
view
(
Inds
...
newShapes
)
const
{
static_assert
(
sizeof
...(
newShapes
)
>
0
,
"dont support empty for now"
);
TensorShape
shape
{
int
(
newShapes
)...};
bool
found_minus_1
=
false
;
for
(
size_t
i
=
0
;
i
<
shape
.
ndim
();
++
i
)
{
if
(
!
found_minus_1
)
{
if
(
shape
[
i
]
==
-
1
)
{
shape
[
i
]
=
1
;
shape
[
i
]
=
size
()
/
shape
.
size
();
found_minus_1
=
true
;
}
else
{
TV_ASSERT_INVALID_ARG
(
shape
[
i
]
>
0
,
"shape except -1 must larger than 0"
);
}
}
else
{
TV_ASSERT_INVALID_ARG
(
shape
[
i
]
>
0
,
"multiple -1 in your argument."
);
}
}
TV_ASSERT_RT_ERR
(
shape
.
size
()
==
size
(),
"error"
);
Tensor
res
(
*
this
);
res
.
shape_
=
shape
;
res
.
stride_
=
shape
.
stride_rowmajor
();
return
res
;
}
Tensor
view
(
TensorShape
shape
)
const
{
TV_ASSERT_RT_ERR
(
shape
.
size
()
==
size
(),
"error"
);
Tensor
res
(
*
this
);
res
.
shape_
=
shape
;
res
.
stride_
=
shape
.
stride_rowmajor
();
return
res
;
}
Tensor
operator
[](
int64_t
index
)
{
TV_ASSERT_INVALID_ARG
(
ndim
()
>
1
,
"error"
);
if
(
index
<
0
)
{
index
+=
dim
(
0
);
}
TV_ASSERT_INVALID_ARG
(
index
<
dim
(
0
),
"error"
);
Tensor
res
=
Tensor
();
res
.
storage_
=
storage_
;
res
.
shape_
=
shape_
.
subshape
(
1
);
res
.
offset_
=
offset_
+
index
*
stride_
[
0
];
res
.
stride_
=
stride_
.
subshape
(
1
);
res
.
writeable_
=
writeable_
;
return
res
;
}
Tensor
squeeze
()
const
{
return
view
(
shape_
.
squeeze
());
}
Tensor
squeeze
(
int
axis
)
const
{
if
(
axis
<
0
)
{
axis
=
ndim
()
+
axis
;
}
return
view
(
shape_
.
squeeze
(
axis
));
}
Tensor
unsqueeze
(
int
axis
)
const
{
if
(
axis
<
0
)
{
axis
=
ndim
()
+
axis
;
}
return
view
(
shape_
.
unsqueeze
(
axis
));
}
bool
pinned
()
const
{
return
storage_
->
pinned
();
}
Tensor
slice_first_axis
(
int
start
,
int
end
)
const
{
TV_ASSERT_INVALID_ARG
(
contiguous_
,
"only support contiguous for now"
);
if
(
start
<
0
)
{
start
=
shape_
[
0
]
+
start
;
}
if
(
end
<
0
)
{
end
=
shape_
[
0
]
+
end
;
}
TV_ASSERT_INVALID_ARG
(
start
<
shape_
[
0
],
"start must small than dim 0"
);
TV_ASSERT_INVALID_ARG
(
start
<
end
,
"start must small than end"
);
size_t
new_offset
=
start
*
shape_
.
prod
(
1
)
*
itemsize
();
Tensor
res
(
*
this
);
TensorShape
newshape
(
shape_
);
newshape
[
0
]
=
end
-
start
;
res
.
shape_
=
newshape
;
res
.
stride_
=
stride_
;
res
.
offset_
=
new_offset
;
return
res
;
}
bool
empty
()
const
{
return
storage_
->
empty
();
}
DType
dtype
()
const
{
return
dtype_
;
}
int
device
()
const
{
return
storage_
->
device
();
}
size_t
ndim
()
const
{
return
shape_
.
ndim
();
}
const
TensorShape
&
shape
()
const
{
return
shape_
;
}
const
TensorShape
&
sizes
()
const
{
return
shape_
;
}
const
TensorShape
&
stride
()
const
{
return
stride_
;
}
int
dim
(
int
idx
)
const
{
if
(
idx
<
0
)
{
TV_ASSERT_RT_ERR
(
shape_
.
size
()
+
idx
<
shape_
.
size
(),
idx
,
shape_
);
return
shape_
[
shape_
.
size
()
+
idx
];
}
else
{
TV_ASSERT_RT_ERR
(
idx
<
int
(
shape_
.
size
()),
idx
,
shape_
);
return
shape_
[
idx
];
}
}
const
uint8_t
*
raw_data
()
const
{
return
storage_
->
data
()
+
offset_
;
}
size_t
raw_size
()
const
{
return
size
()
*
itemsize
();
}
size_t
size
()
const
{
return
shape_
.
size
();
}
size_t
size
(
int64_t
idx
)
const
{
return
dim
(
idx
);
}
size_t
itemsize
()
const
{
return
detail
::
sizeof_dtype
(
dtype_
);
}
Tensor
&
zero_
()
{
writable_check
();
storage_
->
zero_
();
return
*
this
;
}
uint8_t
*
raw_data
()
{
writable_check
();
return
storage_
->
data
()
+
offset_
;
}
template
<
typename
T
>
Tensor
&
fill_
(
T
value
)
{
writable_check
();
TV_ASSERT_RT_ERR
(
device
()
==
-
1
,
"error"
);
Dispatch
<
detail
::
all_tensor_types_t
>
()(
dtype_
,
[
&
](
auto
I
)
{
using
Treal
=
TV_DECLTYPE
(
I
);
if
(
std
::
is_convertible
<
T
,
Treal
>::
value
)
{
auto
ptr
=
reinterpret_cast
<
Treal
*>
(
raw_data
());
std
::
fill
(
ptr
,
ptr
+
size
(),
Treal
(
value
));
}
else
{
TV_THROW_INVALID_ARG
(
"not convertable from"
,
type_s
<
T
>
,
"to"
,
type_s
<
Treal
>
);
}
});
return
*
this
;
}
template
<
typename
T
>
T
*
data
()
{
TV_ASSERT_RT_ERR
(
dtype_
==
type_v
<
T
>
,
"error"
);
writable_check
();
return
reinterpret_cast
<
T
*>
(
raw_data
());
}
template
<
typename
T
>
const
T
*
data
()
const
{
TV_ASSERT_RT_ERR
(
dtype_
==
type_v
<
T
>
,
"error"
);
return
reinterpret_cast
<
const
T
*>
(
raw_data
());
}
template
<
typename
T
>
T
*
data_ptr
()
{
return
data
<
T
>
();
}
template
<
typename
T
>
const
T
*
data_ptr
()
const
{
return
data
<
T
>
();
}
void
*
data_ptr
()
{
return
reinterpret_cast
<
void
*>
(
raw_data
());
}
const
void
*
data_ptr
()
const
{
return
reinterpret_cast
<
const
void
*>
(
raw_data
());
}
void
copy_
(
const
Tensor
&
tensor
)
{
writable_check
();
TV_ASSERT_INVALID_ARG
(
contiguous_
,
"only support contiguous for now"
);
TV_ASSERT_RT_ERR
(
!
empty
()
&&
!
tensor
.
empty
(),
"must not empty"
);
TV_ASSERT_RT_ERR
(
size
()
==
tensor
.
size
(),
"must have same size"
);
TV_ASSERT_RT_ERR
(
dtype
()
==
tensor
.
dtype
(),
"must have same dtype"
,
detail
::
typeString
(
dtype
()),
detail
::
typeString
(
tensor
.
dtype
()));
if
(
device
()
==
-
1
&&
tensor
.
device
()
==
-
1
)
{
#ifdef TV_CUDA
host2host
(
storage_
->
data
(),
tensor
.
raw_data
(),
size
()
*
detail
::
sizeof_dtype
(
dtype_
));
#else
std
::
copy
(
tensor
.
raw_data
(),
tensor
.
raw_data
()
+
size
()
*
detail
::
sizeof_dtype
(
dtype_
),
storage_
->
data
());
#endif
}
#ifdef TV_CUDA
else
if
(
device
()
>=
0
&&
tensor
.
device
()
==
-
1
)
{
host2dev
(
storage_
->
data
(),
tensor
.
raw_data
(),
size
()
*
detail
::
sizeof_dtype
(
dtype_
));
}
else
if
(
device
()
==
-
1
&&
tensor
.
device
()
>=
0
)
{
dev2host
(
storage_
->
data
(),
tensor
.
raw_data
(),
size
()
*
detail
::
sizeof_dtype
(
dtype_
));
}
else
if
(
device
()
>=
0
&&
tensor
.
device
()
>=
0
)
{
dev2dev
(
storage_
->
data
(),
tensor
.
raw_data
(),
size
()
*
detail
::
sizeof_dtype
(
dtype_
));
}
#endif
else
{
TV_THROW_RT_ERR
(
"only support cpu tensor"
);
}
}
#ifdef TV_CUDA
void
copy_
(
const
Tensor
&
tensor
,
cudaStream_t
stream
)
{
writable_check
();
TV_ASSERT_INVALID_ARG
(
contiguous_
,
"only support contiguous for now"
);
TV_ASSERT_RT_ERR
(
!
empty
()
&&
!
tensor
.
empty
(),
"must not empty"
);
TV_ASSERT_RT_ERR
(
size
()
==
tensor
.
size
(),
"must have same size"
);
TV_ASSERT_RT_ERR
(
dtype
()
==
tensor
.
dtype
(),
"must have same dtype"
,
detail
::
typeString
(
dtype
()),
detail
::
typeString
(
tensor
.
dtype
()));
if
(
device
()
==
-
1
&&
tensor
.
device
()
==
-
1
)
{
host2host
(
storage_
->
data
(),
tensor
.
raw_data
(),
size
()
*
detail
::
sizeof_dtype
(
dtype_
),
stream
);
}
else
if
(
device
()
>=
0
&&
tensor
.
device
()
==
-
1
)
{
host2dev
(
storage_
->
data
(),
tensor
.
raw_data
(),
size
()
*
detail
::
sizeof_dtype
(
dtype_
),
stream
);
}
else
if
(
device
()
==
-
1
&&
tensor
.
device
()
>=
0
)
{
dev2host
(
storage_
->
data
(),
tensor
.
raw_data
(),
size
()
*
detail
::
sizeof_dtype
(
dtype_
),
stream
);
}
else
if
(
device
()
>=
0
&&
tensor
.
device
()
>=
0
)
{
dev2dev
(
storage_
->
data
(),
tensor
.
raw_data
(),
size
()
*
detail
::
sizeof_dtype
(
dtype_
),
stream
);
}
else
{
TV_THROW_RT_ERR
(
"only support cpu tensor"
);
}
}
#endif
Tensor
cpu
()
const
{
if
(
storage_
->
device
()
==
-
1
)
{
// cpu() should always copy tensor.
return
clone
();
}
Tensor
res
(
shape_
,
stride_
,
dtype_
,
-
1
,
storage_
->
managed
());
res
.
copy_
(
*
this
);
return
res
;
}
template
<
typename
T
>
void
copy_
(
const
TensorView
<
T
>
&
tensor
,
int
device
)
{
writable_check
();
TV_ASSERT_INVALID_ARG
(
contiguous_
,
"only support contiguous for now"
);
Tensor
src
=
from_blob
(
tensor
,
device
);
return
copy_
(
src
);
}
Tensor
&
operator
=
(
const
Tensor
&
tensor
)
{
dtype_
=
tensor
.
dtype_
;
storage_
=
tensor
.
storage_
;
shape_
=
tensor
.
shape_
;
writeable_
=
tensor
.
writeable_
;
offset_
=
tensor
.
offset_
;
stride_
=
tensor
.
stride_
;
return
*
this
;
}
Tensor
(
const
Tensor
&
tensor
)
{
dtype_
=
tensor
.
dtype_
;
storage_
=
tensor
.
storage_
;
shape_
=
tensor
.
shape_
;
writeable_
=
tensor
.
writeable_
;
offset_
=
tensor
.
offset_
;
stride_
=
tensor
.
stride_
;
}
Tensor
clone
(
bool
pinned
=
false
)
const
{
TV_ASSERT_RT_ERR
(
!
empty
(),
"clone a empty tensor"
);
TV_ASSERT_INVALID_ARG
(
contiguous_
,
"only support contiguous for now"
);
Tensor
newtensor
(
shape_
,
stride_
,
dtype_
,
device
(),
pinned
,
storage_
->
managed
());
newtensor
.
copy_
(
*
this
);
return
newtensor
;
}
Tensor
astype
(
DType
dtype
)
{
if
(
dtype
==
dtype_
)
{
return
clone
();
}
TV_ASSERT_INVALID_ARG
(
device
()
==
-
1
,
"only support cpu tensor"
);
TV_ASSERT_INVALID_ARG
(
!
empty
(),
"can't be used in empty tensor"
);
TV_ASSERT_INVALID_ARG
(
contiguous_
,
"only support contiguous for now"
);
auto
tensor
=
Tensor
();
Dispatch
<
detail
::
all_tensor_types_t
>
()(
dtype
,
[
&
](
auto
Idst
)
{
using
Tdst
=
TV_DECLTYPE
(
Idst
);
Dispatch
<
detail
::
all_tensor_types_t
>
()(
this
->
dtype_
,
[
&
](
auto
Icur
)
{
using
Tcur
=
TV_DECLTYPE
(
Icur
);
if
(
std
::
is_convertible
<
Tcur
,
Tdst
>::
value
)
{
auto
ptr
=
this
->
data
<
Tcur
>
();
tensor
=
Tensor
(
this
->
shape_
,
this
->
stride_
,
dtype
,
this
->
device
(),
this
->
pinned
(),
this
->
storage_
->
managed
());
std
::
copy
(
ptr
,
ptr
+
this
->
size
(),
tensor
.
data
<
Tdst
>
());
}
else
{
TV_THROW_INVALID_ARG
(
"not convertable from"
,
type_s
<
Tcur
>
,
"to"
,
type_s
<
Tdst
>
);
}
});
});
return
tensor
;
}
template
<
class
...
Ts
,
typename
F
>
inline
void
dispatch
(
F
&&
f
)
{
return
tv
::
dispatch
<
Ts
...
>
(
dtype_
,
std
::
forward
<
F
>
(
f
));
}
protected:
inline
void
writable_check
()
{
TV_ASSERT_RT_ERR
(
writeable_
,
"you cant do non-const operation when not writable"
);
}
DType
dtype_
;
std
::
shared_ptr
<
detail
::
TensorStorage
<
uint8_t
>>
storage_
;
TensorShape
shape_
;
size_t
offset_
=
0
;
TensorShape
stride_
;
private:
bool
writeable_
=
true
;
bool
contiguous_
=
true
;
};
template
<
typename
Os
>
Os
&
operator
<<
(
Os
&
os
,
const
Tensor
&
tensor
)
{
TV_ASSERT_INVALID_ARG
(
tensor
.
device
()
==
-
1
,
"must be cpu tensor"
);
Dispatch
<
detail
::
all_tensor_types_t
>
()(
tensor
.
dtype
(),
[
&
](
auto
I
)
{
using
T
=
TV_DECLTYPE
(
I
);
std
::
stringstream
ss
;
if
(
std
::
is_same
<
T
,
float
>::
value
||
std
::
is_same
<
T
,
double
>::
value
)
{
ss
<<
std
::
setprecision
(
4
);
}
os
<<
tensor
.
tview
<
T
,
-
1
,
DefaultPtrTraits
,
int64_t
>
().
repr
(
ss
);
});
return
os
;
}
inline
Tensor
from_blob
(
void
*
ptr
,
TensorShape
shape
,
DType
dtype
,
int
device
)
{
return
Tensor
(
ptr
,
shape
,
dtype
,
device
);
}
inline
Tensor
from_blob
(
const
void
*
ptr
,
TensorShape
shape
,
DType
dtype
,
int
device
)
{
return
Tensor
(
ptr
,
shape
,
dtype
,
device
);
}
}
// namespace tv
\ No newline at end of file
include/tensorview/tensorview.h
deleted
100644 → 0
View file @
fad30002
// Copyright 2019-2020 Yan Yan
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "common.h"
#include "prettyprint.h"
#include <algorithm>
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <iterator>
#include <memory>
#include <sstream>
#include <type_traits>
#include <vector>
#ifdef TV_CUDA
#include <cuda_runtime_api.h>
#endif
namespace
tv
{
#if (defined(__clang__) && defined(__CUDA__)) || defined(__NVCC__)
#define TV_HOST_DEVICE_INLINE __forceinline__ __device__ __host__
#define TV_DEVICE_INLINE __forceinline__ __device__
#define TV_HOST_DEVICE __device__ __host__
#define TV_ASSERT(expr) assert(expr)
#elif defined(__CUDACC_RTC__)
#define TV_ASSERT(expr) assert(expr)
#define TV_HOST_DEVICE_INLINE __forceinline__ __device__
#define TV_DEVICE_INLINE __forceinline__ __device__
#define TV_HOST_DEVICE __device__ __host__
#else
#define TV_ASSERT(x) assert(x)
#define TV_HOST_DEVICE_INLINE inline
#define TV_HOST_DEVICE
#endif
#define TV_REQUIRE(expr, ...) \
{ \
if (!(expr)) { \
printf(__VA_ARGS__); \
assert(expr); \
} \
}
#define TV_CHECK_CUDA_ERR() \
{ \
auto __macro_err = cudaGetLastError(); \
if (__macro_err != cudaSuccess) { \
std::stringstream __macro_s; \
__macro_s << __FILE__ << " " << __LINE__ << "\n"; \
__macro_s << "cuda execution failed with error " << __macro_err; \
TV_BACKTRACE_PRINT(__macro_s); \
throw std::runtime_error(__macro_s.str()); \
} \
}
#define TV_CHECK_CUDA_ERR_V2(...) \
{ \
auto __macro_err = cudaGetLastError(); \
if (__macro_err != cudaSuccess) { \
std::stringstream __macro_s; \
__macro_s << __FILE__ << " " << __LINE__ << "\n"; \
__macro_s << "cuda execution failed with error " << __macro_err; \
__macro_s << " " << cudaGetErrorString(__macro_err) << "\n"; \
tv::sstream_print(__macro_s, __VA_ARGS__); \
TV_BACKTRACE_PRINT(__macro_s); \
throw std::runtime_error(__macro_s.str()); \
} \
}
#ifdef TV_CUDA
struct
GPU
{
GPU
(
cudaStream_t
s
=
0
)
:
mStream
(
s
)
{}
virtual
cudaStream_t
getStream
()
const
{
return
mStream
;
}
cudaStream_t
mStream
=
0
;
};
#endif
struct
CPU
{};
#ifndef TV_MAX_DIM
#define TV_MAX_DIM 6
#endif
template
<
typename
T
>
struct
DefaultPtrTraits
{
typedef
T
*
type
;
};
#if defined(__CUDACC__) || defined(__HIPCC__)
template
<
typename
T
>
struct
RestrictPtrTraits
{
typedef
T
*
__restrict__
type
;
};
#endif
/*
template <typename T>
constexpr size_t calc_align(size_t ndim)
{
if (ndim * sizeof(T) == 1)
return 1;
else if (ndim * sizeof(T) == 2)
return 2;
else if (ndim * sizeof(T) <= 4 && ndim * sizeof(T) > 2)
return 4;
else if (ndim * sizeof(T) <= 8 && ndim * sizeof(T) > 4)
return 8;
else if (ndim * sizeof(T) <= 16 && ndim * sizeof(T) > 8)
return 16;
else if (ndim * sizeof(T) <= 32 && ndim * sizeof(T) > 16)
return 32;
else
return 64;
}
*/
namespace
detail
{
template
<
typename
_InIter
>
using
_RequireInputIter
=
typename
std
::
enable_if
<
std
::
is_convertible
<
typename
std
::
iterator_traits
<
_InIter
>::
iterator_category
,
std
::
input_iterator_tag
>::
value
>::
type
;
}
template
<
typename
T
,
size_t
MaxDim
=
TV_MAX_DIM
>
struct
/*alignas(calc_align<T>(MaxDim))*/
SimpleVector
{
public:
TV_HOST_DEVICE_INLINE
SimpleVector
(){};
TV_HOST_DEVICE_INLINE
SimpleVector
(
size_t
count
,
T
init
=
T
())
:
size_
(
count
)
{
for
(
size_t
i
=
0
;
i
<
count
;
++
i
)
{
array_
[
i
]
=
init
;
}
};
template
<
typename
Iterator
,
typename
=
detail
::
_RequireInputIter
<
Iterator
>
>
SimpleVector
(
Iterator
first
,
Iterator
last
)
{
size_
=
0
;
for
(;
first
!=
last
;
++
first
)
{
if
(
size_
>=
MaxDim
)
{
TV_THROW_INVALID_ARG
(
"iterator too long"
);
}
array_
[
size_
++
]
=
*
first
;
}
};
TV_HOST_DEVICE_INLINE
SimpleVector
(
std
::
initializer_list
<
T
>
q
)
{
TV_ASSERT
(
q
.
size
()
<=
MaxDim
);
size_
=
0
;
for
(
T
s
:
q
)
{
array_
[
size_
++
]
=
s
;
}
size_
=
q
.
size
();
}
SimpleVector
(
const
std
::
vector
<
T
>
&
arr
)
{
TV_ASSERT
(
arr
.
size
()
<=
MaxDim
);
for
(
size_t
i
=
0
;
i
<
arr
.
size
();
++
i
)
{
array_
[
i
]
=
arr
[
i
];
}
size_
=
arr
.
size
();
}
TV_HOST_DEVICE_INLINE
SimpleVector
(
const
SimpleVector
<
T
,
MaxDim
>
&
arr
)
{
TV_ASSERT
(
arr
.
size
()
<=
MaxDim
);
for
(
size_t
i
=
0
;
i
<
arr
.
size
();
++
i
)
{
array_
[
i
]
=
arr
[
i
];
}
size_
=
arr
.
size
();
}
TV_HOST_DEVICE_INLINE
T
&
operator
[](
int
idx
)
{
#ifdef TV_DEBUG
TV_ASSERT
(
idx
>=
0
&&
idx
<
size_
);
#endif
return
array_
[
idx
];
}
TV_HOST_DEVICE_INLINE
const
T
&
operator
[](
int
idx
)
const
{
#ifdef TV_DEBUG
TV_ASSERT
(
idx
>=
0
&&
idx
<
size_
);
#endif
return
array_
[
idx
];
}
TV_HOST_DEVICE_INLINE
void
push_back
(
T
s
)
{
#ifdef TV_DEBUG
TV_ASSERT
(
size_
<
MaxDim
);
#endif
array_
[
size_
]
=
s
;
size_
++
;
}
TV_HOST_DEVICE_INLINE
void
pop_back
()
{
#ifdef TV_DEBUG
TV_ASSERT
(
size_
>
0
);
#endif
size_
--
;
}
TV_HOST_DEVICE_INLINE
size_t
size
()
const
{
return
size_
;
}
TV_HOST_DEVICE_INLINE
const
T
*
data
()
const
{
return
array_
;
}
TV_HOST_DEVICE_INLINE
T
*
data
()
{
return
array_
;
}
TV_HOST_DEVICE_INLINE
size_t
empty
()
const
{
return
size_
==
0
;
}
typedef
size_t
size_type
;
class
iterator
{
public:
typedef
iterator
self_type
;
typedef
T
value_type
;
typedef
T
&
reference
;
typedef
T
*
pointer
;
typedef
std
::
forward_iterator_tag
iterator_category
;
typedef
std
::
ptrdiff_t
difference_type
;
TV_HOST_DEVICE_INLINE
iterator
(
pointer
ptr
)
:
ptr_
(
ptr
)
{}
TV_HOST_DEVICE_INLINE
self_type
operator
++
(
int
junk
)
{
self_type
i
=
*
this
;
ptr_
++
;
return
i
;
}
TV_HOST_DEVICE_INLINE
self_type
operator
++
()
{
ptr_
++
;
return
*
this
;
}
TV_HOST_DEVICE_INLINE
reference
operator
*
()
{
return
*
ptr_
;
}
TV_HOST_DEVICE_INLINE
pointer
operator
->
()
{
return
ptr_
;
}
TV_HOST_DEVICE_INLINE
bool
operator
==
(
const
self_type
&
rhs
)
const
{
return
ptr_
==
rhs
.
ptr_
;
}
TV_HOST_DEVICE_INLINE
bool
operator
!=
(
const
self_type
&
rhs
)
const
{
return
ptr_
!=
rhs
.
ptr_
;
}
private:
pointer
ptr_
;
};
class
const_iterator
{
public:
typedef
const_iterator
self_type
;
typedef
T
value_type
;
typedef
const
T
&
reference
;
typedef
const
T
*
pointer
;
typedef
std
::
ptrdiff_t
difference_type
;
typedef
std
::
forward_iterator_tag
iterator_category
;
TV_HOST_DEVICE_INLINE
const_iterator
(
pointer
ptr
)
:
ptr_
(
ptr
)
{}
TV_HOST_DEVICE_INLINE
self_type
operator
++
(
int
junk
)
{
self_type
i
=
*
this
;
ptr_
++
;
return
i
;
}
TV_HOST_DEVICE_INLINE
self_type
operator
++
()
{
ptr_
++
;
return
*
this
;
}
TV_HOST_DEVICE_INLINE
reference
operator
*
()
{
return
*
ptr_
;
}
TV_HOST_DEVICE_INLINE
pointer
operator
->
()
{
return
ptr_
;
}
TV_HOST_DEVICE_INLINE
bool
operator
==
(
const
self_type
&
rhs
)
const
{
return
ptr_
==
rhs
.
ptr_
;
}
TV_HOST_DEVICE_INLINE
bool
operator
!=
(
const
self_type
&
rhs
)
const
{
return
ptr_
!=
rhs
.
ptr_
;
}
private:
pointer
ptr_
;
};
TV_HOST_DEVICE_INLINE
iterator
begin
()
{
return
iterator
(
array_
);
}
TV_HOST_DEVICE_INLINE
iterator
end
()
{
return
iterator
(
array_
+
size_
);
}
TV_HOST_DEVICE_INLINE
const_iterator
begin
()
const
{
return
const_iterator
(
array_
);
}
TV_HOST_DEVICE_INLINE
const_iterator
end
()
const
{
return
const_iterator
(
array_
+
size_
);
}
TV_HOST_DEVICE_INLINE
const_iterator
cbegin
()
const
{
return
const_iterator
(
array_
);
}
TV_HOST_DEVICE_INLINE
const_iterator
cend
()
const
{
return
const_iterator
(
array_
+
size_
);
}
protected:
T
array_
[
MaxDim
];
size_t
size_
=
0
;
};
template
<
typename
T
,
size_t
MaxDim
>
bool
operator
==
(
const
SimpleVector
<
T
,
MaxDim
>
&
lfs
,
const
SimpleVector
<
T
,
MaxDim
>
&
rfs
)
{
if
(
lfs
.
size
()
!=
rfs
.
size
())
return
false
;
for
(
size_t
i
=
0
;
i
<
lfs
.
size
();
++
i
)
{
if
(
lfs
[
i
]
!=
rfs
[
i
])
return
false
;
}
return
true
;
}
template
<
typename
T
,
size_t
MaxDim
>
bool
operator
!=
(
const
SimpleVector
<
T
,
MaxDim
>
&
lfs
,
const
SimpleVector
<
T
,
MaxDim
>
&
rfs
)
{
return
!
(
lfs
==
rfs
);
}
struct
Slice
{
template
<
class
...
Integers
>
TV_HOST_DEVICE_INLINE
Slice
(
Integers
...
ints
)
{
static_assert
(
sizeof
...(
ints
)
<=
3
,
"slice init must smaller than 3"
);
SimpleVector
<
int
,
3
>
slices
{
int
(
ints
)...};
slices_
[
0
]
=
-
1
;
slices_
[
1
]
=
-
1
;
slices_
[
2
]
=
-
1
;
for
(
size_t
i
=
0
;
i
<
slices
.
size
();
++
i
)
{
slices_
[
i
]
=
slices
[
i
];
}
}
TV_HOST_DEVICE_INLINE
Slice
()
{
slices_
[
0
]
=
-
1
;
slices_
[
1
]
=
-
1
;
slices_
[
2
]
=
-
1
;
}
template
<
typename
T
>
TV_HOST_DEVICE_INLINE
Slice
(
std
::
initializer_list
<
T
>
slice
)
{
slices_
[
0
]
=
-
1
;
slices_
[
1
]
=
-
1
;
slices_
[
2
]
=
-
1
;
TV_ASSERT
(
slice
.
size
()
<=
3
);
int
idx
=
0
;
for
(
T
s
:
slice
)
{
slices_
[
idx
]
=
int
(
s
);
++
idx
;
}
}
TV_HOST_DEVICE_INLINE
int
&
operator
[](
int
idx
)
{
#ifdef TV_DEBUG
TV_ASSERT
(
idx
>=
0
&&
idx
<
3
);
#endif
return
slices_
[
idx
];
}
TV_HOST_DEVICE_INLINE
const
int
&
operator
[](
int
idx
)
const
{
#ifdef TV_DEBUG
TV_ASSERT
(
idx
>=
0
&&
idx
<
3
);
#endif
return
slices_
[
idx
];
}
protected:
int
slices_
[
3
];
};
template
<
size_t
MaxDim
=
TV_MAX_DIM
,
typename
Tindex
=
int
>
struct
ShapeBase
:
public
SimpleVector
<
Tindex
,
MaxDim
>
{
TV_HOST_DEVICE_INLINE
ShapeBase
()
:
SimpleVector
<
Tindex
,
MaxDim
>
(){};
TV_HOST_DEVICE_INLINE
ShapeBase
(
std
::
initializer_list
<
Tindex
>
shape
)
:
SimpleVector
<
Tindex
,
MaxDim
>
(
shape
)
{}
TV_HOST_DEVICE_INLINE
ShapeBase
(
SimpleVector
<
Tindex
,
MaxDim
>
vec
)
:
SimpleVector
<
Tindex
,
MaxDim
>
(
vec
)
{}
template
<
typename
T
,
template
<
class
...
>
class
Container
>
ShapeBase
(
Container
<
T
>
shape
)
:
SimpleVector
<
Tindex
,
MaxDim
>
(
shape
)
{}
TV_HOST_DEVICE_INLINE
ShapeBase
(
const
ShapeBase
<
MaxDim
>
&
shape
)
:
SimpleVector
<
Tindex
,
MaxDim
>
(
shape
)
{}
ShapeBase
(
const
std
::
vector
<
Tindex
>
&
arr
)
:
SimpleVector
<
Tindex
,
MaxDim
>
(
arr
)
{}
ShapeBase
<
MaxDim
,
Tindex
>
&
operator
=
(
const
ShapeBase
<
MaxDim
,
Tindex
>
&
shape
)
=
default
;
TV_HOST_DEVICE
ShapeBase
<
MaxDim
,
Tindex
>
subshape
(
Tindex
start
,
Tindex
end
)
const
{
#ifdef TV_DEBUG
TV_ASSERT
(
start
>=
0
&&
end
<=
this
->
size_
&&
end
>
start
);
#endif
ShapeBase
<
MaxDim
,
Tindex
>
shape
;
for
(
Tindex
i
=
start
;
i
<
end
;
++
i
)
{
shape
.
push_back
(
this
->
array_
[
i
]);
}
return
shape
;
}
TV_HOST_DEVICE
ShapeBase
<
MaxDim
,
Tindex
>
subshape
(
Tindex
start
)
const
{
#ifdef TV_DEBUG
TV_ASSERT
(
start
>=
0
&&
start
<=
this
->
size_
);
#endif
ShapeBase
<
MaxDim
,
Tindex
>
shape
;
for
(
size_t
i
=
start
;
i
<
this
->
size_
;
++
i
)
{
shape
.
push_back
(
this
->
array_
[
i
]);
}
return
shape
;
}
TV_HOST_DEVICE
size_t
size
()
const
{
if
(
this
->
size_
==
0
)
return
0
;
size_t
s
=
1
;
for
(
int
i
=
0
;
i
<
int
(
this
->
size_
);
++
i
)
{
s
*=
this
->
array_
[
i
];
}
return
s
;
}
TV_HOST_DEVICE_INLINE
size_t
ndim
()
const
{
return
this
->
size_
;
}
TV_HOST_DEVICE
ShapeBase
<
MaxDim
,
Tindex
>
squeeze
()
const
{
ShapeBase
<
MaxDim
,
Tindex
>
shape
;
for
(
size_t
i
=
0
;
i
<
this
->
size_
;
++
i
)
{
if
(
this
->
array_
[
i
]
!=
1
)
shape
.
push_back
(
this
->
array_
[
i
]);
}
if
(
shape
.
empty
())
{
// dont support empty shape for now
shape
.
push_back
(
1
);
}
return
shape
;
}
template
<
size_t
MaxDim2
=
MaxDim
>
TV_HOST_DEVICE
ShapeBase
<
MaxDim2
,
Tindex
>
squeeze
(
int
dim
)
const
{
static_assert
(
MaxDim2
>=
MaxDim
-
1
,
"error"
);
ShapeBase
<
MaxDim2
,
Tindex
>
shape
;
for
(
size_t
i
=
0
;
i
<
this
->
size_
;
++
i
)
{
if
(
i
!=
size_t
(
dim
)
||
this
->
array_
[
i
]
!=
1
)
shape
.
push_back
(
this
->
array_
[
i
]);
}
return
shape
;
}
template
<
size_t
MaxDim2
=
MaxDim
>
TV_HOST_DEVICE
ShapeBase
<
MaxDim2
,
Tindex
>
unsqueeze
(
int
dim
)
const
{
static_assert
(
MaxDim2
>=
MaxDim
-
1
,
"error"
);
ShapeBase
<
MaxDim2
,
Tindex
>
shape
;
for
(
size_t
i
=
0
;
i
<
this
->
size_
;
++
i
)
{
if
(
i
==
size_t
(
dim
))
shape
.
push_back
(
1
);
shape
.
push_back
(
this
->
array_
[
i
]);
}
return
shape
;
}
TV_HOST_DEVICE
size_t
prod
(
Tindex
start
=
0
)
const
{
size_t
res
=
1
;
for
(
size_t
i
=
start
;
i
<
this
->
size_
;
++
i
)
{
res
*=
this
->
array_
[
i
];
}
return
res
;
}
template
<
size_t
MaxDim2
=
MaxDim
>
TV_HOST_DEVICE
ShapeBase
<
MaxDim2
,
Tindex
>
stride_rowmajor
()
{
static_assert
(
MaxDim2
>=
MaxDim
,
"error"
);
Tindex
p
=
Tindex
(
1
);
ShapeBase
<
MaxDim2
,
Tindex
>
res
(
this
->
size_
);
for
(
Tindex
i
=
this
->
size_
-
1
;
i
>=
0
;
--
i
)
{
res
[
i
]
=
p
;
p
*=
this
->
array_
[
i
];
}
return
res
;
}
};
using
Shape
=
ShapeBase
<
TV_MAX_DIM
,
int
>
;
template
<
class
...
Inds
>
TV_HOST_DEVICE_INLINE
unsigned
rowArrayIdx
(
std
::
vector
<
int
>
&
shape
,
Inds
...
indexes
)
{
unsigned
offset
=
0
;
unsigned
m
=
1
;
int
indexes_vec
[
sizeof
...(
indexes
)]
=
{
indexes
...};
#ifdef TV_DEBUG
TV_ASSERT
(
sizeof
...(
indexes
)
==
shape
.
size
());
#endif
#if defined(__CUDA_ARCH__)
#pragma unroll
#endif
for
(
int
i
=
sizeof
...(
indexes
)
-
1
;
i
>=
0
;
--
i
)
{
offset
+=
m
*
indexes_vec
[
i
];
m
*=
shape
[
i
];
}
return
offset
;
}
TV_HOST_DEVICE_INLINE
unsigned
rowArrayIdx
(
std
::
vector
<
int
>
&
shape
,
std
::
vector
<
int
>
&
indexes_vec
)
{
unsigned
offset
=
0
;
unsigned
m
=
1
;
for
(
int
i
=
shape
.
size
()
-
1
;
i
>=
0
;
--
i
)
{
offset
+=
m
*
indexes_vec
[
i
];
m
*=
shape
[
i
];
}
return
offset
;
}
template
<
class
...
Inds
>
TV_HOST_DEVICE_INLINE
unsigned
rowArrayIdx
(
const
Shape
&
shape
,
Inds
...
indexes
)
{
unsigned
offset
=
0
;
unsigned
m
=
1
;
int
indexes_vec
[
sizeof
...(
indexes
)]
=
{
indexes
...};
#if defined(__CUDA_ARCH__)
#pragma unroll
#endif
for
(
int
i
=
sizeof
...(
indexes
)
-
1
;
i
>=
0
;
--
i
)
{
offset
+=
m
*
indexes_vec
[
i
];
m
*=
shape
[
i
];
}
return
offset
;
}
TV_HOST_DEVICE_INLINE
unsigned
rowArrayIdx
(
const
Shape
&
shape
,
const
Shape
&
indexes_vec
)
{
unsigned
offset
=
0
;
unsigned
m
=
1
;
for
(
int
i
=
indexes_vec
.
ndim
()
-
1
;
i
>=
0
;
--
i
)
{
offset
+=
m
*
indexes_vec
[
i
];
m
*=
shape
[
i
];
}
return
offset
;
}
template
<
typename
Index
,
unsigned
NDim
>
TV_HOST_DEVICE_INLINE
unsigned
rowArrayIdx
(
const
Index
*
indexes
,
const
Index
*
shape
)
{
unsigned
offset
=
0
;
unsigned
m
=
1
;
#if defined(__CUDA_ARCH__)
#pragma unroll
#endif
for
(
int
i
=
NDim
-
1
;
i
>=
0
;
--
i
)
{
offset
+=
m
*
indexes
[
i
];
m
*=
shape
[
i
];
}
return
offset
;
}
template
<
typename
Index
,
unsigned
NDim
>
TV_HOST_DEVICE_INLINE
Index
rowArrayIdxInv
(
Index
index
,
Index
*
output
,
const
Index
*
shape
)
{
#pragma unroll
for
(
int
i
=
NDim
-
1
;
i
>=
0
;
--
i
)
{
output
[
i
]
=
index
%
shape
[
i
];
index
-=
output
[
i
];
index
/=
shape
[
i
];
}
return
index
;
}
template
<
typename
Index
>
TV_HOST_DEVICE
Index
rowArrayIdxInv
(
Index
index
,
Index
*
output
,
const
Index
*
shape
,
int
ndim
)
{
for
(
int
i
=
ndim
-
1
;
i
>=
0
;
--
i
)
{
output
[
i
]
=
index
%
shape
[
i
];
index
-=
output
[
i
];
index
/=
shape
[
i
];
}
return
index
;
}
template
<
int
N
>
struct
ArrayIndexRowMajorReverse
{
template
<
typename
TShape
,
typename
T
,
class
...
Ts
>
TV_HOST_DEVICE_INLINE
static
unsigned
run
(
const
TShape
*
shape
,
T
index
,
Ts
...
inds
)
{
return
index
+
shape
[
N
-
1
]
*
ArrayIndexRowMajorReverse
<
N
-
1
>::
run
(
shape
,
inds
...);
}
template
<
typename
T
,
class
...
Ts
>
TV_HOST_DEVICE_INLINE
static
unsigned
runShape
(
const
Shape
&
shape
,
T
index
,
Ts
...
inds
)
{
return
index
+
shape
[
N
-
1
]
*
ArrayIndexRowMajorReverse
<
N
-
1
>::
run
(
shape
,
inds
...);
}
};
template
<
>
struct
ArrayIndexRowMajorReverse
<
1
>
{
template
<
typename
TShape
,
typename
T
>
TV_HOST_DEVICE_INLINE
static
unsigned
run
(
const
TShape
*
shape
,
T
idx
)
{
return
idx
;
}
template
<
typename
T
>
TV_HOST_DEVICE_INLINE
static
unsigned
runShape
(
const
Shape
&
shape
,
T
idx
)
{
return
idx
;
}
};
template
<
int
N
,
int
Ndim
>
struct
ArrayIndexRowMajor
{
// this array index provide almost same compiled code. compile it in
// https://godbolt.org/ for more details.
template
<
typename
TShape
,
typename
Tinit
,
typename
T
,
class
...
Ts
>
TV_HOST_DEVICE_INLINE
static
unsigned
run
(
const
TShape
*
shape
,
Tinit
start
,
T
index
,
Ts
...
inds
)
{
return
ArrayIndexRowMajor
<
N
-
1
,
Ndim
>::
run
(
shape
,
(
index
+
start
)
*
shape
[
Ndim
-
N
+
1
],
inds
...);
}
template
<
typename
Tinit
,
typename
T
,
class
...
Ts
>
TV_HOST_DEVICE_INLINE
static
unsigned
runShape
(
const
Shape
&
shape
,
Tinit
start
,
T
index
,
Ts
...
inds
)
{
return
ArrayIndexRowMajor
<
N
-
1
,
Ndim
>::
runShape
(
shape
,
(
index
+
start
)
*
shape
[
Ndim
-
N
+
1
],
inds
...);
}
template
<
typename
TShape
,
typename
Tinit
>
TV_HOST_DEVICE_INLINE
static
unsigned
runPtrs
(
const
TShape
*
indexes
,
const
TShape
*
shape
,
Tinit
start
)
{
return
ArrayIndexRowMajor
<
N
-
1
,
Ndim
>::
runPtrs
(
indexes
,
shape
,
(
indexes
[
Ndim
-
N
]
+
start
)
*
shape
[
Ndim
-
N
+
1
]);
}
};
template
<
int
Ndim
>
struct
ArrayIndexRowMajor
<
1
,
Ndim
>
{
template
<
typename
TShape
,
typename
Tinit
,
typename
T
>
TV_HOST_DEVICE_INLINE
static
unsigned
run
(
const
TShape
*
shape
,
Tinit
start
,
T
idx
)
{
return
start
+
idx
;
}
template
<
typename
Tinit
,
typename
T
>
TV_HOST_DEVICE_INLINE
static
unsigned
runShape
(
const
Shape
&
shape
,
Tinit
start
,
T
idx
)
{
return
start
+
idx
;
}
template
<
typename
TShape
,
typename
Tinit
>
TV_HOST_DEVICE_INLINE
static
unsigned
runPtrs
(
const
TShape
*
indexes
,
const
TShape
*
shape
,
Tinit
start
)
{
return
start
+
indexes
[
Ndim
-
1
];
}
};
template
<
>
struct
ArrayIndexRowMajor
<
0
,
0
>
{
template
<
typename
TShape
,
typename
Tinit
>
TV_HOST_DEVICE_INLINE
static
unsigned
run
(
const
TShape
*
shape
,
Tinit
start
)
{
return
0
;
}
template
<
typename
Tinit
>
TV_HOST_DEVICE_INLINE
static
unsigned
runShape
(
const
Shape
&
shape
,
Tinit
start
)
{
return
0
;
}
template
<
typename
TShape
,
typename
Tinit
>
TV_HOST_DEVICE_INLINE
static
unsigned
runPtrs
(
const
TShape
*
indexes
,
const
TShape
*
shape
,
Tinit
start
)
{
return
0
;
}
};
template
<
int
N
,
int
Ndim
>
struct
ArrayIndexStride
{
// this array index provide almost same compiled code. compile it in
// https://godbolt.org/ for more details.
template
<
typename
TShape
,
typename
Tinit
,
typename
T
,
class
...
Ts
>
TV_HOST_DEVICE_INLINE
static
unsigned
run
(
const
TShape
*
stride
,
Tinit
start
,
T
index
,
Ts
...
inds
)
{
return
ArrayIndexStride
<
N
-
1
,
Ndim
>::
run
(
stride
,
start
+
index
*
stride
[
Ndim
-
N
+
1
],
inds
...);
}
};
template
<
int
Ndim
>
struct
ArrayIndexStride
<
1
,
Ndim
>
{
template
<
typename
TShape
,
typename
Tinit
,
typename
T
>
TV_HOST_DEVICE_INLINE
static
unsigned
run
(
const
TShape
*
stride
,
Tinit
start
,
T
idx
)
{
return
start
+
idx
*
stride
[
Ndim
-
1
];
}
};
#if __cplusplus >= 201703L
template
<
size_t
...
N
,
class
T
,
class
...
Ts
>
TV_HOST_DEVICE_INLINE
T
array_index_stride
(
const
T
*
stride
,
Ts
...
ids
)
{
return
((
stride
[
N
]
*
std
::
get
<
N
>
(
std
::
forward_as_tuple
(
ids
...)))
+
...);
}
#endif
namespace
detail
{
template
<
typename
T
>
struct
TypeToString
;
template
<
>
struct
TypeToString
<
bool
>
{
static
constexpr
const
char
*
value
=
"bool"
;
};
template
<
>
struct
TypeToString
<
const
bool
>
{
static
constexpr
const
char
*
value
=
"bool"
;
};
template
<
>
struct
TypeToString
<
int32_t
>
{
static
constexpr
const
char
*
value
=
"int32"
;
};
template
<
>
struct
TypeToString
<
float
>
{
static
constexpr
const
char
*
value
=
"float"
;
};
template
<
>
struct
TypeToString
<
double
>
{
static
constexpr
const
char
*
value
=
"double"
;
};
template
<
>
struct
TypeToString
<
int16_t
>
{
static
constexpr
const
char
*
value
=
"int16"
;
};
template
<
>
struct
TypeToString
<
int8_t
>
{
static
constexpr
const
char
*
value
=
"int8"
;
};
template
<
>
struct
TypeToString
<
int64_t
>
{
static
constexpr
const
char
*
value
=
"int64"
;
};
template
<
>
struct
TypeToString
<
uint8_t
>
{
static
constexpr
const
char
*
value
=
"uint8"
;
};
template
<
>
struct
TypeToString
<
uint16_t
>
{
static
constexpr
const
char
*
value
=
"uint16"
;
};
template
<
>
struct
TypeToString
<
uint32_t
>
{
static
constexpr
const
char
*
value
=
"uint32"
;
};
template
<
>
struct
TypeToString
<
uint64_t
>
{
static
constexpr
const
char
*
value
=
"uint64"
;
};
template
<
>
struct
TypeToString
<
const
int32_t
>
{
static
constexpr
const
char
*
value
=
"int32"
;
};
template
<
>
struct
TypeToString
<
const
float
>
{
static
constexpr
const
char
*
value
=
"float"
;
};
template
<
>
struct
TypeToString
<
const
double
>
{
static
constexpr
const
char
*
value
=
"double"
;
};
template
<
>
struct
TypeToString
<
const
int16_t
>
{
static
constexpr
const
char
*
value
=
"int16"
;
};
template
<
>
struct
TypeToString
<
const
int8_t
>
{
static
constexpr
const
char
*
value
=
"int8"
;
};
template
<
>
struct
TypeToString
<
const
int64_t
>
{
static
constexpr
const
char
*
value
=
"int64"
;
};
template
<
>
struct
TypeToString
<
const
uint8_t
>
{
static
constexpr
const
char
*
value
=
"uint8"
;
};
template
<
>
struct
TypeToString
<
const
uint16_t
>
{
static
constexpr
const
char
*
value
=
"uint16"
;
};
template
<
>
struct
TypeToString
<
const
uint32_t
>
{
static
constexpr
const
char
*
value
=
"uint32"
;
};
template
<
>
struct
TypeToString
<
const
uint64_t
>
{
static
constexpr
const
char
*
value
=
"uint64"
;
};
}
// namespace detail
template
<
typename
T
>
constexpr
const
char
*
type_s
=
detail
::
TypeToString
<
T
>::
value
;
namespace
detail
{
template
<
typename
T
,
int
Rank
,
template
<
class
>
class
PtrTraits
=
DefaultPtrTraits
,
typename
Tindex
=
int
>
struct
TensorAccesserBase
{
static
constexpr
int
rank_value
=
Rank
;
using
ptr_t
=
typename
PtrTraits
<
T
>::
type
;
static_assert
(
Rank
>
0
,
"error"
);
explicit
TV_HOST_DEVICE_INLINE
TensorAccesserBase
(
ptr_t
ptr
,
const
Tindex
*
stride_ptr
)
:
ptr_
(
ptr
),
stride_ptr_
(
stride_ptr
)
{}
TV_HOST_DEVICE_INLINE
ptr_t
data
()
{
return
ptr_
;
}
TV_HOST_DEVICE_INLINE
const
ptr_t
data
()
const
{
return
ptr_
;
}
template
<
class
...
Inds
>
TV_HOST_DEVICE_INLINE
T
&
operator
()(
Inds
...
inds
)
{
static_assert
(
sizeof
...(
inds
)
==
Rank
,
"error"
);
return
ptr_
[
ArrayIndexStride
<
Rank
,
Rank
>::
run
(
stride_ptr_
,
0
,
inds
...)];
}
template
<
class
...
Inds
>
TV_HOST_DEVICE_INLINE
const
T
&
operator
()(
Inds
...
inds
)
const
{
static_assert
(
sizeof
...(
inds
)
==
Rank
,
"error"
);
return
ptr_
[
ArrayIndexStride
<
Rank
,
Rank
>::
run
(
stride_ptr_
,
0
,
inds
...)];
}
protected:
ptr_t
ptr_
;
const
Tindex
*
stride_ptr_
;
};
}
// namespace detail
template
<
typename
T
,
int
Rank
,
template
<
class
>
class
PtrTraits
=
DefaultPtrTraits
,
typename
Tindex
=
int
>
struct
TensorAccesser
:
public
detail
::
TensorAccesserBase
<
T
,
Rank
,
PtrTraits
,
Tindex
>
{
using
ptr_t
=
typename
PtrTraits
<
T
>::
type
;
static_assert
(
Rank
>
0
,
"error"
);
explicit
TV_HOST_DEVICE_INLINE
TensorAccesser
(
ptr_t
ptr
,
const
Tindex
*
stride_ptr
)
:
detail
::
TensorAccesserBase
<
T
,
Rank
,
PtrTraits
,
Tindex
>
(
ptr
,
stride_ptr
)
{}
TV_HOST_DEVICE_INLINE
TensorAccesser
<
T
,
Rank
-
1
,
PtrTraits
,
Tindex
>
operator
[](
int
i
)
{
return
TensorAccesser
<
T
,
Rank
-
1
,
PtrTraits
,
Tindex
>
(
this
->
ptr_
+
this
->
stride_ptr_
[
0
]
*
i
,
this
->
stride_ptr_
+
1
);
}
TV_HOST_DEVICE_INLINE
TensorAccesser
<
T
,
Rank
-
1
,
PtrTraits
,
Tindex
>
operator
[](
int
i
)
const
{
return
TensorAccesser
<
T
,
Rank
-
1
,
PtrTraits
,
Tindex
>
(
this
->
ptr_
+
this
->
stride_ptr_
[
0
]
*
i
,
this
->
stride_ptr_
+
1
);
}
};
template
<
typename
T
,
template
<
class
>
class
PtrTraits
,
typename
Tindex
>
struct
TensorAccesser
<
T
,
1
,
PtrTraits
,
Tindex
>
:
public
detail
::
TensorAccesserBase
<
T
,
1
,
PtrTraits
,
Tindex
>
{
using
ptr_t
=
typename
PtrTraits
<
T
>::
type
;
explicit
TV_HOST_DEVICE_INLINE
TensorAccesser
(
ptr_t
ptr
,
const
Tindex
*
stride_ptr
)
:
detail
::
TensorAccesserBase
<
T
,
1
,
PtrTraits
,
Tindex
>
(
ptr
,
stride_ptr
)
{}
TV_HOST_DEVICE_INLINE
T
&
operator
[](
int
i
)
{
return
this
->
ptr_
[
this
->
stride_ptr_
[
0
]
*
i
];
}
TV_HOST_DEVICE_INLINE
T
&
operator
[](
int
i
)
const
{
return
this
->
ptr_
[
this
->
stride_ptr_
[
0
]
*
i
];
}
};
template
<
typename
T
,
int
Rank
=
-
1
,
template
<
class
>
class
PtrTraits
=
DefaultPtrTraits
,
typename
Tindex
=
int
>
struct
TensorView
{
static
constexpr
int
rank_value
=
Rank
;
using
ptr_t
=
typename
PtrTraits
<
T
>::
type
;
using
tv_shape_t
=
ShapeBase
<
Rank
==
-
1
?
TV_MAX_DIM
:
Rank
,
Tindex
>
;
using
no_cv_type
=
typename
std
::
remove_cv
<
T
>::
type
;
static_assert
(
Rank
==
-
1
||
Rank
>
0
,
"error"
);
TV_HOST_DEVICE_INLINE
TensorView
()
{}
explicit
TV_HOST_DEVICE_INLINE
TensorView
(
ptr_t
ptr
,
tv_shape_t
shape
)
:
ptr_
(
ptr
),
shape_
(
shape
),
stride_
(
shape
.
stride_rowmajor
())
{}
explicit
TV_HOST_DEVICE_INLINE
TensorView
(
ptr_t
ptr
,
tv_shape_t
shape
,
tv_shape_t
stride
)
:
ptr_
(
ptr
),
shape_
(
shape
),
stride_
(
stride
)
{}
operator
TensorView
<
const
no_cv_type
,
Rank
,
PtrTraits
,
Tindex
>
()
{
return
TensorView
<
const
no_cv_type
,
Rank
,
PtrTraits
,
Tindex
>
(
ptr_
,
shape_
);
}
// conversion function
template
<
class
...
Inds
>
TV_HOST_DEVICE_INLINE
T
&
operator
()(
Inds
...
inds
)
{
static_assert
(
Rank
==
-
1
||
sizeof
...(
inds
)
==
Rank
,
"error"
);
#if defined TV_DEBUG
int
idxes
[
sizeof
...(
Inds
)]{
int
(
inds
)...};
TV_REQUIRE
(
sizeof
...(
inds
)
==
shape_
.
ndim
(),
"you provide %d indexes, but dim is %d
\n
"
,
sizeof
...(
inds
),
shape_
.
ndim
());
for
(
int
i
=
0
;
i
<
sizeof
...(
inds
);
++
i
)
{
TV_REQUIRE
(
idxes
[
i
]
>=
0
&&
idxes
[
i
]
<
shape_
[
i
],
"index-%d(%d) out-of-range: [0, %d)
\n
"
,
i
,
idxes
[
i
],
shape_
[
i
]);
}
#endif
constexpr
int
Ndim
=
sizeof
...(
Inds
);
return
ptr_
[
ArrayIndexRowMajor
<
Ndim
,
Ndim
>::
runShape
(
shape_
,
0
,
inds
...)];
}
template
<
class
...
Inds
>
TV_HOST_DEVICE_INLINE
const
T
&
operator
()(
Inds
...
inds
)
const
{
static_assert
(
Rank
==
-
1
||
sizeof
...(
inds
)
==
Rank
,
"error"
);
#if defined TV_DEBUG
int
idxes
[
sizeof
...(
Inds
)]{
int
(
inds
)...};
TV_REQUIRE
(
sizeof
...(
inds
)
==
shape_
.
ndim
(),
"you provide %d indexes, but dim is %d
\n
"
,
sizeof
...(
inds
),
shape_
.
ndim
());
for
(
int
i
=
0
;
i
<
sizeof
...(
inds
);
++
i
)
{
TV_REQUIRE
(
idxes
[
i
]
>=
0
&&
idxes
[
i
]
<
shape_
[
i
],
"index-%d(%d) out-of-range: [0, %d)
\n
"
,
i
,
idxes
[
i
],
shape_
[
i
]);
}
#endif
constexpr
int
Ndim
=
sizeof
...(
Inds
);
return
ptr_
[
ArrayIndexRowMajor
<
Ndim
,
Ndim
>::
runShape
(
shape_
,
0
,
inds
...)];
}
TV_HOST_DEVICE_INLINE
T
&
operator
()()
{
static_assert
(
Rank
==
-
1
||
0
==
Rank
,
"error"
);
#if defined TV_DEBUG
TV_REQUIRE
(
ptr_
!=
nullptr
,
"you want get value but the view is empty.%s"
,
"
\n
"
);
TV_REQUIRE
(
shape_
.
ndim
()
==
0
,
"you provide 0 indexes, but dim is %ld
\n
"
,
shape_
.
ndim
());
#endif
return
ptr_
[
0
];
}
TV_HOST_DEVICE_INLINE
const
T
&
operator
()()
const
{
static_assert
(
Rank
==
-
1
||
0
==
Rank
,
"error"
);
#if defined TV_DEBUG
TV_REQUIRE
(
ptr_
!=
nullptr
,
"you want get value but the view is empty.%s"
,
"
\n
"
);
TV_REQUIRE
(
shape_
.
ndim
()
==
0
,
"you provide 0 indexes, but dim is %ld
\n
"
,
shape_
.
ndim
());
#endif
return
ptr_
[
0
];
}
template
<
class
T1
>
TV_HOST_DEVICE_INLINE
T
&
operator
()(
T1
i1
)
{
static_assert
(
Rank
==
-
1
||
1
==
Rank
,
"error"
);
#if defined TV_DEBUG
TV_REQUIRE
(
shape_
.
ndim
()
==
1
,
"you provide 1 indexes, but dim is %ld
\n
"
,
shape_
.
ndim
());
TV_REQUIRE
(
i1
>=
0
&&
i1
<
shape_
[
0
],
"index-%d(%d) out-of-range: [0, %d)
\n
"
,
0
,
i1
,
shape_
[
0
]);
#endif
return
ptr_
[
i1
];
}
template
<
class
T1
,
class
T2
>
TV_HOST_DEVICE_INLINE
T
&
operator
()(
T1
i1
,
T2
i2
)
{
static_assert
(
Rank
==
-
1
||
2
==
Rank
,
"error"
);
#if defined TV_DEBUG
TV_REQUIRE
(
shape_
.
ndim
()
==
2
,
"you provide 2 indexes, but dim is %ld
\n
"
,
shape_
.
ndim
());
TV_REQUIRE
(
i1
>=
0
&&
i1
<
shape_
[
0
],
"index-%d(%d) out-of-range: [0, %d)
\n
"
,
0
,
int
(
i1
),
shape_
[
0
]);
TV_REQUIRE
(
i2
>=
0
&&
i2
<
shape_
[
1
],
"index-%d(%d) out-of-range: [0, %d)
\n
"
,
1
,
int
(
i2
),
shape_
[
1
]);
#endif
return
ptr_
[
i1
*
shape_
[
1
]
+
i2
];
}
template
<
class
T1
,
class
T2
,
class
T3
>
TV_HOST_DEVICE_INLINE
T
&
operator
()(
T1
i1
,
T2
i2
,
T3
i3
)
{
static_assert
(
Rank
==
-
1
||
3
==
Rank
,
"error"
);
#if defined TV_DEBUG
TV_REQUIRE
(
shape_
.
ndim
()
==
3
,
"you provide 3 indexes, but dim is %ld
\n
"
,
shape_
.
ndim
());
TV_REQUIRE
(
i1
>=
0
&&
i1
<
shape_
[
0
],
"index-%d(%d) out-of-range: [0, %d)
\n
"
,
0
,
int
(
i1
),
shape_
[
0
]);
TV_REQUIRE
(
i2
>=
0
&&
i2
<
shape_
[
1
],
"index-%d(%d) out-of-range: [0, %d)
\n
"
,
1
,
int
(
i2
),
shape_
[
1
]);
TV_REQUIRE
(
i3
>=
0
&&
i3
<
shape_
[
2
],
"index-%d(%d) out-of-range: [0, %d)
\n
"
,
2
,
int
(
i3
),
shape_
[
2
]);
#endif
return
ptr_
[(
i1
*
shape_
[
1
]
+
i2
)
*
shape_
[
2
]
+
i3
];
}
template
<
class
T1
,
class
T2
,
class
T3
,
class
T4
>
TV_HOST_DEVICE_INLINE
T
&
operator
()(
T1
i1
,
T2
i2
,
T3
i3
,
T4
i4
)
{
static_assert
(
Rank
==
-
1
||
4
==
Rank
,
"error"
);
#if defined TV_DEBUG
TV_REQUIRE
(
shape_
.
ndim
()
==
4
,
"you provide 4 indexes, but dim is %ld
\n
"
,
shape_
.
ndim
());
TV_REQUIRE
(
i1
>=
0
&&
i1
<
shape_
[
0
],
"index-%d(%d) out-of-range: [0, %d)
\n
"
,
0
,
int
(
i1
),
shape_
[
0
]);
TV_REQUIRE
(
i2
>=
0
&&
i2
<
shape_
[
1
],
"index-%d(%d) out-of-range: [0, %d)
\n
"
,
1
,
int
(
i2
),
shape_
[
1
]);
TV_REQUIRE
(
i3
>=
0
&&
i3
<
shape_
[
2
],
"index-%d(%d) out-of-range: [0, %d)
\n
"
,
2
,
int
(
i3
),
shape_
[
2
]);
TV_REQUIRE
(
i4
>=
0
&&
i4
<
shape_
[
3
],
"index-%d(%d) out-of-range: [0, %d)
\n
"
,
3
,
int
(
i4
),
shape_
[
3
]);
#endif
return
ptr_
[((
i1
*
shape_
[
1
]
+
i2
)
*
shape_
[
2
]
+
i3
)
*
shape_
[
3
]
+
i4
];
}
template
<
class
T1
>
TV_HOST_DEVICE_INLINE
const
T
&
operator
()(
T1
i1
)
const
{
static_assert
(
Rank
==
-
1
||
1
==
Rank
,
"error"
);
#if defined TV_DEBUG
TV_REQUIRE
(
shape_
.
ndim
()
==
1
,
"you provide 1 indexes, but dim is %ld
\n
"
,
shape_
.
ndim
());
TV_REQUIRE
(
i1
>=
0
&&
i1
<
shape_
[
0
],
"index-%d(%d) out-of-range: [0, %d)
\n
"
,
0
,
int
(
i1
),
shape_
[
0
]);
#endif
return
ptr_
[
i1
];
}
template
<
class
T1
,
class
T2
>
TV_HOST_DEVICE_INLINE
const
T
&
operator
()(
T1
i1
,
T2
i2
)
const
{
static_assert
(
Rank
==
-
1
||
2
==
Rank
,
"error"
);
#if defined TV_DEBUG
TV_REQUIRE
(
shape_
.
ndim
()
==
2
,
"you provide 2 indexes, but dim is %ld
\n
"
,
shape_
.
ndim
());
TV_REQUIRE
(
i1
>=
0
&&
i1
<
shape_
[
0
],
"index-%d(%d) out-of-range: [0, %d)
\n
"
,
0
,
int
(
i1
),
shape_
[
0
]);
TV_REQUIRE
(
i2
>=
0
&&
i2
<
shape_
[
1
],
"index-%d(%d) out-of-range: [0, %d)
\n
"
,
1
,
int
(
i2
),
shape_
[
1
]);
#endif
return
ptr_
[
i1
*
shape_
[
1
]
+
i2
];
}
template
<
class
T1
,
class
T2
,
class
T3
>
TV_HOST_DEVICE_INLINE
const
T
&
operator
()(
T1
i1
,
T2
i2
,
T3
i3
)
const
{
static_assert
(
Rank
==
-
1
||
3
==
Rank
,
"error"
);
#if defined TV_DEBUG
TV_REQUIRE
(
shape_
.
ndim
()
==
3
,
"you provide 3 indexes, but dim is %ld
\n
"
,
shape_
.
ndim
());
TV_REQUIRE
(
i1
>=
0
&&
i1
<
shape_
[
0
],
"index-%d(%d) out-of-range: [0, %d)
\n
"
,
0
,
int
(
i1
),
shape_
[
0
]);
TV_REQUIRE
(
i2
>=
0
&&
i2
<
shape_
[
1
],
"index-%d(%d) out-of-range: [0, %d)
\n
"
,
1
,
int
(
i2
),
shape_
[
1
]);
TV_REQUIRE
(
i3
>=
0
&&
i3
<
shape_
[
2
],
"index-%d(%d) out-of-range: [0, %d)
\n
"
,
2
,
int
(
i3
),
shape_
[
2
]);
#endif
return
ptr_
[(
i1
*
shape_
[
1
]
+
i2
)
*
shape_
[
2
]
+
i3
];
}
template
<
class
T1
,
class
T2
,
class
T3
,
class
T4
>
TV_HOST_DEVICE_INLINE
const
T
&
operator
()(
T1
i1
,
T2
i2
,
T3
i3
,
T4
i4
)
const
{
static_assert
(
Rank
==
-
1
||
4
==
Rank
,
"error"
);
#if defined TV_DEBUG
TV_REQUIRE
(
shape_
.
ndim
()
==
4
,
"you provide 4 indexes, but dim is %ld
\n
"
,
shape_
.
ndim
());
TV_REQUIRE
(
i1
>=
0
&&
i1
<
shape_
[
0
],
"index-%d(%d) out-of-range: [0, %d)
\n
"
,
0
,
int
(
i1
),
shape_
[
0
]);
TV_REQUIRE
(
i2
>=
0
&&
i2
<
shape_
[
1
],
"index-%d(%d) out-of-range: [0, %d)
\n
"
,
1
,
int
(
i2
),
shape_
[
1
]);
TV_REQUIRE
(
i3
>=
0
&&
i3
<
shape_
[
2
],
"index-%d(%d) out-of-range: [0, %d)
\n
"
,
2
,
int
(
i3
),
shape_
[
2
]);
TV_REQUIRE
(
i4
>=
0
&&
i4
<
shape_
[
3
],
"index-%d(%d) out-of-range: [0, %d)
\n
"
,
3
,
int
(
i4
),
shape_
[
3
]);
#endif
return
ptr_
[((
i1
*
shape_
[
1
]
+
i2
)
*
shape_
[
2
]
+
i3
)
*
shape_
[
3
]
+
i4
];
}
TV_HOST_DEVICE_INLINE
T
&
operator
[](
int
idx
)
{
#ifdef TV_DEBUG
TV_REQUIRE
(
idx
>=
0
&&
idx
<
size
(),
"index(%d) out-of-range: [0, %ld)
\n
"
,
int
(
idx
),
size
());
#endif
return
ptr_
[
idx
];
}
TV_HOST_DEVICE_INLINE
const
T
&
operator
[](
int
idx
)
const
{
#ifdef TV_DEBUG
TV_REQUIRE
(
idx
>=
0
&&
idx
<
size
(),
"index(%d) out-of-range: [0, %ld)
\n
"
,
int
(
idx
),
size
());
#endif
return
ptr_
[
idx
];
}
TV_HOST_DEVICE_INLINE
TensorAccesser
<
T
,
Rank
-
1
,
PtrTraits
,
Tindex
>
accessor
(
Tindex
idx
)
{
static_assert
(
Rank
>
1
,
"for Rank == 1, use accessor() or just use []"
);
return
TensorAccesser
<
T
,
Rank
-
1
,
PtrTraits
,
Tindex
>
(
ptr_
+
stride_
[
0
]
*
idx
,
stride_
.
data
()
+
1
);
}
TV_HOST_DEVICE_INLINE
TensorAccesser
<
T
,
Rank
,
PtrTraits
,
Tindex
>
accessor
()
{
static_assert
(
Rank
>
0
,
"rank must higher than zero"
);
return
TensorAccesser
<
T
,
Rank
,
PtrTraits
,
Tindex
>
(
ptr_
,
stride_
.
data
());
}
TV_HOST_DEVICE_INLINE
TensorAccesser
<
T
,
Rank
-
1
,
PtrTraits
,
Tindex
>
accessor
(
Tindex
idx
)
const
{
static_assert
(
Rank
>
1
,
"for Rank == 1, use accessor() or just use []"
);
return
TensorAccesser
<
T
,
Rank
-
1
,
PtrTraits
,
Tindex
>
(
ptr_
+
stride_
[
0
]
*
idx
,
stride_
.
data
()
+
1
);
}
TV_HOST_DEVICE_INLINE
TensorAccesser
<
T
,
Rank
,
PtrTraits
,
Tindex
>
accessor
()
const
{
static_assert
(
Rank
>
0
,
"error"
);
return
TensorAccesser
<
T
,
Rank
,
PtrTraits
,
Tindex
>
(
ptr_
,
stride_
.
data
(),
"rank must higher than zero"
);
}
TV_HOST_DEVICE_INLINE
bool
empty
()
const
{
return
ptr_
==
nullptr
;
}
TV_HOST_DEVICE_INLINE
ptr_t
data
()
{
return
ptr_
;
}
TV_HOST_DEVICE_INLINE
const
ptr_t
data
()
const
{
return
ptr_
;
}
TV_HOST_DEVICE_INLINE
const
tv_shape_t
&
shape
()
const
{
return
shape_
;
}
TV_HOST_DEVICE_INLINE
const
tv_shape_t
&
stride
()
const
{
return
stride_
;
}
TV_HOST_DEVICE_INLINE
int
dim
(
int
idx
)
const
{
return
shape_
[
idx
];
}
TV_HOST_DEVICE_INLINE
int
ndim
()
const
{
return
shape_
.
ndim
();
}
template
<
class
...
Inds
>
TV_HOST_DEVICE_INLINE
TensorView
<
T
,
Rank
==
-
1
?
-
1
:
sizeof
...(
Inds
),
PtrTraits
,
Tindex
>
view
(
Inds
...
newShapes
)
const
{
ShapeBase
<
Rank
==
-
1
?
TV_MAX_DIM
:
sizeof
...(
Inds
),
Tindex
>
shapes
{
int
(
newShapes
)...};
for
(
size_t
i
=
0
;
i
<
sizeof
...(
newShapes
);
++
i
)
{
if
(
shapes
[
i
]
==
-
1
)
{
shapes
[
i
]
=
1
;
shapes
[
i
]
=
size
()
/
shapes
.
size
();
break
;
}
}
TV_ASSERT
(
shapes
.
size
()
==
size
());
return
TensorView
<
T
,
Rank
==
-
1
?
-
1
:
sizeof
...(
Inds
),
PtrTraits
,
Tindex
>
(
ptr_
,
shapes
);
}
TV_HOST_DEVICE_INLINE
TensorView
<
T
,
-
1
,
PtrTraits
,
Tindex
>
view
(
Shape
shapes
)
const
{
TV_ASSERT
(
shapes
.
size
()
==
size
());
return
TensorView
<
T
,
-
1
,
PtrTraits
,
Tindex
>
(
ptr_
,
shapes
);
}
TV_HOST_DEVICE_INLINE
TensorView
<
T
,
-
1
,
PtrTraits
,
Tindex
>
squeeze
()
const
{
return
TensorView
<
T
,
-
1
,
PtrTraits
,
Tindex
>
(
ptr_
,
shape_
.
squeeze
());
}
TV_HOST_DEVICE_INLINE
TensorView
<
T
,
Rank
==
-
1
?
-
1
:
Rank
-
1
,
PtrTraits
,
Tindex
>
squeeze
(
int
dim
)
const
{
return
TensorView
<
T
,
Rank
==
-
1
?
-
1
:
Rank
-
1
,
PtrTraits
,
Tindex
>
(
ptr_
,
shape_
.
squeeze
<
Rank
==
-
1
?
TV_MAX_DIM
:
Rank
-
1
>
(
dim
));
}
TV_HOST_DEVICE_INLINE
size_t
size
()
const
{
return
shape_
.
size
();
}
template
<
class
...
Integers
>
TV_HOST_DEVICE_INLINE
TensorView
<
T
,
-
1
,
PtrTraits
,
Tindex
>
subview
(
int
id
,
Integers
...
ints
)
{
tv_shape_t
start
=
{
id
,
ints
...};
for
(
int
i
=
1
+
sizeof
...(
ints
);
i
<
ndim
();
++
i
)
{
start
.
push_back
(
0
);
}
return
TensorView
<
T
,
Rank
,
PtrTraits
,
Tindex
>
(
ptr_
+
rowArrayIdx
(
shape_
,
start
),
shape_
.
subshape
(
sizeof
...(
ints
)
+
1
));
}
template
<
class
...
Integers
>
TV_HOST_DEVICE_INLINE
TensorView
<
T
,
-
1
,
PtrTraits
,
Tindex
>
subview
(
int
id
,
Integers
...
ints
)
const
{
tv_shape_t
start
=
{
id
,
ints
...};
for
(
int
i
=
1
+
sizeof
...(
ints
);
i
<
ndim
();
++
i
)
{
start
.
push_back
(
0
);
}
return
TensorView
<
T
,
Rank
,
PtrTraits
,
Tindex
>
(
ptr_
+
rowArrayIdx
(
shape_
,
start
),
shape_
.
subshape
(
sizeof
...(
ints
)
+
1
));
}
TV_HOST_DEVICE_INLINE
TensorView
<
T
,
-
1
,
PtrTraits
,
Tindex
>
subview
(
SimpleVector
<
int
>
ids
)
const
{
Shape
start
=
ids
;
for
(
int
i
=
ids
.
size
();
i
<
ndim
();
++
i
)
{
start
.
push_back
(
0
);
}
return
TensorView
<
T
,
Rank
,
PtrTraits
,
Tindex
>
(
ptr_
+
rowArrayIdx
(
shape_
,
start
),
shape_
.
subshape
(
ids
.
size
()));
}
template
<
typename
Os
>
std
::
string
repr
(
Os
&
ss
)
const
{
if
(
empty
())
return
""
;
if
(
shape_
.
ndim
()
==
0
)
{
ss
<<
"Tensor["
<<
type_s
<
T
>
<<
"]"
<<
std
::
endl
;
ss
<<
*
ptr_
;
return
ss
.
str
();
}
SimpleVector
<
int64_t
,
TV_MAX_DIM
>
prev
(
ndim
(),
-
1
);
SimpleVector
<
int64_t
,
TV_MAX_DIM
>
nd_index
(
ndim
());
SimpleVector
<
int64_t
,
TV_MAX_DIM
>
_shape
;
for
(
auto
s
:
shape
())
{
_shape
.
push_back
(
s
);
}
ss
<<
"Tensor["
<<
type_s
<
T
>
<<
"]: shape="
<<
shape
()
<<
", stride="
<<
stride
()
<<
std
::
endl
;
auto
ndimValue
=
ndim
();
for
(
int64_t
i
=
0
;
i
<
int64_t
(
size
());
++
i
)
{
rowArrayIdxInv
(
i
,
nd_index
.
data
(),
_shape
.
data
(),
ndimValue
);
bool
newline
=
false
;
int
end_count
=
0
;
for
(
int
j
=
0
;
j
<
ndimValue
;
++
j
)
{
if
(
nd_index
[
j
]
!=
prev
[
j
]
&&
nd_index
[
j
]
==
0
&&
prev
[
j
]
!=
0
&&
prev
[
j
]
!=
-
1
)
{
ss
<<
"]"
;
++
end_count
;
newline
=
true
;
}
}
if
(
prev
[
0
]
==
-
1
)
{
end_count
=
ndimValue
;
}
if
(
newline
)
{
ss
<<
"
\n
"
;
}
int
starts_count
=
0
;
for
(
int
j
=
0
;
j
<
ndimValue
;
++
j
)
{
if
(
nd_index
[
j
]
!=
prev
[
j
]
&&
nd_index
[
j
]
==
0
&&
prev
[
j
]
!=
0
)
{
++
starts_count
;
}
}
if
(
starts_count
>
0
)
{
for
(
int
j
=
0
;
j
<
ndimValue
-
end_count
;
++
j
)
{
ss
<<
" "
;
}
for
(
int
j
=
0
;
j
<
starts_count
;
++
j
)
{
ss
<<
"["
;
}
}
if
(
std
::
is_same
<
T
,
uint8_t
>::
value
||
std
::
is_same
<
T
,
const
uint8_t
>::
value
)
{
ss
<<
unsigned
((
*
this
)[
i
]);
}
else
{
ss
<<
(
*
this
)[
i
];
}
if
(
nd_index
[
ndimValue
-
1
]
!=
_shape
[
ndimValue
-
1
]
-
1
)
{
ss
<<
","
;
}
for
(
int
j
=
0
;
j
<
ndimValue
;
++
j
)
{
prev
[
j
]
=
nd_index
[
j
];
}
}
for
(
int
j
=
0
;
j
<
ndimValue
;
++
j
)
{
ss
<<
"]"
;
}
return
ss
.
str
();
}
std
::
string
repr
()
const
{
std
::
ostringstream
ss
;
return
repr
(
ss
);
}
protected:
template
<
typename
T1
>
TV_HOST_DEVICE_INLINE
Slice
to_slice
(
T1
s
)
const
{
return
Slice
{
int
(
s
),
-
1
,
-
1
};
}
TV_HOST_DEVICE_INLINE
Slice
to_slice
(
Slice
s
)
const
{
return
Slice
(
s
);
}
ptr_t
ptr_
=
nullptr
;
tv_shape_t
shape_
;
tv_shape_t
stride_
;
};
template
<
typename
T
>
TensorView
<
T
>
vector2tv
(
std
::
vector
<
T
>
&
arr
)
{
return
TensorView
<
T
>
(
arr
.
data
(),
{
arr
.
size
()});
}
template
<
typename
T
>
TensorView
<
T
>
vector2tv
(
std
::
vector
<
T
>
&
arr
,
Shape
shape
)
{
TV_ASSERT_INVALID_ARG
(
shape
.
prod
()
==
arr
.
size
(),
"error"
);
return
TensorView
<
T
>
(
arr
.
data
(),
shape
);
}
template
<
typename
T
>
TensorView
<
const
T
>
vector2tv
(
const
std
::
vector
<
T
>
&
arr
)
{
return
TensorView
<
const
T
>
(
arr
.
data
(),
{
arr
.
size
()});
}
template
<
typename
Os
,
typename
T
,
int
Rank
,
template
<
class
>
class
PtrTraits
,
typename
Tindex
>
Os
&
operator
<<
(
Os
&
os
,
const
TensorView
<
T
,
Rank
,
PtrTraits
,
Tindex
>
&
dt
)
{
os
<<
dt
.
repr
();
return
os
;
}
template
<
typename
Os
,
typename
T
,
int
Rank
,
template
<
class
>
class
PtrTraits
,
typename
Tindex
>
Os
&
operator
<<
(
Os
&
os
,
const
TensorView
<
const
T
,
Rank
,
PtrTraits
,
Tindex
>
&
dt
)
{
os
<<
dt
.
repr
();
return
os
;
}
namespace
detail
{
template
<
typename
T
>
struct
TypePrintfFormat
;
template
<
>
struct
TypePrintfFormat
<
float
>
{
static
constexpr
const
char
*
value
=
"%.2f"
;
};
template
<
>
struct
TypePrintfFormat
<
double
>
{
static
constexpr
const
char
*
value
=
"%.2f"
;
};
template
<
>
struct
TypePrintfFormat
<
int8_t
>
{
static
constexpr
const
char
*
value
=
"%d"
;
};
template
<
>
struct
TypePrintfFormat
<
int16_t
>
{
static
constexpr
const
char
*
value
=
"%d"
;
};
template
<
>
struct
TypePrintfFormat
<
int32_t
>
{
static
constexpr
const
char
*
value
=
"%d"
;
};
template
<
>
struct
TypePrintfFormat
<
uint8_t
>
{
static
constexpr
const
char
*
value
=
"%u"
;
};
template
<
>
struct
TypePrintfFormat
<
uint16_t
>
{
static
constexpr
const
char
*
value
=
"%u"
;
};
template
<
>
struct
TypePrintfFormat
<
uint32_t
>
{
static
constexpr
const
char
*
value
=
"%u"
;
};
template
<
>
struct
TypePrintfFormat
<
int64_t
>
{
static
constexpr
const
char
*
value
=
"%ld"
;
};
template
<
>
struct
TypePrintfFormat
<
uint64_t
>
{
static
constexpr
const
char
*
value
=
"%lu"
;
};
template
<
>
struct
TypePrintfFormat
<
bool
>
{
static
constexpr
const
char
*
value
=
"%d"
;
};
template
<
typename
T
>
constexpr
const
char
*
type_printf_format_v
=
TypePrintfFormat
<
T
>::
value
;
};
// namespace detail
template
<
typename
T
,
int
Rank
,
template
<
class
>
class
PtrTraits
,
typename
Tindex
>
TV_HOST_DEVICE
void
printTensorView
(
const
TensorView
<
T
,
Rank
,
PtrTraits
,
Tindex
>
&
tensor
,
const
char
*
format
)
{
// used to print tensor in cuda kernel.
if
(
tensor
.
empty
())
return
;
if
(
tensor
.
ndim
()
==
0
)
{
printf
(
format
,
tensor
());
printf
(
"
\n
"
);
return
;
}
SimpleVector
<
int64_t
,
TV_MAX_DIM
>
prev
(
tensor
.
ndim
(),
-
1
);
SimpleVector
<
int64_t
,
TV_MAX_DIM
>
nd_index
(
tensor
.
ndim
());
SimpleVector
<
int64_t
,
TV_MAX_DIM
>
shape
(
tensor
.
shape
());
auto
ndim
=
tensor
.
ndim
();
for
(
int64_t
i
=
0
;
i
<
tensor
.
size
();
++
i
)
{
rowArrayIdxInv
(
i
,
nd_index
.
data
(),
shape
.
data
(),
ndim
);
bool
newline
=
false
;
int
end_count
=
0
;
for
(
int
j
=
0
;
j
<
ndim
;
++
j
)
{
if
(
nd_index
[
j
]
!=
prev
[
j
]
&&
nd_index
[
j
]
==
0
&&
prev
[
j
]
!=
0
&&
prev
[
j
]
!=
-
1
)
{
printf
(
"]"
);
++
end_count
;
newline
=
true
;
}
}
if
(
prev
[
0
]
==
-
1
)
{
end_count
=
ndim
;
}
if
(
newline
)
{
printf
(
"
\n
"
);
}
int
starts_count
=
0
;
for
(
int
j
=
0
;
j
<
ndim
;
++
j
)
{
if
(
nd_index
[
j
]
!=
prev
[
j
]
&&
nd_index
[
j
]
==
0
&&
prev
[
j
]
!=
0
)
{
++
starts_count
;
}
}
if
(
starts_count
>
0
)
{
for
(
int
j
=
0
;
j
<
ndim
-
end_count
;
++
j
)
{
printf
(
" "
);
}
for
(
int
j
=
0
;
j
<
starts_count
;
++
j
)
{
printf
(
"]"
);
}
}
printf
(
format
,
tensor
[
i
]);
if
(
nd_index
[
ndim
-
1
]
!=
shape
[
ndim
-
1
]
-
1
)
{
printf
(
","
);
}
for
(
int
j
=
0
;
j
<
ndim
;
++
j
)
{
prev
[
j
]
=
nd_index
[
j
];
}
}
for
(
int
j
=
0
;
j
<
ndim
;
++
j
)
{
printf
(
"]"
);
}
printf
(
"
\n
"
);
}
template
<
typename
T
,
int
Rank
,
template
<
class
>
class
PtrTraits
,
typename
Tindex
>
TV_HOST_DEVICE
void
printTensorView
(
TensorView
<
T
,
Rank
,
PtrTraits
,
Tindex
>
tensor
)
{
using
Traw
=
typename
std
::
remove_const
<
T
>::
type
;
return
printTensorView
(
tensor
,
detail
::
type_printf_format_v
<
Traw
>
);
}
template
<
typename
T
>
TV_HOST_DEVICE
void
printTensorView
(
const
T
*
ptr
,
Shape
shape
)
{
using
Traw
=
typename
std
::
remove_const
<
T
>::
type
;
return
printTensorView
(
TensorView
<
const
T
>
(
ptr
,
shape
),
detail
::
type_printf_format_v
<
Traw
>
);
}
template
<
typename
T
>
TV_HOST_DEVICE
void
printTensorView
(
const
T
*
ptr
,
Shape
shape
,
const
char
*
format
)
{
return
printTensorView
(
TensorView
<
const
T
>
(
ptr
,
shape
),
format
);
}
#ifdef TV_CUDA
#ifdef __DRIVER_TYPES_H__
#ifndef DEVICE_RESET
#define DEVICE_RESET cudaDeviceReset();
#endif
#else
#ifndef DEVICE_RESET
#define DEVICE_RESET
#endif
#endif
template
<
typename
T
>
void
check
(
T
result
,
char
const
*
const
func
,
const
char
*
const
file
,
int
const
line
)
{
if
(
result
)
{
fprintf
(
stderr
,
"CUDA error at %s:%d code=%d
\"
%s
\"
\n
"
,
file
,
line
,
static_cast
<
unsigned
int
>
(
result
),
func
);
DEVICE_RESET
// Make sure we call CUDA Device Reset before exiting
exit
(
EXIT_FAILURE
);
}
}
#define checkCudaErrors(val) tv::check((val), #val, __FILE__, __LINE__)
template
<
typename
T
>
void
host2dev
(
T
*
dst
,
const
T
*
src
,
size_t
size
,
cudaStream_t
s
=
0
)
{
checkCudaErrors
(
cudaMemcpyAsync
(
dst
,
src
,
size
*
sizeof
(
T
),
cudaMemcpyHostToDevice
,
s
));
}
template
<
typename
T
,
int
Rank
,
template
<
class
>
class
PtrTraits1
,
template
<
class
>
class
PtrTraits2
,
typename
Tindex1
,
typename
Tindex2
>
void
host2dev
(
TensorView
<
T
,
Rank
,
PtrTraits1
,
Tindex1
>
dst
,
const
TensorView
<
const
T
,
Rank
,
PtrTraits2
,
Tindex2
>
src
,
cudaStream_t
s
=
0
)
{
host2dev
(
dst
.
data
(),
src
.
data
(),
std
::
min
(
dst
.
size
(),
src
.
size
()),
s
);
}
template
<
typename
T
,
int
Rank
,
template
<
class
>
class
PtrTraits1
,
template
<
class
>
class
PtrTraits2
,
typename
Tindex1
,
typename
Tindex2
>
void
host2dev
(
TensorView
<
T
,
Rank
,
PtrTraits1
,
Tindex1
>
dst
,
const
TensorView
<
T
,
Rank
,
PtrTraits2
,
Tindex2
>
src
,
cudaStream_t
s
=
0
)
{
host2dev
(
dst
.
data
(),
src
.
data
(),
std
::
min
(
dst
.
size
(),
src
.
size
()),
s
);
}
template
<
typename
T
>
void
host2dev_sync
(
T
*
dst
,
const
T
*
src
,
size_t
size
)
{
checkCudaErrors
(
cudaMemcpy
(
dst
,
src
,
size
*
sizeof
(
T
),
cudaMemcpyHostToDevice
));
}
template
<
typename
T
,
int
Rank
,
template
<
class
>
class
PtrTraits1
,
template
<
class
>
class
PtrTraits2
,
typename
Tindex1
,
typename
Tindex2
>
void
host2dev_sync
(
TensorView
<
T
,
Rank
,
PtrTraits1
,
Tindex1
>
dst
,
const
TensorView
<
const
T
,
Rank
,
PtrTraits2
,
Tindex2
>
src
)
{
host2dev_sync
(
dst
.
data
(),
src
.
data
(),
std
::
min
(
dst
.
size
(),
src
.
size
()));
}
template
<
typename
T
,
int
Rank
,
template
<
class
>
class
PtrTraits1
,
template
<
class
>
class
PtrTraits2
,
typename
Tindex1
,
typename
Tindex2
>
void
host2dev_sync
(
TensorView
<
T
,
Rank
,
PtrTraits1
,
Tindex1
>
dst
,
const
TensorView
<
T
,
Rank
,
PtrTraits2
,
Tindex2
>
src
)
{
host2dev_sync
(
dst
.
data
(),
src
.
data
(),
std
::
min
(
dst
.
size
(),
src
.
size
()));
}
template
<
typename
T
>
void
dev2host
(
T
*
dst
,
const
T
*
src
,
size_t
size
,
cudaStream_t
s
=
0
)
{
checkCudaErrors
(
cudaMemcpyAsync
(
dst
,
src
,
size
*
sizeof
(
T
),
cudaMemcpyDeviceToHost
,
s
));
}
template
<
typename
T
,
int
Rank
,
template
<
class
>
class
PtrTraits1
,
template
<
class
>
class
PtrTraits2
,
typename
Tindex1
,
typename
Tindex2
>
void
dev2host
(
TensorView
<
T
,
Rank
,
PtrTraits1
,
Tindex1
>
dst
,
const
TensorView
<
const
T
,
Rank
,
PtrTraits2
,
Tindex2
>
src
,
cudaStream_t
s
=
0
)
{
dev2host
(
dst
.
data
(),
src
.
data
(),
std
::
min
(
dst
.
size
(),
src
.
size
()),
s
);
}
template
<
typename
T
,
int
Rank
,
template
<
class
>
class
PtrTraits1
,
template
<
class
>
class
PtrTraits2
,
typename
Tindex1
,
typename
Tindex2
>
void
dev2host
(
TensorView
<
T
,
Rank
,
PtrTraits1
,
Tindex1
>
dst
,
const
TensorView
<
T
,
Rank
,
PtrTraits2
,
Tindex2
>
src
,
cudaStream_t
s
=
0
)
{
dev2host
(
dst
.
data
(),
src
.
data
(),
std
::
min
(
dst
.
size
(),
src
.
size
()),
s
);
}
template
<
typename
T
>
void
dev2dev
(
T
*
dst
,
const
T
*
src
,
size_t
size
,
cudaStream_t
s
=
0
)
{
checkCudaErrors
(
cudaMemcpyAsync
(
dst
,
src
,
size
*
sizeof
(
T
),
cudaMemcpyDeviceToDevice
,
s
));
}
template
<
typename
T
,
int
Rank
,
template
<
class
>
class
PtrTraits1
,
template
<
class
>
class
PtrTraits2
,
typename
Tindex1
,
typename
Tindex2
>
void
dev2dev
(
TensorView
<
T
,
Rank
,
PtrTraits1
,
Tindex1
>
dst
,
const
TensorView
<
const
T
,
Rank
,
PtrTraits2
,
Tindex2
>
src
,
cudaStream_t
s
=
0
)
{
dev2dev
(
dst
.
data
(),
src
.
data
(),
std
::
min
(
dst
.
size
(),
src
.
size
()),
s
);
}
template
<
typename
T
,
int
Rank
,
template
<
class
>
class
PtrTraits1
,
template
<
class
>
class
PtrTraits2
,
typename
Tindex1
,
typename
Tindex2
>
void
dev2dev
(
TensorView
<
T
,
Rank
,
PtrTraits1
,
Tindex1
>
dst
,
const
TensorView
<
T
,
Rank
,
PtrTraits2
,
Tindex2
>
src
,
cudaStream_t
s
=
0
)
{
dev2dev
(
dst
.
data
(),
src
.
data
(),
std
::
min
(
dst
.
size
(),
src
.
size
()),
s
);
}
template
<
typename
T
>
void
host2host
(
T
*
dst
,
const
T
*
src
,
size_t
size
,
cudaStream_t
s
=
0
)
{
checkCudaErrors
(
cudaMemcpyAsync
(
dst
,
src
,
size
*
sizeof
(
T
),
cudaMemcpyHostToHost
,
s
));
}
template
<
typename
T
,
int
Rank
,
template
<
class
>
class
PtrTraits1
,
template
<
class
>
class
PtrTraits2
,
typename
Tindex1
,
typename
Tindex2
>
void
host2host
(
TensorView
<
T
,
Rank
,
PtrTraits1
,
Tindex1
>
dst
,
const
TensorView
<
const
T
,
Rank
,
PtrTraits2
,
Tindex2
>
src
,
cudaStream_t
s
=
0
)
{
host2host
(
dst
.
data
(),
src
.
data
(),
std
::
min
(
dst
.
size
(),
src
.
size
()),
s
);
}
template
<
typename
T
,
int
Rank
,
template
<
class
>
class
PtrTraits1
,
template
<
class
>
class
PtrTraits2
,
typename
Tindex1
,
typename
Tindex2
>
void
host2host
(
TensorView
<
T
,
Rank
,
PtrTraits1
,
Tindex1
>
dst
,
const
TensorView
<
T
,
Rank
,
PtrTraits2
,
Tindex2
>
src
,
cudaStream_t
s
=
0
)
{
host2host
(
dst
.
data
(),
src
.
data
(),
std
::
min
(
dst
.
size
(),
src
.
size
()),
s
);
}
template
<
typename
T
,
int
Rank
,
template
<
class
>
class
PtrTraits
,
typename
Tindex
>
void
zero_dev
(
TensorView
<
T
,
Rank
,
PtrTraits
,
Tindex
>
tensor
)
{
checkCudaErrors
(
cudaMemset
(
tensor
.
data
(),
0
,
tensor
.
size
()
*
sizeof
(
T
)));
}
template
<
typename
T
,
int
Rank
,
template
<
class
>
class
PtrTraits
,
typename
Tindex
>
void
zero_dev
(
TensorView
<
T
,
Rank
,
PtrTraits
,
Tindex
>
tensor
,
cudaStream_t
s
)
{
checkCudaErrors
(
cudaMemsetAsync
(
tensor
.
data
(),
0
,
tensor
.
size
()
*
sizeof
(
T
),
s
));
}
template
<
typename
T
,
int
Rank
,
template
<
class
>
class
PtrTraits
,
typename
Tindex
>
void
zero_host
(
TensorView
<
T
,
Rank
,
PtrTraits
,
Tindex
>
tensor
)
{
std
::
fill
(
tensor
.
data
(),
tensor
.
data
()
+
tensor
.
size
(),
0
);
}
#endif
}
// namespace tv
\ No newline at end of file
include/tensorview/tools.h
deleted
100644 → 0
View file @
fad30002
// Copyright 2019-2020 Yan Yan
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <chrono>
#ifdef TV_CUDA
#include <cuda_runtime_api.h>
#endif
#include <iostream>
namespace
tv
{
#ifdef TV_CUDA
template
<
typename
TimeT
=
std
::
chrono
::
microseconds
>
struct
CudaContextTimer
{
CudaContextTimer
()
{
cudaDeviceSynchronize
();
mCurTime
=
std
::
chrono
::
steady_clock
::
now
();
}
typename
TimeT
::
rep
report
()
{
cudaDeviceSynchronize
();
auto
duration
=
std
::
chrono
::
duration_cast
<
TimeT
>
(
std
::
chrono
::
steady_clock
::
now
()
-
mCurTime
);
auto
res
=
duration
.
count
();
mCurTime
=
std
::
chrono
::
steady_clock
::
now
();
return
res
;
}
private:
std
::
chrono
::
time_point
<
std
::
chrono
::
steady_clock
>
mCurTime
;
};
#endif
template
<
typename
TimeT
=
std
::
chrono
::
microseconds
>
struct
CPUTimer
{
CPUTimer
()
{
mCurTime
=
std
::
chrono
::
steady_clock
::
now
();
}
typename
TimeT
::
rep
report
()
{
auto
duration
=
std
::
chrono
::
duration_cast
<
TimeT
>
(
std
::
chrono
::
steady_clock
::
now
()
-
mCurTime
);
auto
res
=
duration
.
count
();
mCurTime
=
std
::
chrono
::
steady_clock
::
now
();
return
res
;
}
private:
std
::
chrono
::
time_point
<
std
::
chrono
::
steady_clock
>
mCurTime
;
};
}
// namespace tv
include/tensorview/torch_utils.h
deleted
100644 → 0
View file @
fad30002
// Copyright 2019-2020 Yan Yan
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "mp_helper.h"
#include <tensorview/tensorview.h>
#include <ATen/ATen.h>
#include <torch/script.h>
#ifdef TV_CUDA
#include <ATen/cuda/CUDAContext.h>
#endif
namespace
tv
{
#ifdef TV_CUDA
struct
TorchGPU
:
public
tv
::
GPU
{
virtual
cudaStream_t
getStream
()
const
override
{
return
at
::
cuda
::
getCurrentCUDAStream
();
}
};
#endif
namespace
detail
{
template
<
typename
T
>
struct
TypeToTorchDtypeTraits
;
template
<
>
struct
TypeToTorchDtypeTraits
<
int32_t
>
{
static
constexpr
decltype
(
torch
::
kInt32
)
value
=
torch
::
kInt32
;
};
template
<
>
struct
TypeToTorchDtypeTraits
<
int16_t
>
{
static
constexpr
decltype
(
torch
::
kInt32
)
value
=
torch
::
kInt16
;
};
template
<
>
struct
TypeToTorchDtypeTraits
<
int8_t
>
{
static
constexpr
decltype
(
torch
::
kInt8
)
value
=
torch
::
kInt8
;
};
template
<
>
struct
TypeToTorchDtypeTraits
<
int64_t
>
{
static
constexpr
decltype
(
torch
::
kInt32
)
value
=
torch
::
kInt64
;
};
template
<
>
struct
TypeToTorchDtypeTraits
<
uint8_t
>
{
static
constexpr
decltype
(
torch
::
kInt32
)
value
=
torch
::
kUInt8
;
};
template
<
>
struct
TypeToTorchDtypeTraits
<
bool
>
{
static
constexpr
decltype
(
torch
::
kInt32
)
value
=
torch
::
kBool
;
};
template
<
>
struct
TypeToTorchDtypeTraits
<
float
>
{
static
constexpr
decltype
(
torch
::
kInt32
)
value
=
torch
::
kFloat32
;
};
template
<
>
struct
TypeToTorchDtypeTraits
<
double
>
{
static
constexpr
decltype
(
torch
::
kInt32
)
value
=
torch
::
kFloat64
;
};
template
<
>
struct
TypeToTorchDtypeTraits
<
at
::
Half
>
{
static
constexpr
decltype
(
torch
::
kInt32
)
value
=
torch
::
kHalf
;
};
using
all_torch_types_t
=
std
::
tuple
<
float
,
double
,
int8_t
,
int16_t
,
int32_t
,
int64_t
,
uint8_t
,
bool
,
at
::
Half
>
;
}
// namespace detail
template
<
typename
T
>
constexpr
decltype
(
torch
::
kInt32
)
torch_type_v
=
detail
::
TypeToTorchDtypeTraits
<
T
>::
value
;
template
<
class
...
Ts
,
typename
F
>
void
dispatch_torch
(
at
::
ScalarType
t
,
F
&&
f
)
{
static_assert
(
sizeof
...(
Ts
)
>
0
,
"you need to provide at least one type"
);
bool
notFound
=
true
;
tv
::
mp_for_each
<
mp_list
<
Ts
...
>>
([
=
,
&
notFound
,
&
f
](
auto
I
)
{
if
(
detail
::
TypeToTorchDtypeTraits
<
TV_DECLTYPE
(
I
)
>::
value
==
t
)
{
std
::
forward
<
F
>
(
f
)(
TV_DECLTYPE
(
I
)());
notFound
=
false
;
}
});
if
(
notFound
)
{
std
::
stringstream
ss
;
tv
::
mp_for_each
<
mp_list
<
Ts
...
>>
([
=
,
&
ss
](
auto
I
)
{
ss
<<
tv
::
detail
::
TypeToString
<
TV_DECLTYPE
(
I
)
>::
value
<<
" "
;
});
TV_THROW_RT_ERR
(
"unknown type"
,
t
,
", available:"
,
ss
.
str
());
}
}
template
<
class
T
>
struct
DispatchTorch
;
template
<
template
<
class
...
>
class
T
,
class
...
Args
>
struct
DispatchTorch
<
T
<
Args
...
>>
{
template
<
typename
F
>
inline
void
operator
()(
at
::
ScalarType
t
,
F
&&
f
)
{
return
dispatch_torch
<
Args
...
>
(
t
,
std
::
forward
<
F
>
(
f
));
}
};
template
<
typename
T
>
void
check_torch_dtype
(
const
torch
::
Tensor
&
tensor
)
{
DispatchTorch
<
detail
::
all_torch_types_t
>
()(
tensor
.
scalar_type
(),
[
&
](
auto
I
)
{
using
Ttensor
=
TV_DECLTYPE
(
I
);
constexpr
bool
val
=
std
::
is_same
<
std
::
remove_cv_t
<
T
>
,
Ttensor
>::
value
;
TV_ASSERT_RT_ERR
(
val
,
"error"
);
});
}
template
<
typename
T
,
int
Rank
=
-
1
,
template
<
class
>
class
PtrTraits
=
DefaultPtrTraits
,
typename
Tindex
=
int
>
TensorView
<
T
,
Rank
,
PtrTraits
,
Tindex
>
torch2tv
(
const
torch
::
Tensor
&
tensor
)
{
using
tv_shape_t
=
typename
TensorView
<
T
,
Rank
,
PtrTraits
,
Tindex
>::
tv_shape_t
;
check_torch_dtype
<
T
>
(
tensor
);
// TODO stride
if
(
Rank
>
0
)
{
TV_ASSERT_INVALID_ARG
(
tensor
.
dim
()
==
Rank
,
"error"
);
}
tv_shape_t
shape
;
for
(
auto
i
:
tensor
.
sizes
())
{
shape
.
push_back
(
i
);
}
return
tv
::
TensorView
<
T
,
Rank
,
PtrTraits
,
Tindex
>
(
tensor
.
data_ptr
<
std
::
remove_const_t
<
T
>>
(),
shape
);
}
template
<
typename
T
>
torch
::
Tensor
torch_slice_first_axis
(
torch
::
Tensor
tensor
,
T
start
,
T
end
)
{
// only torch >= 1.5 have tensor slice.
torch
::
Tensor
res
;
auto
tensor_shape
=
tensor
.
sizes
();
std
::
vector
<
int64_t
>
shape
(
tensor_shape
.
begin
(),
tensor_shape
.
end
());
shape
[
0
]
=
end
-
start
;
uint8_t
*
ptr
=
reinterpret_cast
<
uint8_t
*>
(
tensor
.
data_ptr
());
res
=
torch
::
from_blob
(
ptr
+
start
*
tensor
.
stride
(
0
)
*
tensor
.
itemsize
(),
torch
::
IntArrayRef
(
shape
),
tensor
.
options
());
return
res
;
}
namespace
detail
{
template
<
>
struct
TypeToString
<
at
::
Half
>
{
static
constexpr
const
char
*
value
=
"half"
;
};
}
// namespace detail
}
// namespace tv
\ No newline at end of file
include/torch_utils.h
deleted
100644 → 0
View file @
fad30002
// Copyright 2019 Yan Yan
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <tensorview/mp_helper.h>
#include <tensorview/tensorview.h>
#include <ATen/ATen.h>
#include <torch/script.h>
#ifdef TV_CUDA
#include <ATen/cuda/CUDAContext.h>
#endif
namespace
tv
{
#ifdef TV_CUDA
struct
TorchGPU
:
public
tv
::
GPU
{
virtual
cudaStream_t
getStream
()
const
override
{
return
at
::
cuda
::
getCurrentCUDAStream
();
}
};
#endif
template
<
typename
T
>
void
check_torch_dtype
(
const
torch
::
Tensor
&
tensor
)
{
switch
(
tensor
.
scalar_type
())
{
case
at
::
ScalarType
::
Double
:
{
auto
val
=
std
::
is_same
<
std
::
remove_const_t
<
T
>
,
double
>::
value
;
TV_ASSERT_RT_ERR
(
val
,
"error"
);
break
;
}
case
at
::
ScalarType
::
Float
:
{
auto
val
=
std
::
is_same
<
std
::
remove_const_t
<
T
>
,
float
>::
value
;
TV_ASSERT_RT_ERR
(
val
,
"error"
);
break
;
}
case
at
::
ScalarType
::
Int
:
{
auto
val
=
std
::
is_same
<
std
::
remove_const_t
<
T
>
,
int
>::
value
;
TV_ASSERT_RT_ERR
(
val
,
"error"
);
break
;
}
case
at
::
ScalarType
::
Half
:
{
auto
val
=
std
::
is_same
<
std
::
remove_const_t
<
T
>
,
at
::
Half
>::
value
;
TV_ASSERT_RT_ERR
(
val
,
"error"
);
break
;
}
case
at
::
ScalarType
::
Long
:
{
auto
val
=
std
::
is_same
<
std
::
remove_const_t
<
T
>
,
long
>::
value
;
TV_ASSERT_RT_ERR
(
val
,
"error"
);
break
;
}
default:
TV_ASSERT_RT_ERR
(
false
,
"error"
);
}
}
namespace
detail
{
template
<
typename
T
>
struct
TypeToTorchDtypeTraits
;
template
<
>
struct
TypeToTorchDtypeTraits
<
int32_t
>
{
static
constexpr
decltype
(
torch
::
kInt32
)
value
=
torch
::
kInt32
;
};
template
<
>
struct
TypeToTorchDtypeTraits
<
int64_t
>
{
static
constexpr
decltype
(
torch
::
kInt32
)
value
=
torch
::
kInt64
;
};
template
<
>
struct
TypeToTorchDtypeTraits
<
float
>
{
static
constexpr
decltype
(
torch
::
kInt32
)
value
=
torch
::
kFloat32
;
};
template
<
>
struct
TypeToTorchDtypeTraits
<
double
>
{
static
constexpr
decltype
(
torch
::
kInt32
)
value
=
torch
::
kFloat64
;
};
template
<
>
struct
TypeToTorchDtypeTraits
<
at
::
Half
>
{
static
constexpr
decltype
(
torch
::
kInt32
)
value
=
torch
::
kHalf
;
};
}
// namespace detail
template
<
typename
T
>
constexpr
decltype
(
torch
::
kInt32
)
torch_type_v
=
detail
::
TypeToTorchDtypeTraits
<
T
>::
value
;
template
<
typename
T
>
tv
::
TensorView
<
T
>
torch2tv
(
const
torch
::
Tensor
&
tensor
)
{
check_torch_dtype
<
T
>
(
tensor
);
tv
::
Shape
shape
;
for
(
auto
i
:
tensor
.
sizes
())
{
shape
.
push_back
(
i
);
}
return
tv
::
TensorView
<
T
>
(
tensor
.
data_ptr
<
std
::
remove_const_t
<
T
>>
(),
shape
);
}
namespace
detail
{
template
<
>
struct
TypeToString
<
at
::
Half
>
{
static
constexpr
const
char
*
value
=
"half"
;
};
}
// namespace detail
template
<
class
...
Ts
,
typename
F
>
void
dispatch_torch
(
at
::
ScalarType
t
,
F
&&
f
)
{
static_assert
(
sizeof
...(
Ts
)
>
0
,
"you need to provide at least one type"
);
bool
notFound
=
true
;
spconv
::
tv
::
mp_for_each
<
spconv
::
mp_list
<
Ts
...
>>
([
=
,
&
notFound
,
&
f
](
auto
I
)
{
if
(
torch_type_v
<
decltype
(
I
)
>
==
t
)
{
std
::
forward
<
F
>
(
f
)(
decltype
(
I
)());
notFound
=
false
;
}
});
if
(
notFound
)
{
std
::
stringstream
ss
;
spconv
::
tv
::
mp_for_each
<
spconv
::
mp_list
<
Ts
...
>>
([
=
,
&
ss
](
auto
I
)
{
ss
<<
tv
::
detail
::
TypeToString
<
decltype
(
I
)
>::
value
<<
" "
;
});
TV_THROW_RT_ERR
(
"unknown type"
,
t
,
", available: "
,
ss
.
str
());
}
}
}
// namespace tv
\ No newline at end of file
include/tsl/robin_growth_policy.h
deleted
100644 → 0
View file @
fad30002
/**
* MIT License
*
* Copyright (c) 2017 Tessil
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_ROBIN_GROWTH_POLICY_H
#define TSL_ROBIN_GROWTH_POLICY_H
#include <algorithm>
#include <array>
#include <climits>
#include <cmath>
#include <cstddef>
#include <iterator>
#include <limits>
#include <ratio>
#include <stdexcept>
#ifdef TSL_DEBUG
#define tsl_rh_assert(expr) assert(expr)
#else
#define tsl_rh_assert(expr) (static_cast<void>(0))
#endif
/**
* If exceptions are enabled, throw the exception passed in parameter, otherwise
* call std::terminate.
*/
#if (defined(__cpp_exceptions) || defined(__EXCEPTIONS) || \
(defined(_MSC_VER) && defined(_CPPUNWIND))) && \
!defined(TSL_NO_EXCEPTIONS)
#define TSL_RH_THROW_OR_TERMINATE(ex, msg) throw ex(msg)
#else
#ifdef NDEBUG
#define TSL_RH_THROW_OR_TERMINATE(ex, msg) std::terminate()
#else
#include <cstdio>
#define TSL_RH_THROW_OR_TERMINATE(ex, msg) \
do { \
std::fprintf(stderr, msg); \
std::terminate(); \
} while (0)
#endif
#endif
#if defined(__GNUC__) || defined(__clang__)
#define TSL_RH_LIKELY(exp) (__builtin_expect(!!(exp), true))
#else
#define TSL_RH_LIKELY(exp) (exp)
#endif
namespace
tsl
{
namespace
rh
{
/**
* Grow the hash table by a factor of GrowthFactor keeping the bucket count to a
* power of two. It allows the table to use a mask operation instead of a modulo
* operation to map a hash to a bucket.
*
* GrowthFactor must be a power of two >= 2.
*/
template
<
std
::
size_t
GrowthFactor
>
class
power_of_two_growth_policy
{
public:
/**
* Called on the hash table creation and on rehash. The number of buckets for
* the table is passed in parameter. This number is a minimum, the policy may
* update this value with a higher value if needed (but not lower).
*
* If 0 is given, min_bucket_count_in_out must still be 0 after the policy
* creation and bucket_for_hash must always return 0 in this case.
*/
explicit
power_of_two_growth_policy
(
std
::
size_t
&
min_bucket_count_in_out
)
{
if
(
min_bucket_count_in_out
>
max_bucket_count
())
{
TSL_RH_THROW_OR_TERMINATE
(
std
::
length_error
,
"The hash table exceeds its maxmimum size."
);
}
if
(
min_bucket_count_in_out
>
0
)
{
min_bucket_count_in_out
=
round_up_to_power_of_two
(
min_bucket_count_in_out
);
m_mask
=
min_bucket_count_in_out
-
1
;
}
else
{
m_mask
=
0
;
}
}
/**
* Return the bucket [0, bucket_count()) to which the hash belongs.
* If bucket_count() is 0, it must always return 0.
*/
std
::
size_t
bucket_for_hash
(
std
::
size_t
hash
)
const
noexcept
{
return
hash
&
m_mask
;
}
/**
* Return the number of buckets that should be used on next growth.
*/
std
::
size_t
next_bucket_count
()
const
{
if
((
m_mask
+
1
)
>
max_bucket_count
()
/
GrowthFactor
)
{
TSL_RH_THROW_OR_TERMINATE
(
std
::
length_error
,
"The hash table exceeds its maxmimum size."
);
}
return
(
m_mask
+
1
)
*
GrowthFactor
;
}
/**
* Return the maximum number of buckets supported by the policy.
*/
std
::
size_t
max_bucket_count
()
const
{
// Largest power of two.
return
(
std
::
numeric_limits
<
std
::
size_t
>::
max
()
/
2
)
+
1
;
}
/**
* Reset the growth policy as if it was created with a bucket count of 0.
* After a clear, the policy must always return 0 when bucket_for_hash is
* called.
*/
void
clear
()
noexcept
{
m_mask
=
0
;
}
private:
static
std
::
size_t
round_up_to_power_of_two
(
std
::
size_t
value
)
{
if
(
is_power_of_two
(
value
))
{
return
value
;
}
if
(
value
==
0
)
{
return
1
;
}
--
value
;
for
(
std
::
size_t
i
=
1
;
i
<
sizeof
(
std
::
size_t
)
*
CHAR_BIT
;
i
*=
2
)
{
value
|=
value
>>
i
;
}
return
value
+
1
;
}
static
constexpr
bool
is_power_of_two
(
std
::
size_t
value
)
{
return
value
!=
0
&&
(
value
&
(
value
-
1
))
==
0
;
}
protected:
static_assert
(
is_power_of_two
(
GrowthFactor
)
&&
GrowthFactor
>=
2
,
"GrowthFactor must be a power of two >= 2."
);
std
::
size_t
m_mask
;
};
/**
* Grow the hash table by GrowthFactor::num / GrowthFactor::den and use a modulo
* to map a hash to a bucket. Slower but it can be useful if you want a slower
* growth.
*/
template
<
class
GrowthFactor
=
std
::
ratio
<
3
,
2
>
>
class
mod_growth_policy
{
public:
explicit
mod_growth_policy
(
std
::
size_t
&
min_bucket_count_in_out
)
{
if
(
min_bucket_count_in_out
>
max_bucket_count
())
{
TSL_RH_THROW_OR_TERMINATE
(
std
::
length_error
,
"The hash table exceeds its maxmimum size."
);
}
if
(
min_bucket_count_in_out
>
0
)
{
m_mod
=
min_bucket_count_in_out
;
}
else
{
m_mod
=
1
;
}
}
std
::
size_t
bucket_for_hash
(
std
::
size_t
hash
)
const
noexcept
{
return
hash
%
m_mod
;
}
std
::
size_t
next_bucket_count
()
const
{
if
(
m_mod
==
max_bucket_count
())
{
TSL_RH_THROW_OR_TERMINATE
(
std
::
length_error
,
"The hash table exceeds its maxmimum size."
);
}
const
double
next_bucket_count
=
std
::
ceil
(
double
(
m_mod
)
*
REHASH_SIZE_MULTIPLICATION_FACTOR
);
if
(
!
std
::
isnormal
(
next_bucket_count
))
{
TSL_RH_THROW_OR_TERMINATE
(
std
::
length_error
,
"The hash table exceeds its maxmimum size."
);
}
if
(
next_bucket_count
>
double
(
max_bucket_count
()))
{
return
max_bucket_count
();
}
else
{
return
std
::
size_t
(
next_bucket_count
);
}
}
std
::
size_t
max_bucket_count
()
const
{
return
MAX_BUCKET_COUNT
;
}
void
clear
()
noexcept
{
m_mod
=
1
;
}
private:
static
constexpr
double
REHASH_SIZE_MULTIPLICATION_FACTOR
=
1.0
*
GrowthFactor
::
num
/
GrowthFactor
::
den
;
static
const
std
::
size_t
MAX_BUCKET_COUNT
=
std
::
size_t
(
double
(
std
::
numeric_limits
<
std
::
size_t
>::
max
()
/
REHASH_SIZE_MULTIPLICATION_FACTOR
));
static_assert
(
REHASH_SIZE_MULTIPLICATION_FACTOR
>=
1.1
,
"Growth factor should be >= 1.1."
);
std
::
size_t
m_mod
;
};
namespace
detail
{
static
constexpr
const
std
::
array
<
std
::
size_t
,
40
>
PRIMES
=
{
{
1ul
,
5ul
,
17ul
,
29ul
,
37ul
,
53ul
,
67ul
,
79ul
,
97ul
,
131ul
,
193ul
,
257ul
,
389ul
,
521ul
,
769ul
,
1031ul
,
1543ul
,
2053ul
,
3079ul
,
6151ul
,
12289ul
,
24593ul
,
49157ul
,
98317ul
,
196613ul
,
393241ul
,
786433ul
,
1572869ul
,
3145739ul
,
6291469ul
,
12582917ul
,
25165843ul
,
50331653ul
,
100663319ul
,
201326611ul
,
402653189ul
,
805306457ul
,
1610612741ul
,
3221225473ul
,
4294967291ul
}};
template
<
unsigned
int
IPrime
>
static
constexpr
std
::
size_t
mod
(
std
::
size_t
hash
)
{
return
hash
%
PRIMES
[
IPrime
];
}
// MOD_PRIME[iprime](hash) returns hash % PRIMES[iprime]. This table allows for
// faster modulo as the compiler can optimize the modulo code better with a
// constant known at the compilation.
static
constexpr
const
std
::
array
<
std
::
size_t
(
*
)(
std
::
size_t
),
40
>
MOD_PRIME
=
{{
&
mod
<
0
>
,
&
mod
<
1
>
,
&
mod
<
2
>
,
&
mod
<
3
>
,
&
mod
<
4
>
,
&
mod
<
5
>
,
&
mod
<
6
>
,
&
mod
<
7
>
,
&
mod
<
8
>
,
&
mod
<
9
>
,
&
mod
<
10
>
,
&
mod
<
11
>
,
&
mod
<
12
>
,
&
mod
<
13
>
,
&
mod
<
14
>
,
&
mod
<
15
>
,
&
mod
<
16
>
,
&
mod
<
17
>
,
&
mod
<
18
>
,
&
mod
<
19
>
,
&
mod
<
20
>
,
&
mod
<
21
>
,
&
mod
<
22
>
,
&
mod
<
23
>
,
&
mod
<
24
>
,
&
mod
<
25
>
,
&
mod
<
26
>
,
&
mod
<
27
>
,
&
mod
<
28
>
,
&
mod
<
29
>
,
&
mod
<
30
>
,
&
mod
<
31
>
,
&
mod
<
32
>
,
&
mod
<
33
>
,
&
mod
<
34
>
,
&
mod
<
35
>
,
&
mod
<
36
>
,
&
mod
<
37
>
,
&
mod
<
38
>
,
&
mod
<
39
>
}};
}
// namespace detail
/**
* Grow the hash table by using prime numbers as bucket count. Slower than
* tsl::rh::power_of_two_growth_policy in general but will probably distribute
* the values around better in the buckets with a poor hash function.
*
* To allow the compiler to optimize the modulo operation, a lookup table is
* used with constant primes numbers.
*
* With a switch the code would look like:
* \code
* switch(iprime) { // iprime is the current prime of the hash table
* case 0: hash % 5ul;
* break;
* case 1: hash % 17ul;
* break;
* case 2: hash % 29ul;
* break;
* ...
* }
* \endcode
*
* Due to the constant variable in the modulo the compiler is able to optimize
* the operation by a series of multiplications, substractions and shifts.
*
* The 'hash % 5' could become something like 'hash - (hash * 0xCCCCCCCD) >> 34)
* * 5' in a 64 bits environement.
*/
class
prime_growth_policy
{
public:
explicit
prime_growth_policy
(
std
::
size_t
&
min_bucket_count_in_out
)
{
auto
it_prime
=
std
::
lower_bound
(
detail
::
PRIMES
.
begin
(),
detail
::
PRIMES
.
end
(),
min_bucket_count_in_out
);
if
(
it_prime
==
detail
::
PRIMES
.
end
())
{
TSL_RH_THROW_OR_TERMINATE
(
std
::
length_error
,
"The hash table exceeds its maxmimum size."
);
}
m_iprime
=
static_cast
<
unsigned
int
>
(
std
::
distance
(
detail
::
PRIMES
.
begin
(),
it_prime
));
if
(
min_bucket_count_in_out
>
0
)
{
min_bucket_count_in_out
=
*
it_prime
;
}
else
{
min_bucket_count_in_out
=
0
;
}
}
std
::
size_t
bucket_for_hash
(
std
::
size_t
hash
)
const
noexcept
{
return
detail
::
MOD_PRIME
[
m_iprime
](
hash
);
}
std
::
size_t
next_bucket_count
()
const
{
if
(
m_iprime
+
1
>=
detail
::
PRIMES
.
size
())
{
TSL_RH_THROW_OR_TERMINATE
(
std
::
length_error
,
"The hash table exceeds its maxmimum size."
);
}
return
detail
::
PRIMES
[
m_iprime
+
1
];
}
std
::
size_t
max_bucket_count
()
const
{
return
detail
::
PRIMES
.
back
();
}
void
clear
()
noexcept
{
m_iprime
=
0
;
}
private:
unsigned
int
m_iprime
;
static_assert
(
std
::
numeric_limits
<
decltype
(
m_iprime
)
>::
max
()
>=
detail
::
PRIMES
.
size
(),
"The type of m_iprime is not big enough."
);
};
}
// namespace rh
}
// namespace tsl
#endif
include/tsl/robin_hash.h
deleted
100644 → 0
View file @
fad30002
/**
* MIT License
*
* Copyright (c) 2017 Tessil
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_ROBIN_HASH_H
#define TSL_ROBIN_HASH_H
#include "robin_growth_policy.h"
#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <exception>
#include <iterator>
#include <limits>
#include <memory>
#include <stdexcept>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>
namespace
tsl
{
namespace
detail_robin_hash
{
template
<
typename
T
>
struct
make_void
{
using
type
=
void
;
};
template
<
typename
T
,
typename
=
void
>
struct
has_is_transparent
:
std
::
false_type
{};
template
<
typename
T
>
struct
has_is_transparent
<
T
,
typename
make_void
<
typename
T
::
is_transparent
>::
type
>
:
std
::
true_type
{};
template
<
typename
U
>
struct
is_power_of_two_policy
:
std
::
false_type
{};
template
<
std
::
size_t
GrowthFactor
>
struct
is_power_of_two_policy
<
tsl
::
rh
::
power_of_two_growth_policy
<
GrowthFactor
>>
:
std
::
true_type
{};
// Only available in C++17, we need to be compatible with C++11
template
<
class
T
>
const
T
&
clamp
(
const
T
&
v
,
const
T
&
lo
,
const
T
&
hi
)
{
return
std
::
min
(
hi
,
std
::
max
(
lo
,
v
));
}
using
truncated_hash_type
=
std
::
uint_least32_t
;
/**
* Helper class that stores a truncated hash if StoreHash is true and nothing
* otherwise.
*/
template
<
bool
StoreHash
>
class
bucket_entry_hash
{
public:
bool
bucket_hash_equal
(
std
::
size_t
/*hash*/
)
const
noexcept
{
return
true
;
}
truncated_hash_type
truncated_hash
()
const
noexcept
{
return
0
;
}
protected:
void
set_hash
(
truncated_hash_type
/*hash*/
)
noexcept
{}
};
template
<
>
class
bucket_entry_hash
<
true
>
{
public:
bool
bucket_hash_equal
(
std
::
size_t
hash
)
const
noexcept
{
return
m_hash
==
truncated_hash_type
(
hash
);
}
truncated_hash_type
truncated_hash
()
const
noexcept
{
return
m_hash
;
}
protected:
void
set_hash
(
truncated_hash_type
hash
)
noexcept
{
m_hash
=
truncated_hash_type
(
hash
);
}
private:
truncated_hash_type
m_hash
;
};
/**
* Each bucket entry has:
* - A value of type `ValueType`.
* - An integer to store how far the value of the bucket, if any, is from its
* ideal bucket (ex: if the current bucket 5 has the value 'foo' and
* `hash('foo') % nb_buckets` == 3, `dist_from_ideal_bucket()` will return 2 as
* the current value of the bucket is two buckets away from its ideal bucket) If
* there is no value in the bucket (i.e. `empty()` is true)
* `dist_from_ideal_bucket()` will be < 0.
* - A marker which tells us if the bucket is the last bucket of the bucket
* array (useful for the iterator of the hash table).
* - If `StoreHash` is true, 32 bits of the hash of the value, if any, are also
* stored in the bucket. If the size of the hash is more than 32 bits, it is
* truncated. We don't store the full hash as storing the hash is a potential
* opportunity to use the unused space due to the alignement of the bucket_entry
* structure. We can thus potentially store the hash without any extra space
* (which would not be possible with 64 bits of the hash).
*/
template
<
typename
ValueType
,
bool
StoreHash
>
class
bucket_entry
:
public
bucket_entry_hash
<
StoreHash
>
{
using
bucket_hash
=
bucket_entry_hash
<
StoreHash
>
;
public:
using
value_type
=
ValueType
;
using
distance_type
=
std
::
int_least16_t
;
bucket_entry
()
noexcept
:
bucket_hash
(),
m_dist_from_ideal_bucket
(
EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET
),
m_last_bucket
(
false
)
{
tsl_rh_assert
(
empty
());
}
bucket_entry
(
bool
last_bucket
)
noexcept
:
bucket_hash
(),
m_dist_from_ideal_bucket
(
EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET
),
m_last_bucket
(
last_bucket
)
{
tsl_rh_assert
(
empty
());
}
bucket_entry
(
const
bucket_entry
&
other
)
noexcept
(
std
::
is_nothrow_copy_constructible
<
value_type
>::
value
)
:
bucket_hash
(
other
),
m_dist_from_ideal_bucket
(
EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET
),
m_last_bucket
(
other
.
m_last_bucket
)
{
if
(
!
other
.
empty
())
{
::
new
(
static_cast
<
void
*>
(
std
::
addressof
(
m_value
)))
value_type
(
other
.
value
());
m_dist_from_ideal_bucket
=
other
.
m_dist_from_ideal_bucket
;
}
}
/**
* Never really used, but still necessary as we must call resize on an empty
* `std::vector<bucket_entry>`. and we need to support move-only types. See
* robin_hash constructor for details.
*/
bucket_entry
(
bucket_entry
&&
other
)
noexcept
(
std
::
is_nothrow_move_constructible
<
value_type
>::
value
)
:
bucket_hash
(
std
::
move
(
other
)),
m_dist_from_ideal_bucket
(
EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET
),
m_last_bucket
(
other
.
m_last_bucket
)
{
if
(
!
other
.
empty
())
{
::
new
(
static_cast
<
void
*>
(
std
::
addressof
(
m_value
)))
value_type
(
std
::
move
(
other
.
value
()));
m_dist_from_ideal_bucket
=
other
.
m_dist_from_ideal_bucket
;
}
}
bucket_entry
&
operator
=
(
const
bucket_entry
&
other
)
noexcept
(
std
::
is_nothrow_copy_constructible
<
value_type
>::
value
)
{
if
(
this
!=
&
other
)
{
clear
();
bucket_hash
::
operator
=
(
other
);
if
(
!
other
.
empty
())
{
::
new
(
static_cast
<
void
*>
(
std
::
addressof
(
m_value
)))
value_type
(
other
.
value
());
}
m_dist_from_ideal_bucket
=
other
.
m_dist_from_ideal_bucket
;
m_last_bucket
=
other
.
m_last_bucket
;
}
return
*
this
;
}
bucket_entry
&
operator
=
(
bucket_entry
&&
)
=
delete
;
~
bucket_entry
()
noexcept
{
clear
();
}
void
clear
()
noexcept
{
if
(
!
empty
())
{
destroy_value
();
m_dist_from_ideal_bucket
=
EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET
;
}
}
bool
empty
()
const
noexcept
{
return
m_dist_from_ideal_bucket
==
EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET
;
}
value_type
&
value
()
noexcept
{
tsl_rh_assert
(
!
empty
());
return
*
reinterpret_cast
<
value_type
*>
(
std
::
addressof
(
m_value
));
}
const
value_type
&
value
()
const
noexcept
{
tsl_rh_assert
(
!
empty
());
return
*
reinterpret_cast
<
const
value_type
*>
(
std
::
addressof
(
m_value
));
}
distance_type
dist_from_ideal_bucket
()
const
noexcept
{
return
m_dist_from_ideal_bucket
;
}
bool
last_bucket
()
const
noexcept
{
return
m_last_bucket
;
}
void
set_as_last_bucket
()
noexcept
{
m_last_bucket
=
true
;
}
template
<
typename
...
Args
>
void
set_value_of_empty_bucket
(
distance_type
dist_from_ideal_bucket
,
truncated_hash_type
hash
,
Args
&&
...
value_type_args
)
{
tsl_rh_assert
(
dist_from_ideal_bucket
>=
0
);
tsl_rh_assert
(
empty
());
::
new
(
static_cast
<
void
*>
(
std
::
addressof
(
m_value
)))
value_type
(
std
::
forward
<
Args
>
(
value_type_args
)...);
this
->
set_hash
(
hash
);
m_dist_from_ideal_bucket
=
dist_from_ideal_bucket
;
tsl_rh_assert
(
!
empty
());
}
void
swap_with_value_in_bucket
(
distance_type
&
dist_from_ideal_bucket
,
truncated_hash_type
&
hash
,
value_type
&
value
)
{
tsl_rh_assert
(
!
empty
());
using
std
::
swap
;
swap
(
value
,
this
->
value
());
swap
(
dist_from_ideal_bucket
,
m_dist_from_ideal_bucket
);
// Avoid warning of unused variable if StoreHash is false
(
void
)
hash
;
if
(
StoreHash
)
{
const
truncated_hash_type
tmp_hash
=
this
->
truncated_hash
();
this
->
set_hash
(
hash
);
hash
=
tmp_hash
;
}
}
static
truncated_hash_type
truncate_hash
(
std
::
size_t
hash
)
noexcept
{
return
truncated_hash_type
(
hash
);
}
private:
void
destroy_value
()
noexcept
{
tsl_rh_assert
(
!
empty
());
value
().
~
value_type
();
}
private:
using
storage
=
typename
std
::
aligned_storage
<
sizeof
(
value_type
),
alignof
(
value_type
)
>::
type
;
static
const
distance_type
EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET
=
-
1
;
distance_type
m_dist_from_ideal_bucket
;
bool
m_last_bucket
;
storage
m_value
;
};
/**
* Internal common class used by `robin_map` and `robin_set`.
*
* ValueType is what will be stored by `robin_hash` (usually `std::pair<Key, T>`
* for map and `Key` for set).
*
* `KeySelect` should be a `FunctionObject` which takes a `ValueType` in
* parameter and returns a reference to the key.
*
* `ValueSelect` should be a `FunctionObject` which takes a `ValueType` in
* parameter and returns a reference to the value. `ValueSelect` should be void
* if there is no value (in a set for example).
*
* The strong exception guarantee only holds if the expression
* `std::is_nothrow_swappable<ValueType>::value &&
* std::is_nothrow_move_constructible<ValueType>::value` is true.
*
* Behaviour is undefined if the destructor of `ValueType` throws.
*/
template
<
class
ValueType
,
class
KeySelect
,
class
ValueSelect
,
class
Hash
,
class
KeyEqual
,
class
Allocator
,
bool
StoreHash
,
class
GrowthPolicy
>
class
robin_hash
:
private
Hash
,
private
KeyEqual
,
private
GrowthPolicy
{
private:
template
<
typename
U
>
using
has_mapped_type
=
typename
std
::
integral_constant
<
bool
,
!
std
::
is_same
<
U
,
void
>::
value
>
;
static_assert
(
noexcept
(
std
::
declval
<
GrowthPolicy
>
().
bucket_for_hash
(
std
::
size_t
(
0
))),
"GrowthPolicy::bucket_for_hash must be noexcept."
);
static_assert
(
noexcept
(
std
::
declval
<
GrowthPolicy
>
().
clear
()),
"GrowthPolicy::clear must be noexcept."
);
public:
template
<
bool
IsConst
>
class
robin_iterator
;
using
key_type
=
typename
KeySelect
::
key_type
;
using
value_type
=
ValueType
;
using
size_type
=
std
::
size_t
;
using
difference_type
=
std
::
ptrdiff_t
;
using
hasher
=
Hash
;
using
key_equal
=
KeyEqual
;
using
allocator_type
=
Allocator
;
using
reference
=
value_type
&
;
using
const_reference
=
const
value_type
&
;
using
pointer
=
value_type
*
;
using
const_pointer
=
const
value_type
*
;
using
iterator
=
robin_iterator
<
false
>
;
using
const_iterator
=
robin_iterator
<
true
>
;
private:
/**
* Either store the hash because we are asked by the `StoreHash` template
* parameter or store the hash because it doesn't cost us anything in size and
* can be used to speed up rehash.
*/
static
constexpr
bool
STORE_HASH
=
StoreHash
||
((
sizeof
(
tsl
::
detail_robin_hash
::
bucket_entry
<
value_type
,
true
>
)
==
sizeof
(
tsl
::
detail_robin_hash
::
bucket_entry
<
value_type
,
false
>
))
&&
(
sizeof
(
std
::
size_t
)
==
sizeof
(
truncated_hash_type
)
||
is_power_of_two_policy
<
GrowthPolicy
>::
value
)
&&
// Don't store the hash for primitive types with default hash.
(
!
std
::
is_arithmetic
<
key_type
>::
value
||
!
std
::
is_same
<
Hash
,
std
::
hash
<
key_type
>>::
value
));
/**
* Only use the stored hash on lookup if we are explictly asked. We are not
* sure how slow the KeyEqual operation is. An extra comparison may slow
* things down with a fast KeyEqual.
*/
static
constexpr
bool
USE_STORED_HASH_ON_LOOKUP
=
StoreHash
;
/**
* We can only use the hash on rehash if the size of the hash type is the same
* as the stored one or if we use a power of two modulo. In the case of the
* power of two modulo, we just mask the least significant bytes, we just have
* to check that the truncated_hash_type didn't truncated more bytes.
*/
static
bool
USE_STORED_HASH_ON_REHASH
(
size_type
bucket_count
)
{
(
void
)
bucket_count
;
if
(
STORE_HASH
&&
sizeof
(
std
::
size_t
)
==
sizeof
(
truncated_hash_type
))
{
return
true
;
}
else
if
(
STORE_HASH
&&
is_power_of_two_policy
<
GrowthPolicy
>::
value
)
{
tsl_rh_assert
(
bucket_count
>
0
);
return
(
bucket_count
-
1
)
<=
std
::
numeric_limits
<
truncated_hash_type
>::
max
();
}
else
{
return
false
;
}
}
using
bucket_entry
=
tsl
::
detail_robin_hash
::
bucket_entry
<
value_type
,
STORE_HASH
>
;
using
distance_type
=
typename
bucket_entry
::
distance_type
;
using
buckets_allocator
=
typename
std
::
allocator_traits
<
allocator_type
>::
template
rebind_alloc
<
bucket_entry
>;
using
buckets_container_type
=
std
::
vector
<
bucket_entry
,
buckets_allocator
>
;
public:
/**
* The 'operator*()' and 'operator->()' methods return a const reference and
* const pointer respectively to the stored value type.
*
* In case of a map, to get a mutable reference to the value associated to a
* key (the '.second' in the stored pair), you have to call 'value()'.
*
* The main reason for this is that if we returned a `std::pair<Key, T>&`
* instead of a `const std::pair<Key, T>&`, the user may modify the key which
* will put the map in a undefined state.
*/
template
<
bool
IsConst
>
class
robin_iterator
{
friend
class
robin_hash
;
private:
using
bucket_entry_ptr
=
typename
std
::
conditional
<
IsConst
,
const
bucket_entry
*
,
bucket_entry
*>::
type
;
robin_iterator
(
bucket_entry_ptr
bucket
)
noexcept
:
m_bucket
(
bucket
)
{}
public:
using
iterator_category
=
std
::
forward_iterator_tag
;
using
value_type
=
const
typename
robin_hash
::
value_type
;
using
difference_type
=
std
::
ptrdiff_t
;
using
reference
=
value_type
&
;
using
pointer
=
value_type
*
;
robin_iterator
()
noexcept
{}
// Copy constructor from iterator to const_iterator.
template
<
bool
TIsConst
=
IsConst
,
typename
std
::
enable_if
<
TIsConst
>
::
type
*
=
nullptr
>
robin_iterator
(
const
robin_iterator
<!
TIsConst
>
&
other
)
noexcept
:
m_bucket
(
other
.
m_bucket
)
{}
robin_iterator
(
const
robin_iterator
&
other
)
=
default
;
robin_iterator
(
robin_iterator
&&
other
)
=
default
;
robin_iterator
&
operator
=
(
const
robin_iterator
&
other
)
=
default
;
robin_iterator
&
operator
=
(
robin_iterator
&&
other
)
=
default
;
const
typename
robin_hash
::
key_type
&
key
()
const
{
return
KeySelect
()(
m_bucket
->
value
());
}
template
<
class
U
=
ValueSelect
,
typename
std
::
enable_if
<
has_mapped_type
<
U
>
::
value
&&
IsConst
>::
type
*
=
nullptr
>
const
typename
U
::
value_type
&
value
()
const
{
return
U
()(
m_bucket
->
value
());
}
template
<
class
U
=
ValueSelect
,
typename
std
::
enable_if
<
has_mapped_type
<
U
>
::
value
&&
!
IsConst
>::
type
*
=
nullptr
>
typename
U
::
value_type
&
value
()
{
return
U
()(
m_bucket
->
value
());
}
reference
operator
*
()
const
{
return
m_bucket
->
value
();
}
pointer
operator
->
()
const
{
return
std
::
addressof
(
m_bucket
->
value
());
}
robin_iterator
&
operator
++
()
{
while
(
true
)
{
if
(
m_bucket
->
last_bucket
())
{
++
m_bucket
;
return
*
this
;
}
++
m_bucket
;
if
(
!
m_bucket
->
empty
())
{
return
*
this
;
}
}
}
robin_iterator
operator
++
(
int
)
{
robin_iterator
tmp
(
*
this
);
++*
this
;
return
tmp
;
}
friend
bool
operator
==
(
const
robin_iterator
&
lhs
,
const
robin_iterator
&
rhs
)
{
return
lhs
.
m_bucket
==
rhs
.
m_bucket
;
}
friend
bool
operator
!=
(
const
robin_iterator
&
lhs
,
const
robin_iterator
&
rhs
)
{
return
!
(
lhs
==
rhs
);
}
private:
bucket_entry_ptr
m_bucket
;
};
public:
#if defined(__cplusplus) && __cplusplus >= 201402L
robin_hash
(
size_type
bucket_count
,
const
Hash
&
hash
,
const
KeyEqual
&
equal
,
const
Allocator
&
alloc
,
float
min_load_factor
=
DEFAULT_MIN_LOAD_FACTOR
,
float
max_load_factor
=
DEFAULT_MAX_LOAD_FACTOR
)
:
Hash
(
hash
),
KeyEqual
(
equal
),
GrowthPolicy
(
bucket_count
),
m_buckets_data
(
[
&
]()
{
if
(
bucket_count
>
max_bucket_count
())
{
TSL_RH_THROW_OR_TERMINATE
(
std
::
length_error
,
"The map exceeds its maximum bucket count."
);
}
return
bucket_count
;
}(),
alloc
),
m_buckets
(
m_buckets_data
.
empty
()
?
static_empty_bucket_ptr
()
:
m_buckets_data
.
data
()),
m_bucket_count
(
bucket_count
),
m_nb_elements
(
0
),
m_grow_on_next_insert
(
false
),
m_try_skrink_on_next_insert
(
false
)
{
if
(
m_bucket_count
>
0
)
{
tsl_rh_assert
(
!
m_buckets_data
.
empty
());
m_buckets_data
.
back
().
set_as_last_bucket
();
}
this
->
min_load_factor
(
min_load_factor
);
this
->
max_load_factor
(
max_load_factor
);
}
#else
/**
* C++11 doesn't support the creation of a std::vector with a custom allocator
* and 'count' default-inserted elements. The needed contructor `explicit
* vector(size_type count, const Allocator& alloc = Allocator());` is only
* available in C++14 and later. We thus must resize after using the
* `vector(const Allocator& alloc)` constructor.
*
* We can't use `vector(size_type count, const T& value, const Allocator&
* alloc)` as it requires the value T to be copyable.
*/
robin_hash
(
size_type
bucket_count
,
const
Hash
&
hash
,
const
KeyEqual
&
equal
,
const
Allocator
&
alloc
,
float
min_load_factor
=
DEFAULT_MIN_LOAD_FACTOR
,
float
max_load_factor
=
DEFAULT_MAX_LOAD_FACTOR
)
:
Hash
(
hash
),
KeyEqual
(
equal
),
GrowthPolicy
(
bucket_count
),
m_buckets_data
(
alloc
),
m_buckets
(
static_empty_bucket_ptr
()),
m_bucket_count
(
bucket_count
),
m_nb_elements
(
0
),
m_grow_on_next_insert
(
false
),
m_try_skrink_on_next_insert
(
false
)
{
if
(
bucket_count
>
max_bucket_count
())
{
TSL_RH_THROW_OR_TERMINATE
(
std
::
length_error
,
"The map exceeds its maxmimum bucket count."
);
}
if
(
m_bucket_count
>
0
)
{
m_buckets_data
.
resize
(
m_bucket_count
);
m_buckets
=
m_buckets_data
.
data
();
tsl_rh_assert
(
!
m_buckets_data
.
empty
());
m_buckets_data
.
back
().
set_as_last_bucket
();
}
this
->
min_load_factor
(
min_load_factor
);
this
->
max_load_factor
(
max_load_factor
);
}
#endif
robin_hash
(
const
robin_hash
&
other
)
:
Hash
(
other
),
KeyEqual
(
other
),
GrowthPolicy
(
other
),
m_buckets_data
(
other
.
m_buckets_data
),
m_buckets
(
m_buckets_data
.
empty
()
?
static_empty_bucket_ptr
()
:
m_buckets_data
.
data
()),
m_bucket_count
(
other
.
m_bucket_count
),
m_nb_elements
(
other
.
m_nb_elements
),
m_load_threshold
(
other
.
m_load_threshold
),
m_max_load_factor
(
other
.
m_max_load_factor
),
m_grow_on_next_insert
(
other
.
m_grow_on_next_insert
),
m_min_load_factor
(
other
.
m_min_load_factor
),
m_try_skrink_on_next_insert
(
other
.
m_try_skrink_on_next_insert
)
{}
robin_hash
(
robin_hash
&&
other
)
noexcept
(
std
::
is_nothrow_move_constructible
<
Hash
>::
value
&&
std
::
is_nothrow_move_constructible
<
KeyEqual
>::
value
&&
std
::
is_nothrow_move_constructible
<
GrowthPolicy
>::
value
&&
std
::
is_nothrow_move_constructible
<
buckets_container_type
>::
value
)
:
Hash
(
std
::
move
(
static_cast
<
Hash
&>
(
other
))),
KeyEqual
(
std
::
move
(
static_cast
<
KeyEqual
&>
(
other
))),
GrowthPolicy
(
std
::
move
(
static_cast
<
GrowthPolicy
&>
(
other
))),
m_buckets_data
(
std
::
move
(
other
.
m_buckets_data
)),
m_buckets
(
m_buckets_data
.
empty
()
?
static_empty_bucket_ptr
()
:
m_buckets_data
.
data
()),
m_bucket_count
(
other
.
m_bucket_count
),
m_nb_elements
(
other
.
m_nb_elements
),
m_load_threshold
(
other
.
m_load_threshold
),
m_max_load_factor
(
other
.
m_max_load_factor
),
m_grow_on_next_insert
(
other
.
m_grow_on_next_insert
),
m_min_load_factor
(
other
.
m_min_load_factor
),
m_try_skrink_on_next_insert
(
other
.
m_try_skrink_on_next_insert
)
{
other
.
GrowthPolicy
::
clear
();
other
.
m_buckets_data
.
clear
();
other
.
m_buckets
=
static_empty_bucket_ptr
();
other
.
m_bucket_count
=
0
;
other
.
m_nb_elements
=
0
;
other
.
m_load_threshold
=
0
;
other
.
m_grow_on_next_insert
=
false
;
other
.
m_try_skrink_on_next_insert
=
false
;
}
robin_hash
&
operator
=
(
const
robin_hash
&
other
)
{
if
(
&
other
!=
this
)
{
Hash
::
operator
=
(
other
);
KeyEqual
::
operator
=
(
other
);
GrowthPolicy
::
operator
=
(
other
);
m_buckets_data
=
other
.
m_buckets_data
;
m_buckets
=
m_buckets_data
.
empty
()
?
static_empty_bucket_ptr
()
:
m_buckets_data
.
data
();
m_bucket_count
=
other
.
m_bucket_count
;
m_nb_elements
=
other
.
m_nb_elements
;
m_load_threshold
=
other
.
m_load_threshold
;
m_max_load_factor
=
other
.
m_max_load_factor
;
m_grow_on_next_insert
=
other
.
m_grow_on_next_insert
;
m_min_load_factor
=
other
.
m_min_load_factor
;
m_try_skrink_on_next_insert
=
other
.
m_try_skrink_on_next_insert
;
}
return
*
this
;
}
robin_hash
&
operator
=
(
robin_hash
&&
other
)
{
other
.
swap
(
*
this
);
other
.
clear
();
return
*
this
;
}
allocator_type
get_allocator
()
const
{
return
m_buckets_data
.
get_allocator
();
}
/*
* Iterators
*/
iterator
begin
()
noexcept
{
std
::
size_t
i
=
0
;
while
(
i
<
m_bucket_count
&&
m_buckets
[
i
].
empty
())
{
i
++
;
}
return
iterator
(
m_buckets
+
i
);
}
const_iterator
begin
()
const
noexcept
{
return
cbegin
();
}
const_iterator
cbegin
()
const
noexcept
{
std
::
size_t
i
=
0
;
while
(
i
<
m_bucket_count
&&
m_buckets
[
i
].
empty
())
{
i
++
;
}
return
const_iterator
(
m_buckets
+
i
);
}
iterator
end
()
noexcept
{
return
iterator
(
m_buckets
+
m_bucket_count
);
}
const_iterator
end
()
const
noexcept
{
return
cend
();
}
const_iterator
cend
()
const
noexcept
{
return
const_iterator
(
m_buckets
+
m_bucket_count
);
}
/*
* Capacity
*/
bool
empty
()
const
noexcept
{
return
m_nb_elements
==
0
;
}
size_type
size
()
const
noexcept
{
return
m_nb_elements
;
}
size_type
max_size
()
const
noexcept
{
return
m_buckets_data
.
max_size
();
}
/*
* Modifiers
*/
void
clear
()
noexcept
{
for
(
auto
&
bucket
:
m_buckets_data
)
{
bucket
.
clear
();
}
m_nb_elements
=
0
;
m_grow_on_next_insert
=
false
;
}
template
<
typename
P
>
std
::
pair
<
iterator
,
bool
>
insert
(
P
&&
value
)
{
return
insert_impl
(
KeySelect
()(
value
),
std
::
forward
<
P
>
(
value
));
}
template
<
typename
P
>
iterator
insert_hint
(
const_iterator
hint
,
P
&&
value
)
{
if
(
hint
!=
cend
()
&&
compare_keys
(
KeySelect
()(
*
hint
),
KeySelect
()(
value
)))
{
return
mutable_iterator
(
hint
);
}
return
insert
(
std
::
forward
<
P
>
(
value
)).
first
;
}
template
<
class
InputIt
>
void
insert
(
InputIt
first
,
InputIt
last
)
{
if
(
std
::
is_base_of
<
std
::
forward_iterator_tag
,
typename
std
::
iterator_traits
<
InputIt
>::
iterator_category
>::
value
)
{
const
auto
nb_elements_insert
=
std
::
distance
(
first
,
last
);
const
size_type
nb_free_buckets
=
m_load_threshold
-
size
();
tsl_rh_assert
(
m_load_threshold
>=
size
());
if
(
nb_elements_insert
>
0
&&
nb_free_buckets
<
size_type
(
nb_elements_insert
))
{
reserve
(
size
()
+
size_type
(
nb_elements_insert
));
}
}
for
(;
first
!=
last
;
++
first
)
{
insert
(
*
first
);
}
}
template
<
class
K
,
class
M
>
std
::
pair
<
iterator
,
bool
>
insert_or_assign
(
K
&&
key
,
M
&&
obj
)
{
auto
it
=
try_emplace
(
std
::
forward
<
K
>
(
key
),
std
::
forward
<
M
>
(
obj
));
if
(
!
it
.
second
)
{
it
.
first
.
value
()
=
std
::
forward
<
M
>
(
obj
);
}
return
it
;
}
template
<
class
K
,
class
M
>
iterator
insert_or_assign
(
const_iterator
hint
,
K
&&
key
,
M
&&
obj
)
{
if
(
hint
!=
cend
()
&&
compare_keys
(
KeySelect
()(
*
hint
),
key
))
{
auto
it
=
mutable_iterator
(
hint
);
it
.
value
()
=
std
::
forward
<
M
>
(
obj
);
return
it
;
}
return
insert_or_assign
(
std
::
forward
<
K
>
(
key
),
std
::
forward
<
M
>
(
obj
)).
first
;
}
template
<
class
...
Args
>
std
::
pair
<
iterator
,
bool
>
emplace
(
Args
&&
...
args
)
{
return
insert
(
value_type
(
std
::
forward
<
Args
>
(
args
)...));
}
template
<
class
...
Args
>
iterator
emplace_hint
(
const_iterator
hint
,
Args
&&
...
args
)
{
return
insert_hint
(
hint
,
value_type
(
std
::
forward
<
Args
>
(
args
)...));
}
template
<
class
K
,
class
...
Args
>
std
::
pair
<
iterator
,
bool
>
try_emplace
(
K
&&
key
,
Args
&&
...
args
)
{
return
insert_impl
(
key
,
std
::
piecewise_construct
,
std
::
forward_as_tuple
(
std
::
forward
<
K
>
(
key
)),
std
::
forward_as_tuple
(
std
::
forward
<
Args
>
(
args
)...));
}
template
<
class
K
,
class
...
Args
>
iterator
try_emplace_hint
(
const_iterator
hint
,
K
&&
key
,
Args
&&
...
args
)
{
if
(
hint
!=
cend
()
&&
compare_keys
(
KeySelect
()(
*
hint
),
key
))
{
return
mutable_iterator
(
hint
);
}
return
try_emplace
(
std
::
forward
<
K
>
(
key
),
std
::
forward
<
Args
>
(
args
)...).
first
;
}
/**
* Here to avoid `template<class K> size_type erase(const K& key)` being used
* when we use an `iterator` instead of a `const_iterator`.
*/
iterator
erase
(
iterator
pos
)
{
erase_from_bucket
(
pos
);
/**
* Erase bucket used a backward shift after clearing the bucket.
* Check if there is a new value in the bucket, if not get the next
* non-empty.
*/
if
(
pos
.
m_bucket
->
empty
())
{
++
pos
;
}
m_try_skrink_on_next_insert
=
true
;
return
pos
;
}
iterator
erase
(
const_iterator
pos
)
{
return
erase
(
mutable_iterator
(
pos
));
}
iterator
erase
(
const_iterator
first
,
const_iterator
last
)
{
if
(
first
==
last
)
{
return
mutable_iterator
(
first
);
}
auto
first_mutable
=
mutable_iterator
(
first
);
auto
last_mutable
=
mutable_iterator
(
last
);
for
(
auto
it
=
first_mutable
.
m_bucket
;
it
!=
last_mutable
.
m_bucket
;
++
it
)
{
if
(
!
it
->
empty
())
{
it
->
clear
();
m_nb_elements
--
;
}
}
if
(
last_mutable
==
end
())
{
return
end
();
}
/*
* Backward shift on the values which come after the deleted values.
* We try to move the values closer to their ideal bucket.
*/
std
::
size_t
icloser_bucket
=
static_cast
<
std
::
size_t
>
(
first_mutable
.
m_bucket
-
m_buckets
);
std
::
size_t
ito_move_closer_value
=
static_cast
<
std
::
size_t
>
(
last_mutable
.
m_bucket
-
m_buckets
);
tsl_rh_assert
(
ito_move_closer_value
>
icloser_bucket
);
const
std
::
size_t
ireturn_bucket
=
ito_move_closer_value
-
std
::
min
(
ito_move_closer_value
-
icloser_bucket
,
std
::
size_t
(
m_buckets
[
ito_move_closer_value
].
dist_from_ideal_bucket
()));
while
(
ito_move_closer_value
<
m_bucket_count
&&
m_buckets
[
ito_move_closer_value
].
dist_from_ideal_bucket
()
>
0
)
{
icloser_bucket
=
ito_move_closer_value
-
std
::
min
(
ito_move_closer_value
-
icloser_bucket
,
std
::
size_t
(
m_buckets
[
ito_move_closer_value
].
dist_from_ideal_bucket
()));
tsl_rh_assert
(
m_buckets
[
icloser_bucket
].
empty
());
const
distance_type
new_distance
=
distance_type
(
m_buckets
[
ito_move_closer_value
].
dist_from_ideal_bucket
()
-
(
ito_move_closer_value
-
icloser_bucket
));
m_buckets
[
icloser_bucket
].
set_value_of_empty_bucket
(
new_distance
,
m_buckets
[
ito_move_closer_value
].
truncated_hash
(),
std
::
move
(
m_buckets
[
ito_move_closer_value
].
value
()));
m_buckets
[
ito_move_closer_value
].
clear
();
++
icloser_bucket
;
++
ito_move_closer_value
;
}
m_try_skrink_on_next_insert
=
true
;
return
iterator
(
m_buckets
+
ireturn_bucket
);
}
template
<
class
K
>
size_type
erase
(
const
K
&
key
)
{
return
erase
(
key
,
hash_key
(
key
));
}
template
<
class
K
>
size_type
erase
(
const
K
&
key
,
std
::
size_t
hash
)
{
auto
it
=
find
(
key
,
hash
);
if
(
it
!=
end
())
{
erase_from_bucket
(
it
);
m_try_skrink_on_next_insert
=
true
;
return
1
;
}
else
{
return
0
;
}
}
void
swap
(
robin_hash
&
other
)
{
using
std
::
swap
;
swap
(
static_cast
<
Hash
&>
(
*
this
),
static_cast
<
Hash
&>
(
other
));
swap
(
static_cast
<
KeyEqual
&>
(
*
this
),
static_cast
<
KeyEqual
&>
(
other
));
swap
(
static_cast
<
GrowthPolicy
&>
(
*
this
),
static_cast
<
GrowthPolicy
&>
(
other
));
swap
(
m_buckets_data
,
other
.
m_buckets_data
);
swap
(
m_buckets
,
other
.
m_buckets
);
swap
(
m_bucket_count
,
other
.
m_bucket_count
);
swap
(
m_nb_elements
,
other
.
m_nb_elements
);
swap
(
m_load_threshold
,
other
.
m_load_threshold
);
swap
(
m_max_load_factor
,
other
.
m_max_load_factor
);
swap
(
m_grow_on_next_insert
,
other
.
m_grow_on_next_insert
);
swap
(
m_min_load_factor
,
other
.
m_min_load_factor
);
swap
(
m_try_skrink_on_next_insert
,
other
.
m_try_skrink_on_next_insert
);
}
/*
* Lookup
*/
template
<
class
K
,
class
U
=
ValueSelect
,
typename
std
::
enable_if
<
has_mapped_type
<
U
>
::
value
>::
type
*
=
nullptr
>
typename
U
::
value_type
&
at
(
const
K
&
key
)
{
return
at
(
key
,
hash_key
(
key
));
}
template
<
class
K
,
class
U
=
ValueSelect
,
typename
std
::
enable_if
<
has_mapped_type
<
U
>
::
value
>::
type
*
=
nullptr
>
typename
U
::
value_type
&
at
(
const
K
&
key
,
std
::
size_t
hash
)
{
return
const_cast
<
typename
U
::
value_type
&>
(
static_cast
<
const
robin_hash
*>
(
this
)
->
at
(
key
,
hash
));
}
template
<
class
K
,
class
U
=
ValueSelect
,
typename
std
::
enable_if
<
has_mapped_type
<
U
>
::
value
>::
type
*
=
nullptr
>
const
typename
U
::
value_type
&
at
(
const
K
&
key
)
const
{
return
at
(
key
,
hash_key
(
key
));
}
template
<
class
K
,
class
U
=
ValueSelect
,
typename
std
::
enable_if
<
has_mapped_type
<
U
>
::
value
>::
type
*
=
nullptr
>
const
typename
U
::
value_type
&
at
(
const
K
&
key
,
std
::
size_t
hash
)
const
{
auto
it
=
find
(
key
,
hash
);
if
(
it
!=
cend
())
{
return
it
.
value
();
}
else
{
TSL_RH_THROW_OR_TERMINATE
(
std
::
out_of_range
,
"Couldn't find key."
);
}
}
template
<
class
K
,
class
U
=
ValueSelect
,
typename
std
::
enable_if
<
has_mapped_type
<
U
>
::
value
>::
type
*
=
nullptr
>
typename
U
::
value_type
&
operator
[](
K
&&
key
)
{
return
try_emplace
(
std
::
forward
<
K
>
(
key
)).
first
.
value
();
}
template
<
class
K
>
size_type
count
(
const
K
&
key
)
const
{
return
count
(
key
,
hash_key
(
key
));
}
template
<
class
K
>
size_type
count
(
const
K
&
key
,
std
::
size_t
hash
)
const
{
if
(
find
(
key
,
hash
)
!=
cend
())
{
return
1
;
}
else
{
return
0
;
}
}
template
<
class
K
>
iterator
find
(
const
K
&
key
)
{
return
find_impl
(
key
,
hash_key
(
key
));
}
template
<
class
K
>
iterator
find
(
const
K
&
key
,
std
::
size_t
hash
)
{
return
find_impl
(
key
,
hash
);
}
template
<
class
K
>
const_iterator
find
(
const
K
&
key
)
const
{
return
find_impl
(
key
,
hash_key
(
key
));
}
template
<
class
K
>
const_iterator
find
(
const
K
&
key
,
std
::
size_t
hash
)
const
{
return
find_impl
(
key
,
hash
);
}
template
<
class
K
>
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
K
&
key
)
{
return
equal_range
(
key
,
hash_key
(
key
));
}
template
<
class
K
>
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
K
&
key
,
std
::
size_t
hash
)
{
iterator
it
=
find
(
key
,
hash
);
return
std
::
make_pair
(
it
,
(
it
==
end
())
?
it
:
std
::
next
(
it
));
}
template
<
class
K
>
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
K
&
key
)
const
{
return
equal_range
(
key
,
hash_key
(
key
));
}
template
<
class
K
>
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
K
&
key
,
std
::
size_t
hash
)
const
{
const_iterator
it
=
find
(
key
,
hash
);
return
std
::
make_pair
(
it
,
(
it
==
cend
())
?
it
:
std
::
next
(
it
));
}
/*
* Bucket interface
*/
size_type
bucket_count
()
const
{
return
m_bucket_count
;
}
size_type
max_bucket_count
()
const
{
return
std
::
min
(
GrowthPolicy
::
max_bucket_count
(),
m_buckets_data
.
max_size
());
}
/*
* Hash policy
*/
float
load_factor
()
const
{
if
(
bucket_count
()
==
0
)
{
return
0
;
}
return
float
(
m_nb_elements
)
/
float
(
bucket_count
());
}
float
min_load_factor
()
const
{
return
m_min_load_factor
;
}
float
max_load_factor
()
const
{
return
m_max_load_factor
;
}
void
min_load_factor
(
float
ml
)
{
m_min_load_factor
=
clamp
(
ml
,
float
(
MINIMUM_MIN_LOAD_FACTOR
),
float
(
MAXIMUM_MIN_LOAD_FACTOR
));
}
void
max_load_factor
(
float
ml
)
{
m_max_load_factor
=
clamp
(
ml
,
float
(
MINIMUM_MAX_LOAD_FACTOR
),
float
(
MAXIMUM_MAX_LOAD_FACTOR
));
m_load_threshold
=
size_type
(
float
(
bucket_count
())
*
m_max_load_factor
);
}
void
rehash
(
size_type
count
)
{
count
=
std
::
max
(
count
,
size_type
(
std
::
ceil
(
float
(
size
())
/
max_load_factor
())));
rehash_impl
(
count
);
}
void
reserve
(
size_type
count
)
{
rehash
(
size_type
(
std
::
ceil
(
float
(
count
)
/
max_load_factor
())));
}
/*
* Observers
*/
hasher
hash_function
()
const
{
return
static_cast
<
const
Hash
&>
(
*
this
);
}
key_equal
key_eq
()
const
{
return
static_cast
<
const
KeyEqual
&>
(
*
this
);
}
/*
* Other
*/
iterator
mutable_iterator
(
const_iterator
pos
)
{
return
iterator
(
const_cast
<
bucket_entry
*>
(
pos
.
m_bucket
));
}
private:
template
<
class
K
>
std
::
size_t
hash_key
(
const
K
&
key
)
const
{
return
Hash
::
operator
()(
key
);
}
template
<
class
K1
,
class
K2
>
bool
compare_keys
(
const
K1
&
key1
,
const
K2
&
key2
)
const
{
return
KeyEqual
::
operator
()(
key1
,
key2
);
}
std
::
size_t
bucket_for_hash
(
std
::
size_t
hash
)
const
{
const
std
::
size_t
bucket
=
GrowthPolicy
::
bucket_for_hash
(
hash
);
tsl_rh_assert
(
bucket
<
m_bucket_count
||
(
bucket
==
0
&&
m_bucket_count
==
0
));
return
bucket
;
}
template
<
class
U
=
GrowthPolicy
,
typename
std
::
enable_if
<
is_power_of_two_policy
<
U
>
::
value
>::
type
*
=
nullptr
>
std
::
size_t
next_bucket
(
std
::
size_t
index
)
const
noexcept
{
tsl_rh_assert
(
index
<
bucket_count
());
return
(
index
+
1
)
&
this
->
m_mask
;
}
template
<
class
U
=
GrowthPolicy
,
typename
std
::
enable_if
<!
is_power_of_two_policy
<
U
>
::
value
>::
type
*
=
nullptr
>
std
::
size_t
next_bucket
(
std
::
size_t
index
)
const
noexcept
{
tsl_rh_assert
(
index
<
bucket_count
());
index
++
;
return
(
index
!=
bucket_count
())
?
index
:
0
;
}
template
<
class
K
>
iterator
find_impl
(
const
K
&
key
,
std
::
size_t
hash
)
{
return
mutable_iterator
(
static_cast
<
const
robin_hash
*>
(
this
)
->
find
(
key
,
hash
));
}
template
<
class
K
>
const_iterator
find_impl
(
const
K
&
key
,
std
::
size_t
hash
)
const
{
std
::
size_t
ibucket
=
bucket_for_hash
(
hash
);
distance_type
dist_from_ideal_bucket
=
0
;
while
(
dist_from_ideal_bucket
<=
m_buckets
[
ibucket
].
dist_from_ideal_bucket
())
{
if
(
TSL_RH_LIKELY
(
(
!
USE_STORED_HASH_ON_LOOKUP
||
m_buckets
[
ibucket
].
bucket_hash_equal
(
hash
))
&&
compare_keys
(
KeySelect
()(
m_buckets
[
ibucket
].
value
()),
key
)))
{
return
const_iterator
(
m_buckets
+
ibucket
);
}
ibucket
=
next_bucket
(
ibucket
);
dist_from_ideal_bucket
++
;
}
return
cend
();
}
void
erase_from_bucket
(
iterator
pos
)
{
pos
.
m_bucket
->
clear
();
m_nb_elements
--
;
/**
* Backward shift, swap the empty bucket, previous_ibucket, with the values
* on its right, ibucket, until we cross another empty bucket or if the
* other bucket has a distance_from_ideal_bucket == 0.
*
* We try to move the values closer to their ideal bucket.
*/
std
::
size_t
previous_ibucket
=
static_cast
<
std
::
size_t
>
(
pos
.
m_bucket
-
m_buckets
);
std
::
size_t
ibucket
=
next_bucket
(
previous_ibucket
);
while
(
m_buckets
[
ibucket
].
dist_from_ideal_bucket
()
>
0
)
{
tsl_rh_assert
(
m_buckets
[
previous_ibucket
].
empty
());
const
distance_type
new_distance
=
distance_type
(
m_buckets
[
ibucket
].
dist_from_ideal_bucket
()
-
1
);
m_buckets
[
previous_ibucket
].
set_value_of_empty_bucket
(
new_distance
,
m_buckets
[
ibucket
].
truncated_hash
(),
std
::
move
(
m_buckets
[
ibucket
].
value
()));
m_buckets
[
ibucket
].
clear
();
previous_ibucket
=
ibucket
;
ibucket
=
next_bucket
(
ibucket
);
}
}
template
<
class
K
,
class
...
Args
>
std
::
pair
<
iterator
,
bool
>
insert_impl
(
const
K
&
key
,
Args
&&
...
value_type_args
)
{
const
std
::
size_t
hash
=
hash_key
(
key
);
std
::
size_t
ibucket
=
bucket_for_hash
(
hash
);
distance_type
dist_from_ideal_bucket
=
0
;
while
(
dist_from_ideal_bucket
<=
m_buckets
[
ibucket
].
dist_from_ideal_bucket
())
{
if
((
!
USE_STORED_HASH_ON_LOOKUP
||
m_buckets
[
ibucket
].
bucket_hash_equal
(
hash
))
&&
compare_keys
(
KeySelect
()(
m_buckets
[
ibucket
].
value
()),
key
))
{
return
std
::
make_pair
(
iterator
(
m_buckets
+
ibucket
),
false
);
}
ibucket
=
next_bucket
(
ibucket
);
dist_from_ideal_bucket
++
;
}
if
(
rehash_on_extreme_load
())
{
ibucket
=
bucket_for_hash
(
hash
);
dist_from_ideal_bucket
=
0
;
while
(
dist_from_ideal_bucket
<=
m_buckets
[
ibucket
].
dist_from_ideal_bucket
())
{
ibucket
=
next_bucket
(
ibucket
);
dist_from_ideal_bucket
++
;
}
}
if
(
m_buckets
[
ibucket
].
empty
())
{
m_buckets
[
ibucket
].
set_value_of_empty_bucket
(
dist_from_ideal_bucket
,
bucket_entry
::
truncate_hash
(
hash
),
std
::
forward
<
Args
>
(
value_type_args
)...);
}
else
{
insert_value
(
ibucket
,
dist_from_ideal_bucket
,
bucket_entry
::
truncate_hash
(
hash
),
std
::
forward
<
Args
>
(
value_type_args
)...);
}
m_nb_elements
++
;
/*
* The value will be inserted in ibucket in any case, either because it was
* empty or by stealing the bucket (robin hood).
*/
return
std
::
make_pair
(
iterator
(
m_buckets
+
ibucket
),
true
);
}
template
<
class
...
Args
>
void
insert_value
(
std
::
size_t
ibucket
,
distance_type
dist_from_ideal_bucket
,
truncated_hash_type
hash
,
Args
&&
...
value_type_args
)
{
value_type
value
(
std
::
forward
<
Args
>
(
value_type_args
)...);
insert_value_impl
(
ibucket
,
dist_from_ideal_bucket
,
hash
,
value
);
}
void
insert_value
(
std
::
size_t
ibucket
,
distance_type
dist_from_ideal_bucket
,
truncated_hash_type
hash
,
value_type
&&
value
)
{
insert_value_impl
(
ibucket
,
dist_from_ideal_bucket
,
hash
,
value
);
}
/*
* We don't use `value_type&& value` as last argument due to a bug in MSVC
* when `value_type` is a pointer, The compiler is not able to see the
* difference between `std::string*` and `std::string*&&` resulting in compile
* error.
*
* The `value` will be in a moved state at the end of the function.
*/
void
insert_value_impl
(
std
::
size_t
ibucket
,
distance_type
dist_from_ideal_bucket
,
truncated_hash_type
hash
,
value_type
&
value
)
{
m_buckets
[
ibucket
].
swap_with_value_in_bucket
(
dist_from_ideal_bucket
,
hash
,
value
);
ibucket
=
next_bucket
(
ibucket
);
dist_from_ideal_bucket
++
;
while
(
!
m_buckets
[
ibucket
].
empty
())
{
if
(
dist_from_ideal_bucket
>
m_buckets
[
ibucket
].
dist_from_ideal_bucket
())
{
if
(
dist_from_ideal_bucket
>=
REHASH_ON_HIGH_NB_PROBES__NPROBES
&&
load_factor
()
>=
REHASH_ON_HIGH_NB_PROBES__MIN_LOAD_FACTOR
)
{
/**
* The number of probes is really high, rehash the map on the next
* insert. Difficult to do now as rehash may throw an exception.
*/
m_grow_on_next_insert
=
true
;
}
m_buckets
[
ibucket
].
swap_with_value_in_bucket
(
dist_from_ideal_bucket
,
hash
,
value
);
}
ibucket
=
next_bucket
(
ibucket
);
dist_from_ideal_bucket
++
;
}
m_buckets
[
ibucket
].
set_value_of_empty_bucket
(
dist_from_ideal_bucket
,
hash
,
std
::
move
(
value
));
}
void
rehash_impl
(
size_type
count
)
{
robin_hash
new_table
(
count
,
static_cast
<
Hash
&>
(
*
this
),
static_cast
<
KeyEqual
&>
(
*
this
),
get_allocator
(),
m_min_load_factor
,
m_max_load_factor
);
const
bool
use_stored_hash
=
USE_STORED_HASH_ON_REHASH
(
new_table
.
bucket_count
());
for
(
auto
&
bucket
:
m_buckets_data
)
{
if
(
bucket
.
empty
())
{
continue
;
}
const
std
::
size_t
hash
=
use_stored_hash
?
bucket
.
truncated_hash
()
:
new_table
.
hash_key
(
KeySelect
()(
bucket
.
value
()));
new_table
.
insert_value_on_rehash
(
new_table
.
bucket_for_hash
(
hash
),
0
,
bucket_entry
::
truncate_hash
(
hash
),
std
::
move
(
bucket
.
value
()));
}
new_table
.
m_nb_elements
=
m_nb_elements
;
new_table
.
swap
(
*
this
);
}
void
insert_value_on_rehash
(
std
::
size_t
ibucket
,
distance_type
dist_from_ideal_bucket
,
truncated_hash_type
hash
,
value_type
&&
value
)
{
while
(
true
)
{
if
(
dist_from_ideal_bucket
>
m_buckets
[
ibucket
].
dist_from_ideal_bucket
())
{
if
(
m_buckets
[
ibucket
].
empty
())
{
m_buckets
[
ibucket
].
set_value_of_empty_bucket
(
dist_from_ideal_bucket
,
hash
,
std
::
move
(
value
));
return
;
}
else
{
m_buckets
[
ibucket
].
swap_with_value_in_bucket
(
dist_from_ideal_bucket
,
hash
,
value
);
}
}
dist_from_ideal_bucket
++
;
ibucket
=
next_bucket
(
ibucket
);
}
}
/**
* Grow the table if m_grow_on_next_insert is true or we reached the
* max_load_factor. Shrink the table if m_try_skrink_on_next_insert is true
* (an erase occured) and we're below the min_load_factor.
*
* Return true if the table has been rehashed.
*/
bool
rehash_on_extreme_load
()
{
if
(
m_grow_on_next_insert
||
size
()
>=
m_load_threshold
)
{
rehash_impl
(
GrowthPolicy
::
next_bucket_count
());
m_grow_on_next_insert
=
false
;
return
true
;
}
if
(
m_try_skrink_on_next_insert
)
{
m_try_skrink_on_next_insert
=
false
;
if
(
m_min_load_factor
!=
0.0
f
&&
load_factor
()
<
m_min_load_factor
)
{
reserve
(
size
()
+
1
);
return
true
;
}
}
return
false
;
}
public:
static
const
size_type
DEFAULT_INIT_BUCKETS_SIZE
=
0
;
static
constexpr
float
DEFAULT_MAX_LOAD_FACTOR
=
0.5
f
;
static
constexpr
float
MINIMUM_MAX_LOAD_FACTOR
=
0.2
f
;
static
constexpr
float
MAXIMUM_MAX_LOAD_FACTOR
=
0.95
f
;
static
constexpr
float
DEFAULT_MIN_LOAD_FACTOR
=
0.0
f
;
static
constexpr
float
MINIMUM_MIN_LOAD_FACTOR
=
0.0
f
;
static
constexpr
float
MAXIMUM_MIN_LOAD_FACTOR
=
0.15
f
;
static_assert
(
MINIMUM_MAX_LOAD_FACTOR
<
MAXIMUM_MAX_LOAD_FACTOR
,
"MINIMUM_MAX_LOAD_FACTOR should be < MAXIMUM_MAX_LOAD_FACTOR"
);
static_assert
(
MINIMUM_MIN_LOAD_FACTOR
<
MAXIMUM_MIN_LOAD_FACTOR
,
"MINIMUM_MIN_LOAD_FACTOR should be < MAXIMUM_MIN_LOAD_FACTOR"
);
static_assert
(
MAXIMUM_MIN_LOAD_FACTOR
<
MINIMUM_MAX_LOAD_FACTOR
,
"MAXIMUM_MIN_LOAD_FACTOR should be < MINIMUM_MAX_LOAD_FACTOR"
);
private:
static
const
distance_type
REHASH_ON_HIGH_NB_PROBES__NPROBES
=
128
;
static
constexpr
float
REHASH_ON_HIGH_NB_PROBES__MIN_LOAD_FACTOR
=
0.15
f
;
/**
* Return an always valid pointer to an static empty bucket_entry with
* last_bucket() == true.
*/
bucket_entry
*
static_empty_bucket_ptr
()
{
static
bucket_entry
empty_bucket
(
true
);
return
&
empty_bucket
;
}
private:
buckets_container_type
m_buckets_data
;
/**
* Points to m_buckets_data.data() if !m_buckets_data.empty() otherwise points
* to static_empty_bucket_ptr. This variable is useful to avoid the cost of
* checking if m_buckets_data is empty when trying to find an element.
*
* TODO Remove m_buckets_data and only use a pointer instead of a
* pointer+vector to save some space in the robin_hash object. Manage the
* Allocator manually.
*/
bucket_entry
*
m_buckets
;
/**
* Used a lot in find, avoid the call to m_buckets_data.size() which is a bit
* slower.
*/
size_type
m_bucket_count
;
size_type
m_nb_elements
;
size_type
m_load_threshold
;
float
m_max_load_factor
;
bool
m_grow_on_next_insert
;
float
m_min_load_factor
;
/**
* We can't shrink down the map on erase operations as the erase methods need
* to return the next iterator. Shrinking the map would invalidate all the
* iterators and we could not return the next iterator in a meaningful way, On
* erase, we thus just indicate on erase that we should try to shrink the hash
* table on the next insert if we go below the min_load_factor.
*/
bool
m_try_skrink_on_next_insert
;
};
}
// namespace detail_robin_hash
}
// namespace tsl
#endif
include/tsl/robin_map.h
deleted
100644 → 0
View file @
fad30002
/**
* MIT License
*
* Copyright (c) 2017 Tessil
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_ROBIN_MAP_H
#define TSL_ROBIN_MAP_H
#include "robin_hash.h"
#include <cstddef>
#include <functional>
#include <initializer_list>
#include <memory>
#include <type_traits>
#include <utility>
namespace
tsl
{
/**
* Implementation of a hash map using open-adressing and the robin hood hashing
* algorithm with backward shift deletion.
*
* For operations modifying the hash map (insert, erase, rehash, ...), the
* strong exception guarantee is only guaranteed when the expression
* `std::is_nothrow_swappable<std::pair<Key, T>>::value &&
* std::is_nothrow_move_constructible<std::pair<Key, T>>::value` is true,
* otherwise if an exception is thrown during the swap or the move, the hash map
* may end up in a undefined state. Per the standard a `Key` or `T` with a
* noexcept copy constructor and no move constructor also satisfies the
* `std::is_nothrow_move_constructible<std::pair<Key, T>>::value` criterion (and
* will thus guarantee the strong exception for the map).
*
* When `StoreHash` is true, 32 bits of the hash are stored alongside the
* values. It can improve the performance during lookups if the `KeyEqual`
* function takes time (if it engenders a cache-miss for example) as we then
* compare the stored hashes before comparing the keys. When
* `tsl::rh::power_of_two_growth_policy` is used as `GrowthPolicy`, it may also
* speed-up the rehash process as we can avoid to recalculate the hash. When it
* is detected that storing the hash will not incur any memory penality due to
* alignement (i.e. `sizeof(tsl::detail_robin_hash::bucket_entry<ValueType,
* true>) == sizeof(tsl::detail_robin_hash::bucket_entry<ValueType, false>)`)
* and `tsl::rh::power_of_two_growth_policy` is used, the hash will be stored
* even if `StoreHash` is false so that we can speed-up the rehash (but it will
* not be used on lookups unless `StoreHash` is true).
*
* `GrowthPolicy` defines how the map grows and consequently how a hash value is
* mapped to a bucket. By default the map uses
* `tsl::rh::power_of_two_growth_policy`. This policy keeps the number of
* buckets to a power of two and uses a mask to map the hash to a bucket instead
* of the slow modulo. Other growth policies are available and you may define
* your own growth policy, check `tsl::rh::power_of_two_growth_policy` for the
* interface.
*
* `std::pair<Key, T>` must be swappable.
*
* `Key` and `T` must be copy and/or move constructible.
*
* If the destructor of `Key` or `T` throws an exception, the behaviour of the
* class is undefined.
*
* Iterators invalidation:
* - clear, operator=, reserve, rehash: always invalidate the iterators.
* - insert, emplace, emplace_hint, operator[]: if there is an effective
* insert, invalidate the iterators.
* - erase: always invalidate the iterators.
*/
template
<
class
Key
,
class
T
,
class
Hash
=
std
::
hash
<
Key
>,
class
KeyEqual
=
std
::
equal_to
<
Key
>
,
class
Allocator
=
std
::
allocator
<
std
::
pair
<
Key
,
T
>>
,
bool
StoreHash
=
false
,
class
GrowthPolicy
=
tsl
::
rh
::
power_of_two_growth_policy
<
2
>>
class
robin_map
{
private:
template
<
typename
U
>
using
has_is_transparent
=
tsl
::
detail_robin_hash
::
has_is_transparent
<
U
>
;
class
KeySelect
{
public:
using
key_type
=
Key
;
const
key_type
&
operator
()(
const
std
::
pair
<
Key
,
T
>
&
key_value
)
const
noexcept
{
return
key_value
.
first
;
}
key_type
&
operator
()(
std
::
pair
<
Key
,
T
>
&
key_value
)
noexcept
{
return
key_value
.
first
;
}
};
class
ValueSelect
{
public:
using
value_type
=
T
;
const
value_type
&
operator
()(
const
std
::
pair
<
Key
,
T
>
&
key_value
)
const
noexcept
{
return
key_value
.
second
;
}
value_type
&
operator
()(
std
::
pair
<
Key
,
T
>
&
key_value
)
noexcept
{
return
key_value
.
second
;
}
};
using
ht
=
detail_robin_hash
::
robin_hash
<
std
::
pair
<
Key
,
T
>
,
KeySelect
,
ValueSelect
,
Hash
,
KeyEqual
,
Allocator
,
StoreHash
,
GrowthPolicy
>
;
public:
using
key_type
=
typename
ht
::
key_type
;
using
mapped_type
=
T
;
using
value_type
=
typename
ht
::
value_type
;
using
size_type
=
typename
ht
::
size_type
;
using
difference_type
=
typename
ht
::
difference_type
;
using
hasher
=
typename
ht
::
hasher
;
using
key_equal
=
typename
ht
::
key_equal
;
using
allocator_type
=
typename
ht
::
allocator_type
;
using
reference
=
typename
ht
::
reference
;
using
const_reference
=
typename
ht
::
const_reference
;
using
pointer
=
typename
ht
::
pointer
;
using
const_pointer
=
typename
ht
::
const_pointer
;
using
iterator
=
typename
ht
::
iterator
;
using
const_iterator
=
typename
ht
::
const_iterator
;
public:
/*
* Constructors
*/
robin_map
()
:
robin_map
(
ht
::
DEFAULT_INIT_BUCKETS_SIZE
)
{}
explicit
robin_map
(
size_type
bucket_count
,
const
Hash
&
hash
=
Hash
(),
const
KeyEqual
&
equal
=
KeyEqual
(),
const
Allocator
&
alloc
=
Allocator
())
:
m_ht
(
bucket_count
,
hash
,
equal
,
alloc
)
{}
robin_map
(
size_type
bucket_count
,
const
Allocator
&
alloc
)
:
robin_map
(
bucket_count
,
Hash
(),
KeyEqual
(),
alloc
)
{}
robin_map
(
size_type
bucket_count
,
const
Hash
&
hash
,
const
Allocator
&
alloc
)
:
robin_map
(
bucket_count
,
hash
,
KeyEqual
(),
alloc
)
{}
explicit
robin_map
(
const
Allocator
&
alloc
)
:
robin_map
(
ht
::
DEFAULT_INIT_BUCKETS_SIZE
,
alloc
)
{}
template
<
class
InputIt
>
robin_map
(
InputIt
first
,
InputIt
last
,
size_type
bucket_count
=
ht
::
DEFAULT_INIT_BUCKETS_SIZE
,
const
Hash
&
hash
=
Hash
(),
const
KeyEqual
&
equal
=
KeyEqual
(),
const
Allocator
&
alloc
=
Allocator
())
:
robin_map
(
bucket_count
,
hash
,
equal
,
alloc
)
{
insert
(
first
,
last
);
}
template
<
class
InputIt
>
robin_map
(
InputIt
first
,
InputIt
last
,
size_type
bucket_count
,
const
Allocator
&
alloc
)
:
robin_map
(
first
,
last
,
bucket_count
,
Hash
(),
KeyEqual
(),
alloc
)
{}
template
<
class
InputIt
>
robin_map
(
InputIt
first
,
InputIt
last
,
size_type
bucket_count
,
const
Hash
&
hash
,
const
Allocator
&
alloc
)
:
robin_map
(
first
,
last
,
bucket_count
,
hash
,
KeyEqual
(),
alloc
)
{}
robin_map
(
std
::
initializer_list
<
value_type
>
init
,
size_type
bucket_count
=
ht
::
DEFAULT_INIT_BUCKETS_SIZE
,
const
Hash
&
hash
=
Hash
(),
const
KeyEqual
&
equal
=
KeyEqual
(),
const
Allocator
&
alloc
=
Allocator
())
:
robin_map
(
init
.
begin
(),
init
.
end
(),
bucket_count
,
hash
,
equal
,
alloc
)
{}
robin_map
(
std
::
initializer_list
<
value_type
>
init
,
size_type
bucket_count
,
const
Allocator
&
alloc
)
:
robin_map
(
init
.
begin
(),
init
.
end
(),
bucket_count
,
Hash
(),
KeyEqual
(),
alloc
)
{}
robin_map
(
std
::
initializer_list
<
value_type
>
init
,
size_type
bucket_count
,
const
Hash
&
hash
,
const
Allocator
&
alloc
)
:
robin_map
(
init
.
begin
(),
init
.
end
(),
bucket_count
,
hash
,
KeyEqual
(),
alloc
)
{}
robin_map
&
operator
=
(
std
::
initializer_list
<
value_type
>
ilist
)
{
m_ht
.
clear
();
m_ht
.
reserve
(
ilist
.
size
());
m_ht
.
insert
(
ilist
.
begin
(),
ilist
.
end
());
return
*
this
;
}
allocator_type
get_allocator
()
const
{
return
m_ht
.
get_allocator
();
}
/*
* Iterators
*/
iterator
begin
()
noexcept
{
return
m_ht
.
begin
();
}
const_iterator
begin
()
const
noexcept
{
return
m_ht
.
begin
();
}
const_iterator
cbegin
()
const
noexcept
{
return
m_ht
.
cbegin
();
}
iterator
end
()
noexcept
{
return
m_ht
.
end
();
}
const_iterator
end
()
const
noexcept
{
return
m_ht
.
end
();
}
const_iterator
cend
()
const
noexcept
{
return
m_ht
.
cend
();
}
/*
* Capacity
*/
bool
empty
()
const
noexcept
{
return
m_ht
.
empty
();
}
size_type
size
()
const
noexcept
{
return
m_ht
.
size
();
}
size_type
max_size
()
const
noexcept
{
return
m_ht
.
max_size
();
}
/*
* Modifiers
*/
void
clear
()
noexcept
{
m_ht
.
clear
();
}
std
::
pair
<
iterator
,
bool
>
insert
(
const
value_type
&
value
)
{
return
m_ht
.
insert
(
value
);
}
template
<
class
P
,
typename
std
::
enable_if
<
std
::
is_constructible
<
value_type
,
P
&&
>
::
value
>::
type
*
=
nullptr
>
std
::
pair
<
iterator
,
bool
>
insert
(
P
&&
value
)
{
return
m_ht
.
emplace
(
std
::
forward
<
P
>
(
value
));
}
std
::
pair
<
iterator
,
bool
>
insert
(
value_type
&&
value
)
{
return
m_ht
.
insert
(
std
::
move
(
value
));
}
iterator
insert
(
const_iterator
hint
,
const
value_type
&
value
)
{
return
m_ht
.
insert_hint
(
hint
,
value
);
}
template
<
class
P
,
typename
std
::
enable_if
<
std
::
is_constructible
<
value_type
,
P
&&
>
::
value
>::
type
*
=
nullptr
>
iterator
insert
(
const_iterator
hint
,
P
&&
value
)
{
return
m_ht
.
emplace_hint
(
hint
,
std
::
forward
<
P
>
(
value
));
}
iterator
insert
(
const_iterator
hint
,
value_type
&&
value
)
{
return
m_ht
.
insert_hint
(
hint
,
std
::
move
(
value
));
}
template
<
class
InputIt
>
void
insert
(
InputIt
first
,
InputIt
last
)
{
m_ht
.
insert
(
first
,
last
);
}
void
insert
(
std
::
initializer_list
<
value_type
>
ilist
)
{
m_ht
.
insert
(
ilist
.
begin
(),
ilist
.
end
());
}
template
<
class
M
>
std
::
pair
<
iterator
,
bool
>
insert_or_assign
(
const
key_type
&
k
,
M
&&
obj
)
{
return
m_ht
.
insert_or_assign
(
k
,
std
::
forward
<
M
>
(
obj
));
}
template
<
class
M
>
std
::
pair
<
iterator
,
bool
>
insert_or_assign
(
key_type
&&
k
,
M
&&
obj
)
{
return
m_ht
.
insert_or_assign
(
std
::
move
(
k
),
std
::
forward
<
M
>
(
obj
));
}
template
<
class
M
>
iterator
insert_or_assign
(
const_iterator
hint
,
const
key_type
&
k
,
M
&&
obj
)
{
return
m_ht
.
insert_or_assign
(
hint
,
k
,
std
::
forward
<
M
>
(
obj
));
}
template
<
class
M
>
iterator
insert_or_assign
(
const_iterator
hint
,
key_type
&&
k
,
M
&&
obj
)
{
return
m_ht
.
insert_or_assign
(
hint
,
std
::
move
(
k
),
std
::
forward
<
M
>
(
obj
));
}
/**
* Due to the way elements are stored, emplace will need to move or copy the
* key-value once. The method is equivalent to
* insert(value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template
<
class
...
Args
>
std
::
pair
<
iterator
,
bool
>
emplace
(
Args
&&
...
args
)
{
return
m_ht
.
emplace
(
std
::
forward
<
Args
>
(
args
)...);
}
/**
* Due to the way elements are stored, emplace_hint will need to move or copy
* the key-value once. The method is equivalent to insert(hint,
* value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template
<
class
...
Args
>
iterator
emplace_hint
(
const_iterator
hint
,
Args
&&
...
args
)
{
return
m_ht
.
emplace_hint
(
hint
,
std
::
forward
<
Args
>
(
args
)...);
}
template
<
class
...
Args
>
std
::
pair
<
iterator
,
bool
>
try_emplace
(
const
key_type
&
k
,
Args
&&
...
args
)
{
return
m_ht
.
try_emplace
(
k
,
std
::
forward
<
Args
>
(
args
)...);
}
template
<
class
...
Args
>
std
::
pair
<
iterator
,
bool
>
try_emplace
(
key_type
&&
k
,
Args
&&
...
args
)
{
return
m_ht
.
try_emplace
(
std
::
move
(
k
),
std
::
forward
<
Args
>
(
args
)...);
}
template
<
class
...
Args
>
iterator
try_emplace
(
const_iterator
hint
,
const
key_type
&
k
,
Args
&&
...
args
)
{
return
m_ht
.
try_emplace_hint
(
hint
,
k
,
std
::
forward
<
Args
>
(
args
)...);
}
template
<
class
...
Args
>
iterator
try_emplace
(
const_iterator
hint
,
key_type
&&
k
,
Args
&&
...
args
)
{
return
m_ht
.
try_emplace_hint
(
hint
,
std
::
move
(
k
),
std
::
forward
<
Args
>
(
args
)...);
}
iterator
erase
(
iterator
pos
)
{
return
m_ht
.
erase
(
pos
);
}
iterator
erase
(
const_iterator
pos
)
{
return
m_ht
.
erase
(
pos
);
}
iterator
erase
(
const_iterator
first
,
const_iterator
last
)
{
return
m_ht
.
erase
(
first
,
last
);
}
size_type
erase
(
const
key_type
&
key
)
{
return
m_ht
.
erase
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup to the value if you already have the hash.
*/
size_type
erase
(
const
key_type
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
erase
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef
* KeyEqual::is_transparent exists. If so, K must be hashable and comparable
* to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
size_type
erase
(
const
K
&
key
)
{
return
m_ht
.
erase
(
key
);
}
/**
* @copydoc erase(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup to the value if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
size_type
erase
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
erase
(
key
,
precalculated_hash
);
}
void
swap
(
robin_map
&
other
)
{
other
.
m_ht
.
swap
(
m_ht
);
}
/*
* Lookup
*/
T
&
at
(
const
Key
&
key
)
{
return
m_ht
.
at
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup if you already have the hash.
*/
T
&
at
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
at
(
key
,
precalculated_hash
);
}
const
T
&
at
(
const
Key
&
key
)
const
{
return
m_ht
.
at
(
key
);
}
/**
* @copydoc at(const Key& key, std::size_t precalculated_hash)
*/
const
T
&
at
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
at
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef
* KeyEqual::is_transparent exists. If so, K must be hashable and comparable
* to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
T
&
at
(
const
K
&
key
)
{
return
m_ht
.
at
(
key
);
}
/**
* @copydoc at(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
T
&
at
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
at
(
key
,
precalculated_hash
);
}
/**
* @copydoc at(const K& key)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
const
T
&
at
(
const
K
&
key
)
const
{
return
m_ht
.
at
(
key
);
}
/**
* @copydoc at(const K& key, std::size_t precalculated_hash)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
const
T
&
at
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
at
(
key
,
precalculated_hash
);
}
T
&
operator
[](
const
Key
&
key
)
{
return
m_ht
[
key
];
}
T
&
operator
[](
Key
&&
key
)
{
return
m_ht
[
std
::
move
(
key
)];
}
size_type
count
(
const
Key
&
key
)
const
{
return
m_ht
.
count
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup if you already have the hash.
*/
size_type
count
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
count
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef
* KeyEqual::is_transparent exists. If so, K must be hashable and comparable
* to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
size_type
count
(
const
K
&
key
)
const
{
return
m_ht
.
count
(
key
);
}
/**
* @copydoc count(const K& key) const
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
size_type
count
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
count
(
key
,
precalculated_hash
);
}
iterator
find
(
const
Key
&
key
)
{
return
m_ht
.
find
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup if you already have the hash.
*/
iterator
find
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
find
(
key
,
precalculated_hash
);
}
const_iterator
find
(
const
Key
&
key
)
const
{
return
m_ht
.
find
(
key
);
}
/**
* @copydoc find(const Key& key, std::size_t precalculated_hash)
*/
const_iterator
find
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
find
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef
* KeyEqual::is_transparent exists. If so, K must be hashable and comparable
* to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
iterator
find
(
const
K
&
key
)
{
return
m_ht
.
find
(
key
);
}
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
iterator
find
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
find
(
key
,
precalculated_hash
);
}
/**
* @copydoc find(const K& key)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
const_iterator
find
(
const
K
&
key
)
const
{
return
m_ht
.
find
(
key
);
}
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
const_iterator
find
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
find
(
key
,
precalculated_hash
);
}
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
Key
&
key
)
{
return
m_ht
.
equal_range
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup if you already have the hash.
*/
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
equal_range
(
key
,
precalculated_hash
);
}
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
Key
&
key
)
const
{
return
m_ht
.
equal_range
(
key
);
}
/**
* @copydoc equal_range(const Key& key, std::size_t precalculated_hash)
*/
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
equal_range
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef
* KeyEqual::is_transparent exists. If so, K must be hashable and comparable
* to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
K
&
key
)
{
return
m_ht
.
equal_range
(
key
);
}
/**
* @copydoc equal_range(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
equal_range
(
key
,
precalculated_hash
);
}
/**
* @copydoc equal_range(const K& key)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
K
&
key
)
const
{
return
m_ht
.
equal_range
(
key
);
}
/**
* @copydoc equal_range(const K& key, std::size_t precalculated_hash)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
equal_range
(
key
,
precalculated_hash
);
}
/*
* Bucket interface
*/
size_type
bucket_count
()
const
{
return
m_ht
.
bucket_count
();
}
size_type
max_bucket_count
()
const
{
return
m_ht
.
max_bucket_count
();
}
/*
* Hash policy
*/
float
load_factor
()
const
{
return
m_ht
.
load_factor
();
}
float
min_load_factor
()
const
{
return
m_ht
.
min_load_factor
();
}
float
max_load_factor
()
const
{
return
m_ht
.
max_load_factor
();
}
/**
* Set the `min_load_factor` to `ml`. When the `load_factor` of the map goes
* below `min_load_factor` after some erase operations, the map will be
* shrunk when an insertion occurs. The erase method itself never shrinks
* the map.
*
* The default value of `min_load_factor` is 0.0f, the map never shrinks by
* default.
*/
void
min_load_factor
(
float
ml
)
{
m_ht
.
min_load_factor
(
ml
);
}
void
max_load_factor
(
float
ml
)
{
m_ht
.
max_load_factor
(
ml
);
}
void
rehash
(
size_type
count
)
{
m_ht
.
rehash
(
count
);
}
void
reserve
(
size_type
count
)
{
m_ht
.
reserve
(
count
);
}
/*
* Observers
*/
hasher
hash_function
()
const
{
return
m_ht
.
hash_function
();
}
key_equal
key_eq
()
const
{
return
m_ht
.
key_eq
();
}
/*
* Other
*/
/**
* Convert a const_iterator to an iterator.
*/
iterator
mutable_iterator
(
const_iterator
pos
)
{
return
m_ht
.
mutable_iterator
(
pos
);
}
friend
bool
operator
==
(
const
robin_map
&
lhs
,
const
robin_map
&
rhs
)
{
if
(
lhs
.
size
()
!=
rhs
.
size
())
{
return
false
;
}
for
(
const
auto
&
element_lhs
:
lhs
)
{
const
auto
it_element_rhs
=
rhs
.
find
(
element_lhs
.
first
);
if
(
it_element_rhs
==
rhs
.
cend
()
||
element_lhs
.
second
!=
it_element_rhs
->
second
)
{
return
false
;
}
}
return
true
;
}
friend
bool
operator
!=
(
const
robin_map
&
lhs
,
const
robin_map
&
rhs
)
{
return
!
operator
==
(
lhs
,
rhs
);
}
friend
void
swap
(
robin_map
&
lhs
,
robin_map
&
rhs
)
{
lhs
.
swap
(
rhs
);
}
private:
ht
m_ht
;
};
/**
* Same as `tsl::robin_map<Key, T, Hash, KeyEqual, Allocator, StoreHash,
* tsl::rh::prime_growth_policy>`.
*/
template
<
class
Key
,
class
T
,
class
Hash
=
std
::
hash
<
Key
>,
class
KeyEqual
=
std
::
equal_to
<
Key
>
,
class
Allocator
=
std
::
allocator
<
std
::
pair
<
Key
,
T
>>
,
bool
StoreHash
=
false
>
using
robin_pg_map
=
robin_map
<
Key
,
T
,
Hash
,
KeyEqual
,
Allocator
,
StoreHash
,
tsl
::
rh
::
prime_growth_policy
>
;
}
// end namespace tsl
#endif
include/utility/timer.h
deleted
100644 → 0
View file @
fad30002
// Copyright 2019 Yan Yan
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <chrono>
#ifdef TV_CUDA
#include <cuda_runtime_api.h>
#endif
#include <iostream>
namespace
spconv
{
#ifdef TV_CUDA
template
<
typename
TimeT
=
std
::
chrono
::
microseconds
>
struct
CudaContextTimer
{
CudaContextTimer
()
{
cudaDeviceSynchronize
();
mCurTime
=
std
::
chrono
::
steady_clock
::
now
();
}
typename
TimeT
::
rep
report
()
{
cudaDeviceSynchronize
();
auto
duration
=
std
::
chrono
::
duration_cast
<
TimeT
>
(
std
::
chrono
::
steady_clock
::
now
()
-
mCurTime
);
auto
res
=
duration
.
count
();
mCurTime
=
std
::
chrono
::
steady_clock
::
now
();
return
res
;
}
private:
std
::
chrono
::
time_point
<
std
::
chrono
::
steady_clock
>
mCurTime
;
};
#endif
template
<
typename
TimeT
=
std
::
chrono
::
microseconds
>
struct
CPUTimer
{
CPUTimer
()
{
mCurTime
=
std
::
chrono
::
steady_clock
::
now
();
}
typename
TimeT
::
rep
report
()
{
auto
duration
=
std
::
chrono
::
duration_cast
<
TimeT
>
(
std
::
chrono
::
steady_clock
::
now
()
-
mCurTime
);
auto
res
=
duration
.
count
();
mCurTime
=
std
::
chrono
::
steady_clock
::
now
();
return
res
;
}
private:
std
::
chrono
::
time_point
<
std
::
chrono
::
steady_clock
>
mCurTime
;
};
}
// namespace spconv
pyproject.toml
0 → 100644
View file @
a6abf55d
[build-system]
requires
=
[
"setuptools>=41.0"
,
"wheel"
,
"pccm>=0.2.14"
,
"cumm>=0.1.7"
]
build-backend
=
"setuptools.build_meta"
setup.py
View file @
a6abf55d
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Note: To use the 'upload' functionality of this file, you must:
# $ pip install twine
import
io
import
os
import
platform
import
re
import
subprocess
import
shutil
import
sys
from
distutils.version
import
LooseVersion
from
pathlib
import
Path
from
shutil
import
rmtree
from
typing
import
List
import
pccm
from
pccm.extension
import
ExtCallback
,
PCCMBuild
,
PCCMExtension
from
setuptools
import
Command
,
find_packages
,
setup
from
setuptools.extension
import
Extension
from
ccimport
import
compat
import
subprocess
import
re
# Package meta-data.
NAME
=
'spconv'
RELEASE_NAME
=
NAME
deps
=
[
"cumm"
]
cuda_ver
=
os
.
environ
.
get
(
"CUMM_CUDA_VERSION"
,
""
)
if
not
cuda_ver
:
nvcc_version
=
subprocess
.
check_output
([
"nvcc"
,
"--version"
]).
decode
(
"utf-8"
).
strip
()
nvcc_version_str
=
nvcc_version
.
split
(
"
\n
"
)[
3
]
version_str
:
str
=
re
.
findall
(
r
"release (\d+.\d+)"
,
nvcc_version_str
)[
0
]
cuda_ver
=
version_str
cuda_ver
=
cuda_ver
.
replace
(
"."
,
""
)
# 10.2 to 102
RELEASE_NAME
+=
"-cu{}"
.
format
(
cuda_ver
)
deps
=
[
"cumm-cu{}"
.
format
(
cuda_ver
)]
DESCRIPTION
=
'spatial sparse convolution'
URL
=
'https://github.com/traveller59/spconv'
EMAIL
=
'yanyan.sub@outlook.com'
AUTHOR
=
'Yan Yan'
REQUIRES_PYTHON
=
'>=3.6'
VERSION
=
None
# What packages are required for this module to be executed?
REQUIRED
=
[
"pccm>=0.2.14"
,
"pybind11>=2.6.0"
,
"fire"
,
"numpy"
,
*
deps
]
# What packages are optional?
EXTRAS
=
{
# 'fancy feature': ['django'],
}
# The rest you shouldn't have to touch too much :)
# ------------------------------------------------
# Except, perhaps the License and Trove Classifiers!
# If you do change the License, remember to change the Trove Classifier for that!
here
=
os
.
path
.
abspath
(
os
.
path
.
dirname
(
__file__
))
sys
.
path
.
append
(
str
(
Path
(
__file__
).
parent
))
# Import the README and use it as the long-description.
# Note: this will only work if 'README.md' is present in your MANIFEST.in file!
try
:
with
io
.
open
(
os
.
path
.
join
(
here
,
'README.md'
),
encoding
=
'utf-8'
)
as
f
:
long_description
=
'
\n
'
+
f
.
read
()
except
FileNotFoundError
:
long_description
=
DESCRIPTION
# Load the package's __version__.py module as a dictionary.
about
=
{}
if
not
VERSION
:
with
open
(
'version.txt'
,
'r'
)
as
f
:
version
=
f
.
read
().
strip
()
else
:
version
=
VERSION
cwd
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
import
torch
from
setuptools
import
Extension
,
find_packages
,
setup
from
setuptools.command.build_ext
import
build_ext
def
_convert_build_number
(
build_number
):
parts
=
build_number
.
split
(
"."
)
if
len
(
parts
)
==
2
:
return
"{}{:03d}"
.
format
(
int
(
parts
[
0
]),
int
(
parts
[
1
]))
elif
len
(
parts
)
==
1
:
return
build_number
else
:
raise
NotImplementedError
# if 'LIBTORCH_ROOT' not in os.environ:
# raise ValueError("You must set LIBTORCH_ROOT to your torch c++ library.")
LIBTORCH_ROOT
=
str
(
Path
(
torch
.
__file__
).
parent
)
env_suffix
=
os
.
environ
.
get
(
"SPCONV_VERSION_SUFFIX"
,
""
)
if
env_suffix
!=
""
:
version
+=
".dev{}"
.
format
(
_convert_build_number
(
env_suffix
))
version_path
=
os
.
path
.
join
(
cwd
,
NAME
,
'__version__.py'
)
about
[
'__version__'
]
=
version
SPCONV_FORCE_BUILD_CUDA
=
os
.
getenv
(
"SPCONV_FORCE_BUILD_CUDA"
)
with
open
(
version_path
,
'w'
)
as
f
:
f
.
write
(
"__version__ = '{}'
\n
"
.
format
(
version
))
PYTHON_VERSION
=
"{}.{}"
.
format
(
sys
.
version_info
.
major
,
sys
.
version_info
.
minor
)
class
UploadCommand
(
Command
):
"""Support setup.py upload."""
remove_device
=
re
.
search
(
r
"(\+|\.)(dev|cu|cpu)"
,
torch
.
__version__
)
PYTORCH_VERSION
=
torch
.
__version__
if
remove_device
is
not
None
:
PYTORCH_VERSION
=
torch
.
__version__
[:
remove_device
.
start
()]
PYTORCH_VERSION
=
list
(
map
(
int
,
PYTORCH_VERSION
.
split
(
"."
)))
PYTORCH_VERSION_NUMBER
=
PYTORCH_VERSION
[
0
]
*
10000
+
PYTORCH_VERSION
[
1
]
*
100
+
PYTORCH_VERSION
[
2
]
class
CMakeExtension
(
Extension
):
def
__init__
(
self
,
name
,
sourcedir
=
''
,
library_dirs
=
[]):
Extension
.
__init__
(
self
,
name
,
sources
=
[],
library_dirs
=
library_dirs
)
self
.
sourcedir
=
os
.
path
.
abspath
(
sourcedir
)
description
=
'Build and publish the package.'
user_options
=
[]
@
staticmethod
def
status
(
s
):
"""Prints things in bold."""
print
(
'
\033
[1m{0}
\033
[0m'
.
format
(
s
))
def
initialize_options
(
self
):
pass
def
finalize_options
(
self
):
pass
class
CMakeBuild
(
build_ext
):
def
run
(
self
):
try
:
out
=
subprocess
.
check_output
([
'cmake'
,
'--version'
])
self
.
status
(
'Removing previous builds...'
)
rmtree
(
os
.
path
.
join
(
here
,
'dist'
))
except
OSError
:
raise
RuntimeError
(
"CMake must be installed to build the following extensions: "
+
", "
.
join
(
e
.
name
for
e
in
self
.
extensions
))
if
platform
.
system
()
==
"Windows"
:
cmake_version
=
LooseVersion
(
re
.
search
(
r
'version\s*([\d.]+)'
,
out
.
decode
()).
group
(
1
))
if
cmake_version
<
'3.13.0'
:
raise
RuntimeError
(
"CMake >= 3.13.0 is required on Windows"
)
for
ext
in
self
.
extensions
:
self
.
build_extension
(
ext
)
def
build_extension
(
self
,
ext
):
extdir
=
os
.
path
.
abspath
(
os
.
path
.
dirname
(
self
.
get_ext_fullpath
(
ext
.
name
)))
cmake_args
=
[
# '-G "Visual Studio 15 2017 Win64"',
'-DCMAKE_PREFIX_PATH={}'
.
format
(
LIBTORCH_ROOT
),
'-DPYBIND11_PYTHON_VERSION={}'
.
format
(
PYTHON_VERSION
),
'-DSPCONV_BuildTests=OFF'
,
'-DPYTORCH_VERSION={}'
.
format
(
PYTORCH_VERSION_NUMBER
)
]
# -arch=sm_61
if
not
torch
.
cuda
.
is_available
()
and
SPCONV_FORCE_BUILD_CUDA
is
None
:
cmake_args
+=
[
'-DSPCONV_BuildCUDA=OFF'
]
else
:
cuda_flags
=
[
"
\"
--expt-relaxed-constexpr
\"
"
]
# must add following flags to use at::Half
# but will remove raw half operators.
cuda_flags
+=
[
"-D__CUDA_NO_HALF_OPERATORS__"
,
"-D__CUDA_NO_HALF_CONVERSIONS__"
]
# cuda_flags += ["-D__CUDA_NO_HALF2_OPERATORS__"]
cmake_args
+=
[
'-DCMAKE_CUDA_FLAGS='
+
" "
.
join
(
cuda_flags
)]
cfg
=
'Debug'
if
self
.
debug
else
'Release'
assert
cfg
==
"Release"
,
"pytorch ops don't support debug build."
build_args
=
[
'--config'
,
cfg
]
print
(
cfg
)
if
platform
.
system
()
==
"Windows"
:
cmake_args
+=
[
'-DCMAKE_BUILD_TYPE='
+
cfg
]
cmake_args
+=
[
'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}'
.
format
(
cfg
.
upper
(),
str
(
Path
(
extdir
)
/
"spconv"
))]
# cmake_args += ['-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY_{}={}'.format(cfg.upper(), str(Path(extdir) / "spconv"))]
cmake_args
+=
[
'-DCMAKE_RUNTIME_OUTPUT_DIRECTORY_{}={}'
.
format
(
cfg
.
upper
(),
str
(
Path
(
extdir
)
/
"spconv"
))]
cmake_args
+=
[
"-DCMAKE_WINDOWS_EXPORT_ALL_SYMBOLS=TRUE"
]
if
sys
.
maxsize
>
2
**
32
:
cmake_args
+=
[
'-A'
,
'x64'
]
build_args
+=
[
'--'
,
'/m'
]
else
:
cmake_args
+=
[
'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={}'
.
format
(
str
(
Path
(
extdir
)
/
"spconv"
))]
cmake_args
+=
[
'-DCMAKE_BUILD_TYPE='
+
cfg
]
build_args
+=
[
'--'
,
'-j4'
]
env
=
os
.
environ
.
copy
()
env
[
'CXXFLAGS'
]
=
'{} -DVERSION_INFO=
\\
"{}
\\
"'
.
format
(
env
.
get
(
'CXXFLAGS'
,
''
),
self
.
distribution
.
get_version
())
if
not
os
.
path
.
exists
(
self
.
build_temp
):
os
.
makedirs
(
self
.
build_temp
)
print
(
"|||||CMAKE ARGS|||||"
,
cmake_args
)
subprocess
.
check_call
([
'cmake'
,
ext
.
sourcedir
]
+
cmake_args
,
cwd
=
self
.
build_temp
,
env
=
env
)
subprocess
.
check_call
([
'cmake'
,
'--build'
,
'.'
]
+
build_args
,
cwd
=
self
.
build_temp
)
packages
=
find_packages
(
exclude
=
(
'tools'
,
'tools.*'
))
pass
self
.
status
(
'Building Source and Wheel (universal) distribution...'
)
os
.
system
(
'{0} setup.py sdist bdist_wheel --universal'
.
format
(
sys
.
executable
))
self
.
status
(
'Uploading the package to PyPI via Twine...'
)
os
.
system
(
'twine upload dist/*'
)
self
.
status
(
'Pushing git tags...'
)
os
.
system
(
'git tag v{0}'
.
format
(
about
[
'__version__'
]))
os
.
system
(
'git push --tags'
)
sys
.
exit
()
disable_jit
=
os
.
getenv
(
"SPCONV_DISABLE_JIT"
,
None
)
if
disable_jit
is
not
None
and
disable_jit
==
"1"
:
cmdclass
=
{
'upload'
:
UploadCommand
,
'build_ext'
:
PCCMBuild
,
}
from
cumm.gemm.main
import
GemmMainUnitTest
from
spconv.core
import
SHUFFLE_SIMT_PARAMS
,
SHUFFLE_VOLTA_PARAMS
,
SHUFFLE_TURING_PARAMS
from
spconv.csrc.sparse.all
import
SpconvOps
cu
=
GemmMainUnitTest
(
SHUFFLE_SIMT_PARAMS
+
SHUFFLE_VOLTA_PARAMS
+
SHUFFLE_TURING_PARAMS
)
cu
.
namespace
=
"cumm.gemm.main"
cuda_ver_number
=
int
(
cuda_ver
)
if
cuda_ver_number
<
110
:
std
=
"c++14"
else
:
std
=
"c++17"
ext_modules
:
List
[
Extension
]
=
[
PCCMExtension
([
cu
,
SpconvOps
()],
"spconv/core_cc"
,
Path
(
__file__
).
resolve
().
parent
/
"spconv"
,
objects_folder
=
"objects"
,
std
=
std
,
disable_pch
=
True
)
]
else
:
cmdclass
=
{
'upload'
:
UploadCommand
,
}
ext_modules
=
[]
# Where the magic happens:
setup
(
name
=
'spconv'
,
version
=
'1.2.1'
,
author
=
'Yan Yan'
,
author_email
=
'scrin@foxmail.com'
,
description
=
'spatial sparse convolution for pytorch'
,
long_description
=
''
,
setup_requires
=
[
'torch>=1.3.0'
],
packages
=
packages
,
package_dir
=
{
'spconv'
:
'spconv'
},
ext_modules
=
[
CMakeExtension
(
'spconv'
,
library_dirs
=
[])],
cmdclass
=
dict
(
build_ext
=
CMakeBuild
),
zip_safe
=
False
,
name
=
RELEASE_NAME
,
version
=
about
[
'__version__'
],
description
=
DESCRIPTION
,
long_description
=
long_description
,
long_description_content_type
=
'text/markdown'
,
author
=
AUTHOR
,
author_email
=
EMAIL
,
python_requires
=
REQUIRES_PYTHON
,
url
=
URL
,
packages
=
find_packages
(
exclude
=
(
'tests'
,
)),
# If your package is a single module, use this instead of 'packages':
# py_modules=['mypackage'],
entry_points
=
{
'console_scripts'
:
[],
},
install_requires
=
REQUIRED
,
extras_require
=
EXTRAS
,
include_package_data
=
True
,
license
=
'MIT'
,
classifiers
=
[
# Trove classifiers
# Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers
'License :: OSI Approved :: MIT License'
,
'Programming Language :: Python'
,
'Programming Language :: Python :: 3'
,
'Programming Language :: Python :: Implementation :: CPython'
,
'Programming Language :: Python :: Implementation :: PyPy'
],
# $ setup.py publish support.
cmdclass
=
cmdclass
,
ext_modules
=
ext_modules
,
)
spconv/__init__.py
View file @
a6abf55d
# Copyright 201
9
Yan Yan
#
# Copyright 20
2
1 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
platform
from
pathlib
import
Path
from
.
import
build
as
_build
import
numpy
as
np
import
torch
from
spconv
import
ops
,
utils
from
spconv.conv
import
(
SparseConv2d
,
SparseConv3d
,
SparseConvTranspose2d
,
SparseConvTranspose3d
,
SparseInverseConv2d
,
SparseInverseConv3d
,
SubMConv2d
,
SubMConv3d
)
from
spconv.identity
import
Identity
from
spconv.modules
import
SparseModule
,
SparseSequential
from
spconv.ops
import
ConvAlgo
from
spconv.pool
import
SparseMaxPool2d
,
SparseMaxPool3d
from
spconv.tables
import
AddTable
,
ConcatTable
,
JoinTable
_LIB_FILE_NAME
=
"libspconv.so"
if
platform
.
system
()
==
"Windows"
:
_LIB_FILE_NAME
=
"spconv.dll"
_LIB_PATH
=
str
(
Path
(
__file__
).
parent
/
_LIB_FILE_NAME
)
torch
.
ops
.
load_library
(
_LIB_PATH
)
def
scatter_nd
(
indices
,
updates
,
shape
):
"""pytorch edition of tensorflow scatter_nd.
this function don't contain except handle code. so use this carefully
when indice repeats, don't support repeat add which is supported
in tensorflow.
"""
ret
=
torch
.
zeros
(
*
shape
,
dtype
=
updates
.
dtype
,
device
=
updates
.
device
)
ndim
=
indices
.
shape
[
-
1
]
output_shape
=
list
(
indices
.
shape
[:
-
1
])
+
shape
[
indices
.
shape
[
-
1
]:]
flatted_indices
=
indices
.
view
(
-
1
,
ndim
)
slices
=
[
flatted_indices
[:,
i
]
for
i
in
range
(
ndim
)]
slices
+=
[
Ellipsis
]
ret
[
slices
]
=
updates
.
view
(
*
output_shape
)
return
ret
class
SparseConvTensor
(
object
):
def
__init__
(
self
,
features
,
indices
,
spatial_shape
,
batch_size
,
grid
=
None
):
"""
Args:
features: [num_points, num_features] feature tensor
indices: [num_points, ndim + 1] indice tensor. batch index saved in indices[:, 0]
spatial_shape: spatial shape of your sparse data
batch_size: batch size of your sparse data
grid: pre-allocated grid tensor. should be used when the volume of spatial shape
is very large.
"""
self
.
features
=
features
self
.
indices
=
indices
self
.
spatial_shape
=
spatial_shape
self
.
batch_size
=
batch_size
self
.
indice_dict
=
{}
self
.
grid
=
grid
@
classmethod
def
from_dense
(
cls
,
x
:
torch
.
Tensor
):
"""create sparse tensor fron channel last dense tensor by to_sparse
x must be NHWC tensor, channel last
"""
x
=
x
.
to_sparse
(
x
.
ndim
-
1
)
spatial_shape
=
x
.
shape
[
1
:
-
1
]
batch_size
=
x
.
shape
[
0
]
indices_th
=
x
.
indices
().
permute
(
1
,
0
).
contiguous
().
int
()
features_th
=
x
.
values
()
return
cls
(
features_th
,
indices_th
,
spatial_shape
,
batch_size
)
@
property
def
spatial_size
(
self
):
return
np
.
prod
(
self
.
spatial_shape
)
def
find_indice_pair
(
self
,
key
):
if
key
is
None
:
return
None
if
key
in
self
.
indice_dict
:
return
self
.
indice_dict
[
key
]
return
None
def
dense
(
self
,
channels_first
=
True
):
output_shape
=
[
self
.
batch_size
]
+
list
(
self
.
spatial_shape
)
+
[
self
.
features
.
shape
[
1
]]
res
=
scatter_nd
(
self
.
indices
.
to
(
self
.
features
.
device
).
long
(),
self
.
features
,
output_shape
)
if
not
channels_first
:
return
res
ndim
=
len
(
self
.
spatial_shape
)
trans_params
=
list
(
range
(
0
,
ndim
+
1
))
trans_params
.
insert
(
1
,
ndim
+
1
)
return
res
.
permute
(
*
trans_params
).
contiguous
()
@
property
def
sparity
(
self
):
return
self
.
indices
.
shape
[
0
]
/
np
.
prod
(
self
.
spatial_shape
)
/
self
.
batch_size
class
ToDense
(
SparseModule
):
"""convert SparseConvTensor to NCHW dense tensor.
"""
def
forward
(
self
,
x
:
SparseConvTensor
):
return
x
.
dense
()
class
RemoveGrid
(
SparseModule
):
"""remove pre-allocated grid buffer.
"""
def
forward
(
self
,
x
:
SparseConvTensor
):
x
.
grid
=
None
return
x
from
.core
import
ConvAlgo
,
AlgoHint
from
.
import
constants
\ No newline at end of file
spconv/algo.py
0 → 100644
View file @
a6abf55d
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
enum
import
Enum
from
cumm
import
tensorview
as
tv
from
typing
import
Dict
,
List
,
Set
,
Tuple
from
spconv.core_cc.cumm.gemm.main
import
GemmAlgoDesp
,
GemmMainUnitTest
,
GemmParams
# from spconv.core_cc.cumm.gemm.gather import GatherAll, ScatterAll
from
cumm.gemm.algospec.core
import
ShuffleStrideType
,
get_min_arch_of_algo_str
,
get_available_algo_str_from_arch
from
cumm.gemm.codeops
import
group_by
,
div_up
from
typing
import
Optional
import
time
import
numpy
as
np
from
.core
import
ConvAlgo
,
AlgoHint
ALL_ALGO_DESPS
=
GemmMainUnitTest
.
get_all_algo_desp
()
_GEMM_STATIC_KEY
=
Tuple
[
bool
,
bool
,
bool
,
int
,
int
,
int
,
str
,
str
]
# GATHER = GatherAll()
# SCATTER = ScatterAll()
class
SimpleGemmAlgoMeta
:
def
__init__
(
self
,
tile_ms
:
List
[
int
],
tile_ns
:
List
[
int
],
tile_ks
:
List
[
int
],
tile_shape_to_algos
:
Dict
[
int
,
List
[
int
]])
->
None
:
self
.
tile_shape_to_algos
=
tile_shape_to_algos
self
.
tile_ms
=
tile_ms
self
.
tile_ns
=
tile_ns
self
.
tile_ks
=
tile_ks
class
BestAlgoByProfile
:
def
__init__
(
self
,
algo_desp
:
GemmAlgoDesp
,
external_gather
:
bool
,
external_scatter
:
bool
,
gather_params
:
Optional
[
Tuple
[
int
,
int
,
int
,
int
]]
=
None
,
scatter_params
:
Optional
[
Tuple
[
int
,
int
,
int
,
int
]]
=
None
,
splitk
:
int
=
1
)
->
None
:
self
.
algo_desp
=
algo_desp
self
.
external_gather
=
external_gather
self
.
external_scatter
=
external_scatter
self
.
gather_params
=
gather_params
self
.
scatter_params
=
scatter_params
self
.
splitk
=
splitk
class
SimpleGemm
:
def
__init__
(
self
,
desps
:
List
[
GemmAlgoDesp
])
->
None
:
self
.
desps
=
desps
self
.
static_key_to_desps
=
group_by
(
self
.
get_static_key
,
desps
)
self
.
static_key_to_meta
:
Dict
[
_GEMM_STATIC_KEY
,
SimpleGemmAlgoMeta
]
=
{}
for
k
,
static_desps
in
self
.
static_key_to_desps
.
items
():
tile_shape_to_algos
:
Dict
[
int
,
List
[
int
]]
=
{}
tile_ms
:
Set
[
int
]
=
set
()
tile_ns
:
Set
[
int
]
=
set
()
tile_ks
:
Set
[
int
]
=
set
()
for
i
,
desp
in
enumerate
(
static_desps
):
ts
=
desp
.
tile_shape
tile_ms
.
add
(
ts
[
0
])
tile_ns
.
add
(
ts
[
1
])
tile_ks
.
add
(
ts
[
2
])
tile_key
=
ts
[
0
]
|
(
ts
[
1
]
<<
20
)
|
(
ts
[
2
]
<<
40
)
if
tile_key
not
in
tile_shape_to_algos
:
tile_shape_to_algos
[
tile_key
]
=
[]
tile_shape_to_algos
[
tile_key
].
append
(
i
)
tile_ms_list
=
list
(
tile_ms
)
tile_ns_list
=
list
(
tile_ns
)
tile_ks_list
=
list
(
tile_ks
)
tile_ms_list
.
sort
()
tile_ns_list
.
sort
()
tile_ks_list
.
sort
()
self
.
static_key_to_meta
[
k
]
=
SimpleGemmAlgoMeta
(
tile_ms_list
,
tile_ns_list
,
tile_ks_list
,
tile_shape_to_algos
)
self
.
nk_forward_cache
:
Dict
[
Tuple
[
int
,
int
],
BestAlgoByProfile
]
=
{}
# for forward
self
.
nk_dgrad_cache
:
Dict
[
Tuple
[
int
,
int
],
BestAlgoByProfile
]
=
{}
# for backward weight
self
.
mn_cache
:
Dict
[
Tuple
[
int
,
int
],
BestAlgoByProfile
]
=
{}
# for backward weight
@
staticmethod
def
get_static_key
(
d
:
GemmAlgoDesp
)
->
_GEMM_STATIC_KEY
:
return
(
d
.
trans_a
,
d
.
trans_b
,
d
.
trans_c
,
d
.
dtype_a
,
d
.
dtype_b
,
d
.
dtype_c
,
d
.
shuffle_type
,
d
.
algo
)
def
device_synchronize
(
self
):
return
GemmMainUnitTest
.
device_synchronize
()
def
get_all_available
(
self
,
a
:
tv
.
Tensor
,
b
:
tv
.
Tensor
,
c
:
tv
.
Tensor
,
trans_a
:
bool
,
trans_b
:
bool
,
trans_c
:
bool
,
arch
:
Tuple
[
int
,
int
],
shuffle_type
:
ShuffleStrideType
=
ShuffleStrideType
.
NoShuffle
):
if
trans_c
:
trans_a
=
not
trans_a
trans_b
=
not
trans_b
trans_a
,
trans_b
=
trans_b
,
trans_a
a
,
b
=
b
,
a
trans_c
=
False
avail_algos
=
get_available_algo_str_from_arch
(
arch
)
finally_algos
:
List
[
GemmAlgoDesp
]
=
[]
for
algo
in
avail_algos
:
static_key
=
(
trans_a
,
trans_b
,
trans_c
,
a
.
dtype
,
b
.
dtype
,
c
.
dtype
,
shuffle_type
.
value
,
algo
)
desps
=
self
.
static_key_to_desps
.
get
(
static_key
,
None
)
if
desps
is
None
or
len
(
desps
)
==
0
:
continue
for
desp
in
desps
:
lda
=
a
.
dim
(
1
)
ldb
=
b
.
dim
(
1
)
ldc
=
c
.
dim
(
1
)
if
desp
.
supported_ldx
(
lda
,
ldb
,
ldc
):
finally_algos
.
append
(
desp
)
return
finally_algos
def
select
(
self
,
a
:
tv
.
Tensor
,
b
:
tv
.
Tensor
,
c
:
tv
.
Tensor
,
trans_a
:
bool
,
trans_b
:
bool
,
trans_c
:
bool
,
arch
:
Tuple
[
int
,
int
],
shuffle_type
:
ShuffleStrideType
=
ShuffleStrideType
.
NoShuffle
,
a_inds
:
tv
.
Tensor
=
tv
.
Tensor
(),
b_inds
:
tv
.
Tensor
=
tv
.
Tensor
(),
c_inds
:
tv
.
Tensor
=
tv
.
Tensor
(),
hint
:
int
=
AlgoHint
.
NoHint
.
value
):
m
,
n
,
k
=
GemmMainUnitTest
.
extract_mnk
(
a
.
shape
,
b
.
shape
,
trans_a
,
trans_b
,
trans_c
,
shuffle_type
.
value
,
a_inds
.
shape
,
b_inds
.
shape
,
c_inds
.
shape
)
if
trans_c
:
trans_a
=
not
trans_a
trans_b
=
not
trans_b
trans_a
,
trans_b
=
trans_b
,
trans_a
a
,
b
=
b
,
a
trans_c
=
False
avail_algos
=
get_available_algo_str_from_arch
(
arch
)
finally_algos
:
List
[
GemmAlgoDesp
]
=
[]
for
algo
in
avail_algos
:
static_key
=
(
trans_a
,
trans_b
,
trans_c
,
a
.
dtype
,
b
.
dtype
,
c
.
dtype
,
shuffle_type
.
value
,
algo
)
desps
=
self
.
static_key_to_desps
.
get
(
static_key
,
None
)
if
desps
is
None
or
len
(
desps
)
==
0
:
continue
meta
=
self
.
static_key_to_meta
[
static_key
]
# for shuffle stride algos, we need to make channel tile size as large as possible.
# so if ShuffleAC, we need to make k largest.
selected_algo_desps
=
GemmMainUnitTest
.
simple_select_tile_shape
(
m
,
n
,
k
,
meta
.
tile_ms
,
meta
.
tile_ns
,
meta
.
tile_ks
,
meta
.
tile_shape_to_algos
,
large_k_first
=
shuffle_type
==
shuffle_type
.
ShuffleAC
)
if
not
selected_algo_desps
:
candidate
=
desps
else
:
candidate
=
[
desps
[
i
]
for
i
in
selected_algo_desps
]
# select by hint
if
hint
==
0
:
return
candidate
[
0
]
if
hint
&
(
AlgoHint
.
Fowrard
.
value
|
AlgoHint
.
BackwardInput
.
value
):
# m may be huge, n and k are small
# don't need mixed precision
# don't need splitk
finally_algos
=
[]
if
a
.
dtype
==
tv
.
float16
:
dacc
=
tv
.
float16
dcomp
=
tv
.
float16
for
can
in
candidate
:
if
can
.
dacc
==
dacc
and
can
.
dcomp
==
dcomp
:
finally_algos
.
append
(
can
)
else
:
finally_algos
=
candidate
elif
hint
&
AlgoHint
.
BackwardWeight
.
value
:
# k is huge
# don't support i8
# if f16, acc and comp must be f32
finally_algos
=
[]
candidate_filtered
:
List
[
GemmAlgoDesp
]
=
list
(
filter
(
lambda
x
:
x
.
split_k_serial
,
candidate
))
if
not
candidate_filtered
:
candidate_filtered
=
candidate
if
a
.
dtype
==
tv
.
int8
:
continue
elif
a
.
dtype
==
tv
.
float16
:
dacc
=
tv
.
float32
dcomp
=
tv
.
float32
for
can
in
candidate_filtered
:
if
can
.
dacc
==
dacc
and
can
.
dcomp
==
dcomp
:
finally_algos
.
append
(
can
)
else
:
finally_algos
=
candidate_filtered
else
:
return
candidate
[
0
]
# print(finally_algos)
if
finally_algos
:
return
finally_algos
[
0
]
return
None
def
get_profiled_algo
(
self
,
a_shape
:
List
[
int
],
b_shape
:
List
[
int
],
c_shape
:
List
[
int
],
trans_a
:
bool
,
trans_b
:
bool
,
trans_c
:
bool
,
arch
:
Tuple
[
int
,
int
],
shuffle_type
:
ShuffleStrideType
=
ShuffleStrideType
.
NoShuffle
,
a_inds_shape
:
Optional
[
List
[
int
]]
=
None
,
b_inds_shape
:
Optional
[
List
[
int
]]
=
None
,
c_inds_shape
:
Optional
[
List
[
int
]]
=
None
,
hint
:
int
=
AlgoHint
.
NoHint
.
value
):
if
a_inds_shape
is
None
:
a_inds_shape
=
[]
if
b_inds_shape
is
None
:
b_inds_shape
=
[]
if
c_inds_shape
is
None
:
c_inds_shape
=
[]
m
,
n
,
k
=
GemmMainUnitTest
.
extract_mnk
(
a_shape
,
b_shape
,
trans_a
,
trans_b
,
trans_c
,
shuffle_type
.
value
,
a_inds_shape
,
b_inds_shape
,
c_inds_shape
)
if
hint
&
AlgoHint
.
BackwardWeight
.
value
:
key
=
(
m
,
n
)
return
self
.
mn_cache
.
get
(
key
,
None
)
elif
hint
&
AlgoHint
.
BackwardInput
.
value
:
key
=
(
n
,
k
)
return
self
.
nk_dgrad_cache
.
get
(
key
,
None
)
elif
hint
&
AlgoHint
.
Fowrard
.
value
:
key
=
(
n
,
k
)
return
self
.
nk_forward_cache
.
get
(
key
,
None
)
raise
NotImplementedError
def
extract_mnk
(
self
,
a_shape
:
List
[
int
],
b_shape
:
List
[
int
],
trans_a
:
bool
,
trans_b
:
bool
,
trans_c
:
bool
,
arch
:
Tuple
[
int
,
int
],
shuffle_type
:
ShuffleStrideType
=
ShuffleStrideType
.
NoShuffle
,
a_inds_shape
:
Optional
[
List
[
int
]]
=
None
,
b_inds_shape
:
Optional
[
List
[
int
]]
=
None
,
c_inds_shape
:
Optional
[
List
[
int
]]
=
None
,
hint
:
int
=
AlgoHint
.
NoHint
.
value
):
if
a_inds_shape
is
None
:
a_inds_shape
=
[]
if
b_inds_shape
is
None
:
b_inds_shape
=
[]
if
c_inds_shape
is
None
:
c_inds_shape
=
[]
m
,
n
,
k
=
GemmMainUnitTest
.
extract_mnk
(
a_shape
,
b_shape
,
trans_a
,
trans_b
,
trans_c
,
shuffle_type
.
value
,
a_inds_shape
,
b_inds_shape
,
c_inds_shape
)
return
m
,
n
,
k
def
profile_and_cache
(
self
,
a
:
tv
.
Tensor
,
b
:
tv
.
Tensor
,
c
:
tv
.
Tensor
,
trans_a
:
bool
,
trans_b
:
bool
,
trans_c
:
bool
,
arch
:
Tuple
[
int
,
int
],
shuffle_type
:
ShuffleStrideType
=
ShuffleStrideType
.
NoShuffle
,
a_inds
:
tv
.
Tensor
=
tv
.
Tensor
(),
b_inds
:
tv
.
Tensor
=
tv
.
Tensor
(),
c_inds
:
tv
.
Tensor
=
tv
.
Tensor
(),
hint
:
int
=
AlgoHint
.
NoHint
.
value
,
alpha
:
float
=
1.0
,
beta
:
float
=
0.0
,
gather_data
:
tv
.
Tensor
=
tv
.
Tensor
(),
scatter_data
:
tv
.
Tensor
=
tv
.
Tensor
(),
# mm_func
stream
:
int
=
0
):
m
,
n
,
k
=
GemmMainUnitTest
.
extract_mnk
(
a
.
shape
,
b
.
shape
,
trans_a
,
trans_b
,
trans_c
,
shuffle_type
.
value
,
a_inds
.
shape
,
b_inds
.
shape
,
c_inds
.
shape
)
if
hint
&
AlgoHint
.
BackwardWeight
.
value
:
key
=
(
m
,
n
)
else
:
key
=
(
n
,
k
)
avail
=
self
.
get_all_available
(
a
,
b
,
c
,
trans_a
,
trans_b
,
trans_c
,
arch
,
shuffle_type
)
c_
=
c
.
clone
()
times
:
List
[
float
]
=
[]
# gather_algos: List[GemmAlgoDesp] = []
# find fastest gather algo for this input
best_gather_params
=
(
-
1
,
-
1
,
-
1
,
-
1
)
best_scatter_params
=
(
-
1
,
-
1
,
-
1
,
-
1
)
# gather_data_ = tv.Tensor()
# if not gather_data.empty(
# ) and not hint & AlgoHint.BackwardWeight.value:
# # run gather here
# all_gather_params = GATHER.get_all_gather_params()
# gather_data_ = gather_data.clone()
# gather_times: List[float] = []
# for gather_params in all_gather_params:
# if GATHER.supported(gather_params[2], a.dim(1), a.dtype):
# this_times = []
# for j in range(10):
# GemmMainUnitTest.stream_synchronize(stream)
# t = time.time()
# GATHER.gather(gather_data_, a, a_inds, *gather_params)
# GemmMainUnitTest.stream_synchronize(stream)
# this_times.append(time.time() - t)
# gather_times.append(np.mean(this_times[5:]))
# min_time = 1000
# min_idx = -1
# for i, t in enumerate(gather_times):
# if t < min_time:
# min_time = t
# min_idx = i
# best_gather_params = all_gather_params[min_idx]
# if not scatter_data.empty(
# ) and not hint & AlgoHint.BackwardWeight.value:
# # run gather here
# all_scatter_params = SCATTER.get_all_scatter_params()
# scatter_data_ = scatter_data.clone()
# scatter_times: List[float] = []
# for params in all_scatter_params:
# if SCATTER.supported_scatter(*params, a.dim(1), a.dtype):
# this_times = []
# for j in range(10):
# GemmMainUnitTest.stream_synchronize(stream)
# t = time.time()
# SCATTER.scatter(c_, scatter_data_, c_inds, *params)
# GemmMainUnitTest.stream_synchronize(stream)
# this_times.append(time.time() - t)
# scatter_times.append(np.mean(this_times[5:]))
# min_time = 1000
# min_idx = -1
# for i, t in enumerate(scatter_times):
# if t < min_time:
# min_time = t
# min_idx = i
# best_scatter_params = all_scatter_params[min_idx]
all_profile_res
:
List
[
BestAlgoByProfile
]
=
[]
for
desp
in
avail
:
c_
.
zero_
()
split_k_slices
=
1
# TODO better splitk selection
if
desp
.
split_k_serial
and
hint
&
AlgoHint
.
BackwardWeight
.
value
:
split_k_slices
=
max
(
min
(
32
,
k
//
128
),
1
)
params
=
GemmParams
()
params
.
a
=
a
params
.
b
=
b
params
.
c
=
c_
params
.
a_inds
=
a_inds
params
.
b_inds
=
b_inds
params
.
c_inds
=
c_inds
params
.
algo_desp
=
desp
params
.
alpha
=
alpha
params
.
beta
=
beta
params
.
stream
=
stream
if
desp
.
split_k_serial
and
hint
&
AlgoHint
.
BackwardWeight
.
value
:
splitk_tests
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
]
else
:
splitk_tests
=
[
1
]
spk_speeds
=
[]
for
spk
in
splitk_tests
:
this_times
=
[]
for
j
in
range
(
3
):
GemmMainUnitTest
.
stream_synchronize
(
stream
)
t
=
time
.
time
()
params
.
split_k_slices
=
spk
GemmMainUnitTest
.
matmul2
(
params
)
GemmMainUnitTest
.
stream_synchronize
(
stream
)
this_times
.
append
(
time
.
time
()
-
t
)
times
.
append
(
np
.
mean
(
this_times
[
1
:]))
spk_speeds
.
append
(
times
[
-
1
])
all_profile_res
.
append
(
BestAlgoByProfile
(
desp
,
False
,
False
,
best_gather_params
,
best_scatter_params
,
splitk
=
spk
))
# if desp.split_k_serial:
# print(a.shape, b.shape, spk_speeds)
# if not gather_data.empty(
# ) and not hint & AlgoHint.BackwardWeight.value:
# # run gather here
# for spk in splitk_tests:
# this_times = []
# for j in range(3):
# GemmMainUnitTest.stream_synchronize(stream)
# t = time.time()
# params.a_inds = tv.Tensor()
# params.a = gather_data_
# params.split_k_slices = spk
# GATHER.gather(gather_data_,
# a,
# a_inds,
# *best_gather_params,
# stream=stream)
# GemmMainUnitTest.matmul2(params)
# GemmMainUnitTest.stream_synchronize(stream)
# this_times.append(time.time() - t)
# times.append(np.mean(this_times[1:]))
# # print("G", times[-1], times[-2])
# all_profile_res.append(
# BestAlgoByProfile(desp,
# True,
# False,
# best_gather_params, best_scatter_params,
# splitk=spk))
min_time
=
1000
min_idx
=
-
1
for
i
,
t
in
enumerate
(
times
):
if
t
<
min_time
:
min_time
=
t
min_idx
=
i
res
=
all_profile_res
[
min_idx
]
if
hint
&
AlgoHint
.
BackwardWeight
.
value
:
key
=
(
m
,
n
)
self
.
mn_cache
[
key
]
=
res
elif
hint
&
AlgoHint
.
BackwardInput
.
value
:
key
=
(
n
,
k
)
self
.
nk_dgrad_cache
[
key
]
=
res
elif
hint
&
AlgoHint
.
Fowrard
.
value
:
key
=
(
n
,
k
)
self
.
nk_forward_cache
[
key
]
=
res
else
:
raise
NotImplementedError
return
res
,
min_time
def
run_profile
(
self
,
profile_res
:
BestAlgoByProfile
,
a
:
tv
.
Tensor
,
b
:
tv
.
Tensor
,
c
:
tv
.
Tensor
,
trans_a
:
bool
,
trans_b
:
bool
,
trans_c
:
bool
,
arch
:
Tuple
[
int
,
int
],
stream
:
int
,
shuffle_type
:
ShuffleStrideType
=
ShuffleStrideType
.
NoShuffle
,
a_inds
:
tv
.
Tensor
=
tv
.
Tensor
(),
b_inds
:
tv
.
Tensor
=
tv
.
Tensor
(),
c_inds
:
tv
.
Tensor
=
tv
.
Tensor
(),
hint
:
int
=
AlgoHint
.
NoHint
.
value
,
alpha
:
float
=
1.0
,
beta
:
float
=
0.0
,
gather_data
:
tv
.
Tensor
=
tv
.
Tensor
(),
workspace
:
tv
.
Tensor
=
tv
.
Tensor
()):
m
,
n
,
k
=
GemmMainUnitTest
.
extract_mnk
(
a
.
shape
,
b
.
shape
,
trans_a
,
trans_b
,
trans_c
,
shuffle_type
.
value
,
a_inds
.
shape
,
b_inds
.
shape
,
c_inds
.
shape
)
# GemmMainUnitTest.stream_synchronize(stream)
algo_desp
=
profile_res
.
algo_desp
assert
algo_desp
is
not
None
split_k_slices
=
1
# TODO better splitk selection
# if algo_desp.split_k_serial and hint & AlgoHint.BackwardWeight.value:
# split_k_slices = max(min(32, k // 128), 1)
if
profile_res
.
splitk
>
1
:
split_k_slices
=
profile_res
.
splitk
params
=
GemmParams
()
params
.
a
=
a
params
.
b
=
b
params
.
c
=
c
params
.
a_inds
=
a_inds
params
.
b_inds
=
b_inds
params
.
c_inds
=
c_inds
params
.
algo_desp
=
algo_desp
params
.
split_k_slices
=
split_k_slices
params
.
stream
=
stream
params
.
alpha
=
alpha
params
.
beta
=
beta
params
.
workspace
=
workspace
# gather = 0
# if profile_res.external_gather and not gather_data.empty():
# GemmMainUnitTest.stream_synchronize(stream)
# tt = time.time()
# assert not gather_data.empty()
# params.a_inds = tv.Tensor()
# params.a = gather_data
# # print(profile_res.gather_params, gather_data.shape, a.shape, a_inds.shape)
# GATHER.gather(gather_data,
# a,
# a_inds,
# *profile_res.gather_params,
# stream=stream)
# GemmMainUnitTest.stream_synchronize(stream)
# gather = time.time() - tt
GemmMainUnitTest
.
matmul2
(
params
)
# GemmMainUnitTest.stream_synchronize(stream)
return
algo_desp
GEMM
=
SimpleGemm
(
ALL_ALGO_DESPS
)
if
__name__
==
"__main__"
:
print
(
len
(
ALL_ALGO_DESPS
))
print
(
ALL_ALGO_DESPS
[
0
])
a
=
tv
.
zeros
([
64000
,
32
],
dtype
=
tv
.
float16
)
b
=
tv
.
zeros
([
32
,
64
],
dtype
=
tv
.
float16
)
c
=
tv
.
zeros
([
64000
,
64
],
dtype
=
tv
.
float16
)
a_inds
=
tv
.
zeros
([
64000
],
dtype
=
tv
.
int32
)
c_inds
=
tv
.
zeros
([
64000
],
dtype
=
tv
.
int32
)
t
=
time
.
time
()
for
i
in
range
(
100
):
algo
=
GEMM
.
select
(
a
,
c
,
b
,
True
,
False
,
False
,
(
7
,
5
),
ShuffleStrideType
.
ShuffleAB
,
a_inds
=
a_inds
,
b_inds
=
c_inds
)
print
((
time
.
time
()
-
t
)
/
100
)
print
(
algo
)
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment