Commit 0c05f427 authored by Chao Liu's avatar Chao Liu
Browse files

adding dimension tranformation

parent bd44e639
#ifndef CK_DIMENSION_TRANSFORM_HPP
#define CK_DIMENSION_TRANSFORM_HPP
#include "common_header.hpp"
namespace ck {
template <index_t N>
using MultiIndex = Array<index_t, N>;
// LowLengths: Sequence<...>
template <class LowLengths>
struct PassThrough
{
static constexpr index_t nDim = LowLengths::GetSize();
using LowerId = MultiIndex<nDim>;
using UpperId = LowerId;
__host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr auto GetUpperNumOfDimension()
{
return GetLowerNumOfDimension();
}
__host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; }
__host__ __device__ static constexpr auto GetUpperLengths() { return GetLowerLengths(); }
__host__ __device__ static constexpr auto GetLowerId(UpperId id_up) { return id_up; }
__host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff)
{
return id_up_diff;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
};
// LowLengths: Sequence<...>
template <class LowLengths, class LeftPads, class RightPads>
struct Pad
{
static constexpr index_t nDim = LowLengths::GetSize();
using LowerId = MultiIndex<nDim>;
using UpperId = LowerId;
__host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr auto GetUpperNumOfDimension()
{
return GetLowerNumOfDimension();
}
__host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; }
__host__ __device__ static constexpr auto GetUpperLengths()
{
return GetLowerLengths() + LeftPads + RightPads;
}
__host__ __device__ static constexpr auto GetLowerId(UpperId id_up) { return id_up - LeftPads; }
__host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff)
{
return id_up_diff;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
};
// LowLengths: Sequence<...>
template <class LowLengths>
struct Merge
{
static constexpr index_t nDimLow = LowLengths::GetSize();
static constexpr index_t nDimUp = 1;
using LowerId = MultiIndex<nDimLow>;
using UpperId = MultiIndex<nDimUp>;
__host__ __device__ static constexpr auto GetUpperNumOfDimension(){return Number<nDimUp>{}};
__host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number<nDimLow>{}; }
__host__ __device__ static constexpr auto GetLowerLengths() { return LowLengths{}; }
__host__ __device__ static constexpr auto GetUpperLengths()
{
return Sequence<accumulate_on_sequence(
GetLowerLengths(), math::multiplies<index_t>{}, Number<1>{})>{};
}
__host__ __device__ static constexpr auto GetLowerId(UpperId id_up)
{
LowerId id_low;
// not implemeneted
return id_low;
}
// id_low_diff depends on id_low_old, so id_low need to be up-to-date
__host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff, LowerId id_low_old)
{
LowerId id_low_diff;
// not implemeneted
return id_low_diff;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
};
// UpLengths: Sequence<...>
template <index_t LowLength, class UpLengths>
struct Unmerge
{
static constexpr index_t nDimLow = 1;
static constexpr index_t nDimUp = UpLengths::GetSize();
__host__ __device__ constexpr Unmerge()
{
static_assert(LowLength == accumulate_on_sequence(
UpLengths{}, math::multiplies<index_t>{}, Number<1>{}),
"wrong! UpLengths need to be ");
}
__host__ __device__ static constexpr auto GetUpperNumOfDimension(){return Number<nDimUp>{}};
__host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number<nDimLow>{}; }
__host__ __device__ static constexpr auto GetLowerLengths() { return Sequence<LowLength>{}; }
__host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; }
__host__ __device__ static constexpr auto GetLowerId(UpperId id_up)
{
constexpr auto scans = typename sequence_reverse_inclusive_scan<UpLengths,
math::multiplies<index_t>,
1>::type{};
LowerId id_low{0};
static_for<0, nDim, 1>{}([&](auto idim) { id_low[0] += id_up[idim] * scans[idim]; });
return id_low;
}
__host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff)
{
return GetLowerId(id_up_diff);
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
};
// UpLengths: Sequence<...>
// Coefficients: Sequence<...>
// id_low = coefficients[0, ...nDimUp-1] * id_up[0, ...nDimUp-1] + coefficients[nDimUp]
template <index_t LowLength, class UpLengths, class Coefficients>
struct Embed
{
static constexpr index_t nDimLow = 1;
static constexpr index_t nDimUp = UpLengths::GetSize();
static constexpr auto mCoefficients = Coefficients{};
__host__ __device__ constexpr Embed()
{
static_assert(UpLengths::GetSize() == nDimUp && Coefficients::GetSize() == nDimUp + 1,
"wrong! # of dimensions not consistent");
constexpr index_t low_id_max =
Coefficents.Back() + accumulate_on_sequence(UpLengths{} * Coefficients::PopBack(),
math::plus<index_t>{},
Number<0>{});
static_assert(low_id_max < LowLength, "wrong! lower-id will go out of range");
}
__host__ __device__ static constexpr auto GetUpperNumOfDimension(){return Number<nDimUp>{}};
__host__ __device__ static constexpr auto GetLowerNumOfDimension() { return Number<nDimLow>{}; }
__host__ __device__ static constexpr auto GetLowerLengths() { return Sequence<LowLength>{}; }
__host__ __device__ static constexpr auto GetUpperLengths() { return UpLengths{}; }
__host__ __device__ static constexpr auto GetLowerId(UpperId id_up)
{
LowerId id_low{mCoefficients[nDimUp]};
static_for<0, nDimUp, 1>{}(
[&](auto idim) { id_low[0] += id_up[idim] * mCoefficients[idim]; });
return id_low;
}
__host__ __device__ static constexpr auto GetLowerIdDiff(UpperId id_up_diff)
{
LowerId id_low_diff{0};
static_for<0, nDimUp, 1>{}(
[&](auto idim) { id_low_diff[0] += id_up_diff[idim] * mCoefficients[idim]; });
return id_low_diff;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
};
} // namespace ck
#endif
...@@ -3,75 +3,159 @@ ...@@ -3,75 +3,159 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "dimension.hpp" #include "dimension.hpp"
#include "multi_index_transform.hpp"
namespace ck { namespace ck {
template <class Lengths, class Strides> template <class... NativeDimensions>
struct NativeTensorDescriptor struct NativeTensorDescriptor
{ {
using type = NativeTensorDescriptor; using type = NativeTensorDescriptor;
static constexpr index_t nDim = Lengths::GetSize(); static constexpr auto mDimensions = Tuple<NativeDimensions...>;
static constexpr index_t nDim = mDimensions::GetSize();
using Id = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; } __host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr auto GetLengths() { return Lengths{}; } __host__ __device__ static constexpr auto GetLengths()
{
// not implemented
}
__host__ __device__ static constexpr auto GetStrides() { return Strides{}; } __host__ __device__ static constexpr auto GetStrides()
{
// not implemented
}
__host__ __device__ static constexpr auto GetLength(index_t IDim) { return Lengths{}[IDim]; } template <index_t IDim>
__host__ __device__ static constexpr auto GetLength(Number<IDim>)
{
return mDimensions.Get(Number<IDim>{}).GetLength();
}
__host__ __device__ static constexpr auto GetStride(index_t IDim) { return Strides{}[IDim]; } template <index_t IDim>
__host__ __device__ static constexpr auto GetStride(Number<IDim>)
{
return mDimensions.Get(Number<IDim>{}).GetStride();
}
__host__ __device__ static constexpr index_t GetOffset(Id id) __host__ __device__ static constexpr index_t GetOffset(Index idx)
{ {
// not implemented index_t offset = 0;
static_for<0, nDim, 1>{}([&](auto idim) { offset += idx[idim] * GetStride(idim); });
return offset;
}
__host__ __device__ static constexpr index_t GetOffsetDiff(Index idx_diff)
{
index_t offset_diff = 0;
static_for<0, nDim, 1>{}(
[&](auto idim) { offset_diff += idx_diff[idim] * GetStride(idim); });
return offset_diff;
}
__host__ __device__ static constexpr auto AreUpperIndex2OffsetTransformLinear();
{
// TODO: re-implement "Sequence", so that it can take other data-type (including bool) as
// element
return uniform_sequence_gen<nDim, 1>{};
}
__host__ __device__ static constexpr auto GetIndependentDimensionGroups()
{
// not implemented, should return Tuple<Sequence<0>, Sequence<1>, ...>
return xxx;
} }
}; };
// LowerTensorDescriptor // LowerTensorDescriptor
// Transforms: std::tuple<DimensionTransforms...> // Transforms: std::tuple<DimensionTransforms...>
// LowerIds: std::tuple<Sequence<...>> // LowerDimensionIds: std::tuple<Sequence<...>>
// UpperIds: std::tuple<Sequence<...>> // UpperDimensionIds: std::tuple<Sequence<...>>
template <class LowTensorDescriptor, template <class LowTensorDescriptor, class Transforms, class LowDimensionIds, class UpDimensionIds>
class Transforms,
class LowDimensionMasks,
class UpDimensionMasks>
struct TransformedTensorDescriptor struct TransformedTensorDescriptor
{ {
using type = TransformedTensorDescriptor; using type = TransformedTensorDescriptor;
static constexpr index_t nDimUp = xxxx; static constexpr index_t nDimUp = GetUpperNumOfDimension();
static constexpr index_t nDimLow = xxx; static constexpr index_t nDimLow = GetLowerNumOfDimension();
static constexpr index_t nTransform = Transforms::GetSize(); static constexpr index_t nTransform = Transforms::GetSize();
using UpperId = MultiIndex<nDimUp>; using UpperIndex = MultiIndex<nDimUp>;
using LowerId = MultiIndex<nDimLow>; using LowerIndex = MultiIndex<nDimLow>;
__host__ __device__ static constexpr TransformedTensorDescriptor() __host__ __device__ static constexpr TransformedTensorDescriptor()
{ {
static_assert(nTransform == Transforms::GetSize() && static_assert(nTransform == Transforms::GetSize() &&
nTransform == LowDimensionMasks::GetSize() && nTransform == LowDimensionIds::GetSize() &&
nTransform == UpDimensionMasks::GetSize(), nTransform == UpDimensionIds::GetSize(),
"wrong! # of transformations not the same"); "wrong! # of transformations not the same");
// TODO: sanity check: LowDimensionMasks should include all low-dimensions, // TODO: sanity check: LowDimensionIds should include all low-dimensions,
// UpDimensionMasks should include all up-dimensions // UpDimensionIds should include all up-dimensions
// TODO: sanity check: while a up-dimension could be associated with multille // TODO: sanity check: while a up-dimension could be associated with multille
// transformation, // transformation,
// a low-dimension should be associated with only one transformation // a low-dimension should be associated with only one transformation
} }
__host__ __device__ static constexpr auto GetNumOfLowerDimension()
{
// Here, we assume all lower-dimensions are active
// TODO: sanity-check all lower-dimension are indeed active
constexpr auto low_active_dims = unique_sort_sequence(
merge_tuple_of_sequences(LowDimensionIds{}), math::less<index_t>{});
return low_active_dims.GetSize();
}
__host__ __device__ static constexpr auto GetNumOfUpperDimension()
{
constexpr auto up_active_dims =
unique_sort_sequence(merge_tuple_of_sequences(UpDimensionIds{}), math::less<index_t>{});
return up_active_dims.GetSize();
}
__host__ __device__ static constexpr auto GetNumOfDimension() __host__ __device__ static constexpr auto GetNumOfDimension()
{ {
// not implemented return GetNumOfUpperDimension();
} }
__host__ __device__ static constexpr auto GetLengths() __host__ __device__ static constexpr auto GetLengths()
{ {
// not implemented struct lambda_get_upper_lengths
{
template <class Transform>
__host__ __device__ constexpr auto operator()(Transform tran) const
{
return tran.GetUpperLengths();
}
};
constexpr auto tuple_of_upper_lengths =
transform_tuple(Transforms, lambda_get_upper_lengths{});
constexpr auto all_upper_lengths = merge_tuple_of_sequences(tuple_of_upper_lengths);
constexpr auto all_upper_dimension_ids = merge_tuple_of_sequences(UpDimensionIds{});
// TODO: sanity-check all_upper_dimension_ids contain all upper-dimensions
// TODO: sanity-check all_upper_lengths have no conflicting upper-length
using sort_dimension_ids =
sequence_unique_sort<decltype(all_upper_dimension_ids), math::less<index_t>>;
constexpr auto sorted_upper_dimension_ids = typename sort_dimension_ids::type;
constexpr auto sorted2unsorted_map = typename sort_dimension_ids::sorted2unsorted_map_type;
constexpr auto sorted_upper_lengths =
sequence_element_pick(all_upper_lengths, sorted2unsorted_map);
return sorted_upper_lengths;
} }
__host__ __device__ static constexpr auto GetLowerTensorDescriptor() __host__ __device__ static constexpr auto GetLowerTensorDescriptor()
...@@ -79,17 +163,57 @@ struct TransformedTensorDescriptor ...@@ -79,17 +163,57 @@ struct TransformedTensorDescriptor
return LowTensorDescriptor{}; return LowTensorDescriptor{};
} }
__host__ __device__ static constexpr index_t GetLowerId(UpperId id_up) __host__ __device__ static constexpr index_t GetLowerIndex(UpperIndex idx_up)
{ {
// not implemented LowerIndex idx_low;
static_for<0, nTransform, 1>{}([&](auto itran) {
constexpr auto tran = Transforms::Get(itran);
constexpr auto idx_low_part = pick_array_element(idx_low, LowDimensionIds::Get(itran));
constexpr auto idx_up_part = pick_array_element(idx_up, UpDimensionIds::Get(itran));
// this assume each lower (single) index is only assocaited with one transformation,
// which is required for index transformation, and has been checked during constructor
// of TransformedTensorDescriptor
idx_low_part = tran.GetLowerIndex(idx_up_part);
});
return idx_low;
}
__host__ __device__ static constexpr index_t GetLowerIndexDiff(UpperIndex idx_up_diff,
LowerIndex idx_low_old)
{
LowerIndex idx_low_diff;
static_for<0, nTransform, 1>{}([&](auto itran) {
constexpr auto tran = Transforms::Get(itran);
constexpr auto idx_up_diff_part =
pick_array_element(idx_up_diff, UpDimensionIds::Get(itran));
constexpr auto idx_low_diff_part =
pick_array_element(idx_low_diff, LowDimensionIds::Get(itran));
constexpr auto idx_low_old_part =
pick_array_element(idx_low_old, LowDimensionIds::Get(itran));
// this assume each lower (single) index is associated with only one transformation,
// which is required for index transformation, and has been checked during constructor
// of TransformedTensorDescriptor
idx_low_diff_part = tran.GetLowerIndex(idx_up_diff_part, idx_low_old_part);
});
return idx_low_diff;
} }
__host__ __device__ static constexpr index_t GetOffset(UpperId id_up) __host__ __device__ static constexpr index_t GetOffset(UpperIndex idx_up)
{ {
return GetLowerTensorDescriptor().GetOffset(GetLowerId(id_up)); return GetLowerTensorDescriptor().GetOffset(GetLowerIndex(idx_up));
} }
__host__ __device__ static constexpr auto AreUpperId2OffsetLinear(); __host__ __device__ static constexpr auto AreUpperIndex2OffsetTransformLinear();
{ {
// not implemented // not implemented
} }
......
...@@ -79,6 +79,56 @@ struct Array ...@@ -79,6 +79,56 @@ struct Array
} }
}; };
// A: Array
// Picks: Sequence<...>
template <class Arr, class Picks>
ArrayElementPicker
{
__host__ __device__ constexpr ArrayElementPicker(Arr & array) : mData{array}
{
constexpr index_t imax =
accumulate_on_sequence(Picks{}, math::maxer<index_t>{}, Number<0>{});
static_assert(imax < Picks::GetSize(), "wrong! exceeding max id");
}
__host__ __device__ static constexpr index_t GetSize() { return Picks::GetSize(); }
template <index_t I>
__host__ __device__ constexpr TData operator[](Number<I>) const
{
constexpr auto IP = Picks::Get(Number<I>{});
return mData[IP];
}
__host__ __device__ constexpr TData operator[](index_t i) const
{
constexpr index_t ip = Picks{}[i];
return mData[ip];
}
template <index_t I>
__host__ __device__ TData& operator()(Number<I>)
{
constexpr auto IP = Picks::Get(Number<I>{});
return mData[IP];
}
__host__ __device__ TData& operator()(index_t i)
{
constexpr index_t ip = Picks{}[i];
return mData[ip];
}
Arr& mData;
};
template <class Arr, class Picks>
__host__ __device__ constexpr auto pick_array_element(Arr& a, Picks)
{
return ArrayElementPicker<Arr, Picks>(a);
}
template <index_t... Is> template <index_t... Is>
__host__ __device__ constexpr auto sequence2array(Sequence<Is...>) __host__ __device__ constexpr auto sequence2array(Sequence<Is...>)
{ {
......
...@@ -6,6 +6,9 @@ ...@@ -6,6 +6,9 @@
namespace ck { namespace ck {
template <index_t, index_t, index_t>
struct static_for;
template <index_t...> template <index_t...>
struct Sequence; struct Sequence;
...@@ -294,6 +297,18 @@ struct sequence_reverse<Sequence<I0, I1>> ...@@ -294,6 +297,18 @@ struct sequence_reverse<Sequence<I0, I1>>
using type = Sequence<I1, I0>; using type = Sequence<I1, I0>;
}; };
template <class Seq, class Compare>
struct sequence_sort
{
// not implemented
};
template <class Seq, class Compare>
struct sequence_unique_sort
{
// not implemented
};
template <class Seq> template <class Seq>
struct is_valid_sequence_map struct is_valid_sequence_map
{ {
...@@ -486,6 +501,35 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<I ...@@ -486,6 +501,35 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<I
return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse(); return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse();
} }
template <class Seq, class Reduce>
struct lambda_accumulate_on_sequence
{
const Reduce& f;
index_t& result;
__host__ __device__ constexpr lambda_accumulate_on_sequence(const Reduce& f_, index_t& result_)
: f(f_), result(result_)
{
}
template <class IDim>
__host__ __device__ constexpr index_t operator()(IDim) const
{
return result = f(result, Seq::Get(IDim{}));
}
};
template <class Seq, class Reduce, index_t Init>
__host__ __device__ constexpr index_t
accumulate_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
{
index_t result = Init;
static_for<0, Seq::mSize, 1>{}(lambda_accumulate_on_sequence<Seq, Reduce>(f, result));
return result;
}
template <index_t... Xs> template <index_t... Xs>
__host__ __device__ void print_Sequence(const char* s, Sequence<Xs...>) __host__ __device__ void print_Sequence(const char* s, Sequence<Xs...>)
{ {
......
...@@ -37,34 +37,5 @@ struct static_for ...@@ -37,34 +37,5 @@ struct static_for
} }
}; };
template <class Seq, class Reduce>
struct lambda_accumulate_on_sequence
{
const Reduce& f;
index_t& result;
__host__ __device__ constexpr lambda_accumulate_on_sequence(const Reduce& f_, index_t& result_)
: f(f_), result(result_)
{
}
template <class IDim>
__host__ __device__ constexpr index_t operator()(IDim) const
{
return result = f(result, Seq::Get(IDim{}));
}
};
template <class Seq, class Reduce, index_t Init>
__host__ __device__ constexpr index_t
accumulate_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
{
index_t result = Init;
static_for<0, Seq::mSize, 1>{}(lambda_accumulate_on_sequence<Seq, Reduce>(f, result));
return result;
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -31,6 +31,12 @@ struct multiplies ...@@ -31,6 +31,12 @@ struct multiplies
__host__ __device__ constexpr T operator()(T a, T b) const { return a * b; } __host__ __device__ constexpr T operator()(T a, T b) const { return a * b; }
}; };
template <class T>
struct maxer
{
__host__ __device__ constexpr T operator()(T a, T b) const { return a >= b ? a : b; }
};
template <class T> template <class T>
struct integer_divide_ceiler struct integer_divide_ceiler
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment