Commit ca42e910 authored by Chao Liu's avatar Chao Liu
Browse files

adding merge transform

parent 7a7fe160
...@@ -528,34 +528,64 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded ...@@ -528,34 +528,64 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
#elif 1 #elif 1
// create a native tensor descriptor // create a native tensor descriptor
constexpr auto in_c_h_w_n_global_desc = constexpr auto in_c_h_w_n_global_desc =
make_NativeTensorDescriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides()); make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0); constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0);
constexpr index_t Hi = in_c_h_w_n_global_desc.GetLength(I1); constexpr index_t Hi = in_c_h_w_n_global_desc.GetLength(I1);
constexpr index_t Wi = in_c_h_w_n_global_desc.GetLength(I2); constexpr index_t Wi = in_c_h_w_n_global_desc.GetLength(I2);
constexpr index_t N = in_c_h_w_n_global_desc.GetLength(I3); constexpr index_t N = in_c_h_w_n_global_desc.GetLength(I3);
constexpr auto pad_h_w = Pad<Sequence<Hi, Wi>, LowerPads, UpperPads>{}; // transformation: {c, h, w, n} --> {n, c, hp, wp}
constexpr auto pass_c = PassThrough<C>{}; // {h, w} --> {hp, wp}, {c} --> {c}, {n} --> {n}
constexpr auto pass_n = PassThrough<N>{}; constexpr auto in_n_c_hp_wp_global_desc = transform_tensor_descriptor(
in_c_h_w_n_global_desc,
make_tuple(
Pad<Sequence<Hi, Wi>, LowerPads, UpperPads>{}, PassThrough<C>{}, PassThrough<N>{}),
make_tuple(Sequence<1, 2>{}, Sequence<0>{}, Sequence<3>{}),
make_tuple(Sequence<2, 3>{}, Sequence<1>{}, Sequence<0>{}));
constexpr auto trans = make_tuple(pass_c, pad_h_w, pass_n); #if 1
constexpr auto lower_dim_groups = // transformation: {n, c, hp, wp} --> {c, b}
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}); // {n, hp, wp} --> {b}, {c} --> {c}
constexpr auto upper_dim_groups = constexpr auto in_c_b_global_desc = transform_tensor_descriptor(
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}); in_n_c_hp_wp_global_desc,
make_tuple(Merge<decltype(in_n_c_hp_wp_global_desc.GetLengths(I0, I2, I3))>{},
constexpr auto in_c_h_w_n_padded_global_desc = transform_tensor_descriptor( PassThrough<in_n_c_hp_wp_global_desc.GetLength(I1)>{}),
in_c_h_w_n_global_desc, trans, lower_dim_groups, upper_dim_groups); make_tuple(Sequence<0, 2, 3>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
#endif
#if 1
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{ {
// 0
print_tensor_descriptor("in_c_h_w_n_global_desc", in_c_h_w_n_global_desc); print_tensor_descriptor("in_c_h_w_n_global_desc", in_c_h_w_n_global_desc);
printf("offset: %lu\n", in_c_h_w_n_global_desc.GetOffset({1, 2, 3, 4})); // 1
print_tensor_descriptor("in_n_c_hp_wp_global_desc", in_n_c_hp_wp_global_desc);
// 2
print_tensor_descriptor("in_c_b_global_desc", in_c_b_global_desc);
constexpr auto idx2 = MultiIndex<2>{1, 4 * (16 * 16) + 5 * 16 + 6};
auto idx1 = in_c_b_global_desc.CalculateLowerIndex(idx2);
auto idx0 = in_c_b_global_desc.GetLowerTensorDescriptor().CalculateLowerIndex(idx1);
printf("padded offset: %lu\n", in_c_h_w_n_padded_global_desc.GetOffset({1, 4, 5, 4})); print_array("idx2: ", idx2);
print_array("idx1: ", idx1);
print_array("idx0: ", idx0);
printf("in_c_b_global_desc offset: %lu\n", in_c_b_global_desc.CalculateOffset(idx2));
}
#else
{
index_t c = static_cast<index_t>(threadIdx.x);
index_t h = static_cast<index_t>(threadIdx.y);
index_t w = static_cast<index_t>(threadIdx.z);
p_out_global[0] = in_n_c_h_w_padded_global_desc.CalculateOffset({1, c, h, w});
} }
#endif
#endif #endif
} }
#endif #endif
......
...@@ -18,9 +18,9 @@ struct NativeDimension ...@@ -18,9 +18,9 @@ struct NativeDimension
__host__ __device__ static constexpr auto GetStride() { return Number<Stride>{}; } __host__ __device__ static constexpr auto GetStride() { return Number<Stride>{}; }
__host__ __device__ static constexpr index_t GetOffset(index_t i) { return i * Stride; } __host__ __device__ static constexpr index_t CalculateOffset(index_t i) { return i * Stride; }
__host__ __device__ static constexpr index_t GetOffsetDiff(index_t i_diff) __host__ __device__ static constexpr index_t CalculateOffsetDiff(index_t i_diff)
{ {
return i_diff * Stride; return i_diff * Stride;
} }
......
...@@ -22,9 +22,12 @@ struct PassThrough ...@@ -22,9 +22,12 @@ struct PassThrough
__host__ __device__ static constexpr auto GetUpperLengths() { return Sequence<Length>{}; } __host__ __device__ static constexpr auto GetUpperLengths() { return Sequence<Length>{}; }
__host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up) { return idx_up; } __host__ __device__ static constexpr auto CalculateLowerIndex(UpperIndex idx_up)
{
return idx_up;
}
__host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff) __host__ __device__ static constexpr auto CalculateLowerIndexDiff(UpperIndex idx_up_diff)
{ {
return idx_up_diff; return idx_up_diff;
} }
...@@ -36,7 +39,7 @@ struct PassThrough ...@@ -36,7 +39,7 @@ struct PassThrough
template <typename LowLengths, typename LeftPads, typename RightPads> template <typename LowLengths, typename LeftPads, typename RightPads>
struct Pad struct Pad
{ {
static constexpr index_t nDim = LowLengths::GetSize(); static constexpr index_t nDim = LowLengths::Size();
using LowerIndex = MultiIndex<nDim>; using LowerIndex = MultiIndex<nDim>;
using UpperIndex = MultiIndex<nDim>; using UpperIndex = MultiIndex<nDim>;
...@@ -52,12 +55,12 @@ struct Pad ...@@ -52,12 +55,12 @@ struct Pad
return GetLowerLengths() + LeftPads{} + RightPads{}; return GetLowerLengths() + LeftPads{} + RightPads{};
} }
__host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up) __host__ __device__ static constexpr auto CalculateLowerIndex(UpperIndex idx_up)
{ {
return idx_up - LeftPads{}; return idx_up - LeftPads{};
} }
__host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff) __host__ __device__ static constexpr auto CalculateLowerIndexDiff(UpperIndex idx_up_diff)
{ {
return idx_up_diff; return idx_up_diff;
} }
...@@ -65,21 +68,20 @@ struct Pad ...@@ -65,21 +68,20 @@ struct Pad
__host__ __device__ static constexpr bool IsLinearTransform() { return true; } __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
}; };
#if 0
// LowLengths: Sequence<...> // LowLengths: Sequence<...>
template <typename LowLengths> template <typename LowLengths>
struct Merge struct Merge
{ {
static constexpr index_t nDimLow = LowLengths::GetSize(); static constexpr index_t nDimLow = LowLengths::Size();
static constexpr index_t nDimUp = 1; static constexpr index_t nDimUp = 1;
using LowerIndex = MultiIndex<nDimLow>; using LowerIndex = MultiIndex<nDimLow>;
using UpperIndex = MultiIndex<nDimUp>; using UpperIndex = MultiIndex<nDimUp>;
__host__ __device__ static constexpr auto GetNumOfUpperDimension(){return Number<nDimUp>{}};
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; } __host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; }
__host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; } __host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; }
__host__ __device__ static constexpr auto GetUpperLengths() __host__ __device__ static constexpr auto GetUpperLengths()
...@@ -88,18 +90,56 @@ struct Merge ...@@ -88,18 +90,56 @@ struct Merge
GetLowerLengths(), math::multiplies<index_t>{}, Number<1>{})>{}; GetLowerLengths(), math::multiplies<index_t>{}, Number<1>{})>{};
} }
__host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up) // emulate constexpr lambda
template <typename PseudoLowStrides>
struct lambda_CalculateLowerIndex
{
index_t& itmp;
LowerIndex& idx_low;
__host__ __device__ explicit constexpr lambda_CalculateLowerIndex(index_t& itmp_,
LowerIndex& idx_low_)
: itmp(itmp_), idx_low(idx_low_)
{
}
template <typename IDim>
__host__ __device__ constexpr void operator()(IDim idim) const
{
constexpr index_t stride = PseudoLowStrides::At(idim);
idx_low(idim) = itmp / stride;
itmp -= idx_low[idim] * stride;
}
};
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{ {
LowerIndex idx_low; LowerIndex idx_low;
// not implemeneted index_t itmp = idx_up[0];
constexpr auto pseudo_low_strides =
reverse_inclusive_scan_sequence(
GetLowerLengths().PopFront(), math::multiplies<index_t>{}, Number<1>{})
.PushBack(Number<1>{});
// calculate index in each of the dimensions in the order of their dimension
#if 1
static_for<0, nDimLow - 1, 1>{}(
lambda_CalculateLowerIndex<decltype(pseudo_low_strides)>(itmp, idx_low));
idx_low(nDimLow - 1) = itmp / pseudo_low_strides[nDimLow - 1];
#else
static_for<0, nDimLow, 1>{}(
lambda_CalculateLowerIndex<decltype(pseudo_low_strides)>(itmp, idx_low));
#endif
return idx_low; return idx_low;
} }
// idx_low_diff depends on idx_low_old, so idx_low need to be up-to-date // idx_low_diff depends on idx_low_old, so idx_low need to be up-to-date
__host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff, __host__ __device__ static constexpr auto CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
LowerIndex idx_low_old) const LowerIndex& idx_low_old)
{ {
LowerIndex idx_low_diff; LowerIndex idx_low_diff;
...@@ -110,49 +150,48 @@ struct Merge ...@@ -110,49 +150,48 @@ struct Merge
__host__ __device__ static constexpr bool IsLinearTransform() { return false; } __host__ __device__ static constexpr bool IsLinearTransform() { return false; }
}; };
#endif
// UpLengths: Sequence<...> // UpLengths: Sequence<...>
template <index_t LowLength, typename UpLengths> template <typename UpLengths>
struct Unmerge struct Unmerge
{ {
static constexpr index_t nDimLow = 1; static constexpr index_t nDimLow = 1;
static constexpr index_t nDimUp = UpLengths::GetSize(); static constexpr index_t nDimUp = UpLengths::Size();
using UpperIndex = MultiIndex<nDimUp>;
using LowerIndex = MultiIndex<nDimLow>; using LowerIndex = MultiIndex<nDimLow>;
using UpperIndex = MultiIndex<nDimUp>;
__host__ __device__ constexpr Unmerge() __host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
{
static_assert(LowLength == accumulate_on_sequence(
UpLengths{}, math::multiplies<index_t>{}, Number<1>{}),
"wrong! UpLengths need to be ");
}
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; } __host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; }
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; } __host__ __device__ static constexpr auto GetLowerLengths()
{
constexpr index_t low_length =
accumulate_on_sequence(UpLengths{}, math::multiplies<index_t>{}, Number<1>{});
__host__ __device__ static constexpr auto GetLowerLengths() { return Sequence<LowLength>{}; } return Sequence<low_length>{};
}
__host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; } __host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; }
__host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up) __host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{ {
constexpr auto scans = typename sequence_reverse_inclusive_scan<UpLengths,
math::multiplies<index_t>,
1>::type{};
LowerIndex idx_low{0}; LowerIndex idx_low{0};
static_for<0, nDimUp, 1>{}([&](auto idim) { idx_low(0) += idx_up[idim] * scans[idim]; }); constexpr auto pseudo_up_strides =
typename sequence_reverse_inclusive_scan<UpLengths, math::multiplies<index_t>, 1>::
type{};
static_for<0, nDimUp, 1>{}(
[&](auto idim) { idx_low(0) += idx_up[idim] * pseudo_up_strides[idim]; });
return idx_low; return idx_low;
} }
__host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff) __host__ __device__ static constexpr auto CalculateLowerIndexDiff(const UpperIndex& idx_up_diff)
{ {
return GetLowerIndex(idx_up_diff); return CalculateLowerIndex(idx_up_diff);
} }
__host__ __device__ static constexpr bool IsLinearTransform() { return true; } __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
...@@ -165,12 +204,12 @@ template <index_t LowLength, typename UpLengths, typename Coefficients> ...@@ -165,12 +204,12 @@ template <index_t LowLength, typename UpLengths, typename Coefficients>
struct Embed struct Embed
{ {
static constexpr index_t nDimLow = 1; static constexpr index_t nDimLow = 1;
static constexpr index_t nDimUp = UpLengths::GetSize(); static constexpr index_t nDimUp = UpLengths::Size();
using LowerIndex = MultiIndex<nDimLow>; using LowerIndex = MultiIndex<nDimLow>;
using UpperIndex = MultiIndex<nDimUp>; using UpperIndex = MultiIndex<nDimUp>;
__host__ __device__ constexpr Embed() __host__ __device__ explicit constexpr Embed()
{ {
static_assert(UpLengths::GetSize() == nDimUp && Coefficients::GetSize() == nDimUp + 1, static_assert(UpLengths::GetSize() == nDimUp && Coefficients::GetSize() == nDimUp + 1,
"wrong! # of dimensions not consistent"); "wrong! # of dimensions not consistent");
...@@ -191,7 +230,7 @@ struct Embed ...@@ -191,7 +230,7 @@ struct Embed
__host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; } __host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; }
__host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up) __host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{ {
LowerIndex idx_low(Coefficients{}[nDimUp]); LowerIndex idx_low(Coefficients{}[nDimUp]);
...@@ -201,7 +240,7 @@ struct Embed ...@@ -201,7 +240,7 @@ struct Embed
return idx_low; return idx_low;
} }
__host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff) __host__ __device__ static constexpr auto CalculateLowerIndexDiff(const UpperIndex& idx_up_diff)
{ {
LowerIndex idx_low_diff{0}; LowerIndex idx_low_diff{0};
......
...@@ -18,47 +18,53 @@ struct NativeTensorDescriptor ...@@ -18,47 +18,53 @@ struct NativeTensorDescriptor
__host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; } __host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; }
struct lambda_GetLength template <index_t IDim>
__host__ __device__ static constexpr auto GetLength(Number<IDim>)
{ {
template <typename IDim> return mDimensions.At(Number<IDim>{}).GetLength();
__host__ __device__ constexpr auto operator()(IDim) const }
{
return GetLength(IDim{});
}
};
__host__ __device__ static constexpr auto GetLengths() template <index_t IDim>
__host__ __device__ static constexpr auto GetStride(Number<IDim>)
{ {
return typename sequence_gen<nDim, lambda_GetLength>::type{}; return mDimensions.At(Number<IDim>{}).GetStride();
} }
struct lambda_GetStride template <index_t... IDims>
__host__ __device__ static constexpr auto GetLengths(Sequence<IDims...>)
{ {
template <typename IDim> return Sequence<GetLength(Number<IDims>{})...>{};
__host__ __device__ constexpr auto operator()(IDim) const }
{
return GetStride(IDim{});
}
};
__host__ __device__ static constexpr auto GetStrides() template <index_t... IDims>
__host__ __device__ static constexpr auto GetStrides(Sequence<IDims...>)
{ {
return typename sequence_gen<nDim, lambda_GetStride>::type{}; return Sequence<GetStride(Number<IDims>{})...>{};
} }
template <index_t IDim> template <index_t IDim, index_t... IDims>
__host__ __device__ static constexpr auto GetLength(Number<IDim>) __host__ __device__ static constexpr auto GetLengths(Number<IDim>, Number<IDims>...)
{ {
return mDimensions.At(Number<IDim>{}).GetLength(); return GetLengths(Sequence<IDim, IDims...>{});
} }
template <index_t IDim> template <index_t IDim, index_t... IDims>
__host__ __device__ static constexpr auto GetStride(Number<IDim>) __host__ __device__ static constexpr auto GetStrides(Number<IDim>, Number<IDims>...)
{ {
return mDimensions.At(Number<IDim>{}).GetStride(); return GetStrides(Sequence<IDim, IDims...>{});
} }
__host__ __device__ static constexpr index_t GetOffset(const Index& idx) __host__ __device__ static constexpr auto GetLengths()
{
return GetLengths(typename arithmetic_sequence_gen<0, nDim, 1>::type{});
}
__host__ __device__ static constexpr auto GetStrides()
{
return GetStrides(typename arithmetic_sequence_gen<0, nDim, 1>::type{});
}
__host__ __device__ static constexpr index_t CalculateOffset(const Index& idx)
{ {
index_t offset = 0; index_t offset = 0;
...@@ -67,7 +73,7 @@ struct NativeTensorDescriptor ...@@ -67,7 +73,7 @@ struct NativeTensorDescriptor
return offset; return offset;
} }
__host__ __device__ static constexpr index_t GetOffsetDiff(const Index& idx_diff) __host__ __device__ static constexpr index_t CalculateOffsetDiff(const Index& idx_diff)
{ {
index_t offset_diff = 0; index_t offset_diff = 0;
...@@ -161,8 +167,10 @@ struct TransformedTensorDescriptor ...@@ -161,8 +167,10 @@ struct TransformedTensorDescriptor
// UpDimensionIds should include all up-dimensions // UpDimensionIds should include all up-dimensions
// TODO: sanity check: while a up-dimension could be associated with multille // TODO: sanity check: while a up-dimension could be associated with multille
// transformation, // transformation, a low-dimension should be associated with only one transformation
// a low-dimension should be associated with only one transformation
// TODO: sanity-check: GetLowerLengths of each transform should be consistent with lengths
// of lower-tensor-descriptor
} }
__host__ __device__ static constexpr auto GetNumOfDimension() __host__ __device__ static constexpr auto GetNumOfDimension()
...@@ -170,49 +178,78 @@ struct TransformedTensorDescriptor ...@@ -170,49 +178,78 @@ struct TransformedTensorDescriptor
return GetNumOfUpperDimension(); return GetNumOfUpperDimension();
} }
#if 0 __host__ __device__ static constexpr auto GetLowerTensorDescriptor()
__host__ __device__ static constexpr auto GetUpperLengths() {
return LowTensorDescriptor{};
}
__host__ __device__ static constexpr auto GetLowerLengths()
{ {
struct lambda_get_upper_lengths return GetLowerTensorDescriptor().GetLengths();
}
struct lambda_GetUpperLengths
{
template <typename Transform>
__host__ __device__ constexpr auto operator()(const Transform& tran) const
{ {
template <typename Transform> return tran.GetUpperLengths();
__host__ __device__ constexpr auto operator()(Transform tran) const }
{ };
return tran.GetUpperLengths();
}
};
constexpr auto tuple_of_upper_lengths = __host__ __device__ static constexpr auto GetUpperLengths()
transform_tuple(Transforms, lambda_get_upper_lengths{}); {
constexpr auto tuple_of_up_lengths =
transform_tuple(lambda_GetUpperLengths{}, Transforms{});
constexpr auto all_upper_lengths = merge_tuple_of_sequences(tuple_of_upper_lengths); constexpr auto mingled_up_lengths = unpack(lambda_merge_sequences{}, tuple_of_up_lengths);
constexpr auto all_upper_dimension_ids = merge_tuple_of_sequences(UpDimensionIds{}); constexpr auto mingled_up_dimension_ids =
unpack(lambda_merge_sequences{}, UpDimensionIds{});
// TODO: sanity-check all_upper_dimension_ids contain all upper-dimensions // TODO: sanity-check mingled_up_dimension_ids contain all upper-dimensions
// TODO: sanity-check all_upper_lengths have no conflicting upper-length // TODO: sanity-check mingled_up_lengths have no conflicting upper-length
using sort_dimension_ids = // sort by upper-dimension-ids
sequence_unique_sort<decltype(all_upper_dimension_ids), math::less<index_t>>; using sort_up_dimension_ids = sequence_unique_sort<decltype(mingled_up_dimension_ids),
math::less<index_t>,
math::equal<index_t>>;
constexpr auto sorted_upper_dimension_ids = typename sort_dimension_ids::type; // sanity-check sorted-upper-dimension-ids should be Sequence<0, 1, ... nDimUp-1>
constexpr auto sorted2unsorted_map = typename sort_dimension_ids::sorted2unsorted_map_type; static_assert(is_same<typename sort_up_dimension_ids::type,
typename arithmetic_sequence_gen<0, nDimUp, 1>::type>{},
"wrong! UpDimensionIds is not configured correctly");
constexpr auto sorted_upper_lengths = constexpr auto sorted2unsorted_map = typename sort_up_dimension_ids::sorted2unsorted_map{};
sequence_element_pick(all_upper_lengths, sorted2unsorted_map);
return sorted_upper_lengths; constexpr auto sorted_up_lengths =
pick_sequence_elements(mingled_up_lengths, sorted2unsorted_map);
return sorted_up_lengths;
} }
__host__ __device__ static constexpr auto GetLengths() { return GetUpperLengths(); } __host__ __device__ static constexpr auto GetLengths() { return GetUpperLengths(); }
#endif
__host__ __device__ static constexpr auto GetLowerTensorDescriptor() template <index_t IDim>
__host__ __device__ static constexpr auto GetLength(Number<IDim>)
{ {
return LowTensorDescriptor{}; return GetLengths()[IDim];
}
template <index_t... IDims>
__host__ __device__ static constexpr auto GetLengths(Sequence<IDims...>)
{
return Sequence<GetLength(Number<IDims>{})...>{};
}
template <index_t IDim, index_t... IDims>
__host__ __device__ static constexpr auto GetLengths(Number<IDim>, Number<IDims>...)
{
return GetLengths(Sequence<IDim, IDims...>{});
} }
__host__ __device__ static constexpr LowerIndex GetLowerIndex(const UpperIndex& idx_up) // TODO: right now return value is constexpr because use of non-constepxr lambda
__host__ __device__ static constexpr LowerIndex CalculateLowerIndex(const UpperIndex& idx_up)
{ {
LowerIndex idx_low; LowerIndex idx_low;
...@@ -225,14 +262,15 @@ struct TransformedTensorDescriptor ...@@ -225,14 +262,15 @@ struct TransformedTensorDescriptor
// this assume each lower (single) index is only assocaited with one transformation, // this assume each lower (single) index is only assocaited with one transformation,
// which is required for index transformation, and has been checked during constructor // which is required for index transformation, and has been checked during constructor
// of TransformedTensorDescriptor // of TransformedTensorDescriptor
idx_low_part = tran.GetLowerIndex(to_array(idx_up_part)); idx_low_part = tran.CalculateLowerIndex(to_array(idx_up_part));
}); });
return idx_low; return idx_low;
} }
__host__ __device__ static constexpr LowerIndex GetLowerIndexDiff(const UpperIndex& idx_up_diff, // TODO: right now return value is constexpr because use of non-constepxr lambda
const LowerIndex& idx_low_old) __host__ __device__ static constexpr LowerIndex
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff, const LowerIndex& idx_low_old)
{ {
LowerIndex idx_low_diff; LowerIndex idx_low_diff;
...@@ -250,15 +288,15 @@ struct TransformedTensorDescriptor ...@@ -250,15 +288,15 @@ struct TransformedTensorDescriptor
// this assume each lower (single) index is associated with only one transformation, // this assume each lower (single) index is associated with only one transformation,
// which is required for index transformation, and has been checked during constructor // which is required for index transformation, and has been checked during constructor
// of TransformedTensorDescriptor // of TransformedTensorDescriptor
idx_low_diff_part = tran.GetLowerIndex(idx_up_diff_part, idx_low_old_part); idx_low_diff_part = tran.CalculateLowerIndex(idx_up_diff_part, idx_low_old_part);
}); });
return idx_low_diff; return idx_low_diff;
} }
__host__ __device__ static constexpr index_t GetOffset(const UpperIndex& idx_up) __host__ __device__ static constexpr index_t CalculateOffset(const UpperIndex& idx_up)
{ {
return GetLowerTensorDescriptor().GetOffset(GetLowerIndex(idx_up)); return GetLowerTensorDescriptor().CalculateOffset(CalculateLowerIndex(idx_up));
} }
#if 0 #if 0
...@@ -286,14 +324,14 @@ struct TransformedTensorDescriptor ...@@ -286,14 +324,14 @@ struct TransformedTensorDescriptor
}; };
template <index_t... Lengths, index_t... Strides> template <index_t... Lengths, index_t... Strides>
__host__ __device__ constexpr auto make_NativeTensorDescriptor(Sequence<Lengths...>, __host__ __device__ constexpr auto make_native_tensor_descriptor(Sequence<Lengths...>,
Sequence<Strides...>) Sequence<Strides...>)
{ {
return NativeTensorDescriptor<NativeDimension<Lengths, Strides>...>{}; return NativeTensorDescriptor<NativeDimension<Lengths, Strides>...>{};
} }
template <typename Lengths> template <typename Lengths>
__host__ __device__ constexpr auto make_NativeTensorDescriptor_packed(Lengths) __host__ __device__ constexpr auto make_native_tensor_descriptor_packed(Lengths)
{ {
constexpr auto strides = reverse_inclusive_scan_sequence( constexpr auto strides = reverse_inclusive_scan_sequence(
Lengths::PopFront(), math::multiplies<index_t>{}, Number<1>{}) Lengths::PopFront(), math::multiplies<index_t>{}, Number<1>{})
......
...@@ -7,12 +7,19 @@ ...@@ -7,12 +7,19 @@
namespace ck { namespace ck {
template <typename... NativeDimensions> template <typename... NativeDimensions>
__host__ __device__ void print_tensor_descriptor(const char* s, __host__ __device__ void
NativeTensorDescriptor<NativeDimensions...> desc) print_tensor_descriptor(const char* s, const NativeTensorDescriptor<NativeDimensions...>& desc)
{ {
print_tensor_descriptor_impl(s, desc.GetLengths(), desc.GetStrides()); print_tensor_descriptor_impl(s, desc.GetLengths(), desc.GetStrides());
} }
template <typename... Ts>
__host__ __device__ void print_tensor_descriptor(const char* s,
const TransformedTensorDescriptor<Ts...>& desc)
{
print_tensor_descriptor_impl(s, desc.GetLengths());
}
template <index_t... Lengths, index_t... Strides> template <index_t... Lengths, index_t... Strides>
__host__ __device__ void __host__ __device__ void
print_tensor_descriptor_impl(const char* s, Sequence<Lengths...>, Sequence<Strides...>) print_tensor_descriptor_impl(const char* s, Sequence<Lengths...>, Sequence<Strides...>)
...@@ -113,5 +120,53 @@ print_tensor_descriptor_impl(const char* s, Sequence<Lengths...>, Sequence<Strid ...@@ -113,5 +120,53 @@ print_tensor_descriptor_impl(const char* s, Sequence<Lengths...>, Sequence<Strid
}); });
} }
template <index_t... Lengths>
__host__ __device__ void print_tensor_descriptor_impl(const char* s, Sequence<Lengths...>)
{
constexpr index_t nDim = sizeof...(Lengths);
static_assert(nDim > 0 && nDim <= 12, "wrong!");
static_if<nDim == 1>{}([&](auto) { printf("%s dim %u, lengths {%u}\n", s, nDim, Lengths...); });
static_if<nDim == 2>{}(
[&](auto) { printf("%s dim %u, lengths {%u %u}\n", s, nDim, Lengths...); });
static_if<nDim == 3>{}(
[&](auto) { printf("%s dim %u, lengths {%u %u %u}\n", s, nDim, Lengths...); });
static_if<nDim == 4>{}(
[&](auto) { printf("%s dim %u, lengths {%u %u %u %u}\n", s, nDim, Lengths...); });
static_if<nDim == 5>{}(
[&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u}\n", s, nDim, Lengths...); });
static_if<nDim == 6>{}(
[&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u %u}, \n", s, nDim, Lengths...); });
static_if<nDim == 7>{}(
[&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u %u %u}\n", s, nDim, Lengths...); });
static_if<nDim == 8>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
});
static_if<nDim == 9>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
});
static_if<nDim == 10>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
});
static_if<nDim == 11>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
});
static_if<nDim == 12>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
});
}
} // namespace ck } // namespace ck
#endif #endif
#ifndef CK_ARRAY_HPP #ifndef CK_ARRAY_HPP
#define CK_ARRAY_HPP #define CK_ARRAY_HPP
#include "Sequence.hpp" #include "sequence.hpp"
#include "functional2.hpp" #include "functional2.hpp"
namespace ck { namespace ck {
...@@ -17,7 +17,7 @@ struct Array ...@@ -17,7 +17,7 @@ struct Array
__host__ __device__ explicit constexpr Array() {} __host__ __device__ explicit constexpr Array() {}
template <typename X, typename... Xs> template <typename X, typename... Xs>
__host__ __device__ explicit constexpr Array(X x, Xs... xs) __host__ __device__ constexpr Array(X x, Xs... xs)
: mData{static_cast<TData>(x), static_cast<TData>(xs)...} : mData{static_cast<TData>(x), static_cast<TData>(xs)...}
{ {
static_assert(sizeof...(Xs) + 1 == NSize, "wrong! size"); static_assert(sizeof...(Xs) + 1 == NSize, "wrong! size");
...@@ -176,7 +176,6 @@ __host__ __device__ constexpr auto pick_array_element(Arr& a, Picks) ...@@ -176,7 +176,6 @@ __host__ __device__ constexpr auto pick_array_element(Arr& a, Picks)
return ArrayElementPicker<Arr, Picks>(a); return ArrayElementPicker<Arr, Picks>(a);
} }
#if 1
template <typename T> template <typename T>
__host__ __device__ constexpr auto to_array(const T& x) __host__ __device__ constexpr auto to_array(const T& x)
{ {
...@@ -186,8 +185,8 @@ __host__ __device__ constexpr auto to_array(const T& x) ...@@ -186,8 +185,8 @@ __host__ __device__ constexpr auto to_array(const T& x)
return y; return y;
} }
#endif
// TODO: remove this
template <index_t... Is> template <index_t... Is>
__host__ __device__ constexpr auto sequence2array(Sequence<Is...>) __host__ __device__ constexpr auto sequence2array(Sequence<Is...>)
{ {
......
#ifndef CK_ARRAY_HELPER_HPP #ifndef CK_ARRAY_HELPER_HPP
#define CK_ARRAY_HELPER_HPP #define CK_ARRAY_HELPER_HPP
#include "Array.hpp" #include "array.hpp"
namespace ck { namespace ck {
template <typename T, index_t NSize> template <typename T, index_t NSize>
__host__ __device__ void print_Array(const char* s, Array<T, NSize> a) __host__ __device__ void print_array(const char* s, Array<T, NSize> a)
{ {
constexpr index_t nsize = a.GetSize(); constexpr index_t nsize = a.GetSize();
...@@ -90,4 +90,4 @@ __host__ __device__ void print_Array(const char* s, Array<T, NSize> a) ...@@ -90,4 +90,4 @@ __host__ __device__ void print_Array(const char* s, Array<T, NSize> a)
} }
} // namespace ck } // namespace ck
#endif #endif
\ No newline at end of file
...@@ -9,9 +9,9 @@ ...@@ -9,9 +9,9 @@
#include "tuple.hpp" #include "tuple.hpp"
#include "math.hpp" #include "math.hpp"
#include "vector_type.hpp" #include "vector_type.hpp"
#include "Sequence.hpp" #include "sequence.hpp"
#include "sequence_helper.hpp" #include "sequence_helper.hpp"
#include "Array.hpp" #include "array.hpp"
#include "array_helper.hpp" #include "array_helper.hpp"
#include "functional.hpp" #include "functional.hpp"
#include "functional2.hpp" #include "functional2.hpp"
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define CK_FUNCTIONAL_HPP #define CK_FUNCTIONAL_HPP
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "Sequence.hpp" #include "sequence.hpp"
#include "type.hpp" #include "type.hpp"
namespace ck { namespace ck {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define CK_FUNCTIONAL2_HPP #define CK_FUNCTIONAL2_HPP
#include "functional.hpp" #include "functional.hpp"
#include "Sequence.hpp" #include "sequence.hpp"
namespace ck { namespace ck {
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
#include "functional.hpp" #include "functional.hpp"
#include "functional2.hpp" #include "functional2.hpp"
#include "Sequence.hpp" #include "sequence.hpp"
#include "Array.hpp" #include "array.hpp"
namespace ck { namespace ck {
......
#ifndef CK_FUNCTIONAL4_HPP #ifndef CK_FUNCTIONAL4_HPP
#define CK_FUNCTIONAL4_HPP #define CK_FUNCTIONAL4_HPP
#include "Sequence.hpp" #include "sequence.hpp"
#include "tuple.hpp" #include "tuple.hpp"
#include "Array.hpp" #include "array.hpp"
namespace ck { namespace ck {
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "config.hpp" #include "config.hpp"
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "type.hpp"
namespace ck { namespace ck {
namespace math { namespace math {
......
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
#define CK_SEQUENCE_HPP #define CK_SEQUENCE_HPP
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "type.hpp"
#include "functional.hpp" #include "functional.hpp"
#include "math.hpp"
namespace ck { namespace ck {
...@@ -155,8 +157,8 @@ struct Sequence ...@@ -155,8 +157,8 @@ struct Sequence
static_assert(I < Size(), "wrong!"); static_assert(I < Size(), "wrong!");
using seq_split = sequence_split<Type, I>; using seq_split = sequence_split<Type, I>;
constexpr auto seq_left = typename seq_split::SeqType0{}; constexpr auto seq_left = typename seq_split::left_type{};
constexpr auto seq_right = typename seq_split::SeqType1{}.PopFront(); constexpr auto seq_right = typename seq_split::right_type{}.PopFront();
return seq_left.PushBack(Number<X>{}).PushBack(seq_right); return seq_left.PushBack(Number<X>{}).PushBack(seq_right);
} }
...@@ -188,34 +190,34 @@ struct sequence_merge<Seq> ...@@ -188,34 +190,34 @@ struct sequence_merge<Seq>
}; };
// generate sequence // generate sequence
template <index_t IBegin, index_t NRemain, typename F> template <index_t NSize, typename F>
struct sequence_gen_impl struct sequence_gen
{ {
static constexpr index_t NRemainLeft = NRemain / 2; template <index_t IBegin, index_t NRemain, typename G>
static constexpr index_t NRemainRight = NRemain - NRemainLeft; struct sequence_gen_impl
static constexpr index_t IMiddle = IBegin + NRemainLeft; {
static constexpr index_t NRemainLeft = NRemain / 2;
static constexpr index_t NRemainRight = NRemain - NRemainLeft;
static constexpr index_t IMiddle = IBegin + NRemainLeft;
using type = using type = typename sequence_merge<
typename sequence_merge<typename sequence_gen_impl<IBegin, NRemainLeft, F>::type, typename sequence_gen_impl<IBegin, NRemainLeft, G>::type,
typename sequence_gen_impl<IMiddle, NRemainRight, F>::type>::type; typename sequence_gen_impl<IMiddle, NRemainRight, G>::type>::type;
}; };
template <index_t I, typename F> template <index_t I, typename G>
struct sequence_gen_impl<I, 1, F> struct sequence_gen_impl<I, 1, G>
{ {
static constexpr index_t Is = F{}(Number<I>{}); static constexpr index_t Is = G{}(Number<I>{});
using type = Sequence<Is>; using type = Sequence<Is>;
}; };
template <index_t I, typename F> template <index_t I, typename G>
struct sequence_gen_impl<I, 0, F> struct sequence_gen_impl<I, 0, G>
{ {
using type = Sequence<>; using type = Sequence<>;
}; };
template <index_t NSize, typename F>
struct sequence_gen
{
using type = typename sequence_gen_impl<0, NSize, F>::type; using type = typename sequence_gen_impl<0, NSize, F>::type;
}; };
...@@ -281,8 +283,8 @@ struct sequence_split ...@@ -281,8 +283,8 @@ struct sequence_split
using range0 = typename arithmetic_sequence_gen<0, I, 1>::type; using range0 = typename arithmetic_sequence_gen<0, I, 1>::type;
using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::type; using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::type;
using SeqType0 = decltype(Seq::Extract(range0{})); using left_type = decltype(Seq::Extract(range0{}));
using SeqType1 = decltype(Seq::Extract(range1{})); using right_type = decltype(Seq::Extract(range1{}));
}; };
// reverse sequence // reverse sequence
...@@ -293,8 +295,8 @@ struct sequence_reverse ...@@ -293,8 +295,8 @@ struct sequence_reverse
using seq_split = sequence_split<Seq, NSize / 2>; using seq_split = sequence_split<Seq, NSize / 2>;
using type = typename sequence_merge< using type = typename sequence_merge<
typename sequence_reverse<typename seq_split::SeqType1>::type, typename sequence_reverse<typename seq_split::right_type>::type,
typename sequence_reverse<typename seq_split::SeqType0>::type>::type; typename sequence_reverse<typename seq_split::left_type>::type>::type;
}; };
template <index_t I> template <index_t I>
...@@ -309,138 +311,264 @@ struct sequence_reverse<Sequence<I0, I1>> ...@@ -309,138 +311,264 @@ struct sequence_reverse<Sequence<I0, I1>>
using type = Sequence<I1, I0>; using type = Sequence<I1, I0>;
}; };
template <typename Seq, typename Compare> template <typename Values, typename Ids, typename Compare>
struct sequence_sort struct sequence_sort_impl
{ {
template <typename SeqLeft, typename SeqRight, typename MergedSeq, typename Comp> template <typename LeftValues,
typename LeftIds,
typename RightValues,
typename RightIds,
typename MergedValues,
typename MergedIds,
typename Comp>
struct sorted_sequence_merge_impl struct sorted_sequence_merge_impl
{ {
static constexpr bool pick_left = SeqLeft::Front() < SeqRight::Front(); static constexpr bool choose_left = LeftValues::Front() < RightValues::Front();
static constexpr index_t next_value = pick_left ? SeqLeft::Front() : SeqRight::Front();
static constexpr index_t chosen_value =
using new_merged_seq = decltype(MergedSeq::PushBack(Number<next_value>{})); choose_left ? LeftValues::Front() : RightValues::Front();
static constexpr index_t chosen_id = choose_left ? LeftIds::Front() : RightIds::Front();
using new_left_seq =
typename conditional<pick_left, decltype(SeqLeft::PopFront()), SeqLeft>::type; using new_merged_values = decltype(MergedValues::PushBack(Number<chosen_value>{}));
using new_right_seq = using new_merged_ids = decltype(MergedIds::PushBack(Number<chosen_id>{}));
typename conditional<pick_left, SeqRight, decltype(SeqRight::PopFront())>::type;
using new_left_values =
using type = typename conditional<choose_left, decltype(LeftValues::PopFront()), LeftValues>::type;
typename sorted_sequence_merge_impl<new_left_seq, new_right_seq, new_merged_seq, Comp>:: using new_left_ids =
type; typename conditional<choose_left, decltype(LeftIds::PopFront()), LeftIds>::type;
using new_right_values =
typename conditional<choose_left, RightValues, decltype(RightValues::PopFront())>::type;
using new_right_ids =
typename conditional<choose_left, RightIds, decltype(RightIds::PopFront())>::type;
using merge = sorted_sequence_merge_impl<new_left_values,
new_left_ids,
new_right_values,
new_right_ids,
new_merged_values,
new_merged_ids,
Comp>;
// this is output
using merged_values = typename merge::merged_values;
using merged_ids = typename merge::merged_ids;
}; };
template <typename SeqLeft, typename MergedSeq, typename Comp> template <typename LeftValues,
struct sorted_sequence_merge_impl<SeqLeft, Sequence<>, MergedSeq, Comp> typename LeftIds,
typename MergedValues,
typename MergedIds,
typename Comp>
struct sorted_sequence_merge_impl<LeftValues,
LeftIds,
Sequence<>,
Sequence<>,
MergedValues,
MergedIds,
Comp>
{ {
using type = typename sequence_merge<MergedSeq, SeqLeft>::type; using merged_values = typename sequence_merge<MergedValues, LeftValues>::type;
using merged_ids = typename sequence_merge<MergedIds, LeftIds>::type;
}; };
template <typename SeqRight, typename MergedSeq, typename Comp> template <typename RightValues,
struct sorted_sequence_merge_impl<Sequence<>, SeqRight, MergedSeq, Comp> typename RightIds,
typename MergedValues,
typename MergedIds,
typename Comp>
struct sorted_sequence_merge_impl<Sequence<>,
Sequence<>,
RightValues,
RightIds,
MergedValues,
MergedIds,
Comp>
{ {
using type = typename sequence_merge<MergedSeq, SeqRight>::type; using merged_values = typename sequence_merge<MergedValues, RightValues>::type;
using merged_ids = typename sequence_merge<MergedIds, RightIds>::type;
}; };
template <typename Seq0, typename Seq1, typename Comp> template <typename LeftValues,
typename LeftIds,
typename RightValues,
typename RightIds,
typename Comp>
struct sorted_sequence_merge struct sorted_sequence_merge
{ {
using type = typename sorted_sequence_merge_impl<Seq0, Seq1, Sequence<>, Comp>::type; using merge = sorted_sequence_merge_impl<LeftValues,
LeftIds,
RightValues,
RightIds,
Sequence<>,
Sequence<>,
Comp>;
using merged_values = typename merge::merged_values;
using merged_ids = typename merge::merged_ids;
}; };
using split = sequence_split<Seq, Seq::Size() / 2>; static constexpr index_t nsize = Values::Size();
using unsorted_left = typename split::SeqType0;
using unsorted_right = typename split::SeqType1; using split_unsorted_values = sequence_split<Values, nsize / 2>;
using split_unsorted_ids = sequence_split<Ids, nsize / 2>;
using sorted_left = typename sequence_sort<unsorted_left, Compare>::type; using left_unsorted_values = typename split_unsorted_values::left_type;
using sorted_right = typename sequence_sort<unsorted_right, Compare>::type; using left_unsorted_ids = typename split_unsorted_ids::left_type;
using left_sort = sequence_sort_impl<left_unsorted_values, left_unsorted_ids, Compare>;
using left_sorted_values = typename left_sort::sorted_values;
using left_sorted_ids = typename left_sort::sorted_ids;
using type = typename sorted_sequence_merge<sorted_left, sorted_right, Compare>::type; using right_unsorted_values = typename split_unsorted_values::right_type;
using right_unsorted_ids = typename split_unsorted_ids::right_type;
using right_sort = sequence_sort_impl<right_unsorted_values, right_unsorted_ids, Compare>;
using right_sorted_values = typename right_sort::sorted_values;
using right_sorted_ids = typename right_sort::sorted_ids;
using merged_sorted = sorted_sequence_merge<left_sorted_values,
left_sorted_ids,
right_sorted_values,
right_sorted_ids,
Compare>;
using sorted_values = typename merged_sorted::merged_values;
using sorted_ids = typename merged_sorted::merged_ids;
}; };
template <index_t X, index_t Y, typename Compare> template <index_t ValueX, index_t ValueY, index_t IdX, index_t IdY, typename Compare>
struct sequence_sort<Sequence<X, Y>, Compare> struct sequence_sort_impl<Sequence<ValueX, ValueY>, Sequence<IdX, IdY>, Compare>
{ {
static constexpr bool x_first = Compare{}(X, Y); static constexpr bool choose_x = Compare{}(ValueX, ValueY);
using sorted_values =
typename conditional<choose_x, Sequence<ValueX, ValueY>, Sequence<ValueY, ValueX>>::type;
using sorted_ids = typename conditional<choose_x, Sequence<IdX, IdY>, Sequence<IdY, IdX>>::type;
};
using type = typename conditional<x_first, Sequence<X, Y>, Sequence<Y, X>>::type; template <index_t Value, index_t Id, typename Compare>
struct sequence_sort_impl<Sequence<Value>, Sequence<Id>, Compare>
{
using sorted_values = Sequence<Value>;
using sorted_ids = Sequence<Id>;
}; };
template <index_t X, typename Compare> template <typename Values, typename Compare>
struct sequence_sort<Sequence<X>, Compare> struct sequence_sort
{ {
using type = Sequence<X>; using unsorted_ids = typename arithmetic_sequence_gen<0, Values::Size(), 1>::type;
using sort = sequence_sort_impl<Values, unsorted_ids, Compare>;
// this is output
using type = typename sort::sorted_values;
using sorted2unsorted_map = typename sort::sorted_ids;
}; };
template <typename Seq, typename Less, typename Equal> template <typename Values, typename Less, typename Equal>
struct sequence_unique_sort struct sequence_unique_sort
{ {
template <typename WorkInputSeq, typename WorkOutputSeq, typename Eq> template <typename RemainValues,
typename RemainIds,
typename UniquifiedValues,
typename UniquifiedIds,
typename Eq>
struct sorted_sequence_uniquify_impl struct sorted_sequence_uniquify_impl
{ {
static constexpr index_t new_value = WorkInputSeq::Front(); static constexpr index_t current_value = RemainValues::Front();
using new_work_input_seq = decltype(WorkInputSeq::PopFront()); static constexpr index_t current_id = RemainIds::Front();
static constexpr bool is_unique_value = (current_value != UniquifiedValues::Back());
using new_remain_values = decltype(RemainValues::PopFront());
using new_remain_ids = decltype(RemainIds::PopFront());
using new_uniquified_values =
typename conditional<is_unique_value,
decltype(UniquifiedValues::PushBack(Number<current_value>{})),
UniquifiedValues>::type;
using new_working_output_seq = using new_uniquified_ids =
typename conditional<new_value == WorkOutputSeq::Back(), typename conditional<is_unique_value,
WorkOutputSeq, decltype(UniquifiedIds::PushBack(Number<current_id>{})),
decltype(WorkOutputSeq::PopBack(Number<new_value>{}))>::type; UniquifiedIds>::type;
using uniquify = sorted_sequence_uniquify_impl<new_remain_values,
new_remain_ids,
new_uniquified_values,
new_uniquified_ids,
Eq>;
// this is output
using uniquified_values = typename uniquify::uniquified_values;
using uniquified_ids = typename uniquify::uniquified_ids;
}; };
template <typename WorkInputSeq, typename Eq> template <typename UniquifiedValues, typename UniquifiedIds, typename Eq>
struct sorted_sequence_uniquify_impl<WorkInputSeq, Sequence<>, Eq> struct sorted_sequence_uniquify_impl<Sequence<>,
Sequence<>,
UniquifiedValues,
UniquifiedIds,
Eq>
{ {
using type = WorkInputSeq; using uniquified_values = UniquifiedValues;
using uniquified_ids = UniquifiedIds;
}; };
template <typename SortedSeq, typename Eq> template <typename SortedValues, typename SortedIds, typename Eq>
struct sorted_sequence_uniquify struct sorted_sequence_uniquify
{ {
using type = typename sorted_sequence_uniquify_impl<SortedSeq, Sequence<>, Eq>::type; using uniquify = sorted_sequence_uniquify_impl<decltype(SortedValues::PopFront()),
decltype(SortedIds::PopFront()),
Sequence<SortedValues::Front()>,
Sequence<SortedIds::Front()>,
Eq>;
using uniquified_values = typename uniquify::uniquified_values;
using uniquified_ids = typename uniquify::uniquified_ids;
}; };
using sorted_seq = typename sequence_sort<Seq, Less>::type; using sort = sequence_sort<Values, Less>;
using sorted_values = typename sort::type;
using sorted_ids = typename sort::sorted2unsorted_map;
using type = typename sorted_sequence_uniquify<sorted_seq, Equal>::type; using uniquify = sorted_sequence_uniquify<sorted_values, sorted_ids, Equal>;
// this is output
using type = typename uniquify::uniquified_values;
using sorted2unsorted_map = typename uniquify::uniquified_ids;
}; };
template <typename Seq> template <typename SeqMap>
struct is_valid_sequence_map struct is_valid_sequence_map
{ {
// not implemented yet, always return true static constexpr bool value =
static constexpr integral_constant<bool, true> value = integral_constant<bool, true>{}; is_same<typename arithmetic_sequence_gen<0, SeqMap::Size(), 1>::type,
typename sequence_sort<SeqMap, math::less<index_t>>::type>{};
// TODO: add proper check for is_valid, something like:
// static constexpr bool value =
// is_same<typename arithmetic_sequence_gen<0, Seq::Size(), 1>::type,
// typename sequence_sort<Seq>::SortedSeqType>{};
}; };
template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain> template <typename SeqMap>
struct sequence_map_inverse_impl struct sequence_map_inverse
{ {
private: template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
static constexpr auto new_y2x = WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{}); struct sequence_map_inverse_impl
{
static constexpr auto new_y2x =
WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});
public: using type =
using type = typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::type; type;
}; };
template <typename X2Y, typename WorkingY2X, index_t XBegin> template <typename X2Y, typename WorkingY2X, index_t XBegin>
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0> struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
{ {
using type = WorkingY2X; using type = WorkingY2X;
}; };
template <typename X2Y>
struct sequence_map_inverse
{
using type = using type =
typename sequence_map_inverse_impl<X2Y, typename sequence_map_inverse_impl<SeqMap,
typename uniform_sequence_gen<X2Y::Size(), 0>::type, typename uniform_sequence_gen<SeqMap::Size(), 0>::type,
0, 0,
X2Y::Size()>::type; SeqMap::Size()>::type;
}; };
template <index_t... Xs, index_t... Ys> template <index_t... Xs, index_t... Ys>
...@@ -601,6 +729,12 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<I ...@@ -601,6 +729,12 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<I
return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse(); return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse();
} }
template <typename Seq, index_t... Is>
__host__ __device__ constexpr auto pick_sequence_elements(Seq, Sequence<Is...>)
{
return Sequence<Seq::At(Number<Is>{})...>{};
}
template <typename Seq, typename Reduce> template <typename Seq, typename Reduce>
struct lambda_accumulate_on_sequence struct lambda_accumulate_on_sequence
{ {
......
#ifndef CK_SEQUENCE_HELPER_HPP #ifndef CK_SEQUENCE_HELPER_HPP
#define CK_SEQUENCE_HELPER_HPP #define CK_SEQUENCE_HELPER_HPP
#include "Sequence.hpp" #include "sequence.hpp"
namespace ck { namespace ck {
template <index_t... Xs> template <index_t... Xs>
__host__ __device__ void print_Sequence(const char* s, Sequence<Xs...>) __host__ __device__ void print_sequence(const char* s, Sequence<Xs...>)
{ {
constexpr index_t nsize = Sequence<Xs...>::Size(); constexpr index_t nsize = Sequence<Xs...>::Size();
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "type.hpp" #include "type.hpp"
#include "Sequence.hpp" #include "sequence.hpp"
namespace ck { namespace ck {
...@@ -114,19 +114,19 @@ __host__ __device__ constexpr auto make_tuple(Xs&&... xs) ...@@ -114,19 +114,19 @@ __host__ __device__ constexpr auto make_tuple(Xs&&... xs)
namespace detail { namespace detail {
template <typename X, typename F, index_t... Is> template <typename F, typename X, index_t... Is>
__host__ __device__ constexpr auto transpose_tuple_impl(X& x, F f, Sequence<Is...>) __host__ __device__ constexpr auto transform_tuple_impl(F f, const X& x, Sequence<Is...>)
{ {
return make_tuple(f(x.At(Number<Is>{}))...); return make_tuple(f(x.At(Number<Is>{}))...);
} }
} // namespace detail } // namespace detail
template <typename X, typename F> template <typename F, typename X>
__host__ __device__ constexpr auto transpose_tuple(X& x, F f) __host__ __device__ constexpr auto transform_tuple(F f, const X& x)
{ {
return detail::transpose_tuple_impl( return detail::transform_tuple_impl(
x, f, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{}); f, x, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
} }
} // namespace ck } // namespace ck
......
...@@ -2,10 +2,12 @@ ...@@ -2,10 +2,12 @@
#define CK_TYPE_HPP #define CK_TYPE_HPP
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "Sequence.hpp"
namespace ck { namespace ck {
template <index_t... Is>
struct Sequence;
template <typename X, typename Y> template <typename X, typename Y>
struct is_same : public integral_constant<bool, false> struct is_same : public integral_constant<bool, false>
{ {
......
...@@ -84,8 +84,8 @@ int main(int argc, char* argv[]) ...@@ -84,8 +84,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 2; constexpr index_t HPad = 3;
constexpr index_t WPad = 2; constexpr index_t WPad = 3;
#elif 1 #elif 1
// 3x3, 34x34 // 3x3, 34x34
constexpr index_t N = 64; constexpr index_t N = 64;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment