Unverified Commit 0e92deb7 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Tile program init bulk PR (#4)



Tile Program init bulk PR

---------
Co-authored-by: default avatarzjing14 <zhangjing14@gmail.com>
Co-authored-by: default avatarPo-Yen, Chen <PoYen.Chen@amd.com>
parent 0077eeb3
......@@ -8,9 +8,56 @@
namespace ck {
enum struct IndexTransformEnum
{
Undefined,
PassThrough,
Pad,
Embed,
Merge,
UnMerge,
Replicate,
Xor,
};
template <index_t NDimLow, index_t NDimUp>
struct BaseTransform
{
__host__ __device__ static constexpr auto GetTypeEnum()
{
return IndexTransformEnum::Undefined;
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; }
// return safe value for vector length/stride, based on compile-time known only
// variables
// MUST be static function
template <typename LowVectorLengths, typename LowVectorStrides>
__host__ __device__ static constexpr auto
CalculateUpperDimensionSafeVectorLengthStrides(const LowVectorLengths&, const LowVectorStrides&)
{
if constexpr(NDimUp > 0)
{
Array<index_t, NDimUp> up_vector_lengths{-1};
Array<index_t, NDimUp> up_vector_strides{-1};
return make_tuple(up_vector_lengths, up_vector_strides);
}
else
{
return make_tuple(Array<index_t, 0>{}, Array<index_t, 0>{});
}
}
};
template <typename LowLength>
struct PassThrough
struct PassThrough : public BaseTransform<1, 1>
{
static constexpr auto type_enum = IndexTransformEnum::PassThrough;
using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>;
......@@ -25,9 +72,10 @@ struct PassThrough
{
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ static constexpr auto GetTypeEnum()
{
return IndexTransformEnum::PassThrough;
}
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
......@@ -41,16 +89,11 @@ struct PassThrough
idx_low(Number<0>{}) = idx_up[Number<0>{}];
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
__host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&,
Number<Hack>)
const UpIdx&)
{
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
UpIdx::Size() == 1,
......@@ -63,8 +106,6 @@ struct PassThrough
idx_low += idx_diff_low;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
......@@ -82,12 +123,24 @@ struct PassThrough
return is_known_at_compile_time<UpLengths>::value;
}
// MUST be static function
template <typename LowVectorLengths, typename LowVectorStrides>
__host__ __device__ static constexpr auto
CalculateUpperDimensionSafeVectorLengthStrides(const LowVectorLengths& low_vector_lengths,
const LowVectorStrides& low_vector_strides)
{
return make_tuple(low_vector_lengths, low_vector_strides);
}
__host__ __device__ void Print() const
{
printf("{");
printf("PassThrough, ");
printf("up_lengths_");
print_multi_index(up_lengths_);
printf("PassThrough{");
//
printf("up_lengths_:");
print(up_lengths_);
//
printf("}");
}
};
......@@ -96,7 +149,7 @@ template <typename LowLength,
typename LeftPadLength,
typename RightPadLength,
bool SkipIsValidCheck = false>
struct Pad
struct Pad : public BaseTransform<1, 1>
{
using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>;
......@@ -107,7 +160,7 @@ struct Pad
LeftPadLength left_pad_length_;
RightPadLength right_pad_length_;
__host__ __device__ constexpr Pad() = default;
__host__ __device__ constexpr Pad() : up_lengths_{}, left_pad_length_{}, right_pad_length_{} {}
__host__ __device__ constexpr Pad(const LowLength& low_length,
const LeftPadLength& left_pad_length,
......@@ -118,10 +171,6 @@ struct Pad
{
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
......@@ -134,16 +183,11 @@ struct Pad
idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_length_;
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
__host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&,
Number<Hack>)
const UpIdx&)
{
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
UpIdx::Size() == 1,
......@@ -156,8 +200,6 @@ struct Pad
idx_low += idx_diff_low;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return SkipIsValidCheck;
......@@ -181,12 +223,22 @@ struct Pad
__host__ __device__ void Print() const
{
printf("{");
printf("Pad, ");
printf("up_lengths_");
print_multi_index(up_lengths_);
printf("left_pad_length %d", index_t{left_pad_length_});
printf("right_pad_length %d", index_t{right_pad_length_});
printf("Pad{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("left_pad_length_: ");
print(left_pad_length_);
printf(", ");
//
printf("right_pad_length_: ");
print(right_pad_length_);
printf("}");
}
};
......@@ -210,10 +262,6 @@ struct LeftPad
{
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
......@@ -226,16 +274,11 @@ struct LeftPad
idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_length_;
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
__host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&,
Number<Hack>)
const UpIdx&)
{
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
UpIdx::Size() == 1,
......@@ -248,8 +291,6 @@ struct LeftPad
idx_low += idx_diff_low;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return SkipIsValidCheck;
......@@ -270,17 +311,23 @@ struct LeftPad
__host__ __device__ void Print() const
{
printf("{");
printf("LeftPad, ");
printf("up_lengths_");
print_multi_index(up_lengths_);
printf("left_pad_length_ %d", index_t{left_pad_length_});
printf("LeftPad{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("left_pad_length_: ");
print(left_pad_length_);
printf("}");
}
};
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
struct RightPad
struct RightPad : public BaseTransform<1, 1>
{
using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>;
......@@ -301,10 +348,6 @@ struct RightPad
{
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
......@@ -317,16 +360,11 @@ struct RightPad
idx_low(Number<0>{}) = idx_up[Number<0>{}];
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
__host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&,
Number<Hack>)
const UpIdx&)
{
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
UpIdx::Size() == 1,
......@@ -339,8 +377,6 @@ struct RightPad
idx_low += idx_diff_low;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return SkipIsValidCheck;
......@@ -362,12 +398,17 @@ struct RightPad
__host__ __device__ void Print() const
{
printf("{");
printf("RightPad, ");
printf("up_lengths_");
print_multi_index(up_lengths_);
printf("low_length_ %d", index_t{low_length_});
printf("left_pad_length_ %d", index_t{right_pad_length_});
printf("LeftPad{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("right_pad_length_: ");
print(right_pad_length_);
printf("}");
}
};
......@@ -381,7 +422,7 @@ struct RightPad
template <typename UpLengths,
typename Coefficients,
typename enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
struct Embed
struct Embed : public BaseTransform<1, UpLengths::Size()>
{
static constexpr index_t NDimUp = UpLengths::Size();
......@@ -399,9 +440,7 @@ struct Embed
{
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; }
__host__ __device__ static constexpr auto GetTypeEnum() { return IndexTransformEnum::Embed; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
......@@ -417,735 +456,29 @@ struct Embed
static_for<0, NDimUp, 1>{}([&idx_low, &idx_up, this](auto i) {
idx_low(Number<0>{}) += idx_up[i] * this->coefficients_[i];
});
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&,
Number<Hack>) const
{
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == NDimUp &&
LowIdx::Size() == 1 && UpIdx::Size() == NDimUp,
"wrong! inconsistent # of dimension");
idx_diff_low(Number<0>{}) = 0;
static_for<0, NDimUp, 1>{}(
[&](auto i) { idx_diff_low(Number<0>{}) += idx_diff_up[i] * coefficients_[i]; });
idx_low += idx_diff_low;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
template <typename UpIdx>
__host__ __device__ static constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
{
return true;
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<UpLengths>::value &&
is_known_at_compile_time<Coefficients>::value;
}
__host__ __device__ void Print() const
{
printf("{");
printf("Embed, ");
printf("up_lengths_ ");
print_multi_index(up_lengths_);
printf("coefficients_ ");
print_multi_index(coefficients_);
printf("}");
}
};
// Implementation of "Merge" transformation primitive that uses regular to do lowering of
// multi-index and use carry-and-borrow check to do lowering of multi-index delta
template <typename LowLengths>
struct Merge_v1_carry_check
{
static constexpr index_t NDimLow = LowLengths::Size();
using LowerIndex = MultiIndex<NDimLow>;
using UpperIndex = MultiIndex<1>;
using LowLengthsScan =
decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{}));
using UpLengths =
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
LowLengths low_lengths_;
LowLengthsScan low_lengths_scan_;
UpLengths up_lengths_;
__host__ __device__ constexpr Merge_v1_carry_check() = default;
__host__ __device__ constexpr Merge_v1_carry_check(const LowLengths& low_lengths)
: low_lengths_{low_lengths},
low_lengths_scan_{
container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})},
up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
{
static_assert(LowerIndex::Size() == NDimLow, "wrong!");
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
index_t tmp = idx_up[Number<0>{}];
// normal division
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
idx_low(i) = tmp / this->low_lengths_scan_[i];
tmp -= idx_low[i] * this->low_lengths_scan_[i];
});
idx_low(Number<NDimLow - 1>{}) = tmp;
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex_1a(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx& /* idx_up_new */,
Number<Hack>) const
{
static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
// CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions.
// However,
// 1) If idx_diff_up is known at compile-time, then idx_diff_low_const
// can be calculated at compile-time.
// 2) If idx_diff_up is not known at compile-time, but its value
// doesn't change during the whole kernel execution, then
// idx_diff_low_const also
// doesn't change during the whole kernel execution. Compiler generated
// ISA should
// only caclculate idx_diff_low_const once and save it durinng the whole
// kernel execution
// If neither 1) nor 2) is satisfied, then the calculation will also be
// computed at
// run-time each time this function is called, and can be very expensive.
LowerIndex idx_diff_low_const;
LowerIndex idx_low_length_minus_idx_diff_low_const;
LowerIndex idx_low_length_plus_idx_diff_low_const;
#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
index_t tmp = idx_diff_up[Number<0>{}];
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
idx_diff_low_const(i) = tmp / low_lengths_scan_[i];
tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
});
idx_diff_low_const(Number<NDimLow - 1>{}) = tmp;
static_for<0, NDimLow, 1>{}([&](auto i) {
idx_low_length_minus_idx_diff_low_const(i) = low_lengths_[i] - idx_diff_low_const[i];
idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i];
});
#else
// Hack: this force result into SGPR. Need to make sure the result is thread invariant
index_t tmp = idx_diff_up[Number<0>{}];
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]);
tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
});
idx_diff_low_const(Number<NDimLow - 1>{}) = __builtin_amdgcn_readfirstlane(tmp);
static_for<0, NDimLow, 1>{}([&](auto i) {
idx_low_length_minus_idx_diff_low_const(i) =
__builtin_amdgcn_readfirstlane(low_lengths_[i] - idx_diff_low_const[i]);
idx_low_length_plus_idx_diff_low_const(i) =
__builtin_amdgcn_readfirstlane(low_lengths_[i] + idx_diff_low_const[i]);
});
#endif
if constexpr(Hack == 1)
{
// do carry check on each low dimension in reversed order
// do not need to check the first dimension
index_t carry = 0;
static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
index_t idx_low_tmp = idx_low[i] + carry;
bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i];
idx_diff_low(i) =
do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i];
idx_diff_low(i) += carry;
carry = do_carry ? 1 : 0;
});
idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry;
idx_low += idx_diff_low;
}
else if constexpr(Hack == 2)
{
// do carry check on each low dimension in reversed order
// do not need to check the first dimension
index_t borrow = 0;
static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
index_t idx_low_tmp = idx_low[i] - borrow;
bool do_borrow = idx_low_tmp < -idx_diff_low_const[i];
idx_diff_low(i) =
do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low_const[i];
idx_diff_low(i) -= borrow;
borrow = do_borrow ? 1 : 0;
});
idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] - borrow;
idx_low += idx_diff_low;
}
else
{
// do carry check on each low dimension in reversed order
// do not need to check the first dimension
index_t carry = 0;
static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
index_t idx_low_tmp = idx_low[i] + carry;
bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i];
bool do_borrow = idx_low_tmp < -idx_diff_low_const[i];
idx_diff_low(i) =
do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i];
idx_diff_low(i) =
do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low[i];
idx_diff_low(i) += carry;
carry = do_carry ? 1 : 0;
carry = do_borrow ? -1 : carry;
});
idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry;
idx_low += idx_diff_low;
}
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex_1b(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx& /* idx_up_new */,
Number<Hack>) const
{
static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
// CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions.
// However,
// 1) If idx_diff_up is known at compile-time, then idx_diff_low_const
// can be calculated at compile-time.
// 2) If idx_diff_up is not known at compile-time, but its value
// doesn't change during the whole kernel execution, then
// idx_diff_low_const also
// doesn't change during the whole kernel execution. Compiler generated
// ISA should
// only caclculate idx_diff_low_const once and save it durinng the whole
// kernel execution
// If neither 1) nor 2) is satisfied, then the calculation will also be
// computed at
// run-time each time this function is called, and can be very expensive.
LowerIndex idx_diff_low_const;
LowerIndex idx_low_length_minus_idx_diff_low_const;
LowerIndex idx_low_length_plus_idx_diff_low_const;
#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
index_t tmp = idx_diff_up[Number<0>{}];
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
idx_diff_low_const(i) = tmp / low_lengths_scan_[i];
tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
});
idx_diff_low_const(Number<NDimLow - 1>{}) = tmp;
static_for<0, NDimLow, 1>{}([&](auto i) {
idx_low_length_minus_idx_diff_low_const(i) = low_lengths_[i] - idx_diff_low_const[i];
idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i];
});
#else
// Hack: this force result into SGPR. Need to make sure the result is thread invariant
index_t tmp = idx_diff_up[Number<0>{}];
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]);
tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
});
idx_diff_low_const(Number<NDimLow - 1>{}) = __builtin_amdgcn_readfirstlane(tmp);
static_for<0, NDimLow, 1>{}([&](auto i) {
idx_low_length_minus_idx_diff_low_const(i) =
__builtin_amdgcn_readfirstlane(low_lengths_[i] - idx_diff_low_const[i]);
idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i];
});
#endif
if constexpr(Hack == 1)
{
// do carry check on each low dimension in reversed order
// do not need to check the first dimension
index_t carry = 0;
static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
index_t idx_low_tmp = idx_low[i] + carry;
bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i];
idx_diff_low(i) =
do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i];
idx_diff_low(i) += carry;
carry = do_carry ? 1 : 0;
});
idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry;
idx_low += idx_diff_low;
}
else if constexpr(Hack == 2)
{
// do carry check on each low dimension in reversed order
// do not need to check the first dimension
index_t borrow = 0;
static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
index_t negative_idx_low_tmp = borrow - idx_low[i];
bool do_borrow = negative_idx_low_tmp > idx_diff_low_const[i];
idx_diff_low(i) =
do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low_const[i];
idx_diff_low(i) -= borrow;
borrow = do_borrow ? 1 : 0;
});
idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] - borrow;
idx_low += idx_diff_low;
}
else
{
// do carry check on each low dimension in reversed order
// do not need to check the first dimension
index_t carry = 0;
static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
index_t idx_low_tmp = idx_low[i] + carry;
bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i];
bool do_borrow = idx_low_tmp < -idx_diff_low_const[i];
idx_diff_low(i) =
do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i];
idx_diff_low(i) =
do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low[i];
idx_diff_low(i) += carry;
carry = do_carry ? 1 : 0;
carry = do_borrow ? -1 : carry;
});
idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry;
idx_low += idx_diff_low;
}
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex_2(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx& /* idx_up_new */,
Number<Hack>) const
{
static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
// CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions.
// However,
// 1) If idx_diff_up is known at compile-time, then idx_diff_low_const
// can be calculated at compile-time.
// 2) If idx_diff_up is not known at compile-time, but its value
// doesn't change during the whole kernel execution, then
// idx_diff_low_const also
// doesn't change during the whole kernel execution. Compiler generated
// ISA should
// only caclculate idx_diff_low_const once and save it durinng the whole
// kernel execution
// If neither 1) nor 2) is satisfied, then the calculation will also be
// computed at run-time each time this function is called, and can be
// very expensive.
LowerIndex idx_diff_low_const;
#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
index_t tmp = idx_diff_up[Number<0>{}];
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
idx_diff_low_const(i) = tmp / low_lengths_scan_[i];
tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
});
idx_diff_low_const(Number<NDimLow - 1>{}) = tmp;
#else
// Hack: this force result into SGPR. Need to make sure the result is thread invariant
index_t tmp = idx_diff_up[Number<0>{}];
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]);
tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
});
idx_diff_low_const(Number<NDimLow - 1>{}) = __builtin_amdgcn_readfirstlane(tmp);
#endif
if constexpr(Hack == 1)
{
// do carry check on each low dimension in reversed order
// do not need to check the first dimension
bool do_carry = 0;
static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
idx_diff_low(i) = idx_diff_low_const[i] + do_carry;
index_t idx_low_tmp = idx_low[i] + idx_diff_low[i];
do_carry = idx_low_tmp >= low_lengths_[i];
#if 0
// TODO: use exec-mask inline asm, which use 1 VALU
if(do_carry)
{
idx_diff_low(i) -= low_lengths_[i];
}
#elif 1
// this use 2 VALU
idx_diff_low(i) = do_carry ? idx_diff_low[i] - low_lengths_[i] : idx_diff_low[i];
#elif 1
// this use 2 VALU
index_t idx_diff_low_tmp = idx_diff_low[i] - low_lengths_[i];
idx_diff_low(i) = do_carry ? idx_diff_low_tmp : idx_diff_low[i];
#endif
idx_low(i) += idx_diff_low[i];
});
constexpr auto I0 = Number<0>{};
idx_diff_low(I0) = idx_diff_low_const[I0] + do_carry;
idx_low(I0) += idx_diff_low[I0];
}
else if constexpr(Hack == 2)
{
// do borrow check on each low dimension in reversed order
// do not need to check the first dimension
bool do_borrow = 0;
static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
idx_diff_low(i) = idx_diff_low_const[i] - do_borrow;
index_t idx_low_tmp = idx_low[i] + idx_diff_low[i];
do_borrow = idx_low_tmp < 0;
#if 0
// TODO: use exec-mask inline asm
if(do_borrow)
{
idx_diff_low(i) += low_lengths_[i];
}
#elif 1
idx_diff_low(i) = do_borrow ? idx_diff_low[i] + low_lengths_[i] : idx_diff_low[i];
#elif 1
index_t idx_diff_low_tmp = idx_diff_low[i] + low_lengths_[i];
idx_diff_low(i) = do_borrow ? idx_diff_low_tmp : idx_diff_low[i];
#endif
idx_low(i) += idx_diff_low[i];
});
constexpr auto I0 = Number<0>{};
idx_diff_low(I0) = idx_diff_low_const[I0] - do_borrow;
idx_low(I0) += idx_diff_low[I0];
}
else
{
// not implemented
}
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx& idx_up_new,
Number<Hack>) const
{
#if 1
UpdateLowerIndex_1a(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
#elif 0
UpdateLowerIndex_1b(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
#else
UpdateLowerIndex_2(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
#endif
}
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<LowLengths>::value &&
is_known_at_compile_time<LowLengthsScan>::value &&
is_known_at_compile_time<UpLengths>::value;
}
template <typename UpIdx>
__host__ __device__ static constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
{
return true;
}
__host__ __device__ void Print() const
{
printf("{");
printf("Merge_v1_carry_check, ");
printf("low_lengths_ ");
print_multi_index(low_lengths_);
printf("low_lengths_scan_ ");
print_multi_index(low_lengths_scan_);
printf("up_lengths_ ");
print_multi_index(up_lengths_);
printf("}");
}
};
template <typename LowLengths>
struct lambda_merge_generate_MagicDivision_calculate_magic_multiplier
{
template <index_t I>
__host__ __device__ constexpr auto operator()(Number<I> i) const
{
return MagicDivision::CalculateMagicMultiplier(LowLengths{}[i]);
}
};
template <typename LowLengths>
struct lambda_merge_generate_MagicDivision_calculate_magic_shift
{
template <index_t I>
__host__ __device__ constexpr auto operator()(Number<I> i) const
{
return MagicDivision::CalculateMagicShift(LowLengths{}[i]);
}
};
// Implementation of "Merge" transformation primitive that uses magic-number-division to do lowering
// of both multi-index and delta of multi-index
// Caution:
// 1. The magic number division implementation being used would produce correct result if the
// dividended is uint32_t and its value is with in 31-bit value range of uint32_t.
// 2. The magic number division for int32_t dividened has not been implemented, the int32_t
// dividend would be bit-wise interpreted as uint32_t and magic number division implementation for
// uint32_t is then used.
// 3. For Merge primitive, upper-index is the dividend.
// 4. When upper-index is uint32_t, its value need to be within 31-bit range.
// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
// non-negative.
template <typename LowLengths>
struct Merge_v2_magic_division
{
static constexpr index_t NDimLow = LowLengths::Size();
using LowerIndex = MultiIndex<NDimLow>;
using UpperIndex = MultiIndex<1>;
using UpLengths =
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
using LowLengthsMagicDivisorMultipiler = decltype(generate_tuple(
lambda_merge_generate_MagicDivision_calculate_magic_multiplier<LowLengths>{},
Number<NDimLow>{}));
using LowLengthsMagicDivisorShift = decltype(generate_tuple(
lambda_merge_generate_MagicDivision_calculate_magic_shift<LowLengths>{},
Number<NDimLow>{}));
LowLengths low_lengths_;
LowLengthsMagicDivisorMultipiler low_lengths_magic_divisor_multiplier_;
LowLengthsMagicDivisorShift low_lengths_magic_divisor_shift_;
UpLengths up_lengths_;
__host__ __device__ constexpr Merge_v2_magic_division() = default;
__host__ __device__ constexpr Merge_v2_magic_division(const LowLengths& low_lengths)
: low_lengths_{low_lengths},
low_lengths_magic_divisor_multiplier_{generate_tuple(
[&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths[i]); },
Number<NDimLow>{})},
low_lengths_magic_divisor_shift_{generate_tuple(
[&](auto i) { return MagicDivision::CalculateMagicShift(low_lengths[i]); },
Number<NDimLow>{})},
up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
{
static_assert(LowerIndex::Size() == NDimLow, "wrong!");
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
index_t tmp = idx_up[Number<0>{}];
static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
index_t tmp2 =
MagicDivision::DoMagicDivision(tmp,
this->low_lengths_magic_divisor_multiplier_[i],
this->low_lengths_magic_divisor_shift_[i]);
idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
tmp = tmp2;
});
idx_low(Number<0>{}) = tmp;
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff&,
LowIdx& idx_low,
const UpIdx& idx_up_new,
Number<Hack>) const
{
static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
index_t tmp = idx_up_new[Number<0>{}];
static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
index_t tmp2 =
MagicDivision::DoMagicDivision(tmp,
this->low_lengths_magic_divisor_multiplier_[i],
this->low_lengths_magic_divisor_shift_[i]);
index_t idx_low_old = idx_low[i];
idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
tmp = tmp2;
idx_diff_low(i) = idx_low[i] - idx_low_old;
});
}
idx_diff_low(Number<0>{}) = tmp - idx_low(Number<0>{});
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&) const
{
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == NDimUp &&
LowIdx::Size() == 1 && UpIdx::Size() == NDimUp,
"wrong! inconsistent # of dimension");
idx_low(Number<0>{}) = tmp;
}
idx_diff_low(Number<0>{}) = 0;
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
static_for<0, NDimUp, 1>{}(
[&](auto i) { idx_diff_low(Number<0>{}) += idx_diff_up[i] * coefficients_[i]; });
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
idx_low += idx_diff_low;
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return is_known_at_compile_time<LowLengths>::value &&
is_known_at_compile_time<LowLengthsMagicDivisorMultipiler>::value &&
is_known_at_compile_time<LowLengthsMagicDivisorShift>::value &&
is_known_at_compile_time<UpLengths>::value;
return true;
}
template <typename UpIdx>
......@@ -1155,22 +488,39 @@ struct Merge_v2_magic_division
return true;
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<UpLengths>::value &&
is_known_at_compile_time<Coefficients>::value;
}
__host__ __device__ void Print() const
{
printf("{");
printf("Merge_v2_magic_division, ");
printf("low_lengths_ ");
print_multi_index(low_lengths_);
printf("low_lengths_magic_divisor_multiplier_ ");
print_multi_index(low_lengths_magic_divisor_multiplier_);
printf("low_lengths_magic_divisor_shift_ ");
print_multi_index(low_lengths_magic_divisor_shift_);
printf("up_lengths_ ");
print_multi_index(up_lengths_);
printf("Embed{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("coefficients_: ");
print(coefficients_);
printf("}");
}
};
template <typename LowLengths>
struct lambda_merge_generate_MagicDivision_calculate_magic_divisor
{
template <index_t I>
__host__ __device__ constexpr auto operator()(Number<I> i) const
{
return MagicDivision::CalculateMagicNumbers(LowLengths{}[i]);
}
};
// Implementation of "Merge" transformation primitive that uses magic-number-division to do lowering
// of both multi-index and delta of multi-index
// Caution:
......@@ -1184,53 +534,40 @@ struct Merge_v2_magic_division
// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
// non-negative.
template <typename LowLengths>
struct Merge_v2r2_magic_division
struct Merge_v2_magic_division : public BaseTransform<LowLengths::Size(), 1>
{
static constexpr index_t NDimLow = LowLengths::Size();
using LowerIndex = MultiIndex<NDimLow>;
using UpperIndex = MultiIndex<1>;
using LowLengthsScan =
decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{}));
using UpLengths =
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
using LowLengthsScanMagicDivisorMultipiler = decltype(generate_tuple(
lambda_merge_generate_MagicDivision_calculate_magic_multiplier<LowLengthsScan>{},
Number<NDimLow>{}));
using LowLengthsScanMagicDivisorShift = decltype(generate_tuple(
lambda_merge_generate_MagicDivision_calculate_magic_shift<LowLengthsScan>{},
Number<NDimLow>{}));
using LowLengthsMagicDivisor = decltype(
generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_divisor<LowLengths>{},
Number<NDimLow>{}));
LowLengths low_lengths_;
LowLengthsScan low_lengths_scan_;
LowLengthsScanMagicDivisorMultipiler low_lengths_scan_magic_divisor_multiplier_;
LowLengthsScanMagicDivisorShift low_lengths_scan_magic_divisor_shift_;
LowLengthsMagicDivisor low_lengths_magic_divisor_;
UpLengths up_lengths_;
__host__ __device__ constexpr Merge_v2r2_magic_division() = default;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ constexpr Merge_v2_magic_division() = default;
__host__ __device__ constexpr Merge_v2r2_magic_division(const LowLengths& low_lengths)
__host__ __device__ constexpr Merge_v2_magic_division(const LowLengths& low_lengths)
: low_lengths_{low_lengths},
low_lengths_scan_{
container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})},
low_lengths_scan_magic_divisor_multiplier_{generate_tuple(
[&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths_scan_[i]); },
low_lengths_magic_divisor_{generate_tuple(
[&](auto i) { return MagicDivision::CalculateMagicNumbers(low_lengths[i]); },
Number<NDimLow>{})},
low_lengths_scan_magic_divisor_shift_{generate_tuple(
[&](auto i) { return MagicDivision::CalculateMagicShift(low_lengths_scan_[i]); },
Number<NDimLow>{})},
up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, I1))}
{
static_assert(LowerIndex::Size() == NDimLow, "wrong!");
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ static constexpr auto GetTypeEnum() { return IndexTransformEnum::Merge; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
......@@ -1241,30 +578,24 @@ struct Merge_v2r2_magic_division
static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
index_t tmp = idx_up[Number<0>{}];
static_for<0, NDimLow - 1, 1>{}([&, this](auto i) {
idx_low(i) =
MagicDivision::DoMagicDivision(tmp,
this->low_lengths_scan_magic_divisor_multiplier_[i],
this->low_lengths_scan_magic_divisor_shift_[i]);
index_t tmp = idx_up[I0];
tmp -= idx_low[i] * this->low_lengths_scan_[i];
static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
index_t tmp2 = MagicDivision::DoMagicDivision(tmp,
this->low_lengths_magic_divisor_[i][I0],
this->low_lengths_magic_divisor_[i][I1]);
idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
tmp = tmp2;
});
idx_low(Number<NDimLow - 1>{}) = tmp;
idx_low(Number<0>{}) = tmp;
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff&,
LowIdx& idx_low,
const UpIdx& idx_up_new,
Number<Hack>) const
const UpIdx& idx_up_new) const
{
static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
......@@ -1272,26 +603,24 @@ struct Merge_v2r2_magic_division
index_t tmp = idx_up_new[Number<0>{}];
static_for<0, NDimLow - 1, 1>{}([&, this](auto i) {
static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
index_t tmp2 = MagicDivision::DoMagicDivision(tmp,
this->low_lengths_magic_divisor_[i][I0],
this->low_lengths_magic_divisor_[i][I1]);
index_t idx_low_old = idx_low[i];
idx_low(i) =
MagicDivision::DoMagicDivision(tmp,
this->low_lengths_scan_magic_divisor_multiplier_[i],
this->low_lengths_scan_magic_divisor_shift_[i]);
idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
tmp = tmp2;
idx_diff_low(i) = idx_low[i] - idx_low_old;
tmp -= idx_low[i] * this->low_lengths_scan_[i];
});
idx_diff_low(Number<NDimLow - 1>{}) = tmp - idx_low[Number<NDimLow - 1>{}];
idx_diff_low(Number<0>{}) = tmp - idx_low(Number<0>{});
idx_low(Number<NDimLow - 1>{}) = tmp;
idx_low(Number<0>{}) = tmp;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
......@@ -1300,8 +629,7 @@ struct Merge_v2r2_magic_division
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<LowLengths>::value &&
is_known_at_compile_time<LowLengthsScanMagicDivisorMultipiler>::value &&
is_known_at_compile_time<LowLengthsScanMagicDivisorShift>::value &&
is_known_at_compile_time<LowLengthsMagicDivisor>::value &&
is_known_at_compile_time<UpLengths>::value;
}
......@@ -1312,20 +640,34 @@ struct Merge_v2r2_magic_division
return true;
}
// MUST be static function
template <typename LowVectorLengths, typename LowVectorStrides>
__host__ __device__ static constexpr auto
CalculateUpperDimensionSafeVectorLengthStrides(const LowVectorLengths& low_vector_lengths,
const LowVectorStrides& low_vector_strides)
{
Array<index_t, 1> up_vector_lengths{-1};
Array<index_t, 1> up_vector_strides{-1};
up_vector_lengths(0) = low_vector_lengths[Number<NDimLow - 1>{}];
up_vector_strides(0) = low_vector_strides[Number<NDimLow - 1>{}];
return make_tuple(up_vector_lengths, up_vector_strides);
}
__host__ __device__ void Print() const
{
printf("{");
printf("Merge_v2r2_magic_division, ");
printf("Merge_v2_magic_division{");
//
printf("low_lengths_ ");
print_multi_index(low_lengths_);
printf("low_lengths_scan ");
print_multi_index(low_lengths_scan_);
printf("low_lengths_scan_magic_divisor_multiplier_ ");
print_multi_index(low_lengths_scan_magic_divisor_multiplier_);
printf("low_lengths_scan_magic_divisor_shift_ ");
print_multi_index(low_lengths_scan_magic_divisor_shift_);
print(low_lengths_);
printf(", ");
//
printf("up_lengths_ ");
print_multi_index(up_lengths_);
print(up_lengths_);
printf("}");
}
};
......@@ -1334,7 +676,7 @@ struct Merge_v2r2_magic_division
// be used for low_lengths that are known at compile time and are power of 2, otherwise performance
// will be very bad
template <typename LowLengths>
struct Merge_v3_division_mod
struct Merge_v3_division_mod : public BaseTransform<LowLengths::Size(), 1>
{
static constexpr index_t NDimLow = LowLengths::Size();
......@@ -1362,10 +704,6 @@ struct Merge_v3_division_mod
static_assert(LowerIndex::Size() == NDimLow, "wrong!");
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
......@@ -1386,16 +724,11 @@ struct Merge_v3_division_mod
idx_low(Number<NDimLow - 1>{}) = tmp;
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff&,
LowIdx& idx_low,
const UpIdx& idx_up_new,
Number<Hack>) const
const UpIdx& idx_up_new) const
{
static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
......@@ -1418,8 +751,6 @@ struct Merge_v3_division_mod
idx_diff_low(INm1) = idx_low[INm1] - tmp2;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
......@@ -1439,22 +770,45 @@ struct Merge_v3_division_mod
return true;
}
// MUST be static function
template <typename LowVectorLengths, typename LowVectorStrides>
__host__ __device__ static constexpr auto
CalculateUpperDimensionSafeVectorLengthStrides(const LowVectorLengths& low_vector_lengths,
const LowVectorStrides& low_vector_strides)
{
Array<index_t, 1> up_vector_lengths{-1};
Array<index_t, 1> up_vector_strides{-1};
up_vector_lengths(0) = low_vector_lengths[Number<NDimLow - 1>{}];
up_vector_strides(0) = low_vector_strides[Number<NDimLow - 1>{}];
return make_tuple(up_vector_lengths, up_vector_strides);
}
__host__ __device__ void Print() const
{
printf("{");
printf("Merge_v3_direct_division_mod, ");
printf("Merge_v3_direct_division_mod{");
//
printf("low_lengths_ ");
print_multi_index(low_lengths_);
print(low_lengths_);
printf(", ");
//
printf("low_lengths_scan_ ");
print_multi_index(low_lengths_scan_);
print(low_lengths_scan_);
printf(", ");
//
printf("up_lengths_ ");
print_multi_index(up_lengths_);
print(up_lengths_);
printf("}");
}
};
template <typename UpLengths, bool Use24BitIntegerCalculation>
struct UnMerge
struct UnMerge : public BaseTransform<1, UpLengths::Size()>
{
static constexpr index_t NDimUp = UpLengths::Size();
......@@ -1476,9 +830,7 @@ struct UnMerge
{
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; }
__host__ __device__ static constexpr auto GetTypeEnum() { return IndexTransformEnum::UnMerge; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
......@@ -1505,24 +857,17 @@ struct UnMerge
}
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&,
Number<Hack>) const
const UpIdx&) const
{
CalculateLowerIndex(idx_diff_low, idx_diff_up);
idx_low += idx_diff_low;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
......@@ -1541,20 +886,49 @@ struct UnMerge
is_known_at_compile_time<UpLengthsScan>::value;
}
// MUST be static function
template <typename LowVectorLengths, typename LowVectorStrides>
__host__ __device__ static constexpr auto
CalculateUpperDimensionSafeVectorLengthStrides(const LowVectorLengths& low_vector_lengths,
const LowVectorStrides& low_vector_strides)
{
Array<index_t, NDimUp> up_vector_lengths{-1};
Array<index_t, NDimUp> up_vector_strides{-1};
constexpr auto up_length_last = UpLengths{}[Number<NDimUp - 1>{}];
if constexpr(is_known_at_compile_time<decltype(up_length_last)>::value)
{
if(low_vector_lengths[0] != -1)
{
up_vector_lengths(NDimUp - 1) = math::gcd(low_vector_lengths[0], up_length_last);
}
}
up_vector_strides(NDimUp - 1) = low_vector_strides[0];
return make_tuple(up_vector_lengths, up_vector_strides);
}
__host__ __device__ void Print() const
{
printf("{");
printf("UnMerge, ");
printf("UnMerge{");
//
printf("up_lengths_");
print_multi_index(up_lengths_);
print(up_lengths_);
printf(", ");
//
printf("up_lengths_scan_");
print_multi_index(up_lengths_scan_);
print(up_lengths_scan_);
printf("}");
}
};
template <typename LowerIndex>
struct Freeze
struct Freeze : public BaseTransform<1, 0>
{
LowerIndex low_idx_;
......@@ -1562,10 +936,6 @@ struct Freeze
__host__ __device__ constexpr Freeze(const LowerIndex& low_idx) : low_idx_{low_idx} {}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 0; }
__host__ __device__ static constexpr auto GetUpperLengths() { return Tuple<>{}; }
template <typename LowIdx, typename UpIdx>
......@@ -1578,22 +948,15 @@ struct Freeze
idx_low(Number<0>{}) = low_idx_;
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
__host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& /* idx_diff_up */,
LowIdx& /* idx_low */,
const UpIdx& /* idx_up_new */,
Number<Hack>)
const UpIdx& /* idx_up_new */)
{
idx_diff_low(Number<0>{}) = 0;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
......@@ -1613,54 +976,46 @@ struct Freeze
__host__ __device__ void Print() const
{
printf("Freeze");
printf("low_idx_ %d", index_t{low_idx_});
printf("Freeze{");
//
printf("low_idx_: ");
print(low_idx_);
printf("}");
}
};
// Insert a dangling upper dimension without lower dimension
template <typename UpperLength>
struct Insert
// Replicate the original tensor and create a higher dimensional tensor
template <typename UpLengths>
struct Replicate : public BaseTransform<0, UpLengths::Size()>
{
using UpLengths = decltype(make_tuple(UpperLength{}));
UpLengths up_lengths_;
static constexpr index_t NDimUp = UpLengths::Size();
__host__ __device__ constexpr Insert() = default;
__host__ __device__ constexpr Replicate() = default;
__host__ __device__ constexpr Insert(const UpperLength& up_length)
: up_lengths_{make_tuple(up_length)}
__host__ __device__ constexpr Replicate(const UpLengths& up_lengths) : up_lengths_{up_lengths}
{
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 0; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr auto GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx&, const UpIdx&) const
{
static_assert(LowIdx::Size() == 0 && UpIdx::Size() == 1,
static_assert(LowIdx::Size() == 0 && UpIdx::Size() == NDimUp,
"wrong! inconsistent # of dimension");
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
__host__ __device__ static void
UpdateLowerIndex(LowIdxDiff&, const UpIdxDiff&, LowIdx&, const UpIdx&, Number<Hack>)
UpdateLowerIndex(LowIdxDiff&, const UpIdxDiff&, LowIdx&, const UpIdx&)
{
static_assert(LowIdxDiff::Size() == 0 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 0 &&
UpIdx::Size() == 1,
static_assert(LowIdxDiff::Size() == 0 && UpIdxDiff::Size() == NDimUp &&
LowIdx::Size() == 0 && UpIdx::Size() == NDimUp,
"wrong! inconsistent # of dimension");
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
......@@ -1675,39 +1030,47 @@ struct Insert
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<UpperLength>::value;
return is_known_at_compile_time<UpLengths>::value;
}
__host__ __device__ void Print() const
{
printf("Insert");
print_multi_index(up_lengths_);
printf("Replicate{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf("}");
}
//
UpLengths up_lengths_;
};
template <typename VectorSize, typename UpLength>
struct Vectorize
template <typename LowLength, typename SliceBegin, typename SliceEnd>
struct Slice : public BaseTransform<1, 1>
{
using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>;
using UpLengths = decltype(make_tuple(UpLength{}));
using UpLengths = decltype(make_tuple(SliceEnd{} - SliceBegin{}));
UpLengths up_lengths_;
VectorSize vector_size_;
SliceBegin slice_begin_;
SliceEnd slice_end_;
__host__ __device__ constexpr Vectorize() = default;
__host__ __device__ constexpr Slice() = default;
__host__ __device__ constexpr Vectorize(const VectorSize& vector_size,
const UpLength& up_length)
: vector_size_{vector_size}, up_lengths_{make_tuple(up_length)}
__host__ __device__ constexpr Slice(const LowLength&,
const SliceBegin& slice_begin,
const SliceEnd& slice_end)
: up_lengths_{make_tuple(slice_end - slice_begin)},
slice_begin_{slice_begin},
slice_end_{slice_end}
{
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
......@@ -1717,19 +1080,14 @@ struct Vectorize
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
idx_low(Number<0>{}) = vector_size_ * idx_up[Number<0>{}];
idx_low(Number<0>{}) = idx_up[Number<0>{}] + slice_begin_;
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&,
Number<Hack>) const
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
__host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&)
{
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
UpIdx::Size() == 1,
......@@ -1737,67 +1095,72 @@ struct Vectorize
constexpr auto I0 = Number<0>{};
idx_diff_low(I0) = vector_size_ * idx_diff_up[I0];
idx_diff_low(I0) = idx_diff_up[I0];
idx_low += idx_diff_low;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
template <typename UpIdx>
__host__ __device__ static constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
__host__ __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx&) const
{
return true;
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<UpLengths>::value;
return is_known_at_compile_time<UpLengths>::value &&
is_known_at_compile_time<SliceBegin>::value &&
is_known_at_compile_time<SliceEnd>::value;
}
__host__ __device__ void Print() const
{
printf("{");
printf("Vectorize, ");
printf("up_lengths_");
print_multi_index(up_lengths_);
printf("Slice{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("slice_begin_: ");
print(slice_begin_);
printf(", ");
//
printf("slice_end_: ");
print(slice_end_);
printf("}");
}
};
} // namespace ck
}; // namespace ck
template <typename LowLength, typename SliceBegin, typename SliceEnd>
struct Slice
/*
* \brief lower_idx = upper_idx % modulus.
* TODO: Need an improved implementation since the modulo operation is expensive.
*/
template <typename Modulus, typename UpLength>
struct Modulo : public BaseTransform<1, 1>
{
using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>;
using UpLengths = decltype(make_tuple(UpLength{}));
using UpLengths = decltype(make_tuple(SliceEnd{} - SliceBegin{}));
Modulus modulus_;
UpLengths up_lengths_;
SliceBegin slice_begin_;
SliceEnd slice_end_;
__host__ __device__ constexpr Slice() = default;
__host__ __device__ constexpr Modulo() = default;
__host__ __device__ constexpr Slice(const LowLength&,
const SliceBegin& slice_begin,
const SliceEnd& slice_end)
: up_lengths_{make_tuple(slice_end - slice_begin)},
slice_begin_{slice_begin},
slice_end_{slice_end}
__host__ __device__ constexpr Modulo(const Modulus& modulus, const UpLength& up_length)
: modulus_{modulus}, up_lengths_{make_tuple(up_length)}
{
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
......@@ -1807,19 +1170,14 @@ struct Slice
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
idx_low(Number<0>{}) = idx_up[Number<0>{}] + slice_begin_;
idx_low(Number<0>{}) = idx_up[Number<0>{}] % modulus_;
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&,
Number<Hack>)
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx& up_idx) const
{
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
UpIdx::Size() == 1,
......@@ -1827,67 +1185,62 @@ struct Slice
constexpr auto I0 = Number<0>{};
idx_diff_low(I0) = idx_diff_up[I0];
idx_low += idx_diff_low;
const auto idx_low_old = idx_low;
idx_low(I0) = (up_idx(I0) + idx_diff_up(I0)) % modulus_;
idx_diff_low(I0) = idx_low - idx_low_old;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
template <typename UpIdx>
__host__ __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx&) const
__host__ __device__ static constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
{
return true;
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<UpLengths>::value &&
is_known_at_compile_time<SliceBegin>::value &&
is_known_at_compile_time<SliceEnd>::value;
return is_known_at_compile_time<UpLengths>::value;
}
__host__ __device__ void Print() const
{
printf("{");
printf("Slice, ");
printf("up_lengths_");
print_multi_index(up_lengths_);
printf("slice_begin_ %d", index_t{slice_begin_});
printf("slice_end %d", index_t{slice_end_});
printf("Modulus{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf("}");
}
};
/*
* \brief lower_idx = upper_idx % modulus.
* TODO: Need an improved implementation since the modulo operation is expensive.
*/
template <typename Modulus, typename UpLength>
struct Modulo
// 2D XOR
template <typename LowLengths, typename RightShift>
struct Xor : public BaseTransform<2, 2>
{
using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>;
using UpLengths = decltype(make_tuple(UpLength{}));
static constexpr auto type_enum = IndexTransformEnum::Xor;
using LowerIndex = MultiIndex<2>;
using UpperIndex = MultiIndex<2>;
using UpLengths = LowLengths;
Modulus modulus_;
UpLengths up_lengths_;
RightShift right_shift_;
__host__ __device__ constexpr Modulo() = default;
__host__ __device__ constexpr Xor() : up_lengths_{}, right_shift_{} {}
__host__ __device__ constexpr Modulo(const Modulus& modulus, const UpLength& up_length)
: modulus_{modulus}, up_lengths_{make_tuple(up_length)}
__host__ __device__ constexpr Xor(const LowLengths& low_lengths, const RightShift& right_shift)
: up_lengths_{low_lengths}, right_shift_{right_shift}
{
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ static constexpr auto GetTypeEnum() { return IndexTransformEnum::Xor; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
......@@ -1895,35 +1248,36 @@ struct Modulo
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
static_assert(LowIdx::Size() == 2 && UpIdx::Size() == 2,
"wrong! inconsistent # of dimension");
idx_low(Number<0>{}) = idx_up[Number<0>{}] % modulus_;
idx_low(Number<0>{}) = idx_up[Number<0>{}];
const auto idx_low_1_tmp =
(idx_up[Number<1>{}] - idx_up[Number<0>{}] * right_shift_) % up_lengths_[Number<1>{}];
const auto idx_low_1 =
(idx_low_1_tmp >= 0) ? idx_low_1_tmp : up_lengths_[Number<1>{}] + idx_low_1_tmp;
idx_low(Number<1>{}) = idx_low_1;
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
const UpIdxDiff&,
LowIdx& idx_low,
const UpIdx& up_idx,
Number<Hack>) const
const UpIdx& idx_up) const
{
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
UpIdx::Size() == 1,
static_assert(LowIdxDiff::Size() == 2 && UpIdxDiff::Size() == 2 && LowIdx::Size() == 2 &&
UpIdx::Size() == 2,
"wrong! inconsistent # of dimension");
constexpr auto I0 = Number<0>{};
const auto idx_low_old = idx_low;
idx_low(I0) = (up_idx(I0) + idx_diff_up(I0)) % modulus_;
idx_diff_low(I0) = idx_low - idx_low_old;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
CalculateLowerIndex(idx_low, idx_up);
idx_diff_low = idx_low - idx_low_old;
}
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
......@@ -1939,16 +1293,45 @@ struct Modulo
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<UpLengths>::value;
return is_known_at_compile_time<UpLengths>::value &&
is_known_at_compile_time<RightShift>::value;
}
// MUST be static function
template <typename LowVectorLengths, typename LowVectorStrides>
__host__ __device__ constexpr auto
CalculateUpperDimensionSafeVectorLengthStrides(const LowVectorLengths& low_vector_lengths,
const LowVectorStrides& low_vector_strides) const
{
Array<index_t, 2> up_vector_lengths = low_vector_lengths;
Array<index_t, 2> up_vector_strides = low_vector_strides;
if constexpr(is_known_at_compile_time<RightShift>::value)
{
if(low_vector_lengths[1] != -1)
{
up_vector_lengths(1) = math::gcd(low_vector_lengths[1], math::abs(right_shift_));
}
}
return make_tuple(up_vector_lengths, up_vector_strides);
}
__host__ __device__ void Print() const
{
printf("{");
printf("Modulus, ");
printf("up_lengths_");
print_multi_index(up_lengths_);
printf("Xor{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("right_shift_: ");
print(right_shift_);
printf("}");
}
};
} // namespace ck
......@@ -51,32 +51,11 @@ __host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_leng
return Embed<UpLengths, Coefficients>{up_lengths, coefficients};
}
template <typename LowLengths>
__host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths)
{
#if CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
return make_merge_transform_v2_magic_division(low_lengths);
#else
return make_merge_transform_v1_carry_check(low_lengths);
#endif
}
template <typename LowLengths>
__host__ __device__ constexpr auto
make_merge_transform_v1_carry_check(const LowLengths& low_lengths)
{
return Merge_v1_carry_check<LowLengths>{low_lengths};
}
template <typename LowLengths>
__host__ __device__ constexpr auto
make_merge_transform_v2_magic_division(const LowLengths& low_lengths)
{
#if 1
return Merge_v2_magic_division<LowLengths>{low_lengths};
#else
return Merge_v2r2_magic_division<LowLengths>{low_lengths};
#endif
}
template <typename LowLengths>
......@@ -86,6 +65,12 @@ make_merge_transform_v3_division_mod(const LowLengths& low_lengths)
return Merge_v3_division_mod<LowLengths>{low_lengths};
}
template <typename LowLengths>
__host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths)
{
return make_merge_transform_v2_magic_division(low_lengths);
}
template <typename UpLengths, bool Use24BitIntegerCalculation = false>
__host__ __device__ constexpr auto make_unmerge_transform(
const UpLengths& up_lengths,
......@@ -100,10 +85,10 @@ __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_i
return Freeze<LowerIndex>{low_idx};
}
template <typename UpperIndex>
__host__ __device__ constexpr auto make_insert_transform(const UpperIndex& up_idx)
template <typename UpLengths>
__host__ __device__ constexpr auto make_replicate_transform(const UpLengths& up_lengths)
{
return Insert<UpperIndex>{up_idx};
return Replicate<UpLengths>{up_lengths};
}
template <typename LowLength, typename SliceBegin, typename SliceEnd>
......@@ -114,17 +99,18 @@ __host__ __device__ constexpr auto make_slice_transform(const LowLength& low_len
return Slice<LowLength, SliceBegin, SliceEnd>{low_length, slice_begin, slice_end};
}
template <typename VectorSize, typename UpLength>
__host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size,
const UpLength& up_length)
{
return Vectorize<VectorSize, UpLength>{vector_size, up_length};
}
template <typename Modulus, typename UpLength>
__host__ __device__ constexpr auto make_modulo_transform(const Modulus& modulus,
const UpLength& up_length)
{
return Modulo<Modulus, UpLength>{modulus, up_length};
}
template <typename LowLengths, typename RightShift>
__host__ __device__ constexpr auto make_xor_transform(const LowLengths& low_lengths,
const RightShift& right_shift)
{
return Xor<LowLengths, RightShift>{low_lengths, right_shift};
}
} // namespace ck
......@@ -4,8 +4,6 @@
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
namespace ck {
......@@ -35,21 +33,23 @@ struct TensorAdaptor
return UpperDimensionHiddenIdss{};
}
__host__ __device__ static constexpr auto GetTopDimensionHiddenIds()
__host__ __device__ static constexpr auto GetBottomDimensionHiddenIds()
{
return TopDimensionHiddenIds{};
return BottomDimensionHiddenIds{};
}
__host__ __device__ static constexpr auto GetBottomDimensionHiddenIds()
__host__ __device__ static constexpr auto GetTopDimensionHiddenIds()
{
return BottomDimensionHiddenIds{};
return TopDimensionHiddenIds{};
}
__host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms)
{
const auto lengths = generate_tuple(
[&](auto idim_top) {
constexpr auto tmp = GetTransformAndItsUpperDimension(idim_top);
constexpr index_t idim_hidden = TopDimensionHiddenIds::At(idim_top);
constexpr auto tmp = GetTransformAndItsUpperDimension(Number<idim_hidden>{});
constexpr index_t itran = tmp[Number<0>{}];
constexpr index_t idim_up = tmp[Number<1>{}];
......@@ -69,12 +69,12 @@ struct TensorAdaptor
return container_reduce(lengths, math::multiplies{}, Number<1>{});
}
template <index_t IDim>
__host__ __device__ static constexpr auto GetTransformAndItsUpperDimension(Number<IDim>)
template <index_t IDimHidden>
__host__ __device__ static constexpr auto GetTransformAndItsUpperDimension(Number<IDimHidden>)
{
constexpr auto idim_top = Number<IDim>{};
constexpr index_t idim_hidden = TopDimensionHiddenIds::At(idim_top);
// FIXME: length of bottom dimension is not known, since info about lower dim length are not
// saved in transformation
static_assert(IDimHidden >= ndim_bottom_, "wrong! not implemented");
index_t itran_found = 0;
index_t idim_up_found = 0;
......@@ -84,7 +84,7 @@ struct TensorAdaptor
constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran];
static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) {
if constexpr(up_dim_ids[idim_up] == idim_hidden)
if constexpr(up_dim_ids[idim_up] == IDimHidden)
{
itran_found = itran;
idim_up_found = idim_up;
......@@ -138,11 +138,7 @@ struct TensorAdaptor
using ElementSize = remove_cv_t<decltype(InitializeElementSize(Transforms{}))>;
public:
#if 0 // workaround compiler complaint about constexpr
__host__ __device__ constexpr TensorAdaptor() = default;
#else
__host__ __device__ constexpr TensorAdaptor() : transforms_{}, element_size_{} {}
#endif
__host__ __device__ constexpr TensorAdaptor(const Transforms& transforms)
: transforms_{transforms}, element_size_{InitializeElementSize(transforms)}
......@@ -157,17 +153,52 @@ struct TensorAdaptor
__host__ __device__ constexpr auto GetElementSize() const { return element_size_; }
#if 0 // debug
template <index_t I>
__host__ __device__ constexpr index_t GetTopDimensionLength(Number<I> idim) const
// FIXME: this logic is wrong when getting bottome dimension lengths
template <index_t IDimHidden>
__host__ __device__ constexpr auto GetHiddenDimensionLength(Number<IDimHidden>) const
{
// TODO: not implemented
static_assert(IDimHidden >= 0 && IDimHidden < ndim_hidden_, "wrong! out of range");
constexpr auto tmp = GetTransformAndItsUpperDimension(Number<IDimHidden>{});
constexpr index_t itran = tmp[Number<0>{}];
constexpr index_t idim_up = tmp[Number<1>{}];
constexpr bool found = tmp[Number<2>{}];
static_assert(found == true,
"wrong! not found matching transformation and upper-dimension");
return transforms_[Number<itran>{}].GetUpperLengths()[Number<idim_up>{}];
}
template <index_t I>
__host__ __device__ constexpr index_t GetBottomDimensionLength(Number<I> idim) const
template <index_t IDimTop>
__host__ __device__ constexpr auto GetTopDimensionLength(Number<IDimTop> idim_top) const
{
// TODO: not implemented
return GetHiddenDimensionLength(TopDimensionHiddenIds::At(idim_top));
}
#if 0
// FIXME: GetHiddenDimensionLength is wrong when getting bottome dimension lengths
template <index_t IDimBottom>
__host__ __device__ constexpr index_t
GetBottomDimensionLength(Number<IDimBottom> idim_bottom) const
{
return GetHiddenDimensionLength(TopDimensionHiddenIds::At(idim_bottom));
}
#endif
__host__ __device__ constexpr auto GetTopDimensionLengths() const
{
return generate_tuple([&](auto i) { return GetTopDimensionLength(i); },
Number<ndim_top_>{});
}
#if 0
// FIXME: GetHiddenDimensionLength is wrong when getting bottome dimension lengths
__host__ __device__ constexpr auto GetBottomDimensionLengths() const
{
return generate_tuple([&](auto i) { return GetBottomDimensionLength(i); },
Number<ndim_bottom_>{});
}
#endif
......@@ -204,7 +235,7 @@ struct TensorAdaptor
return get_container_subset(idx_hidden, BottomDimensionHiddenIds{});
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
__host__ __device__ static constexpr bool IsStatic()
{
bool is_known = true;
......@@ -215,23 +246,81 @@ struct TensorAdaptor
return is_known && is_known_at_compile_time<ElementSize>::value;
}
__host__ __device__ void Print() const
__host__ __device__ static constexpr bool IsKnownAtCompileTime() { return IsStatic(); }
__host__ __device__ static constexpr auto GetTopDimensionSafeVectorLengthStrides(
const Array<index_t, ndim_hidden_>& guaranteed_vector_lengths,
const Array<index_t, ndim_hidden_>& guaranteed_vector_strides)
{
printf("{");
printf("TensorAdaptor, ");
static_for<0, ntransform_, 1>{}([&](auto i) {
printf("transforms: ");
transforms_[i].Print();
printf("LowerDimensionHiddenIds:");
LowerDimensionHiddenIdss{}.At(i).Print();
printf("UpperDimensionHiddenIds:");
UpperDimensionHiddenIdss{}.At(i).Print();
auto vector_lengths = guaranteed_vector_lengths;
auto vector_strides = guaranteed_vector_strides;
static_for<0, GetNumOfTransform(), 1>{}([&](auto itran) {
constexpr auto low_dims = GetLowerDimensionHiddenIdss().At(itran);
constexpr auto up_dims = GetUpperDimensionHiddenIdss().At(itran);
const auto up_guaranteed_vector_lengths =
get_container_subset(guaranteed_vector_lengths, up_dims);
const auto up_guaranteed_vector_strides =
get_container_subset(guaranteed_vector_strides, up_dims);
// only need type of transform
auto [up_vector_lengths, up_vector_strides] =
Transforms{}.At(itran).CalculateUpperDimensionSafeVectorLengthStrides(
get_container_subset(vector_lengths, low_dims),
get_container_subset(vector_strides, low_dims));
if constexpr(up_dims.Size() > 0)
{
for(index_t i = 0; i < up_dims.Size(); ++i)
{
up_vector_lengths(i) = (up_guaranteed_vector_lengths[i] != -1)
? up_guaranteed_vector_lengths[i]
: up_vector_lengths[i];
up_vector_strides(i) = (up_guaranteed_vector_strides[i] != -1)
? up_guaranteed_vector_strides[i]
: up_vector_strides[i];
}
}
set_container_subset(vector_lengths, up_dims, up_vector_lengths);
set_container_subset(vector_strides, up_dims, up_vector_strides);
});
printf("BottomDimensionHiddenIds:");
BottomDimensionHiddenIds::Print();
printf("TopDimensionHiddenIds:");
TopDimensionHiddenIds::Print();
constexpr auto top_dims = TopDimensionHiddenIds{};
return make_tuple(get_container_subset(vector_lengths, top_dims),
get_container_subset(vector_strides, top_dims));
}
__host__ __device__ void Print() const
{
printf("TensorAdaptor{");
//
printf("transforms: ");
print(transforms_);
printf(", ");
//
printf("LowerDimensionHiddenIds: ");
print(LowerDimensionHiddenIdss{});
printf(", ");
//
printf("UpperDimensionHiddenIds: ");
print(UpperDimensionHiddenIdss{});
printf(", ");
//
printf("BottomDimensionHiddenIds: ");
print(BottomDimensionHiddenIds{});
printf(", ");
//
printf("TopDimensionHiddenIds: ");
print(TopDimensionHiddenIds{});
printf("}");
}
......@@ -241,6 +330,161 @@ struct TensorAdaptor
ElementSize element_size_;
};
// Transforms: Tuple<transforms...>
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
template <typename Transforms, typename LowerDimensionOldTopIdss, typename UpperDimensionNewTopIdss>
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms& transforms,
LowerDimensionOldTopIdss,
UpperDimensionNewTopIdss)
{
constexpr index_t ntransform = Transforms::Size();
static_assert(LowerDimensionOldTopIdss::Size() == ntransform &&
UpperDimensionNewTopIdss::Size() == ntransform,
"wrong!");
// sanity check on LowerDimensionOldTopIdss and UpperDimensionNewTopIdss
constexpr auto all_low_dim_old_top_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionOldTopIdss{});
constexpr auto all_up_dim_new_top_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionNewTopIdss{});
static_assert(is_valid_sequence_map<decltype(all_low_dim_old_top_ids)>::value &&
is_valid_sequence_map<decltype(all_up_dim_new_top_ids)>::value,
"wrong!");
constexpr index_t ndim_old_top = all_low_dim_old_top_ids.Size();
constexpr index_t ndim_new_top = all_up_dim_new_top_ids.Size();
// low_dim_hidden_idss
constexpr auto low_dim_hidden_idss = LowerDimensionOldTopIdss{};
// up_dim_hidden_idss: shift UpperDimensionNewTopIdss by ndim_bottom
constexpr auto up_dim_hidden_idss = generate_tuple(
[](auto itran) { return UpperDimensionNewTopIdss{}[itran] + Number<ndim_old_top>{}; },
Number<ntransform>{});
// bottom_dim_hidden_ids
constexpr auto bottom_dim_hidden_ids =
typename arithmetic_sequence_gen<0, ndim_old_top, 1>::type{};
// top_dim_hidden_ids
constexpr auto top_dim_hidden_ids =
typename arithmetic_sequence_gen<0, ndim_new_top, 1>::type{} + Number<ndim_old_top>{};
return TensorAdaptor<remove_cvref_t<Transforms>,
remove_cvref_t<decltype(low_dim_hidden_idss)>,
remove_cvref_t<decltype(up_dim_hidden_idss)>,
remove_cvref_t<decltype(bottom_dim_hidden_ids)>,
remove_cvref_t<decltype(top_dim_hidden_ids)>>{transforms};
}
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor, and to put it outside the scope where it is used
// (transform_tensor_adaptor) because template cannot be defined inside a function
// template
template <typename NewTransforms>
struct lambda_get_up_dim_num
{
template <typename I>
__host__ __device__ constexpr auto operator()(I) const
{
using Tran = remove_reference_t<decltype(NewTransforms{}.At(I{}))>;
return Number<Tran::GetNumOfUpperDimension()>{};
}
};
template <typename OldTensorAdaptor,
typename NewTransforms,
typename NewLowerDimensionOldTopIdss,
typename NewUpperDimensionNewTopIdss>
__host__ __device__ constexpr auto
transform_tensor_adaptor(const OldTensorAdaptor& old_tensor_adaptor,
const NewTransforms& new_transforms,
NewLowerDimensionOldTopIdss,
NewUpperDimensionNewTopIdss)
{
// sanity check
{
static_assert(NewTransforms::Size() == NewLowerDimensionOldTopIdss::Size() &&
NewTransforms::Size() == NewUpperDimensionNewTopIdss::Size(),
"wrong! inconsitent number of transform");
constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewLowerDimensionOldTopIdss{});
constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewUpperDimensionNewTopIdss{});
static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
is_valid_sequence_map<decltype(all_new_top_ids)>::value,
"wrong!");
}
// lower dimension's hidden idss
// convert lower dimension top idss (tuple of sequences) to hidden idss (tuple of
// sequences)
constexpr auto low_dim_hidden_idss = transform_tuples(
// convert lower dimension top ids (a sequence) to hidden ids (a sequence)
[](auto low_dim_top_ids) constexpr {
return transform_sequences(
// convert lower dimension top id to hidden id
[](auto low_dim_top_id) constexpr {
return OldTensorAdaptor::GetTopDimensionHiddenIds()[low_dim_top_id];
},
low_dim_top_ids);
},
NewLowerDimensionOldTopIdss{});
constexpr index_t num_new_transform = NewTransforms::Size();
// upper dimension's hidden idss
constexpr index_t old_hidden_dim_number = OldTensorAdaptor::GetNumOfHiddenDimension();
constexpr auto up_dim_numbers =
generate_sequence(lambda_get_up_dim_num<NewTransforms>{}, Number<num_new_transform>{});
constexpr auto up_dim_numbers_scan = merge_sequences(
Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, math::plus<index_t>{}, Number<0>{}));
constexpr auto up_dim_hidden_idss = generate_tuple(
[ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr {
return
typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
old_hidden_dim_number + up_dim_numbers_scan[i + 1],
1>::type{};
},
Number<num_new_transform>{});
// new top dimension's hidden ids
constexpr auto unordered_new_top_dim_hidden_ids = unpack(
[](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
constexpr auto new_top_dim_unordered2ordered = unpack(
[](auto... xs) constexpr { return merge_sequences(xs...); }, NewUpperDimensionNewTopIdss{});
constexpr auto new_top_dim_hidden_ids =
unordered_new_top_dim_hidden_ids.ReorderGivenOld2New(new_top_dim_unordered2ordered);
// put everything together
const auto all_transforms =
container_concat(old_tensor_adaptor.GetTransforms(), new_transforms);
constexpr auto all_low_dim_hidden_idss =
container_concat(OldTensorAdaptor::GetLowerDimensionHiddenIdss(), low_dim_hidden_idss);
constexpr auto all_up_dim_hidden_idss =
container_concat(OldTensorAdaptor::GetUpperDimensionHiddenIdss(), up_dim_hidden_idss);
return TensorAdaptor<remove_cvref_t<decltype(all_transforms)>,
remove_cvref_t<decltype(all_low_dim_hidden_idss)>,
remove_cvref_t<decltype(all_up_dim_hidden_idss)>,
remove_cvref_t<decltype(OldTensorAdaptor::GetBottomDimensionHiddenIds())>,
remove_cvref_t<decltype(new_top_dim_hidden_ids)>>{all_transforms};
}
template <typename TensorAdaptor0, typename TensorAdaptor1>
__host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& adaptor0,
const TensorAdaptor1& adaptor1)
......@@ -415,62 +659,11 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
TensorAdaptor1::GetTopDimensionHiddenIds() + Number<adaptor1_hidden_id_shift>{};
// put everything together
return TensorAdaptor<remove_cv_t<decltype(all_transforms)>,
remove_cv_t<decltype(all_low_dim_hidden_idss)>,
remove_cv_t<decltype(all_up_dim_hidden_idss)>,
remove_cv_t<decltype(bottom_dim_hidden_ids)>,
remove_cv_t<decltype(top_dim_hidden_ids)>>{all_transforms};
}
// Transforms: Tuple<transforms...>
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
template <typename Transforms, typename LowerDimensionOldTopIdss, typename UpperDimensionNewTopIdss>
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms& transforms,
LowerDimensionOldTopIdss,
UpperDimensionNewTopIdss)
{
constexpr index_t ntransform = Transforms::Size();
static_assert(LowerDimensionOldTopIdss::Size() == ntransform &&
UpperDimensionNewTopIdss::Size() == ntransform,
"wrong!");
// sanity check on LowerDimensionOldTopIdss and UpperDimensionNewTopIdss
constexpr auto all_low_dim_old_top_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionOldTopIdss{});
constexpr auto all_up_dim_new_top_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionNewTopIdss{});
static_assert(is_valid_sequence_map<decltype(all_low_dim_old_top_ids)>::value &&
is_valid_sequence_map<decltype(all_up_dim_new_top_ids)>::value,
"wrong!");
constexpr index_t ndim_old_top = all_low_dim_old_top_ids.Size();
constexpr index_t ndim_new_top = all_up_dim_new_top_ids.Size();
// low_dim_hidden_idss
constexpr auto low_dim_hidden_idss = LowerDimensionOldTopIdss{};
// up_dim_hidden_idss: shift UpperDimensionNewTopIdss by ndim_bottom
constexpr auto up_dim_hidden_idss = generate_tuple(
[](auto itran) { return UpperDimensionNewTopIdss{}[itran] + Number<ndim_old_top>{}; },
Number<ntransform>{});
// bottom_dim_hidden_ids
constexpr auto bottom_dim_hidden_ids =
typename arithmetic_sequence_gen<0, ndim_old_top, 1>::type{};
// top_dim_hidden_ids
constexpr auto top_dim_hidden_ids =
typename arithmetic_sequence_gen<0, ndim_new_top, 1>::type{} + Number<ndim_old_top>{};
return TensorAdaptor<remove_cv_t<Transforms>,
remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(bottom_dim_hidden_ids)>,
remove_cv_t<decltype(top_dim_hidden_ids)>>{transforms};
return TensorAdaptor<remove_cvref_t<decltype(all_transforms)>,
remove_cvref_t<decltype(all_low_dim_hidden_idss)>,
remove_cvref_t<decltype(all_up_dim_hidden_idss)>,
remove_cvref_t<decltype(bottom_dim_hidden_ids)>,
remove_cvref_t<decltype(top_dim_hidden_ids)>>{all_transforms};
}
template <typename X, typename... Xs, typename enable_if<sizeof...(Xs) >= 2, bool>::type = false>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace ck {
template <index_t NDimHidden, typename BottomDimensionHiddenIds, typename TopDimensionHiddenIds>
struct TensorAdaptorCoordinate
{
static constexpr index_t ndim_bottom_ = BottomDimensionHiddenIds::Size();
static constexpr index_t ndim_top_ = TopDimensionHiddenIds::Size();
using HiddenIndex = MultiIndex<NDimHidden>;
using BottomIndex = MultiIndex<ndim_bottom_>;
using TopIndex = MultiIndex<ndim_top_>;
public:
__host__ __device__ constexpr TensorAdaptorCoordinate() = default;
__host__ __device__ constexpr TensorAdaptorCoordinate(const HiddenIndex& idx_hidden)
: idx_hidden_{idx_hidden}
{
}
__host__ __device__ constexpr auto GetTopIndex() const
{
return get_container_subset(idx_hidden_, TopDimensionHiddenIds{});
}
__host__ __device__ constexpr auto GetBottomIndex() const
{
return get_container_subset(idx_hidden_, BottomDimensionHiddenIds{});
}
__host__ __device__ constexpr const auto& GetHiddenIndex() const { return idx_hidden_; }
__host__ __device__ constexpr auto& GetHiddenIndex() { return idx_hidden_; }
//
HiddenIndex idx_hidden_;
};
template <typename Adaptor, typename TopIndex>
__host__ __device__ constexpr auto make_tensor_adaptor_coordinate(const Adaptor& adaptor,
const TopIndex& idx_top)
{
static_assert(Adaptor::GetNumOfTopDimension() == TopIndex::Size(),
"wrong! # of dimension inconsistent");
constexpr index_t ntransform = Adaptor::GetNumOfTransform();
constexpr index_t ndim_hidden = Adaptor::GetNumOfHiddenDimension();
constexpr auto bottom_dim_ids = Adaptor::GetBottomDimensionHiddenIds();
constexpr auto top_dim_ids = Adaptor::GetTopDimensionHiddenIds();
MultiIndex<ndim_hidden> idx_hidden;
// initialize visible index
set_container_subset(idx_hidden, top_dim_ids, idx_top);
// calculate hidden index
static_for<ntransform, 0, -1>{}([&adaptor, &idx_hidden](auto itran_p1) {
auto itran = itran_p1 - Number<1>{};
const auto& tran = adaptor.GetTransforms().At(itran);
constexpr auto dims_low = Adaptor::GetLowerDimensionHiddenIdss().At(itran);
constexpr auto dims_up = Adaptor::GetUpperDimensionHiddenIdss().At(itran);
const auto idx_up = get_container_subset(idx_hidden, dims_up);
MultiIndex<dims_low.Size()> idx_low;
tran.CalculateLowerIndex(idx_low, idx_up);
set_container_subset(idx_hidden, dims_low, idx_low);
});
return TensorAdaptorCoordinate<ndim_hidden,
remove_cvref_t<decltype(bottom_dim_ids)>,
remove_cvref_t<decltype(top_dim_ids)>>{idx_hidden};
}
template <bool JudgeDoTransforms = true,
typename Adaptor,
typename AdaptorCoord,
typename TopIndex,
typename BottomIndex>
__host__ __device__ constexpr void move_tensor_adaptor_coordinate(const Adaptor& adaptor,
AdaptorCoord& coord,
const TopIndex& idx_diff_top,
BottomIndex& idx_diff_bottom)
{
constexpr index_t ndim_hidden = Adaptor::GetNumOfHiddenDimension();
constexpr index_t ndim_top = Adaptor::GetNumOfTopDimension();
// constexpr index_t ndim_bottom = Adaptor::GetNumOfBottomDimension();
constexpr index_t ntransform = Adaptor::GetNumOfTransform();
// STATIC_ASSERT(TopIndex::Size() == ndim_top && BottomIndex::Size() == ndim_bottom, "");
// judge whether calculation of lower diff is needed for each transform
// use index_t for boolean type
auto do_transforms = make_zero_multi_index<ntransform>();
if constexpr(JudgeDoTransforms)
{
auto is_non_zero_diff = make_zero_multi_index<ndim_hidden>();
// decide do_transform by checkout non-zero index diff components
MultiIndex<ndim_top> non_zero_diff_pick_top;
static_for<0, ndim_top, 1>{}(
[&](auto i) { non_zero_diff_pick_top(i) = (idx_diff_top[i] != 0); });
set_container_subset(
is_non_zero_diff, Adaptor::GetTopDimensionHiddenIds(), non_zero_diff_pick_top);
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
constexpr auto dims_low = Adaptor::GetLowerDimensionHiddenIdss().At(itran);
constexpr auto dims_up = Adaptor::GetUpperDimensionHiddenIdss().At(itran);
const auto non_zero_diff_pick_up = get_container_subset(is_non_zero_diff, dims_up);
MultiIndex<dims_low.Size()> non_zero_diff_pick_low;
// if any of upper index diff components is non-zero, then
// 1) Need to do this transform
// 2) all components of lower index diff will assume to be non-zero and need to be
// computed
const bool idx_diff_up_has_non_zero = container_reduce(
non_zero_diff_pick_up, [](auto a, auto b) constexpr { return a or b; }, false);
do_transforms(itran) = idx_diff_up_has_non_zero;
static_for<0, dims_low.Size(), 1>{}(
[&](auto i) { non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero; });
set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low);
});
}
else
{
static_for<ntransform - 1, -1, -1>{}([&](auto itran) { do_transforms(itran) = 1; });
}
// this is what needs to be calculated
auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>();
// initialize top index diff
set_container_subset(idx_diff_hidden, Adaptor::GetTopDimensionHiddenIds(), idx_diff_top);
// this is what needs to be updated
auto& idx_hidden = coord.GetHiddenIndex();
// update top index
auto idx_hidden_pick_top =
get_container_subset(idx_hidden, Adaptor::GetTopDimensionHiddenIds());
idx_hidden_pick_top += idx_diff_top;
set_container_subset(idx_hidden, Adaptor::GetTopDimensionHiddenIds(), idx_hidden_pick_top);
// update rest of hidden index
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
if(do_transforms[itran])
{
const auto& tran = adaptor.GetTransforms().At(itran);
constexpr auto dims_low = Adaptor::GetLowerDimensionHiddenIdss().At(itran);
constexpr auto dims_up = Adaptor::GetUpperDimensionHiddenIdss().At(itran);
const auto idx_up_new = get_container_subset(idx_hidden, dims_up);
auto idx_low = get_container_subset(idx_hidden, dims_low);
const auto idx_diff_up = get_container_subset(idx_diff_hidden, dims_up);
MultiIndex<dims_low.Size()> idx_diff_low;
tran.UpdateLowerIndex(idx_diff_low, idx_diff_up, idx_low, idx_up_new);
set_container_subset(idx_diff_hidden, dims_low, idx_diff_low);
set_container_subset(idx_hidden, dims_low, idx_low);
}
});
// set bottom index diff
idx_diff_bottom = get_container_subset(idx_diff_hidden, Adaptor::GetBottomDimensionHiddenIds());
}
template <bool JudgeDoTransforms = true, typename Adaptor, typename AdaptorCoord, typename TopIndex>
__host__ __device__ constexpr void move_tensor_adaptor_coordinate(const Adaptor& adaptor,
AdaptorCoord& coord,
const TopIndex& idx_diff_top)
{
constexpr index_t ndim_bottom = Adaptor::GetNumOfBottomDimension();
MultiIndex<ndim_bottom> tmp;
move_tensor_adaptor_coordinate<JudgeDoTransforms>(adaptor, coord, idx_diff_top, tmp);
}
template <typename Adaptor, typename AdaptorCoord>
__host__ __device__ constexpr bool
adaptor_coordinate_is_valid_assuming_top_index_is_valid(const Adaptor& adaptor,
const AdaptorCoord& coord)
{
bool valid = true;
constexpr index_t ntransform = Adaptor::GetNumOfTransform();
const auto& idx_hidden = coord.GetHiddenIndex();
static_for<ntransform - 1, -1, -1>{}([&adaptor, &idx_hidden, &valid](auto itran) {
const auto tran = adaptor.GetTransforms().At(itran);
// check validity, only if current transformation does not always has a valid mapping
if constexpr(!decltype(tran)::IsValidUpperIndexAlwaysMappedToValidLowerIndex())
{
const auto idx_up =
get_container_subset(idx_hidden, Adaptor::GetUpperDimensionHiddenIdss().At(itran));
// Comment: using valid = valid && .. will result in weird control flow in ISA
valid &= tran.IsValidUpperIndexMappedToValidLowerIndex(idx_up);
}
});
return valid;
}
template <typename Adaptor, typename AdpatorCoord>
__host__ __device__ constexpr bool adaptor_coordinate_is_valid(const Adaptor& adaptor,
const AdpatorCoord& coord)
{
// check top index
const auto& idx_top = coord.GetTopIndex();
bool is_top_index_valid = true;
static_for<0, Adaptor::GetNumOfDimension(), 1>{}(
[&is_top_index_valid, &idx_top, &adaptor](auto i) {
is_top_index_valid =
is_top_index_valid && (idx_top[i] >= 0 && idx_top[i] < adaptor.GetLength(i));
});
// check other hidden index
return is_top_index_valid &&
adaptor_coordinate_is_valid_assuming_top_index_is_valid(adaptor, coord);
}
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_adaptor_coordinate.hpp"
namespace ck {
template <index_t NDimHidden, typename TopDimensionHiddenIds>
struct TensorCoordinate
: public TensorAdaptorCoordinate<NDimHidden, Sequence<0>, TopDimensionHiddenIds>
{
using Base = TensorAdaptorCoordinate<NDimHidden, Sequence<0>, TopDimensionHiddenIds>;
// TODO make these private
static constexpr index_t ndim_top_ = TopDimensionHiddenIds::Size();
using HiddenIndex = MultiIndex<NDimHidden>;
using TopIndex = MultiIndex<ndim_top_>;
public:
__host__ __device__ constexpr TensorCoordinate() = default;
__host__ __device__ constexpr TensorCoordinate(const HiddenIndex& idx_hidden) : Base{idx_hidden}
{
}
// construct from TensorAdaptorCoordinte base class
__host__ __device__ constexpr TensorCoordinate(const Base& adaptor_coord) : Base{adaptor_coord}
{
}
__host__ __device__ constexpr auto GetIndex() const { return Base::GetTopIndex(); }
__host__ __device__ constexpr index_t GetOffset() const
{
return Base::GetBottomIndex()[Number<0>{}];
}
__host__ __device__ constexpr const auto& GetHiddenIndex() const
{
return Base::GetHiddenIndex();
}
__host__ __device__ auto& GetHiddenIndex() { return Base::GetHiddenIndex(); }
};
template <typename TensorDesc, typename TopIndex>
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc& tensor_desc,
const TopIndex& idx_top)
{
const auto adaptor_coord = make_tensor_adaptor_coordinate(tensor_desc, idx_top);
return TensorCoordinate<TensorDesc::GetNumOfHiddenDimension(),
remove_cvref_t<decltype(TensorDesc::GetTopDimensionHiddenIds())>>{
adaptor_coord};
}
template <bool JudgeDoTransforms = true, typename TensorDesc, typename TensorCoord, typename Index>
__host__ __device__ constexpr void
move_tensor_coordinate(const TensorDesc& tensor_desc, TensorCoord& coord, const Index& coord_step)
{
move_tensor_adaptor_coordinate(tensor_desc, coord, coord_step);
}
template <typename TensorDesc, typename TensorCoord>
__host__ __device__ constexpr bool
coordinate_has_valid_offset_assuming_top_index_is_valid(const TensorDesc& tensor_desc,
const TensorCoord& coord)
{
return adaptor_coordinate_is_valid_assuming_top_index_is_valid(tensor_desc, coord);
}
template <typename TensorDesc, typename TensorCoord>
__host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc& tensor_desc,
const TensorCoord& coord)
{
return adaptor_coordinate_is_valid(tensor_desc, coord);
}
} // namespace ck
......@@ -4,612 +4,200 @@
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/sequence_helper.hpp"
#include "ck/tensor_description/multi_index_transform.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace ck {
template <index_t NDimHidden, typename VisibleDimensionIds>
struct TensorCoordinate;
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
struct TensorCoordinateStep;
// Transforms: Tuple<transforms...>
// LowerDimensionIdss : Tuple<Sequence<...>, ...>
// UpperDimensionIdss : Tuple<Sequence<...>, ...>
// VisibleDimensionIds> : Sequence<...>
// LowerDimensionHiddenIdss : Tuple<Sequence<...>, ...>
// UpperDimensionHiddenIdss : Tuple<Sequence<...>, ...>
// TopDimensionHiddenIds> : Sequence<...>
template <typename Transforms,
typename LowerDimensionIdss,
typename UpperDimensionIdss,
typename VisibleDimensionIds,
typename ElementSpaceSize>
struct TensorDescriptor
typename LowerDimensionHiddenIdss,
typename UpperDimensionHiddenIdss,
typename TopDimensionHiddenIds,
typename ElementSpaceSize,
typename GuaranteedVectorLengths_,
typename GuaranteedVectorSrides_>
struct TensorDescriptor : public TensorAdaptor<Transforms,
LowerDimensionHiddenIdss,
UpperDimensionHiddenIdss,
Sequence<0>,
TopDimensionHiddenIds>
{
// TODO make these private
__host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); }
__host__ __device__ static constexpr index_t GetNumOfVisibleDimension()
{
return VisibleDimensionIds::Size();
}
__host__ __device__ static constexpr index_t GetNumOfHiddenDimension()
{
constexpr auto all_low_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionIdss{});
constexpr auto all_up_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionIdss{});
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
using Base = TensorAdaptor<Transforms,
LowerDimensionHiddenIdss,
UpperDimensionHiddenIdss,
Sequence<0>,
TopDimensionHiddenIds>;
using unique_sort_all_dim_ids = typename sequence_unique_sort<decltype(all_dim_ids),
math::less<index_t>,
math::equal<index_t>>::type;
using ElementSpaceSizeType = ElementSpaceSize;
return unique_sort_all_dim_ids::Size();
}
__host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms)
{
const auto lengths = generate_tuple(
[&](auto idim_visible) {
constexpr auto tmp = GetTransformAndItsUpperDimension(idim_visible);
constexpr index_t itran = tmp[Number<0>{}];
constexpr index_t idim_up = tmp[Number<1>{}];
constexpr bool found = tmp[Number<2>{}];
static_assert(found == true,
"wrong! not found matching transformation and upper-dimension");
constexpr static index_t ntransform_ = Base::GetNumOfTransform();
constexpr static index_t ndim_hidden_ = Base::GetNumOfHiddenDimension();
constexpr static index_t ndim_top_ = Base::GetNumOfTopDimension();
const auto length =
transforms[Number<itran>{}].GetUpperLengths()[Number<idim_up>{}];
using GuaranteedVectorLengths = GuaranteedVectorLengths_;
using GuaranteedVectorStrides = GuaranteedVectorSrides_;
return length;
},
Number<ndim_visible_>{});
// TODO: make container_reduce support tuple of Number and index_t
return container_reduce(lengths, math::multiplies{}, Number<1>{});
}
template <index_t IDim>
__host__ __device__ static constexpr auto GetTransformAndItsUpperDimension(Number<IDim>)
{
constexpr auto idim_visible = Number<IDim>{};
static_assert(GuaranteedVectorLengths::Size() == ndim_hidden_ &&
GuaranteedVectorStrides::Size() == ndim_hidden_,
"wrong! inconsistent # of hidden dimensions");
constexpr index_t idim_hidden = VisibleDimensionIds::At(idim_visible);
index_t itran_found = 0;
index_t idim_up_found = 0;
bool found = false;
static_for<0, ntransform_, 1>{}([&](auto itran) {
constexpr auto up_dim_ids = UpperDimensionIdss{}[itran];
static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) {
if constexpr(up_dim_ids[idim_up] == idim_hidden)
{
itran_found = itran;
idim_up_found = idim_up;
found = true;
}
});
});
return make_tuple(itran_found, idim_up_found, found);
}
constexpr static index_t ntransform_ = GetNumOfTransform();
constexpr static index_t ndim_visible_ = GetNumOfVisibleDimension();
constexpr static index_t ndim_hidden_ = GetNumOfHiddenDimension();
using VisibleIndex = MultiIndex<ndim_visible_>;
using HiddenIndex = MultiIndex<ndim_hidden_>;
using Coordinate = TensorCoordinate<ndim_hidden_, VisibleDimensionIds>;
// may be index_t or Number<>
using ElementSize = remove_cv_t<decltype(InitializeElementSize(Transforms{}))>;
using TopIndex = MultiIndex<ndim_top_>;
using HiddenIndex = MultiIndex<ndim_hidden_>;
public:
#if 0 // workaround compiler complaint about constexpr
__host__ __device__ constexpr TensorDescriptor() = default;
#else
__host__ __device__ constexpr TensorDescriptor()
: transforms_{}, element_size_{}, element_space_size_{}
{
}
#endif
__host__ __device__ constexpr TensorDescriptor(const Transforms& transforms,
ElementSpaceSize element_space_size)
: transforms_{transforms},
element_size_{InitializeElementSize(transforms)},
element_space_size_{element_space_size}
: Base{transforms}, element_space_size_{element_space_size}
{
static_assert(Transforms::Size() == ntransform_ &&
LowerDimensionIdss::Size() == ntransform_ &&
UpperDimensionIdss::Size() == ntransform_,
LowerDimensionHiddenIdss::Size() == ntransform_ &&
UpperDimensionHiddenIdss::Size() == ntransform_,
"wrong! inconsistent # of transformations");
// TODO check dependency of dimensions is valid
}
__host__ __device__ static constexpr index_t GetNumOfDimension()
// construct from TensorAdaptor base class
__host__ __device__ constexpr TensorDescriptor(const Base& adaptor,
ElementSpaceSize element_space_size)
: Base{adaptor}, element_space_size_{element_space_size}
{
return GetNumOfVisibleDimension();
}
template <index_t IDim>
__host__ __device__ constexpr auto GetLength(Number<IDim>) const
__host__ __device__ static constexpr index_t GetNumOfDimension()
{
static_assert(IDim >= 0 && IDim < ndim_visible_, "wrong! out of range");
constexpr auto tmp = GetTransformAndItsUpperDimension(Number<IDim>{});
constexpr index_t itran = tmp[Number<0>{}];
constexpr index_t idim_up = tmp[Number<1>{}];
constexpr bool found = tmp[Number<2>{}];
static_assert(found == true,
"wrong! not found matching transformation and upper-dimension");
return transforms_[Number<itran>{}].GetUpperLengths()[Number<idim_up>{}];
return Base::GetNumOfTopDimension();
}
__host__ __device__ constexpr auto GetLengths() const
template <index_t IDim>
__host__ __device__ constexpr auto GetLength(Number<IDim> idim) const
{
// FIXME: use Tuple of reference instead
return generate_sequence_v2([&](auto I) { return GetLength(I); }, Number<ndim_visible_>{});
return Base::GetTopDimensionLength(idim);
}
__host__ __device__ constexpr auto GetElementSize() const { return element_size_; }
__host__ __device__ constexpr auto GetLengths() const { return Base::GetTopDimensionLengths(); }
__host__ __device__ constexpr auto GetElementSpaceSize() const { return element_space_size_; }
template <typename Idx>
__host__ __device__ constexpr index_t CalculateOffset(const Idx& idx) const
{
static_assert(Idx::Size() == GetNumOfDimension(), "wrong! inconsistent # of dimension");
return make_tensor_coordinate(*this, idx).GetOffset();
return Base::CalculateBottomIndex(idx)[Number<0>{}];
}
// TODO make these private
__host__ __device__ constexpr const auto& GetTransforms() const { return transforms_; }
__host__ __device__ static constexpr auto GetLowerDimensionIdss()
__host__ __device__ constexpr const auto& GetTransforms() const
{
return LowerDimensionIdss{};
return Base::GetTransforms();
}
__host__ __device__ static constexpr auto GetUpperDimensionIdss()
__host__ __device__ static constexpr auto GetLowerDimensionHiddenIdss()
{
return UpperDimensionIdss{};
return Base::GetLowerDimensionHiddenIdss();
}
__host__ __device__ static constexpr auto GetVisibleDimensionIds()
__host__ __device__ static constexpr auto GetUpperDimensionHiddenIdss()
{
return VisibleDimensionIds{};
return Base::GetUpperDimensionHiddenIdss();
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
__host__ __device__ static constexpr auto GetTopDimensionHiddenIds()
{
bool is_known = true;
static_for<0, Transforms::Size(), 1>{}([&](auto i) {
is_known &= remove_cvref_t<decltype(Transforms{}[i])>::IsKnownAtCompileTime();
});
return is_known && is_known_at_compile_time<ElementSize>::value &&
is_known_at_compile_time<ElementSpaceSize>::value;
return Base::GetTopDimensionHiddenIds();
}
__host__ __device__ void Print() const
__host__ __device__ static constexpr bool IsStatic()
{
printf("{");
printf("TensorDescriptor, ");
static_for<0, ntransform_, 1>{}([&](auto i) {
printf("transforms: ");
transforms_[i].Print();
printf("LowerDimensionIds:");
LowerDimensionIdss{}.At(i).Print();
printf("UpperDimensionIds:");
UpperDimensionIdss{}.At(i).Print();
});
printf("}");
VisibleDimensionIds::Print();
return Base::IsKnownAtCompileTime() && is_known_at_compile_time<ElementSpaceSize>::value;
}
// TODO make these private
Transforms transforms_;
ElementSize element_size_;
ElementSpaceSize element_space_size_;
};
__host__ __device__ static constexpr bool IsKnownAtCompileTime() { return IsStatic(); }
template <index_t NDimHidden, typename VisibleDimensionIds>
struct TensorCoordinate
{
// TODO make these private
static constexpr index_t ndim_visible_ = VisibleDimensionIds::Size();
using HiddenIndex = MultiIndex<NDimHidden>;
using VisibleIndex = MultiIndex<ndim_visible_>;
public:
__host__ __device__ constexpr TensorCoordinate() = default;
__host__ __device__ constexpr TensorCoordinate(const HiddenIndex& idx_hidden)
: idx_hidden_{idx_hidden}
__host__ __device__ static constexpr auto GetTopDimensionSafeVectorLengthStrides()
{
return Base::GetTopDimensionSafeVectorLengthStrides(
to_array<index_t, ndim_hidden_>(GuaranteedVectorLengths{}),
to_array<index_t, ndim_hidden_>(GuaranteedVectorStrides{}));
}
__host__ __device__ constexpr auto GetIndex() const { return GetVisibleIndex(); }
__host__ __device__ constexpr index_t GetOffset() const { return idx_hidden_[Number<0>{}]; }
// TODO make these private
__host__ __device__ constexpr const auto& GetHiddenIndex() const { return idx_hidden_; }
__host__ __device__ auto& GetHiddenIndex() { return idx_hidden_; }
__host__ __device__ constexpr auto GetVisibleIndex() const
__host__ __device__ void Print() const
{
return get_container_subset(idx_hidden_, VisibleDimensionIds{});
}
// TODO make these private
HiddenIndex idx_hidden_;
};
printf("TensorDescriptor{");
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
struct TensorCoordinateStep
{
// TODO make these private
using VisibleIndex = MultiIndex<NDimVisible>;
// TensorAdaptor
Base::Print();
printf(", ");
public:
__host__ __device__ constexpr TensorCoordinateStep() = default;
// element_space_size_
printf("element_space_size_: ");
print(element_space_size_);
__host__ __device__ constexpr TensorCoordinateStep(const VisibleIndex& idx_diff_visible,
const MultiIndex<NTransform>& do_transforms)
: idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms}
{
printf("}");
}
__host__ __device__ constexpr const auto& GetIndexDiff() const { return GetVisibleIndexDiff(); }
// TODO make these private
__host__ __device__ constexpr const auto& GetVisibleIndexDiff() const
{
return idx_diff_visible_;
}
VisibleIndex idx_diff_visible_;
MultiIndex<NTransform> do_transforms_;
// HACK: control UpdateLowerIndex()
static constexpr UpdateLowerIndexHack update_lower_index_hack_;
ElementSpaceSize element_space_size_;
};
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor, and to put it outside the scope where it is used
// (transform_tensor_descriptor) because template cannot be defined inside a function
// template
template <typename NewTransforms>
struct lambda_get_up_dim_num
template <typename Adaptor, typename ElementSpaceSize>
__host__ __device__ constexpr auto
make_tensor_descriptor_from_adaptor(const Adaptor& adaptor,
const ElementSpaceSize& element_space_size)
{
template <typename I>
__host__ __device__ constexpr auto operator()(I) const
{
using Tran = remove_reference_t<decltype(NewTransforms{}.At(I{}))>;
return Number<Tran::GetNumOfUpperDimension()>{};
}
};
constexpr index_t NDimHidden = Adaptor::GetNumOfHiddenDimension();
return TensorDescriptor<remove_cvref_t<decltype(adaptor.GetTransforms())>,
remove_cvref_t<decltype(adaptor.GetLowerDimensionHiddenIdss())>,
remove_cvref_t<decltype(adaptor.GetUpperDimensionHiddenIdss())>,
remove_cvref_t<decltype(adaptor.GetTopDimensionHiddenIds())>,
remove_cvref_t<decltype(element_space_size)>,
typename uniform_sequence_gen<NDimHidden, -1>::type,
typename uniform_sequence_gen<NDimHidden, -1>::type>{
adaptor, element_space_size};
}
template <typename OldTensorDescriptor,
typename NewTransforms,
typename NewLowerDimensionOldVisibleIdss,
typename NewUpperDimensionNewVisibleIdss>
typename NewLowerDimensionOldTopIdss,
typename NewUpperDimensionNewTopIdss>
__host__ __device__ constexpr auto
transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
const NewTransforms& new_transforms,
NewLowerDimensionOldVisibleIdss,
NewUpperDimensionNewVisibleIdss)
NewLowerDimensionOldTopIdss,
NewUpperDimensionNewTopIdss)
{
// sanity check
{
static_assert(NewTransforms::Size() == NewLowerDimensionOldVisibleIdss::Size() &&
NewTransforms::Size() == NewUpperDimensionNewVisibleIdss::Size(),
"wrong! inconsitent number of transform");
constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewLowerDimensionOldVisibleIdss{});
constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewUpperDimensionNewVisibleIdss{});
static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
is_valid_sequence_map<decltype(all_new_top_ids)>::value,
"wrong!");
}
// lower dimension's hidden idss
// convert lower dimension visible idss (tuple of sequences) to hidden idss (tuple of
// sequences)
constexpr auto low_dim_hidden_idss = transform_tuples(
// convert lower dimension visible ids (a sequence) to hidden ids (a sequence)
[](auto low_dim_visible_ids) constexpr {
return transform_sequences(
// convert lower dimension visible id to hidden id
[](auto low_dim_visible_id) constexpr {
return OldTensorDescriptor::GetVisibleDimensionIds()[low_dim_visible_id];
},
low_dim_visible_ids);
},
NewLowerDimensionOldVisibleIdss{});
constexpr index_t num_new_transform = NewTransforms::Size();
// upper dimension's hidden idss
constexpr index_t old_hidden_dim_number = OldTensorDescriptor::GetNumOfHiddenDimension();
constexpr auto up_dim_numbers =
generate_sequence(lambda_get_up_dim_num<NewTransforms>{}, Number<num_new_transform>{});
constexpr auto up_dim_numbers_scan = merge_sequences(
Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, math::plus<index_t>{}, Number<0>{}));
constexpr auto up_dim_hidden_idss = generate_tuple(
[ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr {
return
typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
old_hidden_dim_number + up_dim_numbers_scan[i + 1],
1>::type{};
},
Number<num_new_transform>{});
// new visible dimension's hidden ids
constexpr auto unordered_new_visible_dim_hidden_ids = unpack(
[](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
constexpr auto new_visible_dim_unordered2ordered = unpack(
[](auto... xs) constexpr { return merge_sequences(xs...); },
NewUpperDimensionNewVisibleIdss{});
constexpr auto new_visible_dim_hidden_ids =
unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered);
// put everything together
const auto all_transforms = container_concat(old_tensor_desc.GetTransforms(), new_transforms);
constexpr auto all_low_dim_hidden_idss =
container_concat(OldTensorDescriptor::GetLowerDimensionIdss(), low_dim_hidden_idss);
constexpr auto all_up_dim_hidden_idss =
container_concat(OldTensorDescriptor::GetUpperDimensionIdss(), up_dim_hidden_idss);
const auto element_space_size = old_tensor_desc.GetElementSpaceSize();
return TensorDescriptor<remove_cv_t<decltype(all_transforms)>,
remove_cv_t<decltype(all_low_dim_hidden_idss)>,
remove_cv_t<decltype(all_up_dim_hidden_idss)>,
remove_cv_t<decltype(new_visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>>{all_transforms,
element_space_size};
}
template <typename TensorDesc, typename VisibleIndex>
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc& tensor_desc,
const VisibleIndex& idx_visible)
{
static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(),
"wrong! # of dimension inconsistent");
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension();
constexpr auto visible_dim_ids = TensorDesc::GetVisibleDimensionIds();
MultiIndex<ndim_hidden> idx_hidden;
// initialize visible index
set_container_subset(idx_hidden, visible_dim_ids, idx_visible);
// calculate hidden index
static_for<ntransform, 0, -1>{}([&tensor_desc, &idx_hidden](auto itran_p1) {
auto itran = itran_p1 - Number<1>{};
const auto& tran = tensor_desc.GetTransforms().At(itran);
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran);
const auto idx_up = get_container_subset(idx_hidden, dims_up);
MultiIndex<dims_low.Size()> idx_low;
tran.CalculateLowerIndex(idx_low, idx_up);
set_container_subset(idx_hidden, dims_low, idx_low);
});
return TensorCoordinate<ndim_hidden, decltype(visible_dim_ids)>{idx_hidden};
}
// UpdateLowerIndexHack: Sequence<...>
// HACK: control UpdateLowerIndex
template <typename TensorDesc, typename VisibleIndex, typename UpdateLowerIndexHack>
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&,
const VisibleIndex& idx_diff_visible,
UpdateLowerIndexHack)
{
static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(),
"wrong! # of dimension inconsistent");
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension();
constexpr index_t ndim_visible = TensorDesc::GetNumOfVisibleDimension();
constexpr auto visible_dim_ids = TensorDesc::GetVisibleDimensionIds();
static_assert(UpdateLowerIndexHack::Size() == ntransform, "wrong!");
// use index_t for boolean type
auto do_transforms = make_zero_multi_index<ntransform>();
auto is_non_zero_diff = make_zero_multi_index<ndim_hidden>();
// decide do_transform by checkout non-zero index diff components
MultiIndex<VisibleIndex::Size()> non_zero_diff_pick_visible;
static_for<0, ndim_visible, 1>{}(
[&](auto i) { non_zero_diff_pick_visible(i) = (idx_diff_visible[i] != 0); });
set_container_subset(is_non_zero_diff, visible_dim_ids, non_zero_diff_pick_visible);
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran);
const auto non_zero_diff_pick_up = get_container_subset(is_non_zero_diff, dims_up);
MultiIndex<dims_low.Size()> non_zero_diff_pick_low;
// if any of upper index diff components is non-zero, then
// 1) Need to do this transform
// 2) all components of lower index diff will assume to be non-zero and need to be
// computed
const bool idx_diff_up_has_non_zero = container_reduce(
non_zero_diff_pick_up, [](auto a, auto b) constexpr { return a or b; }, false);
do_transforms(itran) = idx_diff_up_has_non_zero;
static_for<0, dims_low.Size(), 1>{}(
[&](auto i) { non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero; });
set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low);
});
return TensorCoordinateStep<ntransform, ndim_visible, UpdateLowerIndexHack>{idx_diff_visible,
do_transforms};
}
template <typename TensorDesc, typename VisibleIndex>
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&,
const VisibleIndex& idx_diff_visible)
{
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
return make_tensor_coordinate_step(
TensorDesc{}, idx_diff_visible, typename uniform_sequence_gen<ntransform, 0>::type{});
}
template <typename TensorDesc, typename TensorCoord, typename TensorCoordStep>
__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tensor_desc,
TensorCoord& coord,
const TensorCoordStep& coord_step)
{
constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension();
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
// this is what needs to be calculated
auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>();
// initialize visible index diff
set_container_subset(
idx_diff_hidden, TensorDesc::GetVisibleDimensionIds(), coord_step.GetVisibleIndexDiff());
// this is what needs to be updated
auto& idx_hidden = coord.GetHiddenIndex();
// update visible index
auto idx_hidden_pick_visible =
get_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds());
idx_hidden_pick_visible += coord_step.GetIndexDiff();
set_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds(), idx_hidden_pick_visible);
// update rest of hidden index
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
if(coord_step.do_transforms_[itran])
{
const auto& tran = tensor_desc.GetTransforms().At(itran);
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran);
const auto idx_up_new = get_container_subset(idx_hidden, dims_up);
auto idx_low = get_container_subset(idx_hidden, dims_low);
const auto idx_diff_up = get_container_subset(idx_diff_hidden, dims_up);
MultiIndex<dims_low.Size()> idx_diff_low;
// HACK: control UpdateLowerIndex for Merge using hack
constexpr index_t Hack = decltype(coord_step.update_lower_index_hack_)::At(itran);
tran.UpdateLowerIndex(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
set_container_subset(idx_diff_hidden, dims_low, idx_diff_low);
set_container_subset(idx_hidden, dims_low, idx_low);
}
});
const auto new_tensor_adaptor = transform_tensor_adaptor(old_tensor_desc,
new_transforms,
NewLowerDimensionOldTopIdss{},
NewUpperDimensionNewTopIdss{});
constexpr index_t NDimHiddenOld = OldTensorDescriptor::GetNumOfHiddenDimension();
constexpr index_t NDimHiddenNew = decltype(new_tensor_adaptor)::GetNumOfHiddenDimension();
using NewGuaranteedVectorLengths = typename sequence_merge<
typename OldTensorDescriptor::GuaranteedVectorLengths,
typename uniform_sequence_gen<NDimHiddenNew - NDimHiddenOld, -1>::type>::type;
using NewGuaranteedVectorStrides = typename sequence_merge<
typename OldTensorDescriptor::GuaranteedVectorStrides,
typename uniform_sequence_gen<NDimHiddenNew - NDimHiddenOld, -1>::type>::type;
return TensorDescriptor<
remove_cvref_t<decltype(new_tensor_adaptor.GetTransforms())>,
remove_cvref_t<decltype(new_tensor_adaptor.GetLowerDimensionHiddenIdss())>,
remove_cvref_t<decltype(new_tensor_adaptor.GetUpperDimensionHiddenIdss())>,
remove_cvref_t<decltype(new_tensor_adaptor.GetTopDimensionHiddenIds())>,
remove_cvref_t<decltype(element_space_size)>,
NewGuaranteedVectorLengths,
NewGuaranteedVectorStrides>{new_tensor_adaptor, element_space_size};
}
template <typename TensorDesc, typename TensorCoord>
__host__ __device__ constexpr bool
coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc& tensor_desc,
const TensorCoord& coord)
{
bool valid = true;
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
const auto& idx_hidden = coord.GetHiddenIndex();
static_for<ntransform - 1, -1, -1>{}([&tensor_desc, &idx_hidden, &valid](auto itran) {
const auto tran = tensor_desc.GetTransforms().At(itran);
// check validity, only if current transformation does not always has a valid mapping
if constexpr(!decltype(tran)::IsValidUpperIndexAlwaysMappedToValidLowerIndex())
{
const auto idx_up =
get_container_subset(idx_hidden, TensorDesc::GetUpperDimensionIdss().At(itran));
// Comment: using valid = valid && .. will result in weird control flow in ISA
valid &= tran.IsValidUpperIndexMappedToValidLowerIndex(idx_up);
}
});
return valid;
}
template <typename TensorDesc, typename TensorCoord>
__host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc& tensor_desc,
const TensorCoord& coord)
{
// check visible index
const auto& idx_visible = coord.GetVisibleIndex();
bool is_visible_index_valid = true;
static_for<0, TensorDesc::GetNumOfDimension(), 1>{}(
[&is_visible_index_valid, &idx_visible, &tensor_desc](auto i) {
is_visible_index_valid =
is_visible_index_valid &&
(idx_visible[i] >= 0 && idx_visible[i] < tensor_desc.GetLength(i));
});
// check other hidden index
return is_visible_index_valid &&
coordinate_has_valid_offset_assuming_visible_index_is_valid(tensor_desc, coord);
}
template <typename TensorDesc>
using TensorCoordinate_t = decltype(make_tensor_coordinate(
TensorDesc{}, MultiIndex<remove_cvref_t<TensorDesc>::GetNumOfDimension()>{}));
template <typename TensorDesc>
using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step(
TensorDesc{}, MultiIndex<remove_cvref_t<TensorDesc>::GetNumOfDimension()>{}));
} // namespace ck
......@@ -4,20 +4,13 @@
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
namespace ck {
/*
* These functions create tensor descriptor at runtime. If they are not constexpr, you will
* likely see usage of scratch memory during construction of these tensor descriptors. So
* it's better to call these functions on host and then pass the constructed tensor descritpors
* to GPU. If the tensor descritpors being constructed are constexpr, then you can call these
* functions on GPU without worrying about scratch memory usage.
*/
namespace detail {
#if CK_WORKAROUND_SWDEV_275126
template <typename Lengths, typename Strides, index_t I, typename AccOld>
__host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengths& lengths,
const Strides& strides,
......@@ -35,7 +28,12 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt
return acc_new;
}
}
#endif
} // namespace detail
/*
* These functions create naive tensor descriptor
*/
// Lengths..., Strides... could be:
// 1) index_t, which is known at run-time, or
......@@ -45,9 +43,14 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt
// 2) LongNumber<>
template <typename... Lengths,
typename... Strides,
index_t GuaranteedLastDimensionVectorLength = -1,
index_t GuaranteedLastDimensionVectorStride = -1,
typename enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Lengths...>& lengths,
const Tuple<Strides...>& strides)
__host__ __device__ constexpr auto
make_naive_tensor_descriptor(const Tuple<Lengths...>& lengths,
const Tuple<Strides...>& strides,
Number<GuaranteedLastDimensionVectorLength> = Number<-1>{},
Number<GuaranteedLastDimensionVectorStride> = Number<-1>{})
{
constexpr index_t N = sizeof...(Lengths);
......@@ -60,34 +63,24 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Leng
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
#if !CK_WORKAROUND_SWDEV_275126
// rocm-4.1 compiler would crash for recursive labmda
// recursive function for reduction
auto f = [&](auto fs, auto i, auto acc_old) {
auto acc_new = acc_old + (lengths[i] - Number<1>{}) * strides[i];
if constexpr(i.value < N - 1)
{
return fs(fs, i + Number<1>{}, acc_new);
}
else
{
return acc_new;
}
};
const auto element_space_size = f(f, Number<0>{}, LongNumber<1>{});
#else
const auto element_space_size =
calculate_element_space_size_impl(lengths, strides, Number<0>{}, LongNumber<1>{});
#endif
detail::calculate_element_space_size_impl(lengths, strides, Number<0>{}, LongNumber<1>{});
using GuaranteedVectorLengths =
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
Sequence<GuaranteedLastDimensionVectorLength>>::type;
using GuaranteedVectorStrides =
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
Sequence<GuaranteedLastDimensionVectorStride>>::type;
return TensorDescriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>>{transforms,
element_space_size};
remove_cv_t<decltype(element_space_size)>,
GuaranteedVectorLengths,
GuaranteedVectorStrides>{transforms, element_space_size};
}
// Lengths... could be:
......@@ -96,9 +89,10 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Leng
// element_space_size could be:
// 1) long_index_t, or
// 2) LongNumber<>
template <typename... Lengths>
template <typename... Lengths, index_t GuaranteedLastDimensionVectorLength = -1>
__host__ __device__ constexpr auto
make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths)
make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths,
Number<GuaranteedLastDimensionVectorLength> = Number<-1>{})
{
constexpr index_t N = sizeof...(Lengths);
......@@ -113,12 +107,20 @@ make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths)
const auto element_space_size = container_reduce(lengths, math::multiplies{}, LongNumber<1>{});
using GuaranteedVectorLengths =
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
Sequence<GuaranteedLastDimensionVectorLength>>::type;
using GuaranteedVectorStrides =
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type, Sequence<1>>::type;
return TensorDescriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>>{transforms,
element_space_size};
remove_cv_t<decltype(element_space_size)>,
GuaranteedVectorLengths,
GuaranteedVectorStrides>{transforms, element_space_size};
}
// Lengths... could be:
......
......@@ -6,7 +6,7 @@
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/sequence_helper.hpp"
#include "ck/utility/statically_indexed_array_multi_index.hpp"
#include "ck/utility/multi_index.hpp"
#include "ck/utility/tuple_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
......
......@@ -87,7 +87,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
// |KRepeat |MRepeat|MWave |MLane |KPack
return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0);
return make_multi_index(0, 0, waveId_m, WMMA_a_idx, 0);
}
__device__ static auto CalculateBThreadOriginDataIndex()
......@@ -98,7 +98,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
// |KRepeat |NRepeat|Nwave |NLane |KPack
return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0);
return make_multi_index(0, 0, waveId_n, WMMA_b_idx, 0);
}
template <index_t m0, index_t n0>
......
......@@ -108,7 +108,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]);
return make_multi_index(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]);
}
__device__ static auto CalculateBThreadOriginDataIndex()
......@@ -119,7 +119,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]);
return make_multi_index(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
......@@ -144,11 +144,12 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
make_multi_index(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
make_multi_index(n0, waveId_n, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n);
return make_multi_index(c_thread_m, c_thread_n);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
......@@ -336,17 +337,19 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
a_thread_vec.template AsType<FloatAB>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_multi_index(0, 0, 0, k + i))>{}];
b_thread_vec.template AsType<FloatAB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_multi_index(0, 0, 0, k + i))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
c_thread_desc_.CalculateOffset(make_multi_index(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
......@@ -707,7 +710,7 @@ struct BlockwiseGemmXdlops_v2
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]);
return make_multi_index(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]);
}
__device__ static auto CalculateBThreadOriginDataIndex()
......@@ -718,7 +721,7 @@ struct BlockwiseGemmXdlops_v2
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]);
return make_multi_index(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
......@@ -765,10 +768,10 @@ struct BlockwiseGemmXdlops_v2
m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
}
using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
using Array4 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
__host__ __device__ BlockwiseGemmXdlops_v2(Array4 a_origin = CalculateAThreadOriginDataIndex(),
Array4 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
......
......@@ -66,7 +66,7 @@ struct BlockwiseSoftmax
reduce::Add,
false>>::type;
using ThreadClusterLengths_M_K = decltype(ThreadClusterDesc_M_K{}.GetLengths());
using ThreadClusterLengths_M_K = decltype(to_sequence(ThreadClusterDesc_M_K{}.GetLengths()));
using BlockwiseMaxReduce = PartitionedBlockwiseReduction_v2<AccDataType,
BlockSize,
......
......@@ -50,6 +50,10 @@ struct ThreadGroupTensorSliceTransfer_v4r1
using Index = MultiIndex<nDim>;
#if 1 // debug
__host__ __device__ constexpr ThreadGroupTensorSliceTransfer_v4r1() : threadwise_transfer_{} {}
#endif
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r1(
const SrcDesc& src_desc,
const Index& src_block_slice_origin,
......
......@@ -86,7 +86,7 @@ struct BlockToCTileMap_M00_N0_M01
const auto M00 = math::integer_divide_ceil(M0, M01);
const auto m00_n0_m01_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_insert_transform(1),
make_tuple(make_replicate_transform(make_tuple(1)),
make_unmerge_transform(make_tuple(M00, M01)),
make_pass_through_transform(make_tuple(N0))),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
......@@ -402,7 +402,8 @@ struct BlockToCTileMap_M00_N00_M01_N01
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_insert_transform(1), // swallow the carry from lower dimensions
make_tuple(make_replicate_transform(
make_tuple(1)), // swallow the carry from lower dimensions
make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
......
......@@ -630,7 +630,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
// BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin
make_multi_index(0, 0, 0, 0)}; // A_origin
auto c_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer();
......
......@@ -796,7 +796,7 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
// BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin
make_multi_index(0, 0, 0, 0)}; // A_origin
auto acc1_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer();
......@@ -953,7 +953,7 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using Acc0TileIterator = SpaceFillingCurve<
decltype(c_thread_lengths),
decltype(to_sequence(c_thread_lengths)),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>; // SnakeCurved
......
......@@ -651,7 +651,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
// BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin
make_multi_index(0, 0, 0, 0)}; // A_origin
auto acc1_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer();
......@@ -769,7 +769,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using Acc0TileIterator = SpaceFillingCurve<
decltype(c_thread_lengths),
decltype(to_sequence(c_thread_lengths)),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>; // SnakeCurved
......
......@@ -45,7 +45,7 @@ struct ThreadwiseTensorSliceSet_v1
constexpr auto coord = make_tensor_coordinate(desc, origin_idx + access_idx);
constexpr bool is_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(desc, coord);
coordinate_has_valid_offset_assuming_top_index_is_valid(desc, coord);
constexpr index_t offset = coord.GetOffset();
......
......@@ -70,8 +70,6 @@ struct ThreadwiseTensorSliceTransfer_v1r3
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(const DstDesc& dst_desc,
const Index& dst_slice_origin_idx,
const ElementwiseOperation& element_op)
......@@ -147,7 +145,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
});
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
coordinate_has_valid_offset_assuming_top_index_is_valid(dst_desc, dst_coord_);
// copy data from dst_vector into dst_buf
dst_buf.template Update<DstInMemOp, dst_vector_t>(
......@@ -159,18 +157,14 @@ struct ThreadwiseTensorSliceTransfer_v1r3
{
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate(
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
move_tensor_coordinate(dst_desc, dst_coord_, forward_step);
}
});
// move dst coordinate back to slice origin (or not)
if constexpr(DstResetCoordinateAfterRun)
{
const auto dst_reset_step =
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
move_tensor_coordinate(dst_desc, dst_coord_, GetDstCoordinateResetStep());
}
}
......@@ -250,8 +244,6 @@ struct ThreadwiseTensorSliceTransfer_v2
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDesc& src_desc,
const Index& src_slice_origin_idx)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx))
......@@ -311,7 +303,7 @@ struct ThreadwiseTensorSliceTransfer_v2
constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d);
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
coordinate_has_valid_offset_assuming_top_index_is_valid(src_desc, src_coord_);
// copy data from src_buf into src_vector
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
......@@ -341,18 +333,14 @@ struct ThreadwiseTensorSliceTransfer_v2
{
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate(
src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step));
move_tensor_coordinate(src_desc, src_coord_, forward_step);
}
});
// move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun)
{
const auto src_reset_step =
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
move_tensor_coordinate(src_desc, src_coord_, GetSrcCoordinateResetStep());
}
}
......@@ -388,29 +376,7 @@ struct ThreadwiseTensorSliceTransfer_v2
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template <typename SrcMoveSliceWindowStepHack>
__device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx,
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(
src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step_idx);
}
private:
......@@ -450,9 +416,6 @@ struct ThreadwiseTensorSliceTransfer_v3
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v3(const SrcDesc& src_desc,
const Index& src_slice_origin,
const DstDesc& dst_desc,
......@@ -574,7 +537,7 @@ struct ThreadwiseTensorSliceTransfer_v3
using src_vector_t = typename decltype(src_tmp_vector)::type;
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
coordinate_has_valid_offset_assuming_top_index_is_valid(src_desc, src_coord_);
// copy data from src_buf to src_tmp_vector
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
......@@ -741,7 +704,7 @@ struct ThreadwiseTensorSliceTransfer_v3
// copy data from dst_tmp_vector to dst_buf
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
coordinate_has_valid_offset_assuming_top_index_is_valid(dst_desc, dst_coord_);
dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset(),
......@@ -1033,8 +996,6 @@ struct ThreadwiseTensorSliceTransfer_v4
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v4(const Index& src_ref_idx)
: src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx))
{
......@@ -1130,19 +1091,16 @@ struct ThreadwiseTensorSliceTransfer_v4
constexpr auto src_ref_to_data_disp_idx =
src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
constexpr auto src_ref_to_data_disp_coord_step =
make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx);
auto src_data_coord = src_ref_coord_;
move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step);
move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_idx);
vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector;
using src_vector_t = typename decltype(src_tmp_vector)::type;
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_data_coord);
const bool is_src_valid =
coordinate_has_valid_offset_assuming_top_index_is_valid(src_desc, src_data_coord);
// copy data from src_buf into src_tmp_vector
if constexpr(SrcBuffer::IsDynamicBuffer())
......
......@@ -6,7 +6,10 @@
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_coordinate.hpp"
#include "ck/tensor/thread_private_tensor.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
// FIXME: remove
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp"
......@@ -53,8 +56,8 @@ template <typename SliceLengths,
typename SrcElementwiseOperation,
typename DstElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp,
typename SrcData,
typename DstData,
typename SrcDataTmp,
typename DstDataTmp,
typename SrcDesc,
typename DstDesc,
typename SrcDimAccessOrder,
......@@ -77,14 +80,34 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using SrcData = remove_cvref_t<SrcDataTmp>;
using DstData = remove_cvref_t<DstDataTmp>;
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
static constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
static constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
using SrcSpaceFillingCurve = SpaceFillingCurve<SliceLengths,
SrcDimAccessOrder,
remove_cv_t<decltype(src_scalar_per_access)>>;
using DstSpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DstDimAccessOrder,
remove_cv_t<decltype(dst_scalar_per_access)>>;
__host__ __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1()
: src_coord_{}, dst_coord_{}, src_element_op_{}, dst_element_op_{}
{
}
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(
const SrcDesc& src_desc,
const Index& src_slice_origin,
......@@ -122,87 +145,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1
is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value,
"wrong! SrcBuffer and SrcData data type are inconsistent");
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
static_assert(SliceLengths::At(SrcVectorDim) % SrcScalarPerVector == 0,
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector");
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// make forward steps
const auto src_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(src_desc, forward_step_idx);
},
Number<nDim>{});
// make backward steps
const auto src_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(src_desc, backward_step_idx);
},
Number<nDim>{});
// loop over tensor and copy
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
constexpr auto num_access = SrcSpaceFillingCurve::GetNumOfAccess();
// loop over space-filling curve
static_for<0, num_access, 1>{}([&](auto iAccess) {
// calculate src data index
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
: ordered_src_access_lengths[i] - 1 -
ordered_src_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_scalar_per_access;
}();
constexpr auto src_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
constexpr auto src_data_idx = SrcSpaceFillingCurve::GetIndexTupleOfNumber(iAccess);
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
coordinate_has_valid_offset_assuming_top_index_is_valid(src_desc, src_coord_);
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
......@@ -212,50 +163,29 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
// copy data from src_vector_container into src_thread_scratch_
#if 1 // debug
src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<src_vector_t>(
src_data_idx_seq, src_vector_container.template AsType<src_vector_t>()[I0]);
src_data_idx, src_vector_container.template AsType<src_vector_t>()[I0]);
#else
src_thread_scratch_tuple_(thread_scratch_id)
.template Set<src_vector_t>(
src_data_idx, src_vector_container.template AsType<src_vector_t>()[I0]);
#endif
constexpr auto move_on_dim = [&]() constexpr
// move src coordinate
if constexpr(iAccess.value != num_access - 1)
{
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
constexpr auto step = SrcSpaceFillingCurve::GetForwardStep(iAccess);
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &=
ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
});
});
return move_on_dim_;
move_tensor_coordinate(src_desc, src_coord_, step);
}
();
// move src coord
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(
src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
}
else
{
move_tensor_coordinate(
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
}
}
});
});
// move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun)
{
const auto src_reset_step =
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
const auto src_reset_step = GetSrcCoordinateResetStep();
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
}
......@@ -265,13 +195,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
__device__ void
TransferDataFromSrcThreadScratchToDstThreadScratch(Number<ThreadScratchId> thread_scratch_id)
{
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
static_ford<SliceLengths>{}([&](auto idx) {
// convert from SrcData to DstData here
dst_thread_scratch_(idx) =
type_convert<DstData>(src_thread_scratch_tuple_[thread_scratch_id][idx]);
});
#else
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
// TODO make this logic more generic for more sub-dword datatype
if constexpr(SrcVectorDim != DstVectorDim &&
......@@ -347,7 +270,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_element_op_(dst_v, src_thread_scratch_tuple_[thread_scratch_id][idx]);
dst_thread_scratch_(idx) = dst_v;
});
#endif
}
template <typename DstBuffer, index_t ThreadScratchId = 0>
......@@ -367,91 +289,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
"wrong! SrcBuffer or DstBuffer data type is wrong");
// src scalar per access on each dim
// TODO: don't use this
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// make forward steps
const auto dst_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst_desc, forward_step_idx);
},
Number<nDim>{});
// make backward steps
const auto dst_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst_desc, backward_step_idx);
},
Number<nDim>{});
constexpr auto num_access = DstSpaceFillingCurve::GetNumOfAccess();
// loop over tensor and copy
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate dst data index
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i]
: ordered_dst_access_lengths[i] - 1 -
ordered_dst_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_scalar_per_access;
}();
constexpr auto dst_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<dst_data_idx[i]>{}; }, Number<dst_data_idx.Size()>{});
static_for<0, num_access, 1>{}([&](auto iAccess) {
// calculate src data index
constexpr auto dst_data_idx = DstSpaceFillingCurve::GetIndexTupleOfNumber(iAccess);
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
coordinate_has_valid_offset_assuming_top_index_is_valid(dst_desc, dst_coord_);
using dst_vector_type = vector_type_maker_t<DstData, DstScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type;
// copy data from dst_thread_scratch_ into dst_vector_container
auto dst_vector_container = dst_vector_type{
dst_thread_scratch_.template GetAsType<dst_vector_t>(dst_data_idx_seq)};
auto dst_vector_container =
dst_vector_type{dst_thread_scratch_.template GetAsType<dst_vector_t>(dst_data_idx)};
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
DstData dst_v;
......@@ -468,165 +321,54 @@ struct ThreadwiseTensorSliceTransfer_v3r1
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr
// move dst coord
if constexpr(iAccess.value != num_access - 1)
{
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
constexpr auto step = DstSpaceFillingCurve::GetForwardStep(iAccess);
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &=
ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
});
});
return move_on_dim_;
move_tensor_coordinate(dst_desc, dst_coord_, step);
}
();
// move dst coord
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(
dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
}
else
{
move_tensor_coordinate(
dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
}
}
});
});
// move dst coordinate back to slice origin (or not)
if constexpr(DstResetCoordinateAfterRun)
{
const auto dst_reset_step =
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
move_tensor_coordinate(dst_desc, dst_coord_, GetDstCoordinateResetStep());
}
}
__device__ static constexpr auto GetSrcCoordinateResetStep()
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_lengths[I0] - 1;
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate src data index after last iteration in RunRead(), if it has not being reset by
// RunRead()
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
});
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_scalar_per_access;
}();
//
constexpr auto reset_src_data_step = [&]() {
Index reset_src_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
constexpr auto num_access = SrcSpaceFillingCurve::GetNumOfAccess();
return reset_src_data_step_;
}();
if constexpr(num_access == 0)
{
return typename SrcSpaceFillingCurve::Index{};
}
else
{
constexpr auto reset_step =
SrcSpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
return reset_src_data_step;
return reset_step;
}
}
__device__ static constexpr auto GetDstCoordinateResetStep()
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_lengths[I0] - 1;
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
// RunWrite()
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
constexpr auto num_access = DstSpaceFillingCurve::GetNumOfAccess();
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
});
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_scalar_per_access;
}();
//
constexpr auto reset_dst_data_step = [&]() {
Index reset_dst_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
return reset_dst_data_step_;
}();
if constexpr(num_access == 0)
{
return typename DstSpaceFillingCurve::Index{};
}
else
{
constexpr auto reset_step =
DstSpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
return reset_dst_data_step;
return reset_step;
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
......@@ -638,10 +380,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step_idx);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
......@@ -653,17 +392,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step_idx);
}
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
{
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_access_lengths_and_vector_length = container_push_back(
......@@ -711,9 +444,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
__device__ static constexpr auto GetDstThreadScratchDescriptor()
{
// 1st stage of transforms
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_access_lengths_and_vector_length = container_push_back(
......@@ -761,11 +491,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
#if 1 // debug
using SrcThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
SrcData,
SrcScalarPerVector,
decltype(src_thread_scratch_desc_),
true>;
#else
using SrcThreadScratch = ThreadPrivateTensor<SrcData, decltype(src_thread_scratch_desc_)>;
#endif
using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData,
......
......@@ -133,8 +133,8 @@ struct ThreadwiseTensorSliceTransfer_v4r1
using src_vector_t = typename decltype(src_vector)::type;
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_data_coord);
const bool is_src_valid =
coordinate_has_valid_offset_assuming_top_index_is_valid(src_desc, src_data_coord);
// copy data from src_buf into src_vector
src_vector.template AsType<src_vector_t>()(I0) =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment