Commit 7a7fe160 authored by Chao Liu's avatar Chao Liu
Browse files

more utility code

parent 625838de
......@@ -47,6 +47,19 @@ template <index_t GridSize,
index_t OutThreadCopyDataPerAccess_N>
struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto I8 = Number<8>{};
static constexpr auto I9 = Number<9>{};
static constexpr auto I10 = Number<10>{};
static constexpr auto I11 = Number<11>{};
#if 0
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
......@@ -60,11 +73,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
GemmNPerThreadSubC % NPerThread == 0)),
"wrong!");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto False = integral_constant<bool, false>{};
......@@ -487,58 +495,67 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
Float* const __restrict__ p_out_global) const
{
#if 0
constexpr auto tmp = std::tuple<bool>{};
constexpr auto flag = std::get<0>(tmp);
#else
constexpr auto a = Tuple<bool, Sequence<1>, index_t>(true, Sequence<1>{}, 99);
constexpr auto a = make_tuple(true, Sequence<1>{}, index_t(99));
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
printf("adsas %d\n", a.At(Number<0>{}));
print_Sequence("seq", a.At(Number<1>{}));
printf("adsas %lu\n", a.At(Number<2>{}));
printf("[0] %d\n", a.At(I0));
print_Sequence("[1]", a.At(I1));
printf("[2] %lu\n", a.At(I2));
}
auto b = Tuple<bool, Sequence<1>, index_t>(true, Sequence<1>{}, 99);
bool flag = true;
auto b = make_tuple(flag, Sequence<1>{}, 99);
b.At(Number<0>{}) = false;
b.At(I0) = false;
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
printf("adsas %d\n", b.At(Number<0>{}));
print_Sequence("seq", b.At(Number<1>{}));
printf("adsas %lu\n", b.At(Number<2>{}));
printf("[0] %d\n", b.At(I0));
print_Sequence("[1]", b.At(I1));
printf("[2] %lu\n", b.At(I2));
printf("flag %d\n", flag);
}
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
printf("adsas %d\n",
Tuple<bool, Sequence<1>, index_t>(true, Sequence<1>(), 99).At(Number<0>{}));
print_Sequence(
"seq", Tuple<bool, Sequence<1>, index_t>(true, Sequence<1>(), 99).At(Number<1>{}));
printf("adsas %d\n",
Tuple<bool, Sequence<1>, index_t>(true, Sequence<1>(), 99).At(Number<2>{}));
printf("[0] %d\n", make_tuple(true, Sequence<1>(), index_t(99)).At(I0));
print_Sequence("[1]", make_tuple(true, Sequence<1>(), index_t(99)).At(I1));
printf("[2] %d\n", make_tuple(true, Sequence<1>(), index_t(99)).At(I2));
}
#endif
#if 0
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
#elif 1
// create a native tensor descriptor
constexpr auto in_n_c_h_w_global_desc =
constexpr auto in_c_h_w_n_global_desc =
make_NativeTensorDescriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0);
constexpr index_t Hi = in_c_h_w_n_global_desc.GetLength(I1);
constexpr index_t Wi = in_c_h_w_n_global_desc.GetLength(I2);
constexpr index_t N = in_c_h_w_n_global_desc.GetLength(I3);
constexpr auto pad_h_w = Pad<Sequence<Hi, Wi>, LowerPads, UpperPads>{};
constexpr auto pass_c = PassThrough<C>{};
constexpr auto pass_n = PassThrough<N>{};
constexpr auto trans = make_tuple(pass_c, pad_h_w, pass_n);
constexpr auto lower_dim_groups =
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{});
constexpr auto upper_dim_groups =
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{});
constexpr auto in_c_h_w_n_padded_global_desc = transform_tensor_descriptor(
in_c_h_w_n_global_desc, trans, lower_dim_groups, upper_dim_groups);
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_tensor_descriptor("in_n_c_h_w_global_desc", in_n_c_h_w_global_desc);
}
print_tensor_descriptor("in_c_h_w_n_global_desc", in_c_h_w_n_global_desc);
// transform the tensor descriptor once
//
// calculate the offset of some entry
printf("offset: %lu\n", in_c_h_w_n_global_desc.GetOffset({1, 2, 3, 4}));
printf("padded offset: %lu\n", in_c_h_w_n_padded_global_desc.GetOffset({1, 4, 5, 4}));
}
#endif
}
#endif
......
......@@ -178,7 +178,7 @@ struct ConstantTensorDescriptor
{
constexpr auto IDim = IDim_{};
constexpr index_t stride = PackedStrides::Get(IDim);
multi_id.Set(IDim, id / stride);
multi_id(IDim) = id / stride;
id -= multi_id[IDim] * stride;
}
};
......@@ -192,7 +192,7 @@ struct ConstantTensorDescriptor
// calculate index in each of the dimensions in the order of their dimension
static_for<0, nDim - 1, 1>{}(lambda_GetMultiIndexFrom1dIndex<PackedStrides>(id, multi_id));
multi_id.Set(Number<nDim - 1>{}, id / PackedStrides::Get(Number<nDim - 1>{}));
multi_id(Number<nDim - 1>{}) = id / PackedStrides::Get(Number<nDim - 1>{});
return multi_id;
}
......
......@@ -33,7 +33,7 @@ struct PassThrough
};
// LowLengths: Sequence<...>
template <class LowLengths, class LeftPads, class RightPads>
template <typename LowLengths, typename LeftPads, typename RightPads>
struct Pad
{
static constexpr index_t nDim = LowLengths::GetSize();
......@@ -67,7 +67,7 @@ struct Pad
#if 0
// LowLengths: Sequence<...>
template <class LowLengths>
template <typename LowLengths>
struct Merge
{
static constexpr index_t nDimLow = LowLengths::GetSize();
......@@ -113,7 +113,7 @@ struct Merge
#endif
// UpLengths: Sequence<...>
template <index_t LowLength, class UpLengths>
template <index_t LowLength, typename UpLengths>
struct Unmerge
{
static constexpr index_t nDimLow = 1;
......@@ -161,7 +161,7 @@ struct Unmerge
// UpLengths: Sequence<...>
// Coefficients: Sequence<...>
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp]
template <index_t LowLength, class UpLengths, class Coefficients>
template <index_t LowLength, typename UpLengths, typename Coefficients>
struct Embed
{
static constexpr index_t nDimLow = 1;
......
......@@ -7,12 +7,12 @@
namespace ck {
template <class... NativeDimensions>
template <typename... NativeDimensions>
struct NativeTensorDescriptor
{
using type = NativeTensorDescriptor;
static constexpr auto mDimensions = Tuple<NativeDimensions...>{};
static constexpr index_t nDim = mDimensions.GetSize();
static constexpr index_t nDim = sizeof...(NativeDimensions);
static constexpr auto mDimensions = make_tuple(NativeDimensions{}...);
using Index = MultiIndex<nDim>;
......@@ -20,7 +20,7 @@ struct NativeTensorDescriptor
struct lambda_GetLength
{
template <class IDim>
template <typename IDim>
__host__ __device__ constexpr auto operator()(IDim) const
{
return GetLength(IDim{});
......@@ -34,7 +34,7 @@ struct NativeTensorDescriptor
struct lambda_GetStride
{
template <class IDim>
template <typename IDim>
__host__ __device__ constexpr auto operator()(IDim) const
{
return GetStride(IDim{});
......@@ -49,16 +49,16 @@ struct NativeTensorDescriptor
template <index_t IDim>
__host__ __device__ static constexpr auto GetLength(Number<IDim>)
{
return mDimensions.Get(Number<IDim>{}).GetLength();
return mDimensions.At(Number<IDim>{}).GetLength();
}
template <index_t IDim>
__host__ __device__ static constexpr auto GetStride(Number<IDim>)
{
return mDimensions.Get(Number<IDim>{}).GetStride();
return mDimensions.At(Number<IDim>{}).GetStride();
}
__host__ __device__ static constexpr index_t GetOffset(Index idx)
__host__ __device__ static constexpr index_t GetOffset(const Index& idx)
{
index_t offset = 0;
......@@ -67,7 +67,7 @@ struct NativeTensorDescriptor
return offset;
}
__host__ __device__ static constexpr index_t GetOffsetDiff(Index idx_diff)
__host__ __device__ static constexpr index_t GetOffsetDiff(const Index& idx_diff)
{
index_t offset_diff = 0;
......@@ -96,28 +96,65 @@ struct NativeTensorDescriptor
}
};
#if 0
// LowerTensorDescriptor
// Transforms: std::tuple<DimensionTransforms...>
// LowerDimensionIds: std::tuple<Sequence<...>>
// UpperDimensionIds: std::tuple<Sequence<...>>
template <class LowTensorDescriptor, class Transforms, class LowDimensionIds, class UpDimensionIds>
// Transforms: Tuple<DimensionTransforms...>
// LowerDimensionIds: Tuple<Sequence<...>>
// UpperDimensionIds: Tuple<Sequence<...>>
template <typename LowTensorDescriptor,
typename Transforms,
typename LowDimensionIds,
typename UpDimensionIds>
struct TransformedTensorDescriptor
{
using type = TransformedTensorDescriptor;
static constexpr index_t nDimUp = GetUpperNumOfDimension();
static constexpr index_t nDimLow = GetLowerNumOfDimension();
using type = TransformedTensorDescriptor;
static constexpr index_t nTransform = Transforms::Size();
struct lambda_merge_sequences
{
template <typename... Seqs>
__host__ __device__ constexpr auto operator()(Seqs... seqs) const
{
return merge_sequences(seqs...);
}
};
__host__ __device__ static constexpr auto GetNumOfLowerDimension()
{
// Here, we assume all lower-dimensions are active
// TODO: sanity-check all lower-dimension are indeed active
static constexpr index_t nTransform = Transforms::GetSize();
using duplicated_low_active_dims =
decltype(unpack(lambda_merge_sequences{}, LowDimensionIds{}));
using low_active_dims = typename sequence_unique_sort<duplicated_low_active_dims,
math::less<index_t>,
math::equal<index_t>>::type;
return low_active_dims::Size();
}
__host__ __device__ static constexpr auto GetNumOfUpperDimension()
{
using duplicated_up_active_dims =
decltype(unpack(lambda_merge_sequences{}, UpDimensionIds{}));
using up_active_dims = typename sequence_unique_sort<duplicated_up_active_dims,
math::less<index_t>,
math::equal<index_t>>::type;
return up_active_dims::Size();
}
static constexpr index_t nDimUp = GetNumOfUpperDimension();
static constexpr index_t nDimLow = GetNumOfLowerDimension();
using UpperIndex = MultiIndex<nDimUp>;
using LowerIndex = MultiIndex<nDimLow>;
__host__ __device__ static constexpr TransformedTensorDescriptor()
__host__ __device__ constexpr TransformedTensorDescriptor()
{
static_assert(nTransform == Transforms::GetSize() &&
nTransform == LowDimensionIds::GetSize() &&
nTransform == UpDimensionIds::GetSize(),
static_assert(nTransform == Transforms::Size() && nTransform == LowDimensionIds::Size() &&
nTransform == UpDimensionIds::Size(),
"wrong! # of transformations not the same");
// TODO: sanity check: LowDimensionIds should include all low-dimensions,
......@@ -128,33 +165,17 @@ struct TransformedTensorDescriptor
// a low-dimension should be associated with only one transformation
}
__host__ __device__ static constexpr auto GetNumOfLowerDimension()
{
// Here, we assume all lower-dimensions are active
// TODO: sanity-check all lower-dimension are indeed active
constexpr auto low_active_dims = unique_sort_sequence(
merge_tuple_of_sequences(LowDimensionIds{}), math::less<index_t>{});
return low_active_dims.GetSize();
}
__host__ __device__ static constexpr auto GetNumOfUpperDimension()
{
constexpr auto up_active_dims =
unique_sort_sequence(merge_tuple_of_sequences(UpDimensionIds{}), math::less<index_t>{});
return up_active_dims.GetSize();
}
__host__ __device__ static constexpr auto GetNumOfDimension()
{
return GetNumOfUpperDimension();
}
__host__ __device__ static constexpr auto GetLengths()
#if 0
__host__ __device__ static constexpr auto GetUpperLengths()
{
struct lambda_get_upper_lengths
{
template <class Transform>
template <typename Transform>
__host__ __device__ constexpr auto operator()(Transform tran) const
{
return tran.GetUpperLengths();
......@@ -173,6 +194,7 @@ struct TransformedTensorDescriptor
using sort_dimension_ids =
sequence_unique_sort<decltype(all_upper_dimension_ids), math::less<index_t>>;
constexpr auto sorted_upper_dimension_ids = typename sort_dimension_ids::type;
constexpr auto sorted2unsorted_map = typename sort_dimension_ids::sorted2unsorted_map_type;
......@@ -182,46 +204,48 @@ struct TransformedTensorDescriptor
return sorted_upper_lengths;
}
__host__ __device__ static constexpr auto GetLengths() { return GetUpperLengths(); }
#endif
__host__ __device__ static constexpr auto GetLowerTensorDescriptor()
{
return LowTensorDescriptor{};
}
__host__ __device__ static constexpr index_t GetLowerIndex(UpperIndex idx_up)
__host__ __device__ static constexpr LowerIndex GetLowerIndex(const UpperIndex& idx_up)
{
LowerIndex idx_low;
static_for<0, nTransform, 1>{}([&](auto itran) {
constexpr auto tran = Transforms::Get(itran);
constexpr auto tran = Transforms{}.At(itran);
constexpr auto idx_low_part = pick_array_element(idx_low, LowDimensionIds::Get(itran));
constexpr auto idx_up_part = pick_array_element(idx_up, UpDimensionIds::Get(itran));
auto idx_low_part = pick_array_element(idx_low, LowDimensionIds{}.At(itran));
const auto idx_up_part = pick_array_element(idx_up, UpDimensionIds{}.At(itran));
// this assume each lower (single) index is only assocaited with one transformation,
// which is required for index transformation, and has been checked during constructor
// of TransformedTensorDescriptor
idx_low_part = tran.GetLowerIndex(idx_up_part);
idx_low_part = tran.GetLowerIndex(to_array(idx_up_part));
});
return idx_low;
}
__host__ __device__ static constexpr index_t GetLowerIndexDiff(UpperIndex idx_up_diff,
LowerIndex idx_low_old)
__host__ __device__ static constexpr LowerIndex GetLowerIndexDiff(const UpperIndex& idx_up_diff,
const LowerIndex& idx_low_old)
{
LowerIndex idx_low_diff;
static_for<0, nTransform, 1>{}([&](auto itran) {
constexpr auto tran = Transforms::Get(itran);
constexpr auto tran = Transforms::At(itran);
constexpr auto idx_up_diff_part =
pick_array_element(idx_up_diff, UpDimensionIds::Get(itran));
const auto idx_up_diff_part =
pick_array_element(idx_up_diff, UpDimensionIds::At(itran));
constexpr auto idx_low_diff_part =
pick_array_element(idx_low_diff, LowDimensionIds::Get(itran));
auto idx_low_diff_part = pick_array_element(idx_low_diff, LowDimensionIds::At(itran));
constexpr auto idx_low_old_part =
pick_array_element(idx_low_old, LowDimensionIds::Get(itran));
const auto idx_low_old_part =
pick_array_element(idx_low_old, LowDimensionIds::At(itran));
// this assume each lower (single) index is associated with only one transformation,
// which is required for index transformation, and has been checked during constructor
......@@ -232,13 +256,14 @@ struct TransformedTensorDescriptor
return idx_low_diff;
}
__host__ __device__ static constexpr index_t GetOffset(UpperIndex idx_up)
__host__ __device__ static constexpr index_t GetOffset(const UpperIndex& idx_up)
{
return GetLowerTensorDescriptor().GetOffset(GetLowerIndex(idx_up));
}
#if 0
template <index_t IDim>
__host__ __device__ static constexpr bool IsLinearDimension(Number<IDim>);
__host__ __device__ static constexpr bool IsLinearDimension(Number<IDim>)
{
// not implemented
}
......@@ -257,8 +282,8 @@ struct TransformedTensorDescriptor
{
// not implemented
}
};
#endif
};
template <index_t... Lengths, index_t... Strides>
__host__ __device__ constexpr auto make_NativeTensorDescriptor(Sequence<Lengths...>,
......@@ -267,15 +292,28 @@ __host__ __device__ constexpr auto make_NativeTensorDescriptor(Sequence<Lengths.
return NativeTensorDescriptor<NativeDimension<Lengths, Strides>...>{};
}
template <class Lengths>
template <typename Lengths>
__host__ __device__ constexpr auto make_NativeTensorDescriptor_packed(Lengths)
{
constexpr index_t strides = reverse_inclusive_scan_sequence(
Lengths::PopFront(), math::multiplies<index_t>{}, Number<1>{})
.PushBack(Number<1>{});
constexpr auto strides = reverse_inclusive_scan_sequence(
Lengths::PopFront(), math::multiplies<index_t>{}, Number<1>{})
.PushBack(Number<1>{});
return make_NativeTensorDescriptor(Lengths{}, strides);
}
template <typename LowTensorDescriptor,
typename Transforms,
typename LowDimensionIds,
typename UpDimensionIds>
__host__ __device__ constexpr auto
transform_tensor_descriptor(LowTensorDescriptor, Transforms, LowDimensionIds, UpDimensionIds)
{
return TransformedTensorDescriptor<LowTensorDescriptor,
Transforms,
LowDimensionIds,
UpDimensionIds>{};
}
} // namespace ck
#endif
......@@ -6,7 +6,7 @@
namespace ck {
template <class... NativeDimensions>
template <typename... NativeDimensions>
__host__ __device__ void print_tensor_descriptor(const char* s,
NativeTensorDescriptor<NativeDimensions...> desc)
{
......
......@@ -6,48 +6,78 @@
namespace ck {
template <class TData, index_t NSize>
template <typename TData, index_t NSize>
struct Array
{
using Type = Array<TData, NSize>;
using type = Array<TData, NSize>;
using data_type = TData;
static constexpr index_t nSize = NSize;
index_t mData[NSize];
index_t mData[nSize];
__host__ __device__ explicit constexpr Array() {}
template <class... Xs>
__host__ __device__ constexpr Array(Xs... xs) : mData{static_cast<TData>(xs)...}
template <typename X, typename... Xs>
__host__ __device__ explicit constexpr Array(X x, Xs... xs)
: mData{static_cast<TData>(x), static_cast<TData>(xs)...}
{
static_assert(sizeof...(Xs) + 1 == NSize, "wrong! size");
}
__host__ __device__ static constexpr index_t GetSize() { return NSize; }
#if 0
template <typename T>
__host__ __device__ explicit constexpr Array(const T& x)
{
static_assert(T::Size() == NSize, "wrong! size");
static_for<0, NSize, 1>{}([&](auto i){
mData[i] = x.At(i);
})
}
#endif
__host__ __device__ static constexpr index_t Size() { return NSize; }
__host__ __device__ static constexpr index_t GetSize() { return Size(); }
template <index_t I>
__host__ __device__ constexpr TData operator[](Number<I>) const
__host__ __device__ constexpr const TData& At(Number<I>) const
{
static_assert(I < NSize, "wrong!");
return mData[I];
}
__host__ __device__ constexpr TData operator[](index_t i) const { return mData[i]; }
template <index_t I>
__host__ __device__ TData& operator()(Number<I>)
__host__ __device__ constexpr TData& At(Number<I>)
{
static_assert(I < NSize, "wrong!");
return mData[I];
}
__host__ __device__ TData& operator()(index_t i) { return mData[i]; }
__host__ __device__ constexpr const TData& At(index_t i) const { return mData[i]; }
template <index_t I>
__host__ __device__ constexpr void Set(Number<I>, TData x)
__host__ __device__ constexpr TData& At(index_t i) { return mData[i]; }
template <typename I>
__host__ __device__ constexpr const TData& operator[](I i) const
{
static_assert(I < NSize, "wrong!");
return At(i);
}
mData[I] = x;
template <typename I>
__host__ __device__ constexpr TData& operator()(I i)
{
return At(i);
}
__host__ __device__ constexpr void Set(index_t I, TData x) { mData[I] = x; }
template <typename T>
__host__ __device__ constexpr type& operator=(const T& x)
{
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = x[i]; });
return *this;
}
struct lambda_PushBack // emulate constexpr lambda
{
......@@ -63,7 +93,7 @@ struct Array
template <index_t I>
__host__ __device__ constexpr void operator()(Number<I>) const
{
new_array.Set(Number<I>{}, old_array[I]);
new_array(Number<I>{}) = old_array[I];
}
};
......@@ -73,71 +103,98 @@ struct Array
static_for<0, NSize, 1>{}(lambda_PushBack(*this, new_array));
new_array.Set(Number<NSize>{}, x);
new_array(Number<NSize>{}) = x;
return new_array;
}
};
// A: Array
// Arr: Array
// Picks: Sequence<...>
template <class Arr, class Picks>
template <typename Arr, typename Picks>
struct ArrayElementPicker
{
using type = ArrayElementPicker;
using data_type = typename Arr::data_type;
__host__ __device__ constexpr ArrayElementPicker(Arr& array) : mData{array}
__host__ __device__ constexpr ArrayElementPicker() = delete;
__host__ __device__ explicit constexpr ArrayElementPicker(Arr& array) : mArray{array}
{
constexpr index_t imax =
accumulate_on_sequence(Picks{}, math::maxer<index_t>{}, Number<0>{});
static_assert(imax < Picks::GetSize(), "wrong! exceeding max id");
static_assert(imax < Arr::Size(), "wrong! exceeding # array element");
}
__host__ __device__ static constexpr index_t GetSize() { return Picks::GetSize(); }
__host__ __device__ static constexpr auto Size() { return Picks::Size(); }
template <index_t I>
__host__ __device__ constexpr data_type operator[](Number<I>) const
__host__ __device__ constexpr const data_type& At(Number<I>) const
{
constexpr auto IP = Picks::Get(Number<I>{});
return mData[IP];
static_assert(I < Size(), "wrong!");
constexpr auto IP = Picks{}[I];
return mArray[IP];
}
__host__ __device__ constexpr data_type operator[](index_t i) const
template <index_t I>
__host__ __device__ constexpr data_type& At(Number<I>)
{
constexpr index_t ip = Picks{}[i];
return mData[ip];
static_assert(I < Size(), "wrong!");
constexpr auto IP = Picks{}[I];
return mArray(IP);
}
template <index_t I>
__host__ __device__ data_type& operator()(Number<I>)
template <typename I>
__host__ __device__ constexpr const data_type& operator[](I i) const
{
constexpr auto IP = Picks::Get(Number<I>{});
return mData[IP];
return At(i);
}
__host__ __device__ data_type& operator()(index_t i)
template <typename I>
__host__ __device__ constexpr data_type& operator()(I i)
{
constexpr index_t ip = Picks{}[i];
return mData[ip];
return At(i);
}
Arr& mData;
template <typename T>
__host__ __device__ constexpr type& operator=(const T& a)
{
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; });
return *this;
}
Arr& mArray;
};
template <class Arr, class Picks>
template <typename Arr, typename Picks>
__host__ __device__ constexpr auto pick_array_element(Arr& a, Picks)
{
return ArrayElementPicker<Arr, Picks>(a);
}
#if 1
template <typename T>
__host__ __device__ constexpr auto to_array(const T& x)
{
Array<typename T::data_type, T::Size()> y;
static_for<0, T::Size(), 1>{}([&](auto i) { y.At(i) = x.At(i); });
return y;
}
#endif
template <index_t... Is>
__host__ __device__ constexpr auto sequence2array(Sequence<Is...>)
{
return Array<index_t, sizeof...(Is)>{Is...};
}
template <class TData, index_t NSize>
template <typename TData, index_t NSize>
__host__ __device__ constexpr auto make_zero_array()
{
constexpr auto zero_sequence = typename uniform_sequence_gen<NSize, 0>::type{};
......@@ -145,7 +202,7 @@ __host__ __device__ constexpr auto make_zero_array()
return zero_array;
}
template <class TData, index_t NSize, index_t... IRs>
template <typename TData, index_t NSize, index_t... IRs>
__host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData, NSize>& old_array,
Sequence<IRs...> /*new2old*/)
{
......@@ -156,7 +213,7 @@ __host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData
return Array<TData, NSize>{old_array[IRs]...};
}
template <class TData, index_t NSize, class MapOld2New>
template <typename TData, index_t NSize, typename MapOld2New>
struct lambda_reorder_array_given_old2new
{
const Array<TData, NSize>& old_array;
......@@ -173,13 +230,13 @@ struct lambda_reorder_array_given_old2new
{
TData old_data = old_array[IOldDim];
constexpr index_t INewDim = MapOld2New::Get(Number<IOldDim>{});
constexpr index_t INewDim = MapOld2New::At(Number<IOldDim>{});
new_array.Set(Number<INewDim>{}, old_data);
new_array(Number<INewDim>{}) = old_data;
}
};
template <class TData, index_t NSize, index_t... IRs>
template <typename TData, index_t NSize, index_t... IRs>
__host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData, NSize>& old_array,
Sequence<IRs...> /*old2new*/)
{
......@@ -195,7 +252,7 @@ __host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData
return new_array;
}
template <class TData, index_t NSize, class ExtractSeq>
template <typename TData, index_t NSize, typename ExtractSeq>
__host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_array, ExtractSeq)
{
Array<TData, ExtractSeq::GetSize()> new_array;
......@@ -204,12 +261,13 @@ __host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_
static_assert(new_size <= NSize, "wrong! too many extract");
static_for<0, new_size, 1>{}([&](auto I) { new_array(I) = old_array[ExtractSeq::Get(I)]; });
static_for<0, new_size, 1>{}([&](auto I) { new_array(I) = old_array[ExtractSeq::At(I)]; });
return new_array;
}
template <class F, class X, class Y, class Z> // emulate constepxr lambda for array math
template <typename F, typename X, typename Y, typename Z> // emulate constepxr lambda for array
// math
struct lambda_array_math
{
const F& f;
......@@ -226,13 +284,12 @@ struct lambda_array_math
__host__ __device__ constexpr void operator()(Number<IDim_>) const
{
constexpr auto IDim = Number<IDim_>{};
z.Set(IDim, f(x[IDim], y[IDim]));
z(IDim) = f(x[IDim], y[IDim]);
}
};
// Array = Array + Array
template <class TData, index_t NSize>
template <typename TData, index_t NSize>
__host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData, NSize> b)
{
Array<TData, NSize> result;
......@@ -247,7 +304,7 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData,
}
// Array = Array - Array
template <class TData, index_t NSize>
template <typename TData, index_t NSize>
__host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData, NSize> b)
{
Array<TData, NSize> result;
......@@ -262,7 +319,7 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData,
}
// Array += Array
template <class TData, index_t NSize>
template <typename TData, index_t NSize>
__host__ __device__ constexpr auto operator+=(Array<TData, NSize>& a, Array<TData, NSize> b)
{
a = a + b;
......@@ -270,14 +327,14 @@ __host__ __device__ constexpr auto operator+=(Array<TData, NSize>& a, Array<TDat
}
// Array -= Array
template <class TData, index_t NSize>
template <typename TData, index_t NSize>
__host__ __device__ constexpr auto operator-=(Array<TData, NSize>& a, Array<TData, NSize> b)
{
a = a - b;
return a;
}
// Array = Array + Sequence
template <class TData, index_t NSize, index_t... Is>
template <typename TData, index_t NSize, index_t... Is>
__host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is...> b)
{
static_assert(sizeof...(Is) == NSize, "wrong! size not the same");
......@@ -294,7 +351,7 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is.
}
// Array = Array - Sequence
template <class TData, index_t NSize, index_t... Is>
template <typename TData, index_t NSize, index_t... Is>
__host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Sequence<Is...> b)
{
static_assert(sizeof...(Is) == NSize, "wrong! size not the same");
......@@ -311,7 +368,7 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Sequence<Is.
}
// Array = Array * Sequence
template <class TData, index_t NSize, index_t... Is>
template <typename TData, index_t NSize, index_t... Is>
__host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is...> b)
{
static_assert(sizeof...(Is) == NSize, "wrong! size not the same");
......@@ -328,7 +385,7 @@ __host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is.
}
// Array = Sequence - Array
template <class TData, index_t NSize, index_t... Is>
template <typename TData, index_t NSize, index_t... Is>
__host__ __device__ constexpr auto operator-(Sequence<Is...> a, Array<TData, NSize> b)
{
static_assert(sizeof...(Is) == NSize, "wrong! size not the same");
......@@ -344,7 +401,7 @@ __host__ __device__ constexpr auto operator-(Sequence<Is...> a, Array<TData, NSi
return result;
}
template <class TData, index_t NSize, class Reduce>
template <typename TData, index_t NSize, typename Reduce>
__host__ __device__ constexpr TData
accumulate_on_array(const Array<TData, NSize>& a, Reduce f, TData init)
{
......@@ -357,89 +414,5 @@ accumulate_on_array(const Array<TData, NSize>& a, Reduce f, TData init)
return result;
}
template <class T, index_t NSize>
__host__ __device__ void print_Array(const char* s, Array<T, NSize> a)
{
constexpr index_t nsize = a.GetSize();
static_assert(nsize > 0 && nsize <= 10, "wrong!");
static_if<nsize == 1>{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, a[0]); });
static_if<nsize == 2>{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, a[0], a[1]); });
static_if<nsize == 3>{}(
[&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, a[0], a[1], a[2]); });
static_if<nsize == 4>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3]); });
static_if<nsize == 5>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4]);
});
static_if<nsize == 6>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5]);
});
static_if<nsize == 7>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6]);
});
static_if<nsize == 8>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7]);
});
static_if<nsize == 9>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7],
a[8]);
});
static_if<nsize == 10>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7],
a[8],
a[9]);
});
}
} // namespace ck
#endif
......@@ -12,22 +12,22 @@ struct static_for;
template <index_t...>
struct Sequence;
template <class Seq, index_t I>
template <typename Seq, index_t I>
struct sequence_split;
template <class>
template <typename>
struct sequence_reverse;
template <class>
template <typename>
struct sequence_map_inverse;
template <class>
template <typename>
struct is_valid_sequence_map;
template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>);
template <class Seq>
template <typename Seq>
__host__ __device__ constexpr auto sequence_pop_back(Seq);
template <index_t... Is>
......@@ -38,9 +38,11 @@ struct Sequence
static constexpr index_t mSize = sizeof...(Is);
__host__ __device__ static constexpr auto GetSize() { return Number<mSize>{}; }
__host__ __device__ static constexpr auto Size() { return Number<mSize>{}; }
__host__ __device__ static constexpr index_t GetImpl(index_t I)
__host__ __device__ static constexpr auto GetSize() { return Size(); }
__host__ __device__ static constexpr index_t At(index_t I)
{
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0
const index_t mData[mSize + 1] = {Is..., 0};
......@@ -48,23 +50,24 @@ struct Sequence
}
template <index_t I>
__host__ __device__ static constexpr auto Get(Number<I>)
__host__ __device__ static constexpr auto At(Number<I>)
{
static_assert(I < mSize, "wrong! I too large");
return Number<GetImpl(Number<I>{})>{};
return Number<At(I)>{};
}
__host__ __device__ static constexpr auto Get(index_t I) { return GetImpl(I); }
template <index_t I>
__host__ __device__ constexpr auto operator[](Number<I>) const
__host__ __device__ static constexpr auto Get(Number<I>)
{
return Get(Number<I>{});
return At(Number<I>{});
}
// make sure I is constepxr if you want a constexpr return type
__host__ __device__ constexpr index_t operator[](index_t I) const { return GetImpl(I); }
template <typename I>
__host__ __device__ constexpr auto operator[](I i) const
{
return At(i);
}
template <index_t... IRs>
__host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
......@@ -74,14 +77,14 @@ struct Sequence
static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");
return Sequence<Type::Get(Number<IRs>{})...>{};
return Sequence<Type::At(Number<IRs>{})...>{};
}
// MapOld2New is Sequence<...>
template <class MapOld2New>
template <typename MapOld2New>
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
{
static_assert(MapOld2New::GetSize() == GetSize(),
static_assert(MapOld2New::Size() == Size(),
"wrong! reorder map should have the same size as Sequence to be rerodered");
static_assert(is_valid_sequence_map<MapOld2New>::value, "wrong! invalid reorder map");
......@@ -97,13 +100,13 @@ struct Sequence
__host__ __device__ static constexpr auto Front()
{
static_assert(mSize > 0, "wrong!");
return Get(Number<0>{});
return At(Number<0>{});
}
__host__ __device__ static constexpr auto Back()
{
static_assert(mSize > 0, "wrong!");
return Get(Number<mSize - 1>{});
return At(Number<mSize - 1>{});
}
__host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); }
......@@ -137,19 +140,19 @@ struct Sequence
template <index_t... Ns>
__host__ __device__ static constexpr auto Extract(Number<Ns>...)
{
return Sequence<Type::Get(Number<Ns>{})...>{};
return Sequence<Type::At(Number<Ns>{})...>{};
}
template <index_t... Ns>
__host__ __device__ static constexpr auto Extract(Sequence<Ns...>)
{
return Sequence<Type::Get(Number<Ns>{})...>{};
return Sequence<Type::At(Number<Ns>{})...>{};
}
template <index_t I, index_t X>
__host__ __device__ static constexpr auto Modify(Number<I>, Number<X>)
{
static_assert(I < GetSize(), "wrong!");
static_assert(I < Size(), "wrong!");
using seq_split = sequence_split<Type, I>;
constexpr auto seq_left = typename seq_split::SeqType0{};
......@@ -158,7 +161,7 @@ struct Sequence
return seq_left.PushBack(Number<X>{}).PushBack(seq_right);
}
template <class F>
template <typename F>
__host__ __device__ static constexpr auto Transform(F f)
{
return Sequence<f(Is)...>{};
......@@ -166,8 +169,11 @@ struct Sequence
};
// merge sequence
template <class, class>
struct sequence_merge;
template <typename Seq, typename... Seqs>
struct sequence_merge
{
using type = typename sequence_merge<Seq, typename sequence_merge<Seqs...>::type>::type;
};
template <index_t... Xs, index_t... Ys>
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
......@@ -175,8 +181,14 @@ struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
using type = Sequence<Xs..., Ys...>;
};
template <typename Seq>
struct sequence_merge<Seq>
{
using type = Seq;
};
// generate sequence
template <index_t IBegin, index_t NRemain, class F>
template <index_t IBegin, index_t NRemain, typename F>
struct sequence_gen_impl
{
static constexpr index_t NRemainLeft = NRemain / 2;
......@@ -188,20 +200,20 @@ struct sequence_gen_impl
typename sequence_gen_impl<IMiddle, NRemainRight, F>::type>::type;
};
template <index_t I, class F>
template <index_t I, typename F>
struct sequence_gen_impl<I, 1, F>
{
static constexpr index_t Is = F{}(Number<I>{});
using type = Sequence<Is>;
};
template <index_t I, class F>
template <index_t I, typename F>
struct sequence_gen_impl<I, 0, F>
{
using type = Sequence<>;
};
template <index_t NSize, class F>
template <index_t NSize, typename F>
struct sequence_gen
{
using type = typename sequence_gen_impl<0, NSize, F>::type;
......@@ -235,10 +247,10 @@ struct uniform_sequence_gen
};
// reverse inclusive scan (with init) sequence
template <class, class, index_t>
template <typename, typename, index_t>
struct sequence_reverse_inclusive_scan;
template <index_t I, index_t... Is, class Reduce, index_t Init>
template <index_t I, index_t... Is, typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
{
using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce, Init>::type;
......@@ -248,23 +260,23 @@ struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
using type = typename sequence_merge<Sequence<new_reduce>, old_scan>::type;
};
template <index_t I, class Reduce, index_t Init>
template <index_t I, typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce, Init>
{
using type = Sequence<Reduce{}(I, Init)>;
};
template <class Reduce, index_t Init>
template <typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init>
{
using type = Sequence<>;
};
// split sequence
template <class Seq, index_t I>
template <typename Seq, index_t I>
struct sequence_split
{
static constexpr index_t NSize = Seq{}.GetSize();
static constexpr index_t NSize = Seq{}.Size();
using range0 = typename arithmetic_sequence_gen<0, I, 1>::type;
using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::type;
......@@ -274,10 +286,10 @@ struct sequence_split
};
// reverse sequence
template <class Seq>
template <typename Seq>
struct sequence_reverse
{
static constexpr index_t NSize = Seq{}.GetSize();
static constexpr index_t NSize = Seq{}.Size();
using seq_split = sequence_split<Seq, NSize / 2>;
using type = typename sequence_merge<
......@@ -297,19 +309,102 @@ struct sequence_reverse<Sequence<I0, I1>>
using type = Sequence<I1, I0>;
};
template <class Seq, class Compare>
template <typename Seq, typename Compare>
struct sequence_sort
{
// not implemented
template <typename SeqLeft, typename SeqRight, typename MergedSeq, typename Comp>
struct sorted_sequence_merge_impl
{
static constexpr bool pick_left = SeqLeft::Front() < SeqRight::Front();
static constexpr index_t next_value = pick_left ? SeqLeft::Front() : SeqRight::Front();
using new_merged_seq = decltype(MergedSeq::PushBack(Number<next_value>{}));
using new_left_seq =
typename conditional<pick_left, decltype(SeqLeft::PopFront()), SeqLeft>::type;
using new_right_seq =
typename conditional<pick_left, SeqRight, decltype(SeqRight::PopFront())>::type;
using type =
typename sorted_sequence_merge_impl<new_left_seq, new_right_seq, new_merged_seq, Comp>::
type;
};
template <typename SeqLeft, typename MergedSeq, typename Comp>
struct sorted_sequence_merge_impl<SeqLeft, Sequence<>, MergedSeq, Comp>
{
using type = typename sequence_merge<MergedSeq, SeqLeft>::type;
};
template <typename SeqRight, typename MergedSeq, typename Comp>
struct sorted_sequence_merge_impl<Sequence<>, SeqRight, MergedSeq, Comp>
{
using type = typename sequence_merge<MergedSeq, SeqRight>::type;
};
template <typename Seq0, typename Seq1, typename Comp>
struct sorted_sequence_merge
{
using type = typename sorted_sequence_merge_impl<Seq0, Seq1, Sequence<>, Comp>::type;
};
using split = sequence_split<Seq, Seq::Size() / 2>;
using unsorted_left = typename split::SeqType0;
using unsorted_right = typename split::SeqType1;
using sorted_left = typename sequence_sort<unsorted_left, Compare>::type;
using sorted_right = typename sequence_sort<unsorted_right, Compare>::type;
using type = typename sorted_sequence_merge<sorted_left, sorted_right, Compare>::type;
};
template <index_t X, index_t Y, typename Compare>
struct sequence_sort<Sequence<X, Y>, Compare>
{
static constexpr bool x_first = Compare{}(X, Y);
using type = typename conditional<x_first, Sequence<X, Y>, Sequence<Y, X>>::type;
};
template <index_t X, typename Compare>
struct sequence_sort<Sequence<X>, Compare>
{
using type = Sequence<X>;
};
template <class Seq, class Compare>
template <typename Seq, typename Less, typename Equal>
struct sequence_unique_sort
{
// not implemented
template <typename WorkInputSeq, typename WorkOutputSeq, typename Eq>
struct sorted_sequence_uniquify_impl
{
static constexpr index_t new_value = WorkInputSeq::Front();
using new_work_input_seq = decltype(WorkInputSeq::PopFront());
using new_working_output_seq =
typename conditional<new_value == WorkOutputSeq::Back(),
WorkOutputSeq,
decltype(WorkOutputSeq::PopBack(Number<new_value>{}))>::type;
};
template <typename WorkInputSeq, typename Eq>
struct sorted_sequence_uniquify_impl<WorkInputSeq, Sequence<>, Eq>
{
using type = WorkInputSeq;
};
template <typename SortedSeq, typename Eq>
struct sorted_sequence_uniquify
{
using type = typename sorted_sequence_uniquify_impl<SortedSeq, Sequence<>, Eq>::type;
};
using sorted_seq = typename sequence_sort<Seq, Less>::type;
using type = typename sorted_sequence_uniquify<sorted_seq, Equal>::type;
};
template <class Seq>
template <typename Seq>
struct is_valid_sequence_map
{
// not implemented yet, always return true
......@@ -317,36 +412,35 @@ struct is_valid_sequence_map
// TODO: add proper check for is_valid, something like:
// static constexpr bool value =
// is_same<typename arithmetic_sequence_gen<0, Seq::GetSize(), 1>::type,
// is_same<typename arithmetic_sequence_gen<0, Seq::Size(), 1>::type,
// typename sequence_sort<Seq>::SortedSeqType>{};
};
template <class X2Y, class WorkingY2X, index_t XBegin, index_t XRemain>
template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
struct sequence_map_inverse_impl
{
private:
static constexpr auto new_y2x =
WorkingY2X::Modify(X2Y::Get(Number<XBegin>{}), Number<XBegin>{});
static constexpr auto new_y2x = WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});
public:
using type =
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::type;
};
template <class X2Y, class WorkingY2X, index_t XBegin>
template <typename X2Y, typename WorkingY2X, index_t XBegin>
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
{
using type = WorkingY2X;
};
template <class X2Y>
template <typename X2Y>
struct sequence_map_inverse
{
using type =
typename sequence_map_inverse_impl<X2Y,
typename uniform_sequence_gen<X2Y::GetSize(), 0>::type,
typename uniform_sequence_gen<X2Y::Size(), 0>::type,
0,
X2Y::GetSize()>::type;
X2Y::Size()>::type;
};
template <index_t... Xs, index_t... Ys>
......@@ -457,20 +551,26 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
return Sequence<Is...>{};
}
template <class Seq>
template <typename Seq>
__host__ __device__ constexpr auto sequence_pop_back(Seq)
{
static_assert(Seq::GetSize() > 0, "wrong! cannot pop an empty Sequence!");
static_assert(Seq::Size() > 0, "wrong! cannot pop an empty Sequence!");
return sequence_pop_front(Seq::Reverse()).Reverse();
}
template <class F, index_t... Xs>
template <typename F, index_t... Xs>
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>)
{
return Sequence<f(Xs)...>{};
}
template <class F, index_t... Xs, index_t... Ys>
template <typename... Seqs>
__host__ __device__ constexpr auto merge_sequences(Seqs...)
{
return typename sequence_merge<Seqs...>::type{};
}
template <typename F, index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>)
{
static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize, "Dim not the same");
......@@ -478,7 +578,7 @@ __host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Seq
return Sequence<f(Xs, Ys)...>{};
}
template <class F, index_t... Xs, index_t... Ys, index_t... Zs>
template <typename F, index_t... Xs, index_t... Ys, index_t... Zs>
__host__ __device__ constexpr auto
transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
{
......@@ -489,19 +589,19 @@ transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
return Sequence<f(Xs, Ys, Zs)...>{};
}
template <class Seq, class Reduce, index_t Init>
template <typename Seq, typename Reduce, index_t Init>
__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number<Init>)
{
return typename sequence_reverse_inclusive_scan<Seq, Reduce, Init>::type{};
}
template <class Seq, class Reduce, index_t Init>
template <typename Seq, typename Reduce, index_t Init>
__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<Init>)
{
return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse();
}
template <class Seq, class Reduce>
template <typename Seq, typename Reduce>
struct lambda_accumulate_on_sequence
{
const Reduce& f;
......@@ -512,14 +612,14 @@ struct lambda_accumulate_on_sequence
{
}
template <class IDim>
template <typename IDim>
__host__ __device__ constexpr index_t operator()(IDim) const
{
return result = f(result, Seq::Get(IDim{}));
return result = f(result, Seq::At(IDim{}));
}
};
template <class Seq, class Reduce, index_t Init>
template <typename Seq, typename Reduce, index_t Init>
__host__ __device__ constexpr index_t
accumulate_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
{
......@@ -530,41 +630,5 @@ accumulate_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
return result;
}
template <index_t... Xs>
__host__ __device__ void print_Sequence(const char* s, Sequence<Xs...>)
{
constexpr index_t nsize = Sequence<Xs...>::GetSize();
static_assert(nsize <= 10, "wrong!");
static_if<nsize == 0>{}([&](auto) { printf("%s size %u, {}\n", s, nsize, Xs...); });
static_if<nsize == 1>{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, Xs...); });
static_if<nsize == 2>{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, Xs...); });
static_if<nsize == 3>{}([&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 4>{}([&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 5>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 6>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 7>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 8>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 9>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 10>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
}
} // namespace ck
#endif
#ifndef CK_ARRAY_HELPER_HPP
#define CK_ARRAY_HELPER_HPP
#include "Array.hpp"
namespace ck {
template <typename T, index_t NSize>
__host__ __device__ void print_Array(const char* s, Array<T, NSize> a)
{
constexpr index_t nsize = a.GetSize();
static_assert(nsize > 0 && nsize <= 10, "wrong!");
static_if<nsize == 1>{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, a[0]); });
static_if<nsize == 2>{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, a[0], a[1]); });
static_if<nsize == 3>{}(
[&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, a[0], a[1], a[2]); });
static_if<nsize == 4>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3]); });
static_if<nsize == 5>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4]);
});
static_if<nsize == 6>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5]);
});
static_if<nsize == 7>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6]);
});
static_if<nsize == 8>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7]);
});
static_if<nsize == 9>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7],
a[8]);
});
static_if<nsize == 10>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7],
a[8],
a[9]);
});
}
} // namespace ck
#endif
\ No newline at end of file
......@@ -4,14 +4,19 @@
#include "config.hpp"
#include "utility.hpp"
#include "integral_constant.hpp"
#include "number.hpp"
#include "type.hpp"
#include "tuple.hpp"
#include "math.hpp"
#include "vector_type.hpp"
#include "Sequence.hpp"
#include "sequence_helper.hpp"
#include "Array.hpp"
#include "array_helper.hpp"
#include "functional.hpp"
#include "functional2.hpp"
#include "functional3.hpp"
#include "functional4.hpp"
#if CK_USE_AMD_INLINE_ASM
#include "amd_inline_asm.hpp"
......
......@@ -3,9 +3,11 @@
#include "integral_constant.hpp"
#include "Sequence.hpp"
#include "type.hpp"
namespace ck {
// TODO: right? wrong?
struct forwarder
{
template <typename T>
......@@ -17,7 +19,7 @@ struct forwarder
struct swallow
{
template <class... Ts>
template <typename... Ts>
__host__ __device__ constexpr swallow(Ts&&...)
{
}
......@@ -32,7 +34,7 @@ struct static_if<true>
{
using Type = static_if<true>;
template <class F>
template <typename F>
__host__ __device__ constexpr auto operator()(F f) const
{
// This is a trick for compiler:
......@@ -43,7 +45,7 @@ struct static_if<true>
return Type{};
}
template <class F>
template <typename F>
__host__ __device__ static constexpr auto Else(F)
{
return Type{};
......@@ -55,13 +57,13 @@ struct static_if<false>
{
using Type = static_if<false>;
template <class F>
template <typename F>
__host__ __device__ constexpr auto operator()(F) const
{
return Type{};
}
template <class F>
template <typename F>
__host__ __device__ static constexpr auto Else(F f)
{
// This is a trick for compiler:
......@@ -73,5 +75,23 @@ struct static_if<false>
}
};
template <bool predicate, class X, class Y>
struct conditional;
template <class X, class Y>
struct conditional<true, X, Y>
{
using type = X;
};
template <class X, class Y>
struct conditional<false, X, Y>
{
using type = Y;
};
template <bool predicate, class X, class Y>
using conditional_t = typename conditional<predicate, X, Y>::type;
} // namespace ck
#endif
......@@ -6,6 +6,8 @@
namespace ck {
namespace detail {
template <class>
struct static_for_impl;
......@@ -19,6 +21,8 @@ struct static_for_impl<Sequence<Is...>>
}
};
} // namespace detail
// F signature: F(Number<Iter>)
template <index_t NBegin, index_t NEnd, index_t Increment>
struct static_for
......@@ -33,7 +37,8 @@ struct static_for
template <class F>
__host__ __device__ constexpr void operator()(F f) const
{
static_for_impl<typename arithmetic_sequence_gen<NBegin, NEnd, Increment>::type>{}(f);
detail::static_for_impl<typename arithmetic_sequence_gen<NBegin, NEnd, Increment>::type>{}(
f);
}
};
......
......@@ -8,20 +8,7 @@
namespace ck {
template <class>
struct is_static : integral_constant<bool, false>
{
};
template <class T, T X>
struct is_static<integral_constant<T, X>> : integral_constant<bool, true>
{
};
template <index_t... Is>
struct is_static<Sequence<Is...>> : integral_constant<bool, true>
{
};
namespace detail {
// RemainLengths: Sequence<...>
// Orders: Sequence<...>
......@@ -58,29 +45,6 @@ struct static_ford_impl<Sequence<>, Orders>
}
};
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which static_ford will loop over each
// dimension
template <class Lengths,
class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
struct static_ford
{
__host__ __device__ constexpr static_ford()
{
static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
}
// F signature: F(Sequence<...> multi_id)
// multi_id is the unordered multi-index
template <class F>
__host__ __device__ constexpr void operator()(F f) const
{
constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});
static_ford_impl<decltype(ordered_lengths), Orders>{}(f, Sequence<>{});
}
};
// RemainLengths: Sequence<...>
// Orders: Sequence<...>
template <class RemainLengths, class Orders>
......@@ -117,6 +81,31 @@ struct ford_impl<Sequence<>, Orders>
}
};
} // namespace detail
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which static_ford will loop over each
// dimension
template <class Lengths,
class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
struct static_ford
{
__host__ __device__ constexpr static_ford()
{
static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
}
// F signature: F(Sequence<...> multi_id)
// multi_id is the unordered multi-index
template <class F>
__host__ __device__ constexpr void operator()(F f) const
{
constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});
detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, Sequence<>{});
}
};
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which ford will loop over each
// dimension
......@@ -139,7 +128,8 @@ struct ford
for(index_t i = 0; i < ordered_lengths.Front(); ++i)
{
ford_impl<decltype(ordered_lengths.PopFront()), Orders>{}(f, Array<index_t, 1>{i});
detail::ford_impl<decltype(ordered_lengths.PopFront()), Orders>{}(f,
Array<index_t, 1>{i});
}
}
};
......
#ifndef CK_FUNCTIONAL4_HPP
#define CK_FUNCTIONAL4_HPP
#include "Sequence.hpp"
#include "tuple.hpp"
#include "Array.hpp"
namespace ck {
namespace detail {
template <typename Indices>
struct unpack_impl;
template <index_t... Is>
struct unpack_impl<Sequence<Is...>>
{
template <typename F, typename X>
__host__ __device__ constexpr auto operator()(F f, const X& x) const
{
return f(x.At(Number<Is>{})...);
}
};
} // namespace detail
template <typename F, typename X>
__host__ __device__ constexpr auto unpack(F f, const X& x)
{
return detail::unpack_impl<typename arithmetic_sequence_gen<0, X::Size(), 1>::type>{}(f, x);
}
} // namespace ck
#endif
......@@ -13,54 +13,5 @@ struct integral_constant
__host__ __device__ constexpr value_type operator()() const noexcept { return value; }
};
template <class X, class Y>
struct is_same : public integral_constant<bool, false>
{
};
template <class X>
struct is_same<X, X> : public integral_constant<bool, true>
{
};
template <class T>
using remove_cv_t = typename std::remove_cv<T>::type;
template <index_t N>
using Number = integral_constant<index_t, N>;
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator+(Number<X>, Number<Y>)
{
return Number<X + Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator-(Number<X>, Number<Y>)
{
static_assert(Y <= X, "wrong!");
return Number<X - Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator*(Number<X>, Number<Y>)
{
return Number<X * Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator/(Number<X>, Number<Y>)
{
static_assert(Y > 0, "wrong!");
return Number<X / Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator%(Number<X>, Number<Y>)
{
static_assert(Y > 0, "wrong!");
return Number<X % Y>{};
}
} // namespace ck
#endif
......@@ -104,6 +104,18 @@ __host__ __device__ constexpr T lcm(T x, Ts... xs)
return max(x, xs...);
}
template <class T>
struct equal
{
__host__ __device__ constexpr bool operator()(T x, T y) const { return x == y; }
};
template <class T>
struct less
{
__host__ __device__ constexpr bool operator()(T x, T y) const { return x < y; }
};
} // namespace math
} // namspace ck
......
#ifndef CK_NUMBER_HPP
#define CK_NUMBER_HPP
#include "integral_constant.hpp"
namespace ck {
template <index_t N>
using Number = integral_constant<index_t, N>;
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator+(Number<X>, Number<Y>)
{
return Number<X + Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator-(Number<X>, Number<Y>)
{
static_assert(Y <= X, "wrong!");
return Number<X - Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator*(Number<X>, Number<Y>)
{
return Number<X * Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator/(Number<X>, Number<Y>)
{
static_assert(Y > 0, "wrong!");
return Number<X / Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator%(Number<X>, Number<Y>)
{
static_assert(Y > 0, "wrong!");
return Number<X % Y>{};
}
} // namespace ck
#endif
#ifndef CK_SEQUENCE_HELPER_HPP
#define CK_SEQUENCE_HELPER_HPP
#include "Sequence.hpp"
namespace ck {
template <index_t... Xs>
__host__ __device__ void print_Sequence(const char* s, Sequence<Xs...>)
{
constexpr index_t nsize = Sequence<Xs...>::Size();
static_assert(nsize <= 10, "wrong!");
static_if<nsize == 0>{}([&](auto) { printf("%s size %u, {}\n", s, nsize, Xs...); });
static_if<nsize == 1>{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, Xs...); });
static_if<nsize == 2>{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, Xs...); });
static_if<nsize == 3>{}([&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 4>{}([&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 5>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 6>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 7>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 8>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 9>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 10>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
}
} // namespace ck
#endif
......@@ -2,6 +2,7 @@
#define CK_TUPLE_HPP
#include "integral_constant.hpp"
#include "type.hpp"
#include "Sequence.hpp"
namespace ck {
......@@ -16,6 +17,8 @@ struct TupleElementKey
template <typename Key, typename Data>
struct TupleElement
{
__host__ __device__ explicit constexpr TupleElement() : mData() {}
template <typename T>
__host__ __device__ explicit constexpr TupleElement(T&& v) : mData(static_cast<T&&>(v))
{
......@@ -48,6 +51,12 @@ struct TupleImpl;
template <index_t... Is, typename... Xs>
struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>...
{
#if 1
__host__ __device__ explicit constexpr TupleImpl() : TupleElement<TupleElementKey<Is>, Xs>()...
{
}
#endif
template <typename... Ys>
__host__ __device__ explicit constexpr TupleImpl(Ys&&... ys)
: TupleElement<TupleElementKey<Is>, Xs>(static_cast<Ys&&>(ys))...
......@@ -97,5 +106,28 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
}
};
template <typename... Xs>
__host__ __device__ constexpr auto make_tuple(Xs&&... xs)
{
return Tuple<remove_cv_t<remove_reference_t<Xs>>...>(std::forward<Xs>(xs)...);
}
namespace detail {
template <typename X, typename F, index_t... Is>
__host__ __device__ constexpr auto transpose_tuple_impl(X& x, F f, Sequence<Is...>)
{
return make_tuple(f(x.At(Number<Is>{}))...);
}
} // namespace detail
template <typename X, typename F>
__host__ __device__ constexpr auto transpose_tuple(X& x, F f)
{
return detail::transpose_tuple_impl(
x, f, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
}
} // namespace ck
#endif
#ifndef CK_TYPE_HPP
#define CK_TYPE_HPP
#include "integral_constant.hpp"
#include "Sequence.hpp"
namespace ck {
template <typename X, typename Y>
struct is_same : public integral_constant<bool, false>
{
};
template <typename X>
struct is_same<X, X> : public integral_constant<bool, true>
{
};
template <typename>
struct is_static : integral_constant<bool, false>
{
};
template <typename T, T X>
struct is_static<integral_constant<T, X>> : integral_constant<bool, true>
{
};
template <index_t... Is>
struct is_static<Sequence<Is...>> : integral_constant<bool, true>
{
};
template <typename T>
using remove_reference_t = typename std::remove_reference<T>::type;
template <typename T>
using remove_cv_t = typename std::remove_cv<T>::type;
} // namespace ck
#endif
......@@ -115,8 +115,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded(InDesc,
constexpr index_t OutThreadCopyDataPerAccess_N = 4;
#endif
#if 0 // debug
constexpr index_t GridSize =
(N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock);
#else
constexpr index_t GridSize = 1;
#endif
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment