Commit ca42e910 authored by Chao Liu's avatar Chao Liu
Browse files

adding merge transform

parent 7a7fe160
......@@ -528,34 +528,64 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
#elif 1
// create a native tensor descriptor
constexpr auto in_c_h_w_n_global_desc =
make_NativeTensorDescriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0);
constexpr index_t Hi = in_c_h_w_n_global_desc.GetLength(I1);
constexpr index_t Wi = in_c_h_w_n_global_desc.GetLength(I2);
constexpr index_t N = in_c_h_w_n_global_desc.GetLength(I3);
constexpr auto pad_h_w = Pad<Sequence<Hi, Wi>, LowerPads, UpperPads>{};
constexpr auto pass_c = PassThrough<C>{};
constexpr auto pass_n = PassThrough<N>{};
// transformation: {c, h, w, n} --> {n, c, hp, wp}
// {h, w} --> {hp, wp}, {c} --> {c}, {n} --> {n}
constexpr auto in_n_c_hp_wp_global_desc = transform_tensor_descriptor(
in_c_h_w_n_global_desc,
make_tuple(
Pad<Sequence<Hi, Wi>, LowerPads, UpperPads>{}, PassThrough<C>{}, PassThrough<N>{}),
make_tuple(Sequence<1, 2>{}, Sequence<0>{}, Sequence<3>{}),
make_tuple(Sequence<2, 3>{}, Sequence<1>{}, Sequence<0>{}));
constexpr auto trans = make_tuple(pass_c, pad_h_w, pass_n);
constexpr auto lower_dim_groups =
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{});
constexpr auto upper_dim_groups =
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{});
constexpr auto in_c_h_w_n_padded_global_desc = transform_tensor_descriptor(
in_c_h_w_n_global_desc, trans, lower_dim_groups, upper_dim_groups);
#if 1
// transformation: {n, c, hp, wp} --> {c, b}
// {n, hp, wp} --> {b}, {c} --> {c}
constexpr auto in_c_b_global_desc = transform_tensor_descriptor(
in_n_c_hp_wp_global_desc,
make_tuple(Merge<decltype(in_n_c_hp_wp_global_desc.GetLengths(I0, I2, I3))>{},
PassThrough<in_n_c_hp_wp_global_desc.GetLength(I1)>{}),
make_tuple(Sequence<0, 2, 3>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
#endif
#if 1
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
// 0
print_tensor_descriptor("in_c_h_w_n_global_desc", in_c_h_w_n_global_desc);
printf("offset: %lu\n", in_c_h_w_n_global_desc.GetOffset({1, 2, 3, 4}));
// 1
print_tensor_descriptor("in_n_c_hp_wp_global_desc", in_n_c_hp_wp_global_desc);
// 2
print_tensor_descriptor("in_c_b_global_desc", in_c_b_global_desc);
constexpr auto idx2 = MultiIndex<2>{1, 4 * (16 * 16) + 5 * 16 + 6};
auto idx1 = in_c_b_global_desc.CalculateLowerIndex(idx2);
auto idx0 = in_c_b_global_desc.GetLowerTensorDescriptor().CalculateLowerIndex(idx1);
printf("padded offset: %lu\n", in_c_h_w_n_padded_global_desc.GetOffset({1, 4, 5, 4}));
print_array("idx2: ", idx2);
print_array("idx1: ", idx1);
print_array("idx0: ", idx0);
printf("in_c_b_global_desc offset: %lu\n", in_c_b_global_desc.CalculateOffset(idx2));
}
#else
{
index_t c = static_cast<index_t>(threadIdx.x);
index_t h = static_cast<index_t>(threadIdx.y);
index_t w = static_cast<index_t>(threadIdx.z);
p_out_global[0] = in_n_c_h_w_padded_global_desc.CalculateOffset({1, c, h, w});
}
#endif
#endif
}
#endif
......
......@@ -18,9 +18,9 @@ struct NativeDimension
__host__ __device__ static constexpr auto GetStride() { return Number<Stride>{}; }
__host__ __device__ static constexpr index_t GetOffset(index_t i) { return i * Stride; }
__host__ __device__ static constexpr index_t CalculateOffset(index_t i) { return i * Stride; }
__host__ __device__ static constexpr index_t GetOffsetDiff(index_t i_diff)
__host__ __device__ static constexpr index_t CalculateOffsetDiff(index_t i_diff)
{
return i_diff * Stride;
}
......
......@@ -22,9 +22,12 @@ struct PassThrough
__host__ __device__ static constexpr auto GetUpperLengths() { return Sequence<Length>{}; }
__host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up) { return idx_up; }
__host__ __device__ static constexpr auto CalculateLowerIndex(UpperIndex idx_up)
{
return idx_up;
}
__host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff)
__host__ __device__ static constexpr auto CalculateLowerIndexDiff(UpperIndex idx_up_diff)
{
return idx_up_diff;
}
......@@ -36,7 +39,7 @@ struct PassThrough
template <typename LowLengths, typename LeftPads, typename RightPads>
struct Pad
{
static constexpr index_t nDim = LowLengths::GetSize();
static constexpr index_t nDim = LowLengths::Size();
using LowerIndex = MultiIndex<nDim>;
using UpperIndex = MultiIndex<nDim>;
......@@ -52,12 +55,12 @@ struct Pad
return GetLowerLengths() + LeftPads{} + RightPads{};
}
__host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up)
__host__ __device__ static constexpr auto CalculateLowerIndex(UpperIndex idx_up)
{
return idx_up - LeftPads{};
}
__host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff)
__host__ __device__ static constexpr auto CalculateLowerIndexDiff(UpperIndex idx_up_diff)
{
return idx_up_diff;
}
......@@ -65,21 +68,20 @@ struct Pad
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
};
#if 0
// LowLengths: Sequence<...>
template <typename LowLengths>
struct Merge
{
static constexpr index_t nDimLow = LowLengths::GetSize();
static constexpr index_t nDimLow = LowLengths::Size();
static constexpr index_t nDimUp = 1;
using LowerIndex = MultiIndex<nDimLow>;
using UpperIndex = MultiIndex<nDimUp>;
__host__ __device__ static constexpr auto GetNumOfUpperDimension(){return Number<nDimUp>{}};
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; }
__host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; }
__host__ __device__ static constexpr auto GetUpperLengths()
......@@ -88,18 +90,56 @@ struct Merge
GetLowerLengths(), math::multiplies<index_t>{}, Number<1>{})>{};
}
__host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up)
// emulate constexpr lambda
template <typename PseudoLowStrides>
struct lambda_CalculateLowerIndex
{
index_t& itmp;
LowerIndex& idx_low;
__host__ __device__ explicit constexpr lambda_CalculateLowerIndex(index_t& itmp_,
LowerIndex& idx_low_)
: itmp(itmp_), idx_low(idx_low_)
{
}
template <typename IDim>
__host__ __device__ constexpr void operator()(IDim idim) const
{
constexpr index_t stride = PseudoLowStrides::At(idim);
idx_low(idim) = itmp / stride;
itmp -= idx_low[idim] * stride;
}
};
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{
LowerIndex idx_low;
// not implemeneted
index_t itmp = idx_up[0];
constexpr auto pseudo_low_strides =
reverse_inclusive_scan_sequence(
GetLowerLengths().PopFront(), math::multiplies<index_t>{}, Number<1>{})
.PushBack(Number<1>{});
// calculate index in each of the dimensions in the order of their dimension
#if 1
static_for<0, nDimLow - 1, 1>{}(
lambda_CalculateLowerIndex<decltype(pseudo_low_strides)>(itmp, idx_low));
idx_low(nDimLow - 1) = itmp / pseudo_low_strides[nDimLow - 1];
#else
static_for<0, nDimLow, 1>{}(
lambda_CalculateLowerIndex<decltype(pseudo_low_strides)>(itmp, idx_low));
#endif
return idx_low;
}
// idx_low_diff depends on idx_low_old, so idx_low need to be up-to-date
__host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff,
LowerIndex idx_low_old)
__host__ __device__ static constexpr auto CalculateLowerIndexDiff(const UpperIndex& idx_up_diff,
const LowerIndex& idx_low_old)
{
LowerIndex idx_low_diff;
......@@ -110,49 +150,48 @@ struct Merge
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
};
#endif
// UpLengths: Sequence<...>
template <index_t LowLength, typename UpLengths>
template <typename UpLengths>
struct Unmerge
{
static constexpr index_t nDimLow = 1;
static constexpr index_t nDimUp = UpLengths::GetSize();
static constexpr index_t nDimUp = UpLengths::Size();
using UpperIndex = MultiIndex<nDimUp>;
using LowerIndex = MultiIndex<nDimLow>;
using UpperIndex = MultiIndex<nDimUp>;
__host__ __device__ constexpr Unmerge()
{
static_assert(LowLength == accumulate_on_sequence(
UpLengths{}, math::multiplies<index_t>{}, Number<1>{}),
"wrong! UpLengths need to be ");
}
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
__host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<nDimUp>{}; }
__host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<nDimLow>{}; }
__host__ __device__ static constexpr auto GetLowerLengths()
{
constexpr index_t low_length =
accumulate_on_sequence(UpLengths{}, math::multiplies<index_t>{}, Number<1>{});
__host__ __device__ static constexpr auto GetLowerLengths() { return Sequence<LowLength>{}; }
return Sequence<low_length>{};
}
__host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; }
__host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up)
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{
constexpr auto scans = typename sequence_reverse_inclusive_scan<UpLengths,
math::multiplies<index_t>,
1>::type{};
LowerIndex idx_low{0};
static_for<0, nDimUp, 1>{}([&](auto idim) { idx_low(0) += idx_up[idim] * scans[idim]; });
constexpr auto pseudo_up_strides =
typename sequence_reverse_inclusive_scan<UpLengths, math::multiplies<index_t>, 1>::
type{};
static_for<0, nDimUp, 1>{}(
[&](auto idim) { idx_low(0) += idx_up[idim] * pseudo_up_strides[idim]; });
return idx_low;
}
__host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff)
__host__ __device__ static constexpr auto CalculateLowerIndexDiff(const UpperIndex& idx_up_diff)
{
return GetLowerIndex(idx_up_diff);
return CalculateLowerIndex(idx_up_diff);
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
......@@ -165,12 +204,12 @@ template <index_t LowLength, typename UpLengths, typename Coefficients>
struct Embed
{
static constexpr index_t nDimLow = 1;
static constexpr index_t nDimUp = UpLengths::GetSize();
static constexpr index_t nDimUp = UpLengths::Size();
using LowerIndex = MultiIndex<nDimLow>;
using UpperIndex = MultiIndex<nDimUp>;
__host__ __device__ constexpr Embed()
__host__ __device__ explicit constexpr Embed()
{
static_assert(UpLengths::GetSize() == nDimUp && Coefficients::GetSize() == nDimUp + 1,
"wrong! # of dimensions not consistent");
......@@ -191,7 +230,7 @@ struct Embed
__host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; }
__host__ __device__ static constexpr auto GetLowerIndex(UpperIndex idx_up)
__host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up)
{
LowerIndex idx_low(Coefficients{}[nDimUp]);
......@@ -201,7 +240,7 @@ struct Embed
return idx_low;
}
__host__ __device__ static constexpr auto GetLowerIndexDiff(UpperIndex idx_up_diff)
__host__ __device__ static constexpr auto CalculateLowerIndexDiff(const UpperIndex& idx_up_diff)
{
LowerIndex idx_low_diff{0};
......
......@@ -18,47 +18,53 @@ struct NativeTensorDescriptor
__host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; }
struct lambda_GetLength
template <index_t IDim>
__host__ __device__ static constexpr auto GetLength(Number<IDim>)
{
template <typename IDim>
__host__ __device__ constexpr auto operator()(IDim) const
{
return GetLength(IDim{});
}
};
return mDimensions.At(Number<IDim>{}).GetLength();
}
__host__ __device__ static constexpr auto GetLengths()
template <index_t IDim>
__host__ __device__ static constexpr auto GetStride(Number<IDim>)
{
return typename sequence_gen<nDim, lambda_GetLength>::type{};
return mDimensions.At(Number<IDim>{}).GetStride();
}
struct lambda_GetStride
template <index_t... IDims>
__host__ __device__ static constexpr auto GetLengths(Sequence<IDims...>)
{
template <typename IDim>
__host__ __device__ constexpr auto operator()(IDim) const
{
return GetStride(IDim{});
}
};
return Sequence<GetLength(Number<IDims>{})...>{};
}
__host__ __device__ static constexpr auto GetStrides()
template <index_t... IDims>
__host__ __device__ static constexpr auto GetStrides(Sequence<IDims...>)
{
return typename sequence_gen<nDim, lambda_GetStride>::type{};
return Sequence<GetStride(Number<IDims>{})...>{};
}
template <index_t IDim>
__host__ __device__ static constexpr auto GetLength(Number<IDim>)
template <index_t IDim, index_t... IDims>
__host__ __device__ static constexpr auto GetLengths(Number<IDim>, Number<IDims>...)
{
return mDimensions.At(Number<IDim>{}).GetLength();
return GetLengths(Sequence<IDim, IDims...>{});
}
template <index_t IDim>
__host__ __device__ static constexpr auto GetStride(Number<IDim>)
template <index_t IDim, index_t... IDims>
__host__ __device__ static constexpr auto GetStrides(Number<IDim>, Number<IDims>...)
{
return mDimensions.At(Number<IDim>{}).GetStride();
return GetStrides(Sequence<IDim, IDims...>{});
}
__host__ __device__ static constexpr index_t GetOffset(const Index& idx)
__host__ __device__ static constexpr auto GetLengths()
{
return GetLengths(typename arithmetic_sequence_gen<0, nDim, 1>::type{});
}
__host__ __device__ static constexpr auto GetStrides()
{
return GetStrides(typename arithmetic_sequence_gen<0, nDim, 1>::type{});
}
__host__ __device__ static constexpr index_t CalculateOffset(const Index& idx)
{
index_t offset = 0;
......@@ -67,7 +73,7 @@ struct NativeTensorDescriptor
return offset;
}
__host__ __device__ static constexpr index_t GetOffsetDiff(const Index& idx_diff)
__host__ __device__ static constexpr index_t CalculateOffsetDiff(const Index& idx_diff)
{
index_t offset_diff = 0;
......@@ -161,8 +167,10 @@ struct TransformedTensorDescriptor
// UpDimensionIds should include all up-dimensions
// TODO: sanity check: while a up-dimension could be associated with multille
// transformation,
// a low-dimension should be associated with only one transformation
// transformation, a low-dimension should be associated with only one transformation
// TODO: sanity-check: GetLowerLengths of each transform should be consistent with lengths
// of lower-tensor-descriptor
}
__host__ __device__ static constexpr auto GetNumOfDimension()
......@@ -170,49 +178,78 @@ struct TransformedTensorDescriptor
return GetNumOfUpperDimension();
}
#if 0
__host__ __device__ static constexpr auto GetUpperLengths()
__host__ __device__ static constexpr auto GetLowerTensorDescriptor()
{
return LowTensorDescriptor{};
}
__host__ __device__ static constexpr auto GetLowerLengths()
{
struct lambda_get_upper_lengths
return GetLowerTensorDescriptor().GetLengths();
}
struct lambda_GetUpperLengths
{
template <typename Transform>
__host__ __device__ constexpr auto operator()(const Transform& tran) const
{
template <typename Transform>
__host__ __device__ constexpr auto operator()(Transform tran) const
{
return tran.GetUpperLengths();
}
};
return tran.GetUpperLengths();
}
};
constexpr auto tuple_of_upper_lengths =
transform_tuple(Transforms, lambda_get_upper_lengths{});
__host__ __device__ static constexpr auto GetUpperLengths()
{
constexpr auto tuple_of_up_lengths =
transform_tuple(lambda_GetUpperLengths{}, Transforms{});
constexpr auto all_upper_lengths = merge_tuple_of_sequences(tuple_of_upper_lengths);
constexpr auto mingled_up_lengths = unpack(lambda_merge_sequences{}, tuple_of_up_lengths);
constexpr auto all_upper_dimension_ids = merge_tuple_of_sequences(UpDimensionIds{});
constexpr auto mingled_up_dimension_ids =
unpack(lambda_merge_sequences{}, UpDimensionIds{});
// TODO: sanity-check all_upper_dimension_ids contain all upper-dimensions
// TODO: sanity-check all_upper_lengths have no conflicting upper-length
// TODO: sanity-check mingled_up_dimension_ids contain all upper-dimensions
// TODO: sanity-check mingled_up_lengths have no conflicting upper-length
using sort_dimension_ids =
sequence_unique_sort<decltype(all_upper_dimension_ids), math::less<index_t>>;
// sort by upper-dimension-ids
using sort_up_dimension_ids = sequence_unique_sort<decltype(mingled_up_dimension_ids),
math::less<index_t>,
math::equal<index_t>>;
constexpr auto sorted_upper_dimension_ids = typename sort_dimension_ids::type;
constexpr auto sorted2unsorted_map = typename sort_dimension_ids::sorted2unsorted_map_type;
// sanity-check sorted-upper-dimension-ids should be Sequence<0, 1, ... nDimUp-1>
static_assert(is_same<typename sort_up_dimension_ids::type,
typename arithmetic_sequence_gen<0, nDimUp, 1>::type>{},
"wrong! UpDimensionIds is not configured correctly");
constexpr auto sorted_upper_lengths =
sequence_element_pick(all_upper_lengths, sorted2unsorted_map);
constexpr auto sorted2unsorted_map = typename sort_up_dimension_ids::sorted2unsorted_map{};
return sorted_upper_lengths;
constexpr auto sorted_up_lengths =
pick_sequence_elements(mingled_up_lengths, sorted2unsorted_map);
return sorted_up_lengths;
}
__host__ __device__ static constexpr auto GetLengths() { return GetUpperLengths(); }
#endif
__host__ __device__ static constexpr auto GetLowerTensorDescriptor()
template <index_t IDim>
__host__ __device__ static constexpr auto GetLength(Number<IDim>)
{
return LowTensorDescriptor{};
return GetLengths()[IDim];
}
template <index_t... IDims>
__host__ __device__ static constexpr auto GetLengths(Sequence<IDims...>)
{
return Sequence<GetLength(Number<IDims>{})...>{};
}
template <index_t IDim, index_t... IDims>
__host__ __device__ static constexpr auto GetLengths(Number<IDim>, Number<IDims>...)
{
return GetLengths(Sequence<IDim, IDims...>{});
}
__host__ __device__ static constexpr LowerIndex GetLowerIndex(const UpperIndex& idx_up)
// TODO: right now return value is constexpr because use of non-constepxr lambda
__host__ __device__ static constexpr LowerIndex CalculateLowerIndex(const UpperIndex& idx_up)
{
LowerIndex idx_low;
......@@ -225,14 +262,15 @@ struct TransformedTensorDescriptor
// this assume each lower (single) index is only assocaited with one transformation,
// which is required for index transformation, and has been checked during constructor
// of TransformedTensorDescriptor
idx_low_part = tran.GetLowerIndex(to_array(idx_up_part));
idx_low_part = tran.CalculateLowerIndex(to_array(idx_up_part));
});
return idx_low;
}
__host__ __device__ static constexpr LowerIndex GetLowerIndexDiff(const UpperIndex& idx_up_diff,
const LowerIndex& idx_low_old)
// TODO: right now return value is constexpr because use of non-constepxr lambda
__host__ __device__ static constexpr LowerIndex
CalculateLowerIndexDiff(const UpperIndex& idx_up_diff, const LowerIndex& idx_low_old)
{
LowerIndex idx_low_diff;
......@@ -250,15 +288,15 @@ struct TransformedTensorDescriptor
// this assume each lower (single) index is associated with only one transformation,
// which is required for index transformation, and has been checked during constructor
// of TransformedTensorDescriptor
idx_low_diff_part = tran.GetLowerIndex(idx_up_diff_part, idx_low_old_part);
idx_low_diff_part = tran.CalculateLowerIndex(idx_up_diff_part, idx_low_old_part);
});
return idx_low_diff;
}
__host__ __device__ static constexpr index_t GetOffset(const UpperIndex& idx_up)
__host__ __device__ static constexpr index_t CalculateOffset(const UpperIndex& idx_up)
{
return GetLowerTensorDescriptor().GetOffset(GetLowerIndex(idx_up));
return GetLowerTensorDescriptor().CalculateOffset(CalculateLowerIndex(idx_up));
}
#if 0
......@@ -286,14 +324,14 @@ struct TransformedTensorDescriptor
};
template <index_t... Lengths, index_t... Strides>
__host__ __device__ constexpr auto make_NativeTensorDescriptor(Sequence<Lengths...>,
Sequence<Strides...>)
__host__ __device__ constexpr auto make_native_tensor_descriptor(Sequence<Lengths...>,
Sequence<Strides...>)
{
return NativeTensorDescriptor<NativeDimension<Lengths, Strides>...>{};
}
template <typename Lengths>
__host__ __device__ constexpr auto make_NativeTensorDescriptor_packed(Lengths)
__host__ __device__ constexpr auto make_native_tensor_descriptor_packed(Lengths)
{
constexpr auto strides = reverse_inclusive_scan_sequence(
Lengths::PopFront(), math::multiplies<index_t>{}, Number<1>{})
......
......@@ -7,12 +7,19 @@
namespace ck {
template <typename... NativeDimensions>
__host__ __device__ void print_tensor_descriptor(const char* s,
NativeTensorDescriptor<NativeDimensions...> desc)
__host__ __device__ void
print_tensor_descriptor(const char* s, const NativeTensorDescriptor<NativeDimensions...>& desc)
{
print_tensor_descriptor_impl(s, desc.GetLengths(), desc.GetStrides());
}
template <typename... Ts>
__host__ __device__ void print_tensor_descriptor(const char* s,
const TransformedTensorDescriptor<Ts...>& desc)
{
print_tensor_descriptor_impl(s, desc.GetLengths());
}
template <index_t... Lengths, index_t... Strides>
__host__ __device__ void
print_tensor_descriptor_impl(const char* s, Sequence<Lengths...>, Sequence<Strides...>)
......@@ -113,5 +120,53 @@ print_tensor_descriptor_impl(const char* s, Sequence<Lengths...>, Sequence<Strid
});
}
template <index_t... Lengths>
__host__ __device__ void print_tensor_descriptor_impl(const char* s, Sequence<Lengths...>)
{
constexpr index_t nDim = sizeof...(Lengths);
static_assert(nDim > 0 && nDim <= 12, "wrong!");
static_if<nDim == 1>{}([&](auto) { printf("%s dim %u, lengths {%u}\n", s, nDim, Lengths...); });
static_if<nDim == 2>{}(
[&](auto) { printf("%s dim %u, lengths {%u %u}\n", s, nDim, Lengths...); });
static_if<nDim == 3>{}(
[&](auto) { printf("%s dim %u, lengths {%u %u %u}\n", s, nDim, Lengths...); });
static_if<nDim == 4>{}(
[&](auto) { printf("%s dim %u, lengths {%u %u %u %u}\n", s, nDim, Lengths...); });
static_if<nDim == 5>{}(
[&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u}\n", s, nDim, Lengths...); });
static_if<nDim == 6>{}(
[&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u %u}, \n", s, nDim, Lengths...); });
static_if<nDim == 7>{}(
[&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u %u %u}\n", s, nDim, Lengths...); });
static_if<nDim == 8>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
});
static_if<nDim == 9>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
});
static_if<nDim == 10>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
});
static_if<nDim == 11>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
});
static_if<nDim == 12>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
});
}
} // namespace ck
#endif
#ifndef CK_ARRAY_HPP
#define CK_ARRAY_HPP
#include "Sequence.hpp"
#include "sequence.hpp"
#include "functional2.hpp"
namespace ck {
......@@ -17,7 +17,7 @@ struct Array
__host__ __device__ explicit constexpr Array() {}
template <typename X, typename... Xs>
__host__ __device__ explicit constexpr Array(X x, Xs... xs)
__host__ __device__ constexpr Array(X x, Xs... xs)
: mData{static_cast<TData>(x), static_cast<TData>(xs)...}
{
static_assert(sizeof...(Xs) + 1 == NSize, "wrong! size");
......@@ -176,7 +176,6 @@ __host__ __device__ constexpr auto pick_array_element(Arr& a, Picks)
return ArrayElementPicker<Arr, Picks>(a);
}
#if 1
template <typename T>
__host__ __device__ constexpr auto to_array(const T& x)
{
......@@ -186,8 +185,8 @@ __host__ __device__ constexpr auto to_array(const T& x)
return y;
}
#endif
// TODO: remove this
template <index_t... Is>
__host__ __device__ constexpr auto sequence2array(Sequence<Is...>)
{
......
#ifndef CK_ARRAY_HELPER_HPP
#define CK_ARRAY_HELPER_HPP
#include "Array.hpp"
#include "array.hpp"
namespace ck {
template <typename T, index_t NSize>
__host__ __device__ void print_Array(const char* s, Array<T, NSize> a)
__host__ __device__ void print_array(const char* s, Array<T, NSize> a)
{
constexpr index_t nsize = a.GetSize();
......@@ -90,4 +90,4 @@ __host__ __device__ void print_Array(const char* s, Array<T, NSize> a)
}
} // namespace ck
#endif
\ No newline at end of file
#endif
......@@ -9,9 +9,9 @@
#include "tuple.hpp"
#include "math.hpp"
#include "vector_type.hpp"
#include "Sequence.hpp"
#include "sequence.hpp"
#include "sequence_helper.hpp"
#include "Array.hpp"
#include "array.hpp"
#include "array_helper.hpp"
#include "functional.hpp"
#include "functional2.hpp"
......
......@@ -2,7 +2,7 @@
#define CK_FUNCTIONAL_HPP
#include "integral_constant.hpp"
#include "Sequence.hpp"
#include "sequence.hpp"
#include "type.hpp"
namespace ck {
......
......@@ -2,7 +2,7 @@
#define CK_FUNCTIONAL2_HPP
#include "functional.hpp"
#include "Sequence.hpp"
#include "sequence.hpp"
namespace ck {
......
......@@ -3,8 +3,8 @@
#include "functional.hpp"
#include "functional2.hpp"
#include "Sequence.hpp"
#include "Array.hpp"
#include "sequence.hpp"
#include "array.hpp"
namespace ck {
......
#ifndef CK_FUNCTIONAL4_HPP
#define CK_FUNCTIONAL4_HPP
#include "Sequence.hpp"
#include "sequence.hpp"
#include "tuple.hpp"
#include "Array.hpp"
#include "array.hpp"
namespace ck {
......
......@@ -3,6 +3,7 @@
#include "config.hpp"
#include "integral_constant.hpp"
#include "type.hpp"
namespace ck {
namespace math {
......
......@@ -2,7 +2,9 @@
#define CK_SEQUENCE_HPP
#include "integral_constant.hpp"
#include "type.hpp"
#include "functional.hpp"
#include "math.hpp"
namespace ck {
......@@ -155,8 +157,8 @@ struct Sequence
static_assert(I < Size(), "wrong!");
using seq_split = sequence_split<Type, I>;
constexpr auto seq_left = typename seq_split::SeqType0{};
constexpr auto seq_right = typename seq_split::SeqType1{}.PopFront();
constexpr auto seq_left = typename seq_split::left_type{};
constexpr auto seq_right = typename seq_split::right_type{}.PopFront();
return seq_left.PushBack(Number<X>{}).PushBack(seq_right);
}
......@@ -188,34 +190,34 @@ struct sequence_merge<Seq>
};
// generate sequence
template <index_t IBegin, index_t NRemain, typename F>
struct sequence_gen_impl
template <index_t NSize, typename F>
struct sequence_gen
{
static constexpr index_t NRemainLeft = NRemain / 2;
static constexpr index_t NRemainRight = NRemain - NRemainLeft;
static constexpr index_t IMiddle = IBegin + NRemainLeft;
template <index_t IBegin, index_t NRemain, typename G>
struct sequence_gen_impl
{
static constexpr index_t NRemainLeft = NRemain / 2;
static constexpr index_t NRemainRight = NRemain - NRemainLeft;
static constexpr index_t IMiddle = IBegin + NRemainLeft;
using type =
typename sequence_merge<typename sequence_gen_impl<IBegin, NRemainLeft, F>::type,
typename sequence_gen_impl<IMiddle, NRemainRight, F>::type>::type;
};
using type = typename sequence_merge<
typename sequence_gen_impl<IBegin, NRemainLeft, G>::type,
typename sequence_gen_impl<IMiddle, NRemainRight, G>::type>::type;
};
template <index_t I, typename F>
struct sequence_gen_impl<I, 1, F>
{
static constexpr index_t Is = F{}(Number<I>{});
using type = Sequence<Is>;
};
template <index_t I, typename G>
struct sequence_gen_impl<I, 1, G>
{
static constexpr index_t Is = G{}(Number<I>{});
using type = Sequence<Is>;
};
template <index_t I, typename F>
struct sequence_gen_impl<I, 0, F>
{
using type = Sequence<>;
};
template <index_t I, typename G>
struct sequence_gen_impl<I, 0, G>
{
using type = Sequence<>;
};
template <index_t NSize, typename F>
struct sequence_gen
{
using type = typename sequence_gen_impl<0, NSize, F>::type;
};
......@@ -281,8 +283,8 @@ struct sequence_split
using range0 = typename arithmetic_sequence_gen<0, I, 1>::type;
using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::type;
using SeqType0 = decltype(Seq::Extract(range0{}));
using SeqType1 = decltype(Seq::Extract(range1{}));
using left_type = decltype(Seq::Extract(range0{}));
using right_type = decltype(Seq::Extract(range1{}));
};
// reverse sequence
......@@ -293,8 +295,8 @@ struct sequence_reverse
using seq_split = sequence_split<Seq, NSize / 2>;
using type = typename sequence_merge<
typename sequence_reverse<typename seq_split::SeqType1>::type,
typename sequence_reverse<typename seq_split::SeqType0>::type>::type;
typename sequence_reverse<typename seq_split::right_type>::type,
typename sequence_reverse<typename seq_split::left_type>::type>::type;
};
template <index_t I>
......@@ -309,138 +311,264 @@ struct sequence_reverse<Sequence<I0, I1>>
using type = Sequence<I1, I0>;
};
template <typename Seq, typename Compare>
struct sequence_sort
template <typename Values, typename Ids, typename Compare>
struct sequence_sort_impl
{
template <typename SeqLeft, typename SeqRight, typename MergedSeq, typename Comp>
template <typename LeftValues,
typename LeftIds,
typename RightValues,
typename RightIds,
typename MergedValues,
typename MergedIds,
typename Comp>
struct sorted_sequence_merge_impl
{
static constexpr bool pick_left = SeqLeft::Front() < SeqRight::Front();
static constexpr index_t next_value = pick_left ? SeqLeft::Front() : SeqRight::Front();
using new_merged_seq = decltype(MergedSeq::PushBack(Number<next_value>{}));
using new_left_seq =
typename conditional<pick_left, decltype(SeqLeft::PopFront()), SeqLeft>::type;
using new_right_seq =
typename conditional<pick_left, SeqRight, decltype(SeqRight::PopFront())>::type;
using type =
typename sorted_sequence_merge_impl<new_left_seq, new_right_seq, new_merged_seq, Comp>::
type;
static constexpr bool choose_left = LeftValues::Front() < RightValues::Front();
static constexpr index_t chosen_value =
choose_left ? LeftValues::Front() : RightValues::Front();
static constexpr index_t chosen_id = choose_left ? LeftIds::Front() : RightIds::Front();
using new_merged_values = decltype(MergedValues::PushBack(Number<chosen_value>{}));
using new_merged_ids = decltype(MergedIds::PushBack(Number<chosen_id>{}));
using new_left_values =
typename conditional<choose_left, decltype(LeftValues::PopFront()), LeftValues>::type;
using new_left_ids =
typename conditional<choose_left, decltype(LeftIds::PopFront()), LeftIds>::type;
using new_right_values =
typename conditional<choose_left, RightValues, decltype(RightValues::PopFront())>::type;
using new_right_ids =
typename conditional<choose_left, RightIds, decltype(RightIds::PopFront())>::type;
using merge = sorted_sequence_merge_impl<new_left_values,
new_left_ids,
new_right_values,
new_right_ids,
new_merged_values,
new_merged_ids,
Comp>;
// this is output
using merged_values = typename merge::merged_values;
using merged_ids = typename merge::merged_ids;
};
template <typename SeqLeft, typename MergedSeq, typename Comp>
struct sorted_sequence_merge_impl<SeqLeft, Sequence<>, MergedSeq, Comp>
template <typename LeftValues,
typename LeftIds,
typename MergedValues,
typename MergedIds,
typename Comp>
struct sorted_sequence_merge_impl<LeftValues,
LeftIds,
Sequence<>,
Sequence<>,
MergedValues,
MergedIds,
Comp>
{
using type = typename sequence_merge<MergedSeq, SeqLeft>::type;
using merged_values = typename sequence_merge<MergedValues, LeftValues>::type;
using merged_ids = typename sequence_merge<MergedIds, LeftIds>::type;
};
template <typename SeqRight, typename MergedSeq, typename Comp>
struct sorted_sequence_merge_impl<Sequence<>, SeqRight, MergedSeq, Comp>
template <typename RightValues,
typename RightIds,
typename MergedValues,
typename MergedIds,
typename Comp>
struct sorted_sequence_merge_impl<Sequence<>,
Sequence<>,
RightValues,
RightIds,
MergedValues,
MergedIds,
Comp>
{
using type = typename sequence_merge<MergedSeq, SeqRight>::type;
using merged_values = typename sequence_merge<MergedValues, RightValues>::type;
using merged_ids = typename sequence_merge<MergedIds, RightIds>::type;
};
template <typename Seq0, typename Seq1, typename Comp>
template <typename LeftValues,
typename LeftIds,
typename RightValues,
typename RightIds,
typename Comp>
struct sorted_sequence_merge
{
using type = typename sorted_sequence_merge_impl<Seq0, Seq1, Sequence<>, Comp>::type;
using merge = sorted_sequence_merge_impl<LeftValues,
LeftIds,
RightValues,
RightIds,
Sequence<>,
Sequence<>,
Comp>;
using merged_values = typename merge::merged_values;
using merged_ids = typename merge::merged_ids;
};
using split = sequence_split<Seq, Seq::Size() / 2>;
using unsorted_left = typename split::SeqType0;
using unsorted_right = typename split::SeqType1;
static constexpr index_t nsize = Values::Size();
using split_unsorted_values = sequence_split<Values, nsize / 2>;
using split_unsorted_ids = sequence_split<Ids, nsize / 2>;
using sorted_left = typename sequence_sort<unsorted_left, Compare>::type;
using sorted_right = typename sequence_sort<unsorted_right, Compare>::type;
using left_unsorted_values = typename split_unsorted_values::left_type;
using left_unsorted_ids = typename split_unsorted_ids::left_type;
using left_sort = sequence_sort_impl<left_unsorted_values, left_unsorted_ids, Compare>;
using left_sorted_values = typename left_sort::sorted_values;
using left_sorted_ids = typename left_sort::sorted_ids;
using type = typename sorted_sequence_merge<sorted_left, sorted_right, Compare>::type;
using right_unsorted_values = typename split_unsorted_values::right_type;
using right_unsorted_ids = typename split_unsorted_ids::right_type;
using right_sort = sequence_sort_impl<right_unsorted_values, right_unsorted_ids, Compare>;
using right_sorted_values = typename right_sort::sorted_values;
using right_sorted_ids = typename right_sort::sorted_ids;
using merged_sorted = sorted_sequence_merge<left_sorted_values,
left_sorted_ids,
right_sorted_values,
right_sorted_ids,
Compare>;
using sorted_values = typename merged_sorted::merged_values;
using sorted_ids = typename merged_sorted::merged_ids;
};
template <index_t X, index_t Y, typename Compare>
struct sequence_sort<Sequence<X, Y>, Compare>
template <index_t ValueX, index_t ValueY, index_t IdX, index_t IdY, typename Compare>
struct sequence_sort_impl<Sequence<ValueX, ValueY>, Sequence<IdX, IdY>, Compare>
{
static constexpr bool x_first = Compare{}(X, Y);
static constexpr bool choose_x = Compare{}(ValueX, ValueY);
using sorted_values =
typename conditional<choose_x, Sequence<ValueX, ValueY>, Sequence<ValueY, ValueX>>::type;
using sorted_ids = typename conditional<choose_x, Sequence<IdX, IdY>, Sequence<IdY, IdX>>::type;
};
using type = typename conditional<x_first, Sequence<X, Y>, Sequence<Y, X>>::type;
template <index_t Value, index_t Id, typename Compare>
struct sequence_sort_impl<Sequence<Value>, Sequence<Id>, Compare>
{
using sorted_values = Sequence<Value>;
using sorted_ids = Sequence<Id>;
};
template <index_t X, typename Compare>
struct sequence_sort<Sequence<X>, Compare>
template <typename Values, typename Compare>
struct sequence_sort
{
using type = Sequence<X>;
using unsorted_ids = typename arithmetic_sequence_gen<0, Values::Size(), 1>::type;
using sort = sequence_sort_impl<Values, unsorted_ids, Compare>;
// this is output
using type = typename sort::sorted_values;
using sorted2unsorted_map = typename sort::sorted_ids;
};
template <typename Seq, typename Less, typename Equal>
template <typename Values, typename Less, typename Equal>
struct sequence_unique_sort
{
template <typename WorkInputSeq, typename WorkOutputSeq, typename Eq>
template <typename RemainValues,
typename RemainIds,
typename UniquifiedValues,
typename UniquifiedIds,
typename Eq>
struct sorted_sequence_uniquify_impl
{
static constexpr index_t new_value = WorkInputSeq::Front();
using new_work_input_seq = decltype(WorkInputSeq::PopFront());
static constexpr index_t current_value = RemainValues::Front();
static constexpr index_t current_id = RemainIds::Front();
static constexpr bool is_unique_value = (current_value != UniquifiedValues::Back());
using new_remain_values = decltype(RemainValues::PopFront());
using new_remain_ids = decltype(RemainIds::PopFront());
using new_uniquified_values =
typename conditional<is_unique_value,
decltype(UniquifiedValues::PushBack(Number<current_value>{})),
UniquifiedValues>::type;
using new_working_output_seq =
typename conditional<new_value == WorkOutputSeq::Back(),
WorkOutputSeq,
decltype(WorkOutputSeq::PopBack(Number<new_value>{}))>::type;
using new_uniquified_ids =
typename conditional<is_unique_value,
decltype(UniquifiedIds::PushBack(Number<current_id>{})),
UniquifiedIds>::type;
using uniquify = sorted_sequence_uniquify_impl<new_remain_values,
new_remain_ids,
new_uniquified_values,
new_uniquified_ids,
Eq>;
// this is output
using uniquified_values = typename uniquify::uniquified_values;
using uniquified_ids = typename uniquify::uniquified_ids;
};
template <typename WorkInputSeq, typename Eq>
struct sorted_sequence_uniquify_impl<WorkInputSeq, Sequence<>, Eq>
template <typename UniquifiedValues, typename UniquifiedIds, typename Eq>
struct sorted_sequence_uniquify_impl<Sequence<>,
Sequence<>,
UniquifiedValues,
UniquifiedIds,
Eq>
{
using type = WorkInputSeq;
using uniquified_values = UniquifiedValues;
using uniquified_ids = UniquifiedIds;
};
template <typename SortedSeq, typename Eq>
template <typename SortedValues, typename SortedIds, typename Eq>
struct sorted_sequence_uniquify
{
using type = typename sorted_sequence_uniquify_impl<SortedSeq, Sequence<>, Eq>::type;
using uniquify = sorted_sequence_uniquify_impl<decltype(SortedValues::PopFront()),
decltype(SortedIds::PopFront()),
Sequence<SortedValues::Front()>,
Sequence<SortedIds::Front()>,
Eq>;
using uniquified_values = typename uniquify::uniquified_values;
using uniquified_ids = typename uniquify::uniquified_ids;
};
using sorted_seq = typename sequence_sort<Seq, Less>::type;
using sort = sequence_sort<Values, Less>;
using sorted_values = typename sort::type;
using sorted_ids = typename sort::sorted2unsorted_map;
using type = typename sorted_sequence_uniquify<sorted_seq, Equal>::type;
using uniquify = sorted_sequence_uniquify<sorted_values, sorted_ids, Equal>;
// this is output
using type = typename uniquify::uniquified_values;
using sorted2unsorted_map = typename uniquify::uniquified_ids;
};
template <typename Seq>
template <typename SeqMap>
struct is_valid_sequence_map
{
// not implemented yet, always return true
static constexpr integral_constant<bool, true> value = integral_constant<bool, true>{};
// TODO: add proper check for is_valid, something like:
// static constexpr bool value =
// is_same<typename arithmetic_sequence_gen<0, Seq::Size(), 1>::type,
// typename sequence_sort<Seq>::SortedSeqType>{};
static constexpr bool value =
is_same<typename arithmetic_sequence_gen<0, SeqMap::Size(), 1>::type,
typename sequence_sort<SeqMap, math::less<index_t>>::type>{};
};
template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
struct sequence_map_inverse_impl
template <typename SeqMap>
struct sequence_map_inverse
{
private:
static constexpr auto new_y2x = WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});
template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
struct sequence_map_inverse_impl
{
static constexpr auto new_y2x =
WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});
public:
using type =
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::type;
};
using type =
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::
type;
};
template <typename X2Y, typename WorkingY2X, index_t XBegin>
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
{
using type = WorkingY2X;
};
template <typename X2Y, typename WorkingY2X, index_t XBegin>
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
{
using type = WorkingY2X;
};
template <typename X2Y>
struct sequence_map_inverse
{
using type =
typename sequence_map_inverse_impl<X2Y,
typename uniform_sequence_gen<X2Y::Size(), 0>::type,
typename sequence_map_inverse_impl<SeqMap,
typename uniform_sequence_gen<SeqMap::Size(), 0>::type,
0,
X2Y::Size()>::type;
SeqMap::Size()>::type;
};
template <index_t... Xs, index_t... Ys>
......@@ -601,6 +729,12 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<I
return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse();
}
template <typename Seq, index_t... Is>
__host__ __device__ constexpr auto pick_sequence_elements(Seq, Sequence<Is...>)
{
return Sequence<Seq::At(Number<Is>{})...>{};
}
template <typename Seq, typename Reduce>
struct lambda_accumulate_on_sequence
{
......
#ifndef CK_SEQUENCE_HELPER_HPP
#define CK_SEQUENCE_HELPER_HPP
#include "Sequence.hpp"
#include "sequence.hpp"
namespace ck {
template <index_t... Xs>
__host__ __device__ void print_Sequence(const char* s, Sequence<Xs...>)
__host__ __device__ void print_sequence(const char* s, Sequence<Xs...>)
{
constexpr index_t nsize = Sequence<Xs...>::Size();
......
......@@ -3,7 +3,7 @@
#include "integral_constant.hpp"
#include "type.hpp"
#include "Sequence.hpp"
#include "sequence.hpp"
namespace ck {
......@@ -114,19 +114,19 @@ __host__ __device__ constexpr auto make_tuple(Xs&&... xs)
namespace detail {
template <typename X, typename F, index_t... Is>
__host__ __device__ constexpr auto transpose_tuple_impl(X& x, F f, Sequence<Is...>)
template <typename F, typename X, index_t... Is>
__host__ __device__ constexpr auto transform_tuple_impl(F f, const X& x, Sequence<Is...>)
{
return make_tuple(f(x.At(Number<Is>{}))...);
}
} // namespace detail
template <typename X, typename F>
__host__ __device__ constexpr auto transpose_tuple(X& x, F f)
template <typename F, typename X>
__host__ __device__ constexpr auto transform_tuple(F f, const X& x)
{
return detail::transpose_tuple_impl(
x, f, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
return detail::transform_tuple_impl(
f, x, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
}
} // namespace ck
......
......@@ -2,10 +2,12 @@
#define CK_TYPE_HPP
#include "integral_constant.hpp"
#include "Sequence.hpp"
namespace ck {
template <index_t... Is>
struct Sequence;
template <typename X, typename Y>
struct is_same : public integral_constant<bool, false>
{
......
......@@ -84,8 +84,8 @@ int main(int argc, char* argv[])
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 2;
constexpr index_t WPad = 2;
constexpr index_t HPad = 3;
constexpr index_t WPad = 3;
#elif 1
// 3x3, 34x34
constexpr index_t N = 64;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment