Commit a6abf55d authored by yan.yan's avatar yan.yan
Browse files

Merge branch 'develop'

parents fad30002 79a3eaf2
// Copyright 2019-2020 Yan Yan
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <sstream>
#ifdef TV_USE_STACKTRACE
#if defined(WIN32) || defined(_WIN32) || \
defined(__WIN32) && !defined(__CYGWIN__)
#define BOOST_STACKTRACE_USE_WINDBG
#else
// require linking with -ldl and -lbacktrace in linux
#define BOOST_STACKTRACE_USE_BACKTRACE
#endif
#include <boost/stacktrace.hpp>
#endif
#ifdef TV_CUDA
#include <cuda.h>
#endif
#if defined(TV_USE_BOOST_TYPEOF) || (!defined(__clang__) && defined(CUDA_VERSION) && CUDA_VERSION >= 11000)
// a workaround when built with cuda 11
// two options: use BOOST_TYPEOF or identity_t.
// this is a nvcc bug, msvc/gcc/clang don't have this problem.
// #include <boost/typeof/typeof.hpp>
// #define TV_DECLTYPE(x) BOOST_TYPEOF(x)
namespace tv{
template <typename T>
using identity_t = T;
}
#define TV_DECLTYPE(x) tv::identity_t<decltype(x)>
#else
#define TV_DECLTYPE(x) decltype(x)
#endif
namespace tv {
template <class SStream, class T> void sstream_print(SStream &ss, T val) {
ss << val;
}
template <class SStream, class T, class... TArgs>
void sstream_print(SStream &ss, T val, TArgs... args) {
ss << val << " ";
sstream_print(ss, args...);
}
template <class... TArgs> void ssprint(TArgs... args) {
std::stringstream ss;
sstream_print(ss, args...);
std::cout << ss.str() << std::endl;
}
#ifdef TV_USE_STACKTRACE
#define TV_BACKTRACE_PRINT(ss) \
ss << std::endl << boost::stacktrace::stacktrace();
#else
#define TV_BACKTRACE_PRINT(ss)
#endif
#define TV_THROW_RT_ERR(...) \
{ \
std::stringstream __macro_s; \
__macro_s << __FILE__ << " " << __LINE__ << "\n"; \
tv::sstream_print(__macro_s, __VA_ARGS__); \
TV_BACKTRACE_PRINT(__macro_s); \
throw std::runtime_error(__macro_s.str()); \
}
#define TV_THROW_INVALID_ARG(...) \
{ \
std::stringstream __macro_s; \
__macro_s << __FILE__ << " " << __LINE__ << "\n"; \
tv::sstream_print(__macro_s, __VA_ARGS__); \
TV_BACKTRACE_PRINT(__macro_s); \
throw std::invalid_argument(__macro_s.str()); \
}
#define TV_ASSERT_RT_ERR(expr, ...) \
{ \
if (!(expr)) { \
std::stringstream __macro_s; \
__macro_s << __FILE__ << " " << __LINE__ << "\n"; \
__macro_s << #expr << " assert faild. "; \
tv::sstream_print(__macro_s, __VA_ARGS__); \
TV_BACKTRACE_PRINT(__macro_s); \
throw std::runtime_error(__macro_s.str()); \
} \
}
#define TV_ASSERT_INVALID_ARG(expr, ...) \
{ \
if (!(expr)) { \
std::stringstream __macro_s; \
__macro_s << __FILE__ << " " << __LINE__ << "\n"; \
__macro_s << #expr << " assert faild. "; \
tv::sstream_print(__macro_s, __VA_ARGS__); \
TV_BACKTRACE_PRINT(__macro_s); \
throw std::invalid_argument(__macro_s.str()); \
} \
}
} // namespace tv
\ No newline at end of file
#pragma once
// from pytorch.aten
#include "tensorview.h"
#include <type_traits>
namespace tv {
namespace cuda {
template <typename T1, typename T2> inline int DivUp(const T1 a, const T2 b) {
return (a + b - 1) / b;
}
// Use 1024 threads per block, which requires cuda sm_2x or above
constexpr int CUDA_NUM_THREADS = 1024;
// CUDA: number of blocks for threads.
inline int getNumThreads(const int N) {
if (N > CUDA_NUM_THREADS) {
return CUDA_NUM_THREADS;
}
return DivUp(N, 32) * 32;
}
inline int getBlocks(const int N) {
TV_ASSERT_RT_ERR(N > 0,
"CUDA kernel launch blocks must be positive, but got N=", N);
return DivUp(N, getNumThreads(N));
}
} // namespace cuda
} // namespace tv
\ No newline at end of file
// Copyright 2019-2020 Yan Yan
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tensor.h"
#include "tensorview.h"
#include <eigen3/Eigen/Dense>
namespace tv {
template <typename T, int Row = Eigen::Dynamic, int Col = Eigen::Dynamic>
Eigen::Map<Eigen::Matrix<T, Row, Col, Eigen::RowMajor>>
tv2eigen(TensorView<T> view) {
TV_ASSERT_INVALID_ARG(view.ndim() <= 2 && view.ndim() > 0, "error");
if (Row != Eigen::Dynamic) {
TV_ASSERT_INVALID_ARG(view.dim(0) == Row, "error");
}
if (Col != Eigen::Dynamic) {
TV_ASSERT_INVALID_ARG(view.dim(1) == Col, "error");
}
int row = 1;
if (view.ndim() == 2) {
row = view.dim(0);
}
Eigen::Map<Eigen::Matrix<T, Row, Col, Eigen::RowMajor>> eigen_map(
view.data(), row, view.dim(1));
return eigen_map;
}
} // namespace tv
#pragma once
// from tensorflow
namespace tv {
namespace detail {
template <typename T> class KernelLoop {
struct Iterator {
__forceinline__ __device__ Iterator(T index, T delta)
: index_(index), delta_(delta) {}
__forceinline__ __device__ T operator*() const { return index_; }
__forceinline__ __device__ Iterator &operator++() {
index_ += delta_;
return *this;
}
__forceinline__ __device__ bool operator!=(const Iterator &other) const {
bool greater = index_ > other.index_;
bool less = index_ < other.index_;
// Anything past an end iterator (delta_ == 0) is equal.
// In range-based for loops, this optimizes to 'return less'.
if (!other.delta_) {
return less;
}
if (!delta_) {
return greater;
}
return less || greater;
}
private:
T index_;
const T delta_;
};
public:
__forceinline__ __device__ KernelLoop(T begin, T delta, T end)
: begin_(begin), delta_(delta), end_(end) {}
__forceinline__ __device__ Iterator begin() const {
return Iterator{begin_, delta_};
}
__forceinline__ __device__ Iterator end() const { return Iterator{end_, 0}; }
private:
T begin_;
T delta_;
T end_;
};
} // namespace detail
template <typename T, int NumILP = 1>
__forceinline__ __device__ detail::KernelLoop<T> KernelLoopX(T count) {
return detail::KernelLoop<T>(blockIdx.x * blockDim.x + threadIdx.x,
gridDim.x * blockDim.x * NumILP, count);
}
// Helper to visit indices in the range 0 <= i < count using the y-coordinate.
// Usage: for(int i : KernelLoopY(count)) { visit(i); }
template <typename T, int NumILP = 1>
__forceinline__ __device__ detail::KernelLoop<T> KernelLoopY(T count) {
return detail::KernelLoop<T>(blockIdx.y * blockDim.y + threadIdx.y,
gridDim.y * blockDim.y * NumILP, count);
}
// Helper to visit indices in the range 0 <= i < count using the z-coordinate.
// Usage: for(int i : KernelLoopZ(count)) { visit(i); }
template <typename T, int NumILP = 1>
__forceinline__ __device__ detail::KernelLoop<T> KernelLoopZ(T count) {
return detail::KernelLoop<T>(blockIdx.z * blockDim.z + threadIdx.z,
gridDim.z * blockDim.z * NumILP, count);
}
} // namespace tv
\ No newline at end of file
#ifndef MP_HELPER_H_
#define MP_HELPER_H_
#include <type_traits>
#include <utility>
namespace tv {
template <class... T> struct mp_list {};
template <class T, T... I>
using mp_list_c = mp_list<std::integral_constant<T, I>...>;
namespace detail {
template <class... Ts, class F>
constexpr F mp_for_each_impl(mp_list<Ts...>, F &&f) {
return (void)(std::initializer_list<int>{(f(Ts()), 0)...}),
std::forward<F>(f);
}
template <class F> constexpr F mp_for_each_impl(mp_list<>, F &&f) {
return std::forward<F>(f);
}
} // namespace detail
template <class... T>
using mp_length = std::integral_constant<std::size_t, sizeof...(T)>;
namespace detail {
template <class A, template <class...> class B> struct mp_rename_impl {
// An error "no type named 'type'" here means that the first argument to
// mp_rename is not a list
};
template <template <class...> class A, class... T, template <class...> class B>
struct mp_rename_impl<A<T...>, B> {
using type = B<T...>;
};
} // namespace detail
template <class A, template <class...> class B>
using mp_rename = typename detail::mp_rename_impl<A, B>::type;
template <class L> using mp_size = mp_rename<L, mp_length>;
template <class L, class F> constexpr F mp_for_each(F &&f) {
return detail::mp_for_each_impl(mp_rename<L, mp_list>(), std::forward<F>(f));
}
} // namespace tv
#endif
\ No newline at end of file
// Copyright Louis Delacroix 2010 - 2014.
// Distributed under the Boost Software License, Version 1.0.
// (See accompanying file LICENSE_1_0.txt or copy at
// http://www.boost.org/LICENSE_1_0.txt)
//
// A pretty printing library for C++
//
// Usage:
// Include this header, and operator<< will "just work".
#ifndef H_PRETTY_PRINT
#define H_PRETTY_PRINT
#include <cstddef>
#include <iterator>
#include <memory>
#include <ostream>
#include <set>
#include <tuple>
#include <type_traits>
#include <unordered_set>
#include <utility>
#include <valarray>
namespace pretty_print {
namespace detail {
// SFINAE type trait to detect whether T::const_iterator exists.
struct sfinae_base {
using yes = char;
using no = yes[2];
};
template <typename T> struct has_const_iterator : private sfinae_base {
private:
template <typename C> static yes &test(typename C::const_iterator *);
template <typename C> static no &test(...);
public:
static const bool value = sizeof(test<T>(nullptr)) == sizeof(yes);
using type = T;
};
template <typename T> struct has_begin_end : private sfinae_base {
private:
template <typename C>
static yes &
f(typename std::enable_if<
std::is_same<decltype(static_cast<typename C::const_iterator (C::*)()
const>(&C::begin)),
typename C::const_iterator (C::*)() const>::value>::type *);
template <typename C> static no &f(...);
template <typename C>
static yes &
g(typename std::enable_if<
std::is_same<decltype(static_cast<typename C::const_iterator (C::*)()
const>(&C::end)),
typename C::const_iterator (C::*)() const>::value,
void>::type *);
template <typename C> static no &g(...);
public:
static bool const beg_value = sizeof(f<T>(nullptr)) == sizeof(yes);
static bool const end_value = sizeof(g<T>(nullptr)) == sizeof(yes);
};
} // namespace detail
// Holds the delimiter values for a specific character type
template <typename TChar> struct delimiters_values {
using char_type = TChar;
const char_type *prefix;
const char_type *delimiter;
const char_type *postfix;
};
// Defines the delimiter values for a specific container and character type
template <typename T, typename TChar> struct delimiters {
using type = delimiters_values<TChar>;
static const type values;
};
// Functor to print containers. You can use this directly if you want
// to specificy a non-default delimiters type. The printing logic can
// be customized by specializing the nested template.
template <typename T, typename TChar = char,
typename TCharTraits = ::std::char_traits<TChar>,
typename TDelimiters = delimiters<T, TChar>>
struct print_container_helper {
using delimiters_type = TDelimiters;
using ostream_type = std::basic_ostream<TChar, TCharTraits>;
template <typename U> struct printer {
static void print_body(const U &c, ostream_type &stream) {
using std::begin;
using std::end;
auto it = begin(c);
const auto the_end = end(c);
if (it != the_end) {
for (;;) {
stream << *it;
if (++it == the_end)
break;
if (delimiters_type::values.delimiter != NULL)
stream << delimiters_type::values.delimiter;
}
}
}
};
print_container_helper(const T &container) : container_(container) {}
inline void operator()(ostream_type &stream) const {
if (delimiters_type::values.prefix != NULL)
stream << delimiters_type::values.prefix;
printer<T>::print_body(container_, stream);
if (delimiters_type::values.postfix != NULL)
stream << delimiters_type::values.postfix;
}
private:
const T &container_;
};
// Specialization for pairs
template <typename T, typename TChar, typename TCharTraits,
typename TDelimiters>
template <typename T1, typename T2>
struct print_container_helper<T, TChar, TCharTraits,
TDelimiters>::printer<std::pair<T1, T2>> {
using ostream_type =
typename print_container_helper<T, TChar, TCharTraits,
TDelimiters>::ostream_type;
static void print_body(const std::pair<T1, T2> &c, ostream_type &stream) {
stream << c.first;
if (print_container_helper<T, TChar, TCharTraits,
TDelimiters>::delimiters_type::values
.delimiter != NULL)
stream << print_container_helper<T, TChar, TCharTraits,
TDelimiters>::delimiters_type::values
.delimiter;
stream << c.second;
}
};
// Specialization for tuples
template <typename T, typename TChar, typename TCharTraits,
typename TDelimiters>
template <typename... Args>
struct print_container_helper<T, TChar, TCharTraits,
TDelimiters>::printer<std::tuple<Args...>> {
using ostream_type =
typename print_container_helper<T, TChar, TCharTraits,
TDelimiters>::ostream_type;
using element_type = std::tuple<Args...>;
template <std::size_t I> struct Int {};
static void print_body(const element_type &c, ostream_type &stream) {
tuple_print(c, stream, Int<0>());
}
static void tuple_print(const element_type &, ostream_type &,
Int<sizeof...(Args)>) {}
static void
tuple_print(const element_type &c, ostream_type &stream,
typename std::conditional<sizeof...(Args) != 0, Int<0>,
std::nullptr_t>::type) {
stream << std::get<0>(c);
tuple_print(c, stream, Int<1>());
}
template <std::size_t N>
static void tuple_print(const element_type &c, ostream_type &stream, Int<N>) {
if (print_container_helper<T, TChar, TCharTraits,
TDelimiters>::delimiters_type::values
.delimiter != NULL)
stream << print_container_helper<T, TChar, TCharTraits,
TDelimiters>::delimiters_type::values
.delimiter;
stream << std::get<N>(c);
tuple_print(c, stream, Int<N + 1>());
}
};
// Prints a print_container_helper to the specified stream.
template <typename T, typename TChar, typename TCharTraits,
typename TDelimiters>
inline std::basic_ostream<TChar, TCharTraits> &operator<<(
std::basic_ostream<TChar, TCharTraits> &stream,
const print_container_helper<T, TChar, TCharTraits, TDelimiters> &helper) {
helper(stream);
return stream;
}
// Basic is_container template; specialize to derive from std::true_type for all
// desired container types
template <typename T>
struct is_container
: public std::integral_constant<bool,
detail::has_const_iterator<T>::value &&
detail::has_begin_end<T>::beg_value &&
detail::has_begin_end<T>::end_value> {};
template <typename T, std::size_t N>
struct is_container<T[N]> : std::true_type {};
template <std::size_t N> struct is_container<char[N]> : std::false_type {};
template <typename T> struct is_container<std::valarray<T>> : std::true_type {};
template <typename T1, typename T2>
struct is_container<std::pair<T1, T2>> : std::true_type {};
template <typename... Args>
struct is_container<std::tuple<Args...>> : std::true_type {};
// Default delimiters
template <typename T> struct delimiters<T, char> {
static const delimiters_values<char> values;
};
template <typename T>
const delimiters_values<char> delimiters<T, char>::values = {"[", ", ", "]"};
template <typename T> struct delimiters<T, wchar_t> {
static const delimiters_values<wchar_t> values;
};
template <typename T>
const delimiters_values<wchar_t> delimiters<T, wchar_t>::values = {L"[", L", ",
L"]"};
// Delimiters for (multi)set and unordered_(multi)set
template <typename T, typename TComp, typename TAllocator>
struct delimiters<::std::set<T, TComp, TAllocator>, char> {
static const delimiters_values<char> values;
};
template <typename T, typename TComp, typename TAllocator>
const delimiters_values<char>
delimiters<::std::set<T, TComp, TAllocator>, char>::values = {"{", ", ",
"}"};
template <typename T, typename TComp, typename TAllocator>
struct delimiters<::std::set<T, TComp, TAllocator>, wchar_t> {
static const delimiters_values<wchar_t> values;
};
template <typename T, typename TComp, typename TAllocator>
const delimiters_values<wchar_t>
delimiters<::std::set<T, TComp, TAllocator>, wchar_t>::values = {
L"{", L", ", L"}"};
template <typename T, typename TComp, typename TAllocator>
struct delimiters<::std::multiset<T, TComp, TAllocator>, char> {
static const delimiters_values<char> values;
};
template <typename T, typename TComp, typename TAllocator>
const delimiters_values<char> delimiters<::std::multiset<T, TComp, TAllocator>,
char>::values = {"{", ", ", "}"};
template <typename T, typename TComp, typename TAllocator>
struct delimiters<::std::multiset<T, TComp, TAllocator>, wchar_t> {
static const delimiters_values<wchar_t> values;
};
template <typename T, typename TComp, typename TAllocator>
const delimiters_values<wchar_t>
delimiters<::std::multiset<T, TComp, TAllocator>, wchar_t>::values = {
L"{", L", ", L"}"};
template <typename T, typename THash, typename TEqual, typename TAllocator>
struct delimiters<::std::unordered_set<T, THash, TEqual, TAllocator>, char> {
static const delimiters_values<char> values;
};
template <typename T, typename THash, typename TEqual, typename TAllocator>
const delimiters_values<char> delimiters<
::std::unordered_set<T, THash, TEqual, TAllocator>, char>::values = {
"{", ", ", "}"};
template <typename T, typename THash, typename TEqual, typename TAllocator>
struct delimiters<::std::unordered_set<T, THash, TEqual, TAllocator>, wchar_t> {
static const delimiters_values<wchar_t> values;
};
template <typename T, typename THash, typename TEqual, typename TAllocator>
const delimiters_values<wchar_t> delimiters<
::std::unordered_set<T, THash, TEqual, TAllocator>, wchar_t>::values = {
L"{", L", ", L"}"};
template <typename T, typename THash, typename TEqual, typename TAllocator>
struct delimiters<::std::unordered_multiset<T, THash, TEqual, TAllocator>,
char> {
static const delimiters_values<char> values;
};
template <typename T, typename THash, typename TEqual, typename TAllocator>
const delimiters_values<char> delimiters<
::std::unordered_multiset<T, THash, TEqual, TAllocator>, char>::values = {
"{", ", ", "}"};
template <typename T, typename THash, typename TEqual, typename TAllocator>
struct delimiters<::std::unordered_multiset<T, THash, TEqual, TAllocator>,
wchar_t> {
static const delimiters_values<wchar_t> values;
};
template <typename T, typename THash, typename TEqual, typename TAllocator>
const delimiters_values<wchar_t>
delimiters<::std::unordered_multiset<T, THash, TEqual, TAllocator>,
wchar_t>::values = {L"{", L", ", L"}"};
// Delimiters for pair and tuple
template <typename T1, typename T2> struct delimiters<std::pair<T1, T2>, char> {
static const delimiters_values<char> values;
};
template <typename T1, typename T2>
const delimiters_values<char> delimiters<std::pair<T1, T2>, char>::values = {
"(", ", ", ")"};
template <typename T1, typename T2>
struct delimiters<::std::pair<T1, T2>, wchar_t> {
static const delimiters_values<wchar_t> values;
};
template <typename T1, typename T2>
const delimiters_values<wchar_t>
delimiters<::std::pair<T1, T2>, wchar_t>::values = {L"(", L", ", L")"};
template <typename... Args> struct delimiters<std::tuple<Args...>, char> {
static const delimiters_values<char> values;
};
template <typename... Args>
const delimiters_values<char> delimiters<std::tuple<Args...>, char>::values = {
"(", ", ", ")"};
template <typename... Args> struct delimiters<::std::tuple<Args...>, wchar_t> {
static const delimiters_values<wchar_t> values;
};
template <typename... Args>
const delimiters_values<wchar_t>
delimiters<::std::tuple<Args...>, wchar_t>::values = {L"(", L", ", L")"};
// Type-erasing helper class for easy use of custom delimiters.
// Requires TCharTraits = std::char_traits<TChar> and TChar = char or wchar_t,
// and MyDelims needs to be defined for TChar. Usage: "cout <<
// pretty_print::custom_delims<MyDelims>(x)".
struct custom_delims_base {
virtual ~custom_delims_base() {}
virtual std::ostream &stream(::std::ostream &) = 0;
virtual std::wostream &stream(::std::wostream &) = 0;
};
template <typename T, typename Delims>
struct custom_delims_wrapper : custom_delims_base {
custom_delims_wrapper(const T &t_) : t(t_) {}
std::ostream &stream(std::ostream &s) {
return s << print_container_helper<T, char, std::char_traits<char>, Delims>(
t);
}
std::wostream &stream(std::wostream &s) {
return s << print_container_helper<T, wchar_t, std::char_traits<wchar_t>,
Delims>(t);
}
private:
const T &t;
};
template <typename Delims> struct custom_delims {
template <typename Container>
custom_delims(const Container &c)
: base(new custom_delims_wrapper<Container, Delims>(c)) {}
std::unique_ptr<custom_delims_base> base;
};
template <typename TChar, typename TCharTraits, typename Delims>
inline std::basic_ostream<TChar, TCharTraits> &
operator<<(std::basic_ostream<TChar, TCharTraits> &s,
const custom_delims<Delims> &p) {
return p.base->stream(s);
}
// A wrapper for a C-style array given as pointer-plus-size.
// Usage: std::cout << pretty_print_array(arr, n) << std::endl;
template <typename T> struct array_wrapper_n {
typedef const T *const_iterator;
typedef T value_type;
array_wrapper_n(const T *const a, size_t n) : _array(a), _n(n) {}
inline const_iterator begin() const { return _array; }
inline const_iterator end() const { return _array + _n; }
private:
const T *const _array;
size_t _n;
};
// A wrapper for hash-table based containers that offer local iterators to each
// bucket. Usage: std::cout << bucket_print(m, 4) << std::endl; (Prints bucket
// 5 of container m.)
template <typename T> struct bucket_print_wrapper {
typedef typename T::const_local_iterator const_iterator;
typedef typename T::size_type size_type;
const_iterator begin() const { return m_map.cbegin(n); }
const_iterator end() const { return m_map.cend(n); }
bucket_print_wrapper(const T &m, size_type bucket) : m_map(m), n(bucket) {}
private:
const T &m_map;
const size_type n;
};
} // namespace pretty_print
// Global accessor functions for the convenience wrappers
template <typename T>
inline pretty_print::array_wrapper_n<T> pretty_print_array(const T *const a,
size_t n) {
return pretty_print::array_wrapper_n<T>(a, n);
}
template <typename T>
pretty_print::bucket_print_wrapper<T> bucket_print(const T &m,
typename T::size_type n) {
return pretty_print::bucket_print_wrapper<T>(m, n);
}
// Main magic entry point: An overload snuck into namespace std.
// Can we do better?
namespace std {
// Prints a container to the stream using default delimiters
template <typename T, typename TChar, typename TCharTraits>
inline typename enable_if<::pretty_print::is_container<T>::value,
basic_ostream<TChar, TCharTraits> &>::type
operator<<(basic_ostream<TChar, TCharTraits> &stream, const T &container) {
return stream
<< ::pretty_print::print_container_helper<T, TChar, TCharTraits>(
container);
}
} // namespace std
#endif // H_PRETTY_PRINT
// Copyright 2019-2020 Yan Yan
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tensor.h"
#include "tensorview.h"
#include <algorithm>
#include <array>
#include <iostream>
#include <pybind11/functional.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11;
namespace tv {
template <typename Tarr> bool is_c_style(const Tarr &arr) {
return bool(arr.flags() & py::array::c_style);
}
template <typename T, int Rank = -1>
TensorView<T, Rank> arrayt2tv(py::array_t<T> arr) {
TV_ASSERT_INVALID_ARG(is_c_style(arr), "array must be c-contiguous array");
Shape shape;
for (int i = 0; i < arr.ndim(); ++i) {
shape.push_back(arr.shape(i));
}
if (Rank >= 0) {
TV_ASSERT_INVALID_ARG(shape.ndim() == Rank, "error");
}
return TensorView<T, Rank>(arr.mutable_data(), shape);
}
template <typename T, int Rank = -1>
TensorView<const T> carrayt2tv(py::array_t<T> arr) {
TV_ASSERT_INVALID_ARG(is_c_style(arr), "array must be c-contiguous array");
Shape shape;
for (int i = 0; i < arr.ndim(); ++i) {
shape.push_back(arr.shape(i));
}
if (Rank >= 0) {
TV_ASSERT_INVALID_ARG(shape.ndim() == Rank, "error");
}
return TensorView<const T, Rank>(arr.data(), shape);
}
template <typename Tarr> tv::DType get_array_tv_dtype(const Tarr &arr) {
switch (arr.dtype().kind()) {
case 'b':
return tv::bool_;
case 'i': {
switch (arr.itemsize()) {
case 1:
return tv::int8;
case 2:
return tv::int16;
case 4:
return tv::int32;
case 8:
return tv::int64;
default:
break;
}
}
case 'u': {
switch (arr.itemsize()) {
case 1:
return tv::uint8;
case 2:
return tv::uint16;
case 4:
return tv::uint32;
case 8:
return tv::uint64;
default:
break;
}
}
case 'f': {
switch (arr.itemsize()) {
case 2:
return tv::float16;
case 4:
return tv::float32;
case 8:
return tv::float64;
default:
break;
}
}
}
TV_THROW_RT_ERR("unknown dtype", arr.dtype().kind(), arr.itemsize());
}
template <typename Tarr> Tensor array2tensor(Tarr &arr) {
TV_ASSERT_INVALID_ARG(is_c_style(arr), "array must be c-contiguous array");
TensorShape shape;
for (int i = 0; i < arr.ndim(); ++i) {
shape.push_back(arr.shape(i));
}
return tv::from_blob(arr.mutable_data(), shape, get_array_tv_dtype(arr), -1);
}
template <typename T> Tensor arrayt2tensor(py::array_t<T> &arr) {
TV_ASSERT_INVALID_ARG(is_c_style(arr), "array must be c-contiguous array");
TensorShape shape;
for (int i = 0; i < arr.ndim(); ++i) {
shape.push_back(arr.shape(i));
}
return tv::from_blob(arr.mutable_data(), shape, tv::type_v<T>, -1);
}
template <typename TDType> py::dtype tv_dtype_to_py(TDType d) {
switch (d) {
case float32:
return py::dtype("float32");
case float64:
return py::dtype("float64");
case float16:
return py::dtype("float16");
case int32:
return py::dtype("int32");
case int16:
return py::dtype("int16");
case int8:
return py::dtype("int8");
case int64:
return py::dtype("int64");
case uint32:
return py::dtype("uint32");
case uint16:
return py::dtype("uint16");
case uint8:
return py::dtype("uint8");
case uint64:
return py::dtype("uint64");
case bool_:
return py::dtype("bool_");
default:;
}
TV_THROW_INVALID_ARG("unknown dtype", d);
}
// add template to define function in header
template <typename Ttensor> py::array tensor2array(Ttensor &tensor) {
// you cant call this function during GIL released.
TV_ASSERT_INVALID_ARG(tensor.device() == -1, "must be cpu tensor");
auto shape = tensor.shape();
std::vector<int> shape_vec(shape.begin(), shape.end());
auto dtype = tv_dtype_to_py(tensor.dtype());
// construct py::array will copy content from ptr.
// its expected because we can't transfer ownership from
// c++ tv::Tensor to numpy array when c++ object is deleted.
return py::array(dtype, shape_vec, {}, tensor.raw_data());
}
} // namespace tv
// Copyright 2019-2020 Yan Yan
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
tv::Tensor is a lightweight header-only tensor container
without template and annoying dependencies. no algorithm is implemented.
it should only be used when you want a no-template simple container but
dont want to link with libtorch.
If you can use libtorch, dont use tv::Tensor.
*/
#pragma once
#include "mp_helper.h"
#include "tensorview.h"
#include <cstring>
#include <iomanip>
#include <memory>
#include <type_traits>
#ifdef TV_CUDA
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#endif
namespace tv {
enum DType {
float32,
int32,
int16,
int8,
float64,
bool_,
uint8,
float16,
int64,
uint16,
uint32,
uint64
};
namespace detail {
using dtype_collection_t =
tv::mp_list_c<int, float32, int32, int16, int8, float64, bool_, uint8,
float16, int64, uint16, uint32, uint64>;
#ifdef TV_CUDA
using all_tensor_types_t =
std::tuple<float, double, int8_t, int16_t, int32_t, int64_t, uint8_t,
uint16_t, uint32_t, uint64_t, bool>;
#else
using all_tensor_types_t =
std::tuple<float, double, int8_t, int16_t, int32_t, int64_t, uint8_t,
uint16_t, uint32_t, uint64_t, bool>;
#endif
template <typename T> class TensorStorage {
public:
TensorStorage(size_t size, int device = -1, bool managed = false,
bool pinned = false)
: mSize(size), device_(device), managed_(managed), pinned_(pinned) {
if (size == 0) {
mPtr = nullptr;
} else {
if (device == -1) {
if (pinned_) {
#ifdef TV_CUDA
checkCudaErrors(cudaMallocHost(&mPtr, size * sizeof(T)));
#else
TV_THROW_INVALID_ARG("you need to define TV_CUDA to use pinned");
#endif
} else {
mPtr = new T[size];
}
} else {
#ifdef TV_CUDA
// we should select device in external
/*
int deviceCount;
cudaGetDeviceCount(&deviceCount);
if (device >= deviceCount) {
TV_THROW_INVALID_ARG("you provide device ", device,
" but you only have ", deviceCount, " device.");
}
cudaSetDevice(device);
*/
if (managed) {
checkCudaErrors(cudaMallocManaged(&this->mPtr, size * sizeof(T)));
} else {
checkCudaErrors(cudaMalloc(&mPtr, size * sizeof(T)));
}
#else
TV_THROW_INVALID_ARG("don't compiled with cuda");
#endif
}
}
}
TensorStorage(T *ptr, size_t size, int device)
: mSize(size), mPtr(ptr), from_blob_(true), device_(device) {}
virtual ~TensorStorage() {
if (empty()) {
return;
}
if (from_blob_) {
return;
}
if (device_ == -1) {
if (pinned_) {
#ifdef TV_CUDA
cudaFreeHost(mPtr);
#endif
} else {
delete[] mPtr;
}
} else {
#ifdef TV_CUDA
cudaFree(mPtr);
#endif
}
};
inline size_t size() const { return mSize; }
T *data() { return mPtr; }
const T *data() const { return mPtr; }
bool empty() const { return mPtr == nullptr || mSize == 0; }
bool managed() const { return managed_; }
bool pinned() const { return pinned_; }
int device() const { return device_; }
void zero_() {
if (device_ == -1) {
std::memset(data(), 0, mSize);
// std::fill(data(), data() + mSize, 0);
} else {
#ifdef TV_CUDA
checkCudaErrors(cudaMemset(data(), 0, mSize / sizeof(T)));
#else
TV_THROW_INVALID_ARG("don't compiled with cuda");
#endif
}
}
private:
size_t mSize = 0;
T *mPtr = nullptr;
bool from_blob_ = false;
int device_ = -1;
bool managed_ = false;
bool pinned_ = false;
};
template <typename T> size_t sizeof_dtype(T dtype) {
switch (dtype) {
case float32:
return sizeof(float);
case int8:
return sizeof(int8_t);
case int16:
return sizeof(int16_t);
case int32:
return sizeof(int32_t);
case float64:
return sizeof(double);
case int64:
return sizeof(int64_t);
case bool_:
return sizeof(bool);
case uint8:
return sizeof(uint8_t);
case uint16:
return sizeof(uint16_t);
case uint32:
return sizeof(uint32_t);
case uint64:
return sizeof(uint64_t);
case float16:
return 2;
default:
TV_THROW_RT_ERR("unsupported dtype");
}
return 0;
}
template <typename T> std::string typeString(T t) {
switch (t) {
case DType::bool_:
return "bool";
case DType::float32:
return "float32";
case DType::int8:
return "int8";
case DType::int16:
return "int16";
case DType::int32:
return "int32";
case DType::float64:
return "float64";
case DType::int64:
return "int64";
case DType::uint8:
return "uint8";
case DType::uint16:
return "uint16";
case DType::uint32:
return "uint32";
case DType::uint64:
return "uint64";
case DType::float16:
return "half";
default:
return "";
}
}
template <typename T> struct TypeToDtypeTraits;
template <> struct TypeToDtypeTraits<int32_t> {
static constexpr DType dtype = int32;
};
#ifdef TV_CUDA
template <> struct TypeToDtypeTraits<__half> {
static constexpr DType dtype = float16;
};
#endif
template <> struct TypeToDtypeTraits<float> {
static constexpr DType dtype = float32;
};
template <> struct TypeToDtypeTraits<double> {
static constexpr DType dtype = float64;
};
template <> struct TypeToDtypeTraits<int16_t> {
static constexpr DType dtype = int16;
};
template <> struct TypeToDtypeTraits<int8_t> {
static constexpr DType dtype = int8;
};
template <> struct TypeToDtypeTraits<int64_t> {
static constexpr DType dtype = int64;
};
template <> struct TypeToDtypeTraits<uint8_t> {
static constexpr DType dtype = uint8;
};
template <> struct TypeToDtypeTraits<uint16_t> {
static constexpr DType dtype = uint16;
};
template <> struct TypeToDtypeTraits<uint32_t> {
static constexpr DType dtype = uint32;
};
template <> struct TypeToDtypeTraits<uint64_t> {
static constexpr DType dtype = uint64;
};
template <> struct TypeToDtypeTraits<bool> {
static constexpr DType dtype = bool_;
};
template <> struct TypeToDtypeTraits<const int32_t> {
static constexpr DType dtype = int32;
};
#ifdef TV_CUDA
template <> struct TypeToDtypeTraits<const __half> {
static constexpr DType dtype = float16;
};
#endif
template <> struct TypeToDtypeTraits<const float> {
static constexpr DType dtype = float32;
};
template <> struct TypeToDtypeTraits<const double> {
static constexpr DType dtype = float64;
};
template <> struct TypeToDtypeTraits<const int16_t> {
static constexpr DType dtype = int16;
};
template <> struct TypeToDtypeTraits<const int8_t> {
static constexpr DType dtype = int8;
};
template <> struct TypeToDtypeTraits<const int64_t> {
static constexpr DType dtype = int64;
};
template <> struct TypeToDtypeTraits<const uint8_t> {
static constexpr DType dtype = uint8;
};
template <> struct TypeToDtypeTraits<const uint16_t> {
static constexpr DType dtype = uint16;
};
template <> struct TypeToDtypeTraits<const uint32_t> {
static constexpr DType dtype = uint32;
};
template <> struct TypeToDtypeTraits<const uint64_t> {
static constexpr DType dtype = uint64;
};
template <> struct TypeToDtypeTraits<const bool> {
static constexpr DType dtype = bool_;
};
} // namespace detail
template <class T> constexpr DType type_v = detail::TypeToDtypeTraits<T>::dtype;
template <class... Ts, typename F> bool dispatch_noexcept(DType t, F &&f) {
static_assert(sizeof...(Ts) > 0, "you need to provide at least one type");
bool notFound = true;
mp_for_each<mp_list<Ts...>>([=, &notFound, &f](auto I) {
if (type_v<TV_DECLTYPE(I)> == t && notFound) {
std::forward<F>(f)(TV_DECLTYPE(I)());
notFound = false;
}
});
return !notFound;
}
template <class... Ts, typename F> void dispatch(DType t, F &&f) {
if (!dispatch_noexcept<Ts...>(t, std::forward<F>(f))) {
std::stringstream ss;
mp_for_each<mp_list<Ts...>>([=, &ss](auto I) {
ss << detail::TypeToString<TV_DECLTYPE(I)>::value << " ";
});
TV_THROW_RT_ERR("unknown type", detail::typeString(t),
", available:", ss.str());
}
}
template <typename T, T... Is, typename F> void dispatch_scalar(T idx, F &&f) {
static_assert(sizeof...(Is) > 0,
"you need to provide at least one candidate");
bool notFound = true;
mp_for_each<mp_list_c<T, Is...>>([=, &notFound, &f](auto I) {
if (T(I) == idx && notFound) {
std::forward<F>(f)(I);
notFound = false;
}
});
if (notFound) {
std::stringstream ss;
mp_for_each<mp_list_c<T, Is...>>([=, &ss](auto I) { ss << T(I) << " "; });
TV_THROW_RT_ERR("unknown value", idx, ", available:", ss.str());
}
}
template <int... Is, typename F> bool dispatch_int_noexcept(int idx, F &&f) {
static_assert(sizeof...(Is) > 0,
"you need to provide at least one candidate");
bool notFound = true;
mp_for_each<mp_list_c<int, Is...>>([=, &notFound, &f](auto I) {
if (TV_DECLTYPE(I)::value == idx && notFound) {
std::forward<F>(f)(I);
notFound = false;
}
});
return !notFound;
}
template <int... Is, typename F, class BinaryPredicate>
bool dispatch_int_noexcept(int idx, BinaryPredicate p, F &&f) {
static_assert(sizeof...(Is) > 0,
"you need to provide at least one candidate");
bool notFound = true;
mp_for_each<mp_list_c<int, Is...>>([=, &notFound, &f](auto I) {
if (p(idx, TV_DECLTYPE(I)::value) && notFound) {
std::forward<F>(f)(I);
notFound = false;
}
});
return !notFound;
}
template <int... Is, typename F> void dispatch_int(int idx, F &&f) {
if (!dispatch_int_noexcept<Is...>(idx, std::forward<F>(f))) {
std::stringstream ss;
mp_for_each<mp_list_c<int, Is...>>(
[=, &ss](auto I) { ss << TV_DECLTYPE(I)::value << " "; });
TV_THROW_RT_ERR("unknown value", idx, ", available:", ss.str());
}
}
template <int... Is, typename F, class BinaryPredicate>
void dispatch_int(int idx, BinaryPredicate p, F &&f) {
// BinaryPredicate: BinaryPredicate(idx, candidate)
if (!dispatch_int_noexcept<Is...>(idx, p, std::forward<F>(f))) {
std::stringstream ss;
mp_for_each<mp_list_c<int, Is...>>(
[=, &ss](auto I) { ss << TV_DECLTYPE(I)::value << " "; });
TV_THROW_RT_ERR("unknown value", idx, ", available:", ss.str());
}
}
// Ts is pack of mp_list_c
template <class... Ts, typename Iterator, typename F>
bool dispatch_container_noexcept(Iterator begin, Iterator end, F &&f) {
static_assert(sizeof...(Ts) > 0,
"you need to provide at least one candidate");
bool notFound = true;
mp_for_each<mp_list<Ts...>>([=, &notFound, &f](auto I) {
using val_lst_t = TV_DECLTYPE(I);
auto val_lst_size = mp_size<val_lst_t>::value;
bool equal = true;
std::size_t count = 0;
auto iter = begin;
mp_for_each<val_lst_t>([&](auto E) {
if (iter == end || !equal) {
return;
}
if (count >= val_lst_size) {
TV_THROW_INVALID_ARG("iterator length invalid:", val_lst_size);
}
constexpr auto c = TV_DECLTYPE(E)::value;
if (c != *iter) {
equal = false;
}
++count;
std::advance(iter, 1);
});
if (count != val_lst_size || iter != end) {
equal = false;
}
if (equal && notFound) {
std::forward<F>(f)(I);
notFound = false;
}
});
return !notFound;
}
template <class... Ts, typename Iterator, typename F>
void dispatch_container(Iterator begin, Iterator end, F &&f) {
if (!dispatch_container_noexcept<Ts...>(begin, end, std::forward<F>(f))) {
std::stringstream ss;
ss << "unknown value [";
for (auto iter = begin; iter != end; std::advance(iter, 1)) {
ss << *iter << ",";
}
ss << "], available: ";
mp_for_each<mp_list<Ts...>>([=, &ss](auto I) {
ss << "[";
mp_for_each<TV_DECLTYPE(I)>(
[=, &ss](auto E) { ss << TV_DECLTYPE(E)::value << ","; });
ss << "]";
});
TV_THROW_RT_ERR(ss.str());
}
}
/*
template <int... Is, typename F> void dispatch_int(int idx, F &&f) {
return dispatch_scalar<int, Is...>(idx, f);
}
*/
template <class T> struct Dispatch;
template <template <class...> class T, class... Args>
struct Dispatch<T<Args...>> {
template <typename F> inline void operator()(DType t, F &&f) {
return dispatch<Args...>(t, std::forward<F>(f));
}
};
template <class T> struct DispatchContainer;
template <template <class...> class T, class... Args>
struct DispatchContainer<T<Args...>> {
template <typename Iterator, typename F>
inline void operator()(Iterator begin, Iterator end, F &&f) {
return dispatch_container<Args...>(begin, end, std::forward<F>(f));
}
};
template <class T> struct DispatchContainerNoexcept;
template <template <class...> class T, class... Args>
struct DispatchContainerNoexcept<T<Args...>> {
template <typename Iterator, typename F>
inline bool operator()(Iterator begin, Iterator end, F &&f) {
return dispatch_container_noexcept<Args...>(begin, end, std::forward<F>(f));
}
};
template <class T> struct DispatchInt;
// Args should be std::integral_constant<int, value>
// you need to use type_container<std::integral_constant<int, value>...>
// as template parameter of DispatchInt.
// tv::mp_list_c is ok.
template <template <class...> class T, class... Args>
struct DispatchInt<T<Args...>> {
template <typename F> inline void operator()(int t, F &&f) {
return dispatch_int<Args::value...>(t, std::forward<F>(f));
}
template <typename F, typename BinaryPredicate>
inline void operator()(int t, BinaryPredicate p, F &&f) {
return dispatch_int<Args::value...>(t, p, std::forward<F>(f));
}
};
constexpr size_t kTensorMaxDim = 10;
using TensorShape = ShapeBase<kTensorMaxDim, int64_t>;
struct Tensor {
Tensor() {}
Tensor(TensorShape shape, TensorShape stride, DType dtype, int device = -1,
bool pinned = false, bool managed = false)
: dtype_(dtype) {
TV_ASSERT_INVALID_ARG(!shape.empty(), "dont support empty shape");
storage_ = std::make_shared<detail::TensorStorage<uint8_t>>(
shape.size() * detail::sizeof_dtype(dtype), device, managed, pinned);
shape_ = shape;
stride_ = stride;
}
Tensor(TensorShape shape, DType dtype, int device = -1, bool pinned = false,
bool managed = false)
: dtype_(dtype) {
TV_ASSERT_INVALID_ARG(!shape.empty(), "dont support empty shape");
storage_ = std::make_shared<detail::TensorStorage<uint8_t>>(
shape.size() * detail::sizeof_dtype(dtype), device, managed, pinned);
shape_ = shape;
stride_ = shape.stride_rowmajor();
}
Tensor(void *ptr, TensorShape shape, TensorShape stride, DType dtype,
int device = -1)
: dtype_(dtype) {
TV_ASSERT_INVALID_ARG(!shape.empty(), "dont support empty shape");
storage_ = std::make_shared<detail::TensorStorage<uint8_t>>(
reinterpret_cast<uint8_t *>(ptr),
shape.size() * detail::sizeof_dtype(dtype), device);
shape_ = shape;
stride_ = stride;
}
Tensor(void *ptr, TensorShape shape, DType dtype, int device = -1)
: dtype_(dtype) {
TV_ASSERT_INVALID_ARG(!shape.empty(), "dont support empty shape");
storage_ = std::make_shared<detail::TensorStorage<uint8_t>>(
reinterpret_cast<uint8_t *>(ptr),
shape.size() * detail::sizeof_dtype(dtype), device);
shape_ = shape;
stride_ = shape.stride_rowmajor();
}
Tensor(const void *ptr, TensorShape shape, TensorShape stride, DType dtype,
int device = -1)
: dtype_(dtype), writeable_(false) {
TV_ASSERT_INVALID_ARG(!shape.empty(), "dont support empty shape");
storage_ = std::make_shared<detail::TensorStorage<uint8_t>>(
reinterpret_cast<uint8_t *>(const_cast<void *>(ptr)),
shape.size() * detail::sizeof_dtype(dtype), device);
shape_ = shape;
stride_ = stride;
}
Tensor(const void *ptr, TensorShape shape, DType dtype, int device = -1)
: dtype_(dtype), writeable_(false) {
TV_ASSERT_INVALID_ARG(!shape.empty(), "dont support empty shape");
storage_ = std::make_shared<detail::TensorStorage<uint8_t>>(
reinterpret_cast<uint8_t *>(const_cast<void *>(ptr)),
shape.size() * detail::sizeof_dtype(dtype), device);
shape_ = shape;
stride_ = shape.stride_rowmajor();
}
Tensor(std::initializer_list<int32_t> init)
: Tensor({int(init.size())}, tv::int32) {
std::copy(init.begin(), init.end(), data<int32_t>());
}
Tensor(std::initializer_list<int64_t> init)
: Tensor({int(init.size())}, tv::int64) {
std::copy(init.begin(), init.end(), data<int64_t>());
}
Tensor(std::initializer_list<float> init)
: Tensor({int(init.size())}, tv::float32) {
std::copy(init.begin(), init.end(), data<float>());
}
Tensor(std::initializer_list<double> init)
: Tensor({int(init.size())}, tv::float64) {
std::copy(init.begin(), init.end(), data<double>());
}
template <typename T, int Rank = -1,
template <class> class PtrTraits = DefaultPtrTraits,
typename Tindex = int,
typename std::enable_if<(Rank > 0), int>::type = 0>
TensorView<T, Rank, PtrTraits, Tindex> tview() {
using tv_shape_t =
typename TensorView<T, Rank, PtrTraits, Tindex>::tv_shape_t;
writable_check();
static_assert(Rank == -1 || Rank > 0, "error");
TV_ASSERT_RT_ERR(dtype_ == type_v<T>, "error");
tv_shape_t shape(Rank), stride(Rank);
for (int i = 0; i < Rank; ++i) {
shape[i] = shape_[i];
stride[i] = stride_[i];
}
return TensorView<T, Rank, PtrTraits, Tindex>(
reinterpret_cast<T *>(data<T>()), shape, stride);
}
template <typename T, int Rank = -1,
template <class> class PtrTraits = DefaultPtrTraits,
typename Tindex = int,
typename std::enable_if<Rank == -1, int>::type = 0>
TensorView<T, Rank, PtrTraits, Tindex> tview() {
writable_check();
static_assert(Rank == -1 || Rank > 0, "error");
TV_ASSERT_RT_ERR(dtype_ == type_v<T>, "error");
ShapeBase<TV_MAX_DIM, Tindex> shape(ndim()), stride(ndim());
for (size_t i = 0; i < ndim(); ++i) {
shape[i] = shape_[i];
stride[i] = stride_[i];
}
return TensorView<T, Rank, PtrTraits, Tindex>(
reinterpret_cast<T *>(data<T>()), shape, stride);
}
template <typename T, int Rank = -1,
template <class> class PtrTraits = DefaultPtrTraits,
typename Tindex = int,
typename std::enable_if<(Rank > 0), int>::type = 0>
TensorView<const std::remove_const_t<T>, Rank, PtrTraits, Tindex>
tview() const {
static_assert(Rank == -1 || Rank > 0, "error");
if (Rank > 0) {
TV_ASSERT_RT_ERR(Rank == ndim(), "error");
}
TV_ASSERT_RT_ERR(dtype_ == type_v<T>, "error");
ShapeBase<Rank == -1 ? TV_MAX_DIM : Rank, Tindex> shape(Rank), stride(Rank);
for (int i = 0; i < Rank; ++i) {
shape[i] = shape_[i];
stride[i] = stride_[i];
}
return TensorView<const std::remove_const_t<T>, Rank, PtrTraits, Tindex>(
reinterpret_cast<const std::remove_const_t<T> *>(data<T>()), shape,
stride);
}
template <typename T, int Rank = -1,
template <class> class PtrTraits = DefaultPtrTraits,
typename Tindex = int,
typename std::enable_if<Rank == -1, int>::type = 0>
TensorView<const std::remove_const_t<T>, Rank, PtrTraits, Tindex>
tview() const {
static_assert(Rank == -1 || Rank > 0, "error");
if (Rank > 0) {
TV_ASSERT_RT_ERR(Rank == ndim(), "error");
}
TV_ASSERT_RT_ERR(dtype_ == type_v<T>, "error");
ShapeBase<TV_MAX_DIM, Tindex> shape(ndim()), stride(ndim());
for (int i = 0; i < int(ndim()); ++i) {
shape[i] = shape_[i];
stride[i] = stride_[i];
}
return TensorView<const std::remove_const_t<T>, Rank, PtrTraits, Tindex>(
reinterpret_cast<const std::remove_const_t<T> *>(data<T>()), shape,
stride);
}
template <class... Inds> Tensor view(Inds... newShapes) const {
static_assert(sizeof...(newShapes) > 0, "dont support empty for now");
TensorShape shape{int(newShapes)...};
bool found_minus_1 = false;
for (size_t i = 0; i < shape.ndim(); ++i) {
if (!found_minus_1) {
if (shape[i] == -1) {
shape[i] = 1;
shape[i] = size() / shape.size();
found_minus_1 = true;
} else {
TV_ASSERT_INVALID_ARG(shape[i] > 0,
"shape except -1 must larger than 0");
}
} else {
TV_ASSERT_INVALID_ARG(shape[i] > 0, "multiple -1 in your argument.");
}
}
TV_ASSERT_RT_ERR(shape.size() == size(), "error");
Tensor res(*this);
res.shape_ = shape;
res.stride_ = shape.stride_rowmajor();
return res;
}
Tensor view(TensorShape shape) const {
TV_ASSERT_RT_ERR(shape.size() == size(), "error");
Tensor res(*this);
res.shape_ = shape;
res.stride_ = shape.stride_rowmajor();
return res;
}
Tensor operator[](int64_t index) {
TV_ASSERT_INVALID_ARG(ndim() > 1, "error");
if (index < 0) {
index += dim(0);
}
TV_ASSERT_INVALID_ARG(index < dim(0), "error");
Tensor res = Tensor();
res.storage_ = storage_;
res.shape_ = shape_.subshape(1);
res.offset_ = offset_ + index * stride_[0];
res.stride_ = stride_.subshape(1);
res.writeable_ = writeable_;
return res;
}
Tensor squeeze() const { return view(shape_.squeeze()); }
Tensor squeeze(int axis) const {
if (axis < 0) {
axis = ndim() + axis;
}
return view(shape_.squeeze(axis));
}
Tensor unsqueeze(int axis) const {
if (axis < 0) {
axis = ndim() + axis;
}
return view(shape_.unsqueeze(axis));
}
bool pinned() const { return storage_->pinned(); }
Tensor slice_first_axis(int start, int end) const {
TV_ASSERT_INVALID_ARG(contiguous_, "only support contiguous for now");
if (start < 0) {
start = shape_[0] + start;
}
if (end < 0) {
end = shape_[0] + end;
}
TV_ASSERT_INVALID_ARG(start < shape_[0], "start must small than dim 0");
TV_ASSERT_INVALID_ARG(start < end, "start must small than end");
size_t new_offset = start * shape_.prod(1) * itemsize();
Tensor res(*this);
TensorShape newshape(shape_);
newshape[0] = end - start;
res.shape_ = newshape;
res.stride_ = stride_;
res.offset_ = new_offset;
return res;
}
bool empty() const { return storage_->empty(); }
DType dtype() const { return dtype_; }
int device() const { return storage_->device(); }
size_t ndim() const { return shape_.ndim(); }
const TensorShape &shape() const { return shape_; }
const TensorShape &sizes() const { return shape_; }
const TensorShape &stride() const { return stride_; }
int dim(int idx) const {
if (idx < 0) {
TV_ASSERT_RT_ERR(shape_.size() + idx < shape_.size(), idx, shape_);
return shape_[shape_.size() + idx];
} else {
TV_ASSERT_RT_ERR(idx < int(shape_.size()), idx, shape_);
return shape_[idx];
}
}
const uint8_t *raw_data() const { return storage_->data() + offset_; }
size_t raw_size() const { return size() * itemsize(); }
size_t size() const { return shape_.size(); }
size_t size(int64_t idx) const { return dim(idx); }
size_t itemsize() const { return detail::sizeof_dtype(dtype_); }
Tensor &zero_() {
writable_check();
storage_->zero_();
return *this;
}
uint8_t *raw_data() {
writable_check();
return storage_->data() + offset_;
}
template <typename T> Tensor &fill_(T value) {
writable_check();
TV_ASSERT_RT_ERR(device() == -1, "error");
Dispatch<detail::all_tensor_types_t>()(dtype_, [&](auto I) {
using Treal = TV_DECLTYPE(I);
if (std::is_convertible<T, Treal>::value) {
auto ptr = reinterpret_cast<Treal *>(raw_data());
std::fill(ptr, ptr + size(), Treal(value));
} else {
TV_THROW_INVALID_ARG("not convertable from", type_s<T>, "to",
type_s<Treal>);
}
});
return *this;
}
template <typename T> T *data() {
TV_ASSERT_RT_ERR(dtype_ == type_v<T>, "error");
writable_check();
return reinterpret_cast<T *>(raw_data());
}
template <typename T> const T *data() const {
TV_ASSERT_RT_ERR(dtype_ == type_v<T>, "error");
return reinterpret_cast<const T *>(raw_data());
}
template <typename T> T *data_ptr() { return data<T>(); }
template <typename T> const T *data_ptr() const { return data<T>(); }
void *data_ptr() { return reinterpret_cast<void *>(raw_data()); }
const void *data_ptr() const {
return reinterpret_cast<const void *>(raw_data());
}
void copy_(const Tensor &tensor) {
writable_check();
TV_ASSERT_INVALID_ARG(contiguous_, "only support contiguous for now");
TV_ASSERT_RT_ERR(!empty() && !tensor.empty(), "must not empty");
TV_ASSERT_RT_ERR(size() == tensor.size(), "must have same size");
TV_ASSERT_RT_ERR(dtype() == tensor.dtype(), "must have same dtype",
detail::typeString(dtype()),
detail::typeString(tensor.dtype()));
if (device() == -1 && tensor.device() == -1) {
#ifdef TV_CUDA
host2host(storage_->data(), tensor.raw_data(),
size() * detail::sizeof_dtype(dtype_));
#else
std::copy(tensor.raw_data(),
tensor.raw_data() + size() * detail::sizeof_dtype(dtype_),
storage_->data());
#endif
}
#ifdef TV_CUDA
else if (device() >= 0 && tensor.device() == -1) {
host2dev(storage_->data(), tensor.raw_data(),
size() * detail::sizeof_dtype(dtype_));
} else if (device() == -1 && tensor.device() >= 0) {
dev2host(storage_->data(), tensor.raw_data(),
size() * detail::sizeof_dtype(dtype_));
} else if (device() >= 0 && tensor.device() >= 0) {
dev2dev(storage_->data(), tensor.raw_data(),
size() * detail::sizeof_dtype(dtype_));
}
#endif
else {
TV_THROW_RT_ERR("only support cpu tensor");
}
}
#ifdef TV_CUDA
void copy_(const Tensor &tensor, cudaStream_t stream) {
writable_check();
TV_ASSERT_INVALID_ARG(contiguous_, "only support contiguous for now");
TV_ASSERT_RT_ERR(!empty() && !tensor.empty(), "must not empty");
TV_ASSERT_RT_ERR(size() == tensor.size(), "must have same size");
TV_ASSERT_RT_ERR(dtype() == tensor.dtype(), "must have same dtype",
detail::typeString(dtype()),
detail::typeString(tensor.dtype()));
if (device() == -1 && tensor.device() == -1) {
host2host(storage_->data(), tensor.raw_data(),
size() * detail::sizeof_dtype(dtype_), stream);
} else if (device() >= 0 && tensor.device() == -1) {
host2dev(storage_->data(), tensor.raw_data(),
size() * detail::sizeof_dtype(dtype_), stream);
} else if (device() == -1 && tensor.device() >= 0) {
dev2host(storage_->data(), tensor.raw_data(),
size() * detail::sizeof_dtype(dtype_), stream);
} else if (device() >= 0 && tensor.device() >= 0) {
dev2dev(storage_->data(), tensor.raw_data(),
size() * detail::sizeof_dtype(dtype_), stream);
} else {
TV_THROW_RT_ERR("only support cpu tensor");
}
}
#endif
Tensor cpu() const {
if (storage_->device() == -1) {
// cpu() should always copy tensor.
return clone();
}
Tensor res(shape_, stride_, dtype_, -1, storage_->managed());
res.copy_(*this);
return res;
}
template <typename T> void copy_(const TensorView<T> &tensor, int device) {
writable_check();
TV_ASSERT_INVALID_ARG(contiguous_, "only support contiguous for now");
Tensor src = from_blob(tensor, device);
return copy_(src);
}
Tensor &operator=(const Tensor &tensor) {
dtype_ = tensor.dtype_;
storage_ = tensor.storage_;
shape_ = tensor.shape_;
writeable_ = tensor.writeable_;
offset_ = tensor.offset_;
stride_ = tensor.stride_;
return *this;
}
Tensor(const Tensor &tensor) {
dtype_ = tensor.dtype_;
storage_ = tensor.storage_;
shape_ = tensor.shape_;
writeable_ = tensor.writeable_;
offset_ = tensor.offset_;
stride_ = tensor.stride_;
}
Tensor clone(bool pinned = false) const {
TV_ASSERT_RT_ERR(!empty(), "clone a empty tensor");
TV_ASSERT_INVALID_ARG(contiguous_, "only support contiguous for now");
Tensor newtensor(shape_, stride_, dtype_, device(), pinned,
storage_->managed());
newtensor.copy_(*this);
return newtensor;
}
Tensor astype(DType dtype) {
if (dtype == dtype_) {
return clone();
}
TV_ASSERT_INVALID_ARG(device() == -1, "only support cpu tensor");
TV_ASSERT_INVALID_ARG(!empty(), "can't be used in empty tensor");
TV_ASSERT_INVALID_ARG(contiguous_, "only support contiguous for now");
auto tensor = Tensor();
Dispatch<detail::all_tensor_types_t>()(dtype, [&](auto Idst) {
using Tdst = TV_DECLTYPE(Idst);
Dispatch<detail::all_tensor_types_t>()(this->dtype_, [&](auto Icur) {
using Tcur = TV_DECLTYPE(Icur);
if (std::is_convertible<Tcur, Tdst>::value) {
auto ptr = this->data<Tcur>();
tensor = Tensor(this->shape_, this->stride_, dtype, this->device(),
this->pinned(), this->storage_->managed());
std::copy(ptr, ptr + this->size(), tensor.data<Tdst>());
} else {
TV_THROW_INVALID_ARG("not convertable from", type_s<Tcur>, "to",
type_s<Tdst>);
}
});
});
return tensor;
}
template <class... Ts, typename F> inline void dispatch(F &&f) {
return tv::dispatch<Ts...>(dtype_, std::forward<F>(f));
}
protected:
inline void writable_check() {
TV_ASSERT_RT_ERR(writeable_,
"you cant do non-const operation when not writable");
}
DType dtype_;
std::shared_ptr<detail::TensorStorage<uint8_t>> storage_;
TensorShape shape_;
size_t offset_ = 0;
TensorShape stride_;
private:
bool writeable_ = true;
bool contiguous_ = true;
};
template <typename Os> Os &operator<<(Os &os, const Tensor &tensor) {
TV_ASSERT_INVALID_ARG(tensor.device() == -1, "must be cpu tensor");
Dispatch<detail::all_tensor_types_t>()(tensor.dtype(), [&](auto I) {
using T = TV_DECLTYPE(I);
std::stringstream ss;
if (std::is_same<T, float>::value || std::is_same<T, double>::value) {
ss << std::setprecision(4);
}
os << tensor.tview<T, -1, DefaultPtrTraits, int64_t>().repr(ss);
});
return os;
}
inline Tensor from_blob(void *ptr, TensorShape shape, DType dtype, int device) {
return Tensor(ptr, shape, dtype, device);
}
inline Tensor from_blob(const void *ptr, TensorShape shape, DType dtype,
int device) {
return Tensor(ptr, shape, dtype, device);
}
} // namespace tv
\ No newline at end of file
// Copyright 2019-2020 Yan Yan
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "common.h"
#include "prettyprint.h"
#include <algorithm>
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <iterator>
#include <memory>
#include <sstream>
#include <type_traits>
#include <vector>
#ifdef TV_CUDA
#include <cuda_runtime_api.h>
#endif
namespace tv {
#if (defined(__clang__) && defined(__CUDA__)) || defined(__NVCC__)
#define TV_HOST_DEVICE_INLINE __forceinline__ __device__ __host__
#define TV_DEVICE_INLINE __forceinline__ __device__
#define TV_HOST_DEVICE __device__ __host__
#define TV_ASSERT(expr) assert(expr)
#elif defined(__CUDACC_RTC__)
#define TV_ASSERT(expr) assert(expr)
#define TV_HOST_DEVICE_INLINE __forceinline__ __device__
#define TV_DEVICE_INLINE __forceinline__ __device__
#define TV_HOST_DEVICE __device__ __host__
#else
#define TV_ASSERT(x) assert(x)
#define TV_HOST_DEVICE_INLINE inline
#define TV_HOST_DEVICE
#endif
#define TV_REQUIRE(expr, ...) \
{ \
if (!(expr)) { \
printf(__VA_ARGS__); \
assert(expr); \
} \
}
#define TV_CHECK_CUDA_ERR() \
{ \
auto __macro_err = cudaGetLastError(); \
if (__macro_err != cudaSuccess) { \
std::stringstream __macro_s; \
__macro_s << __FILE__ << " " << __LINE__ << "\n"; \
__macro_s << "cuda execution failed with error " << __macro_err; \
TV_BACKTRACE_PRINT(__macro_s); \
throw std::runtime_error(__macro_s.str()); \
} \
}
#define TV_CHECK_CUDA_ERR_V2(...) \
{ \
auto __macro_err = cudaGetLastError(); \
if (__macro_err != cudaSuccess) { \
std::stringstream __macro_s; \
__macro_s << __FILE__ << " " << __LINE__ << "\n"; \
__macro_s << "cuda execution failed with error " << __macro_err; \
__macro_s << " " << cudaGetErrorString(__macro_err) << "\n"; \
tv::sstream_print(__macro_s, __VA_ARGS__); \
TV_BACKTRACE_PRINT(__macro_s); \
throw std::runtime_error(__macro_s.str()); \
} \
}
#ifdef TV_CUDA
struct GPU {
GPU(cudaStream_t s = 0) : mStream(s) {}
virtual cudaStream_t getStream() const { return mStream; }
cudaStream_t mStream = 0;
};
#endif
struct CPU {};
#ifndef TV_MAX_DIM
#define TV_MAX_DIM 6
#endif
template <typename T> struct DefaultPtrTraits { typedef T *type; };
#if defined(__CUDACC__) || defined(__HIPCC__)
template <typename T> struct RestrictPtrTraits {
typedef T *__restrict__ type;
};
#endif
/*
template <typename T>
constexpr size_t calc_align(size_t ndim)
{
if (ndim * sizeof(T) == 1)
return 1;
else if (ndim * sizeof(T) == 2)
return 2;
else if (ndim * sizeof(T) <= 4 && ndim * sizeof(T) > 2)
return 4;
else if (ndim * sizeof(T) <= 8 && ndim * sizeof(T) > 4)
return 8;
else if (ndim * sizeof(T) <= 16 && ndim * sizeof(T) > 8)
return 16;
else if (ndim * sizeof(T) <= 32 && ndim * sizeof(T) > 16)
return 32;
else
return 64;
}
*/
namespace detail {
template <typename _InIter>
using _RequireInputIter = typename std::enable_if<std::is_convertible<
typename std::iterator_traits<_InIter>::iterator_category,
std::input_iterator_tag>::value>::type;
}
template <typename T, size_t MaxDim = TV_MAX_DIM>
struct /*alignas(calc_align<T>(MaxDim))*/ SimpleVector {
public:
TV_HOST_DEVICE_INLINE SimpleVector(){};
TV_HOST_DEVICE_INLINE SimpleVector(size_t count, T init = T())
: size_(count) {
for (size_t i = 0; i < count; ++i) {
array_[i] = init;
}
};
template <typename Iterator, typename = detail::_RequireInputIter<Iterator>>
SimpleVector(Iterator first, Iterator last) {
size_ = 0;
for (; first != last; ++first) {
if (size_ >= MaxDim) {
TV_THROW_INVALID_ARG("iterator too long");
}
array_[size_++] = *first;
}
};
TV_HOST_DEVICE_INLINE SimpleVector(std::initializer_list<T> q) {
TV_ASSERT(q.size() <= MaxDim);
size_ = 0;
for (T s : q) {
array_[size_++] = s;
}
size_ = q.size();
}
SimpleVector(const std::vector<T> &arr) {
TV_ASSERT(arr.size() <= MaxDim);
for (size_t i = 0; i < arr.size(); ++i) {
array_[i] = arr[i];
}
size_ = arr.size();
}
TV_HOST_DEVICE_INLINE SimpleVector(const SimpleVector<T, MaxDim> &arr) {
TV_ASSERT(arr.size() <= MaxDim);
for (size_t i = 0; i < arr.size(); ++i) {
array_[i] = arr[i];
}
size_ = arr.size();
}
TV_HOST_DEVICE_INLINE T &operator[](int idx) {
#ifdef TV_DEBUG
TV_ASSERT(idx >= 0 && idx < size_);
#endif
return array_[idx];
}
TV_HOST_DEVICE_INLINE const T &operator[](int idx) const {
#ifdef TV_DEBUG
TV_ASSERT(idx >= 0 && idx < size_);
#endif
return array_[idx];
}
TV_HOST_DEVICE_INLINE void push_back(T s) {
#ifdef TV_DEBUG
TV_ASSERT(size_ < MaxDim);
#endif
array_[size_] = s;
size_++;
}
TV_HOST_DEVICE_INLINE void pop_back() {
#ifdef TV_DEBUG
TV_ASSERT(size_ > 0);
#endif
size_--;
}
TV_HOST_DEVICE_INLINE size_t size() const { return size_; }
TV_HOST_DEVICE_INLINE const T *data() const { return array_; }
TV_HOST_DEVICE_INLINE T *data() { return array_; }
TV_HOST_DEVICE_INLINE size_t empty() const { return size_ == 0; }
typedef size_t size_type;
class iterator {
public:
typedef iterator self_type;
typedef T value_type;
typedef T &reference;
typedef T *pointer;
typedef std::forward_iterator_tag iterator_category;
typedef std::ptrdiff_t difference_type;
TV_HOST_DEVICE_INLINE iterator(pointer ptr) : ptr_(ptr) {}
TV_HOST_DEVICE_INLINE self_type operator++(int junk) {
self_type i = *this;
ptr_++;
return i;
}
TV_HOST_DEVICE_INLINE self_type operator++() {
ptr_++;
return *this;
}
TV_HOST_DEVICE_INLINE reference operator*() { return *ptr_; }
TV_HOST_DEVICE_INLINE pointer operator->() { return ptr_; }
TV_HOST_DEVICE_INLINE bool operator==(const self_type &rhs) const {
return ptr_ == rhs.ptr_;
}
TV_HOST_DEVICE_INLINE bool operator!=(const self_type &rhs) const {
return ptr_ != rhs.ptr_;
}
private:
pointer ptr_;
};
class const_iterator {
public:
typedef const_iterator self_type;
typedef T value_type;
typedef const T &reference;
typedef const T *pointer;
typedef std::ptrdiff_t difference_type;
typedef std::forward_iterator_tag iterator_category;
TV_HOST_DEVICE_INLINE const_iterator(pointer ptr) : ptr_(ptr) {}
TV_HOST_DEVICE_INLINE self_type operator++(int junk) {
self_type i = *this;
ptr_++;
return i;
}
TV_HOST_DEVICE_INLINE self_type operator++() {
ptr_++;
return *this;
}
TV_HOST_DEVICE_INLINE reference operator*() { return *ptr_; }
TV_HOST_DEVICE_INLINE pointer operator->() { return ptr_; }
TV_HOST_DEVICE_INLINE bool operator==(const self_type &rhs) const {
return ptr_ == rhs.ptr_;
}
TV_HOST_DEVICE_INLINE bool operator!=(const self_type &rhs) const {
return ptr_ != rhs.ptr_;
}
private:
pointer ptr_;
};
TV_HOST_DEVICE_INLINE iterator begin() { return iterator(array_); }
TV_HOST_DEVICE_INLINE iterator end() { return iterator(array_ + size_); }
TV_HOST_DEVICE_INLINE const_iterator begin() const {
return const_iterator(array_);
}
TV_HOST_DEVICE_INLINE const_iterator end() const {
return const_iterator(array_ + size_);
}
TV_HOST_DEVICE_INLINE const_iterator cbegin() const {
return const_iterator(array_);
}
TV_HOST_DEVICE_INLINE const_iterator cend() const {
return const_iterator(array_ + size_);
}
protected:
T array_[MaxDim];
size_t size_ = 0;
};
template <typename T, size_t MaxDim>
bool operator==(const SimpleVector<T, MaxDim> &lfs,
const SimpleVector<T, MaxDim> &rfs) {
if (lfs.size() != rfs.size())
return false;
for (size_t i = 0; i < lfs.size(); ++i) {
if (lfs[i] != rfs[i])
return false;
}
return true;
}
template <typename T, size_t MaxDim>
bool operator!=(const SimpleVector<T, MaxDim> &lfs,
const SimpleVector<T, MaxDim> &rfs) {
return !(lfs == rfs);
}
struct Slice {
template <class... Integers> TV_HOST_DEVICE_INLINE Slice(Integers... ints) {
static_assert(sizeof...(ints) <= 3, "slice init must smaller than 3");
SimpleVector<int, 3> slices{int(ints)...};
slices_[0] = -1;
slices_[1] = -1;
slices_[2] = -1;
for (size_t i = 0; i < slices.size(); ++i) {
slices_[i] = slices[i];
}
}
TV_HOST_DEVICE_INLINE Slice() {
slices_[0] = -1;
slices_[1] = -1;
slices_[2] = -1;
}
template <typename T>
TV_HOST_DEVICE_INLINE Slice(std::initializer_list<T> slice) {
slices_[0] = -1;
slices_[1] = -1;
slices_[2] = -1;
TV_ASSERT(slice.size() <= 3);
int idx = 0;
for (T s : slice) {
slices_[idx] = int(s);
++idx;
}
}
TV_HOST_DEVICE_INLINE int &operator[](int idx) {
#ifdef TV_DEBUG
TV_ASSERT(idx >= 0 && idx < 3);
#endif
return slices_[idx];
}
TV_HOST_DEVICE_INLINE const int &operator[](int idx) const {
#ifdef TV_DEBUG
TV_ASSERT(idx >= 0 && idx < 3);
#endif
return slices_[idx];
}
protected:
int slices_[3];
};
template <size_t MaxDim = TV_MAX_DIM, typename Tindex = int>
struct ShapeBase : public SimpleVector<Tindex, MaxDim> {
TV_HOST_DEVICE_INLINE ShapeBase() : SimpleVector<Tindex, MaxDim>(){};
TV_HOST_DEVICE_INLINE ShapeBase(std::initializer_list<Tindex> shape)
: SimpleVector<Tindex, MaxDim>(shape) {}
TV_HOST_DEVICE_INLINE ShapeBase(SimpleVector<Tindex, MaxDim> vec)
: SimpleVector<Tindex, MaxDim>(vec) {}
template <typename T, template <class...> class Container>
ShapeBase(Container<T> shape) : SimpleVector<Tindex, MaxDim>(shape) {}
TV_HOST_DEVICE_INLINE ShapeBase(const ShapeBase<MaxDim> &shape)
: SimpleVector<Tindex, MaxDim>(shape) {}
ShapeBase(const std::vector<Tindex> &arr)
: SimpleVector<Tindex, MaxDim>(arr) {}
ShapeBase<MaxDim, Tindex> &
operator=(const ShapeBase<MaxDim, Tindex> &shape) = default;
TV_HOST_DEVICE ShapeBase<MaxDim, Tindex> subshape(Tindex start,
Tindex end) const {
#ifdef TV_DEBUG
TV_ASSERT(start >= 0 && end <= this->size_ && end > start);
#endif
ShapeBase<MaxDim, Tindex> shape;
for (Tindex i = start; i < end; ++i) {
shape.push_back(this->array_[i]);
}
return shape;
}
TV_HOST_DEVICE ShapeBase<MaxDim, Tindex> subshape(Tindex start) const {
#ifdef TV_DEBUG
TV_ASSERT(start >= 0 && start <= this->size_);
#endif
ShapeBase<MaxDim, Tindex> shape;
for (size_t i = start; i < this->size_; ++i) {
shape.push_back(this->array_[i]);
}
return shape;
}
TV_HOST_DEVICE size_t size() const {
if (this->size_ == 0)
return 0;
size_t s = 1;
for (int i = 0; i < int(this->size_); ++i) {
s *= this->array_[i];
}
return s;
}
TV_HOST_DEVICE_INLINE size_t ndim() const { return this->size_; }
TV_HOST_DEVICE ShapeBase<MaxDim, Tindex> squeeze() const {
ShapeBase<MaxDim, Tindex> shape;
for (size_t i = 0; i < this->size_; ++i) {
if (this->array_[i] != 1)
shape.push_back(this->array_[i]);
}
if (shape.empty()) {
// dont support empty shape for now
shape.push_back(1);
}
return shape;
}
template <size_t MaxDim2 = MaxDim>
TV_HOST_DEVICE ShapeBase<MaxDim2, Tindex> squeeze(int dim) const {
static_assert(MaxDim2 >= MaxDim - 1, "error");
ShapeBase<MaxDim2, Tindex> shape;
for (size_t i = 0; i < this->size_; ++i) {
if (i != size_t(dim) || this->array_[i] != 1)
shape.push_back(this->array_[i]);
}
return shape;
}
template <size_t MaxDim2 = MaxDim>
TV_HOST_DEVICE ShapeBase<MaxDim2, Tindex> unsqueeze(int dim) const {
static_assert(MaxDim2 >= MaxDim - 1, "error");
ShapeBase<MaxDim2, Tindex> shape;
for (size_t i = 0; i < this->size_; ++i) {
if (i == size_t(dim))
shape.push_back(1);
shape.push_back(this->array_[i]);
}
return shape;
}
TV_HOST_DEVICE size_t prod(Tindex start = 0) const {
size_t res = 1;
for (size_t i = start; i < this->size_; ++i) {
res *= this->array_[i];
}
return res;
}
template <size_t MaxDim2 = MaxDim>
TV_HOST_DEVICE ShapeBase<MaxDim2, Tindex> stride_rowmajor() {
static_assert(MaxDim2 >= MaxDim, "error");
Tindex p = Tindex(1);
ShapeBase<MaxDim2, Tindex> res(this->size_);
for (Tindex i = this->size_ - 1; i >= 0; --i) {
res[i] = p;
p *= this->array_[i];
}
return res;
}
};
using Shape = ShapeBase<TV_MAX_DIM, int>;
template <class... Inds>
TV_HOST_DEVICE_INLINE unsigned rowArrayIdx(std::vector<int> &shape,
Inds... indexes) {
unsigned offset = 0;
unsigned m = 1;
int indexes_vec[sizeof...(indexes)] = {indexes...};
#ifdef TV_DEBUG
TV_ASSERT(sizeof...(indexes) == shape.size());
#endif
#if defined(__CUDA_ARCH__)
#pragma unroll
#endif
for (int i = sizeof...(indexes) - 1; i >= 0; --i) {
offset += m * indexes_vec[i];
m *= shape[i];
}
return offset;
}
TV_HOST_DEVICE_INLINE unsigned rowArrayIdx(std::vector<int> &shape,
std::vector<int> &indexes_vec) {
unsigned offset = 0;
unsigned m = 1;
for (int i = shape.size() - 1; i >= 0; --i) {
offset += m * indexes_vec[i];
m *= shape[i];
}
return offset;
}
template <class... Inds>
TV_HOST_DEVICE_INLINE unsigned rowArrayIdx(const Shape &shape,
Inds... indexes) {
unsigned offset = 0;
unsigned m = 1;
int indexes_vec[sizeof...(indexes)] = {indexes...};
#if defined(__CUDA_ARCH__)
#pragma unroll
#endif
for (int i = sizeof...(indexes) - 1; i >= 0; --i) {
offset += m * indexes_vec[i];
m *= shape[i];
}
return offset;
}
TV_HOST_DEVICE_INLINE unsigned rowArrayIdx(const Shape &shape,
const Shape &indexes_vec) {
unsigned offset = 0;
unsigned m = 1;
for (int i = indexes_vec.ndim() - 1; i >= 0; --i) {
offset += m * indexes_vec[i];
m *= shape[i];
}
return offset;
}
template <typename Index, unsigned NDim>
TV_HOST_DEVICE_INLINE unsigned rowArrayIdx(const Index *indexes,
const Index *shape) {
unsigned offset = 0;
unsigned m = 1;
#if defined(__CUDA_ARCH__)
#pragma unroll
#endif
for (int i = NDim - 1; i >= 0; --i) {
offset += m * indexes[i];
m *= shape[i];
}
return offset;
}
template <typename Index, unsigned NDim>
TV_HOST_DEVICE_INLINE Index rowArrayIdxInv(Index index, Index *output,
const Index *shape) {
#pragma unroll
for (int i = NDim - 1; i >= 0; --i) {
output[i] = index % shape[i];
index -= output[i];
index /= shape[i];
}
return index;
}
template <typename Index>
TV_HOST_DEVICE Index rowArrayIdxInv(Index index, Index *output,
const Index *shape, int ndim) {
for (int i = ndim - 1; i >= 0; --i) {
output[i] = index % shape[i];
index -= output[i];
index /= shape[i];
}
return index;
}
template <int N> struct ArrayIndexRowMajorReverse {
template <typename TShape, typename T, class... Ts>
TV_HOST_DEVICE_INLINE static unsigned run(const TShape *shape, T index,
Ts... inds) {
return index +
shape[N - 1] * ArrayIndexRowMajorReverse<N - 1>::run(shape, inds...);
}
template <typename T, class... Ts>
TV_HOST_DEVICE_INLINE static unsigned runShape(const Shape &shape, T index,
Ts... inds) {
return index +
shape[N - 1] * ArrayIndexRowMajorReverse<N - 1>::run(shape, inds...);
}
};
template <> struct ArrayIndexRowMajorReverse<1> {
template <typename TShape, typename T>
TV_HOST_DEVICE_INLINE static unsigned run(const TShape *shape, T idx) {
return idx;
}
template <typename T>
TV_HOST_DEVICE_INLINE static unsigned runShape(const Shape &shape, T idx) {
return idx;
}
};
template <int N, int Ndim> struct ArrayIndexRowMajor {
// this array index provide almost same compiled code. compile it in
// https://godbolt.org/ for more details.
template <typename TShape, typename Tinit, typename T, class... Ts>
TV_HOST_DEVICE_INLINE static unsigned run(const TShape *shape, Tinit start,
T index, Ts... inds) {
return ArrayIndexRowMajor<N - 1, Ndim>::run(
shape, (index + start) * shape[Ndim - N + 1], inds...);
}
template <typename Tinit, typename T, class... Ts>
TV_HOST_DEVICE_INLINE static unsigned
runShape(const Shape &shape, Tinit start, T index, Ts... inds) {
return ArrayIndexRowMajor<N - 1, Ndim>::runShape(
shape, (index + start) * shape[Ndim - N + 1], inds...);
}
template <typename TShape, typename Tinit>
TV_HOST_DEVICE_INLINE static unsigned
runPtrs(const TShape *indexes, const TShape *shape, Tinit start) {
return ArrayIndexRowMajor<N - 1, Ndim>::runPtrs(
indexes, shape, (indexes[Ndim - N] + start) * shape[Ndim - N + 1]);
}
};
template <int Ndim> struct ArrayIndexRowMajor<1, Ndim> {
template <typename TShape, typename Tinit, typename T>
TV_HOST_DEVICE_INLINE static unsigned run(const TShape *shape, Tinit start,
T idx) {
return start + idx;
}
template <typename Tinit, typename T>
TV_HOST_DEVICE_INLINE static unsigned runShape(const Shape &shape,
Tinit start, T idx) {
return start + idx;
}
template <typename TShape, typename Tinit>
TV_HOST_DEVICE_INLINE static unsigned
runPtrs(const TShape *indexes, const TShape *shape, Tinit start) {
return start + indexes[Ndim - 1];
}
};
template <> struct ArrayIndexRowMajor<0, 0> {
template <typename TShape, typename Tinit>
TV_HOST_DEVICE_INLINE static unsigned run(const TShape *shape, Tinit start) {
return 0;
}
template <typename Tinit>
TV_HOST_DEVICE_INLINE static unsigned runShape(const Shape &shape,
Tinit start) {
return 0;
}
template <typename TShape, typename Tinit>
TV_HOST_DEVICE_INLINE static unsigned
runPtrs(const TShape *indexes, const TShape *shape, Tinit start) {
return 0;
}
};
template <int N, int Ndim> struct ArrayIndexStride {
// this array index provide almost same compiled code. compile it in
// https://godbolt.org/ for more details.
template <typename TShape, typename Tinit, typename T, class... Ts>
TV_HOST_DEVICE_INLINE static unsigned run(const TShape *stride, Tinit start,
T index, Ts... inds) {
return ArrayIndexStride<N - 1, Ndim>::run(
stride, start + index * stride[Ndim - N + 1], inds...);
}
};
template <int Ndim> struct ArrayIndexStride<1, Ndim> {
template <typename TShape, typename Tinit, typename T>
TV_HOST_DEVICE_INLINE static unsigned run(const TShape *stride, Tinit start,
T idx) {
return start + idx * stride[Ndim - 1];
}
};
#if __cplusplus >= 201703L
template <size_t... N, class T, class... Ts>
TV_HOST_DEVICE_INLINE T array_index_stride(const T *stride, Ts... ids) {
return ((stride[N] * std::get<N>(std::forward_as_tuple(ids...))) + ...);
}
#endif
namespace detail {
template <typename T> struct TypeToString;
template <> struct TypeToString<bool> {
static constexpr const char *value = "bool";
};
template <> struct TypeToString<const bool> {
static constexpr const char *value = "bool";
};
template <> struct TypeToString<int32_t> {
static constexpr const char *value = "int32";
};
template <> struct TypeToString<float> {
static constexpr const char *value = "float";
};
template <> struct TypeToString<double> {
static constexpr const char *value = "double";
};
template <> struct TypeToString<int16_t> {
static constexpr const char *value = "int16";
};
template <> struct TypeToString<int8_t> {
static constexpr const char *value = "int8";
};
template <> struct TypeToString<int64_t> {
static constexpr const char *value = "int64";
};
template <> struct TypeToString<uint8_t> {
static constexpr const char *value = "uint8";
};
template <> struct TypeToString<uint16_t> {
static constexpr const char *value = "uint16";
};
template <> struct TypeToString<uint32_t> {
static constexpr const char *value = "uint32";
};
template <> struct TypeToString<uint64_t> {
static constexpr const char *value = "uint64";
};
template <> struct TypeToString<const int32_t> {
static constexpr const char *value = "int32";
};
template <> struct TypeToString<const float> {
static constexpr const char *value = "float";
};
template <> struct TypeToString<const double> {
static constexpr const char *value = "double";
};
template <> struct TypeToString<const int16_t> {
static constexpr const char *value = "int16";
};
template <> struct TypeToString<const int8_t> {
static constexpr const char *value = "int8";
};
template <> struct TypeToString<const int64_t> {
static constexpr const char *value = "int64";
};
template <> struct TypeToString<const uint8_t> {
static constexpr const char *value = "uint8";
};
template <> struct TypeToString<const uint16_t> {
static constexpr const char *value = "uint16";
};
template <> struct TypeToString<const uint32_t> {
static constexpr const char *value = "uint32";
};
template <> struct TypeToString<const uint64_t> {
static constexpr const char *value = "uint64";
};
} // namespace detail
template <typename T>
constexpr const char *type_s = detail::TypeToString<T>::value;
namespace detail {
template <typename T, int Rank,
template <class> class PtrTraits = DefaultPtrTraits,
typename Tindex = int>
struct TensorAccesserBase {
static constexpr int rank_value = Rank;
using ptr_t = typename PtrTraits<T>::type;
static_assert(Rank > 0, "error");
explicit TV_HOST_DEVICE_INLINE TensorAccesserBase(ptr_t ptr,
const Tindex *stride_ptr)
: ptr_(ptr), stride_ptr_(stride_ptr) {}
TV_HOST_DEVICE_INLINE ptr_t data() { return ptr_; }
TV_HOST_DEVICE_INLINE const ptr_t data() const { return ptr_; }
template <class... Inds> TV_HOST_DEVICE_INLINE T &operator()(Inds... inds) {
static_assert(sizeof...(inds) == Rank, "error");
return ptr_[ArrayIndexStride<Rank, Rank>::run(stride_ptr_, 0, inds...)];
}
template <class... Inds>
TV_HOST_DEVICE_INLINE const T &operator()(Inds... inds) const {
static_assert(sizeof...(inds) == Rank, "error");
return ptr_[ArrayIndexStride<Rank, Rank>::run(stride_ptr_, 0, inds...)];
}
protected:
ptr_t ptr_;
const Tindex *stride_ptr_;
};
} // namespace detail
template <typename T, int Rank,
template <class> class PtrTraits = DefaultPtrTraits,
typename Tindex = int>
struct TensorAccesser
: public detail::TensorAccesserBase<T, Rank, PtrTraits, Tindex> {
using ptr_t = typename PtrTraits<T>::type;
static_assert(Rank > 0, "error");
explicit TV_HOST_DEVICE_INLINE TensorAccesser(ptr_t ptr,
const Tindex *stride_ptr)
: detail::TensorAccesserBase<T, Rank, PtrTraits, Tindex>(ptr,
stride_ptr) {}
TV_HOST_DEVICE_INLINE TensorAccesser<T, Rank - 1, PtrTraits, Tindex>
operator[](int i) {
return TensorAccesser<T, Rank - 1, PtrTraits, Tindex>(
this->ptr_ + this->stride_ptr_[0] * i, this->stride_ptr_ + 1);
}
TV_HOST_DEVICE_INLINE TensorAccesser<T, Rank - 1, PtrTraits, Tindex>
operator[](int i) const {
return TensorAccesser<T, Rank - 1, PtrTraits, Tindex>(
this->ptr_ + this->stride_ptr_[0] * i, this->stride_ptr_ + 1);
}
};
template <typename T, template <class> class PtrTraits, typename Tindex>
struct TensorAccesser<T, 1, PtrTraits, Tindex>
: public detail::TensorAccesserBase<T, 1, PtrTraits, Tindex> {
using ptr_t = typename PtrTraits<T>::type;
explicit TV_HOST_DEVICE_INLINE TensorAccesser(ptr_t ptr,
const Tindex *stride_ptr)
: detail::TensorAccesserBase<T, 1, PtrTraits, Tindex>(ptr, stride_ptr) {}
TV_HOST_DEVICE_INLINE T &operator[](int i) {
return this->ptr_[this->stride_ptr_[0] * i];
}
TV_HOST_DEVICE_INLINE T &operator[](int i) const {
return this->ptr_[this->stride_ptr_[0] * i];
}
};
template <typename T, int Rank = -1,
template <class> class PtrTraits = DefaultPtrTraits,
typename Tindex = int>
struct TensorView {
static constexpr int rank_value = Rank;
using ptr_t = typename PtrTraits<T>::type;
using tv_shape_t = ShapeBase<Rank == -1 ? TV_MAX_DIM : Rank, Tindex>;
using no_cv_type = typename std::remove_cv<T>::type;
static_assert(Rank == -1 || Rank > 0, "error");
TV_HOST_DEVICE_INLINE TensorView() {}
explicit TV_HOST_DEVICE_INLINE TensorView(ptr_t ptr, tv_shape_t shape)
: ptr_(ptr), shape_(shape), stride_(shape.stride_rowmajor()) {}
explicit TV_HOST_DEVICE_INLINE TensorView(ptr_t ptr, tv_shape_t shape,
tv_shape_t stride)
: ptr_(ptr), shape_(shape), stride_(stride) {}
operator TensorView<const no_cv_type, Rank, PtrTraits, Tindex>() {
return TensorView<const no_cv_type, Rank, PtrTraits, Tindex>(ptr_, shape_);
} // conversion function
template <class... Inds> TV_HOST_DEVICE_INLINE T &operator()(Inds... inds) {
static_assert(Rank == -1 || sizeof...(inds) == Rank, "error");
#if defined TV_DEBUG
int idxes[sizeof...(Inds)]{int(inds)...};
TV_REQUIRE(sizeof...(inds) == shape_.ndim(),
"you provide %d indexes, but dim is %d\n", sizeof...(inds),
shape_.ndim());
for (int i = 0; i < sizeof...(inds); ++i) {
TV_REQUIRE(idxes[i] >= 0 && idxes[i] < shape_[i],
"index-%d(%d) out-of-range: [0, %d)\n", i, idxes[i],
shape_[i]);
}
#endif
constexpr int Ndim = sizeof...(Inds);
return ptr_[ArrayIndexRowMajor<Ndim, Ndim>::runShape(shape_, 0, inds...)];
}
template <class... Inds>
TV_HOST_DEVICE_INLINE const T &operator()(Inds... inds) const {
static_assert(Rank == -1 || sizeof...(inds) == Rank, "error");
#if defined TV_DEBUG
int idxes[sizeof...(Inds)]{int(inds)...};
TV_REQUIRE(sizeof...(inds) == shape_.ndim(),
"you provide %d indexes, but dim is %d\n", sizeof...(inds),
shape_.ndim());
for (int i = 0; i < sizeof...(inds); ++i) {
TV_REQUIRE(idxes[i] >= 0 && idxes[i] < shape_[i],
"index-%d(%d) out-of-range: [0, %d)\n", i, idxes[i],
shape_[i]);
}
#endif
constexpr int Ndim = sizeof...(Inds);
return ptr_[ArrayIndexRowMajor<Ndim, Ndim>::runShape(shape_, 0, inds...)];
}
TV_HOST_DEVICE_INLINE T &operator()() {
static_assert(Rank == -1 || 0 == Rank, "error");
#if defined TV_DEBUG
TV_REQUIRE(ptr_ != nullptr, "you want get value but the view is empty.%s",
"\n");
TV_REQUIRE(shape_.ndim() == 0, "you provide 0 indexes, but dim is %ld\n",
shape_.ndim());
#endif
return ptr_[0];
}
TV_HOST_DEVICE_INLINE const T &operator()() const {
static_assert(Rank == -1 || 0 == Rank, "error");
#if defined TV_DEBUG
TV_REQUIRE(ptr_ != nullptr, "you want get value but the view is empty.%s",
"\n");
TV_REQUIRE(shape_.ndim() == 0, "you provide 0 indexes, but dim is %ld\n",
shape_.ndim());
#endif
return ptr_[0];
}
template <class T1> TV_HOST_DEVICE_INLINE T &operator()(T1 i1) {
static_assert(Rank == -1 || 1 == Rank, "error");
#if defined TV_DEBUG
TV_REQUIRE(shape_.ndim() == 1, "you provide 1 indexes, but dim is %ld\n",
shape_.ndim());
TV_REQUIRE(i1 >= 0 && i1 < shape_[0],
"index-%d(%d) out-of-range: [0, %d)\n", 0, i1, shape_[0]);
#endif
return ptr_[i1];
}
template <class T1, class T2>
TV_HOST_DEVICE_INLINE T &operator()(T1 i1, T2 i2) {
static_assert(Rank == -1 || 2 == Rank, "error");
#if defined TV_DEBUG
TV_REQUIRE(shape_.ndim() == 2, "you provide 2 indexes, but dim is %ld\n",
shape_.ndim());
TV_REQUIRE(i1 >= 0 && i1 < shape_[0],
"index-%d(%d) out-of-range: [0, %d)\n", 0, int(i1), shape_[0]);
TV_REQUIRE(i2 >= 0 && i2 < shape_[1],
"index-%d(%d) out-of-range: [0, %d)\n", 1, int(i2), shape_[1]);
#endif
return ptr_[i1 * shape_[1] + i2];
}
template <class T1, class T2, class T3>
TV_HOST_DEVICE_INLINE T &operator()(T1 i1, T2 i2, T3 i3) {
static_assert(Rank == -1 || 3 == Rank, "error");
#if defined TV_DEBUG
TV_REQUIRE(shape_.ndim() == 3, "you provide 3 indexes, but dim is %ld\n",
shape_.ndim());
TV_REQUIRE(i1 >= 0 && i1 < shape_[0],
"index-%d(%d) out-of-range: [0, %d)\n", 0, int(i1), shape_[0]);
TV_REQUIRE(i2 >= 0 && i2 < shape_[1],
"index-%d(%d) out-of-range: [0, %d)\n", 1, int(i2), shape_[1]);
TV_REQUIRE(i3 >= 0 && i3 < shape_[2],
"index-%d(%d) out-of-range: [0, %d)\n", 2, int(i3), shape_[2]);
#endif
return ptr_[(i1 * shape_[1] + i2) * shape_[2] + i3];
}
template <class T1, class T2, class T3, class T4>
TV_HOST_DEVICE_INLINE T &operator()(T1 i1, T2 i2, T3 i3, T4 i4) {
static_assert(Rank == -1 || 4 == Rank, "error");
#if defined TV_DEBUG
TV_REQUIRE(shape_.ndim() == 4, "you provide 4 indexes, but dim is %ld\n",
shape_.ndim());
TV_REQUIRE(i1 >= 0 && i1 < shape_[0],
"index-%d(%d) out-of-range: [0, %d)\n", 0, int(i1), shape_[0]);
TV_REQUIRE(i2 >= 0 && i2 < shape_[1],
"index-%d(%d) out-of-range: [0, %d)\n", 1, int(i2), shape_[1]);
TV_REQUIRE(i3 >= 0 && i3 < shape_[2],
"index-%d(%d) out-of-range: [0, %d)\n", 2, int(i3), shape_[2]);
TV_REQUIRE(i4 >= 0 && i4 < shape_[3],
"index-%d(%d) out-of-range: [0, %d)\n", 3, int(i4), shape_[3]);
#endif
return ptr_[((i1 * shape_[1] + i2) * shape_[2] + i3) * shape_[3] + i4];
}
template <class T1> TV_HOST_DEVICE_INLINE const T &operator()(T1 i1) const {
static_assert(Rank == -1 || 1 == Rank, "error");
#if defined TV_DEBUG
TV_REQUIRE(shape_.ndim() == 1, "you provide 1 indexes, but dim is %ld\n",
shape_.ndim());
TV_REQUIRE(i1 >= 0 && i1 < shape_[0],
"index-%d(%d) out-of-range: [0, %d)\n", 0, int(i1), shape_[0]);
#endif
return ptr_[i1];
}
template <class T1, class T2>
TV_HOST_DEVICE_INLINE const T &operator()(T1 i1, T2 i2) const {
static_assert(Rank == -1 || 2 == Rank, "error");
#if defined TV_DEBUG
TV_REQUIRE(shape_.ndim() == 2, "you provide 2 indexes, but dim is %ld\n",
shape_.ndim());
TV_REQUIRE(i1 >= 0 && i1 < shape_[0],
"index-%d(%d) out-of-range: [0, %d)\n", 0, int(i1), shape_[0]);
TV_REQUIRE(i2 >= 0 && i2 < shape_[1],
"index-%d(%d) out-of-range: [0, %d)\n", 1, int(i2), shape_[1]);
#endif
return ptr_[i1 * shape_[1] + i2];
}
template <class T1, class T2, class T3>
TV_HOST_DEVICE_INLINE const T &operator()(T1 i1, T2 i2, T3 i3) const {
static_assert(Rank == -1 || 3 == Rank, "error");
#if defined TV_DEBUG
TV_REQUIRE(shape_.ndim() == 3, "you provide 3 indexes, but dim is %ld\n",
shape_.ndim());
TV_REQUIRE(i1 >= 0 && i1 < shape_[0],
"index-%d(%d) out-of-range: [0, %d)\n", 0, int(i1), shape_[0]);
TV_REQUIRE(i2 >= 0 && i2 < shape_[1],
"index-%d(%d) out-of-range: [0, %d)\n", 1, int(i2), shape_[1]);
TV_REQUIRE(i3 >= 0 && i3 < shape_[2],
"index-%d(%d) out-of-range: [0, %d)\n", 2, int(i3), shape_[2]);
#endif
return ptr_[(i1 * shape_[1] + i2) * shape_[2] + i3];
}
template <class T1, class T2, class T3, class T4>
TV_HOST_DEVICE_INLINE const T &operator()(T1 i1, T2 i2, T3 i3, T4 i4) const {
static_assert(Rank == -1 || 4 == Rank, "error");
#if defined TV_DEBUG
TV_REQUIRE(shape_.ndim() == 4, "you provide 4 indexes, but dim is %ld\n",
shape_.ndim());
TV_REQUIRE(i1 >= 0 && i1 < shape_[0],
"index-%d(%d) out-of-range: [0, %d)\n", 0, int(i1), shape_[0]);
TV_REQUIRE(i2 >= 0 && i2 < shape_[1],
"index-%d(%d) out-of-range: [0, %d)\n", 1, int(i2), shape_[1]);
TV_REQUIRE(i3 >= 0 && i3 < shape_[2],
"index-%d(%d) out-of-range: [0, %d)\n", 2, int(i3), shape_[2]);
TV_REQUIRE(i4 >= 0 && i4 < shape_[3],
"index-%d(%d) out-of-range: [0, %d)\n", 3, int(i4), shape_[3]);
#endif
return ptr_[((i1 * shape_[1] + i2) * shape_[2] + i3) * shape_[3] + i4];
}
TV_HOST_DEVICE_INLINE T &operator[](int idx) {
#ifdef TV_DEBUG
TV_REQUIRE(idx >= 0 && idx < size(), "index(%d) out-of-range: [0, %ld)\n",
int(idx), size());
#endif
return ptr_[idx];
}
TV_HOST_DEVICE_INLINE const T &operator[](int idx) const {
#ifdef TV_DEBUG
TV_REQUIRE(idx >= 0 && idx < size(), "index(%d) out-of-range: [0, %ld)\n",
int(idx), size());
#endif
return ptr_[idx];
}
TV_HOST_DEVICE_INLINE TensorAccesser<T, Rank - 1, PtrTraits, Tindex>
accessor(Tindex idx) {
static_assert(Rank > 1, "for Rank == 1, use accessor() or just use []");
return TensorAccesser<T, Rank - 1, PtrTraits, Tindex>(
ptr_ + stride_[0] * idx, stride_.data() + 1);
}
TV_HOST_DEVICE_INLINE TensorAccesser<T, Rank, PtrTraits, Tindex> accessor() {
static_assert(Rank > 0, "rank must higher than zero");
return TensorAccesser<T, Rank, PtrTraits, Tindex>(ptr_, stride_.data());
}
TV_HOST_DEVICE_INLINE
TensorAccesser<T, Rank - 1, PtrTraits, Tindex> accessor(Tindex idx) const {
static_assert(Rank > 1, "for Rank == 1, use accessor() or just use []");
return TensorAccesser<T, Rank - 1, PtrTraits, Tindex>(
ptr_ + stride_[0] * idx, stride_.data() + 1);
}
TV_HOST_DEVICE_INLINE
TensorAccesser<T, Rank, PtrTraits, Tindex> accessor() const {
static_assert(Rank > 0, "error");
return TensorAccesser<T, Rank, PtrTraits, Tindex>(
ptr_, stride_.data(), "rank must higher than zero");
}
TV_HOST_DEVICE_INLINE bool empty() const { return ptr_ == nullptr; }
TV_HOST_DEVICE_INLINE ptr_t data() { return ptr_; }
TV_HOST_DEVICE_INLINE const ptr_t data() const { return ptr_; }
TV_HOST_DEVICE_INLINE const tv_shape_t &shape() const { return shape_; }
TV_HOST_DEVICE_INLINE const tv_shape_t &stride() const { return stride_; }
TV_HOST_DEVICE_INLINE int dim(int idx) const { return shape_[idx]; }
TV_HOST_DEVICE_INLINE int ndim() const { return shape_.ndim(); }
template <class... Inds>
TV_HOST_DEVICE_INLINE
TensorView<T, Rank == -1 ? -1 : sizeof...(Inds), PtrTraits, Tindex>
view(Inds... newShapes) const {
ShapeBase<Rank == -1 ? TV_MAX_DIM : sizeof...(Inds), Tindex> shapes{
int(newShapes)...};
for (size_t i = 0; i < sizeof...(newShapes); ++i) {
if (shapes[i] == -1) {
shapes[i] = 1;
shapes[i] = size() / shapes.size();
break;
}
}
TV_ASSERT(shapes.size() == size());
return TensorView < T, Rank == -1 ? -1 : sizeof...(Inds), PtrTraits,
Tindex > (ptr_, shapes);
}
TV_HOST_DEVICE_INLINE TensorView<T, -1, PtrTraits, Tindex>
view(Shape shapes) const {
TV_ASSERT(shapes.size() == size());
return TensorView<T, -1, PtrTraits, Tindex>(ptr_, shapes);
}
TV_HOST_DEVICE_INLINE TensorView<T, -1, PtrTraits, Tindex> squeeze() const {
return TensorView<T, -1, PtrTraits, Tindex>(ptr_, shape_.squeeze());
}
TV_HOST_DEVICE_INLINE
TensorView<T, Rank == -1 ? -1 : Rank - 1, PtrTraits, Tindex>
squeeze(int dim) const {
return TensorView < T, Rank == -1 ? -1 : Rank - 1, PtrTraits,
Tindex > (ptr_, shape_.squeeze < Rank == -1 ? TV_MAX_DIM
: Rank - 1 > (dim));
}
TV_HOST_DEVICE_INLINE size_t size() const { return shape_.size(); }
template <class... Integers>
TV_HOST_DEVICE_INLINE TensorView<T, -1, PtrTraits, Tindex>
subview(int id, Integers... ints) {
tv_shape_t start = {id, ints...};
for (int i = 1 + sizeof...(ints); i < ndim(); ++i) {
start.push_back(0);
}
return TensorView<T, Rank, PtrTraits, Tindex>(
ptr_ + rowArrayIdx(shape_, start),
shape_.subshape(sizeof...(ints) + 1));
}
template <class... Integers>
TV_HOST_DEVICE_INLINE TensorView<T, -1, PtrTraits, Tindex>
subview(int id, Integers... ints) const {
tv_shape_t start = {id, ints...};
for (int i = 1 + sizeof...(ints); i < ndim(); ++i) {
start.push_back(0);
}
return TensorView<T, Rank, PtrTraits, Tindex>(
ptr_ + rowArrayIdx(shape_, start),
shape_.subshape(sizeof...(ints) + 1));
}
TV_HOST_DEVICE_INLINE TensorView<T, -1, PtrTraits, Tindex>
subview(SimpleVector<int> ids) const {
Shape start = ids;
for (int i = ids.size(); i < ndim(); ++i) {
start.push_back(0);
}
return TensorView<T, Rank, PtrTraits, Tindex>(
ptr_ + rowArrayIdx(shape_, start), shape_.subshape(ids.size()));
}
template <typename Os> std::string repr(Os &ss) const {
if (empty())
return "";
if (shape_.ndim() == 0) {
ss << "Tensor[" << type_s<T> << "]" << std::endl;
ss << *ptr_;
return ss.str();
}
SimpleVector<int64_t, TV_MAX_DIM> prev(ndim(), -1);
SimpleVector<int64_t, TV_MAX_DIM> nd_index(ndim());
SimpleVector<int64_t, TV_MAX_DIM> _shape;
for (auto s : shape()) {
_shape.push_back(s);
}
ss << "Tensor[" << type_s<T> << "]: shape=" << shape()
<< ", stride=" << stride() << std::endl;
auto ndimValue = ndim();
for (int64_t i = 0; i < int64_t(size()); ++i) {
rowArrayIdxInv(i, nd_index.data(), _shape.data(), ndimValue);
bool newline = false;
int end_count = 0;
for (int j = 0; j < ndimValue; ++j) {
if (nd_index[j] != prev[j] && nd_index[j] == 0 && prev[j] != 0 &&
prev[j] != -1) {
ss << "]";
++end_count;
newline = true;
}
}
if (prev[0] == -1) {
end_count = ndimValue;
}
if (newline) {
ss << "\n";
}
int starts_count = 0;
for (int j = 0; j < ndimValue; ++j) {
if (nd_index[j] != prev[j] && nd_index[j] == 0 && prev[j] != 0) {
++starts_count;
}
}
if (starts_count > 0) {
for (int j = 0; j < ndimValue - end_count; ++j) {
ss << " ";
}
for (int j = 0; j < starts_count; ++j) {
ss << "[";
}
}
if (std::is_same<T, uint8_t>::value ||
std::is_same<T, const uint8_t>::value) {
ss << unsigned((*this)[i]);
} else {
ss << (*this)[i];
}
if (nd_index[ndimValue - 1] != _shape[ndimValue - 1] - 1) {
ss << ",";
}
for (int j = 0; j < ndimValue; ++j) {
prev[j] = nd_index[j];
}
}
for (int j = 0; j < ndimValue; ++j) {
ss << "]";
}
return ss.str();
}
std::string repr() const {
std::ostringstream ss;
return repr(ss);
}
protected:
template <typename T1> TV_HOST_DEVICE_INLINE Slice to_slice(T1 s) const {
return Slice{int(s), -1, -1};
}
TV_HOST_DEVICE_INLINE Slice to_slice(Slice s) const { return Slice(s); }
ptr_t ptr_ = nullptr;
tv_shape_t shape_;
tv_shape_t stride_;
};
template <typename T> TensorView<T> vector2tv(std::vector<T> &arr) {
return TensorView<T>(arr.data(), {arr.size()});
}
template <typename T>
TensorView<T> vector2tv(std::vector<T> &arr, Shape shape) {
TV_ASSERT_INVALID_ARG(shape.prod() == arr.size(), "error");
return TensorView<T>(arr.data(), shape);
}
template <typename T> TensorView<const T> vector2tv(const std::vector<T> &arr) {
return TensorView<const T>(arr.data(), {arr.size()});
}
template <typename Os, typename T, int Rank, template <class> class PtrTraits,
typename Tindex>
Os &operator<<(Os &os, const TensorView<T, Rank, PtrTraits, Tindex> &dt) {
os << dt.repr();
return os;
}
template <typename Os, typename T, int Rank, template <class> class PtrTraits,
typename Tindex>
Os &operator<<(Os &os, const TensorView<const T, Rank, PtrTraits, Tindex> &dt) {
os << dt.repr();
return os;
}
namespace detail {
template <typename T> struct TypePrintfFormat;
template <> struct TypePrintfFormat<float> {
static constexpr const char *value = "%.2f";
};
template <> struct TypePrintfFormat<double> {
static constexpr const char *value = "%.2f";
};
template <> struct TypePrintfFormat<int8_t> {
static constexpr const char *value = "%d";
};
template <> struct TypePrintfFormat<int16_t> {
static constexpr const char *value = "%d";
};
template <> struct TypePrintfFormat<int32_t> {
static constexpr const char *value = "%d";
};
template <> struct TypePrintfFormat<uint8_t> {
static constexpr const char *value = "%u";
};
template <> struct TypePrintfFormat<uint16_t> {
static constexpr const char *value = "%u";
};
template <> struct TypePrintfFormat<uint32_t> {
static constexpr const char *value = "%u";
};
template <> struct TypePrintfFormat<int64_t> {
static constexpr const char *value = "%ld";
};
template <> struct TypePrintfFormat<uint64_t> {
static constexpr const char *value = "%lu";
};
template <> struct TypePrintfFormat<bool> {
static constexpr const char *value = "%d";
};
template <typename T>
constexpr const char *type_printf_format_v = TypePrintfFormat<T>::value;
}; // namespace detail
template <typename T, int Rank, template <class> class PtrTraits,
typename Tindex>
TV_HOST_DEVICE void
printTensorView(const TensorView<T, Rank, PtrTraits, Tindex> &tensor,
const char *format) {
// used to print tensor in cuda kernel.
if (tensor.empty())
return;
if (tensor.ndim() == 0) {
printf(format, tensor());
printf("\n");
return;
}
SimpleVector<int64_t, TV_MAX_DIM> prev(tensor.ndim(), -1);
SimpleVector<int64_t, TV_MAX_DIM> nd_index(tensor.ndim());
SimpleVector<int64_t, TV_MAX_DIM> shape(tensor.shape());
auto ndim = tensor.ndim();
for (int64_t i = 0; i < tensor.size(); ++i) {
rowArrayIdxInv(i, nd_index.data(), shape.data(), ndim);
bool newline = false;
int end_count = 0;
for (int j = 0; j < ndim; ++j) {
if (nd_index[j] != prev[j] && nd_index[j] == 0 && prev[j] != 0 &&
prev[j] != -1) {
printf("]");
++end_count;
newline = true;
}
}
if (prev[0] == -1) {
end_count = ndim;
}
if (newline) {
printf("\n");
}
int starts_count = 0;
for (int j = 0; j < ndim; ++j) {
if (nd_index[j] != prev[j] && nd_index[j] == 0 && prev[j] != 0) {
++starts_count;
}
}
if (starts_count > 0) {
for (int j = 0; j < ndim - end_count; ++j) {
printf(" ");
}
for (int j = 0; j < starts_count; ++j) {
printf("]");
}
}
printf(format, tensor[i]);
if (nd_index[ndim - 1] != shape[ndim - 1] - 1) {
printf(",");
}
for (int j = 0; j < ndim; ++j) {
prev[j] = nd_index[j];
}
}
for (int j = 0; j < ndim; ++j) {
printf("]");
}
printf("\n");
}
template <typename T, int Rank, template <class> class PtrTraits,
typename Tindex>
TV_HOST_DEVICE void
printTensorView(TensorView<T, Rank, PtrTraits, Tindex> tensor) {
using Traw = typename std::remove_const<T>::type;
return printTensorView(tensor, detail::type_printf_format_v<Traw>);
}
template <typename T>
TV_HOST_DEVICE void printTensorView(const T *ptr, Shape shape) {
using Traw = typename std::remove_const<T>::type;
return printTensorView(TensorView<const T>(ptr, shape),
detail::type_printf_format_v<Traw>);
}
template <typename T>
TV_HOST_DEVICE void printTensorView(const T *ptr, Shape shape,
const char *format) {
return printTensorView(TensorView<const T>(ptr, shape), format);
}
#ifdef TV_CUDA
#ifdef __DRIVER_TYPES_H__
#ifndef DEVICE_RESET
#define DEVICE_RESET cudaDeviceReset();
#endif
#else
#ifndef DEVICE_RESET
#define DEVICE_RESET
#endif
#endif
template <typename T>
void check(T result, char const *const func, const char *const file,
int const line) {
if (result) {
fprintf(stderr, "CUDA error at %s:%d code=%d \"%s\" \n", file, line,
static_cast<unsigned int>(result), func);
DEVICE_RESET
// Make sure we call CUDA Device Reset before exiting
exit(EXIT_FAILURE);
}
}
#define checkCudaErrors(val) tv::check((val), #val, __FILE__, __LINE__)
template <typename T>
void host2dev(T *dst, const T *src, size_t size, cudaStream_t s = 0) {
checkCudaErrors(
cudaMemcpyAsync(dst, src, size * sizeof(T), cudaMemcpyHostToDevice, s));
}
template <typename T, int Rank, template <class> class PtrTraits1,
template <class> class PtrTraits2, typename Tindex1, typename Tindex2>
void host2dev(TensorView<T, Rank, PtrTraits1, Tindex1> dst,
const TensorView<const T, Rank, PtrTraits2, Tindex2> src,
cudaStream_t s = 0) {
host2dev(dst.data(), src.data(), std::min(dst.size(), src.size()), s);
}
template <typename T, int Rank, template <class> class PtrTraits1,
template <class> class PtrTraits2, typename Tindex1, typename Tindex2>
void host2dev(TensorView<T, Rank, PtrTraits1, Tindex1> dst,
const TensorView<T, Rank, PtrTraits2, Tindex2> src,
cudaStream_t s = 0) {
host2dev(dst.data(), src.data(), std::min(dst.size(), src.size()), s);
}
template <typename T> void host2dev_sync(T *dst, const T *src, size_t size) {
checkCudaErrors(
cudaMemcpy(dst, src, size * sizeof(T), cudaMemcpyHostToDevice));
}
template <typename T, int Rank, template <class> class PtrTraits1,
template <class> class PtrTraits2, typename Tindex1, typename Tindex2>
void host2dev_sync(TensorView<T, Rank, PtrTraits1, Tindex1> dst,
const TensorView<const T, Rank, PtrTraits2, Tindex2> src) {
host2dev_sync(dst.data(), src.data(), std::min(dst.size(), src.size()));
}
template <typename T, int Rank, template <class> class PtrTraits1,
template <class> class PtrTraits2, typename Tindex1, typename Tindex2>
void host2dev_sync(TensorView<T, Rank, PtrTraits1, Tindex1> dst,
const TensorView<T, Rank, PtrTraits2, Tindex2> src) {
host2dev_sync(dst.data(), src.data(), std::min(dst.size(), src.size()));
}
template <typename T>
void dev2host(T *dst, const T *src, size_t size, cudaStream_t s = 0) {
checkCudaErrors(
cudaMemcpyAsync(dst, src, size * sizeof(T), cudaMemcpyDeviceToHost, s));
}
template <typename T, int Rank, template <class> class PtrTraits1,
template <class> class PtrTraits2, typename Tindex1, typename Tindex2>
void dev2host(TensorView<T, Rank, PtrTraits1, Tindex1> dst,
const TensorView<const T, Rank, PtrTraits2, Tindex2> src,
cudaStream_t s = 0) {
dev2host(dst.data(), src.data(), std::min(dst.size(), src.size()), s);
}
template <typename T, int Rank, template <class> class PtrTraits1,
template <class> class PtrTraits2, typename Tindex1, typename Tindex2>
void dev2host(TensorView<T, Rank, PtrTraits1, Tindex1> dst,
const TensorView<T, Rank, PtrTraits2, Tindex2> src,
cudaStream_t s = 0) {
dev2host(dst.data(), src.data(), std::min(dst.size(), src.size()), s);
}
template <typename T>
void dev2dev(T *dst, const T *src, size_t size, cudaStream_t s = 0) {
checkCudaErrors(
cudaMemcpyAsync(dst, src, size * sizeof(T), cudaMemcpyDeviceToDevice, s));
}
template <typename T, int Rank, template <class> class PtrTraits1,
template <class> class PtrTraits2, typename Tindex1, typename Tindex2>
void dev2dev(TensorView<T, Rank, PtrTraits1, Tindex1> dst,
const TensorView<const T, Rank, PtrTraits2, Tindex2> src,
cudaStream_t s = 0) {
dev2dev(dst.data(), src.data(), std::min(dst.size(), src.size()), s);
}
template <typename T, int Rank, template <class> class PtrTraits1,
template <class> class PtrTraits2, typename Tindex1, typename Tindex2>
void dev2dev(TensorView<T, Rank, PtrTraits1, Tindex1> dst,
const TensorView<T, Rank, PtrTraits2, Tindex2> src,
cudaStream_t s = 0) {
dev2dev(dst.data(), src.data(), std::min(dst.size(), src.size()), s);
}
template <typename T>
void host2host(T *dst, const T *src, size_t size, cudaStream_t s = 0) {
checkCudaErrors(
cudaMemcpyAsync(dst, src, size * sizeof(T), cudaMemcpyHostToHost, s));
}
template <typename T, int Rank, template <class> class PtrTraits1,
template <class> class PtrTraits2, typename Tindex1, typename Tindex2>
void host2host(TensorView<T, Rank, PtrTraits1, Tindex1> dst,
const TensorView<const T, Rank, PtrTraits2, Tindex2> src,
cudaStream_t s = 0) {
host2host(dst.data(), src.data(), std::min(dst.size(), src.size()), s);
}
template <typename T, int Rank, template <class> class PtrTraits1,
template <class> class PtrTraits2, typename Tindex1, typename Tindex2>
void host2host(TensorView<T, Rank, PtrTraits1, Tindex1> dst,
const TensorView<T, Rank, PtrTraits2, Tindex2> src,
cudaStream_t s = 0) {
host2host(dst.data(), src.data(), std::min(dst.size(), src.size()), s);
}
template <typename T, int Rank, template <class> class PtrTraits,
typename Tindex>
void zero_dev(TensorView<T, Rank, PtrTraits, Tindex> tensor) {
checkCudaErrors(cudaMemset(tensor.data(), 0, tensor.size() * sizeof(T)));
}
template <typename T, int Rank, template <class> class PtrTraits,
typename Tindex>
void zero_dev(TensorView<T, Rank, PtrTraits, Tindex> tensor, cudaStream_t s) {
checkCudaErrors(
cudaMemsetAsync(tensor.data(), 0, tensor.size() * sizeof(T), s));
}
template <typename T, int Rank, template <class> class PtrTraits,
typename Tindex>
void zero_host(TensorView<T, Rank, PtrTraits, Tindex> tensor) {
std::fill(tensor.data(), tensor.data() + tensor.size(), 0);
}
#endif
} // namespace tv
\ No newline at end of file
// Copyright 2019-2020 Yan Yan
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <chrono>
#ifdef TV_CUDA
#include <cuda_runtime_api.h>
#endif
#include <iostream>
namespace tv {
#ifdef TV_CUDA
template <typename TimeT = std::chrono::microseconds> struct CudaContextTimer {
CudaContextTimer() {
cudaDeviceSynchronize();
mCurTime = std::chrono::steady_clock::now();
}
typename TimeT::rep report() {
cudaDeviceSynchronize();
auto duration = std::chrono::duration_cast<TimeT>(
std::chrono::steady_clock::now() - mCurTime);
auto res = duration.count();
mCurTime = std::chrono::steady_clock::now();
return res;
}
private:
std::chrono::time_point<std::chrono::steady_clock> mCurTime;
};
#endif
template <typename TimeT = std::chrono::microseconds> struct CPUTimer {
CPUTimer() { mCurTime = std::chrono::steady_clock::now(); }
typename TimeT::rep report() {
auto duration = std::chrono::duration_cast<TimeT>(
std::chrono::steady_clock::now() - mCurTime);
auto res = duration.count();
mCurTime = std::chrono::steady_clock::now();
return res;
}
private:
std::chrono::time_point<std::chrono::steady_clock> mCurTime;
};
} // namespace tv
// Copyright 2019-2020 Yan Yan
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "mp_helper.h"
#include <tensorview/tensorview.h>
#include <ATen/ATen.h>
#include <torch/script.h>
#ifdef TV_CUDA
#include <ATen/cuda/CUDAContext.h>
#endif
namespace tv {
#ifdef TV_CUDA
struct TorchGPU : public tv::GPU {
virtual cudaStream_t getStream() const override {
return at::cuda::getCurrentCUDAStream();
}
};
#endif
namespace detail {
template <typename T> struct TypeToTorchDtypeTraits;
template <> struct TypeToTorchDtypeTraits<int32_t> {
static constexpr decltype(torch::kInt32) value = torch::kInt32;
};
template <> struct TypeToTorchDtypeTraits<int16_t> {
static constexpr decltype(torch::kInt32) value = torch::kInt16;
};
template <> struct TypeToTorchDtypeTraits<int8_t> {
static constexpr decltype(torch::kInt8) value = torch::kInt8;
};
template <> struct TypeToTorchDtypeTraits<int64_t> {
static constexpr decltype(torch::kInt32) value = torch::kInt64;
};
template <> struct TypeToTorchDtypeTraits<uint8_t> {
static constexpr decltype(torch::kInt32) value = torch::kUInt8;
};
template <> struct TypeToTorchDtypeTraits<bool> {
static constexpr decltype(torch::kInt32) value = torch::kBool;
};
template <> struct TypeToTorchDtypeTraits<float> {
static constexpr decltype(torch::kInt32) value = torch::kFloat32;
};
template <> struct TypeToTorchDtypeTraits<double> {
static constexpr decltype(torch::kInt32) value = torch::kFloat64;
};
template <> struct TypeToTorchDtypeTraits<at::Half> {
static constexpr decltype(torch::kInt32) value = torch::kHalf;
};
using all_torch_types_t = std::tuple<float, double, int8_t, int16_t, int32_t,
int64_t, uint8_t, bool, at::Half>;
} // namespace detail
template <typename T>
constexpr decltype(torch::kInt32) torch_type_v =
detail::TypeToTorchDtypeTraits<T>::value;
template <class... Ts, typename F>
void dispatch_torch(at::ScalarType t, F &&f) {
static_assert(sizeof...(Ts) > 0, "you need to provide at least one type");
bool notFound = true;
tv::mp_for_each<mp_list<Ts...>>([=, &notFound, &f](auto I) {
if (detail::TypeToTorchDtypeTraits<TV_DECLTYPE(I)>::value == t) {
std::forward<F>(f)(TV_DECLTYPE(I)());
notFound = false;
}
});
if (notFound) {
std::stringstream ss;
tv::mp_for_each<mp_list<Ts...>>([=, &ss](auto I) {
ss << tv::detail::TypeToString<TV_DECLTYPE(I)>::value << " ";
});
TV_THROW_RT_ERR("unknown type", t, ", available:", ss.str());
}
}
template <class T> struct DispatchTorch;
template <template <class...> class T, class... Args>
struct DispatchTorch<T<Args...>> {
template <typename F> inline void operator()(at::ScalarType t, F &&f) {
return dispatch_torch<Args...>(t, std::forward<F>(f));
}
};
template <typename T> void check_torch_dtype(const torch::Tensor &tensor) {
DispatchTorch<detail::all_torch_types_t>()(tensor.scalar_type(), [&](auto I) {
using Ttensor = TV_DECLTYPE(I);
constexpr bool val = std::is_same<std::remove_cv_t<T>, Ttensor>::value;
TV_ASSERT_RT_ERR(val, "error");
});
}
template <typename T, int Rank = -1,
template <class> class PtrTraits = DefaultPtrTraits,
typename Tindex = int>
TensorView<T, Rank, PtrTraits, Tindex> torch2tv(const torch::Tensor &tensor) {
using tv_shape_t =
typename TensorView<T, Rank, PtrTraits, Tindex>::tv_shape_t;
check_torch_dtype<T>(tensor);
// TODO stride
if (Rank > 0) {
TV_ASSERT_INVALID_ARG(tensor.dim() == Rank, "error");
}
tv_shape_t shape;
for (auto i : tensor.sizes()) {
shape.push_back(i);
}
return tv::TensorView<T, Rank, PtrTraits, Tindex>(
tensor.data_ptr<std::remove_const_t<T>>(), shape);
}
template <typename T>
torch::Tensor torch_slice_first_axis(torch::Tensor tensor, T start, T end) {
// only torch >= 1.5 have tensor slice.
torch::Tensor res;
auto tensor_shape = tensor.sizes();
std::vector<int64_t> shape(tensor_shape.begin(), tensor_shape.end());
shape[0] = end - start;
uint8_t *ptr = reinterpret_cast<uint8_t *>(tensor.data_ptr());
res = torch::from_blob(ptr + start * tensor.stride(0) * tensor.itemsize(),
torch::IntArrayRef(shape), tensor.options());
return res;
}
namespace detail {
template <> struct TypeToString<at::Half> {
static constexpr const char *value = "half";
};
} // namespace detail
} // namespace tv
\ No newline at end of file
// Copyright 2019 Yan Yan
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <tensorview/mp_helper.h>
#include <tensorview/tensorview.h>
#include <ATen/ATen.h>
#include <torch/script.h>
#ifdef TV_CUDA
#include <ATen/cuda/CUDAContext.h>
#endif
namespace tv {
#ifdef TV_CUDA
struct TorchGPU : public tv::GPU {
virtual cudaStream_t getStream() const override {
return at::cuda::getCurrentCUDAStream();
}
};
#endif
template <typename T> void check_torch_dtype(const torch::Tensor &tensor) {
switch (tensor.scalar_type()) {
case at::ScalarType::Double: {
auto val = std::is_same<std::remove_const_t<T>, double>::value;
TV_ASSERT_RT_ERR(val, "error");
break;
}
case at::ScalarType::Float: {
auto val = std::is_same<std::remove_const_t<T>, float>::value;
TV_ASSERT_RT_ERR(val, "error");
break;
}
case at::ScalarType::Int: {
auto val = std::is_same<std::remove_const_t<T>, int>::value;
TV_ASSERT_RT_ERR(val, "error");
break;
}
case at::ScalarType::Half: {
auto val = std::is_same<std::remove_const_t<T>, at::Half>::value;
TV_ASSERT_RT_ERR(val, "error");
break;
}
case at::ScalarType::Long: {
auto val = std::is_same<std::remove_const_t<T>, long>::value;
TV_ASSERT_RT_ERR(val, "error");
break;
}
default:
TV_ASSERT_RT_ERR(false, "error");
}
}
namespace detail {
template <typename T> struct TypeToTorchDtypeTraits;
template <> struct TypeToTorchDtypeTraits<int32_t> {
static constexpr decltype(torch::kInt32) value = torch::kInt32;
};
template <> struct TypeToTorchDtypeTraits<int64_t> {
static constexpr decltype(torch::kInt32) value = torch::kInt64;
};
template <> struct TypeToTorchDtypeTraits<float> {
static constexpr decltype(torch::kInt32) value = torch::kFloat32;
};
template <> struct TypeToTorchDtypeTraits<double> {
static constexpr decltype(torch::kInt32) value = torch::kFloat64;
};
template <> struct TypeToTorchDtypeTraits<at::Half> {
static constexpr decltype(torch::kInt32) value = torch::kHalf;
};
} // namespace detail
template <typename T>
constexpr decltype(torch::kInt32) torch_type_v =
detail::TypeToTorchDtypeTraits<T>::value;
template <typename T> tv::TensorView<T> torch2tv(const torch::Tensor &tensor) {
check_torch_dtype<T>(tensor);
tv::Shape shape;
for (auto i : tensor.sizes()) {
shape.push_back(i);
}
return tv::TensorView<T>(tensor.data_ptr<std::remove_const_t<T>>(), shape);
}
namespace detail {
template <> struct TypeToString<at::Half> {
static constexpr const char *value = "half";
};
} // namespace detail
template <class... Ts, typename F>
void dispatch_torch(at::ScalarType t, F &&f) {
static_assert(sizeof...(Ts) > 0, "you need to provide at least one type");
bool notFound = true;
spconv::tv::mp_for_each<spconv::mp_list<Ts...>>([=, &notFound, &f](auto I) {
if (torch_type_v<decltype(I)> == t) {
std::forward<F>(f)(decltype(I)());
notFound = false;
}
});
if (notFound) {
std::stringstream ss;
spconv::tv::mp_for_each<spconv::mp_list<Ts...>>([=, &ss](auto I) {
ss << tv::detail::TypeToString<decltype(I)>::value << " ";
});
TV_THROW_RT_ERR("unknown type", t, ", available: ", ss.str());
}
}
} // namespace tv
\ No newline at end of file
/**
* MIT License
*
* Copyright (c) 2017 Tessil
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_ROBIN_GROWTH_POLICY_H
#define TSL_ROBIN_GROWTH_POLICY_H
#include <algorithm>
#include <array>
#include <climits>
#include <cmath>
#include <cstddef>
#include <iterator>
#include <limits>
#include <ratio>
#include <stdexcept>
#ifdef TSL_DEBUG
#define tsl_rh_assert(expr) assert(expr)
#else
#define tsl_rh_assert(expr) (static_cast<void>(0))
#endif
/**
* If exceptions are enabled, throw the exception passed in parameter, otherwise
* call std::terminate.
*/
#if (defined(__cpp_exceptions) || defined(__EXCEPTIONS) || \
(defined(_MSC_VER) && defined(_CPPUNWIND))) && \
!defined(TSL_NO_EXCEPTIONS)
#define TSL_RH_THROW_OR_TERMINATE(ex, msg) throw ex(msg)
#else
#ifdef NDEBUG
#define TSL_RH_THROW_OR_TERMINATE(ex, msg) std::terminate()
#else
#include <cstdio>
#define TSL_RH_THROW_OR_TERMINATE(ex, msg) \
do { \
std::fprintf(stderr, msg); \
std::terminate(); \
} while (0)
#endif
#endif
#if defined(__GNUC__) || defined(__clang__)
#define TSL_RH_LIKELY(exp) (__builtin_expect(!!(exp), true))
#else
#define TSL_RH_LIKELY(exp) (exp)
#endif
namespace tsl {
namespace rh {
/**
* Grow the hash table by a factor of GrowthFactor keeping the bucket count to a
* power of two. It allows the table to use a mask operation instead of a modulo
* operation to map a hash to a bucket.
*
* GrowthFactor must be a power of two >= 2.
*/
template <std::size_t GrowthFactor> class power_of_two_growth_policy {
public:
/**
* Called on the hash table creation and on rehash. The number of buckets for
* the table is passed in parameter. This number is a minimum, the policy may
* update this value with a higher value if needed (but not lower).
*
* If 0 is given, min_bucket_count_in_out must still be 0 after the policy
* creation and bucket_for_hash must always return 0 in this case.
*/
explicit power_of_two_growth_policy(std::size_t &min_bucket_count_in_out) {
if (min_bucket_count_in_out > max_bucket_count()) {
TSL_RH_THROW_OR_TERMINATE(std::length_error,
"The hash table exceeds its maxmimum size.");
}
if (min_bucket_count_in_out > 0) {
min_bucket_count_in_out =
round_up_to_power_of_two(min_bucket_count_in_out);
m_mask = min_bucket_count_in_out - 1;
} else {
m_mask = 0;
}
}
/**
* Return the bucket [0, bucket_count()) to which the hash belongs.
* If bucket_count() is 0, it must always return 0.
*/
std::size_t bucket_for_hash(std::size_t hash) const noexcept {
return hash & m_mask;
}
/**
* Return the number of buckets that should be used on next growth.
*/
std::size_t next_bucket_count() const {
if ((m_mask + 1) > max_bucket_count() / GrowthFactor) {
TSL_RH_THROW_OR_TERMINATE(std::length_error,
"The hash table exceeds its maxmimum size.");
}
return (m_mask + 1) * GrowthFactor;
}
/**
* Return the maximum number of buckets supported by the policy.
*/
std::size_t max_bucket_count() const {
// Largest power of two.
return (std::numeric_limits<std::size_t>::max() / 2) + 1;
}
/**
* Reset the growth policy as if it was created with a bucket count of 0.
* After a clear, the policy must always return 0 when bucket_for_hash is
* called.
*/
void clear() noexcept { m_mask = 0; }
private:
static std::size_t round_up_to_power_of_two(std::size_t value) {
if (is_power_of_two(value)) {
return value;
}
if (value == 0) {
return 1;
}
--value;
for (std::size_t i = 1; i < sizeof(std::size_t) * CHAR_BIT; i *= 2) {
value |= value >> i;
}
return value + 1;
}
static constexpr bool is_power_of_two(std::size_t value) {
return value != 0 && (value & (value - 1)) == 0;
}
protected:
static_assert(is_power_of_two(GrowthFactor) && GrowthFactor >= 2,
"GrowthFactor must be a power of two >= 2.");
std::size_t m_mask;
};
/**
* Grow the hash table by GrowthFactor::num / GrowthFactor::den and use a modulo
* to map a hash to a bucket. Slower but it can be useful if you want a slower
* growth.
*/
template <class GrowthFactor = std::ratio<3, 2>> class mod_growth_policy {
public:
explicit mod_growth_policy(std::size_t &min_bucket_count_in_out) {
if (min_bucket_count_in_out > max_bucket_count()) {
TSL_RH_THROW_OR_TERMINATE(std::length_error,
"The hash table exceeds its maxmimum size.");
}
if (min_bucket_count_in_out > 0) {
m_mod = min_bucket_count_in_out;
} else {
m_mod = 1;
}
}
std::size_t bucket_for_hash(std::size_t hash) const noexcept {
return hash % m_mod;
}
std::size_t next_bucket_count() const {
if (m_mod == max_bucket_count()) {
TSL_RH_THROW_OR_TERMINATE(std::length_error,
"The hash table exceeds its maxmimum size.");
}
const double next_bucket_count =
std::ceil(double(m_mod) * REHASH_SIZE_MULTIPLICATION_FACTOR);
if (!std::isnormal(next_bucket_count)) {
TSL_RH_THROW_OR_TERMINATE(std::length_error,
"The hash table exceeds its maxmimum size.");
}
if (next_bucket_count > double(max_bucket_count())) {
return max_bucket_count();
} else {
return std::size_t(next_bucket_count);
}
}
std::size_t max_bucket_count() const { return MAX_BUCKET_COUNT; }
void clear() noexcept { m_mod = 1; }
private:
static constexpr double REHASH_SIZE_MULTIPLICATION_FACTOR =
1.0 * GrowthFactor::num / GrowthFactor::den;
static const std::size_t MAX_BUCKET_COUNT =
std::size_t(double(std::numeric_limits<std::size_t>::max() /
REHASH_SIZE_MULTIPLICATION_FACTOR));
static_assert(REHASH_SIZE_MULTIPLICATION_FACTOR >= 1.1,
"Growth factor should be >= 1.1.");
std::size_t m_mod;
};
namespace detail {
static constexpr const std::array<std::size_t, 40> PRIMES = {
{1ul, 5ul, 17ul, 29ul, 37ul,
53ul, 67ul, 79ul, 97ul, 131ul,
193ul, 257ul, 389ul, 521ul, 769ul,
1031ul, 1543ul, 2053ul, 3079ul, 6151ul,
12289ul, 24593ul, 49157ul, 98317ul, 196613ul,
393241ul, 786433ul, 1572869ul, 3145739ul, 6291469ul,
12582917ul, 25165843ul, 50331653ul, 100663319ul, 201326611ul,
402653189ul, 805306457ul, 1610612741ul, 3221225473ul, 4294967291ul}};
template <unsigned int IPrime>
static constexpr std::size_t mod(std::size_t hash) {
return hash % PRIMES[IPrime];
}
// MOD_PRIME[iprime](hash) returns hash % PRIMES[iprime]. This table allows for
// faster modulo as the compiler can optimize the modulo code better with a
// constant known at the compilation.
static constexpr const std::array<std::size_t (*)(std::size_t), 40> MOD_PRIME =
{{&mod<0>, &mod<1>, &mod<2>, &mod<3>, &mod<4>, &mod<5>, &mod<6>,
&mod<7>, &mod<8>, &mod<9>, &mod<10>, &mod<11>, &mod<12>, &mod<13>,
&mod<14>, &mod<15>, &mod<16>, &mod<17>, &mod<18>, &mod<19>, &mod<20>,
&mod<21>, &mod<22>, &mod<23>, &mod<24>, &mod<25>, &mod<26>, &mod<27>,
&mod<28>, &mod<29>, &mod<30>, &mod<31>, &mod<32>, &mod<33>, &mod<34>,
&mod<35>, &mod<36>, &mod<37>, &mod<38>, &mod<39>}};
} // namespace detail
/**
* Grow the hash table by using prime numbers as bucket count. Slower than
* tsl::rh::power_of_two_growth_policy in general but will probably distribute
* the values around better in the buckets with a poor hash function.
*
* To allow the compiler to optimize the modulo operation, a lookup table is
* used with constant primes numbers.
*
* With a switch the code would look like:
* \code
* switch(iprime) { // iprime is the current prime of the hash table
* case 0: hash % 5ul;
* break;
* case 1: hash % 17ul;
* break;
* case 2: hash % 29ul;
* break;
* ...
* }
* \endcode
*
* Due to the constant variable in the modulo the compiler is able to optimize
* the operation by a series of multiplications, substractions and shifts.
*
* The 'hash % 5' could become something like 'hash - (hash * 0xCCCCCCCD) >> 34)
* * 5' in a 64 bits environement.
*/
class prime_growth_policy {
public:
explicit prime_growth_policy(std::size_t &min_bucket_count_in_out) {
auto it_prime = std::lower_bound(
detail::PRIMES.begin(), detail::PRIMES.end(), min_bucket_count_in_out);
if (it_prime == detail::PRIMES.end()) {
TSL_RH_THROW_OR_TERMINATE(std::length_error,
"The hash table exceeds its maxmimum size.");
}
m_iprime = static_cast<unsigned int>(
std::distance(detail::PRIMES.begin(), it_prime));
if (min_bucket_count_in_out > 0) {
min_bucket_count_in_out = *it_prime;
} else {
min_bucket_count_in_out = 0;
}
}
std::size_t bucket_for_hash(std::size_t hash) const noexcept {
return detail::MOD_PRIME[m_iprime](hash);
}
std::size_t next_bucket_count() const {
if (m_iprime + 1 >= detail::PRIMES.size()) {
TSL_RH_THROW_OR_TERMINATE(std::length_error,
"The hash table exceeds its maxmimum size.");
}
return detail::PRIMES[m_iprime + 1];
}
std::size_t max_bucket_count() const { return detail::PRIMES.back(); }
void clear() noexcept { m_iprime = 0; }
private:
unsigned int m_iprime;
static_assert(std::numeric_limits<decltype(m_iprime)>::max() >=
detail::PRIMES.size(),
"The type of m_iprime is not big enough.");
};
} // namespace rh
} // namespace tsl
#endif
/**
* MIT License
*
* Copyright (c) 2017 Tessil
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_ROBIN_HASH_H
#define TSL_ROBIN_HASH_H
#include "robin_growth_policy.h"
#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <exception>
#include <iterator>
#include <limits>
#include <memory>
#include <stdexcept>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>
namespace tsl {
namespace detail_robin_hash {
template <typename T> struct make_void { using type = void; };
template <typename T, typename = void>
struct has_is_transparent : std::false_type {};
template <typename T>
struct has_is_transparent<T,
typename make_void<typename T::is_transparent>::type>
: std::true_type {};
template <typename U> struct is_power_of_two_policy : std::false_type {};
template <std::size_t GrowthFactor>
struct is_power_of_two_policy<tsl::rh::power_of_two_growth_policy<GrowthFactor>>
: std::true_type {};
// Only available in C++17, we need to be compatible with C++11
template <class T> const T &clamp(const T &v, const T &lo, const T &hi) {
return std::min(hi, std::max(lo, v));
}
using truncated_hash_type = std::uint_least32_t;
/**
* Helper class that stores a truncated hash if StoreHash is true and nothing
* otherwise.
*/
template <bool StoreHash> class bucket_entry_hash {
public:
bool bucket_hash_equal(std::size_t /*hash*/) const noexcept { return true; }
truncated_hash_type truncated_hash() const noexcept { return 0; }
protected:
void set_hash(truncated_hash_type /*hash*/) noexcept {}
};
template <> class bucket_entry_hash<true> {
public:
bool bucket_hash_equal(std::size_t hash) const noexcept {
return m_hash == truncated_hash_type(hash);
}
truncated_hash_type truncated_hash() const noexcept { return m_hash; }
protected:
void set_hash(truncated_hash_type hash) noexcept {
m_hash = truncated_hash_type(hash);
}
private:
truncated_hash_type m_hash;
};
/**
* Each bucket entry has:
* - A value of type `ValueType`.
* - An integer to store how far the value of the bucket, if any, is from its
* ideal bucket (ex: if the current bucket 5 has the value 'foo' and
* `hash('foo') % nb_buckets` == 3, `dist_from_ideal_bucket()` will return 2 as
* the current value of the bucket is two buckets away from its ideal bucket) If
* there is no value in the bucket (i.e. `empty()` is true)
* `dist_from_ideal_bucket()` will be < 0.
* - A marker which tells us if the bucket is the last bucket of the bucket
* array (useful for the iterator of the hash table).
* - If `StoreHash` is true, 32 bits of the hash of the value, if any, are also
* stored in the bucket. If the size of the hash is more than 32 bits, it is
* truncated. We don't store the full hash as storing the hash is a potential
* opportunity to use the unused space due to the alignement of the bucket_entry
* structure. We can thus potentially store the hash without any extra space
* (which would not be possible with 64 bits of the hash).
*/
template <typename ValueType, bool StoreHash>
class bucket_entry : public bucket_entry_hash<StoreHash> {
using bucket_hash = bucket_entry_hash<StoreHash>;
public:
using value_type = ValueType;
using distance_type = std::int_least16_t;
bucket_entry() noexcept
: bucket_hash(),
m_dist_from_ideal_bucket(EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET),
m_last_bucket(false) {
tsl_rh_assert(empty());
}
bucket_entry(bool last_bucket) noexcept
: bucket_hash(),
m_dist_from_ideal_bucket(EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET),
m_last_bucket(last_bucket) {
tsl_rh_assert(empty());
}
bucket_entry(const bucket_entry &other) noexcept(
std::is_nothrow_copy_constructible<value_type>::value)
: bucket_hash(other),
m_dist_from_ideal_bucket(EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET),
m_last_bucket(other.m_last_bucket) {
if (!other.empty()) {
::new (static_cast<void *>(std::addressof(m_value)))
value_type(other.value());
m_dist_from_ideal_bucket = other.m_dist_from_ideal_bucket;
}
}
/**
* Never really used, but still necessary as we must call resize on an empty
* `std::vector<bucket_entry>`. and we need to support move-only types. See
* robin_hash constructor for details.
*/
bucket_entry(bucket_entry &&other) noexcept(
std::is_nothrow_move_constructible<value_type>::value)
: bucket_hash(std::move(other)),
m_dist_from_ideal_bucket(EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET),
m_last_bucket(other.m_last_bucket) {
if (!other.empty()) {
::new (static_cast<void *>(std::addressof(m_value)))
value_type(std::move(other.value()));
m_dist_from_ideal_bucket = other.m_dist_from_ideal_bucket;
}
}
bucket_entry &operator=(const bucket_entry &other) noexcept(
std::is_nothrow_copy_constructible<value_type>::value) {
if (this != &other) {
clear();
bucket_hash::operator=(other);
if (!other.empty()) {
::new (static_cast<void *>(std::addressof(m_value)))
value_type(other.value());
}
m_dist_from_ideal_bucket = other.m_dist_from_ideal_bucket;
m_last_bucket = other.m_last_bucket;
}
return *this;
}
bucket_entry &operator=(bucket_entry &&) = delete;
~bucket_entry() noexcept { clear(); }
void clear() noexcept {
if (!empty()) {
destroy_value();
m_dist_from_ideal_bucket = EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET;
}
}
bool empty() const noexcept {
return m_dist_from_ideal_bucket == EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET;
}
value_type &value() noexcept {
tsl_rh_assert(!empty());
return *reinterpret_cast<value_type *>(std::addressof(m_value));
}
const value_type &value() const noexcept {
tsl_rh_assert(!empty());
return *reinterpret_cast<const value_type *>(std::addressof(m_value));
}
distance_type dist_from_ideal_bucket() const noexcept {
return m_dist_from_ideal_bucket;
}
bool last_bucket() const noexcept { return m_last_bucket; }
void set_as_last_bucket() noexcept { m_last_bucket = true; }
template <typename... Args>
void set_value_of_empty_bucket(distance_type dist_from_ideal_bucket,
truncated_hash_type hash,
Args &&... value_type_args) {
tsl_rh_assert(dist_from_ideal_bucket >= 0);
tsl_rh_assert(empty());
::new (static_cast<void *>(std::addressof(m_value)))
value_type(std::forward<Args>(value_type_args)...);
this->set_hash(hash);
m_dist_from_ideal_bucket = dist_from_ideal_bucket;
tsl_rh_assert(!empty());
}
void swap_with_value_in_bucket(distance_type &dist_from_ideal_bucket,
truncated_hash_type &hash, value_type &value) {
tsl_rh_assert(!empty());
using std::swap;
swap(value, this->value());
swap(dist_from_ideal_bucket, m_dist_from_ideal_bucket);
// Avoid warning of unused variable if StoreHash is false
(void)hash;
if (StoreHash) {
const truncated_hash_type tmp_hash = this->truncated_hash();
this->set_hash(hash);
hash = tmp_hash;
}
}
static truncated_hash_type truncate_hash(std::size_t hash) noexcept {
return truncated_hash_type(hash);
}
private:
void destroy_value() noexcept {
tsl_rh_assert(!empty());
value().~value_type();
}
private:
using storage = typename std::aligned_storage<sizeof(value_type),
alignof(value_type)>::type;
static const distance_type EMPTY_MARKER_DIST_FROM_IDEAL_BUCKET = -1;
distance_type m_dist_from_ideal_bucket;
bool m_last_bucket;
storage m_value;
};
/**
* Internal common class used by `robin_map` and `robin_set`.
*
* ValueType is what will be stored by `robin_hash` (usually `std::pair<Key, T>`
* for map and `Key` for set).
*
* `KeySelect` should be a `FunctionObject` which takes a `ValueType` in
* parameter and returns a reference to the key.
*
* `ValueSelect` should be a `FunctionObject` which takes a `ValueType` in
* parameter and returns a reference to the value. `ValueSelect` should be void
* if there is no value (in a set for example).
*
* The strong exception guarantee only holds if the expression
* `std::is_nothrow_swappable<ValueType>::value &&
* std::is_nothrow_move_constructible<ValueType>::value` is true.
*
* Behaviour is undefined if the destructor of `ValueType` throws.
*/
template <class ValueType, class KeySelect, class ValueSelect, class Hash,
class KeyEqual, class Allocator, bool StoreHash, class GrowthPolicy>
class robin_hash : private Hash, private KeyEqual, private GrowthPolicy {
private:
template <typename U>
using has_mapped_type =
typename std::integral_constant<bool, !std::is_same<U, void>::value>;
static_assert(
noexcept(std::declval<GrowthPolicy>().bucket_for_hash(std::size_t(0))),
"GrowthPolicy::bucket_for_hash must be noexcept.");
static_assert(noexcept(std::declval<GrowthPolicy>().clear()),
"GrowthPolicy::clear must be noexcept.");
public:
template <bool IsConst> class robin_iterator;
using key_type = typename KeySelect::key_type;
using value_type = ValueType;
using size_type = std::size_t;
using difference_type = std::ptrdiff_t;
using hasher = Hash;
using key_equal = KeyEqual;
using allocator_type = Allocator;
using reference = value_type &;
using const_reference = const value_type &;
using pointer = value_type *;
using const_pointer = const value_type *;
using iterator = robin_iterator<false>;
using const_iterator = robin_iterator<true>;
private:
/**
* Either store the hash because we are asked by the `StoreHash` template
* parameter or store the hash because it doesn't cost us anything in size and
* can be used to speed up rehash.
*/
static constexpr bool STORE_HASH =
StoreHash ||
((sizeof(tsl::detail_robin_hash::bucket_entry<value_type, true>) ==
sizeof(tsl::detail_robin_hash::bucket_entry<value_type, false>)) &&
(sizeof(std::size_t) == sizeof(truncated_hash_type) ||
is_power_of_two_policy<GrowthPolicy>::value) &&
// Don't store the hash for primitive types with default hash.
(!std::is_arithmetic<key_type>::value ||
!std::is_same<Hash, std::hash<key_type>>::value));
/**
* Only use the stored hash on lookup if we are explictly asked. We are not
* sure how slow the KeyEqual operation is. An extra comparison may slow
* things down with a fast KeyEqual.
*/
static constexpr bool USE_STORED_HASH_ON_LOOKUP = StoreHash;
/**
* We can only use the hash on rehash if the size of the hash type is the same
* as the stored one or if we use a power of two modulo. In the case of the
* power of two modulo, we just mask the least significant bytes, we just have
* to check that the truncated_hash_type didn't truncated more bytes.
*/
static bool USE_STORED_HASH_ON_REHASH(size_type bucket_count) {
(void)bucket_count;
if (STORE_HASH && sizeof(std::size_t) == sizeof(truncated_hash_type)) {
return true;
} else if (STORE_HASH && is_power_of_two_policy<GrowthPolicy>::value) {
tsl_rh_assert(bucket_count > 0);
return (bucket_count - 1) <=
std::numeric_limits<truncated_hash_type>::max();
} else {
return false;
}
}
using bucket_entry =
tsl::detail_robin_hash::bucket_entry<value_type, STORE_HASH>;
using distance_type = typename bucket_entry::distance_type;
using buckets_allocator = typename std::allocator_traits<
allocator_type>::template rebind_alloc<bucket_entry>;
using buckets_container_type = std::vector<bucket_entry, buckets_allocator>;
public:
/**
* The 'operator*()' and 'operator->()' methods return a const reference and
* const pointer respectively to the stored value type.
*
* In case of a map, to get a mutable reference to the value associated to a
* key (the '.second' in the stored pair), you have to call 'value()'.
*
* The main reason for this is that if we returned a `std::pair<Key, T>&`
* instead of a `const std::pair<Key, T>&`, the user may modify the key which
* will put the map in a undefined state.
*/
template <bool IsConst> class robin_iterator {
friend class robin_hash;
private:
using bucket_entry_ptr =
typename std::conditional<IsConst, const bucket_entry *,
bucket_entry *>::type;
robin_iterator(bucket_entry_ptr bucket) noexcept : m_bucket(bucket) {}
public:
using iterator_category = std::forward_iterator_tag;
using value_type = const typename robin_hash::value_type;
using difference_type = std::ptrdiff_t;
using reference = value_type &;
using pointer = value_type *;
robin_iterator() noexcept {}
// Copy constructor from iterator to const_iterator.
template <bool TIsConst = IsConst,
typename std::enable_if<TIsConst>::type * = nullptr>
robin_iterator(const robin_iterator<!TIsConst> &other) noexcept
: m_bucket(other.m_bucket) {}
robin_iterator(const robin_iterator &other) = default;
robin_iterator(robin_iterator &&other) = default;
robin_iterator &operator=(const robin_iterator &other) = default;
robin_iterator &operator=(robin_iterator &&other) = default;
const typename robin_hash::key_type &key() const {
return KeySelect()(m_bucket->value());
}
template <class U = ValueSelect,
typename std::enable_if<has_mapped_type<U>::value &&
IsConst>::type * = nullptr>
const typename U::value_type &value() const {
return U()(m_bucket->value());
}
template <class U = ValueSelect,
typename std::enable_if<has_mapped_type<U>::value &&
!IsConst>::type * = nullptr>
typename U::value_type &value() {
return U()(m_bucket->value());
}
reference operator*() const { return m_bucket->value(); }
pointer operator->() const { return std::addressof(m_bucket->value()); }
robin_iterator &operator++() {
while (true) {
if (m_bucket->last_bucket()) {
++m_bucket;
return *this;
}
++m_bucket;
if (!m_bucket->empty()) {
return *this;
}
}
}
robin_iterator operator++(int) {
robin_iterator tmp(*this);
++*this;
return tmp;
}
friend bool operator==(const robin_iterator &lhs,
const robin_iterator &rhs) {
return lhs.m_bucket == rhs.m_bucket;
}
friend bool operator!=(const robin_iterator &lhs,
const robin_iterator &rhs) {
return !(lhs == rhs);
}
private:
bucket_entry_ptr m_bucket;
};
public:
#if defined(__cplusplus) && __cplusplus >= 201402L
robin_hash(size_type bucket_count, const Hash &hash, const KeyEqual &equal,
const Allocator &alloc,
float min_load_factor = DEFAULT_MIN_LOAD_FACTOR,
float max_load_factor = DEFAULT_MAX_LOAD_FACTOR)
: Hash(hash), KeyEqual(equal), GrowthPolicy(bucket_count),
m_buckets_data(
[&]() {
if (bucket_count > max_bucket_count()) {
TSL_RH_THROW_OR_TERMINATE(
std::length_error,
"The map exceeds its maximum bucket count.");
}
return bucket_count;
}(),
alloc),
m_buckets(m_buckets_data.empty() ? static_empty_bucket_ptr()
: m_buckets_data.data()),
m_bucket_count(bucket_count), m_nb_elements(0),
m_grow_on_next_insert(false), m_try_skrink_on_next_insert(false) {
if (m_bucket_count > 0) {
tsl_rh_assert(!m_buckets_data.empty());
m_buckets_data.back().set_as_last_bucket();
}
this->min_load_factor(min_load_factor);
this->max_load_factor(max_load_factor);
}
#else
/**
* C++11 doesn't support the creation of a std::vector with a custom allocator
* and 'count' default-inserted elements. The needed contructor `explicit
* vector(size_type count, const Allocator& alloc = Allocator());` is only
* available in C++14 and later. We thus must resize after using the
* `vector(const Allocator& alloc)` constructor.
*
* We can't use `vector(size_type count, const T& value, const Allocator&
* alloc)` as it requires the value T to be copyable.
*/
robin_hash(size_type bucket_count, const Hash &hash, const KeyEqual &equal,
const Allocator &alloc,
float min_load_factor = DEFAULT_MIN_LOAD_FACTOR,
float max_load_factor = DEFAULT_MAX_LOAD_FACTOR)
: Hash(hash), KeyEqual(equal), GrowthPolicy(bucket_count),
m_buckets_data(alloc), m_buckets(static_empty_bucket_ptr()),
m_bucket_count(bucket_count), m_nb_elements(0),
m_grow_on_next_insert(false), m_try_skrink_on_next_insert(false) {
if (bucket_count > max_bucket_count()) {
TSL_RH_THROW_OR_TERMINATE(std::length_error,
"The map exceeds its maxmimum bucket count.");
}
if (m_bucket_count > 0) {
m_buckets_data.resize(m_bucket_count);
m_buckets = m_buckets_data.data();
tsl_rh_assert(!m_buckets_data.empty());
m_buckets_data.back().set_as_last_bucket();
}
this->min_load_factor(min_load_factor);
this->max_load_factor(max_load_factor);
}
#endif
robin_hash(const robin_hash &other)
: Hash(other), KeyEqual(other), GrowthPolicy(other),
m_buckets_data(other.m_buckets_data),
m_buckets(m_buckets_data.empty() ? static_empty_bucket_ptr()
: m_buckets_data.data()),
m_bucket_count(other.m_bucket_count),
m_nb_elements(other.m_nb_elements),
m_load_threshold(other.m_load_threshold),
m_max_load_factor(other.m_max_load_factor),
m_grow_on_next_insert(other.m_grow_on_next_insert),
m_min_load_factor(other.m_min_load_factor),
m_try_skrink_on_next_insert(other.m_try_skrink_on_next_insert) {}
robin_hash(robin_hash &&other) noexcept(
std::is_nothrow_move_constructible<
Hash>::value &&std::is_nothrow_move_constructible<KeyEqual>::value
&&std::is_nothrow_move_constructible<GrowthPolicy>::value &&
std::is_nothrow_move_constructible<buckets_container_type>::value)
: Hash(std::move(static_cast<Hash &>(other))),
KeyEqual(std::move(static_cast<KeyEqual &>(other))),
GrowthPolicy(std::move(static_cast<GrowthPolicy &>(other))),
m_buckets_data(std::move(other.m_buckets_data)),
m_buckets(m_buckets_data.empty() ? static_empty_bucket_ptr()
: m_buckets_data.data()),
m_bucket_count(other.m_bucket_count),
m_nb_elements(other.m_nb_elements),
m_load_threshold(other.m_load_threshold),
m_max_load_factor(other.m_max_load_factor),
m_grow_on_next_insert(other.m_grow_on_next_insert),
m_min_load_factor(other.m_min_load_factor),
m_try_skrink_on_next_insert(other.m_try_skrink_on_next_insert) {
other.GrowthPolicy::clear();
other.m_buckets_data.clear();
other.m_buckets = static_empty_bucket_ptr();
other.m_bucket_count = 0;
other.m_nb_elements = 0;
other.m_load_threshold = 0;
other.m_grow_on_next_insert = false;
other.m_try_skrink_on_next_insert = false;
}
robin_hash &operator=(const robin_hash &other) {
if (&other != this) {
Hash::operator=(other);
KeyEqual::operator=(other);
GrowthPolicy::operator=(other);
m_buckets_data = other.m_buckets_data;
m_buckets = m_buckets_data.empty() ? static_empty_bucket_ptr()
: m_buckets_data.data();
m_bucket_count = other.m_bucket_count;
m_nb_elements = other.m_nb_elements;
m_load_threshold = other.m_load_threshold;
m_max_load_factor = other.m_max_load_factor;
m_grow_on_next_insert = other.m_grow_on_next_insert;
m_min_load_factor = other.m_min_load_factor;
m_try_skrink_on_next_insert = other.m_try_skrink_on_next_insert;
}
return *this;
}
robin_hash &operator=(robin_hash &&other) {
other.swap(*this);
other.clear();
return *this;
}
allocator_type get_allocator() const {
return m_buckets_data.get_allocator();
}
/*
* Iterators
*/
iterator begin() noexcept {
std::size_t i = 0;
while (i < m_bucket_count && m_buckets[i].empty()) {
i++;
}
return iterator(m_buckets + i);
}
const_iterator begin() const noexcept { return cbegin(); }
const_iterator cbegin() const noexcept {
std::size_t i = 0;
while (i < m_bucket_count && m_buckets[i].empty()) {
i++;
}
return const_iterator(m_buckets + i);
}
iterator end() noexcept { return iterator(m_buckets + m_bucket_count); }
const_iterator end() const noexcept { return cend(); }
const_iterator cend() const noexcept {
return const_iterator(m_buckets + m_bucket_count);
}
/*
* Capacity
*/
bool empty() const noexcept { return m_nb_elements == 0; }
size_type size() const noexcept { return m_nb_elements; }
size_type max_size() const noexcept { return m_buckets_data.max_size(); }
/*
* Modifiers
*/
void clear() noexcept {
for (auto &bucket : m_buckets_data) {
bucket.clear();
}
m_nb_elements = 0;
m_grow_on_next_insert = false;
}
template <typename P> std::pair<iterator, bool> insert(P &&value) {
return insert_impl(KeySelect()(value), std::forward<P>(value));
}
template <typename P> iterator insert_hint(const_iterator hint, P &&value) {
if (hint != cend() &&
compare_keys(KeySelect()(*hint), KeySelect()(value))) {
return mutable_iterator(hint);
}
return insert(std::forward<P>(value)).first;
}
template <class InputIt> void insert(InputIt first, InputIt last) {
if (std::is_base_of<
std::forward_iterator_tag,
typename std::iterator_traits<InputIt>::iterator_category>::value) {
const auto nb_elements_insert = std::distance(first, last);
const size_type nb_free_buckets = m_load_threshold - size();
tsl_rh_assert(m_load_threshold >= size());
if (nb_elements_insert > 0 &&
nb_free_buckets < size_type(nb_elements_insert)) {
reserve(size() + size_type(nb_elements_insert));
}
}
for (; first != last; ++first) {
insert(*first);
}
}
template <class K, class M>
std::pair<iterator, bool> insert_or_assign(K &&key, M &&obj) {
auto it = try_emplace(std::forward<K>(key), std::forward<M>(obj));
if (!it.second) {
it.first.value() = std::forward<M>(obj);
}
return it;
}
template <class K, class M>
iterator insert_or_assign(const_iterator hint, K &&key, M &&obj) {
if (hint != cend() && compare_keys(KeySelect()(*hint), key)) {
auto it = mutable_iterator(hint);
it.value() = std::forward<M>(obj);
return it;
}
return insert_or_assign(std::forward<K>(key), std::forward<M>(obj)).first;
}
template <class... Args> std::pair<iterator, bool> emplace(Args &&... args) {
return insert(value_type(std::forward<Args>(args)...));
}
template <class... Args>
iterator emplace_hint(const_iterator hint, Args &&... args) {
return insert_hint(hint, value_type(std::forward<Args>(args)...));
}
template <class K, class... Args>
std::pair<iterator, bool> try_emplace(K &&key, Args &&... args) {
return insert_impl(key, std::piecewise_construct,
std::forward_as_tuple(std::forward<K>(key)),
std::forward_as_tuple(std::forward<Args>(args)...));
}
template <class K, class... Args>
iterator try_emplace_hint(const_iterator hint, K &&key, Args &&... args) {
if (hint != cend() && compare_keys(KeySelect()(*hint), key)) {
return mutable_iterator(hint);
}
return try_emplace(std::forward<K>(key), std::forward<Args>(args)...).first;
}
/**
* Here to avoid `template<class K> size_type erase(const K& key)` being used
* when we use an `iterator` instead of a `const_iterator`.
*/
iterator erase(iterator pos) {
erase_from_bucket(pos);
/**
* Erase bucket used a backward shift after clearing the bucket.
* Check if there is a new value in the bucket, if not get the next
* non-empty.
*/
if (pos.m_bucket->empty()) {
++pos;
}
m_try_skrink_on_next_insert = true;
return pos;
}
iterator erase(const_iterator pos) { return erase(mutable_iterator(pos)); }
iterator erase(const_iterator first, const_iterator last) {
if (first == last) {
return mutable_iterator(first);
}
auto first_mutable = mutable_iterator(first);
auto last_mutable = mutable_iterator(last);
for (auto it = first_mutable.m_bucket; it != last_mutable.m_bucket; ++it) {
if (!it->empty()) {
it->clear();
m_nb_elements--;
}
}
if (last_mutable == end()) {
return end();
}
/*
* Backward shift on the values which come after the deleted values.
* We try to move the values closer to their ideal bucket.
*/
std::size_t icloser_bucket =
static_cast<std::size_t>(first_mutable.m_bucket - m_buckets);
std::size_t ito_move_closer_value =
static_cast<std::size_t>(last_mutable.m_bucket - m_buckets);
tsl_rh_assert(ito_move_closer_value > icloser_bucket);
const std::size_t ireturn_bucket =
ito_move_closer_value -
std::min(
ito_move_closer_value - icloser_bucket,
std::size_t(
m_buckets[ito_move_closer_value].dist_from_ideal_bucket()));
while (ito_move_closer_value < m_bucket_count &&
m_buckets[ito_move_closer_value].dist_from_ideal_bucket() > 0) {
icloser_bucket =
ito_move_closer_value -
std::min(
ito_move_closer_value - icloser_bucket,
std::size_t(
m_buckets[ito_move_closer_value].dist_from_ideal_bucket()));
tsl_rh_assert(m_buckets[icloser_bucket].empty());
const distance_type new_distance = distance_type(
m_buckets[ito_move_closer_value].dist_from_ideal_bucket() -
(ito_move_closer_value - icloser_bucket));
m_buckets[icloser_bucket].set_value_of_empty_bucket(
new_distance, m_buckets[ito_move_closer_value].truncated_hash(),
std::move(m_buckets[ito_move_closer_value].value()));
m_buckets[ito_move_closer_value].clear();
++icloser_bucket;
++ito_move_closer_value;
}
m_try_skrink_on_next_insert = true;
return iterator(m_buckets + ireturn_bucket);
}
template <class K> size_type erase(const K &key) {
return erase(key, hash_key(key));
}
template <class K> size_type erase(const K &key, std::size_t hash) {
auto it = find(key, hash);
if (it != end()) {
erase_from_bucket(it);
m_try_skrink_on_next_insert = true;
return 1;
} else {
return 0;
}
}
void swap(robin_hash &other) {
using std::swap;
swap(static_cast<Hash &>(*this), static_cast<Hash &>(other));
swap(static_cast<KeyEqual &>(*this), static_cast<KeyEqual &>(other));
swap(static_cast<GrowthPolicy &>(*this),
static_cast<GrowthPolicy &>(other));
swap(m_buckets_data, other.m_buckets_data);
swap(m_buckets, other.m_buckets);
swap(m_bucket_count, other.m_bucket_count);
swap(m_nb_elements, other.m_nb_elements);
swap(m_load_threshold, other.m_load_threshold);
swap(m_max_load_factor, other.m_max_load_factor);
swap(m_grow_on_next_insert, other.m_grow_on_next_insert);
swap(m_min_load_factor, other.m_min_load_factor);
swap(m_try_skrink_on_next_insert, other.m_try_skrink_on_next_insert);
}
/*
* Lookup
*/
template <
class K, class U = ValueSelect,
typename std::enable_if<has_mapped_type<U>::value>::type * = nullptr>
typename U::value_type &at(const K &key) {
return at(key, hash_key(key));
}
template <
class K, class U = ValueSelect,
typename std::enable_if<has_mapped_type<U>::value>::type * = nullptr>
typename U::value_type &at(const K &key, std::size_t hash) {
return const_cast<typename U::value_type &>(
static_cast<const robin_hash *>(this)->at(key, hash));
}
template <
class K, class U = ValueSelect,
typename std::enable_if<has_mapped_type<U>::value>::type * = nullptr>
const typename U::value_type &at(const K &key) const {
return at(key, hash_key(key));
}
template <
class K, class U = ValueSelect,
typename std::enable_if<has_mapped_type<U>::value>::type * = nullptr>
const typename U::value_type &at(const K &key, std::size_t hash) const {
auto it = find(key, hash);
if (it != cend()) {
return it.value();
} else {
TSL_RH_THROW_OR_TERMINATE(std::out_of_range, "Couldn't find key.");
}
}
template <
class K, class U = ValueSelect,
typename std::enable_if<has_mapped_type<U>::value>::type * = nullptr>
typename U::value_type &operator[](K &&key) {
return try_emplace(std::forward<K>(key)).first.value();
}
template <class K> size_type count(const K &key) const {
return count(key, hash_key(key));
}
template <class K> size_type count(const K &key, std::size_t hash) const {
if (find(key, hash) != cend()) {
return 1;
} else {
return 0;
}
}
template <class K> iterator find(const K &key) {
return find_impl(key, hash_key(key));
}
template <class K> iterator find(const K &key, std::size_t hash) {
return find_impl(key, hash);
}
template <class K> const_iterator find(const K &key) const {
return find_impl(key, hash_key(key));
}
template <class K> const_iterator find(const K &key, std::size_t hash) const {
return find_impl(key, hash);
}
template <class K> std::pair<iterator, iterator> equal_range(const K &key) {
return equal_range(key, hash_key(key));
}
template <class K>
std::pair<iterator, iterator> equal_range(const K &key, std::size_t hash) {
iterator it = find(key, hash);
return std::make_pair(it, (it == end()) ? it : std::next(it));
}
template <class K>
std::pair<const_iterator, const_iterator> equal_range(const K &key) const {
return equal_range(key, hash_key(key));
}
template <class K>
std::pair<const_iterator, const_iterator>
equal_range(const K &key, std::size_t hash) const {
const_iterator it = find(key, hash);
return std::make_pair(it, (it == cend()) ? it : std::next(it));
}
/*
* Bucket interface
*/
size_type bucket_count() const { return m_bucket_count; }
size_type max_bucket_count() const {
return std::min(GrowthPolicy::max_bucket_count(),
m_buckets_data.max_size());
}
/*
* Hash policy
*/
float load_factor() const {
if (bucket_count() == 0) {
return 0;
}
return float(m_nb_elements) / float(bucket_count());
}
float min_load_factor() const { return m_min_load_factor; }
float max_load_factor() const { return m_max_load_factor; }
void min_load_factor(float ml) {
m_min_load_factor = clamp(ml, float(MINIMUM_MIN_LOAD_FACTOR),
float(MAXIMUM_MIN_LOAD_FACTOR));
}
void max_load_factor(float ml) {
m_max_load_factor = clamp(ml, float(MINIMUM_MAX_LOAD_FACTOR),
float(MAXIMUM_MAX_LOAD_FACTOR));
m_load_threshold = size_type(float(bucket_count()) * m_max_load_factor);
}
void rehash(size_type count) {
count = std::max(count,
size_type(std::ceil(float(size()) / max_load_factor())));
rehash_impl(count);
}
void reserve(size_type count) {
rehash(size_type(std::ceil(float(count) / max_load_factor())));
}
/*
* Observers
*/
hasher hash_function() const { return static_cast<const Hash &>(*this); }
key_equal key_eq() const { return static_cast<const KeyEqual &>(*this); }
/*
* Other
*/
iterator mutable_iterator(const_iterator pos) {
return iterator(const_cast<bucket_entry *>(pos.m_bucket));
}
private:
template <class K> std::size_t hash_key(const K &key) const {
return Hash::operator()(key);
}
template <class K1, class K2>
bool compare_keys(const K1 &key1, const K2 &key2) const {
return KeyEqual::operator()(key1, key2);
}
std::size_t bucket_for_hash(std::size_t hash) const {
const std::size_t bucket = GrowthPolicy::bucket_for_hash(hash);
tsl_rh_assert(bucket < m_bucket_count ||
(bucket == 0 && m_bucket_count == 0));
return bucket;
}
template <class U = GrowthPolicy,
typename std::enable_if<is_power_of_two_policy<U>::value>::type * =
nullptr>
std::size_t next_bucket(std::size_t index) const noexcept {
tsl_rh_assert(index < bucket_count());
return (index + 1) & this->m_mask;
}
template <class U = GrowthPolicy,
typename std::enable_if<!is_power_of_two_policy<U>::value>::type * =
nullptr>
std::size_t next_bucket(std::size_t index) const noexcept {
tsl_rh_assert(index < bucket_count());
index++;
return (index != bucket_count()) ? index : 0;
}
template <class K> iterator find_impl(const K &key, std::size_t hash) {
return mutable_iterator(
static_cast<const robin_hash *>(this)->find(key, hash));
}
template <class K>
const_iterator find_impl(const K &key, std::size_t hash) const {
std::size_t ibucket = bucket_for_hash(hash);
distance_type dist_from_ideal_bucket = 0;
while (dist_from_ideal_bucket <=
m_buckets[ibucket].dist_from_ideal_bucket()) {
if (TSL_RH_LIKELY(
(!USE_STORED_HASH_ON_LOOKUP ||
m_buckets[ibucket].bucket_hash_equal(hash)) &&
compare_keys(KeySelect()(m_buckets[ibucket].value()), key))) {
return const_iterator(m_buckets + ibucket);
}
ibucket = next_bucket(ibucket);
dist_from_ideal_bucket++;
}
return cend();
}
void erase_from_bucket(iterator pos) {
pos.m_bucket->clear();
m_nb_elements--;
/**
* Backward shift, swap the empty bucket, previous_ibucket, with the values
* on its right, ibucket, until we cross another empty bucket or if the
* other bucket has a distance_from_ideal_bucket == 0.
*
* We try to move the values closer to their ideal bucket.
*/
std::size_t previous_ibucket =
static_cast<std::size_t>(pos.m_bucket - m_buckets);
std::size_t ibucket = next_bucket(previous_ibucket);
while (m_buckets[ibucket].dist_from_ideal_bucket() > 0) {
tsl_rh_assert(m_buckets[previous_ibucket].empty());
const distance_type new_distance =
distance_type(m_buckets[ibucket].dist_from_ideal_bucket() - 1);
m_buckets[previous_ibucket].set_value_of_empty_bucket(
new_distance, m_buckets[ibucket].truncated_hash(),
std::move(m_buckets[ibucket].value()));
m_buckets[ibucket].clear();
previous_ibucket = ibucket;
ibucket = next_bucket(ibucket);
}
}
template <class K, class... Args>
std::pair<iterator, bool> insert_impl(const K &key,
Args &&... value_type_args) {
const std::size_t hash = hash_key(key);
std::size_t ibucket = bucket_for_hash(hash);
distance_type dist_from_ideal_bucket = 0;
while (dist_from_ideal_bucket <=
m_buckets[ibucket].dist_from_ideal_bucket()) {
if ((!USE_STORED_HASH_ON_LOOKUP ||
m_buckets[ibucket].bucket_hash_equal(hash)) &&
compare_keys(KeySelect()(m_buckets[ibucket].value()), key)) {
return std::make_pair(iterator(m_buckets + ibucket), false);
}
ibucket = next_bucket(ibucket);
dist_from_ideal_bucket++;
}
if (rehash_on_extreme_load()) {
ibucket = bucket_for_hash(hash);
dist_from_ideal_bucket = 0;
while (dist_from_ideal_bucket <=
m_buckets[ibucket].dist_from_ideal_bucket()) {
ibucket = next_bucket(ibucket);
dist_from_ideal_bucket++;
}
}
if (m_buckets[ibucket].empty()) {
m_buckets[ibucket].set_value_of_empty_bucket(
dist_from_ideal_bucket, bucket_entry::truncate_hash(hash),
std::forward<Args>(value_type_args)...);
} else {
insert_value(ibucket, dist_from_ideal_bucket,
bucket_entry::truncate_hash(hash),
std::forward<Args>(value_type_args)...);
}
m_nb_elements++;
/*
* The value will be inserted in ibucket in any case, either because it was
* empty or by stealing the bucket (robin hood).
*/
return std::make_pair(iterator(m_buckets + ibucket), true);
}
template <class... Args>
void insert_value(std::size_t ibucket, distance_type dist_from_ideal_bucket,
truncated_hash_type hash, Args &&... value_type_args) {
value_type value(std::forward<Args>(value_type_args)...);
insert_value_impl(ibucket, dist_from_ideal_bucket, hash, value);
}
void insert_value(std::size_t ibucket, distance_type dist_from_ideal_bucket,
truncated_hash_type hash, value_type &&value) {
insert_value_impl(ibucket, dist_from_ideal_bucket, hash, value);
}
/*
* We don't use `value_type&& value` as last argument due to a bug in MSVC
* when `value_type` is a pointer, The compiler is not able to see the
* difference between `std::string*` and `std::string*&&` resulting in compile
* error.
*
* The `value` will be in a moved state at the end of the function.
*/
void insert_value_impl(std::size_t ibucket,
distance_type dist_from_ideal_bucket,
truncated_hash_type hash, value_type &value) {
m_buckets[ibucket].swap_with_value_in_bucket(dist_from_ideal_bucket, hash,
value);
ibucket = next_bucket(ibucket);
dist_from_ideal_bucket++;
while (!m_buckets[ibucket].empty()) {
if (dist_from_ideal_bucket >
m_buckets[ibucket].dist_from_ideal_bucket()) {
if (dist_from_ideal_bucket >= REHASH_ON_HIGH_NB_PROBES__NPROBES &&
load_factor() >= REHASH_ON_HIGH_NB_PROBES__MIN_LOAD_FACTOR) {
/**
* The number of probes is really high, rehash the map on the next
* insert. Difficult to do now as rehash may throw an exception.
*/
m_grow_on_next_insert = true;
}
m_buckets[ibucket].swap_with_value_in_bucket(dist_from_ideal_bucket,
hash, value);
}
ibucket = next_bucket(ibucket);
dist_from_ideal_bucket++;
}
m_buckets[ibucket].set_value_of_empty_bucket(dist_from_ideal_bucket, hash,
std::move(value));
}
void rehash_impl(size_type count) {
robin_hash new_table(count, static_cast<Hash &>(*this),
static_cast<KeyEqual &>(*this), get_allocator(),
m_min_load_factor, m_max_load_factor);
const bool use_stored_hash =
USE_STORED_HASH_ON_REHASH(new_table.bucket_count());
for (auto &bucket : m_buckets_data) {
if (bucket.empty()) {
continue;
}
const std::size_t hash =
use_stored_hash ? bucket.truncated_hash()
: new_table.hash_key(KeySelect()(bucket.value()));
new_table.insert_value_on_rehash(new_table.bucket_for_hash(hash), 0,
bucket_entry::truncate_hash(hash),
std::move(bucket.value()));
}
new_table.m_nb_elements = m_nb_elements;
new_table.swap(*this);
}
void insert_value_on_rehash(std::size_t ibucket,
distance_type dist_from_ideal_bucket,
truncated_hash_type hash, value_type &&value) {
while (true) {
if (dist_from_ideal_bucket >
m_buckets[ibucket].dist_from_ideal_bucket()) {
if (m_buckets[ibucket].empty()) {
m_buckets[ibucket].set_value_of_empty_bucket(dist_from_ideal_bucket,
hash, std::move(value));
return;
} else {
m_buckets[ibucket].swap_with_value_in_bucket(dist_from_ideal_bucket,
hash, value);
}
}
dist_from_ideal_bucket++;
ibucket = next_bucket(ibucket);
}
}
/**
* Grow the table if m_grow_on_next_insert is true or we reached the
* max_load_factor. Shrink the table if m_try_skrink_on_next_insert is true
* (an erase occured) and we're below the min_load_factor.
*
* Return true if the table has been rehashed.
*/
bool rehash_on_extreme_load() {
if (m_grow_on_next_insert || size() >= m_load_threshold) {
rehash_impl(GrowthPolicy::next_bucket_count());
m_grow_on_next_insert = false;
return true;
}
if (m_try_skrink_on_next_insert) {
m_try_skrink_on_next_insert = false;
if (m_min_load_factor != 0.0f && load_factor() < m_min_load_factor) {
reserve(size() + 1);
return true;
}
}
return false;
}
public:
static const size_type DEFAULT_INIT_BUCKETS_SIZE = 0;
static constexpr float DEFAULT_MAX_LOAD_FACTOR = 0.5f;
static constexpr float MINIMUM_MAX_LOAD_FACTOR = 0.2f;
static constexpr float MAXIMUM_MAX_LOAD_FACTOR = 0.95f;
static constexpr float DEFAULT_MIN_LOAD_FACTOR = 0.0f;
static constexpr float MINIMUM_MIN_LOAD_FACTOR = 0.0f;
static constexpr float MAXIMUM_MIN_LOAD_FACTOR = 0.15f;
static_assert(MINIMUM_MAX_LOAD_FACTOR < MAXIMUM_MAX_LOAD_FACTOR,
"MINIMUM_MAX_LOAD_FACTOR should be < MAXIMUM_MAX_LOAD_FACTOR");
static_assert(MINIMUM_MIN_LOAD_FACTOR < MAXIMUM_MIN_LOAD_FACTOR,
"MINIMUM_MIN_LOAD_FACTOR should be < MAXIMUM_MIN_LOAD_FACTOR");
static_assert(MAXIMUM_MIN_LOAD_FACTOR < MINIMUM_MAX_LOAD_FACTOR,
"MAXIMUM_MIN_LOAD_FACTOR should be < MINIMUM_MAX_LOAD_FACTOR");
private:
static const distance_type REHASH_ON_HIGH_NB_PROBES__NPROBES = 128;
static constexpr float REHASH_ON_HIGH_NB_PROBES__MIN_LOAD_FACTOR = 0.15f;
/**
* Return an always valid pointer to an static empty bucket_entry with
* last_bucket() == true.
*/
bucket_entry *static_empty_bucket_ptr() {
static bucket_entry empty_bucket(true);
return &empty_bucket;
}
private:
buckets_container_type m_buckets_data;
/**
* Points to m_buckets_data.data() if !m_buckets_data.empty() otherwise points
* to static_empty_bucket_ptr. This variable is useful to avoid the cost of
* checking if m_buckets_data is empty when trying to find an element.
*
* TODO Remove m_buckets_data and only use a pointer instead of a
* pointer+vector to save some space in the robin_hash object. Manage the
* Allocator manually.
*/
bucket_entry *m_buckets;
/**
* Used a lot in find, avoid the call to m_buckets_data.size() which is a bit
* slower.
*/
size_type m_bucket_count;
size_type m_nb_elements;
size_type m_load_threshold;
float m_max_load_factor;
bool m_grow_on_next_insert;
float m_min_load_factor;
/**
* We can't shrink down the map on erase operations as the erase methods need
* to return the next iterator. Shrinking the map would invalidate all the
* iterators and we could not return the next iterator in a meaningful way, On
* erase, we thus just indicate on erase that we should try to shrink the hash
* table on the next insert if we go below the min_load_factor.
*/
bool m_try_skrink_on_next_insert;
};
} // namespace detail_robin_hash
} // namespace tsl
#endif
/**
* MIT License
*
* Copyright (c) 2017 Tessil
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_ROBIN_MAP_H
#define TSL_ROBIN_MAP_H
#include "robin_hash.h"
#include <cstddef>
#include <functional>
#include <initializer_list>
#include <memory>
#include <type_traits>
#include <utility>
namespace tsl {
/**
* Implementation of a hash map using open-adressing and the robin hood hashing
* algorithm with backward shift deletion.
*
* For operations modifying the hash map (insert, erase, rehash, ...), the
* strong exception guarantee is only guaranteed when the expression
* `std::is_nothrow_swappable<std::pair<Key, T>>::value &&
* std::is_nothrow_move_constructible<std::pair<Key, T>>::value` is true,
* otherwise if an exception is thrown during the swap or the move, the hash map
* may end up in a undefined state. Per the standard a `Key` or `T` with a
* noexcept copy constructor and no move constructor also satisfies the
* `std::is_nothrow_move_constructible<std::pair<Key, T>>::value` criterion (and
* will thus guarantee the strong exception for the map).
*
* When `StoreHash` is true, 32 bits of the hash are stored alongside the
* values. It can improve the performance during lookups if the `KeyEqual`
* function takes time (if it engenders a cache-miss for example) as we then
* compare the stored hashes before comparing the keys. When
* `tsl::rh::power_of_two_growth_policy` is used as `GrowthPolicy`, it may also
* speed-up the rehash process as we can avoid to recalculate the hash. When it
* is detected that storing the hash will not incur any memory penality due to
* alignement (i.e. `sizeof(tsl::detail_robin_hash::bucket_entry<ValueType,
* true>) == sizeof(tsl::detail_robin_hash::bucket_entry<ValueType, false>)`)
* and `tsl::rh::power_of_two_growth_policy` is used, the hash will be stored
* even if `StoreHash` is false so that we can speed-up the rehash (but it will
* not be used on lookups unless `StoreHash` is true).
*
* `GrowthPolicy` defines how the map grows and consequently how a hash value is
* mapped to a bucket. By default the map uses
* `tsl::rh::power_of_two_growth_policy`. This policy keeps the number of
* buckets to a power of two and uses a mask to map the hash to a bucket instead
* of the slow modulo. Other growth policies are available and you may define
* your own growth policy, check `tsl::rh::power_of_two_growth_policy` for the
* interface.
*
* `std::pair<Key, T>` must be swappable.
*
* `Key` and `T` must be copy and/or move constructible.
*
* If the destructor of `Key` or `T` throws an exception, the behaviour of the
* class is undefined.
*
* Iterators invalidation:
* - clear, operator=, reserve, rehash: always invalidate the iterators.
* - insert, emplace, emplace_hint, operator[]: if there is an effective
* insert, invalidate the iterators.
* - erase: always invalidate the iterators.
*/
template <class Key, class T, class Hash = std::hash<Key>,
class KeyEqual = std::equal_to<Key>,
class Allocator = std::allocator<std::pair<Key, T>>,
bool StoreHash = false,
class GrowthPolicy = tsl::rh::power_of_two_growth_policy<2>>
class robin_map {
private:
template <typename U>
using has_is_transparent = tsl::detail_robin_hash::has_is_transparent<U>;
class KeySelect {
public:
using key_type = Key;
const key_type &
operator()(const std::pair<Key, T> &key_value) const noexcept {
return key_value.first;
}
key_type &operator()(std::pair<Key, T> &key_value) noexcept {
return key_value.first;
}
};
class ValueSelect {
public:
using value_type = T;
const value_type &
operator()(const std::pair<Key, T> &key_value) const noexcept {
return key_value.second;
}
value_type &operator()(std::pair<Key, T> &key_value) noexcept {
return key_value.second;
}
};
using ht = detail_robin_hash::robin_hash<std::pair<Key, T>, KeySelect,
ValueSelect, Hash, KeyEqual,
Allocator, StoreHash, GrowthPolicy>;
public:
using key_type = typename ht::key_type;
using mapped_type = T;
using value_type = typename ht::value_type;
using size_type = typename ht::size_type;
using difference_type = typename ht::difference_type;
using hasher = typename ht::hasher;
using key_equal = typename ht::key_equal;
using allocator_type = typename ht::allocator_type;
using reference = typename ht::reference;
using const_reference = typename ht::const_reference;
using pointer = typename ht::pointer;
using const_pointer = typename ht::const_pointer;
using iterator = typename ht::iterator;
using const_iterator = typename ht::const_iterator;
public:
/*
* Constructors
*/
robin_map() : robin_map(ht::DEFAULT_INIT_BUCKETS_SIZE) {}
explicit robin_map(size_type bucket_count, const Hash &hash = Hash(),
const KeyEqual &equal = KeyEqual(),
const Allocator &alloc = Allocator())
: m_ht(bucket_count, hash, equal, alloc) {}
robin_map(size_type bucket_count, const Allocator &alloc)
: robin_map(bucket_count, Hash(), KeyEqual(), alloc) {}
robin_map(size_type bucket_count, const Hash &hash, const Allocator &alloc)
: robin_map(bucket_count, hash, KeyEqual(), alloc) {}
explicit robin_map(const Allocator &alloc)
: robin_map(ht::DEFAULT_INIT_BUCKETS_SIZE, alloc) {}
template <class InputIt>
robin_map(InputIt first, InputIt last,
size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE,
const Hash &hash = Hash(), const KeyEqual &equal = KeyEqual(),
const Allocator &alloc = Allocator())
: robin_map(bucket_count, hash, equal, alloc) {
insert(first, last);
}
template <class InputIt>
robin_map(InputIt first, InputIt last, size_type bucket_count,
const Allocator &alloc)
: robin_map(first, last, bucket_count, Hash(), KeyEqual(), alloc) {}
template <class InputIt>
robin_map(InputIt first, InputIt last, size_type bucket_count,
const Hash &hash, const Allocator &alloc)
: robin_map(first, last, bucket_count, hash, KeyEqual(), alloc) {}
robin_map(std::initializer_list<value_type> init,
size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE,
const Hash &hash = Hash(), const KeyEqual &equal = KeyEqual(),
const Allocator &alloc = Allocator())
: robin_map(init.begin(), init.end(), bucket_count, hash, equal, alloc) {}
robin_map(std::initializer_list<value_type> init, size_type bucket_count,
const Allocator &alloc)
: robin_map(init.begin(), init.end(), bucket_count, Hash(), KeyEqual(),
alloc) {}
robin_map(std::initializer_list<value_type> init, size_type bucket_count,
const Hash &hash, const Allocator &alloc)
: robin_map(init.begin(), init.end(), bucket_count, hash, KeyEqual(),
alloc) {}
robin_map &operator=(std::initializer_list<value_type> ilist) {
m_ht.clear();
m_ht.reserve(ilist.size());
m_ht.insert(ilist.begin(), ilist.end());
return *this;
}
allocator_type get_allocator() const { return m_ht.get_allocator(); }
/*
* Iterators
*/
iterator begin() noexcept { return m_ht.begin(); }
const_iterator begin() const noexcept { return m_ht.begin(); }
const_iterator cbegin() const noexcept { return m_ht.cbegin(); }
iterator end() noexcept { return m_ht.end(); }
const_iterator end() const noexcept { return m_ht.end(); }
const_iterator cend() const noexcept { return m_ht.cend(); }
/*
* Capacity
*/
bool empty() const noexcept { return m_ht.empty(); }
size_type size() const noexcept { return m_ht.size(); }
size_type max_size() const noexcept { return m_ht.max_size(); }
/*
* Modifiers
*/
void clear() noexcept { m_ht.clear(); }
std::pair<iterator, bool> insert(const value_type &value) {
return m_ht.insert(value);
}
template <class P, typename std::enable_if<std::is_constructible<
value_type, P &&>::value>::type * = nullptr>
std::pair<iterator, bool> insert(P &&value) {
return m_ht.emplace(std::forward<P>(value));
}
std::pair<iterator, bool> insert(value_type &&value) {
return m_ht.insert(std::move(value));
}
iterator insert(const_iterator hint, const value_type &value) {
return m_ht.insert_hint(hint, value);
}
template <class P, typename std::enable_if<std::is_constructible<
value_type, P &&>::value>::type * = nullptr>
iterator insert(const_iterator hint, P &&value) {
return m_ht.emplace_hint(hint, std::forward<P>(value));
}
iterator insert(const_iterator hint, value_type &&value) {
return m_ht.insert_hint(hint, std::move(value));
}
template <class InputIt> void insert(InputIt first, InputIt last) {
m_ht.insert(first, last);
}
void insert(std::initializer_list<value_type> ilist) {
m_ht.insert(ilist.begin(), ilist.end());
}
template <class M>
std::pair<iterator, bool> insert_or_assign(const key_type &k, M &&obj) {
return m_ht.insert_or_assign(k, std::forward<M>(obj));
}
template <class M>
std::pair<iterator, bool> insert_or_assign(key_type &&k, M &&obj) {
return m_ht.insert_or_assign(std::move(k), std::forward<M>(obj));
}
template <class M>
iterator insert_or_assign(const_iterator hint, const key_type &k, M &&obj) {
return m_ht.insert_or_assign(hint, k, std::forward<M>(obj));
}
template <class M>
iterator insert_or_assign(const_iterator hint, key_type &&k, M &&obj) {
return m_ht.insert_or_assign(hint, std::move(k), std::forward<M>(obj));
}
/**
* Due to the way elements are stored, emplace will need to move or copy the
* key-value once. The method is equivalent to
* insert(value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template <class... Args> std::pair<iterator, bool> emplace(Args &&... args) {
return m_ht.emplace(std::forward<Args>(args)...);
}
/**
* Due to the way elements are stored, emplace_hint will need to move or copy
* the key-value once. The method is equivalent to insert(hint,
* value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template <class... Args>
iterator emplace_hint(const_iterator hint, Args &&... args) {
return m_ht.emplace_hint(hint, std::forward<Args>(args)...);
}
template <class... Args>
std::pair<iterator, bool> try_emplace(const key_type &k, Args &&... args) {
return m_ht.try_emplace(k, std::forward<Args>(args)...);
}
template <class... Args>
std::pair<iterator, bool> try_emplace(key_type &&k, Args &&... args) {
return m_ht.try_emplace(std::move(k), std::forward<Args>(args)...);
}
template <class... Args>
iterator try_emplace(const_iterator hint, const key_type &k,
Args &&... args) {
return m_ht.try_emplace_hint(hint, k, std::forward<Args>(args)...);
}
template <class... Args>
iterator try_emplace(const_iterator hint, key_type &&k, Args &&... args) {
return m_ht.try_emplace_hint(hint, std::move(k),
std::forward<Args>(args)...);
}
iterator erase(iterator pos) { return m_ht.erase(pos); }
iterator erase(const_iterator pos) { return m_ht.erase(pos); }
iterator erase(const_iterator first, const_iterator last) {
return m_ht.erase(first, last);
}
size_type erase(const key_type &key) { return m_ht.erase(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup to the value if you already have the hash.
*/
size_type erase(const key_type &key, std::size_t precalculated_hash) {
return m_ht.erase(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef
* KeyEqual::is_transparent exists. If so, K must be hashable and comparable
* to Key.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
size_type erase(const K &key) {
return m_ht.erase(key);
}
/**
* @copydoc erase(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup to the value if you already have the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
size_type erase(const K &key, std::size_t precalculated_hash) {
return m_ht.erase(key, precalculated_hash);
}
void swap(robin_map &other) { other.m_ht.swap(m_ht); }
/*
* Lookup
*/
T &at(const Key &key) { return m_ht.at(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup if you already have the hash.
*/
T &at(const Key &key, std::size_t precalculated_hash) {
return m_ht.at(key, precalculated_hash);
}
const T &at(const Key &key) const { return m_ht.at(key); }
/**
* @copydoc at(const Key& key, std::size_t precalculated_hash)
*/
const T &at(const Key &key, std::size_t precalculated_hash) const {
return m_ht.at(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef
* KeyEqual::is_transparent exists. If so, K must be hashable and comparable
* to Key.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
T &at(const K &key) {
return m_ht.at(key);
}
/**
* @copydoc at(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup if you already have the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
T &at(const K &key, std::size_t precalculated_hash) {
return m_ht.at(key, precalculated_hash);
}
/**
* @copydoc at(const K& key)
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
const T &at(const K &key) const {
return m_ht.at(key);
}
/**
* @copydoc at(const K& key, std::size_t precalculated_hash)
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
const T &at(const K &key, std::size_t precalculated_hash) const {
return m_ht.at(key, precalculated_hash);
}
T &operator[](const Key &key) { return m_ht[key]; }
T &operator[](Key &&key) { return m_ht[std::move(key)]; }
size_type count(const Key &key) const { return m_ht.count(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup if you already have the hash.
*/
size_type count(const Key &key, std::size_t precalculated_hash) const {
return m_ht.count(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef
* KeyEqual::is_transparent exists. If so, K must be hashable and comparable
* to Key.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
size_type count(const K &key) const {
return m_ht.count(key);
}
/**
* @copydoc count(const K& key) const
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup if you already have the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
size_type count(const K &key, std::size_t precalculated_hash) const {
return m_ht.count(key, precalculated_hash);
}
iterator find(const Key &key) { return m_ht.find(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup if you already have the hash.
*/
iterator find(const Key &key, std::size_t precalculated_hash) {
return m_ht.find(key, precalculated_hash);
}
const_iterator find(const Key &key) const { return m_ht.find(key); }
/**
* @copydoc find(const Key& key, std::size_t precalculated_hash)
*/
const_iterator find(const Key &key, std::size_t precalculated_hash) const {
return m_ht.find(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef
* KeyEqual::is_transparent exists. If so, K must be hashable and comparable
* to Key.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
iterator find(const K &key) {
return m_ht.find(key);
}
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup if you already have the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
iterator find(const K &key, std::size_t precalculated_hash) {
return m_ht.find(key, precalculated_hash);
}
/**
* @copydoc find(const K& key)
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
const_iterator find(const K &key) const {
return m_ht.find(key);
}
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup if you already have the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
const_iterator find(const K &key, std::size_t precalculated_hash) const {
return m_ht.find(key, precalculated_hash);
}
std::pair<iterator, iterator> equal_range(const Key &key) {
return m_ht.equal_range(key);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup if you already have the hash.
*/
std::pair<iterator, iterator> equal_range(const Key &key,
std::size_t precalculated_hash) {
return m_ht.equal_range(key, precalculated_hash);
}
std::pair<const_iterator, const_iterator> equal_range(const Key &key) const {
return m_ht.equal_range(key);
}
/**
* @copydoc equal_range(const Key& key, std::size_t precalculated_hash)
*/
std::pair<const_iterator, const_iterator>
equal_range(const Key &key, std::size_t precalculated_hash) const {
return m_ht.equal_range(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef
* KeyEqual::is_transparent exists. If so, K must be hashable and comparable
* to Key.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
std::pair<iterator, iterator> equal_range(const K &key) {
return m_ht.equal_range(key);
}
/**
* @copydoc equal_range(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The
* hash value should be the same as hash_function()(key). Usefull to speed-up
* the lookup if you already have the hash.
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
std::pair<iterator, iterator> equal_range(const K &key,
std::size_t precalculated_hash) {
return m_ht.equal_range(key, precalculated_hash);
}
/**
* @copydoc equal_range(const K& key)
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
std::pair<const_iterator, const_iterator> equal_range(const K &key) const {
return m_ht.equal_range(key);
}
/**
* @copydoc equal_range(const K& key, std::size_t precalculated_hash)
*/
template <
class K, class KE = KeyEqual,
typename std::enable_if<has_is_transparent<KE>::value>::type * = nullptr>
std::pair<const_iterator, const_iterator>
equal_range(const K &key, std::size_t precalculated_hash) const {
return m_ht.equal_range(key, precalculated_hash);
}
/*
* Bucket interface
*/
size_type bucket_count() const { return m_ht.bucket_count(); }
size_type max_bucket_count() const { return m_ht.max_bucket_count(); }
/*
* Hash policy
*/
float load_factor() const { return m_ht.load_factor(); }
float min_load_factor() const { return m_ht.min_load_factor(); }
float max_load_factor() const { return m_ht.max_load_factor(); }
/**
* Set the `min_load_factor` to `ml`. When the `load_factor` of the map goes
* below `min_load_factor` after some erase operations, the map will be
* shrunk when an insertion occurs. The erase method itself never shrinks
* the map.
*
* The default value of `min_load_factor` is 0.0f, the map never shrinks by
* default.
*/
void min_load_factor(float ml) { m_ht.min_load_factor(ml); }
void max_load_factor(float ml) { m_ht.max_load_factor(ml); }
void rehash(size_type count) { m_ht.rehash(count); }
void reserve(size_type count) { m_ht.reserve(count); }
/*
* Observers
*/
hasher hash_function() const { return m_ht.hash_function(); }
key_equal key_eq() const { return m_ht.key_eq(); }
/*
* Other
*/
/**
* Convert a const_iterator to an iterator.
*/
iterator mutable_iterator(const_iterator pos) {
return m_ht.mutable_iterator(pos);
}
friend bool operator==(const robin_map &lhs, const robin_map &rhs) {
if (lhs.size() != rhs.size()) {
return false;
}
for (const auto &element_lhs : lhs) {
const auto it_element_rhs = rhs.find(element_lhs.first);
if (it_element_rhs == rhs.cend() ||
element_lhs.second != it_element_rhs->second) {
return false;
}
}
return true;
}
friend bool operator!=(const robin_map &lhs, const robin_map &rhs) {
return !operator==(lhs, rhs);
}
friend void swap(robin_map &lhs, robin_map &rhs) { lhs.swap(rhs); }
private:
ht m_ht;
};
/**
* Same as `tsl::robin_map<Key, T, Hash, KeyEqual, Allocator, StoreHash,
* tsl::rh::prime_growth_policy>`.
*/
template <class Key, class T, class Hash = std::hash<Key>,
class KeyEqual = std::equal_to<Key>,
class Allocator = std::allocator<std::pair<Key, T>>,
bool StoreHash = false>
using robin_pg_map = robin_map<Key, T, Hash, KeyEqual, Allocator, StoreHash,
tsl::rh::prime_growth_policy>;
} // end namespace tsl
#endif
// Copyright 2019 Yan Yan
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <chrono>
#ifdef TV_CUDA
#include <cuda_runtime_api.h>
#endif
#include <iostream>
namespace spconv {
#ifdef TV_CUDA
template <typename TimeT = std::chrono::microseconds> struct CudaContextTimer {
CudaContextTimer() {
cudaDeviceSynchronize();
mCurTime = std::chrono::steady_clock::now();
}
typename TimeT::rep report() {
cudaDeviceSynchronize();
auto duration = std::chrono::duration_cast<TimeT>(
std::chrono::steady_clock::now() - mCurTime);
auto res = duration.count();
mCurTime = std::chrono::steady_clock::now();
return res;
}
private:
std::chrono::time_point<std::chrono::steady_clock> mCurTime;
};
#endif
template <typename TimeT = std::chrono::microseconds> struct CPUTimer {
CPUTimer() { mCurTime = std::chrono::steady_clock::now(); }
typename TimeT::rep report() {
auto duration = std::chrono::duration_cast<TimeT>(
std::chrono::steady_clock::now() - mCurTime);
auto res = duration.count();
mCurTime = std::chrono::steady_clock::now();
return res;
}
private:
std::chrono::time_point<std::chrono::steady_clock> mCurTime;
};
} // namespace spconv
[build-system]
requires = ["setuptools>=41.0", "wheel", "pccm>=0.2.14", "cumm>=0.1.7"]
build-backend = "setuptools.build_meta"
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Note: To use the 'upload' functionality of this file, you must:
# $ pip install twine
import io
import os
import platform
import re
import subprocess
import shutil
import sys
from distutils.version import LooseVersion
from pathlib import Path
from shutil import rmtree
from typing import List
import pccm
from pccm.extension import ExtCallback, PCCMBuild, PCCMExtension
from setuptools import Command, find_packages, setup
from setuptools.extension import Extension
from ccimport import compat
import subprocess
import re
# Package meta-data.
NAME = 'spconv'
RELEASE_NAME = NAME
deps = ["cumm"]
cuda_ver = os.environ.get("CUMM_CUDA_VERSION", "")
if not cuda_ver:
nvcc_version = subprocess.check_output(["nvcc", "--version"
]).decode("utf-8").strip()
nvcc_version_str = nvcc_version.split("\n")[3]
version_str: str = re.findall(r"release (\d+.\d+)",
nvcc_version_str)[0]
cuda_ver = version_str
cuda_ver = cuda_ver.replace(".", "") # 10.2 to 102
RELEASE_NAME += "-cu{}".format(cuda_ver)
deps = ["cumm-cu{}".format(cuda_ver)]
DESCRIPTION = 'spatial sparse convolution'
URL = 'https://github.com/traveller59/spconv'
EMAIL = 'yanyan.sub@outlook.com'
AUTHOR = 'Yan Yan'
REQUIRES_PYTHON = '>=3.6'
VERSION = None
# What packages are required for this module to be executed?
REQUIRED = ["pccm>=0.2.14", "pybind11>=2.6.0", "fire", "numpy", *deps]
# What packages are optional?
EXTRAS = {
# 'fancy feature': ['django'],
}
# The rest you shouldn't have to touch too much :)
# ------------------------------------------------
# Except, perhaps the License and Trove Classifiers!
# If you do change the License, remember to change the Trove Classifier for that!
here = os.path.abspath(os.path.dirname(__file__))
sys.path.append(str(Path(__file__).parent))
# Import the README and use it as the long-description.
# Note: this will only work if 'README.md' is present in your MANIFEST.in file!
try:
with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f:
long_description = '\n' + f.read()
except FileNotFoundError:
long_description = DESCRIPTION
# Load the package's __version__.py module as a dictionary.
about = {}
if not VERSION:
with open('version.txt', 'r') as f:
version = f.read().strip()
else:
version = VERSION
cwd = os.path.dirname(os.path.abspath(__file__))
import torch
from setuptools import Extension, find_packages, setup
from setuptools.command.build_ext import build_ext
def _convert_build_number(build_number):
parts = build_number.split(".")
if len(parts) == 2:
return "{}{:03d}".format(int(parts[0]), int(parts[1]))
elif len(parts) == 1:
return build_number
else:
raise NotImplementedError
# if 'LIBTORCH_ROOT' not in os.environ:
# raise ValueError("You must set LIBTORCH_ROOT to your torch c++ library.")
LIBTORCH_ROOT = str(Path(torch.__file__).parent)
env_suffix = os.environ.get("SPCONV_VERSION_SUFFIX", "")
if env_suffix != "":
version += ".dev{}".format(_convert_build_number(env_suffix))
version_path = os.path.join(cwd, NAME, '__version__.py')
about['__version__'] = version
SPCONV_FORCE_BUILD_CUDA = os.getenv("SPCONV_FORCE_BUILD_CUDA")
with open(version_path, 'w') as f:
f.write("__version__ = '{}'\n".format(version))
PYTHON_VERSION = "{}.{}".format(sys.version_info.major, sys.version_info.minor)
class UploadCommand(Command):
"""Support setup.py upload."""
remove_device = re.search(r"(\+|\.)(dev|cu|cpu)", torch.__version__)
PYTORCH_VERSION = torch.__version__
if remove_device is not None:
PYTORCH_VERSION = torch.__version__[:remove_device.start()]
PYTORCH_VERSION = list(map(int, PYTORCH_VERSION.split(".")))
PYTORCH_VERSION_NUMBER = PYTORCH_VERSION[0] * 10000 + PYTORCH_VERSION[1] * 100 + PYTORCH_VERSION[2]
class CMakeExtension(Extension):
def __init__(self, name, sourcedir='', library_dirs=[]):
Extension.__init__(self, name, sources=[], library_dirs=library_dirs)
self.sourcedir = os.path.abspath(sourcedir)
description = 'Build and publish the package.'
user_options = []
@staticmethod
def status(s):
"""Prints things in bold."""
print('\033[1m{0}\033[0m'.format(s))
def initialize_options(self):
pass
def finalize_options(self):
pass
class CMakeBuild(build_ext):
def run(self):
try:
out = subprocess.check_output(['cmake', '--version'])
self.status('Removing previous builds...')
rmtree(os.path.join(here, 'dist'))
except OSError:
raise RuntimeError("CMake must be installed to build the following extensions: " +
", ".join(e.name for e in self.extensions))
if platform.system() == "Windows":
cmake_version = LooseVersion(re.search(r'version\s*([\d.]+)', out.decode()).group(1))
if cmake_version < '3.13.0':
raise RuntimeError("CMake >= 3.13.0 is required on Windows")
for ext in self.extensions:
self.build_extension(ext)
def build_extension(self, ext):
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
cmake_args = [# '-G "Visual Studio 15 2017 Win64"',
'-DCMAKE_PREFIX_PATH={}'.format(LIBTORCH_ROOT),
'-DPYBIND11_PYTHON_VERSION={}'.format(PYTHON_VERSION),
'-DSPCONV_BuildTests=OFF',
'-DPYTORCH_VERSION={}'.format(PYTORCH_VERSION_NUMBER)
] # -arch=sm_61
if not torch.cuda.is_available() and SPCONV_FORCE_BUILD_CUDA is None:
cmake_args += ['-DSPCONV_BuildCUDA=OFF']
else:
cuda_flags = ["\"--expt-relaxed-constexpr\""]
# must add following flags to use at::Half
# but will remove raw half operators.
cuda_flags += ["-D__CUDA_NO_HALF_OPERATORS__", "-D__CUDA_NO_HALF_CONVERSIONS__"]
# cuda_flags += ["-D__CUDA_NO_HALF2_OPERATORS__"]
cmake_args += ['-DCMAKE_CUDA_FLAGS=' + " ".join(cuda_flags)]
cfg = 'Debug' if self.debug else 'Release'
assert cfg == "Release", "pytorch ops don't support debug build."
build_args = ['--config', cfg]
print(cfg)
if platform.system() == "Windows":
cmake_args += ['-DCMAKE_BUILD_TYPE=' + cfg]
cmake_args += ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}'.format(cfg.upper(), str(Path(extdir) / "spconv"))]
# cmake_args += ['-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY_{}={}'.format(cfg.upper(), str(Path(extdir) / "spconv"))]
cmake_args += ['-DCMAKE_RUNTIME_OUTPUT_DIRECTORY_{}={}'.format(cfg.upper(), str(Path(extdir) / "spconv"))]
cmake_args += ["-DCMAKE_WINDOWS_EXPORT_ALL_SYMBOLS=TRUE"]
if sys.maxsize > 2**32:
cmake_args += ['-A', 'x64']
build_args += ['--', '/m']
else:
cmake_args += ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={}'.format(str(Path(extdir) / "spconv"))]
cmake_args += ['-DCMAKE_BUILD_TYPE=' + cfg]
build_args += ['--', '-j4']
env = os.environ.copy()
env['CXXFLAGS'] = '{} -DVERSION_INFO=\\"{}\\"'.format(env.get('CXXFLAGS', ''),
self.distribution.get_version())
if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)
print("|||||CMAKE ARGS|||||", cmake_args)
subprocess.check_call(['cmake', ext.sourcedir] + cmake_args, cwd=self.build_temp, env=env)
subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp)
packages = find_packages(exclude=('tools', 'tools.*'))
pass
self.status('Building Source and Wheel (universal) distribution...')
os.system('{0} setup.py sdist bdist_wheel --universal'.format(
sys.executable))
self.status('Uploading the package to PyPI via Twine...')
os.system('twine upload dist/*')
self.status('Pushing git tags...')
os.system('git tag v{0}'.format(about['__version__']))
os.system('git push --tags')
sys.exit()
disable_jit = os.getenv("SPCONV_DISABLE_JIT", None)
if disable_jit is not None and disable_jit == "1":
cmdclass = {
'upload': UploadCommand,
'build_ext': PCCMBuild,
}
from cumm.gemm.main import GemmMainUnitTest
from spconv.core import SHUFFLE_SIMT_PARAMS, SHUFFLE_VOLTA_PARAMS, SHUFFLE_TURING_PARAMS
from spconv.csrc.sparse.all import SpconvOps
cu = GemmMainUnitTest(SHUFFLE_SIMT_PARAMS + SHUFFLE_VOLTA_PARAMS + SHUFFLE_TURING_PARAMS)
cu.namespace = "cumm.gemm.main"
cuda_ver_number = int(cuda_ver)
if cuda_ver_number < 110:
std = "c++14"
else:
std = "c++17"
ext_modules: List[Extension] = [
PCCMExtension([cu, SpconvOps()],
"spconv/core_cc",
Path(__file__).resolve().parent / "spconv",
objects_folder="objects",
std=std,
disable_pch=True)
]
else:
cmdclass = {
'upload': UploadCommand,
}
ext_modules = []
# Where the magic happens:
setup(
name='spconv',
version='1.2.1',
author='Yan Yan',
author_email='scrin@foxmail.com',
description='spatial sparse convolution for pytorch',
long_description='',
setup_requires = ['torch>=1.3.0'],
packages=packages,
package_dir = {'spconv': 'spconv'},
ext_modules=[CMakeExtension('spconv', library_dirs=[])],
cmdclass=dict(build_ext=CMakeBuild),
zip_safe=False,
name=RELEASE_NAME,
version=about['__version__'],
description=DESCRIPTION,
long_description=long_description,
long_description_content_type='text/markdown',
author=AUTHOR,
author_email=EMAIL,
python_requires=REQUIRES_PYTHON,
url=URL,
packages=find_packages(exclude=('tests', )),
# If your package is a single module, use this instead of 'packages':
# py_modules=['mypackage'],
entry_points={
'console_scripts': [],
},
install_requires=REQUIRED,
extras_require=EXTRAS,
include_package_data=True,
license='MIT',
classifiers=[
# Trove classifiers
# Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers
'License :: OSI Approved :: MIT License',
'Programming Language :: Python',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: Implementation :: CPython',
'Programming Language :: Python :: Implementation :: PyPy'
],
# $ setup.py publish support.
cmdclass=cmdclass,
ext_modules=ext_modules,
)
# Copyright 2019 Yan Yan
#
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
from pathlib import Path
from . import build as _build
import numpy as np
import torch
from spconv import ops, utils
from spconv.conv import (SparseConv2d, SparseConv3d, SparseConvTranspose2d,
SparseConvTranspose3d, SparseInverseConv2d,
SparseInverseConv3d, SubMConv2d, SubMConv3d)
from spconv.identity import Identity
from spconv.modules import SparseModule, SparseSequential
from spconv.ops import ConvAlgo
from spconv.pool import SparseMaxPool2d, SparseMaxPool3d
from spconv.tables import AddTable, ConcatTable, JoinTable
_LIB_FILE_NAME = "libspconv.so"
if platform.system() == "Windows":
_LIB_FILE_NAME = "spconv.dll"
_LIB_PATH = str(Path(__file__).parent / _LIB_FILE_NAME)
torch.ops.load_library(_LIB_PATH)
def scatter_nd(indices, updates, shape):
"""pytorch edition of tensorflow scatter_nd.
this function don't contain except handle code. so use this carefully
when indice repeats, don't support repeat add which is supported
in tensorflow.
"""
ret = torch.zeros(*shape, dtype=updates.dtype, device=updates.device)
ndim = indices.shape[-1]
output_shape = list(indices.shape[:-1]) + shape[indices.shape[-1]:]
flatted_indices = indices.view(-1, ndim)
slices = [flatted_indices[:, i] for i in range(ndim)]
slices += [Ellipsis]
ret[slices] = updates.view(*output_shape)
return ret
class SparseConvTensor(object):
def __init__(self, features, indices, spatial_shape, batch_size,
grid=None):
"""
Args:
features: [num_points, num_features] feature tensor
indices: [num_points, ndim + 1] indice tensor. batch index saved in indices[:, 0]
spatial_shape: spatial shape of your sparse data
batch_size: batch size of your sparse data
grid: pre-allocated grid tensor. should be used when the volume of spatial shape
is very large.
"""
self.features = features
self.indices = indices
self.spatial_shape = spatial_shape
self.batch_size = batch_size
self.indice_dict = {}
self.grid = grid
@classmethod
def from_dense(cls, x: torch.Tensor):
"""create sparse tensor fron channel last dense tensor by to_sparse
x must be NHWC tensor, channel last
"""
x = x.to_sparse(x.ndim - 1)
spatial_shape = x.shape[1:-1]
batch_size = x.shape[0]
indices_th = x.indices().permute(1, 0).contiguous().int()
features_th = x.values()
return cls(features_th, indices_th, spatial_shape, batch_size)
@property
def spatial_size(self):
return np.prod(self.spatial_shape)
def find_indice_pair(self, key):
if key is None:
return None
if key in self.indice_dict:
return self.indice_dict[key]
return None
def dense(self, channels_first=True):
output_shape = [self.batch_size] + list(
self.spatial_shape) + [self.features.shape[1]]
res = scatter_nd(
self.indices.to(self.features.device).long(), self.features,
output_shape)
if not channels_first:
return res
ndim = len(self.spatial_shape)
trans_params = list(range(0, ndim + 1))
trans_params.insert(1, ndim + 1)
return res.permute(*trans_params).contiguous()
@property
def sparity(self):
return self.indices.shape[0] / np.prod(
self.spatial_shape) / self.batch_size
class ToDense(SparseModule):
"""convert SparseConvTensor to NCHW dense tensor.
"""
def forward(self, x: SparseConvTensor):
return x.dense()
class RemoveGrid(SparseModule):
"""remove pre-allocated grid buffer.
"""
def forward(self, x: SparseConvTensor):
x.grid = None
return x
from .core import ConvAlgo, AlgoHint
from . import constants
\ No newline at end of file
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from cumm import tensorview as tv
from typing import Dict, List, Set, Tuple
from spconv.core_cc.cumm.gemm.main import GemmAlgoDesp, GemmMainUnitTest, GemmParams
# from spconv.core_cc.cumm.gemm.gather import GatherAll, ScatterAll
from cumm.gemm.algospec.core import ShuffleStrideType, get_min_arch_of_algo_str, get_available_algo_str_from_arch
from cumm.gemm.codeops import group_by, div_up
from typing import Optional
import time
import numpy as np
from .core import ConvAlgo, AlgoHint
ALL_ALGO_DESPS = GemmMainUnitTest.get_all_algo_desp()
_GEMM_STATIC_KEY = Tuple[bool, bool, bool, int, int, int, str, str]
# GATHER = GatherAll()
# SCATTER = ScatterAll()
class SimpleGemmAlgoMeta:
def __init__(self, tile_ms: List[int], tile_ns: List[int],
tile_ks: List[int],
tile_shape_to_algos: Dict[int, List[int]]) -> None:
self.tile_shape_to_algos = tile_shape_to_algos
self.tile_ms = tile_ms
self.tile_ns = tile_ns
self.tile_ks = tile_ks
class BestAlgoByProfile:
def __init__(self,
algo_desp: GemmAlgoDesp,
external_gather: bool,
external_scatter: bool,
gather_params: Optional[Tuple[int, int, int, int]] = None,
scatter_params: Optional[Tuple[int, int, int, int]] = None,
splitk: int = 1) -> None:
self.algo_desp = algo_desp
self.external_gather = external_gather
self.external_scatter = external_scatter
self.gather_params = gather_params
self.scatter_params = scatter_params
self.splitk = splitk
class SimpleGemm:
def __init__(self, desps: List[GemmAlgoDesp]) -> None:
self.desps = desps
self.static_key_to_desps = group_by(self.get_static_key, desps)
self.static_key_to_meta: Dict[_GEMM_STATIC_KEY,
SimpleGemmAlgoMeta] = {}
for k, static_desps in self.static_key_to_desps.items():
tile_shape_to_algos: Dict[int, List[int]] = {}
tile_ms: Set[int] = set()
tile_ns: Set[int] = set()
tile_ks: Set[int] = set()
for i, desp in enumerate(static_desps):
ts = desp.tile_shape
tile_ms.add(ts[0])
tile_ns.add(ts[1])
tile_ks.add(ts[2])
tile_key = ts[0] | (ts[1] << 20) | (ts[2] << 40)
if tile_key not in tile_shape_to_algos:
tile_shape_to_algos[tile_key] = []
tile_shape_to_algos[tile_key].append(i)
tile_ms_list = list(tile_ms)
tile_ns_list = list(tile_ns)
tile_ks_list = list(tile_ks)
tile_ms_list.sort()
tile_ns_list.sort()
tile_ks_list.sort()
self.static_key_to_meta[k] = SimpleGemmAlgoMeta(
tile_ms_list, tile_ns_list, tile_ks_list, tile_shape_to_algos)
self.nk_forward_cache: Dict[Tuple[int, int],
BestAlgoByProfile] = {} # for forward
self.nk_dgrad_cache: Dict[Tuple[int, int],
BestAlgoByProfile] = {} # for backward weight
self.mn_cache: Dict[Tuple[int, int],
BestAlgoByProfile] = {} # for backward weight
@staticmethod
def get_static_key(d: GemmAlgoDesp) -> _GEMM_STATIC_KEY:
return (d.trans_a, d.trans_b, d.trans_c, d.dtype_a, d.dtype_b,
d.dtype_c, d.shuffle_type, d.algo)
def device_synchronize(self):
return GemmMainUnitTest.device_synchronize()
def get_all_available(
self,
a: tv.Tensor,
b: tv.Tensor,
c: tv.Tensor,
trans_a: bool,
trans_b: bool,
trans_c: bool,
arch: Tuple[int, int],
shuffle_type: ShuffleStrideType = ShuffleStrideType.NoShuffle):
if trans_c:
trans_a = not trans_a
trans_b = not trans_b
trans_a, trans_b = trans_b, trans_a
a, b = b, a
trans_c = False
avail_algos = get_available_algo_str_from_arch(arch)
finally_algos: List[GemmAlgoDesp] = []
for algo in avail_algos:
static_key = (trans_a, trans_b, trans_c, a.dtype, b.dtype, c.dtype,
shuffle_type.value, algo)
desps = self.static_key_to_desps.get(static_key, None)
if desps is None or len(desps) == 0:
continue
for desp in desps:
lda = a.dim(1)
ldb = b.dim(1)
ldc = c.dim(1)
if desp.supported_ldx(lda, ldb, ldc):
finally_algos.append(desp)
return finally_algos
def select(self,
a: tv.Tensor,
b: tv.Tensor,
c: tv.Tensor,
trans_a: bool,
trans_b: bool,
trans_c: bool,
arch: Tuple[int, int],
shuffle_type: ShuffleStrideType = ShuffleStrideType.NoShuffle,
a_inds: tv.Tensor = tv.Tensor(),
b_inds: tv.Tensor = tv.Tensor(),
c_inds: tv.Tensor = tv.Tensor(),
hint: int = AlgoHint.NoHint.value):
m, n, k = GemmMainUnitTest.extract_mnk(a.shape, b.shape,
trans_a, trans_b, trans_c,
shuffle_type.value,
a_inds.shape, b_inds.shape,
c_inds.shape)
if trans_c:
trans_a = not trans_a
trans_b = not trans_b
trans_a, trans_b = trans_b, trans_a
a, b = b, a
trans_c = False
avail_algos = get_available_algo_str_from_arch(arch)
finally_algos: List[GemmAlgoDesp] = []
for algo in avail_algos:
static_key = (trans_a, trans_b, trans_c, a.dtype, b.dtype, c.dtype,
shuffle_type.value, algo)
desps = self.static_key_to_desps.get(static_key, None)
if desps is None or len(desps) == 0:
continue
meta = self.static_key_to_meta[static_key]
# for shuffle stride algos, we need to make channel tile size as large as possible.
# so if ShuffleAC, we need to make k largest.
selected_algo_desps = GemmMainUnitTest.simple_select_tile_shape(
m,
n,
k,
meta.tile_ms,
meta.tile_ns,
meta.tile_ks,
meta.tile_shape_to_algos,
large_k_first=shuffle_type == shuffle_type.ShuffleAC)
if not selected_algo_desps:
candidate = desps
else:
candidate = [desps[i] for i in selected_algo_desps]
# select by hint
if hint == 0:
return candidate[0]
if hint & (AlgoHint.Fowrard.value | AlgoHint.BackwardInput.value):
# m may be huge, n and k are small
# don't need mixed precision
# don't need splitk
finally_algos = []
if a.dtype == tv.float16:
dacc = tv.float16
dcomp = tv.float16
for can in candidate:
if can.dacc == dacc and can.dcomp == dcomp:
finally_algos.append(can)
else:
finally_algos = candidate
elif hint & AlgoHint.BackwardWeight.value:
# k is huge
# don't support i8
# if f16, acc and comp must be f32
finally_algos = []
candidate_filtered: List[GemmAlgoDesp] = list(
filter(lambda x: x.split_k_serial, candidate))
if not candidate_filtered:
candidate_filtered = candidate
if a.dtype == tv.int8:
continue
elif a.dtype == tv.float16:
dacc = tv.float32
dcomp = tv.float32
for can in candidate_filtered:
if can.dacc == dacc and can.dcomp == dcomp:
finally_algos.append(can)
else:
finally_algos = candidate_filtered
else:
return candidate[0]
# print(finally_algos)
if finally_algos:
return finally_algos[0]
return None
def get_profiled_algo(
self,
a_shape: List[int],
b_shape: List[int],
c_shape: List[int],
trans_a: bool,
trans_b: bool,
trans_c: bool,
arch: Tuple[int, int],
shuffle_type: ShuffleStrideType = ShuffleStrideType.NoShuffle,
a_inds_shape: Optional[List[int]] = None,
b_inds_shape: Optional[List[int]] = None,
c_inds_shape: Optional[List[int]] = None,
hint: int = AlgoHint.NoHint.value):
if a_inds_shape is None:
a_inds_shape = []
if b_inds_shape is None:
b_inds_shape = []
if c_inds_shape is None:
c_inds_shape = []
m, n, k = GemmMainUnitTest.extract_mnk(a_shape, b_shape,
trans_a, trans_b, trans_c,
shuffle_type.value,
a_inds_shape, b_inds_shape,
c_inds_shape)
if hint & AlgoHint.BackwardWeight.value:
key = (m, n)
return self.mn_cache.get(key, None)
elif hint & AlgoHint.BackwardInput.value:
key = (n, k)
return self.nk_dgrad_cache.get(key, None)
elif hint & AlgoHint.Fowrard.value:
key = (n, k)
return self.nk_forward_cache.get(key, None)
raise NotImplementedError
def extract_mnk(
self,
a_shape: List[int],
b_shape: List[int],
trans_a: bool,
trans_b: bool,
trans_c: bool,
arch: Tuple[int, int],
shuffle_type: ShuffleStrideType = ShuffleStrideType.NoShuffle,
a_inds_shape: Optional[List[int]] = None,
b_inds_shape: Optional[List[int]] = None,
c_inds_shape: Optional[List[int]] = None,
hint: int = AlgoHint.NoHint.value):
if a_inds_shape is None:
a_inds_shape = []
if b_inds_shape is None:
b_inds_shape = []
if c_inds_shape is None:
c_inds_shape = []
m, n, k = GemmMainUnitTest.extract_mnk(a_shape, b_shape,
trans_a, trans_b, trans_c,
shuffle_type.value,
a_inds_shape, b_inds_shape,
c_inds_shape)
return m, n, k
def profile_and_cache(
self,
a: tv.Tensor,
b: tv.Tensor,
c: tv.Tensor,
trans_a: bool,
trans_b: bool,
trans_c: bool,
arch: Tuple[int, int],
shuffle_type: ShuffleStrideType = ShuffleStrideType.NoShuffle,
a_inds: tv.Tensor = tv.Tensor(),
b_inds: tv.Tensor = tv.Tensor(),
c_inds: tv.Tensor = tv.Tensor(),
hint: int = AlgoHint.NoHint.value,
alpha: float = 1.0,
beta: float = 0.0,
gather_data: tv.Tensor = tv.Tensor(),
scatter_data: tv.Tensor = tv.Tensor(),
# mm_func
stream: int = 0):
m, n, k = GemmMainUnitTest.extract_mnk(a.shape, b.shape,
trans_a, trans_b, trans_c,
shuffle_type.value,
a_inds.shape, b_inds.shape,
c_inds.shape)
if hint & AlgoHint.BackwardWeight.value:
key = (m, n)
else:
key = (n, k)
avail = self.get_all_available(a, b, c, trans_a, trans_b, trans_c,
arch, shuffle_type)
c_ = c.clone()
times: List[float] = []
# gather_algos: List[GemmAlgoDesp] = []
# find fastest gather algo for this input
best_gather_params = (-1, -1, -1, -1)
best_scatter_params = (-1, -1, -1, -1)
# gather_data_ = tv.Tensor()
# if not gather_data.empty(
# ) and not hint & AlgoHint.BackwardWeight.value:
# # run gather here
# all_gather_params = GATHER.get_all_gather_params()
# gather_data_ = gather_data.clone()
# gather_times: List[float] = []
# for gather_params in all_gather_params:
# if GATHER.supported(gather_params[2], a.dim(1), a.dtype):
# this_times = []
# for j in range(10):
# GemmMainUnitTest.stream_synchronize(stream)
# t = time.time()
# GATHER.gather(gather_data_, a, a_inds, *gather_params)
# GemmMainUnitTest.stream_synchronize(stream)
# this_times.append(time.time() - t)
# gather_times.append(np.mean(this_times[5:]))
# min_time = 1000
# min_idx = -1
# for i, t in enumerate(gather_times):
# if t < min_time:
# min_time = t
# min_idx = i
# best_gather_params = all_gather_params[min_idx]
# if not scatter_data.empty(
# ) and not hint & AlgoHint.BackwardWeight.value:
# # run gather here
# all_scatter_params = SCATTER.get_all_scatter_params()
# scatter_data_ = scatter_data.clone()
# scatter_times: List[float] = []
# for params in all_scatter_params:
# if SCATTER.supported_scatter(*params, a.dim(1), a.dtype):
# this_times = []
# for j in range(10):
# GemmMainUnitTest.stream_synchronize(stream)
# t = time.time()
# SCATTER.scatter(c_, scatter_data_, c_inds, *params)
# GemmMainUnitTest.stream_synchronize(stream)
# this_times.append(time.time() - t)
# scatter_times.append(np.mean(this_times[5:]))
# min_time = 1000
# min_idx = -1
# for i, t in enumerate(scatter_times):
# if t < min_time:
# min_time = t
# min_idx = i
# best_scatter_params = all_scatter_params[min_idx]
all_profile_res: List[BestAlgoByProfile] = []
for desp in avail:
c_.zero_()
split_k_slices = 1
# TODO better splitk selection
if desp.split_k_serial and hint & AlgoHint.BackwardWeight.value:
split_k_slices = max(min(32, k // 128), 1)
params = GemmParams()
params.a = a
params.b = b
params.c = c_
params.a_inds = a_inds
params.b_inds = b_inds
params.c_inds = c_inds
params.algo_desp = desp
params.alpha = alpha
params.beta = beta
params.stream = stream
if desp.split_k_serial and hint & AlgoHint.BackwardWeight.value:
splitk_tests = [1, 2, 4, 8, 16, 32, 64]
else:
splitk_tests = [1]
spk_speeds = []
for spk in splitk_tests:
this_times = []
for j in range(3):
GemmMainUnitTest.stream_synchronize(stream)
t = time.time()
params.split_k_slices = spk
GemmMainUnitTest.matmul2(params)
GemmMainUnitTest.stream_synchronize(stream)
this_times.append(time.time() - t)
times.append(np.mean(this_times[1:]))
spk_speeds.append(times[-1])
all_profile_res.append(
BestAlgoByProfile(desp, False, False, best_gather_params, best_scatter_params, splitk=spk))
# if desp.split_k_serial:
# print(a.shape, b.shape, spk_speeds)
# if not gather_data.empty(
# ) and not hint & AlgoHint.BackwardWeight.value:
# # run gather here
# for spk in splitk_tests:
# this_times = []
# for j in range(3):
# GemmMainUnitTest.stream_synchronize(stream)
# t = time.time()
# params.a_inds = tv.Tensor()
# params.a = gather_data_
# params.split_k_slices = spk
# GATHER.gather(gather_data_,
# a,
# a_inds,
# *best_gather_params,
# stream=stream)
# GemmMainUnitTest.matmul2(params)
# GemmMainUnitTest.stream_synchronize(stream)
# this_times.append(time.time() - t)
# times.append(np.mean(this_times[1:]))
# # print("G", times[-1], times[-2])
# all_profile_res.append(
# BestAlgoByProfile(desp,
# True,
# False,
# best_gather_params, best_scatter_params,
# splitk=spk))
min_time = 1000
min_idx = -1
for i, t in enumerate(times):
if t < min_time:
min_time = t
min_idx = i
res = all_profile_res[min_idx]
if hint & AlgoHint.BackwardWeight.value:
key = (m, n)
self.mn_cache[key] = res
elif hint & AlgoHint.BackwardInput.value:
key = (n, k)
self.nk_dgrad_cache[key] = res
elif hint & AlgoHint.Fowrard.value:
key = (n, k)
self.nk_forward_cache[key] = res
else:
raise NotImplementedError
return res, min_time
def run_profile(
self,
profile_res: BestAlgoByProfile,
a: tv.Tensor,
b: tv.Tensor,
c: tv.Tensor,
trans_a: bool,
trans_b: bool,
trans_c: bool,
arch: Tuple[int, int],
stream: int,
shuffle_type: ShuffleStrideType = ShuffleStrideType.NoShuffle,
a_inds: tv.Tensor = tv.Tensor(),
b_inds: tv.Tensor = tv.Tensor(),
c_inds: tv.Tensor = tv.Tensor(),
hint: int = AlgoHint.NoHint.value,
alpha: float = 1.0,
beta: float = 0.0,
gather_data: tv.Tensor = tv.Tensor(),
workspace: tv.Tensor = tv.Tensor()):
m, n, k = GemmMainUnitTest.extract_mnk(a.shape, b.shape,
trans_a, trans_b, trans_c,
shuffle_type.value,
a_inds.shape, b_inds.shape,
c_inds.shape)
# GemmMainUnitTest.stream_synchronize(stream)
algo_desp = profile_res.algo_desp
assert algo_desp is not None
split_k_slices = 1
# TODO better splitk selection
# if algo_desp.split_k_serial and hint & AlgoHint.BackwardWeight.value:
# split_k_slices = max(min(32, k // 128), 1)
if profile_res.splitk > 1:
split_k_slices = profile_res.splitk
params = GemmParams()
params.a = a
params.b = b
params.c = c
params.a_inds = a_inds
params.b_inds = b_inds
params.c_inds = c_inds
params.algo_desp = algo_desp
params.split_k_slices = split_k_slices
params.stream = stream
params.alpha = alpha
params.beta = beta
params.workspace = workspace
# gather = 0
# if profile_res.external_gather and not gather_data.empty():
# GemmMainUnitTest.stream_synchronize(stream)
# tt = time.time()
# assert not gather_data.empty()
# params.a_inds = tv.Tensor()
# params.a = gather_data
# # print(profile_res.gather_params, gather_data.shape, a.shape, a_inds.shape)
# GATHER.gather(gather_data,
# a,
# a_inds,
# *profile_res.gather_params,
# stream=stream)
# GemmMainUnitTest.stream_synchronize(stream)
# gather = time.time() - tt
GemmMainUnitTest.matmul2(params)
# GemmMainUnitTest.stream_synchronize(stream)
return algo_desp
GEMM = SimpleGemm(ALL_ALGO_DESPS)
if __name__ == "__main__":
print(len(ALL_ALGO_DESPS))
print(ALL_ALGO_DESPS[0])
a = tv.zeros([64000, 32], dtype=tv.float16)
b = tv.zeros([32, 64], dtype=tv.float16)
c = tv.zeros([64000, 64], dtype=tv.float16)
a_inds = tv.zeros([64000], dtype=tv.int32)
c_inds = tv.zeros([64000], dtype=tv.int32)
t = time.time()
for i in range(100):
algo = GEMM.select(a,
c,
b,
True,
False,
False, (7, 5),
ShuffleStrideType.ShuffleAB,
a_inds=a_inds,
b_inds=c_inds)
print((time.time() - t) / 100)
print(algo)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment