"examples/vscode:/vscode.git/clone" did not exist on "03bf877bf4e6653137119dd38006f30c76cbee04"
Commit 0a386c46 authored by Chao Liu's avatar Chao Liu
Browse files

use more constexpr for Array

parent 08c69243
...@@ -34,14 +34,6 @@ struct Array ...@@ -34,14 +34,6 @@ struct Array
__host__ __device__ TData& operator()(index_t i) { return mData[i]; } __host__ __device__ TData& operator()(index_t i) { return mData[i]; }
template <index_t I>
__host__ __device__ constexpr TData Get(Number<I>) const
{
static_assert(I < NSize, "wrong!");
return mData[I];
}
template <index_t I> template <index_t I>
__host__ __device__ constexpr void Set(Number<I>, TData x) __host__ __device__ constexpr void Set(Number<I>, TData x)
{ {
...@@ -50,16 +42,33 @@ struct Array ...@@ -50,16 +42,33 @@ struct Array
mData[I] = x; mData[I] = x;
} }
__host__ __device__ constexpr void Set(index_t I, TData x) { mData[I] = x; }
struct lambda_PushBack // emulate constexpr lambda
{
const Array<TData, NSize>& old_array;
Array<TData, NSize + 1>& new_array;
__host__ __device__ constexpr lambda_PushBack(const Array<TData, NSize>& old_array_,
Array<TData, NSize + 1>& new_array_)
: old_array(old_array_), new_array(new_array_)
{
}
template <index_t I>
__host__ __device__ constexpr void operator()(Number<I>) const
{
new_array.Set(Number<I>{}, old_array[I]);
}
};
__host__ __device__ constexpr auto PushBack(TData x) const __host__ __device__ constexpr auto PushBack(TData x) const
{ {
Array<TData, NSize + 1> new_array; Array<TData, NSize + 1> new_array;
static_for<0, NSize, 1>{}([&](auto I) { static_for<0, NSize, 1>{}(lambda_PushBack(*this, new_array));
constexpr index_t i = I.Get();
new_array(i) = mData[i];
});
new_array(NSize) = x; new_array.Set(Number<NSize>{}, x);
return new_array; return new_array;
} }
...@@ -81,18 +90,13 @@ __host__ __device__ constexpr auto make_zero_array() ...@@ -81,18 +90,13 @@ __host__ __device__ constexpr auto make_zero_array()
template <class TData, index_t NSize, index_t... IRs> template <class TData, index_t NSize, index_t... IRs>
__host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData, NSize>& old_array, __host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData, NSize>& old_array,
Sequence<IRs...> new2old) Sequence<IRs...> /*new2old*/)
{ {
Array<TData, NSize> new_array;
static_assert(NSize == sizeof...(IRs), "NSize not consistent"); static_assert(NSize == sizeof...(IRs), "NSize not consistent");
static_for<0, NSize, 1>{}([&](auto IDim) { static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");
constexpr index_t idim = IDim.Get();
new_array[idim] = old_array[new2old.Get(IDim)];
});
return new_array; return Array<TData, NSize>{old_array.mSize[IRs]...};
} }
template <class TData, index_t NSize, class MapOld2New> template <class TData, index_t NSize, class MapOld2New>
...@@ -120,12 +124,14 @@ struct lambda_reorder_array_given_old2new ...@@ -120,12 +124,14 @@ struct lambda_reorder_array_given_old2new
template <class TData, index_t NSize, index_t... IRs> template <class TData, index_t NSize, index_t... IRs>
__host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData, NSize>& old_array, __host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData, NSize>& old_array,
Sequence<IRs...> old2new) Sequence<IRs...> /*old2new*/)
{ {
Array<TData, NSize> new_array; Array<TData, NSize> new_array;
static_assert(NSize == sizeof...(IRs), "NSize not consistent"); static_assert(NSize == sizeof...(IRs), "NSize not consistent");
static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");
static_for<0, NSize, 1>{}( static_for<0, NSize, 1>{}(
lambda_reorder_array_given_old2new<TData, NSize, Sequence<IRs...>>(old_array, new_array)); lambda_reorder_array_given_old2new<TData, NSize, Sequence<IRs...>>(old_array, new_array));
...@@ -141,25 +147,44 @@ __host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_ ...@@ -141,25 +147,44 @@ __host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_
static_assert(new_size <= NSize, "wrong! too many extract"); static_assert(new_size <= NSize, "wrong! too many extract");
static_for<0, new_size, 1>{}([&](auto I) { static_for<0, new_size, 1>{}([&](auto I) { new_array(I) = old_array[ExtractSeq::Get(I)]; });
constexpr index_t i = I.Get();
new_array(i) = old_array[ExtractSeq::Get(I)];
});
return new_array; return new_array;
} }
template <class F, class X, class Y, class Z> // emulate constepxr lambda for array math
struct lambda_array_math
{
const F& f;
const X& x;
const Y& y;
Z& z;
__host__ __device__ constexpr lambda_array_math(const F& f_, const X& x_, const Y& y_, Z& z_)
: f(f_), x(x_), y(y_), z(z_)
{
}
template <index_t IDim_>
__host__ __device__ constexpr void operator()(Number<IDim_>) const
{
constexpr auto IDim = Number<IDim_>{};
z.Set(IDim, f(x[IDim], y[IDim]));
}
};
// Array = Array + Array // Array = Array + Array
template <class TData, index_t NSize> template <class TData, index_t NSize>
__host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData, NSize> b) __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData, NSize> b)
{ {
Array<TData, NSize> result; Array<TData, NSize> result;
static_for<0, NSize, 1>{}([&](auto I) { auto f = mod_conv::plus<index_t>{};
constexpr index_t i = I.Get();
result(i) = a[i] + b[i]; static_for<0, NSize, 1>{}(
}); lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
f, a, b, result));
return result; return result;
} }
...@@ -170,11 +195,11 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData, ...@@ -170,11 +195,11 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData,
{ {
Array<TData, NSize> result; Array<TData, NSize> result;
static_for<0, NSize, 1>{}([&](auto I) { auto f = mod_conv::minus<index_t>{};
constexpr index_t i = I.Get();
result(i) = a[i] - b[i]; static_for<0, NSize, 1>{}(
}); lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
f, a, b, result));
return result; return result;
} }
...@@ -187,11 +212,11 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is. ...@@ -187,11 +212,11 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is.
Array<TData, NSize> result; Array<TData, NSize> result;
static_for<0, NSize, 1>{}([&](auto I) { auto f = mod_conv::plus<index_t>{};
constexpr index_t i = I.Get();
result(i) = a[i] + b.Get(I); static_for<0, NSize, 1>{}(
}); lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
f, a, b, result));
return result; return result;
} }
...@@ -204,11 +229,11 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Sequence<Is. ...@@ -204,11 +229,11 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Sequence<Is.
Array<TData, NSize> result; Array<TData, NSize> result;
static_for<0, NSize, 1>{}([&](auto I) { auto f = mod_conv::minus<index_t>{};
constexpr index_t i = I.Get();
result(i) = a[i] - b.Get(I); static_for<0, NSize, 1>{}(
}); lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
f, a, b, result));
return result; return result;
} }
...@@ -221,11 +246,11 @@ __host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is. ...@@ -221,11 +246,11 @@ __host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is.
Array<TData, NSize> result; Array<TData, NSize> result;
static_for<0, NSize, 1>{}([&](auto I) { auto f = mod_conv::multiplies<index_t>{};
constexpr index_t i = I.Get();
result(i) = a[i] * b.Get(I); static_for<0, NSize, 1>{}(
}); lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
f, a, b, result));
return result; return result;
} }
...@@ -238,11 +263,11 @@ __host__ __device__ constexpr auto operator-(Sequence<Is...> a, Array<TData, NSi ...@@ -238,11 +263,11 @@ __host__ __device__ constexpr auto operator-(Sequence<Is...> a, Array<TData, NSi
Array<TData, NSize> result; Array<TData, NSize> result;
static_for<0, NSize, 1>{}([&](auto I) { auto f = mod_conv::minus<index_t>{};
constexpr index_t i = I.Get();
result(i) = a.Get(I) - b[i]; static_for<0, NSize, 1>{}(
}); lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
f, a, b, result));
return result; return result;
} }
...@@ -255,10 +280,7 @@ accumulate_on_array(const Array<TData, NSize>& a, Reduce f, TData init) ...@@ -255,10 +280,7 @@ accumulate_on_array(const Array<TData, NSize>& a, Reduce f, TData init)
static_assert(NSize > 0, "wrong"); static_assert(NSize > 0, "wrong");
static_for<0, NSize, 1>{}([&](auto I) { static_for<0, NSize, 1>{}([&](auto I) { result = f(result, a[I]); });
constexpr index_t i = I.Get();
result = f(result, a[i]);
});
return result; return result;
} }
......
...@@ -48,13 +48,13 @@ struct ConstantTensorDescriptor ...@@ -48,13 +48,13 @@ struct ConstantTensorDescriptor
template <index_t I> template <index_t I>
__host__ __device__ static constexpr index_t GetLength(Number<I>) __host__ __device__ static constexpr index_t GetLength(Number<I>)
{ {
return Lengths{}.Get(Number<I>{}); return Lengths::Get(Number<I>{});
} }
template <index_t I> template <index_t I>
__host__ __device__ static constexpr index_t GetStride(Number<I>) __host__ __device__ static constexpr index_t GetStride(Number<I>)
{ {
return Strides{}.Get(Number<I>{}); return Strides::Get(Number<I>{});
} }
struct lambda_AreDimensionsContinuous struct lambda_AreDimensionsContinuous
...@@ -131,7 +131,7 @@ struct ConstantTensorDescriptor ...@@ -131,7 +131,7 @@ struct ConstantTensorDescriptor
template <class X> template <class X>
__host__ __device__ constexpr void operator()(X IDim) const __host__ __device__ constexpr void operator()(X IDim) const
{ {
offset += multi_id.Get(IDim) * Type::GetStride(IDim); offset += multi_id[IDim] * Type::GetStride(IDim);
} }
}; };
......
...@@ -2,6 +2,9 @@ ...@@ -2,6 +2,9 @@
#include "integral_constant.hip.hpp" #include "integral_constant.hip.hpp"
#include "functional.hip.hpp" #include "functional.hip.hpp"
template <class Seq>
struct is_valid_sequence_map;
template <index_t... Is> template <index_t... Is>
struct Sequence struct Sequence
{ {
...@@ -40,27 +43,24 @@ struct Sequence ...@@ -40,27 +43,24 @@ struct Sequence
template <index_t... IRs> template <index_t... IRs>
__host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/) __host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
{ {
#if 0 // require sequence_sort, which is not implemented yet
static_assert(is_same<sequence_sort<Sequence<IRs...>>::SortedSeqType,
arithmetic_sequence_gen<0, mSize, 1>::SeqType>::value,
"wrong! invalid new2old map");
#endif
static_assert(sizeof...(Is) == sizeof...(IRs), static_assert(sizeof...(Is) == sizeof...(IRs),
"wrong! new2old map should have the same size as Sequence to be rerodered"); "wrong! reorder map should have the same size as Sequence to be rerodered");
return Sequence<Type{}.Get(Number<IRs>{})...>{}; static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");
return Sequence<Type::Get(Number<IRs>{})...>{};
} }
#if 0 // require sequence_sort, which is not implemented yet #if 0 // require sequence_sort, which is not implemented yet
template <class MapOld2New> template <class MapOld2New>
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New /*old2new*/) __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New /*old2new*/)
{ {
#if 0 static_assert(sizeof...(Is) == MapOld2New::GetSize(),
static_assert(is_same<sequence_sort<MapOld2New>::SortedSeqType, "wrong! reorder map should have the same size as Sequence to be rerodered");
arithmetic_sequence_gen<0, mSize, 1>::SeqType>::value,
"wrong! invalid old2new map"); static_assert(is_valid_sequence_map<MapOld2New>::value,
#endif "wrong! invalid reorder map");
constexpr auto map_new2old = typename sequence_map_inverse<MapOld2New>::SeqMapType{}; constexpr auto map_new2old = typename sequence_map_inverse<MapOld2New>::SeqMapType{};
return ReorderGivenNew2Old(map_new2old); return ReorderGivenNew2Old(map_new2old);
...@@ -106,13 +106,13 @@ struct Sequence ...@@ -106,13 +106,13 @@ struct Sequence
template <index_t... Ns> template <index_t... Ns>
__host__ __device__ static constexpr auto Extract(Number<Ns>...) __host__ __device__ static constexpr auto Extract(Number<Ns>...)
{ {
return Sequence<Type{}.Get(Number<Ns>{})...>{}; return Sequence<Type::Get(Number<Ns>{})...>{};
} }
template <index_t... Ns> template <index_t... Ns>
__host__ __device__ static constexpr auto Extract(Sequence<Ns...>) __host__ __device__ static constexpr auto Extract(Sequence<Ns...>)
{ {
return Sequence<Type{}.Get(Number<Ns>{})...>{}; return Sequence<Type::Get(Number<Ns>{})...>{};
} }
template <index_t I, index_t X> template <index_t I, index_t X>
...@@ -316,6 +316,7 @@ struct sequence_map_inverse<Sequence<Is...>> ...@@ -316,6 +316,7 @@ struct sequence_map_inverse<Sequence<Is...>>
}; };
#endif #endif
template <class Seq> template <class Seq>
struct is_valid_sequence_map struct is_valid_sequence_map
{ {
......
#pragma once
__device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
__device__ index_t get_block_1d_id() { return blockIdx.x; }
template <class T1, class T2>
struct is_same
{
static constexpr bool value = false;
};
template <class T>
struct is_same<T, T>
{
static constexpr bool value = true;
};
template <class X, class Y>
__host__ __device__ constexpr bool is_same_type(X, Y)
{
return is_same<X, Y>::value;
}
namespace mod_conv { // namespace mod_conv
template <class T, T s>
struct scales
{
__host__ __device__ constexpr T operator()(T a) const { return s * a; }
};
template <class T>
struct plus
{
__host__ __device__ constexpr T operator()(T a, T b) const { return a + b; }
};
template <class T>
struct minus
{
__host__ __device__ constexpr T operator()(T a, T b) const { return a - b; }
};
template <class T>
struct multiplies
{
__host__ __device__ constexpr T operator()(T a, T b) const { return a * b; }
};
template <class T>
struct integer_divide_ceiler
{
__host__ __device__ constexpr T operator()(T a, T b) const
{
static_assert(is_same<T, index_t>::value || is_same<T, int>::value, "wrong type");
return (a + b - 1) / b;
}
};
template <class T>
__host__ __device__ constexpr T integer_divide_ceil(T a, T b)
{
static_assert(is_same<T, index_t>::value || is_same<T, int>::value, "wrong type");
return (a + b - 1) / b;
}
template <class T>
__host__ __device__ constexpr T max(T x, T y)
{
return x > y ? x : y;
}
template <class T, class... Ts>
__host__ __device__ constexpr T max(T x, Ts... xs)
{
static_assert(sizeof...(xs) > 0, "not enough argument");
auto y = max(xs...);
static_assert(is_same<decltype(y), T>::value, "not the same type");
return x > y ? x : y;
}
template <class T>
__host__ __device__ constexpr T min(T x, T y)
{
return x < y ? x : y;
}
template <class T, class... Ts>
__host__ __device__ constexpr T min(T x, Ts... xs)
{
static_assert(sizeof...(xs) > 0, "not enough argument");
auto y = min(xs...);
static_assert(is_same<decltype(y), T>::value, "not the same type");
return x < y ? x : y;
}
// this is wrong
// TODO: implement correct least common multiple, instead of calling max()
template <class T, class... Ts>
__host__ __device__ constexpr T lcm(T x, Ts... xs)
{
return max(x, xs...);
}
} // namespace mod_conv
...@@ -203,20 +203,18 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -203,20 +203,18 @@ struct BlockwiseGenericTensorSliceCopy_v1
make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths); make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths);
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) { static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
#if 0 #if 1
constexpr auto repeat_multi_id = sequence2array(decltype(repeat_multi_id_){}); constexpr auto repeat_multi_id = sequence2array(decltype(repeat_multi_id_){});
const auto src_thread_data_multi_id_begin = const auto src_thread_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims;
repeat_multi_id * data_per_cluster_per_dims; // cannot not constexpr, why?
const auto clipboard_data_multi_id_begin = const auto clipboard_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths;
repeat_multi_id * thread_sub_tensor_lengths; // cannot not constexpr, why?
const index_t src_offset = SrcDesc{}.GetOffsetFromMultiIndex( const index_t src_offset =
src_thread_data_multi_id_begin); // cannot not constexpr, why? SrcDesc{}.GetOffsetFromMultiIndex(src_thread_data_multi_id_begin);
const index_t clipboard_offset = thread_tensor_desc.GetOffsetFromMultiIndex( const index_t clipboard_offset =
clipboard_data_multi_id_begin); // cannot not constexpr, why? thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
#else #else
constexpr auto repeat_multi_id = decltype(repeat_multi_id_){}; constexpr auto repeat_multi_id = decltype(repeat_multi_id_){};
...@@ -258,20 +256,17 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -258,20 +256,17 @@ struct BlockwiseGenericTensorSliceCopy_v1
make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths); make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths);
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) { static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
#if 0 #if 1
constexpr auto repeat_multi_id = sequence2array(decltype(repeat_multi_id_){}); constexpr auto repeat_multi_id = sequence2array(decltype(repeat_multi_id_){});
const auto clipboard_data_multi_id_begin = const auto clipboard_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths;
repeat_multi_id * thread_sub_tensor_lengths; // cannot not constexpr, why?
const auto dst_data_multi_id_begin = const auto dst_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims;
repeat_multi_id * data_per_cluster_per_dims; // cannot not constexpr, why?
const index_t clipboard_offset = thread_tensor_desc.GetOffsetFromMultiIndex( const index_t clipboard_offset =
clipboard_data_multi_id_begin); // cannot not constexpr, why? thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
const index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex( const index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(dst_data_multi_id_begin);
dst_data_multi_id_begin); // cannot not constexpr, why?
#else #else
constexpr auto repeat_multi_id = decltype(repeat_multi_id_){}; constexpr auto repeat_multi_id = decltype(repeat_multi_id_){};
......
#pragma once #pragma once
#include "base.hip.hpp"
#include "vector_type.hip.hpp" #include "vector_type.hip.hpp"
#include "integral_constant.hip.hpp" #include "integral_constant.hip.hpp"
#include "Sequence.hip.hpp" #include "Sequence.hip.hpp"
...@@ -10,109 +11,3 @@ ...@@ -10,109 +11,3 @@
#if USE_AMD_INLINE_ASM #if USE_AMD_INLINE_ASM
#include "amd_inline_asm.hip.hpp" #include "amd_inline_asm.hip.hpp"
#endif #endif
__device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
__device__ index_t get_block_1d_id() { return blockIdx.x; }
template <class T1, class T2>
struct is_same
{
static constexpr bool value = false;
};
template <class T>
struct is_same<T, T>
{
static constexpr bool value = true;
};
template <class X, class Y>
__host__ __device__ constexpr bool is_same_type(X, Y)
{
return is_same<X, Y>::value;
}
namespace mod_conv { // namespace mod_conv
template <class T, T s>
struct scales
{
__host__ __device__ constexpr T operator()(T a) const { return s * a; }
};
template <class T>
struct plus
{
__host__ __device__ constexpr T operator()(T a, T b) const { return a + b; }
};
template <class T>
struct multiplies
{
__host__ __device__ constexpr T operator()(T a, T b) const { return a * b; }
};
template <class T>
struct integer_divide_ceiler
{
__host__ __device__ constexpr T operator()(T a, T b) const
{
static_assert(is_same<T, index_t>::value || is_same<T, int>::value, "wrong type");
return (a + b - 1) / b;
}
};
template <class T>
__host__ __device__ constexpr T integer_divide_ceil(T a, T b)
{
static_assert(is_same<T, index_t>::value || is_same<T, int>::value, "wrong type");
return (a + b - 1) / b;
}
template <class T>
__host__ __device__ constexpr T max(T x, T y)
{
return x > y ? x : y;
}
template <class T, class... Ts>
__host__ __device__ constexpr T max(T x, Ts... xs)
{
static_assert(sizeof...(xs) > 0, "not enough argument");
auto y = max(xs...);
static_assert(is_same<decltype(y), T>::value, "not the same type");
return x > y ? x : y;
}
template <class T>
__host__ __device__ constexpr T min(T x, T y)
{
return x < y ? x : y;
}
template <class T, class... Ts>
__host__ __device__ constexpr T min(T x, Ts... xs)
{
static_assert(sizeof...(xs) > 0, "not enough argument");
auto y = min(xs...);
static_assert(is_same<decltype(y), T>::value, "not the same type");
return x < y ? x : y;
}
// this is wrong
// TODO: implement correct least common multiple, instead of calling max()
template <class T, class... Ts>
__host__ __device__ constexpr T lcm(T x, Ts... xs)
{
return max(x, xs...);
}
} // namespace mod_conv
...@@ -11,7 +11,7 @@ struct static_ford_impl ...@@ -11,7 +11,7 @@ struct static_ford_impl
// F signature: F(Sequence<...> multi_id) // F signature: F(Sequence<...> multi_id)
// CurrentMultiIndex: Sequence<...> // CurrentMultiIndex: Sequence<...>
template <class F, class CurrentMultiIndex> template <class F, class CurrentMultiIndex>
__host__ __device__ void operator()(F f, CurrentMultiIndex) const __host__ __device__ constexpr void operator()(F f, CurrentMultiIndex) const
{ {
static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here"); static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here");
...@@ -28,7 +28,7 @@ struct static_ford_impl<Sequence<>> ...@@ -28,7 +28,7 @@ struct static_ford_impl<Sequence<>>
// F signature: F(Sequence<...> multi_id) // F signature: F(Sequence<...> multi_id)
// CurrentMultiIndex: Sequence<...> // CurrentMultiIndex: Sequence<...>
template <class F, class CurrentMultiIndex> template <class F, class CurrentMultiIndex>
__host__ __device__ void operator()(F f, CurrentMultiIndex) const __host__ __device__ constexpr void operator()(F f, CurrentMultiIndex) const
{ {
f(CurrentMultiIndex{}); f(CurrentMultiIndex{});
} }
...@@ -40,7 +40,7 @@ struct static_ford ...@@ -40,7 +40,7 @@ struct static_ford
{ {
// F signature: F(Sequence<...> multi_id) // F signature: F(Sequence<...> multi_id)
template <class F> template <class F>
__host__ __device__ void operator()(F f) const __host__ __device__ constexpr void operator()(F f) const
{ {
static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty"); static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
...@@ -55,7 +55,7 @@ struct ford_impl ...@@ -55,7 +55,7 @@ struct ford_impl
// CurrentMultiIndex: Array<...> // CurrentMultiIndex: Array<...>
// RemainLengths: Sequence<...> // RemainLengths: Sequence<...>
template <class F, class CurrentMultiIndex, class RemainLengths> template <class F, class CurrentMultiIndex, class RemainLengths>
__host__ __device__ void __host__ __device__ constexpr void
operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const
{ {
static_assert(RemainLengths::GetSize() == RemainDim, "wrong!"); static_assert(RemainLengths::GetSize() == RemainDim, "wrong!");
...@@ -77,7 +77,7 @@ struct ford_impl<1> ...@@ -77,7 +77,7 @@ struct ford_impl<1>
// CurrentMultiIndex: Array<...> // CurrentMultiIndex: Array<...>
// RemainLengths: Sequence<...> // RemainLengths: Sequence<...>
template <class F, class CurrentMultiIndex, class RemainLengths> template <class F, class CurrentMultiIndex, class RemainLengths>
__host__ __device__ void __host__ __device__ constexpr void
operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const
{ {
static_assert(RemainLengths::GetSize() == 1, "wrong!"); static_assert(RemainLengths::GetSize() == 1, "wrong!");
...@@ -97,7 +97,7 @@ struct ford ...@@ -97,7 +97,7 @@ struct ford
{ {
// F signature: F(Array<...> multi_id) // F signature: F(Array<...> multi_id)
template <class F> template <class F>
__host__ __device__ void operator()(F f) const __host__ __device__ constexpr void operator()(F f) const
{ {
constexpr index_t first_length = Lengths{}.Front(); constexpr index_t first_length = Lengths{}.Front();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment