"...composable_kernel.git" did not exist on "ad541ad6b9de9b0579d5254f82e9d5b86103d309"
Commit 7a7fe160 authored by Chao Liu's avatar Chao Liu
Browse files

more utility code

parent 625838de
...@@ -47,6 +47,19 @@ template <index_t GridSize, ...@@ -47,6 +47,19 @@ template <index_t GridSize,
index_t OutThreadCopyDataPerAccess_N> index_t OutThreadCopyDataPerAccess_N>
struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
{ {
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto I8 = Number<8>{};
static constexpr auto I9 = Number<9>{};
static constexpr auto I10 = Number<10>{};
static constexpr auto I11 = Number<11>{};
#if 0 #if 0
__device__ void Run(const Float* const __restrict__ p_in_global, __device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
...@@ -60,11 +73,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded ...@@ -60,11 +73,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
GemmNPerThreadSubC % NPerThread == 0)), GemmNPerThreadSubC % NPerThread == 0)),
"wrong!"); "wrong!");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto True = integral_constant<bool, true>{}; constexpr auto True = integral_constant<bool, true>{};
constexpr auto False = integral_constant<bool, false>{}; constexpr auto False = integral_constant<bool, false>{};
...@@ -487,58 +495,67 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded ...@@ -487,58 +495,67 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
Float* const __restrict__ p_out_global) const Float* const __restrict__ p_out_global) const
{ {
#if 0 #if 0
constexpr auto tmp = std::tuple<bool>{}; constexpr auto a = make_tuple(true, Sequence<1>{}, index_t(99));
constexpr auto flag = std::get<0>(tmp);
#else
constexpr auto a = Tuple<bool, Sequence<1>, index_t>(true, Sequence<1>{}, 99);
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{ {
printf("adsas %d\n", a.At(Number<0>{})); printf("[0] %d\n", a.At(I0));
print_Sequence("seq", a.At(Number<1>{})); print_Sequence("[1]", a.At(I1));
printf("adsas %lu\n", a.At(Number<2>{})); printf("[2] %lu\n", a.At(I2));
} }
auto b = Tuple<bool, Sequence<1>, index_t>(true, Sequence<1>{}, 99); bool flag = true;
auto b = make_tuple(flag, Sequence<1>{}, 99);
b.At(Number<0>{}) = false; b.At(I0) = false;
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{ {
printf("adsas %d\n", b.At(Number<0>{})); printf("[0] %d\n", b.At(I0));
print_Sequence("seq", b.At(Number<1>{})); print_Sequence("[1]", b.At(I1));
printf("adsas %lu\n", b.At(Number<2>{})); printf("[2] %lu\n", b.At(I2));
printf("flag %d\n", flag);
} }
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{ {
printf("adsas %d\n", printf("[0] %d\n", make_tuple(true, Sequence<1>(), index_t(99)).At(I0));
Tuple<bool, Sequence<1>, index_t>(true, Sequence<1>(), 99).At(Number<0>{})); print_Sequence("[1]", make_tuple(true, Sequence<1>(), index_t(99)).At(I1));
print_Sequence( printf("[2] %d\n", make_tuple(true, Sequence<1>(), index_t(99)).At(I2));
"seq", Tuple<bool, Sequence<1>, index_t>(true, Sequence<1>(), 99).At(Number<1>{}));
printf("adsas %d\n",
Tuple<bool, Sequence<1>, index_t>(true, Sequence<1>(), 99).At(Number<2>{}));
} }
#endif #elif 1
#if 0
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
// create a native tensor descriptor // create a native tensor descriptor
constexpr auto in_n_c_h_w_global_desc = constexpr auto in_c_h_w_n_global_desc =
make_NativeTensorDescriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides()); make_NativeTensorDescriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0);
constexpr index_t Hi = in_c_h_w_n_global_desc.GetLength(I1);
constexpr index_t Wi = in_c_h_w_n_global_desc.GetLength(I2);
constexpr index_t N = in_c_h_w_n_global_desc.GetLength(I3);
constexpr auto pad_h_w = Pad<Sequence<Hi, Wi>, LowerPads, UpperPads>{};
constexpr auto pass_c = PassThrough<C>{};
constexpr auto pass_n = PassThrough<N>{};
constexpr auto trans = make_tuple(pass_c, pad_h_w, pass_n);
constexpr auto lower_dim_groups =
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{});
constexpr auto upper_dim_groups =
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{});
constexpr auto in_c_h_w_n_padded_global_desc = transform_tensor_descriptor(
in_c_h_w_n_global_desc, trans, lower_dim_groups, upper_dim_groups);
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{ {
print_tensor_descriptor("in_n_c_h_w_global_desc", in_n_c_h_w_global_desc); print_tensor_descriptor("in_c_h_w_n_global_desc", in_c_h_w_n_global_desc);
}
// transform the tensor descriptor once printf("offset: %lu\n", in_c_h_w_n_global_desc.GetOffset({1, 2, 3, 4}));
//
// calculate the offset of some entry printf("padded offset: %lu\n", in_c_h_w_n_padded_global_desc.GetOffset({1, 4, 5, 4}));
}
#endif #endif
} }
#endif #endif
......
...@@ -178,7 +178,7 @@ struct ConstantTensorDescriptor ...@@ -178,7 +178,7 @@ struct ConstantTensorDescriptor
{ {
constexpr auto IDim = IDim_{}; constexpr auto IDim = IDim_{};
constexpr index_t stride = PackedStrides::Get(IDim); constexpr index_t stride = PackedStrides::Get(IDim);
multi_id.Set(IDim, id / stride); multi_id(IDim) = id / stride;
id -= multi_id[IDim] * stride; id -= multi_id[IDim] * stride;
} }
}; };
...@@ -192,7 +192,7 @@ struct ConstantTensorDescriptor ...@@ -192,7 +192,7 @@ struct ConstantTensorDescriptor
// calculate index in each of the dimensions in the order of their dimension // calculate index in each of the dimensions in the order of their dimension
static_for<0, nDim - 1, 1>{}(lambda_GetMultiIndexFrom1dIndex<PackedStrides>(id, multi_id)); static_for<0, nDim - 1, 1>{}(lambda_GetMultiIndexFrom1dIndex<PackedStrides>(id, multi_id));
multi_id.Set(Number<nDim - 1>{}, id / PackedStrides::Get(Number<nDim - 1>{})); multi_id(Number<nDim - 1>{}) = id / PackedStrides::Get(Number<nDim - 1>{});
return multi_id; return multi_id;
} }
......
...@@ -33,7 +33,7 @@ struct PassThrough ...@@ -33,7 +33,7 @@ struct PassThrough
}; };
// LowLengths: Sequence<...> // LowLengths: Sequence<...>
template <class LowLengths, class LeftPads, class RightPads> template <typename LowLengths, typename LeftPads, typename RightPads>
struct Pad struct Pad
{ {
static constexpr index_t nDim = LowLengths::GetSize(); static constexpr index_t nDim = LowLengths::GetSize();
...@@ -67,7 +67,7 @@ struct Pad ...@@ -67,7 +67,7 @@ struct Pad
#if 0 #if 0
// LowLengths: Sequence<...> // LowLengths: Sequence<...>
template <class LowLengths> template <typename LowLengths>
struct Merge struct Merge
{ {
static constexpr index_t nDimLow = LowLengths::GetSize(); static constexpr index_t nDimLow = LowLengths::GetSize();
...@@ -113,7 +113,7 @@ struct Merge ...@@ -113,7 +113,7 @@ struct Merge
#endif #endif
// UpLengths: Sequence<...> // UpLengths: Sequence<...>
template <index_t LowLength, class UpLengths> template <index_t LowLength, typename UpLengths>
struct Unmerge struct Unmerge
{ {
static constexpr index_t nDimLow = 1; static constexpr index_t nDimLow = 1;
...@@ -161,7 +161,7 @@ struct Unmerge ...@@ -161,7 +161,7 @@ struct Unmerge
// UpLengths: Sequence<...> // UpLengths: Sequence<...>
// Coefficients: Sequence<...> // Coefficients: Sequence<...>
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp] // idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp]
template <index_t LowLength, class UpLengths, class Coefficients> template <index_t LowLength, typename UpLengths, typename Coefficients>
struct Embed struct Embed
{ {
static constexpr index_t nDimLow = 1; static constexpr index_t nDimLow = 1;
......
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
namespace ck { namespace ck {
template <class... NativeDimensions> template <typename... NativeDimensions>
struct NativeTensorDescriptor struct NativeTensorDescriptor
{ {
using type = NativeTensorDescriptor; using type = NativeTensorDescriptor;
static constexpr auto mDimensions = Tuple<NativeDimensions...>{}; static constexpr index_t nDim = sizeof...(NativeDimensions);
static constexpr index_t nDim = mDimensions.GetSize(); static constexpr auto mDimensions = make_tuple(NativeDimensions{}...);
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
...@@ -20,7 +20,7 @@ struct NativeTensorDescriptor ...@@ -20,7 +20,7 @@ struct NativeTensorDescriptor
struct lambda_GetLength struct lambda_GetLength
{ {
template <class IDim> template <typename IDim>
__host__ __device__ constexpr auto operator()(IDim) const __host__ __device__ constexpr auto operator()(IDim) const
{ {
return GetLength(IDim{}); return GetLength(IDim{});
...@@ -34,7 +34,7 @@ struct NativeTensorDescriptor ...@@ -34,7 +34,7 @@ struct NativeTensorDescriptor
struct lambda_GetStride struct lambda_GetStride
{ {
template <class IDim> template <typename IDim>
__host__ __device__ constexpr auto operator()(IDim) const __host__ __device__ constexpr auto operator()(IDim) const
{ {
return GetStride(IDim{}); return GetStride(IDim{});
...@@ -49,16 +49,16 @@ struct NativeTensorDescriptor ...@@ -49,16 +49,16 @@ struct NativeTensorDescriptor
template <index_t IDim> template <index_t IDim>
__host__ __device__ static constexpr auto GetLength(Number<IDim>) __host__ __device__ static constexpr auto GetLength(Number<IDim>)
{ {
return mDimensions.Get(Number<IDim>{}).GetLength(); return mDimensions.At(Number<IDim>{}).GetLength();
} }
template <index_t IDim> template <index_t IDim>
__host__ __device__ static constexpr auto GetStride(Number<IDim>) __host__ __device__ static constexpr auto GetStride(Number<IDim>)
{ {
return mDimensions.Get(Number<IDim>{}).GetStride(); return mDimensions.At(Number<IDim>{}).GetStride();
} }
__host__ __device__ static constexpr index_t GetOffset(Index idx) __host__ __device__ static constexpr index_t GetOffset(const Index& idx)
{ {
index_t offset = 0; index_t offset = 0;
...@@ -67,7 +67,7 @@ struct NativeTensorDescriptor ...@@ -67,7 +67,7 @@ struct NativeTensorDescriptor
return offset; return offset;
} }
__host__ __device__ static constexpr index_t GetOffsetDiff(Index idx_diff) __host__ __device__ static constexpr index_t GetOffsetDiff(const Index& idx_diff)
{ {
index_t offset_diff = 0; index_t offset_diff = 0;
...@@ -96,28 +96,65 @@ struct NativeTensorDescriptor ...@@ -96,28 +96,65 @@ struct NativeTensorDescriptor
} }
}; };
#if 0
// LowerTensorDescriptor // LowerTensorDescriptor
// Transforms: std::tuple<DimensionTransforms...> // Transforms: Tuple<DimensionTransforms...>
// LowerDimensionIds: std::tuple<Sequence<...>> // LowerDimensionIds: Tuple<Sequence<...>>
// UpperDimensionIds: std::tuple<Sequence<...>> // UpperDimensionIds: Tuple<Sequence<...>>
template <class LowTensorDescriptor, class Transforms, class LowDimensionIds, class UpDimensionIds> template <typename LowTensorDescriptor,
typename Transforms,
typename LowDimensionIds,
typename UpDimensionIds>
struct TransformedTensorDescriptor struct TransformedTensorDescriptor
{ {
using type = TransformedTensorDescriptor; using type = TransformedTensorDescriptor;
static constexpr index_t nDimUp = GetUpperNumOfDimension(); static constexpr index_t nTransform = Transforms::Size();
static constexpr index_t nDimLow = GetLowerNumOfDimension();
struct lambda_merge_sequences
{
template <typename... Seqs>
__host__ __device__ constexpr auto operator()(Seqs... seqs) const
{
return merge_sequences(seqs...);
}
};
__host__ __device__ static constexpr auto GetNumOfLowerDimension()
{
// Here, we assume all lower-dimensions are active
// TODO: sanity-check all lower-dimension are indeed active
static constexpr index_t nTransform = Transforms::GetSize(); using duplicated_low_active_dims =
decltype(unpack(lambda_merge_sequences{}, LowDimensionIds{}));
using low_active_dims = typename sequence_unique_sort<duplicated_low_active_dims,
math::less<index_t>,
math::equal<index_t>>::type;
return low_active_dims::Size();
}
__host__ __device__ static constexpr auto GetNumOfUpperDimension()
{
using duplicated_up_active_dims =
decltype(unpack(lambda_merge_sequences{}, UpDimensionIds{}));
using up_active_dims = typename sequence_unique_sort<duplicated_up_active_dims,
math::less<index_t>,
math::equal<index_t>>::type;
return up_active_dims::Size();
}
static constexpr index_t nDimUp = GetNumOfUpperDimension();
static constexpr index_t nDimLow = GetNumOfLowerDimension();
using UpperIndex = MultiIndex<nDimUp>; using UpperIndex = MultiIndex<nDimUp>;
using LowerIndex = MultiIndex<nDimLow>; using LowerIndex = MultiIndex<nDimLow>;
__host__ __device__ static constexpr TransformedTensorDescriptor() __host__ __device__ constexpr TransformedTensorDescriptor()
{ {
static_assert(nTransform == Transforms::GetSize() && static_assert(nTransform == Transforms::Size() && nTransform == LowDimensionIds::Size() &&
nTransform == LowDimensionIds::GetSize() && nTransform == UpDimensionIds::Size(),
nTransform == UpDimensionIds::GetSize(),
"wrong! # of transformations not the same"); "wrong! # of transformations not the same");
// TODO: sanity check: LowDimensionIds should include all low-dimensions, // TODO: sanity check: LowDimensionIds should include all low-dimensions,
...@@ -128,33 +165,17 @@ struct TransformedTensorDescriptor ...@@ -128,33 +165,17 @@ struct TransformedTensorDescriptor
// a low-dimension should be associated with only one transformation // a low-dimension should be associated with only one transformation
} }
__host__ __device__ static constexpr auto GetNumOfLowerDimension()
{
// Here, we assume all lower-dimensions are active
// TODO: sanity-check all lower-dimension are indeed active
constexpr auto low_active_dims = unique_sort_sequence(
merge_tuple_of_sequences(LowDimensionIds{}), math::less<index_t>{});
return low_active_dims.GetSize();
}
__host__ __device__ static constexpr auto GetNumOfUpperDimension()
{
constexpr auto up_active_dims =
unique_sort_sequence(merge_tuple_of_sequences(UpDimensionIds{}), math::less<index_t>{});
return up_active_dims.GetSize();
}
__host__ __device__ static constexpr auto GetNumOfDimension() __host__ __device__ static constexpr auto GetNumOfDimension()
{ {
return GetNumOfUpperDimension(); return GetNumOfUpperDimension();
} }
__host__ __device__ static constexpr auto GetLengths() #if 0
__host__ __device__ static constexpr auto GetUpperLengths()
{ {
struct lambda_get_upper_lengths struct lambda_get_upper_lengths
{ {
template <class Transform> template <typename Transform>
__host__ __device__ constexpr auto operator()(Transform tran) const __host__ __device__ constexpr auto operator()(Transform tran) const
{ {
return tran.GetUpperLengths(); return tran.GetUpperLengths();
...@@ -173,6 +194,7 @@ struct TransformedTensorDescriptor ...@@ -173,6 +194,7 @@ struct TransformedTensorDescriptor
using sort_dimension_ids = using sort_dimension_ids =
sequence_unique_sort<decltype(all_upper_dimension_ids), math::less<index_t>>; sequence_unique_sort<decltype(all_upper_dimension_ids), math::less<index_t>>;
constexpr auto sorted_upper_dimension_ids = typename sort_dimension_ids::type; constexpr auto sorted_upper_dimension_ids = typename sort_dimension_ids::type;
constexpr auto sorted2unsorted_map = typename sort_dimension_ids::sorted2unsorted_map_type; constexpr auto sorted2unsorted_map = typename sort_dimension_ids::sorted2unsorted_map_type;
...@@ -182,46 +204,48 @@ struct TransformedTensorDescriptor ...@@ -182,46 +204,48 @@ struct TransformedTensorDescriptor
return sorted_upper_lengths; return sorted_upper_lengths;
} }
__host__ __device__ static constexpr auto GetLengths() { return GetUpperLengths(); }
#endif
__host__ __device__ static constexpr auto GetLowerTensorDescriptor() __host__ __device__ static constexpr auto GetLowerTensorDescriptor()
{ {
return LowTensorDescriptor{}; return LowTensorDescriptor{};
} }
__host__ __device__ static constexpr index_t GetLowerIndex(UpperIndex idx_up) __host__ __device__ static constexpr LowerIndex GetLowerIndex(const UpperIndex& idx_up)
{ {
LowerIndex idx_low; LowerIndex idx_low;
static_for<0, nTransform, 1>{}([&](auto itran) { static_for<0, nTransform, 1>{}([&](auto itran) {
constexpr auto tran = Transforms::Get(itran); constexpr auto tran = Transforms{}.At(itran);
constexpr auto idx_low_part = pick_array_element(idx_low, LowDimensionIds::Get(itran)); auto idx_low_part = pick_array_element(idx_low, LowDimensionIds{}.At(itran));
constexpr auto idx_up_part = pick_array_element(idx_up, UpDimensionIds::Get(itran)); const auto idx_up_part = pick_array_element(idx_up, UpDimensionIds{}.At(itran));
// this assume each lower (single) index is only assocaited with one transformation, // this assume each lower (single) index is only assocaited with one transformation,
// which is required for index transformation, and has been checked during constructor // which is required for index transformation, and has been checked during constructor
// of TransformedTensorDescriptor // of TransformedTensorDescriptor
idx_low_part = tran.GetLowerIndex(idx_up_part); idx_low_part = tran.GetLowerIndex(to_array(idx_up_part));
}); });
return idx_low; return idx_low;
} }
__host__ __device__ static constexpr index_t GetLowerIndexDiff(UpperIndex idx_up_diff, __host__ __device__ static constexpr LowerIndex GetLowerIndexDiff(const UpperIndex& idx_up_diff,
LowerIndex idx_low_old) const LowerIndex& idx_low_old)
{ {
LowerIndex idx_low_diff; LowerIndex idx_low_diff;
static_for<0, nTransform, 1>{}([&](auto itran) { static_for<0, nTransform, 1>{}([&](auto itran) {
constexpr auto tran = Transforms::Get(itran); constexpr auto tran = Transforms::At(itran);
constexpr auto idx_up_diff_part = const auto idx_up_diff_part =
pick_array_element(idx_up_diff, UpDimensionIds::Get(itran)); pick_array_element(idx_up_diff, UpDimensionIds::At(itran));
constexpr auto idx_low_diff_part = auto idx_low_diff_part = pick_array_element(idx_low_diff, LowDimensionIds::At(itran));
pick_array_element(idx_low_diff, LowDimensionIds::Get(itran));
constexpr auto idx_low_old_part = const auto idx_low_old_part =
pick_array_element(idx_low_old, LowDimensionIds::Get(itran)); pick_array_element(idx_low_old, LowDimensionIds::At(itran));
// this assume each lower (single) index is associated with only one transformation, // this assume each lower (single) index is associated with only one transformation,
// which is required for index transformation, and has been checked during constructor // which is required for index transformation, and has been checked during constructor
...@@ -232,13 +256,14 @@ struct TransformedTensorDescriptor ...@@ -232,13 +256,14 @@ struct TransformedTensorDescriptor
return idx_low_diff; return idx_low_diff;
} }
__host__ __device__ static constexpr index_t GetOffset(UpperIndex idx_up) __host__ __device__ static constexpr index_t GetOffset(const UpperIndex& idx_up)
{ {
return GetLowerTensorDescriptor().GetOffset(GetLowerIndex(idx_up)); return GetLowerTensorDescriptor().GetOffset(GetLowerIndex(idx_up));
} }
#if 0
template <index_t IDim> template <index_t IDim>
__host__ __device__ static constexpr bool IsLinearDimension(Number<IDim>); __host__ __device__ static constexpr bool IsLinearDimension(Number<IDim>)
{ {
// not implemented // not implemented
} }
...@@ -257,8 +282,8 @@ struct TransformedTensorDescriptor ...@@ -257,8 +282,8 @@ struct TransformedTensorDescriptor
{ {
// not implemented // not implemented
} }
};
#endif #endif
};
template <index_t... Lengths, index_t... Strides> template <index_t... Lengths, index_t... Strides>
__host__ __device__ constexpr auto make_NativeTensorDescriptor(Sequence<Lengths...>, __host__ __device__ constexpr auto make_NativeTensorDescriptor(Sequence<Lengths...>,
...@@ -267,15 +292,28 @@ __host__ __device__ constexpr auto make_NativeTensorDescriptor(Sequence<Lengths. ...@@ -267,15 +292,28 @@ __host__ __device__ constexpr auto make_NativeTensorDescriptor(Sequence<Lengths.
return NativeTensorDescriptor<NativeDimension<Lengths, Strides>...>{}; return NativeTensorDescriptor<NativeDimension<Lengths, Strides>...>{};
} }
template <class Lengths> template <typename Lengths>
__host__ __device__ constexpr auto make_NativeTensorDescriptor_packed(Lengths) __host__ __device__ constexpr auto make_NativeTensorDescriptor_packed(Lengths)
{ {
constexpr index_t strides = reverse_inclusive_scan_sequence( constexpr auto strides = reverse_inclusive_scan_sequence(
Lengths::PopFront(), math::multiplies<index_t>{}, Number<1>{}) Lengths::PopFront(), math::multiplies<index_t>{}, Number<1>{})
.PushBack(Number<1>{}); .PushBack(Number<1>{});
return make_NativeTensorDescriptor(Lengths{}, strides); return make_NativeTensorDescriptor(Lengths{}, strides);
} }
template <typename LowTensorDescriptor,
typename Transforms,
typename LowDimensionIds,
typename UpDimensionIds>
__host__ __device__ constexpr auto
transform_tensor_descriptor(LowTensorDescriptor, Transforms, LowDimensionIds, UpDimensionIds)
{
return TransformedTensorDescriptor<LowTensorDescriptor,
Transforms,
LowDimensionIds,
UpDimensionIds>{};
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
namespace ck { namespace ck {
template <class... NativeDimensions> template <typename... NativeDimensions>
__host__ __device__ void print_tensor_descriptor(const char* s, __host__ __device__ void print_tensor_descriptor(const char* s,
NativeTensorDescriptor<NativeDimensions...> desc) NativeTensorDescriptor<NativeDimensions...> desc)
{ {
......
...@@ -6,48 +6,78 @@ ...@@ -6,48 +6,78 @@
namespace ck { namespace ck {
template <class TData, index_t NSize> template <typename TData, index_t NSize>
struct Array struct Array
{ {
using Type = Array<TData, NSize>; using type = Array<TData, NSize>;
using data_type = TData; using data_type = TData;
static constexpr index_t nSize = NSize; index_t mData[NSize];
index_t mData[nSize]; __host__ __device__ explicit constexpr Array() {}
template <class... Xs> template <typename X, typename... Xs>
__host__ __device__ constexpr Array(Xs... xs) : mData{static_cast<TData>(xs)...} __host__ __device__ explicit constexpr Array(X x, Xs... xs)
: mData{static_cast<TData>(x), static_cast<TData>(xs)...}
{ {
static_assert(sizeof...(Xs) + 1 == NSize, "wrong! size");
} }
__host__ __device__ static constexpr index_t GetSize() { return NSize; } #if 0
template <typename T>
__host__ __device__ explicit constexpr Array(const T& x)
{
static_assert(T::Size() == NSize, "wrong! size");
static_for<0, NSize, 1>{}([&](auto i){
mData[i] = x.At(i);
})
}
#endif
__host__ __device__ static constexpr index_t Size() { return NSize; }
__host__ __device__ static constexpr index_t GetSize() { return Size(); }
template <index_t I> template <index_t I>
__host__ __device__ constexpr TData operator[](Number<I>) const __host__ __device__ constexpr const TData& At(Number<I>) const
{ {
static_assert(I < NSize, "wrong!");
return mData[I]; return mData[I];
} }
__host__ __device__ constexpr TData operator[](index_t i) const { return mData[i]; }
template <index_t I> template <index_t I>
__host__ __device__ TData& operator()(Number<I>) __host__ __device__ constexpr TData& At(Number<I>)
{ {
static_assert(I < NSize, "wrong!");
return mData[I]; return mData[I];
} }
__host__ __device__ TData& operator()(index_t i) { return mData[i]; } __host__ __device__ constexpr const TData& At(index_t i) const { return mData[i]; }
template <index_t I> __host__ __device__ constexpr TData& At(index_t i) { return mData[i]; }
__host__ __device__ constexpr void Set(Number<I>, TData x)
template <typename I>
__host__ __device__ constexpr const TData& operator[](I i) const
{ {
static_assert(I < NSize, "wrong!"); return At(i);
}
mData[I] = x; template <typename I>
__host__ __device__ constexpr TData& operator()(I i)
{
return At(i);
} }
__host__ __device__ constexpr void Set(index_t I, TData x) { mData[I] = x; } template <typename T>
__host__ __device__ constexpr type& operator=(const T& x)
{
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = x[i]; });
return *this;
}
struct lambda_PushBack // emulate constexpr lambda struct lambda_PushBack // emulate constexpr lambda
{ {
...@@ -63,7 +93,7 @@ struct Array ...@@ -63,7 +93,7 @@ struct Array
template <index_t I> template <index_t I>
__host__ __device__ constexpr void operator()(Number<I>) const __host__ __device__ constexpr void operator()(Number<I>) const
{ {
new_array.Set(Number<I>{}, old_array[I]); new_array(Number<I>{}) = old_array[I];
} }
}; };
...@@ -73,71 +103,98 @@ struct Array ...@@ -73,71 +103,98 @@ struct Array
static_for<0, NSize, 1>{}(lambda_PushBack(*this, new_array)); static_for<0, NSize, 1>{}(lambda_PushBack(*this, new_array));
new_array.Set(Number<NSize>{}, x); new_array(Number<NSize>{}) = x;
return new_array; return new_array;
} }
}; };
// A: Array // Arr: Array
// Picks: Sequence<...> // Picks: Sequence<...>
template <class Arr, class Picks> template <typename Arr, typename Picks>
struct ArrayElementPicker struct ArrayElementPicker
{ {
using type = ArrayElementPicker;
using data_type = typename Arr::data_type; using data_type = typename Arr::data_type;
__host__ __device__ constexpr ArrayElementPicker(Arr& array) : mData{array} __host__ __device__ constexpr ArrayElementPicker() = delete;
__host__ __device__ explicit constexpr ArrayElementPicker(Arr& array) : mArray{array}
{ {
constexpr index_t imax = constexpr index_t imax =
accumulate_on_sequence(Picks{}, math::maxer<index_t>{}, Number<0>{}); accumulate_on_sequence(Picks{}, math::maxer<index_t>{}, Number<0>{});
static_assert(imax < Picks::GetSize(), "wrong! exceeding max id"); static_assert(imax < Arr::Size(), "wrong! exceeding # array element");
} }
__host__ __device__ static constexpr index_t GetSize() { return Picks::GetSize(); } __host__ __device__ static constexpr auto Size() { return Picks::Size(); }
template <index_t I> template <index_t I>
__host__ __device__ constexpr data_type operator[](Number<I>) const __host__ __device__ constexpr const data_type& At(Number<I>) const
{ {
constexpr auto IP = Picks::Get(Number<I>{}); static_assert(I < Size(), "wrong!");
return mData[IP];
constexpr auto IP = Picks{}[I];
return mArray[IP];
} }
__host__ __device__ constexpr data_type operator[](index_t i) const template <index_t I>
__host__ __device__ constexpr data_type& At(Number<I>)
{ {
constexpr index_t ip = Picks{}[i]; static_assert(I < Size(), "wrong!");
return mData[ip];
constexpr auto IP = Picks{}[I];
return mArray(IP);
} }
template <index_t I> template <typename I>
__host__ __device__ data_type& operator()(Number<I>) __host__ __device__ constexpr const data_type& operator[](I i) const
{ {
constexpr auto IP = Picks::Get(Number<I>{}); return At(i);
return mData[IP];
} }
__host__ __device__ data_type& operator()(index_t i) template <typename I>
__host__ __device__ constexpr data_type& operator()(I i)
{ {
constexpr index_t ip = Picks{}[i]; return At(i);
return mData[ip];
} }
Arr& mData; template <typename T>
__host__ __device__ constexpr type& operator=(const T& a)
{
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; });
return *this;
}
Arr& mArray;
}; };
template <class Arr, class Picks> template <typename Arr, typename Picks>
__host__ __device__ constexpr auto pick_array_element(Arr& a, Picks) __host__ __device__ constexpr auto pick_array_element(Arr& a, Picks)
{ {
return ArrayElementPicker<Arr, Picks>(a); return ArrayElementPicker<Arr, Picks>(a);
} }
#if 1
template <typename T>
__host__ __device__ constexpr auto to_array(const T& x)
{
Array<typename T::data_type, T::Size()> y;
static_for<0, T::Size(), 1>{}([&](auto i) { y.At(i) = x.At(i); });
return y;
}
#endif
template <index_t... Is> template <index_t... Is>
__host__ __device__ constexpr auto sequence2array(Sequence<Is...>) __host__ __device__ constexpr auto sequence2array(Sequence<Is...>)
{ {
return Array<index_t, sizeof...(Is)>{Is...}; return Array<index_t, sizeof...(Is)>{Is...};
} }
template <class TData, index_t NSize> template <typename TData, index_t NSize>
__host__ __device__ constexpr auto make_zero_array() __host__ __device__ constexpr auto make_zero_array()
{ {
constexpr auto zero_sequence = typename uniform_sequence_gen<NSize, 0>::type{}; constexpr auto zero_sequence = typename uniform_sequence_gen<NSize, 0>::type{};
...@@ -145,7 +202,7 @@ __host__ __device__ constexpr auto make_zero_array() ...@@ -145,7 +202,7 @@ __host__ __device__ constexpr auto make_zero_array()
return zero_array; return zero_array;
} }
template <class TData, index_t NSize, index_t... IRs> template <typename TData, index_t NSize, index_t... IRs>
__host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData, NSize>& old_array, __host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData, NSize>& old_array,
Sequence<IRs...> /*new2old*/) Sequence<IRs...> /*new2old*/)
{ {
...@@ -156,7 +213,7 @@ __host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData ...@@ -156,7 +213,7 @@ __host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData
return Array<TData, NSize>{old_array[IRs]...}; return Array<TData, NSize>{old_array[IRs]...};
} }
template <class TData, index_t NSize, class MapOld2New> template <typename TData, index_t NSize, typename MapOld2New>
struct lambda_reorder_array_given_old2new struct lambda_reorder_array_given_old2new
{ {
const Array<TData, NSize>& old_array; const Array<TData, NSize>& old_array;
...@@ -173,13 +230,13 @@ struct lambda_reorder_array_given_old2new ...@@ -173,13 +230,13 @@ struct lambda_reorder_array_given_old2new
{ {
TData old_data = old_array[IOldDim]; TData old_data = old_array[IOldDim];
constexpr index_t INewDim = MapOld2New::Get(Number<IOldDim>{}); constexpr index_t INewDim = MapOld2New::At(Number<IOldDim>{});
new_array.Set(Number<INewDim>{}, old_data); new_array(Number<INewDim>{}) = old_data;
} }
}; };
template <class TData, index_t NSize, index_t... IRs> template <typename TData, index_t NSize, index_t... IRs>
__host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData, NSize>& old_array, __host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData, NSize>& old_array,
Sequence<IRs...> /*old2new*/) Sequence<IRs...> /*old2new*/)
{ {
...@@ -195,7 +252,7 @@ __host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData ...@@ -195,7 +252,7 @@ __host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData
return new_array; return new_array;
} }
template <class TData, index_t NSize, class ExtractSeq> template <typename TData, index_t NSize, typename ExtractSeq>
__host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_array, ExtractSeq) __host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_array, ExtractSeq)
{ {
Array<TData, ExtractSeq::GetSize()> new_array; Array<TData, ExtractSeq::GetSize()> new_array;
...@@ -204,12 +261,13 @@ __host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_ ...@@ -204,12 +261,13 @@ __host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_
static_assert(new_size <= NSize, "wrong! too many extract"); static_assert(new_size <= NSize, "wrong! too many extract");
static_for<0, new_size, 1>{}([&](auto I) { new_array(I) = old_array[ExtractSeq::Get(I)]; }); static_for<0, new_size, 1>{}([&](auto I) { new_array(I) = old_array[ExtractSeq::At(I)]; });
return new_array; return new_array;
} }
template <class F, class X, class Y, class Z> // emulate constepxr lambda for array math template <typename F, typename X, typename Y, typename Z> // emulate constepxr lambda for array
// math
struct lambda_array_math struct lambda_array_math
{ {
const F& f; const F& f;
...@@ -226,13 +284,12 @@ struct lambda_array_math ...@@ -226,13 +284,12 @@ struct lambda_array_math
__host__ __device__ constexpr void operator()(Number<IDim_>) const __host__ __device__ constexpr void operator()(Number<IDim_>) const
{ {
constexpr auto IDim = Number<IDim_>{}; constexpr auto IDim = Number<IDim_>{};
z(IDim) = f(x[IDim], y[IDim]);
z.Set(IDim, f(x[IDim], y[IDim]));
} }
}; };
// Array = Array + Array // Array = Array + Array
template <class TData, index_t NSize> template <typename TData, index_t NSize>
__host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData, NSize> b) __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData, NSize> b)
{ {
Array<TData, NSize> result; Array<TData, NSize> result;
...@@ -247,7 +304,7 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData, ...@@ -247,7 +304,7 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData,
} }
// Array = Array - Array // Array = Array - Array
template <class TData, index_t NSize> template <typename TData, index_t NSize>
__host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData, NSize> b) __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData, NSize> b)
{ {
Array<TData, NSize> result; Array<TData, NSize> result;
...@@ -262,7 +319,7 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData, ...@@ -262,7 +319,7 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData,
} }
// Array += Array // Array += Array
template <class TData, index_t NSize> template <typename TData, index_t NSize>
__host__ __device__ constexpr auto operator+=(Array<TData, NSize>& a, Array<TData, NSize> b) __host__ __device__ constexpr auto operator+=(Array<TData, NSize>& a, Array<TData, NSize> b)
{ {
a = a + b; a = a + b;
...@@ -270,14 +327,14 @@ __host__ __device__ constexpr auto operator+=(Array<TData, NSize>& a, Array<TDat ...@@ -270,14 +327,14 @@ __host__ __device__ constexpr auto operator+=(Array<TData, NSize>& a, Array<TDat
} }
// Array -= Array // Array -= Array
template <class TData, index_t NSize> template <typename TData, index_t NSize>
__host__ __device__ constexpr auto operator-=(Array<TData, NSize>& a, Array<TData, NSize> b) __host__ __device__ constexpr auto operator-=(Array<TData, NSize>& a, Array<TData, NSize> b)
{ {
a = a - b; a = a - b;
return a; return a;
} }
// Array = Array + Sequence // Array = Array + Sequence
template <class TData, index_t NSize, index_t... Is> template <typename TData, index_t NSize, index_t... Is>
__host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is...> b) __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is...> b)
{ {
static_assert(sizeof...(Is) == NSize, "wrong! size not the same"); static_assert(sizeof...(Is) == NSize, "wrong! size not the same");
...@@ -294,7 +351,7 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is. ...@@ -294,7 +351,7 @@ __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is.
} }
// Array = Array - Sequence // Array = Array - Sequence
template <class TData, index_t NSize, index_t... Is> template <typename TData, index_t NSize, index_t... Is>
__host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Sequence<Is...> b) __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Sequence<Is...> b)
{ {
static_assert(sizeof...(Is) == NSize, "wrong! size not the same"); static_assert(sizeof...(Is) == NSize, "wrong! size not the same");
...@@ -311,7 +368,7 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Sequence<Is. ...@@ -311,7 +368,7 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Sequence<Is.
} }
// Array = Array * Sequence // Array = Array * Sequence
template <class TData, index_t NSize, index_t... Is> template <typename TData, index_t NSize, index_t... Is>
__host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is...> b) __host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is...> b)
{ {
static_assert(sizeof...(Is) == NSize, "wrong! size not the same"); static_assert(sizeof...(Is) == NSize, "wrong! size not the same");
...@@ -328,7 +385,7 @@ __host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is. ...@@ -328,7 +385,7 @@ __host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is.
} }
// Array = Sequence - Array // Array = Sequence - Array
template <class TData, index_t NSize, index_t... Is> template <typename TData, index_t NSize, index_t... Is>
__host__ __device__ constexpr auto operator-(Sequence<Is...> a, Array<TData, NSize> b) __host__ __device__ constexpr auto operator-(Sequence<Is...> a, Array<TData, NSize> b)
{ {
static_assert(sizeof...(Is) == NSize, "wrong! size not the same"); static_assert(sizeof...(Is) == NSize, "wrong! size not the same");
...@@ -344,7 +401,7 @@ __host__ __device__ constexpr auto operator-(Sequence<Is...> a, Array<TData, NSi ...@@ -344,7 +401,7 @@ __host__ __device__ constexpr auto operator-(Sequence<Is...> a, Array<TData, NSi
return result; return result;
} }
template <class TData, index_t NSize, class Reduce> template <typename TData, index_t NSize, typename Reduce>
__host__ __device__ constexpr TData __host__ __device__ constexpr TData
accumulate_on_array(const Array<TData, NSize>& a, Reduce f, TData init) accumulate_on_array(const Array<TData, NSize>& a, Reduce f, TData init)
{ {
...@@ -357,89 +414,5 @@ accumulate_on_array(const Array<TData, NSize>& a, Reduce f, TData init) ...@@ -357,89 +414,5 @@ accumulate_on_array(const Array<TData, NSize>& a, Reduce f, TData init)
return result; return result;
} }
template <class T, index_t NSize>
__host__ __device__ void print_Array(const char* s, Array<T, NSize> a)
{
constexpr index_t nsize = a.GetSize();
static_assert(nsize > 0 && nsize <= 10, "wrong!");
static_if<nsize == 1>{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, a[0]); });
static_if<nsize == 2>{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, a[0], a[1]); });
static_if<nsize == 3>{}(
[&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, a[0], a[1], a[2]); });
static_if<nsize == 4>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3]); });
static_if<nsize == 5>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4]);
});
static_if<nsize == 6>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5]);
});
static_if<nsize == 7>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6]);
});
static_if<nsize == 8>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7]);
});
static_if<nsize == 9>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7],
a[8]);
});
static_if<nsize == 10>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7],
a[8],
a[9]);
});
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -12,22 +12,22 @@ struct static_for; ...@@ -12,22 +12,22 @@ struct static_for;
template <index_t...> template <index_t...>
struct Sequence; struct Sequence;
template <class Seq, index_t I> template <typename Seq, index_t I>
struct sequence_split; struct sequence_split;
template <class> template <typename>
struct sequence_reverse; struct sequence_reverse;
template <class> template <typename>
struct sequence_map_inverse; struct sequence_map_inverse;
template <class> template <typename>
struct is_valid_sequence_map; struct is_valid_sequence_map;
template <index_t I, index_t... Is> template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>); __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>);
template <class Seq> template <typename Seq>
__host__ __device__ constexpr auto sequence_pop_back(Seq); __host__ __device__ constexpr auto sequence_pop_back(Seq);
template <index_t... Is> template <index_t... Is>
...@@ -38,9 +38,11 @@ struct Sequence ...@@ -38,9 +38,11 @@ struct Sequence
static constexpr index_t mSize = sizeof...(Is); static constexpr index_t mSize = sizeof...(Is);
__host__ __device__ static constexpr auto GetSize() { return Number<mSize>{}; } __host__ __device__ static constexpr auto Size() { return Number<mSize>{}; }
__host__ __device__ static constexpr index_t GetImpl(index_t I) __host__ __device__ static constexpr auto GetSize() { return Size(); }
__host__ __device__ static constexpr index_t At(index_t I)
{ {
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0 // the last dummy element is to prevent compiler complain about empty array, when mSize = 0
const index_t mData[mSize + 1] = {Is..., 0}; const index_t mData[mSize + 1] = {Is..., 0};
...@@ -48,23 +50,24 @@ struct Sequence ...@@ -48,23 +50,24 @@ struct Sequence
} }
template <index_t I> template <index_t I>
__host__ __device__ static constexpr auto Get(Number<I>) __host__ __device__ static constexpr auto At(Number<I>)
{ {
static_assert(I < mSize, "wrong! I too large"); static_assert(I < mSize, "wrong! I too large");
return Number<GetImpl(Number<I>{})>{}; return Number<At(I)>{};
} }
__host__ __device__ static constexpr auto Get(index_t I) { return GetImpl(I); }
template <index_t I> template <index_t I>
__host__ __device__ constexpr auto operator[](Number<I>) const __host__ __device__ static constexpr auto Get(Number<I>)
{ {
return Get(Number<I>{}); return At(Number<I>{});
} }
// make sure I is constepxr if you want a constexpr return type template <typename I>
__host__ __device__ constexpr index_t operator[](index_t I) const { return GetImpl(I); } __host__ __device__ constexpr auto operator[](I i) const
{
return At(i);
}
template <index_t... IRs> template <index_t... IRs>
__host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/) __host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
...@@ -74,14 +77,14 @@ struct Sequence ...@@ -74,14 +77,14 @@ struct Sequence
static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map"); static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");
return Sequence<Type::Get(Number<IRs>{})...>{}; return Sequence<Type::At(Number<IRs>{})...>{};
} }
// MapOld2New is Sequence<...> // MapOld2New is Sequence<...>
template <class MapOld2New> template <typename MapOld2New>
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New) __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
{ {
static_assert(MapOld2New::GetSize() == GetSize(), static_assert(MapOld2New::Size() == Size(),
"wrong! reorder map should have the same size as Sequence to be rerodered"); "wrong! reorder map should have the same size as Sequence to be rerodered");
static_assert(is_valid_sequence_map<MapOld2New>::value, "wrong! invalid reorder map"); static_assert(is_valid_sequence_map<MapOld2New>::value, "wrong! invalid reorder map");
...@@ -97,13 +100,13 @@ struct Sequence ...@@ -97,13 +100,13 @@ struct Sequence
__host__ __device__ static constexpr auto Front() __host__ __device__ static constexpr auto Front()
{ {
static_assert(mSize > 0, "wrong!"); static_assert(mSize > 0, "wrong!");
return Get(Number<0>{}); return At(Number<0>{});
} }
__host__ __device__ static constexpr auto Back() __host__ __device__ static constexpr auto Back()
{ {
static_assert(mSize > 0, "wrong!"); static_assert(mSize > 0, "wrong!");
return Get(Number<mSize - 1>{}); return At(Number<mSize - 1>{});
} }
__host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); } __host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); }
...@@ -137,19 +140,19 @@ struct Sequence ...@@ -137,19 +140,19 @@ struct Sequence
template <index_t... Ns> template <index_t... Ns>
__host__ __device__ static constexpr auto Extract(Number<Ns>...) __host__ __device__ static constexpr auto Extract(Number<Ns>...)
{ {
return Sequence<Type::Get(Number<Ns>{})...>{}; return Sequence<Type::At(Number<Ns>{})...>{};
} }
template <index_t... Ns> template <index_t... Ns>
__host__ __device__ static constexpr auto Extract(Sequence<Ns...>) __host__ __device__ static constexpr auto Extract(Sequence<Ns...>)
{ {
return Sequence<Type::Get(Number<Ns>{})...>{}; return Sequence<Type::At(Number<Ns>{})...>{};
} }
template <index_t I, index_t X> template <index_t I, index_t X>
__host__ __device__ static constexpr auto Modify(Number<I>, Number<X>) __host__ __device__ static constexpr auto Modify(Number<I>, Number<X>)
{ {
static_assert(I < GetSize(), "wrong!"); static_assert(I < Size(), "wrong!");
using seq_split = sequence_split<Type, I>; using seq_split = sequence_split<Type, I>;
constexpr auto seq_left = typename seq_split::SeqType0{}; constexpr auto seq_left = typename seq_split::SeqType0{};
...@@ -158,7 +161,7 @@ struct Sequence ...@@ -158,7 +161,7 @@ struct Sequence
return seq_left.PushBack(Number<X>{}).PushBack(seq_right); return seq_left.PushBack(Number<X>{}).PushBack(seq_right);
} }
template <class F> template <typename F>
__host__ __device__ static constexpr auto Transform(F f) __host__ __device__ static constexpr auto Transform(F f)
{ {
return Sequence<f(Is)...>{}; return Sequence<f(Is)...>{};
...@@ -166,8 +169,11 @@ struct Sequence ...@@ -166,8 +169,11 @@ struct Sequence
}; };
// merge sequence // merge sequence
template <class, class> template <typename Seq, typename... Seqs>
struct sequence_merge; struct sequence_merge
{
using type = typename sequence_merge<Seq, typename sequence_merge<Seqs...>::type>::type;
};
template <index_t... Xs, index_t... Ys> template <index_t... Xs, index_t... Ys>
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>> struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
...@@ -175,8 +181,14 @@ struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>> ...@@ -175,8 +181,14 @@ struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
using type = Sequence<Xs..., Ys...>; using type = Sequence<Xs..., Ys...>;
}; };
template <typename Seq>
struct sequence_merge<Seq>
{
using type = Seq;
};
// generate sequence // generate sequence
template <index_t IBegin, index_t NRemain, class F> template <index_t IBegin, index_t NRemain, typename F>
struct sequence_gen_impl struct sequence_gen_impl
{ {
static constexpr index_t NRemainLeft = NRemain / 2; static constexpr index_t NRemainLeft = NRemain / 2;
...@@ -188,20 +200,20 @@ struct sequence_gen_impl ...@@ -188,20 +200,20 @@ struct sequence_gen_impl
typename sequence_gen_impl<IMiddle, NRemainRight, F>::type>::type; typename sequence_gen_impl<IMiddle, NRemainRight, F>::type>::type;
}; };
template <index_t I, class F> template <index_t I, typename F>
struct sequence_gen_impl<I, 1, F> struct sequence_gen_impl<I, 1, F>
{ {
static constexpr index_t Is = F{}(Number<I>{}); static constexpr index_t Is = F{}(Number<I>{});
using type = Sequence<Is>; using type = Sequence<Is>;
}; };
template <index_t I, class F> template <index_t I, typename F>
struct sequence_gen_impl<I, 0, F> struct sequence_gen_impl<I, 0, F>
{ {
using type = Sequence<>; using type = Sequence<>;
}; };
template <index_t NSize, class F> template <index_t NSize, typename F>
struct sequence_gen struct sequence_gen
{ {
using type = typename sequence_gen_impl<0, NSize, F>::type; using type = typename sequence_gen_impl<0, NSize, F>::type;
...@@ -235,10 +247,10 @@ struct uniform_sequence_gen ...@@ -235,10 +247,10 @@ struct uniform_sequence_gen
}; };
// reverse inclusive scan (with init) sequence // reverse inclusive scan (with init) sequence
template <class, class, index_t> template <typename, typename, index_t>
struct sequence_reverse_inclusive_scan; struct sequence_reverse_inclusive_scan;
template <index_t I, index_t... Is, class Reduce, index_t Init> template <index_t I, index_t... Is, typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init> struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
{ {
using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce, Init>::type; using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce, Init>::type;
...@@ -248,23 +260,23 @@ struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init> ...@@ -248,23 +260,23 @@ struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
using type = typename sequence_merge<Sequence<new_reduce>, old_scan>::type; using type = typename sequence_merge<Sequence<new_reduce>, old_scan>::type;
}; };
template <index_t I, class Reduce, index_t Init> template <index_t I, typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce, Init> struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce, Init>
{ {
using type = Sequence<Reduce{}(I, Init)>; using type = Sequence<Reduce{}(I, Init)>;
}; };
template <class Reduce, index_t Init> template <typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init> struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init>
{ {
using type = Sequence<>; using type = Sequence<>;
}; };
// split sequence // split sequence
template <class Seq, index_t I> template <typename Seq, index_t I>
struct sequence_split struct sequence_split
{ {
static constexpr index_t NSize = Seq{}.GetSize(); static constexpr index_t NSize = Seq{}.Size();
using range0 = typename arithmetic_sequence_gen<0, I, 1>::type; using range0 = typename arithmetic_sequence_gen<0, I, 1>::type;
using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::type; using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::type;
...@@ -274,10 +286,10 @@ struct sequence_split ...@@ -274,10 +286,10 @@ struct sequence_split
}; };
// reverse sequence // reverse sequence
template <class Seq> template <typename Seq>
struct sequence_reverse struct sequence_reverse
{ {
static constexpr index_t NSize = Seq{}.GetSize(); static constexpr index_t NSize = Seq{}.Size();
using seq_split = sequence_split<Seq, NSize / 2>; using seq_split = sequence_split<Seq, NSize / 2>;
using type = typename sequence_merge< using type = typename sequence_merge<
...@@ -297,19 +309,102 @@ struct sequence_reverse<Sequence<I0, I1>> ...@@ -297,19 +309,102 @@ struct sequence_reverse<Sequence<I0, I1>>
using type = Sequence<I1, I0>; using type = Sequence<I1, I0>;
}; };
template <class Seq, class Compare> template <typename Seq, typename Compare>
struct sequence_sort struct sequence_sort
{ {
// not implemented template <typename SeqLeft, typename SeqRight, typename MergedSeq, typename Comp>
struct sorted_sequence_merge_impl
{
static constexpr bool pick_left = SeqLeft::Front() < SeqRight::Front();
static constexpr index_t next_value = pick_left ? SeqLeft::Front() : SeqRight::Front();
using new_merged_seq = decltype(MergedSeq::PushBack(Number<next_value>{}));
using new_left_seq =
typename conditional<pick_left, decltype(SeqLeft::PopFront()), SeqLeft>::type;
using new_right_seq =
typename conditional<pick_left, SeqRight, decltype(SeqRight::PopFront())>::type;
using type =
typename sorted_sequence_merge_impl<new_left_seq, new_right_seq, new_merged_seq, Comp>::
type;
};
template <typename SeqLeft, typename MergedSeq, typename Comp>
struct sorted_sequence_merge_impl<SeqLeft, Sequence<>, MergedSeq, Comp>
{
using type = typename sequence_merge<MergedSeq, SeqLeft>::type;
};
template <typename SeqRight, typename MergedSeq, typename Comp>
struct sorted_sequence_merge_impl<Sequence<>, SeqRight, MergedSeq, Comp>
{
using type = typename sequence_merge<MergedSeq, SeqRight>::type;
};
template <typename Seq0, typename Seq1, typename Comp>
struct sorted_sequence_merge
{
using type = typename sorted_sequence_merge_impl<Seq0, Seq1, Sequence<>, Comp>::type;
};
using split = sequence_split<Seq, Seq::Size() / 2>;
using unsorted_left = typename split::SeqType0;
using unsorted_right = typename split::SeqType1;
using sorted_left = typename sequence_sort<unsorted_left, Compare>::type;
using sorted_right = typename sequence_sort<unsorted_right, Compare>::type;
using type = typename sorted_sequence_merge<sorted_left, sorted_right, Compare>::type;
};
template <index_t X, index_t Y, typename Compare>
struct sequence_sort<Sequence<X, Y>, Compare>
{
static constexpr bool x_first = Compare{}(X, Y);
using type = typename conditional<x_first, Sequence<X, Y>, Sequence<Y, X>>::type;
};
template <index_t X, typename Compare>
struct sequence_sort<Sequence<X>, Compare>
{
using type = Sequence<X>;
}; };
template <class Seq, class Compare> template <typename Seq, typename Less, typename Equal>
struct sequence_unique_sort struct sequence_unique_sort
{ {
// not implemented template <typename WorkInputSeq, typename WorkOutputSeq, typename Eq>
struct sorted_sequence_uniquify_impl
{
static constexpr index_t new_value = WorkInputSeq::Front();
using new_work_input_seq = decltype(WorkInputSeq::PopFront());
using new_working_output_seq =
typename conditional<new_value == WorkOutputSeq::Back(),
WorkOutputSeq,
decltype(WorkOutputSeq::PopBack(Number<new_value>{}))>::type;
};
template <typename WorkInputSeq, typename Eq>
struct sorted_sequence_uniquify_impl<WorkInputSeq, Sequence<>, Eq>
{
using type = WorkInputSeq;
};
template <typename SortedSeq, typename Eq>
struct sorted_sequence_uniquify
{
using type = typename sorted_sequence_uniquify_impl<SortedSeq, Sequence<>, Eq>::type;
};
using sorted_seq = typename sequence_sort<Seq, Less>::type;
using type = typename sorted_sequence_uniquify<sorted_seq, Equal>::type;
}; };
template <class Seq> template <typename Seq>
struct is_valid_sequence_map struct is_valid_sequence_map
{ {
// not implemented yet, always return true // not implemented yet, always return true
...@@ -317,36 +412,35 @@ struct is_valid_sequence_map ...@@ -317,36 +412,35 @@ struct is_valid_sequence_map
// TODO: add proper check for is_valid, something like: // TODO: add proper check for is_valid, something like:
// static constexpr bool value = // static constexpr bool value =
// is_same<typename arithmetic_sequence_gen<0, Seq::GetSize(), 1>::type, // is_same<typename arithmetic_sequence_gen<0, Seq::Size(), 1>::type,
// typename sequence_sort<Seq>::SortedSeqType>{}; // typename sequence_sort<Seq>::SortedSeqType>{};
}; };
template <class X2Y, class WorkingY2X, index_t XBegin, index_t XRemain> template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
struct sequence_map_inverse_impl struct sequence_map_inverse_impl
{ {
private: private:
static constexpr auto new_y2x = static constexpr auto new_y2x = WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});
WorkingY2X::Modify(X2Y::Get(Number<XBegin>{}), Number<XBegin>{});
public: public:
using type = using type =
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::type; typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::type;
}; };
template <class X2Y, class WorkingY2X, index_t XBegin> template <typename X2Y, typename WorkingY2X, index_t XBegin>
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0> struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
{ {
using type = WorkingY2X; using type = WorkingY2X;
}; };
template <class X2Y> template <typename X2Y>
struct sequence_map_inverse struct sequence_map_inverse
{ {
using type = using type =
typename sequence_map_inverse_impl<X2Y, typename sequence_map_inverse_impl<X2Y,
typename uniform_sequence_gen<X2Y::GetSize(), 0>::type, typename uniform_sequence_gen<X2Y::Size(), 0>::type,
0, 0,
X2Y::GetSize()>::type; X2Y::Size()>::type;
}; };
template <index_t... Xs, index_t... Ys> template <index_t... Xs, index_t... Ys>
...@@ -457,20 +551,26 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>) ...@@ -457,20 +551,26 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
return Sequence<Is...>{}; return Sequence<Is...>{};
} }
template <class Seq> template <typename Seq>
__host__ __device__ constexpr auto sequence_pop_back(Seq) __host__ __device__ constexpr auto sequence_pop_back(Seq)
{ {
static_assert(Seq::GetSize() > 0, "wrong! cannot pop an empty Sequence!"); static_assert(Seq::Size() > 0, "wrong! cannot pop an empty Sequence!");
return sequence_pop_front(Seq::Reverse()).Reverse(); return sequence_pop_front(Seq::Reverse()).Reverse();
} }
template <class F, index_t... Xs> template <typename F, index_t... Xs>
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>) __host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>)
{ {
return Sequence<f(Xs)...>{}; return Sequence<f(Xs)...>{};
} }
template <class F, index_t... Xs, index_t... Ys> template <typename... Seqs>
__host__ __device__ constexpr auto merge_sequences(Seqs...)
{
return typename sequence_merge<Seqs...>::type{};
}
template <typename F, index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>) __host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>)
{ {
static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize, "Dim not the same"); static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize, "Dim not the same");
...@@ -478,7 +578,7 @@ __host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Seq ...@@ -478,7 +578,7 @@ __host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Seq
return Sequence<f(Xs, Ys)...>{}; return Sequence<f(Xs, Ys)...>{};
} }
template <class F, index_t... Xs, index_t... Ys, index_t... Zs> template <typename F, index_t... Xs, index_t... Ys, index_t... Zs>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>) transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
{ {
...@@ -489,19 +589,19 @@ transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>) ...@@ -489,19 +589,19 @@ transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
return Sequence<f(Xs, Ys, Zs)...>{}; return Sequence<f(Xs, Ys, Zs)...>{};
} }
template <class Seq, class Reduce, index_t Init> template <typename Seq, typename Reduce, index_t Init>
__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number<Init>) __host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number<Init>)
{ {
return typename sequence_reverse_inclusive_scan<Seq, Reduce, Init>::type{}; return typename sequence_reverse_inclusive_scan<Seq, Reduce, Init>::type{};
} }
template <class Seq, class Reduce, index_t Init> template <typename Seq, typename Reduce, index_t Init>
__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<Init>) __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<Init>)
{ {
return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse(); return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse();
} }
template <class Seq, class Reduce> template <typename Seq, typename Reduce>
struct lambda_accumulate_on_sequence struct lambda_accumulate_on_sequence
{ {
const Reduce& f; const Reduce& f;
...@@ -512,14 +612,14 @@ struct lambda_accumulate_on_sequence ...@@ -512,14 +612,14 @@ struct lambda_accumulate_on_sequence
{ {
} }
template <class IDim> template <typename IDim>
__host__ __device__ constexpr index_t operator()(IDim) const __host__ __device__ constexpr index_t operator()(IDim) const
{ {
return result = f(result, Seq::Get(IDim{})); return result = f(result, Seq::At(IDim{}));
} }
}; };
template <class Seq, class Reduce, index_t Init> template <typename Seq, typename Reduce, index_t Init>
__host__ __device__ constexpr index_t __host__ __device__ constexpr index_t
accumulate_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/) accumulate_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
{ {
...@@ -530,41 +630,5 @@ accumulate_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/) ...@@ -530,41 +630,5 @@ accumulate_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
return result; return result;
} }
template <index_t... Xs>
__host__ __device__ void print_Sequence(const char* s, Sequence<Xs...>)
{
constexpr index_t nsize = Sequence<Xs...>::GetSize();
static_assert(nsize <= 10, "wrong!");
static_if<nsize == 0>{}([&](auto) { printf("%s size %u, {}\n", s, nsize, Xs...); });
static_if<nsize == 1>{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, Xs...); });
static_if<nsize == 2>{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, Xs...); });
static_if<nsize == 3>{}([&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 4>{}([&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 5>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 6>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 7>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 8>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 9>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 10>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
}
} // namespace ck } // namespace ck
#endif #endif
#ifndef CK_ARRAY_HELPER_HPP
#define CK_ARRAY_HELPER_HPP
#include "Array.hpp"
namespace ck {
template <typename T, index_t NSize>
__host__ __device__ void print_Array(const char* s, Array<T, NSize> a)
{
constexpr index_t nsize = a.GetSize();
static_assert(nsize > 0 && nsize <= 10, "wrong!");
static_if<nsize == 1>{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, a[0]); });
static_if<nsize == 2>{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, a[0], a[1]); });
static_if<nsize == 3>{}(
[&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, a[0], a[1], a[2]); });
static_if<nsize == 4>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3]); });
static_if<nsize == 5>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4]);
});
static_if<nsize == 6>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5]);
});
static_if<nsize == 7>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6]);
});
static_if<nsize == 8>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7]);
});
static_if<nsize == 9>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7],
a[8]);
});
static_if<nsize == 10>{}([&](auto) {
printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n",
s,
nsize,
a[0],
a[1],
a[2],
a[3],
a[4],
a[5],
a[6],
a[7],
a[8],
a[9]);
});
}
} // namespace ck
#endif
\ No newline at end of file
...@@ -4,14 +4,19 @@ ...@@ -4,14 +4,19 @@
#include "config.hpp" #include "config.hpp"
#include "utility.hpp" #include "utility.hpp"
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "number.hpp"
#include "type.hpp"
#include "tuple.hpp" #include "tuple.hpp"
#include "math.hpp" #include "math.hpp"
#include "vector_type.hpp" #include "vector_type.hpp"
#include "Sequence.hpp" #include "Sequence.hpp"
#include "sequence_helper.hpp"
#include "Array.hpp" #include "Array.hpp"
#include "array_helper.hpp"
#include "functional.hpp" #include "functional.hpp"
#include "functional2.hpp" #include "functional2.hpp"
#include "functional3.hpp" #include "functional3.hpp"
#include "functional4.hpp"
#if CK_USE_AMD_INLINE_ASM #if CK_USE_AMD_INLINE_ASM
#include "amd_inline_asm.hpp" #include "amd_inline_asm.hpp"
......
...@@ -3,9 +3,11 @@ ...@@ -3,9 +3,11 @@
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "Sequence.hpp" #include "Sequence.hpp"
#include "type.hpp"
namespace ck { namespace ck {
// TODO: right? wrong?
struct forwarder struct forwarder
{ {
template <typename T> template <typename T>
...@@ -17,7 +19,7 @@ struct forwarder ...@@ -17,7 +19,7 @@ struct forwarder
struct swallow struct swallow
{ {
template <class... Ts> template <typename... Ts>
__host__ __device__ constexpr swallow(Ts&&...) __host__ __device__ constexpr swallow(Ts&&...)
{ {
} }
...@@ -32,7 +34,7 @@ struct static_if<true> ...@@ -32,7 +34,7 @@ struct static_if<true>
{ {
using Type = static_if<true>; using Type = static_if<true>;
template <class F> template <typename F>
__host__ __device__ constexpr auto operator()(F f) const __host__ __device__ constexpr auto operator()(F f) const
{ {
// This is a trick for compiler: // This is a trick for compiler:
...@@ -43,7 +45,7 @@ struct static_if<true> ...@@ -43,7 +45,7 @@ struct static_if<true>
return Type{}; return Type{};
} }
template <class F> template <typename F>
__host__ __device__ static constexpr auto Else(F) __host__ __device__ static constexpr auto Else(F)
{ {
return Type{}; return Type{};
...@@ -55,13 +57,13 @@ struct static_if<false> ...@@ -55,13 +57,13 @@ struct static_if<false>
{ {
using Type = static_if<false>; using Type = static_if<false>;
template <class F> template <typename F>
__host__ __device__ constexpr auto operator()(F) const __host__ __device__ constexpr auto operator()(F) const
{ {
return Type{}; return Type{};
} }
template <class F> template <typename F>
__host__ __device__ static constexpr auto Else(F f) __host__ __device__ static constexpr auto Else(F f)
{ {
// This is a trick for compiler: // This is a trick for compiler:
...@@ -73,5 +75,23 @@ struct static_if<false> ...@@ -73,5 +75,23 @@ struct static_if<false>
} }
}; };
template <bool predicate, class X, class Y>
struct conditional;
template <class X, class Y>
struct conditional<true, X, Y>
{
using type = X;
};
template <class X, class Y>
struct conditional<false, X, Y>
{
using type = Y;
};
template <bool predicate, class X, class Y>
using conditional_t = typename conditional<predicate, X, Y>::type;
} // namespace ck } // namespace ck
#endif #endif
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
namespace ck { namespace ck {
namespace detail {
template <class> template <class>
struct static_for_impl; struct static_for_impl;
...@@ -19,6 +21,8 @@ struct static_for_impl<Sequence<Is...>> ...@@ -19,6 +21,8 @@ struct static_for_impl<Sequence<Is...>>
} }
}; };
} // namespace detail
// F signature: F(Number<Iter>) // F signature: F(Number<Iter>)
template <index_t NBegin, index_t NEnd, index_t Increment> template <index_t NBegin, index_t NEnd, index_t Increment>
struct static_for struct static_for
...@@ -33,7 +37,8 @@ struct static_for ...@@ -33,7 +37,8 @@ struct static_for
template <class F> template <class F>
__host__ __device__ constexpr void operator()(F f) const __host__ __device__ constexpr void operator()(F f) const
{ {
static_for_impl<typename arithmetic_sequence_gen<NBegin, NEnd, Increment>::type>{}(f); detail::static_for_impl<typename arithmetic_sequence_gen<NBegin, NEnd, Increment>::type>{}(
f);
} }
}; };
......
...@@ -8,20 +8,7 @@ ...@@ -8,20 +8,7 @@
namespace ck { namespace ck {
template <class> namespace detail {
struct is_static : integral_constant<bool, false>
{
};
template <class T, T X>
struct is_static<integral_constant<T, X>> : integral_constant<bool, true>
{
};
template <index_t... Is>
struct is_static<Sequence<Is...>> : integral_constant<bool, true>
{
};
// RemainLengths: Sequence<...> // RemainLengths: Sequence<...>
// Orders: Sequence<...> // Orders: Sequence<...>
...@@ -58,29 +45,6 @@ struct static_ford_impl<Sequence<>, Orders> ...@@ -58,29 +45,6 @@ struct static_ford_impl<Sequence<>, Orders>
} }
}; };
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which static_ford will loop over each
// dimension
template <class Lengths,
class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
struct static_ford
{
__host__ __device__ constexpr static_ford()
{
static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
}
// F signature: F(Sequence<...> multi_id)
// multi_id is the unordered multi-index
template <class F>
__host__ __device__ constexpr void operator()(F f) const
{
constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});
static_ford_impl<decltype(ordered_lengths), Orders>{}(f, Sequence<>{});
}
};
// RemainLengths: Sequence<...> // RemainLengths: Sequence<...>
// Orders: Sequence<...> // Orders: Sequence<...>
template <class RemainLengths, class Orders> template <class RemainLengths, class Orders>
...@@ -117,6 +81,31 @@ struct ford_impl<Sequence<>, Orders> ...@@ -117,6 +81,31 @@ struct ford_impl<Sequence<>, Orders>
} }
}; };
} // namespace detail
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which static_ford will loop over each
// dimension
template <class Lengths,
class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
struct static_ford
{
__host__ __device__ constexpr static_ford()
{
static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
}
// F signature: F(Sequence<...> multi_id)
// multi_id is the unordered multi-index
template <class F>
__host__ __device__ constexpr void operator()(F f) const
{
constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});
detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, Sequence<>{});
}
};
// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop // Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which ford will loop over each // Orders is Sequence<...>, it is the order of dimension in which ford will loop over each
// dimension // dimension
...@@ -139,7 +128,8 @@ struct ford ...@@ -139,7 +128,8 @@ struct ford
for(index_t i = 0; i < ordered_lengths.Front(); ++i) for(index_t i = 0; i < ordered_lengths.Front(); ++i)
{ {
ford_impl<decltype(ordered_lengths.PopFront()), Orders>{}(f, Array<index_t, 1>{i}); detail::ford_impl<decltype(ordered_lengths.PopFront()), Orders>{}(f,
Array<index_t, 1>{i});
} }
} }
}; };
......
#ifndef CK_FUNCTIONAL4_HPP
#define CK_FUNCTIONAL4_HPP
#include "Sequence.hpp"
#include "tuple.hpp"
#include "Array.hpp"
namespace ck {
namespace detail {
template <typename Indices>
struct unpack_impl;
template <index_t... Is>
struct unpack_impl<Sequence<Is...>>
{
template <typename F, typename X>
__host__ __device__ constexpr auto operator()(F f, const X& x) const
{
return f(x.At(Number<Is>{})...);
}
};
} // namespace detail
template <typename F, typename X>
__host__ __device__ constexpr auto unpack(F f, const X& x)
{
return detail::unpack_impl<typename arithmetic_sequence_gen<0, X::Size(), 1>::type>{}(f, x);
}
} // namespace ck
#endif
...@@ -13,54 +13,5 @@ struct integral_constant ...@@ -13,54 +13,5 @@ struct integral_constant
__host__ __device__ constexpr value_type operator()() const noexcept { return value; } __host__ __device__ constexpr value_type operator()() const noexcept { return value; }
}; };
template <class X, class Y>
struct is_same : public integral_constant<bool, false>
{
};
template <class X>
struct is_same<X, X> : public integral_constant<bool, true>
{
};
template <class T>
using remove_cv_t = typename std::remove_cv<T>::type;
template <index_t N>
using Number = integral_constant<index_t, N>;
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator+(Number<X>, Number<Y>)
{
return Number<X + Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator-(Number<X>, Number<Y>)
{
static_assert(Y <= X, "wrong!");
return Number<X - Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator*(Number<X>, Number<Y>)
{
return Number<X * Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator/(Number<X>, Number<Y>)
{
static_assert(Y > 0, "wrong!");
return Number<X / Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator%(Number<X>, Number<Y>)
{
static_assert(Y > 0, "wrong!");
return Number<X % Y>{};
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -104,6 +104,18 @@ __host__ __device__ constexpr T lcm(T x, Ts... xs) ...@@ -104,6 +104,18 @@ __host__ __device__ constexpr T lcm(T x, Ts... xs)
return max(x, xs...); return max(x, xs...);
} }
template <class T>
struct equal
{
__host__ __device__ constexpr bool operator()(T x, T y) const { return x == y; }
};
template <class T>
struct less
{
__host__ __device__ constexpr bool operator()(T x, T y) const { return x < y; }
};
} // namespace math } // namespace math
} // namspace ck } // namspace ck
......
#ifndef CK_NUMBER_HPP
#define CK_NUMBER_HPP
#include "integral_constant.hpp"
namespace ck {
template <index_t N>
using Number = integral_constant<index_t, N>;
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator+(Number<X>, Number<Y>)
{
return Number<X + Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator-(Number<X>, Number<Y>)
{
static_assert(Y <= X, "wrong!");
return Number<X - Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator*(Number<X>, Number<Y>)
{
return Number<X * Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator/(Number<X>, Number<Y>)
{
static_assert(Y > 0, "wrong!");
return Number<X / Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator%(Number<X>, Number<Y>)
{
static_assert(Y > 0, "wrong!");
return Number<X % Y>{};
}
} // namespace ck
#endif
#ifndef CK_SEQUENCE_HELPER_HPP
#define CK_SEQUENCE_HELPER_HPP
#include "Sequence.hpp"
namespace ck {
template <index_t... Xs>
__host__ __device__ void print_Sequence(const char* s, Sequence<Xs...>)
{
constexpr index_t nsize = Sequence<Xs...>::Size();
static_assert(nsize <= 10, "wrong!");
static_if<nsize == 0>{}([&](auto) { printf("%s size %u, {}\n", s, nsize, Xs...); });
static_if<nsize == 1>{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, Xs...); });
static_if<nsize == 2>{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, Xs...); });
static_if<nsize == 3>{}([&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 4>{}([&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 5>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 6>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 7>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 8>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 9>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
static_if<nsize == 10>{}(
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
}
} // namespace ck
#endif
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define CK_TUPLE_HPP #define CK_TUPLE_HPP
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "type.hpp"
#include "Sequence.hpp" #include "Sequence.hpp"
namespace ck { namespace ck {
...@@ -16,6 +17,8 @@ struct TupleElementKey ...@@ -16,6 +17,8 @@ struct TupleElementKey
template <typename Key, typename Data> template <typename Key, typename Data>
struct TupleElement struct TupleElement
{ {
__host__ __device__ explicit constexpr TupleElement() : mData() {}
template <typename T> template <typename T>
__host__ __device__ explicit constexpr TupleElement(T&& v) : mData(static_cast<T&&>(v)) __host__ __device__ explicit constexpr TupleElement(T&& v) : mData(static_cast<T&&>(v))
{ {
...@@ -48,6 +51,12 @@ struct TupleImpl; ...@@ -48,6 +51,12 @@ struct TupleImpl;
template <index_t... Is, typename... Xs> template <index_t... Is, typename... Xs>
struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>... struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>...
{ {
#if 1
__host__ __device__ explicit constexpr TupleImpl() : TupleElement<TupleElementKey<Is>, Xs>()...
{
}
#endif
template <typename... Ys> template <typename... Ys>
__host__ __device__ explicit constexpr TupleImpl(Ys&&... ys) __host__ __device__ explicit constexpr TupleImpl(Ys&&... ys)
: TupleElement<TupleElementKey<Is>, Xs>(static_cast<Ys&&>(ys))... : TupleElement<TupleElementKey<Is>, Xs>(static_cast<Ys&&>(ys))...
...@@ -97,5 +106,28 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -97,5 +106,28 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
} }
}; };
template <typename... Xs>
__host__ __device__ constexpr auto make_tuple(Xs&&... xs)
{
return Tuple<remove_cv_t<remove_reference_t<Xs>>...>(std::forward<Xs>(xs)...);
}
namespace detail {
template <typename X, typename F, index_t... Is>
__host__ __device__ constexpr auto transpose_tuple_impl(X& x, F f, Sequence<Is...>)
{
return make_tuple(f(x.At(Number<Is>{}))...);
}
} // namespace detail
template <typename X, typename F>
__host__ __device__ constexpr auto transpose_tuple(X& x, F f)
{
return detail::transpose_tuple_impl(
x, f, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
}
} // namespace ck } // namespace ck
#endif #endif
#ifndef CK_TYPE_HPP
#define CK_TYPE_HPP
#include "integral_constant.hpp"
#include "Sequence.hpp"
namespace ck {
template <typename X, typename Y>
struct is_same : public integral_constant<bool, false>
{
};
template <typename X>
struct is_same<X, X> : public integral_constant<bool, true>
{
};
template <typename>
struct is_static : integral_constant<bool, false>
{
};
template <typename T, T X>
struct is_static<integral_constant<T, X>> : integral_constant<bool, true>
{
};
template <index_t... Is>
struct is_static<Sequence<Is...>> : integral_constant<bool, true>
{
};
template <typename T>
using remove_reference_t = typename std::remove_reference<T>::type;
template <typename T>
using remove_cv_t = typename std::remove_cv<T>::type;
} // namespace ck
#endif
...@@ -115,8 +115,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded(InDesc, ...@@ -115,8 +115,12 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded(InDesc,
constexpr index_t OutThreadCopyDataPerAccess_N = 4; constexpr index_t OutThreadCopyDataPerAccess_N = 4;
#endif #endif
#if 0 // debug
constexpr index_t GridSize = constexpr index_t GridSize =
(N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock); (N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock);
#else
constexpr index_t GridSize = 1;
#endif
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment