Unverified Commit 0e92deb7 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Tile program init bulk PR (#4)



Tile Program init bulk PR

---------
Co-authored-by: default avatarzjing14 <zhangjing14@gmail.com>
Co-authored-by: default avatarPo-Yen, Chen <PoYen.Chen@amd.com>
parent 0077eeb3
...@@ -8,9 +8,56 @@ ...@@ -8,9 +8,56 @@
namespace ck { namespace ck {
enum struct IndexTransformEnum
{
Undefined,
PassThrough,
Pad,
Embed,
Merge,
UnMerge,
Replicate,
Xor,
};
template <index_t NDimLow, index_t NDimUp>
struct BaseTransform
{
__host__ __device__ static constexpr auto GetTypeEnum()
{
return IndexTransformEnum::Undefined;
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; }
// return safe value for vector length/stride, based on compile-time known only
// variables
// MUST be static function
template <typename LowVectorLengths, typename LowVectorStrides>
__host__ __device__ static constexpr auto
CalculateUpperDimensionSafeVectorLengthStrides(const LowVectorLengths&, const LowVectorStrides&)
{
if constexpr(NDimUp > 0)
{
Array<index_t, NDimUp> up_vector_lengths{-1};
Array<index_t, NDimUp> up_vector_strides{-1};
return make_tuple(up_vector_lengths, up_vector_strides);
}
else
{
return make_tuple(Array<index_t, 0>{}, Array<index_t, 0>{});
}
}
};
template <typename LowLength> template <typename LowLength>
struct PassThrough struct PassThrough : public BaseTransform<1, 1>
{ {
static constexpr auto type_enum = IndexTransformEnum::PassThrough;
using LowerIndex = MultiIndex<1>; using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>; using UpperIndex = MultiIndex<1>;
...@@ -25,9 +72,10 @@ struct PassThrough ...@@ -25,9 +72,10 @@ struct PassThrough
{ {
} }
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } __host__ __device__ static constexpr auto GetTypeEnum()
{
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } return IndexTransformEnum::PassThrough;
}
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
...@@ -41,16 +89,11 @@ struct PassThrough ...@@ -41,16 +89,11 @@ struct PassThrough
idx_low(Number<0>{}) = idx_up[Number<0>{}]; idx_low(Number<0>{}) = idx_up[Number<0>{}];
} }
template <typename LowIdxDiff, template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up, const UpIdxDiff& idx_diff_up,
LowIdx& idx_low, LowIdx& idx_low,
const UpIdx&, const UpIdx&)
Number<Hack>)
{ {
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
UpIdx::Size() == 1, UpIdx::Size() == 1,
...@@ -63,8 +106,6 @@ struct PassThrough ...@@ -63,8 +106,6 @@ struct PassThrough
idx_low += idx_diff_low; idx_low += idx_diff_low;
} }
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{ {
return true; return true;
...@@ -82,12 +123,24 @@ struct PassThrough ...@@ -82,12 +123,24 @@ struct PassThrough
return is_known_at_compile_time<UpLengths>::value; return is_known_at_compile_time<UpLengths>::value;
} }
// MUST be static function
template <typename LowVectorLengths, typename LowVectorStrides>
__host__ __device__ static constexpr auto
CalculateUpperDimensionSafeVectorLengthStrides(const LowVectorLengths& low_vector_lengths,
const LowVectorStrides& low_vector_strides)
{
return make_tuple(low_vector_lengths, low_vector_strides);
}
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("{"); printf("PassThrough{");
printf("PassThrough, ");
printf("up_lengths_"); //
print_multi_index(up_lengths_); printf("up_lengths_:");
print(up_lengths_);
//
printf("}"); printf("}");
} }
}; };
...@@ -96,7 +149,7 @@ template <typename LowLength, ...@@ -96,7 +149,7 @@ template <typename LowLength,
typename LeftPadLength, typename LeftPadLength,
typename RightPadLength, typename RightPadLength,
bool SkipIsValidCheck = false> bool SkipIsValidCheck = false>
struct Pad struct Pad : public BaseTransform<1, 1>
{ {
using LowerIndex = MultiIndex<1>; using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>; using UpperIndex = MultiIndex<1>;
...@@ -107,7 +160,7 @@ struct Pad ...@@ -107,7 +160,7 @@ struct Pad
LeftPadLength left_pad_length_; LeftPadLength left_pad_length_;
RightPadLength right_pad_length_; RightPadLength right_pad_length_;
__host__ __device__ constexpr Pad() = default; __host__ __device__ constexpr Pad() : up_lengths_{}, left_pad_length_{}, right_pad_length_{} {}
__host__ __device__ constexpr Pad(const LowLength& low_length, __host__ __device__ constexpr Pad(const LowLength& low_length,
const LeftPadLength& left_pad_length, const LeftPadLength& left_pad_length,
...@@ -118,10 +171,6 @@ struct Pad ...@@ -118,10 +171,6 @@ struct Pad
{ {
} }
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx> template <typename LowIdx, typename UpIdx>
...@@ -134,16 +183,11 @@ struct Pad ...@@ -134,16 +183,11 @@ struct Pad
idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_length_; idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_length_;
} }
template <typename LowIdxDiff, template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up, const UpIdxDiff& idx_diff_up,
LowIdx& idx_low, LowIdx& idx_low,
const UpIdx&, const UpIdx&)
Number<Hack>)
{ {
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
UpIdx::Size() == 1, UpIdx::Size() == 1,
...@@ -156,8 +200,6 @@ struct Pad ...@@ -156,8 +200,6 @@ struct Pad
idx_low += idx_diff_low; idx_low += idx_diff_low;
} }
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{ {
return SkipIsValidCheck; return SkipIsValidCheck;
...@@ -181,12 +223,22 @@ struct Pad ...@@ -181,12 +223,22 @@ struct Pad
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("{"); printf("Pad{");
printf("Pad, ");
printf("up_lengths_"); //
print_multi_index(up_lengths_); printf("up_lengths_: ");
printf("left_pad_length %d", index_t{left_pad_length_}); print(up_lengths_);
printf("right_pad_length %d", index_t{right_pad_length_}); printf(", ");
//
printf("left_pad_length_: ");
print(left_pad_length_);
printf(", ");
//
printf("right_pad_length_: ");
print(right_pad_length_);
printf("}"); printf("}");
} }
}; };
...@@ -210,10 +262,6 @@ struct LeftPad ...@@ -210,10 +262,6 @@ struct LeftPad
{ {
} }
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx> template <typename LowIdx, typename UpIdx>
...@@ -226,16 +274,11 @@ struct LeftPad ...@@ -226,16 +274,11 @@ struct LeftPad
idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_length_; idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_length_;
} }
template <typename LowIdxDiff, template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up, const UpIdxDiff& idx_diff_up,
LowIdx& idx_low, LowIdx& idx_low,
const UpIdx&, const UpIdx&)
Number<Hack>)
{ {
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
UpIdx::Size() == 1, UpIdx::Size() == 1,
...@@ -248,8 +291,6 @@ struct LeftPad ...@@ -248,8 +291,6 @@ struct LeftPad
idx_low += idx_diff_low; idx_low += idx_diff_low;
} }
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{ {
return SkipIsValidCheck; return SkipIsValidCheck;
...@@ -270,17 +311,23 @@ struct LeftPad ...@@ -270,17 +311,23 @@ struct LeftPad
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("{"); printf("LeftPad{");
printf("LeftPad, ");
printf("up_lengths_"); //
print_multi_index(up_lengths_); printf("up_lengths_: ");
printf("left_pad_length_ %d", index_t{left_pad_length_}); print(up_lengths_);
printf(", ");
//
printf("left_pad_length_: ");
print(left_pad_length_);
printf("}"); printf("}");
} }
}; };
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false> template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
struct RightPad struct RightPad : public BaseTransform<1, 1>
{ {
using LowerIndex = MultiIndex<1>; using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>; using UpperIndex = MultiIndex<1>;
...@@ -301,10 +348,6 @@ struct RightPad ...@@ -301,10 +348,6 @@ struct RightPad
{ {
} }
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx> template <typename LowIdx, typename UpIdx>
...@@ -317,16 +360,11 @@ struct RightPad ...@@ -317,16 +360,11 @@ struct RightPad
idx_low(Number<0>{}) = idx_up[Number<0>{}]; idx_low(Number<0>{}) = idx_up[Number<0>{}];
} }
template <typename LowIdxDiff, template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up, const UpIdxDiff& idx_diff_up,
LowIdx& idx_low, LowIdx& idx_low,
const UpIdx&, const UpIdx&)
Number<Hack>)
{ {
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
UpIdx::Size() == 1, UpIdx::Size() == 1,
...@@ -339,8 +377,6 @@ struct RightPad ...@@ -339,8 +377,6 @@ struct RightPad
idx_low += idx_diff_low; idx_low += idx_diff_low;
} }
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{ {
return SkipIsValidCheck; return SkipIsValidCheck;
...@@ -362,12 +398,17 @@ struct RightPad ...@@ -362,12 +398,17 @@ struct RightPad
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("{"); printf("LeftPad{");
printf("RightPad, ");
printf("up_lengths_"); //
print_multi_index(up_lengths_); printf("up_lengths_: ");
printf("low_length_ %d", index_t{low_length_}); print(up_lengths_);
printf("left_pad_length_ %d", index_t{right_pad_length_}); printf(", ");
//
printf("right_pad_length_: ");
print(right_pad_length_);
printf("}"); printf("}");
} }
}; };
...@@ -381,7 +422,7 @@ struct RightPad ...@@ -381,7 +422,7 @@ struct RightPad
template <typename UpLengths, template <typename UpLengths,
typename Coefficients, typename Coefficients,
typename enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false> typename enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
struct Embed struct Embed : public BaseTransform<1, UpLengths::Size()>
{ {
static constexpr index_t NDimUp = UpLengths::Size(); static constexpr index_t NDimUp = UpLengths::Size();
...@@ -399,9 +440,7 @@ struct Embed ...@@ -399,9 +440,7 @@ struct Embed
{ {
} }
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } __host__ __device__ static constexpr auto GetTypeEnum() { return IndexTransformEnum::Embed; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
...@@ -417,735 +456,29 @@ struct Embed ...@@ -417,735 +456,29 @@ struct Embed
static_for<0, NDimUp, 1>{}([&idx_low, &idx_up, this](auto i) { static_for<0, NDimUp, 1>{}([&idx_low, &idx_up, this](auto i) {
idx_low(Number<0>{}) += idx_up[i] * this->coefficients_[i]; idx_low(Number<0>{}) += idx_up[i] * this->coefficients_[i];
}); });
} }
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&,
Number<Hack>) const
{
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == NDimUp &&
LowIdx::Size() == 1 && UpIdx::Size() == NDimUp,
"wrong! inconsistent # of dimension");
idx_diff_low(Number<0>{}) = 0;
static_for<0, NDimUp, 1>{}(
[&](auto i) { idx_diff_low(Number<0>{}) += idx_diff_up[i] * coefficients_[i]; });
idx_low += idx_diff_low;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
template <typename UpIdx>
__host__ __device__ static constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
{
return true;
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<UpLengths>::value &&
is_known_at_compile_time<Coefficients>::value;
}
__host__ __device__ void Print() const
{
printf("{");
printf("Embed, ");
printf("up_lengths_ ");
print_multi_index(up_lengths_);
printf("coefficients_ ");
print_multi_index(coefficients_);
printf("}");
}
};
// Implementation of "Merge" transformation primitive that uses regular to do lowering of
// multi-index and use carry-and-borrow check to do lowering of multi-index delta
template <typename LowLengths>
struct Merge_v1_carry_check
{
static constexpr index_t NDimLow = LowLengths::Size();
using LowerIndex = MultiIndex<NDimLow>;
using UpperIndex = MultiIndex<1>;
using LowLengthsScan =
decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{}));
using UpLengths =
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
LowLengths low_lengths_;
LowLengthsScan low_lengths_scan_;
UpLengths up_lengths_;
__host__ __device__ constexpr Merge_v1_carry_check() = default;
__host__ __device__ constexpr Merge_v1_carry_check(const LowLengths& low_lengths)
: low_lengths_{low_lengths},
low_lengths_scan_{
container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})},
up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
{
static_assert(LowerIndex::Size() == NDimLow, "wrong!");
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
index_t tmp = idx_up[Number<0>{}];
// normal division
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
idx_low(i) = tmp / this->low_lengths_scan_[i];
tmp -= idx_low[i] * this->low_lengths_scan_[i];
});
idx_low(Number<NDimLow - 1>{}) = tmp;
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex_1a(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx& /* idx_up_new */,
Number<Hack>) const
{
static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
// CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions.
// However,
// 1) If idx_diff_up is known at compile-time, then idx_diff_low_const
// can be calculated at compile-time.
// 2) If idx_diff_up is not known at compile-time, but its value
// doesn't change during the whole kernel execution, then
// idx_diff_low_const also
// doesn't change during the whole kernel execution. Compiler generated
// ISA should
// only caclculate idx_diff_low_const once and save it durinng the whole
// kernel execution
// If neither 1) nor 2) is satisfied, then the calculation will also be
// computed at
// run-time each time this function is called, and can be very expensive.
LowerIndex idx_diff_low_const;
LowerIndex idx_low_length_minus_idx_diff_low_const;
LowerIndex idx_low_length_plus_idx_diff_low_const;
#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
index_t tmp = idx_diff_up[Number<0>{}];
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
idx_diff_low_const(i) = tmp / low_lengths_scan_[i];
tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
});
idx_diff_low_const(Number<NDimLow - 1>{}) = tmp;
static_for<0, NDimLow, 1>{}([&](auto i) {
idx_low_length_minus_idx_diff_low_const(i) = low_lengths_[i] - idx_diff_low_const[i];
idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i];
});
#else
// Hack: this force result into SGPR. Need to make sure the result is thread invariant
index_t tmp = idx_diff_up[Number<0>{}];
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]);
tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
});
idx_diff_low_const(Number<NDimLow - 1>{}) = __builtin_amdgcn_readfirstlane(tmp);
static_for<0, NDimLow, 1>{}([&](auto i) {
idx_low_length_minus_idx_diff_low_const(i) =
__builtin_amdgcn_readfirstlane(low_lengths_[i] - idx_diff_low_const[i]);
idx_low_length_plus_idx_diff_low_const(i) =
__builtin_amdgcn_readfirstlane(low_lengths_[i] + idx_diff_low_const[i]);
});
#endif
if constexpr(Hack == 1)
{
// do carry check on each low dimension in reversed order
// do not need to check the first dimension
index_t carry = 0;
static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
index_t idx_low_tmp = idx_low[i] + carry;
bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i];
idx_diff_low(i) =
do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i];
idx_diff_low(i) += carry;
carry = do_carry ? 1 : 0;
});
idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry;
idx_low += idx_diff_low;
}
else if constexpr(Hack == 2)
{
// do carry check on each low dimension in reversed order
// do not need to check the first dimension
index_t borrow = 0;
static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
index_t idx_low_tmp = idx_low[i] - borrow;
bool do_borrow = idx_low_tmp < -idx_diff_low_const[i];
idx_diff_low(i) =
do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low_const[i];
idx_diff_low(i) -= borrow;
borrow = do_borrow ? 1 : 0;
});
idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] - borrow;
idx_low += idx_diff_low;
}
else
{
// do carry check on each low dimension in reversed order
// do not need to check the first dimension
index_t carry = 0;
static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
index_t idx_low_tmp = idx_low[i] + carry;
bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i];
bool do_borrow = idx_low_tmp < -idx_diff_low_const[i];
idx_diff_low(i) =
do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i];
idx_diff_low(i) =
do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low[i];
idx_diff_low(i) += carry;
carry = do_carry ? 1 : 0;
carry = do_borrow ? -1 : carry;
});
idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry;
idx_low += idx_diff_low;
}
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex_1b(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx& /* idx_up_new */,
Number<Hack>) const
{
static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
// CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions.
// However,
// 1) If idx_diff_up is known at compile-time, then idx_diff_low_const
// can be calculated at compile-time.
// 2) If idx_diff_up is not known at compile-time, but its value
// doesn't change during the whole kernel execution, then
// idx_diff_low_const also
// doesn't change during the whole kernel execution. Compiler generated
// ISA should
// only caclculate idx_diff_low_const once and save it durinng the whole
// kernel execution
// If neither 1) nor 2) is satisfied, then the calculation will also be
// computed at
// run-time each time this function is called, and can be very expensive.
LowerIndex idx_diff_low_const;
LowerIndex idx_low_length_minus_idx_diff_low_const;
LowerIndex idx_low_length_plus_idx_diff_low_const;
#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
index_t tmp = idx_diff_up[Number<0>{}];
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
idx_diff_low_const(i) = tmp / low_lengths_scan_[i];
tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
});
idx_diff_low_const(Number<NDimLow - 1>{}) = tmp;
static_for<0, NDimLow, 1>{}([&](auto i) {
idx_low_length_minus_idx_diff_low_const(i) = low_lengths_[i] - idx_diff_low_const[i];
idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i];
});
#else
// Hack: this force result into SGPR. Need to make sure the result is thread invariant
index_t tmp = idx_diff_up[Number<0>{}];
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]);
tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
});
idx_diff_low_const(Number<NDimLow - 1>{}) = __builtin_amdgcn_readfirstlane(tmp);
static_for<0, NDimLow, 1>{}([&](auto i) {
idx_low_length_minus_idx_diff_low_const(i) =
__builtin_amdgcn_readfirstlane(low_lengths_[i] - idx_diff_low_const[i]);
idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i];
});
#endif
if constexpr(Hack == 1)
{
// do carry check on each low dimension in reversed order
// do not need to check the first dimension
index_t carry = 0;
static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
index_t idx_low_tmp = idx_low[i] + carry;
bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i];
idx_diff_low(i) =
do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i];
idx_diff_low(i) += carry;
carry = do_carry ? 1 : 0;
});
idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry;
idx_low += idx_diff_low;
}
else if constexpr(Hack == 2)
{
// do carry check on each low dimension in reversed order
// do not need to check the first dimension
index_t borrow = 0;
static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
index_t negative_idx_low_tmp = borrow - idx_low[i];
bool do_borrow = negative_idx_low_tmp > idx_diff_low_const[i];
idx_diff_low(i) =
do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low_const[i];
idx_diff_low(i) -= borrow;
borrow = do_borrow ? 1 : 0;
});
idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] - borrow;
idx_low += idx_diff_low;
}
else
{
// do carry check on each low dimension in reversed order
// do not need to check the first dimension
index_t carry = 0;
static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
index_t idx_low_tmp = idx_low[i] + carry;
bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i];
bool do_borrow = idx_low_tmp < -idx_diff_low_const[i];
idx_diff_low(i) =
do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i];
idx_diff_low(i) =
do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low[i];
idx_diff_low(i) += carry;
carry = do_carry ? 1 : 0;
carry = do_borrow ? -1 : carry;
});
idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry;
idx_low += idx_diff_low;
}
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex_2(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx& /* idx_up_new */,
Number<Hack>) const
{
static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
// CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions.
// However,
// 1) If idx_diff_up is known at compile-time, then idx_diff_low_const
// can be calculated at compile-time.
// 2) If idx_diff_up is not known at compile-time, but its value
// doesn't change during the whole kernel execution, then
// idx_diff_low_const also
// doesn't change during the whole kernel execution. Compiler generated
// ISA should
// only caclculate idx_diff_low_const once and save it durinng the whole
// kernel execution
// If neither 1) nor 2) is satisfied, then the calculation will also be
// computed at run-time each time this function is called, and can be
// very expensive.
LowerIndex idx_diff_low_const;
#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
index_t tmp = idx_diff_up[Number<0>{}];
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
idx_diff_low_const(i) = tmp / low_lengths_scan_[i];
tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
});
idx_diff_low_const(Number<NDimLow - 1>{}) = tmp;
#else
// Hack: this force result into SGPR. Need to make sure the result is thread invariant
index_t tmp = idx_diff_up[Number<0>{}];
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]);
tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
});
idx_diff_low_const(Number<NDimLow - 1>{}) = __builtin_amdgcn_readfirstlane(tmp);
#endif
if constexpr(Hack == 1)
{
// do carry check on each low dimension in reversed order
// do not need to check the first dimension
bool do_carry = 0;
static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
idx_diff_low(i) = idx_diff_low_const[i] + do_carry;
index_t idx_low_tmp = idx_low[i] + idx_diff_low[i];
do_carry = idx_low_tmp >= low_lengths_[i];
#if 0
// TODO: use exec-mask inline asm, which use 1 VALU
if(do_carry)
{
idx_diff_low(i) -= low_lengths_[i];
}
#elif 1
// this use 2 VALU
idx_diff_low(i) = do_carry ? idx_diff_low[i] - low_lengths_[i] : idx_diff_low[i];
#elif 1
// this use 2 VALU
index_t idx_diff_low_tmp = idx_diff_low[i] - low_lengths_[i];
idx_diff_low(i) = do_carry ? idx_diff_low_tmp : idx_diff_low[i];
#endif
idx_low(i) += idx_diff_low[i];
});
constexpr auto I0 = Number<0>{};
idx_diff_low(I0) = idx_diff_low_const[I0] + do_carry;
idx_low(I0) += idx_diff_low[I0];
}
else if constexpr(Hack == 2)
{
// do borrow check on each low dimension in reversed order
// do not need to check the first dimension
bool do_borrow = 0;
static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
idx_diff_low(i) = idx_diff_low_const[i] - do_borrow;
index_t idx_low_tmp = idx_low[i] + idx_diff_low[i];
do_borrow = idx_low_tmp < 0;
#if 0
// TODO: use exec-mask inline asm
if(do_borrow)
{
idx_diff_low(i) += low_lengths_[i];
}
#elif 1
idx_diff_low(i) = do_borrow ? idx_diff_low[i] + low_lengths_[i] : idx_diff_low[i];
#elif 1
index_t idx_diff_low_tmp = idx_diff_low[i] + low_lengths_[i];
idx_diff_low(i) = do_borrow ? idx_diff_low_tmp : idx_diff_low[i];
#endif
idx_low(i) += idx_diff_low[i];
});
constexpr auto I0 = Number<0>{};
idx_diff_low(I0) = idx_diff_low_const[I0] - do_borrow;
idx_low(I0) += idx_diff_low[I0];
}
else
{
// not implemented
}
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx& idx_up_new,
Number<Hack>) const
{
#if 1
UpdateLowerIndex_1a(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
#elif 0
UpdateLowerIndex_1b(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
#else
UpdateLowerIndex_2(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
#endif
}
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<LowLengths>::value &&
is_known_at_compile_time<LowLengthsScan>::value &&
is_known_at_compile_time<UpLengths>::value;
}
template <typename UpIdx>
__host__ __device__ static constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
{
return true;
}
__host__ __device__ void Print() const
{
printf("{");
printf("Merge_v1_carry_check, ");
printf("low_lengths_ ");
print_multi_index(low_lengths_);
printf("low_lengths_scan_ ");
print_multi_index(low_lengths_scan_);
printf("up_lengths_ ");
print_multi_index(up_lengths_);
printf("}");
}
};
template <typename LowLengths>
struct lambda_merge_generate_MagicDivision_calculate_magic_multiplier
{
template <index_t I>
__host__ __device__ constexpr auto operator()(Number<I> i) const
{
return MagicDivision::CalculateMagicMultiplier(LowLengths{}[i]);
}
};
template <typename LowLengths>
struct lambda_merge_generate_MagicDivision_calculate_magic_shift
{
template <index_t I>
__host__ __device__ constexpr auto operator()(Number<I> i) const
{
return MagicDivision::CalculateMagicShift(LowLengths{}[i]);
}
};
// Implementation of "Merge" transformation primitive that uses magic-number-division to do lowering
// of both multi-index and delta of multi-index
// Caution:
// 1. The magic number division implementation being used would produce correct result if the
// dividended is uint32_t and its value is with in 31-bit value range of uint32_t.
// 2. The magic number division for int32_t dividened has not been implemented, the int32_t
// dividend would be bit-wise interpreted as uint32_t and magic number division implementation for
// uint32_t is then used.
// 3. For Merge primitive, upper-index is the dividend.
// 4. When upper-index is uint32_t, its value need to be within 31-bit range.
// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
// non-negative.
template <typename LowLengths>
struct Merge_v2_magic_division
{
static constexpr index_t NDimLow = LowLengths::Size();
using LowerIndex = MultiIndex<NDimLow>;
using UpperIndex = MultiIndex<1>;
using UpLengths =
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
using LowLengthsMagicDivisorMultipiler = decltype(generate_tuple(
lambda_merge_generate_MagicDivision_calculate_magic_multiplier<LowLengths>{},
Number<NDimLow>{}));
using LowLengthsMagicDivisorShift = decltype(generate_tuple(
lambda_merge_generate_MagicDivision_calculate_magic_shift<LowLengths>{},
Number<NDimLow>{}));
LowLengths low_lengths_;
LowLengthsMagicDivisorMultipiler low_lengths_magic_divisor_multiplier_;
LowLengthsMagicDivisorShift low_lengths_magic_divisor_shift_;
UpLengths up_lengths_;
__host__ __device__ constexpr Merge_v2_magic_division() = default;
__host__ __device__ constexpr Merge_v2_magic_division(const LowLengths& low_lengths)
: low_lengths_{low_lengths},
low_lengths_magic_divisor_multiplier_{generate_tuple(
[&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths[i]); },
Number<NDimLow>{})},
low_lengths_magic_divisor_shift_{generate_tuple(
[&](auto i) { return MagicDivision::CalculateMagicShift(low_lengths[i]); },
Number<NDimLow>{})},
up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
{
static_assert(LowerIndex::Size() == NDimLow, "wrong!");
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
index_t tmp = idx_up[Number<0>{}];
static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
index_t tmp2 =
MagicDivision::DoMagicDivision(tmp,
this->low_lengths_magic_divisor_multiplier_[i],
this->low_lengths_magic_divisor_shift_[i]);
idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
tmp = tmp2;
});
idx_low(Number<0>{}) = tmp;
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff&,
LowIdx& idx_low,
const UpIdx& idx_up_new,
Number<Hack>) const
{
static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
index_t tmp = idx_up_new[Number<0>{}];
static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
index_t tmp2 =
MagicDivision::DoMagicDivision(tmp,
this->low_lengths_magic_divisor_multiplier_[i],
this->low_lengths_magic_divisor_shift_[i]);
index_t idx_low_old = idx_low[i];
idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
tmp = tmp2;
idx_diff_low(i) = idx_low[i] - idx_low_old;
});
idx_diff_low(Number<0>{}) = tmp - idx_low(Number<0>{}); template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&) const
{
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == NDimUp &&
LowIdx::Size() == 1 && UpIdx::Size() == NDimUp,
"wrong! inconsistent # of dimension");
idx_low(Number<0>{}) = tmp; idx_diff_low(Number<0>{}) = 0;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return false; } static_for<0, NDimUp, 1>{}(
[&](auto i) { idx_diff_low(Number<0>{}) += idx_diff_up[i] * coefficients_[i]; });
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() idx_low += idx_diff_low;
{
return true;
} }
__host__ __device__ static constexpr bool IsKnownAtCompileTime() __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{ {
return is_known_at_compile_time<LowLengths>::value && return true;
is_known_at_compile_time<LowLengthsMagicDivisorMultipiler>::value &&
is_known_at_compile_time<LowLengthsMagicDivisorShift>::value &&
is_known_at_compile_time<UpLengths>::value;
} }
template <typename UpIdx> template <typename UpIdx>
...@@ -1155,22 +488,39 @@ struct Merge_v2_magic_division ...@@ -1155,22 +488,39 @@ struct Merge_v2_magic_division
return true; return true;
} }
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<UpLengths>::value &&
is_known_at_compile_time<Coefficients>::value;
}
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("{"); printf("Embed{");
printf("Merge_v2_magic_division, ");
printf("low_lengths_ "); //
print_multi_index(low_lengths_); printf("up_lengths_: ");
printf("low_lengths_magic_divisor_multiplier_ "); print(up_lengths_);
print_multi_index(low_lengths_magic_divisor_multiplier_); printf(", ");
printf("low_lengths_magic_divisor_shift_ ");
print_multi_index(low_lengths_magic_divisor_shift_); //
printf("up_lengths_ "); printf("coefficients_: ");
print_multi_index(up_lengths_); print(coefficients_);
printf("}"); printf("}");
} }
}; };
template <typename LowLengths>
struct lambda_merge_generate_MagicDivision_calculate_magic_divisor
{
template <index_t I>
__host__ __device__ constexpr auto operator()(Number<I> i) const
{
return MagicDivision::CalculateMagicNumbers(LowLengths{}[i]);
}
};
// Implementation of "Merge" transformation primitive that uses magic-number-division to do lowering // Implementation of "Merge" transformation primitive that uses magic-number-division to do lowering
// of both multi-index and delta of multi-index // of both multi-index and delta of multi-index
// Caution: // Caution:
...@@ -1184,53 +534,40 @@ struct Merge_v2_magic_division ...@@ -1184,53 +534,40 @@ struct Merge_v2_magic_division
// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be // 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
// non-negative. // non-negative.
template <typename LowLengths> template <typename LowLengths>
struct Merge_v2r2_magic_division struct Merge_v2_magic_division : public BaseTransform<LowLengths::Size(), 1>
{ {
static constexpr index_t NDimLow = LowLengths::Size(); static constexpr index_t NDimLow = LowLengths::Size();
using LowerIndex = MultiIndex<NDimLow>; using LowerIndex = MultiIndex<NDimLow>;
using UpperIndex = MultiIndex<1>; using UpperIndex = MultiIndex<1>;
using LowLengthsScan =
decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{}));
using UpLengths = using UpLengths =
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{}))); decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
using LowLengthsScanMagicDivisorMultipiler = decltype(generate_tuple( using LowLengthsMagicDivisor = decltype(
lambda_merge_generate_MagicDivision_calculate_magic_multiplier<LowLengthsScan>{}, generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_divisor<LowLengths>{},
Number<NDimLow>{})); Number<NDimLow>{}));
using LowLengthsScanMagicDivisorShift = decltype(generate_tuple(
lambda_merge_generate_MagicDivision_calculate_magic_shift<LowLengthsScan>{},
Number<NDimLow>{}));
LowLengths low_lengths_; LowLengths low_lengths_;
LowLengthsScan low_lengths_scan_; LowLengthsMagicDivisor low_lengths_magic_divisor_;
LowLengthsScanMagicDivisorMultipiler low_lengths_scan_magic_divisor_multiplier_;
LowLengthsScanMagicDivisorShift low_lengths_scan_magic_divisor_shift_;
UpLengths up_lengths_; UpLengths up_lengths_;
__host__ __device__ constexpr Merge_v2r2_magic_division() = default; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ constexpr Merge_v2_magic_division() = default;
__host__ __device__ constexpr Merge_v2r2_magic_division(const LowLengths& low_lengths) __host__ __device__ constexpr Merge_v2_magic_division(const LowLengths& low_lengths)
: low_lengths_{low_lengths}, : low_lengths_{low_lengths},
low_lengths_scan_{ low_lengths_magic_divisor_{generate_tuple(
container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})}, [&](auto i) { return MagicDivision::CalculateMagicNumbers(low_lengths[i]); },
low_lengths_scan_magic_divisor_multiplier_{generate_tuple(
[&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths_scan_[i]); },
Number<NDimLow>{})}, Number<NDimLow>{})},
low_lengths_scan_magic_divisor_shift_{generate_tuple( up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, I1))}
[&](auto i) { return MagicDivision::CalculateMagicShift(low_lengths_scan_[i]); },
Number<NDimLow>{})},
up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
{ {
static_assert(LowerIndex::Size() == NDimLow, "wrong!"); static_assert(LowerIndex::Size() == NDimLow, "wrong!");
} }
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; } __host__ __device__ static constexpr auto GetTypeEnum() { return IndexTransformEnum::Merge; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
...@@ -1241,30 +578,24 @@ struct Merge_v2r2_magic_division ...@@ -1241,30 +578,24 @@ struct Merge_v2r2_magic_division
static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1, static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension"); "wrong! inconsistent # of dimension");
index_t tmp = idx_up[Number<0>{}]; index_t tmp = idx_up[I0];
static_for<0, NDimLow - 1, 1>{}([&, this](auto i) {
idx_low(i) =
MagicDivision::DoMagicDivision(tmp,
this->low_lengths_scan_magic_divisor_multiplier_[i],
this->low_lengths_scan_magic_divisor_shift_[i]);
tmp -= idx_low[i] * this->low_lengths_scan_[i]; static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
index_t tmp2 = MagicDivision::DoMagicDivision(tmp,
this->low_lengths_magic_divisor_[i][I0],
this->low_lengths_magic_divisor_[i][I1]);
idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
tmp = tmp2;
}); });
idx_low(Number<NDimLow - 1>{}) = tmp; idx_low(Number<0>{}) = tmp;
} }
template <typename LowIdxDiff, template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff&, const UpIdxDiff&,
LowIdx& idx_low, LowIdx& idx_low,
const UpIdx& idx_up_new, const UpIdx& idx_up_new) const
Number<Hack>) const
{ {
static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
LowIdx::Size() == NDimLow && UpIdx::Size() == 1, LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
...@@ -1272,26 +603,24 @@ struct Merge_v2r2_magic_division ...@@ -1272,26 +603,24 @@ struct Merge_v2r2_magic_division
index_t tmp = idx_up_new[Number<0>{}]; index_t tmp = idx_up_new[Number<0>{}];
static_for<0, NDimLow - 1, 1>{}([&, this](auto i) { static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
index_t tmp2 = MagicDivision::DoMagicDivision(tmp,
this->low_lengths_magic_divisor_[i][I0],
this->low_lengths_magic_divisor_[i][I1]);
index_t idx_low_old = idx_low[i]; index_t idx_low_old = idx_low[i];
idx_low(i) = idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
MagicDivision::DoMagicDivision(tmp, tmp = tmp2;
this->low_lengths_scan_magic_divisor_multiplier_[i],
this->low_lengths_scan_magic_divisor_shift_[i]);
idx_diff_low(i) = idx_low[i] - idx_low_old; idx_diff_low(i) = idx_low[i] - idx_low_old;
tmp -= idx_low[i] * this->low_lengths_scan_[i];
}); });
idx_diff_low(Number<NDimLow - 1>{}) = tmp - idx_low[Number<NDimLow - 1>{}]; idx_diff_low(Number<0>{}) = tmp - idx_low(Number<0>{});
idx_low(Number<NDimLow - 1>{}) = tmp; idx_low(Number<0>{}) = tmp;
} }
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{ {
return true; return true;
...@@ -1300,8 +629,7 @@ struct Merge_v2r2_magic_division ...@@ -1300,8 +629,7 @@ struct Merge_v2r2_magic_division
__host__ __device__ static constexpr bool IsKnownAtCompileTime() __host__ __device__ static constexpr bool IsKnownAtCompileTime()
{ {
return is_known_at_compile_time<LowLengths>::value && return is_known_at_compile_time<LowLengths>::value &&
is_known_at_compile_time<LowLengthsScanMagicDivisorMultipiler>::value && is_known_at_compile_time<LowLengthsMagicDivisor>::value &&
is_known_at_compile_time<LowLengthsScanMagicDivisorShift>::value &&
is_known_at_compile_time<UpLengths>::value; is_known_at_compile_time<UpLengths>::value;
} }
...@@ -1312,20 +640,34 @@ struct Merge_v2r2_magic_division ...@@ -1312,20 +640,34 @@ struct Merge_v2r2_magic_division
return true; return true;
} }
// MUST be static function
template <typename LowVectorLengths, typename LowVectorStrides>
__host__ __device__ static constexpr auto
CalculateUpperDimensionSafeVectorLengthStrides(const LowVectorLengths& low_vector_lengths,
const LowVectorStrides& low_vector_strides)
{
Array<index_t, 1> up_vector_lengths{-1};
Array<index_t, 1> up_vector_strides{-1};
up_vector_lengths(0) = low_vector_lengths[Number<NDimLow - 1>{}];
up_vector_strides(0) = low_vector_strides[Number<NDimLow - 1>{}];
return make_tuple(up_vector_lengths, up_vector_strides);
}
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("{"); printf("Merge_v2_magic_division{");
printf("Merge_v2r2_magic_division, ");
//
printf("low_lengths_ "); printf("low_lengths_ ");
print_multi_index(low_lengths_); print(low_lengths_);
printf("low_lengths_scan "); printf(", ");
print_multi_index(low_lengths_scan_);
printf("low_lengths_scan_magic_divisor_multiplier_ "); //
print_multi_index(low_lengths_scan_magic_divisor_multiplier_);
printf("low_lengths_scan_magic_divisor_shift_ ");
print_multi_index(low_lengths_scan_magic_divisor_shift_);
printf("up_lengths_ "); printf("up_lengths_ ");
print_multi_index(up_lengths_); print(up_lengths_);
printf("}"); printf("}");
} }
}; };
...@@ -1334,7 +676,7 @@ struct Merge_v2r2_magic_division ...@@ -1334,7 +676,7 @@ struct Merge_v2r2_magic_division
// be used for low_lengths that are known at compile time and are power of 2, otherwise performance // be used for low_lengths that are known at compile time and are power of 2, otherwise performance
// will be very bad // will be very bad
template <typename LowLengths> template <typename LowLengths>
struct Merge_v3_division_mod struct Merge_v3_division_mod : public BaseTransform<LowLengths::Size(), 1>
{ {
static constexpr index_t NDimLow = LowLengths::Size(); static constexpr index_t NDimLow = LowLengths::Size();
...@@ -1362,10 +704,6 @@ struct Merge_v3_division_mod ...@@ -1362,10 +704,6 @@ struct Merge_v3_division_mod
static_assert(LowerIndex::Size() == NDimLow, "wrong!"); static_assert(LowerIndex::Size() == NDimLow, "wrong!");
} }
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx> template <typename LowIdx, typename UpIdx>
...@@ -1386,16 +724,11 @@ struct Merge_v3_division_mod ...@@ -1386,16 +724,11 @@ struct Merge_v3_division_mod
idx_low(Number<NDimLow - 1>{}) = tmp; idx_low(Number<NDimLow - 1>{}) = tmp;
} }
template <typename LowIdxDiff, template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff&, const UpIdxDiff&,
LowIdx& idx_low, LowIdx& idx_low,
const UpIdx& idx_up_new, const UpIdx& idx_up_new) const
Number<Hack>) const
{ {
static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
LowIdx::Size() == NDimLow && UpIdx::Size() == 1, LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
...@@ -1418,8 +751,6 @@ struct Merge_v3_division_mod ...@@ -1418,8 +751,6 @@ struct Merge_v3_division_mod
idx_diff_low(INm1) = idx_low[INm1] - tmp2; idx_diff_low(INm1) = idx_low[INm1] - tmp2;
} }
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{ {
return true; return true;
...@@ -1439,22 +770,45 @@ struct Merge_v3_division_mod ...@@ -1439,22 +770,45 @@ struct Merge_v3_division_mod
return true; return true;
} }
// MUST be static function
template <typename LowVectorLengths, typename LowVectorStrides>
__host__ __device__ static constexpr auto
CalculateUpperDimensionSafeVectorLengthStrides(const LowVectorLengths& low_vector_lengths,
const LowVectorStrides& low_vector_strides)
{
Array<index_t, 1> up_vector_lengths{-1};
Array<index_t, 1> up_vector_strides{-1};
up_vector_lengths(0) = low_vector_lengths[Number<NDimLow - 1>{}];
up_vector_strides(0) = low_vector_strides[Number<NDimLow - 1>{}];
return make_tuple(up_vector_lengths, up_vector_strides);
}
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("{"); printf("Merge_v3_direct_division_mod{");
printf("Merge_v3_direct_division_mod, ");
//
printf("low_lengths_ "); printf("low_lengths_ ");
print_multi_index(low_lengths_); print(low_lengths_);
printf(", ");
//
printf("low_lengths_scan_ "); printf("low_lengths_scan_ ");
print_multi_index(low_lengths_scan_); print(low_lengths_scan_);
printf(", ");
//
printf("up_lengths_ "); printf("up_lengths_ ");
print_multi_index(up_lengths_); print(up_lengths_);
printf("}"); printf("}");
} }
}; };
template <typename UpLengths, bool Use24BitIntegerCalculation> template <typename UpLengths, bool Use24BitIntegerCalculation>
struct UnMerge struct UnMerge : public BaseTransform<1, UpLengths::Size()>
{ {
static constexpr index_t NDimUp = UpLengths::Size(); static constexpr index_t NDimUp = UpLengths::Size();
...@@ -1476,9 +830,7 @@ struct UnMerge ...@@ -1476,9 +830,7 @@ struct UnMerge
{ {
} }
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } __host__ __device__ static constexpr auto GetTypeEnum() { return IndexTransformEnum::UnMerge; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
...@@ -1505,24 +857,17 @@ struct UnMerge ...@@ -1505,24 +857,17 @@ struct UnMerge
} }
} }
template <typename LowIdxDiff, template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up, const UpIdxDiff& idx_diff_up,
LowIdx& idx_low, LowIdx& idx_low,
const UpIdx&, const UpIdx&) const
Number<Hack>) const
{ {
CalculateLowerIndex(idx_diff_low, idx_diff_up); CalculateLowerIndex(idx_diff_low, idx_diff_up);
idx_low += idx_diff_low; idx_low += idx_diff_low;
} }
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{ {
return true; return true;
...@@ -1541,20 +886,49 @@ struct UnMerge ...@@ -1541,20 +886,49 @@ struct UnMerge
is_known_at_compile_time<UpLengthsScan>::value; is_known_at_compile_time<UpLengthsScan>::value;
} }
// MUST be static function
template <typename LowVectorLengths, typename LowVectorStrides>
__host__ __device__ static constexpr auto
CalculateUpperDimensionSafeVectorLengthStrides(const LowVectorLengths& low_vector_lengths,
const LowVectorStrides& low_vector_strides)
{
Array<index_t, NDimUp> up_vector_lengths{-1};
Array<index_t, NDimUp> up_vector_strides{-1};
constexpr auto up_length_last = UpLengths{}[Number<NDimUp - 1>{}];
if constexpr(is_known_at_compile_time<decltype(up_length_last)>::value)
{
if(low_vector_lengths[0] != -1)
{
up_vector_lengths(NDimUp - 1) = math::gcd(low_vector_lengths[0], up_length_last);
}
}
up_vector_strides(NDimUp - 1) = low_vector_strides[0];
return make_tuple(up_vector_lengths, up_vector_strides);
}
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("{"); printf("UnMerge{");
printf("UnMerge, ");
//
printf("up_lengths_"); printf("up_lengths_");
print_multi_index(up_lengths_); print(up_lengths_);
printf(", ");
//
printf("up_lengths_scan_"); printf("up_lengths_scan_");
print_multi_index(up_lengths_scan_); print(up_lengths_scan_);
printf("}"); printf("}");
} }
}; };
template <typename LowerIndex> template <typename LowerIndex>
struct Freeze struct Freeze : public BaseTransform<1, 0>
{ {
LowerIndex low_idx_; LowerIndex low_idx_;
...@@ -1562,10 +936,6 @@ struct Freeze ...@@ -1562,10 +936,6 @@ struct Freeze
__host__ __device__ constexpr Freeze(const LowerIndex& low_idx) : low_idx_{low_idx} {} __host__ __device__ constexpr Freeze(const LowerIndex& low_idx) : low_idx_{low_idx} {}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 0; }
__host__ __device__ static constexpr auto GetUpperLengths() { return Tuple<>{}; } __host__ __device__ static constexpr auto GetUpperLengths() { return Tuple<>{}; }
template <typename LowIdx, typename UpIdx> template <typename LowIdx, typename UpIdx>
...@@ -1578,22 +948,15 @@ struct Freeze ...@@ -1578,22 +948,15 @@ struct Freeze
idx_low(Number<0>{}) = low_idx_; idx_low(Number<0>{}) = low_idx_;
} }
template <typename LowIdxDiff, template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& /* idx_diff_up */, const UpIdxDiff& /* idx_diff_up */,
LowIdx& /* idx_low */, LowIdx& /* idx_low */,
const UpIdx& /* idx_up_new */, const UpIdx& /* idx_up_new */)
Number<Hack>)
{ {
idx_diff_low(Number<0>{}) = 0; idx_diff_low(Number<0>{}) = 0;
} }
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{ {
return true; return true;
...@@ -1613,54 +976,46 @@ struct Freeze ...@@ -1613,54 +976,46 @@ struct Freeze
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("Freeze"); printf("Freeze{");
printf("low_idx_ %d", index_t{low_idx_});
//
printf("low_idx_: ");
print(low_idx_);
printf("}");
} }
}; };
// Insert a dangling upper dimension without lower dimension // Replicate the original tensor and create a higher dimensional tensor
template <typename UpperLength> template <typename UpLengths>
struct Insert struct Replicate : public BaseTransform<0, UpLengths::Size()>
{ {
using UpLengths = decltype(make_tuple(UpperLength{})); static constexpr index_t NDimUp = UpLengths::Size();
UpLengths up_lengths_;
__host__ __device__ constexpr Insert() = default; __host__ __device__ constexpr Replicate() = default;
__host__ __device__ constexpr Insert(const UpperLength& up_length) __host__ __device__ constexpr Replicate(const UpLengths& up_lengths) : up_lengths_{up_lengths}
: up_lengths_{make_tuple(up_length)}
{ {
} }
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 0; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr auto GetUpperLengths() const { return up_lengths_; } __host__ __device__ constexpr auto GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx> template <typename LowIdx, typename UpIdx>
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx&, const UpIdx&) const __host__ __device__ constexpr void CalculateLowerIndex(LowIdx&, const UpIdx&) const
{ {
static_assert(LowIdx::Size() == 0 && UpIdx::Size() == 1, static_assert(LowIdx::Size() == 0 && UpIdx::Size() == NDimUp,
"wrong! inconsistent # of dimension"); "wrong! inconsistent # of dimension");
} }
template <typename LowIdxDiff, template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ static void __host__ __device__ static void
UpdateLowerIndex(LowIdxDiff&, const UpIdxDiff&, LowIdx&, const UpIdx&, Number<Hack>) UpdateLowerIndex(LowIdxDiff&, const UpIdxDiff&, LowIdx&, const UpIdx&)
{ {
static_assert(LowIdxDiff::Size() == 0 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 0 && static_assert(LowIdxDiff::Size() == 0 && UpIdxDiff::Size() == NDimUp &&
UpIdx::Size() == 1, LowIdx::Size() == 0 && UpIdx::Size() == NDimUp,
"wrong! inconsistent # of dimension"); "wrong! inconsistent # of dimension");
} }
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{ {
return true; return true;
...@@ -1675,39 +1030,47 @@ struct Insert ...@@ -1675,39 +1030,47 @@ struct Insert
__host__ __device__ static constexpr bool IsKnownAtCompileTime() __host__ __device__ static constexpr bool IsKnownAtCompileTime()
{ {
return is_known_at_compile_time<UpperLength>::value; return is_known_at_compile_time<UpLengths>::value;
} }
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("Insert"); printf("Replicate{");
print_multi_index(up_lengths_);
//
printf("up_lengths_: ");
print(up_lengths_);
printf("}");
} }
//
UpLengths up_lengths_;
}; };
template <typename VectorSize, typename UpLength> template <typename LowLength, typename SliceBegin, typename SliceEnd>
struct Vectorize struct Slice : public BaseTransform<1, 1>
{ {
using LowerIndex = MultiIndex<1>; using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>; using UpperIndex = MultiIndex<1>;
using UpLengths = decltype(make_tuple(UpLength{})); using UpLengths = decltype(make_tuple(SliceEnd{} - SliceBegin{}));
UpLengths up_lengths_; UpLengths up_lengths_;
VectorSize vector_size_; SliceBegin slice_begin_;
SliceEnd slice_end_;
__host__ __device__ constexpr Vectorize() = default; __host__ __device__ constexpr Slice() = default;
__host__ __device__ constexpr Vectorize(const VectorSize& vector_size, __host__ __device__ constexpr Slice(const LowLength&,
const UpLength& up_length) const SliceBegin& slice_begin,
: vector_size_{vector_size}, up_lengths_{make_tuple(up_length)} const SliceEnd& slice_end)
: up_lengths_{make_tuple(slice_end - slice_begin)},
slice_begin_{slice_begin},
slice_end_{slice_end}
{ {
} }
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx> template <typename LowIdx, typename UpIdx>
...@@ -1717,19 +1080,14 @@ struct Vectorize ...@@ -1717,19 +1080,14 @@ struct Vectorize
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension"); "wrong! inconsistent # of dimension");
idx_low(Number<0>{}) = vector_size_ * idx_up[Number<0>{}]; idx_low(Number<0>{}) = idx_up[Number<0>{}] + slice_begin_;
} }
template <typename LowIdxDiff, template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
typename UpIdxDiff, __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
typename LowIdx, const UpIdxDiff& idx_diff_up,
typename UpIdx, LowIdx& idx_low,
index_t Hack> const UpIdx&)
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&,
Number<Hack>) const
{ {
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
UpIdx::Size() == 1, UpIdx::Size() == 1,
...@@ -1737,67 +1095,72 @@ struct Vectorize ...@@ -1737,67 +1095,72 @@ struct Vectorize
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
idx_diff_low(I0) = vector_size_ * idx_diff_up[I0]; idx_diff_low(I0) = idx_diff_up[I0];
idx_low += idx_diff_low; idx_low += idx_diff_low;
} }
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{ {
return true; return true;
} }
template <typename UpIdx> template <typename UpIdx>
__host__ __device__ static constexpr bool __host__ __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx&) const
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
{ {
return true; return true;
} }
__host__ __device__ static constexpr bool IsKnownAtCompileTime() __host__ __device__ static constexpr bool IsKnownAtCompileTime()
{ {
return is_known_at_compile_time<UpLengths>::value; return is_known_at_compile_time<UpLengths>::value &&
is_known_at_compile_time<SliceBegin>::value &&
is_known_at_compile_time<SliceEnd>::value;
} }
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("{"); printf("Slice{");
printf("Vectorize, ");
printf("up_lengths_"); //
print_multi_index(up_lengths_); printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("slice_begin_: ");
print(slice_begin_);
printf(", ");
//
printf("slice_end_: ");
print(slice_end_);
printf("}"); printf("}");
} } // namespace ck
}; }; // namespace ck
template <typename LowLength, typename SliceBegin, typename SliceEnd> /*
struct Slice * \brief lower_idx = upper_idx % modulus.
* TODO: Need an improved implementation since the modulo operation is expensive.
*/
template <typename Modulus, typename UpLength>
struct Modulo : public BaseTransform<1, 1>
{ {
using LowerIndex = MultiIndex<1>; using LowerIndex = MultiIndex<1>;
using UpperIndex = MultiIndex<1>; using UpperIndex = MultiIndex<1>;
using UpLengths = decltype(make_tuple(UpLength{}));
using UpLengths = decltype(make_tuple(SliceEnd{} - SliceBegin{})); Modulus modulus_;
UpLengths up_lengths_; UpLengths up_lengths_;
SliceBegin slice_begin_;
SliceEnd slice_end_;
__host__ __device__ constexpr Slice() = default; __host__ __device__ constexpr Modulo() = default;
__host__ __device__ constexpr Slice(const LowLength&, __host__ __device__ constexpr Modulo(const Modulus& modulus, const UpLength& up_length)
const SliceBegin& slice_begin, : modulus_{modulus}, up_lengths_{make_tuple(up_length)}
const SliceEnd& slice_end)
: up_lengths_{make_tuple(slice_end - slice_begin)},
slice_begin_{slice_begin},
slice_end_{slice_end}
{ {
} }
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx> template <typename LowIdx, typename UpIdx>
...@@ -1807,19 +1170,14 @@ struct Slice ...@@ -1807,19 +1170,14 @@ struct Slice
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension"); "wrong! inconsistent # of dimension");
idx_low(Number<0>{}) = idx_up[Number<0>{}] + slice_begin_; idx_low(Number<0>{}) = idx_up[Number<0>{}] % modulus_;
} }
template <typename LowIdxDiff, template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
typename UpIdxDiff, __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
typename LowIdx, const UpIdxDiff& idx_diff_up,
typename UpIdx, LowIdx& idx_low,
index_t Hack> const UpIdx& up_idx) const
__host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&,
Number<Hack>)
{ {
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
UpIdx::Size() == 1, UpIdx::Size() == 1,
...@@ -1827,67 +1185,62 @@ struct Slice ...@@ -1827,67 +1185,62 @@ struct Slice
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
idx_diff_low(I0) = idx_diff_up[I0]; const auto idx_low_old = idx_low;
idx_low(I0) = (up_idx(I0) + idx_diff_up(I0)) % modulus_;
idx_low += idx_diff_low; idx_diff_low(I0) = idx_low - idx_low_old;
} }
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{ {
return true; return true;
} }
template <typename UpIdx> template <typename UpIdx>
__host__ __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx&) const __host__ __device__ static constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
{ {
return true; return true;
} }
__host__ __device__ static constexpr bool IsKnownAtCompileTime() __host__ __device__ static constexpr bool IsKnownAtCompileTime()
{ {
return is_known_at_compile_time<UpLengths>::value && return is_known_at_compile_time<UpLengths>::value;
is_known_at_compile_time<SliceBegin>::value &&
is_known_at_compile_time<SliceEnd>::value;
} }
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("{"); printf("Modulus{");
printf("Slice, ");
printf("up_lengths_"); //
print_multi_index(up_lengths_); printf("up_lengths_: ");
printf("slice_begin_ %d", index_t{slice_begin_}); print(up_lengths_);
printf("slice_end %d", index_t{slice_end_});
printf("}"); printf("}");
} }
}; };
/* // 2D XOR
* \brief lower_idx = upper_idx % modulus. template <typename LowLengths, typename RightShift>
* TODO: Need an improved implementation since the modulo operation is expensive. struct Xor : public BaseTransform<2, 2>
*/
template <typename Modulus, typename UpLength>
struct Modulo
{ {
using LowerIndex = MultiIndex<1>; static constexpr auto type_enum = IndexTransformEnum::Xor;
using UpperIndex = MultiIndex<1>;
using UpLengths = decltype(make_tuple(UpLength{})); using LowerIndex = MultiIndex<2>;
using UpperIndex = MultiIndex<2>;
using UpLengths = LowLengths;
Modulus modulus_;
UpLengths up_lengths_; UpLengths up_lengths_;
RightShift right_shift_;
__host__ __device__ constexpr Modulo() = default; __host__ __device__ constexpr Xor() : up_lengths_{}, right_shift_{} {}
__host__ __device__ constexpr Modulo(const Modulus& modulus, const UpLength& up_length) __host__ __device__ constexpr Xor(const LowLengths& low_lengths, const RightShift& right_shift)
: modulus_{modulus}, up_lengths_{make_tuple(up_length)} : up_lengths_{low_lengths}, right_shift_{right_shift}
{ {
} }
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } __host__ __device__ static constexpr auto GetTypeEnum() { return IndexTransformEnum::Xor; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
...@@ -1895,35 +1248,36 @@ struct Modulo ...@@ -1895,35 +1248,36 @@ struct Modulo
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up) const const UpIdx& idx_up) const
{ {
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, static_assert(LowIdx::Size() == 2 && UpIdx::Size() == 2,
"wrong! inconsistent # of dimension"); "wrong! inconsistent # of dimension");
idx_low(Number<0>{}) = idx_up[Number<0>{}] % modulus_; idx_low(Number<0>{}) = idx_up[Number<0>{}];
const auto idx_low_1_tmp =
(idx_up[Number<1>{}] - idx_up[Number<0>{}] * right_shift_) % up_lengths_[Number<1>{}];
const auto idx_low_1 =
(idx_low_1_tmp >= 0) ? idx_low_1_tmp : up_lengths_[Number<1>{}] + idx_low_1_tmp;
idx_low(Number<1>{}) = idx_low_1;
} }
template <typename LowIdxDiff, template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up, const UpIdxDiff&,
LowIdx& idx_low, LowIdx& idx_low,
const UpIdx& up_idx, const UpIdx& idx_up) const
Number<Hack>) const
{ {
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && static_assert(LowIdxDiff::Size() == 2 && UpIdxDiff::Size() == 2 && LowIdx::Size() == 2 &&
UpIdx::Size() == 1, UpIdx::Size() == 2,
"wrong! inconsistent # of dimension"); "wrong! inconsistent # of dimension");
constexpr auto I0 = Number<0>{};
const auto idx_low_old = idx_low; const auto idx_low_old = idx_low;
idx_low(I0) = (up_idx(I0) + idx_diff_up(I0)) % modulus_;
idx_diff_low(I0) = idx_low - idx_low_old;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return false; } CalculateLowerIndex(idx_low, idx_up);
idx_diff_low = idx_low - idx_low_old;
}
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{ {
...@@ -1939,16 +1293,45 @@ struct Modulo ...@@ -1939,16 +1293,45 @@ struct Modulo
__host__ __device__ static constexpr bool IsKnownAtCompileTime() __host__ __device__ static constexpr bool IsKnownAtCompileTime()
{ {
return is_known_at_compile_time<UpLengths>::value; return is_known_at_compile_time<UpLengths>::value &&
is_known_at_compile_time<RightShift>::value;
}
// MUST be static function
template <typename LowVectorLengths, typename LowVectorStrides>
__host__ __device__ constexpr auto
CalculateUpperDimensionSafeVectorLengthStrides(const LowVectorLengths& low_vector_lengths,
const LowVectorStrides& low_vector_strides) const
{
Array<index_t, 2> up_vector_lengths = low_vector_lengths;
Array<index_t, 2> up_vector_strides = low_vector_strides;
if constexpr(is_known_at_compile_time<RightShift>::value)
{
if(low_vector_lengths[1] != -1)
{
up_vector_lengths(1) = math::gcd(low_vector_lengths[1], math::abs(right_shift_));
}
}
return make_tuple(up_vector_lengths, up_vector_strides);
} }
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("{"); printf("Xor{");
printf("Modulus, ");
printf("up_lengths_"); //
print_multi_index(up_lengths_); printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("right_shift_: ");
print(right_shift_);
printf("}"); printf("}");
} }
}; };
} // namespace ck } // namespace ck
...@@ -51,32 +51,11 @@ __host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_leng ...@@ -51,32 +51,11 @@ __host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_leng
return Embed<UpLengths, Coefficients>{up_lengths, coefficients}; return Embed<UpLengths, Coefficients>{up_lengths, coefficients};
} }
template <typename LowLengths>
__host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths)
{
#if CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
return make_merge_transform_v2_magic_division(low_lengths);
#else
return make_merge_transform_v1_carry_check(low_lengths);
#endif
}
template <typename LowLengths>
__host__ __device__ constexpr auto
make_merge_transform_v1_carry_check(const LowLengths& low_lengths)
{
return Merge_v1_carry_check<LowLengths>{low_lengths};
}
template <typename LowLengths> template <typename LowLengths>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_merge_transform_v2_magic_division(const LowLengths& low_lengths) make_merge_transform_v2_magic_division(const LowLengths& low_lengths)
{ {
#if 1
return Merge_v2_magic_division<LowLengths>{low_lengths}; return Merge_v2_magic_division<LowLengths>{low_lengths};
#else
return Merge_v2r2_magic_division<LowLengths>{low_lengths};
#endif
} }
template <typename LowLengths> template <typename LowLengths>
...@@ -86,6 +65,12 @@ make_merge_transform_v3_division_mod(const LowLengths& low_lengths) ...@@ -86,6 +65,12 @@ make_merge_transform_v3_division_mod(const LowLengths& low_lengths)
return Merge_v3_division_mod<LowLengths>{low_lengths}; return Merge_v3_division_mod<LowLengths>{low_lengths};
} }
template <typename LowLengths>
__host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths)
{
return make_merge_transform_v2_magic_division(low_lengths);
}
template <typename UpLengths, bool Use24BitIntegerCalculation = false> template <typename UpLengths, bool Use24BitIntegerCalculation = false>
__host__ __device__ constexpr auto make_unmerge_transform( __host__ __device__ constexpr auto make_unmerge_transform(
const UpLengths& up_lengths, const UpLengths& up_lengths,
...@@ -100,10 +85,10 @@ __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_i ...@@ -100,10 +85,10 @@ __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_i
return Freeze<LowerIndex>{low_idx}; return Freeze<LowerIndex>{low_idx};
} }
template <typename UpperIndex> template <typename UpLengths>
__host__ __device__ constexpr auto make_insert_transform(const UpperIndex& up_idx) __host__ __device__ constexpr auto make_replicate_transform(const UpLengths& up_lengths)
{ {
return Insert<UpperIndex>{up_idx}; return Replicate<UpLengths>{up_lengths};
} }
template <typename LowLength, typename SliceBegin, typename SliceEnd> template <typename LowLength, typename SliceBegin, typename SliceEnd>
...@@ -114,17 +99,18 @@ __host__ __device__ constexpr auto make_slice_transform(const LowLength& low_len ...@@ -114,17 +99,18 @@ __host__ __device__ constexpr auto make_slice_transform(const LowLength& low_len
return Slice<LowLength, SliceBegin, SliceEnd>{low_length, slice_begin, slice_end}; return Slice<LowLength, SliceBegin, SliceEnd>{low_length, slice_begin, slice_end};
} }
template <typename VectorSize, typename UpLength>
__host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size,
const UpLength& up_length)
{
return Vectorize<VectorSize, UpLength>{vector_size, up_length};
}
template <typename Modulus, typename UpLength> template <typename Modulus, typename UpLength>
__host__ __device__ constexpr auto make_modulo_transform(const Modulus& modulus, __host__ __device__ constexpr auto make_modulo_transform(const Modulus& modulus,
const UpLength& up_length) const UpLength& up_length)
{ {
return Modulo<Modulus, UpLength>{modulus, up_length}; return Modulo<Modulus, UpLength>{modulus, up_length};
} }
template <typename LowLengths, typename RightShift>
__host__ __device__ constexpr auto make_xor_transform(const LowLengths& low_lengths,
const RightShift& right_shift)
{
return Xor<LowLengths, RightShift>{low_lengths, right_shift};
}
} // namespace ck } // namespace ck
...@@ -4,8 +4,6 @@ ...@@ -4,8 +4,6 @@
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
namespace ck { namespace ck {
...@@ -35,21 +33,23 @@ struct TensorAdaptor ...@@ -35,21 +33,23 @@ struct TensorAdaptor
return UpperDimensionHiddenIdss{}; return UpperDimensionHiddenIdss{};
} }
__host__ __device__ static constexpr auto GetTopDimensionHiddenIds() __host__ __device__ static constexpr auto GetBottomDimensionHiddenIds()
{ {
return TopDimensionHiddenIds{}; return BottomDimensionHiddenIds{};
} }
__host__ __device__ static constexpr auto GetBottomDimensionHiddenIds() __host__ __device__ static constexpr auto GetTopDimensionHiddenIds()
{ {
return BottomDimensionHiddenIds{}; return TopDimensionHiddenIds{};
} }
__host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms) __host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms)
{ {
const auto lengths = generate_tuple( const auto lengths = generate_tuple(
[&](auto idim_top) { [&](auto idim_top) {
constexpr auto tmp = GetTransformAndItsUpperDimension(idim_top); constexpr index_t idim_hidden = TopDimensionHiddenIds::At(idim_top);
constexpr auto tmp = GetTransformAndItsUpperDimension(Number<idim_hidden>{});
constexpr index_t itran = tmp[Number<0>{}]; constexpr index_t itran = tmp[Number<0>{}];
constexpr index_t idim_up = tmp[Number<1>{}]; constexpr index_t idim_up = tmp[Number<1>{}];
...@@ -69,12 +69,12 @@ struct TensorAdaptor ...@@ -69,12 +69,12 @@ struct TensorAdaptor
return container_reduce(lengths, math::multiplies{}, Number<1>{}); return container_reduce(lengths, math::multiplies{}, Number<1>{});
} }
template <index_t IDim> template <index_t IDimHidden>
__host__ __device__ static constexpr auto GetTransformAndItsUpperDimension(Number<IDim>) __host__ __device__ static constexpr auto GetTransformAndItsUpperDimension(Number<IDimHidden>)
{ {
constexpr auto idim_top = Number<IDim>{}; // FIXME: length of bottom dimension is not known, since info about lower dim length are not
// saved in transformation
constexpr index_t idim_hidden = TopDimensionHiddenIds::At(idim_top); static_assert(IDimHidden >= ndim_bottom_, "wrong! not implemented");
index_t itran_found = 0; index_t itran_found = 0;
index_t idim_up_found = 0; index_t idim_up_found = 0;
...@@ -84,7 +84,7 @@ struct TensorAdaptor ...@@ -84,7 +84,7 @@ struct TensorAdaptor
constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran]; constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran];
static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) { static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) {
if constexpr(up_dim_ids[idim_up] == idim_hidden) if constexpr(up_dim_ids[idim_up] == IDimHidden)
{ {
itran_found = itran; itran_found = itran;
idim_up_found = idim_up; idim_up_found = idim_up;
...@@ -138,11 +138,7 @@ struct TensorAdaptor ...@@ -138,11 +138,7 @@ struct TensorAdaptor
using ElementSize = remove_cv_t<decltype(InitializeElementSize(Transforms{}))>; using ElementSize = remove_cv_t<decltype(InitializeElementSize(Transforms{}))>;
public: public:
#if 0 // workaround compiler complaint about constexpr
__host__ __device__ constexpr TensorAdaptor() = default; __host__ __device__ constexpr TensorAdaptor() = default;
#else
__host__ __device__ constexpr TensorAdaptor() : transforms_{}, element_size_{} {}
#endif
__host__ __device__ constexpr TensorAdaptor(const Transforms& transforms) __host__ __device__ constexpr TensorAdaptor(const Transforms& transforms)
: transforms_{transforms}, element_size_{InitializeElementSize(transforms)} : transforms_{transforms}, element_size_{InitializeElementSize(transforms)}
...@@ -157,17 +153,52 @@ struct TensorAdaptor ...@@ -157,17 +153,52 @@ struct TensorAdaptor
__host__ __device__ constexpr auto GetElementSize() const { return element_size_; } __host__ __device__ constexpr auto GetElementSize() const { return element_size_; }
#if 0 // debug // FIXME: this logic is wrong when getting bottome dimension lengths
template <index_t I> template <index_t IDimHidden>
__host__ __device__ constexpr index_t GetTopDimensionLength(Number<I> idim) const __host__ __device__ constexpr auto GetHiddenDimensionLength(Number<IDimHidden>) const
{ {
// TODO: not implemented static_assert(IDimHidden >= 0 && IDimHidden < ndim_hidden_, "wrong! out of range");
constexpr auto tmp = GetTransformAndItsUpperDimension(Number<IDimHidden>{});
constexpr index_t itran = tmp[Number<0>{}];
constexpr index_t idim_up = tmp[Number<1>{}];
constexpr bool found = tmp[Number<2>{}];
static_assert(found == true,
"wrong! not found matching transformation and upper-dimension");
return transforms_[Number<itran>{}].GetUpperLengths()[Number<idim_up>{}];
} }
template <index_t I> template <index_t IDimTop>
__host__ __device__ constexpr index_t GetBottomDimensionLength(Number<I> idim) const __host__ __device__ constexpr auto GetTopDimensionLength(Number<IDimTop> idim_top) const
{ {
// TODO: not implemented return GetHiddenDimensionLength(TopDimensionHiddenIds::At(idim_top));
}
#if 0
// FIXME: GetHiddenDimensionLength is wrong when getting bottome dimension lengths
template <index_t IDimBottom>
__host__ __device__ constexpr index_t
GetBottomDimensionLength(Number<IDimBottom> idim_bottom) const
{
return GetHiddenDimensionLength(TopDimensionHiddenIds::At(idim_bottom));
}
#endif
__host__ __device__ constexpr auto GetTopDimensionLengths() const
{
return generate_tuple([&](auto i) { return GetTopDimensionLength(i); },
Number<ndim_top_>{});
}
#if 0
// FIXME: GetHiddenDimensionLength is wrong when getting bottome dimension lengths
__host__ __device__ constexpr auto GetBottomDimensionLengths() const
{
return generate_tuple([&](auto i) { return GetBottomDimensionLength(i); },
Number<ndim_bottom_>{});
} }
#endif #endif
...@@ -204,7 +235,7 @@ struct TensorAdaptor ...@@ -204,7 +235,7 @@ struct TensorAdaptor
return get_container_subset(idx_hidden, BottomDimensionHiddenIds{}); return get_container_subset(idx_hidden, BottomDimensionHiddenIds{});
} }
__host__ __device__ static constexpr bool IsKnownAtCompileTime() __host__ __device__ static constexpr bool IsStatic()
{ {
bool is_known = true; bool is_known = true;
...@@ -215,23 +246,81 @@ struct TensorAdaptor ...@@ -215,23 +246,81 @@ struct TensorAdaptor
return is_known && is_known_at_compile_time<ElementSize>::value; return is_known && is_known_at_compile_time<ElementSize>::value;
} }
__host__ __device__ void Print() const __host__ __device__ static constexpr bool IsKnownAtCompileTime() { return IsStatic(); }
__host__ __device__ static constexpr auto GetTopDimensionSafeVectorLengthStrides(
const Array<index_t, ndim_hidden_>& guaranteed_vector_lengths,
const Array<index_t, ndim_hidden_>& guaranteed_vector_strides)
{ {
printf("{"); auto vector_lengths = guaranteed_vector_lengths;
printf("TensorAdaptor, "); auto vector_strides = guaranteed_vector_strides;
static_for<0, ntransform_, 1>{}([&](auto i) {
printf("transforms: "); static_for<0, GetNumOfTransform(), 1>{}([&](auto itran) {
transforms_[i].Print(); constexpr auto low_dims = GetLowerDimensionHiddenIdss().At(itran);
printf("LowerDimensionHiddenIds:"); constexpr auto up_dims = GetUpperDimensionHiddenIdss().At(itran);
LowerDimensionHiddenIdss{}.At(i).Print();
printf("UpperDimensionHiddenIds:"); const auto up_guaranteed_vector_lengths =
UpperDimensionHiddenIdss{}.At(i).Print(); get_container_subset(guaranteed_vector_lengths, up_dims);
const auto up_guaranteed_vector_strides =
get_container_subset(guaranteed_vector_strides, up_dims);
// only need type of transform
auto [up_vector_lengths, up_vector_strides] =
Transforms{}.At(itran).CalculateUpperDimensionSafeVectorLengthStrides(
get_container_subset(vector_lengths, low_dims),
get_container_subset(vector_strides, low_dims));
if constexpr(up_dims.Size() > 0)
{
for(index_t i = 0; i < up_dims.Size(); ++i)
{
up_vector_lengths(i) = (up_guaranteed_vector_lengths[i] != -1)
? up_guaranteed_vector_lengths[i]
: up_vector_lengths[i];
up_vector_strides(i) = (up_guaranteed_vector_strides[i] != -1)
? up_guaranteed_vector_strides[i]
: up_vector_strides[i];
}
}
set_container_subset(vector_lengths, up_dims, up_vector_lengths);
set_container_subset(vector_strides, up_dims, up_vector_strides);
}); });
printf("BottomDimensionHiddenIds:"); constexpr auto top_dims = TopDimensionHiddenIds{};
BottomDimensionHiddenIds::Print();
printf("TopDimensionHiddenIds:"); return make_tuple(get_container_subset(vector_lengths, top_dims),
TopDimensionHiddenIds::Print(); get_container_subset(vector_strides, top_dims));
}
__host__ __device__ void Print() const
{
printf("TensorAdaptor{");
//
printf("transforms: ");
print(transforms_);
printf(", ");
//
printf("LowerDimensionHiddenIds: ");
print(LowerDimensionHiddenIdss{});
printf(", ");
//
printf("UpperDimensionHiddenIds: ");
print(UpperDimensionHiddenIdss{});
printf(", ");
//
printf("BottomDimensionHiddenIds: ");
print(BottomDimensionHiddenIds{});
printf(", ");
//
printf("TopDimensionHiddenIds: ");
print(TopDimensionHiddenIds{});
printf("}"); printf("}");
} }
...@@ -241,6 +330,161 @@ struct TensorAdaptor ...@@ -241,6 +330,161 @@ struct TensorAdaptor
ElementSize element_size_; ElementSize element_size_;
}; };
// Transforms: Tuple<transforms...>
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
template <typename Transforms, typename LowerDimensionOldTopIdss, typename UpperDimensionNewTopIdss>
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms& transforms,
LowerDimensionOldTopIdss,
UpperDimensionNewTopIdss)
{
constexpr index_t ntransform = Transforms::Size();
static_assert(LowerDimensionOldTopIdss::Size() == ntransform &&
UpperDimensionNewTopIdss::Size() == ntransform,
"wrong!");
// sanity check on LowerDimensionOldTopIdss and UpperDimensionNewTopIdss
constexpr auto all_low_dim_old_top_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionOldTopIdss{});
constexpr auto all_up_dim_new_top_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionNewTopIdss{});
static_assert(is_valid_sequence_map<decltype(all_low_dim_old_top_ids)>::value &&
is_valid_sequence_map<decltype(all_up_dim_new_top_ids)>::value,
"wrong!");
constexpr index_t ndim_old_top = all_low_dim_old_top_ids.Size();
constexpr index_t ndim_new_top = all_up_dim_new_top_ids.Size();
// low_dim_hidden_idss
constexpr auto low_dim_hidden_idss = LowerDimensionOldTopIdss{};
// up_dim_hidden_idss: shift UpperDimensionNewTopIdss by ndim_bottom
constexpr auto up_dim_hidden_idss = generate_tuple(
[](auto itran) { return UpperDimensionNewTopIdss{}[itran] + Number<ndim_old_top>{}; },
Number<ntransform>{});
// bottom_dim_hidden_ids
constexpr auto bottom_dim_hidden_ids =
typename arithmetic_sequence_gen<0, ndim_old_top, 1>::type{};
// top_dim_hidden_ids
constexpr auto top_dim_hidden_ids =
typename arithmetic_sequence_gen<0, ndim_new_top, 1>::type{} + Number<ndim_old_top>{};
return TensorAdaptor<remove_cvref_t<Transforms>,
remove_cvref_t<decltype(low_dim_hidden_idss)>,
remove_cvref_t<decltype(up_dim_hidden_idss)>,
remove_cvref_t<decltype(bottom_dim_hidden_ids)>,
remove_cvref_t<decltype(top_dim_hidden_ids)>>{transforms};
}
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor, and to put it outside the scope where it is used
// (transform_tensor_adaptor) because template cannot be defined inside a function
// template
template <typename NewTransforms>
struct lambda_get_up_dim_num
{
template <typename I>
__host__ __device__ constexpr auto operator()(I) const
{
using Tran = remove_reference_t<decltype(NewTransforms{}.At(I{}))>;
return Number<Tran::GetNumOfUpperDimension()>{};
}
};
template <typename OldTensorAdaptor,
typename NewTransforms,
typename NewLowerDimensionOldTopIdss,
typename NewUpperDimensionNewTopIdss>
__host__ __device__ constexpr auto
transform_tensor_adaptor(const OldTensorAdaptor& old_tensor_adaptor,
const NewTransforms& new_transforms,
NewLowerDimensionOldTopIdss,
NewUpperDimensionNewTopIdss)
{
// sanity check
{
static_assert(NewTransforms::Size() == NewLowerDimensionOldTopIdss::Size() &&
NewTransforms::Size() == NewUpperDimensionNewTopIdss::Size(),
"wrong! inconsitent number of transform");
constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewLowerDimensionOldTopIdss{});
constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewUpperDimensionNewTopIdss{});
static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
is_valid_sequence_map<decltype(all_new_top_ids)>::value,
"wrong!");
}
// lower dimension's hidden idss
// convert lower dimension top idss (tuple of sequences) to hidden idss (tuple of
// sequences)
constexpr auto low_dim_hidden_idss = transform_tuples(
// convert lower dimension top ids (a sequence) to hidden ids (a sequence)
[](auto low_dim_top_ids) constexpr {
return transform_sequences(
// convert lower dimension top id to hidden id
[](auto low_dim_top_id) constexpr {
return OldTensorAdaptor::GetTopDimensionHiddenIds()[low_dim_top_id];
},
low_dim_top_ids);
},
NewLowerDimensionOldTopIdss{});
constexpr index_t num_new_transform = NewTransforms::Size();
// upper dimension's hidden idss
constexpr index_t old_hidden_dim_number = OldTensorAdaptor::GetNumOfHiddenDimension();
constexpr auto up_dim_numbers =
generate_sequence(lambda_get_up_dim_num<NewTransforms>{}, Number<num_new_transform>{});
constexpr auto up_dim_numbers_scan = merge_sequences(
Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, math::plus<index_t>{}, Number<0>{}));
constexpr auto up_dim_hidden_idss = generate_tuple(
[ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr {
return
typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
old_hidden_dim_number + up_dim_numbers_scan[i + 1],
1>::type{};
},
Number<num_new_transform>{});
// new top dimension's hidden ids
constexpr auto unordered_new_top_dim_hidden_ids = unpack(
[](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
constexpr auto new_top_dim_unordered2ordered = unpack(
[](auto... xs) constexpr { return merge_sequences(xs...); }, NewUpperDimensionNewTopIdss{});
constexpr auto new_top_dim_hidden_ids =
unordered_new_top_dim_hidden_ids.ReorderGivenOld2New(new_top_dim_unordered2ordered);
// put everything together
const auto all_transforms =
container_concat(old_tensor_adaptor.GetTransforms(), new_transforms);
constexpr auto all_low_dim_hidden_idss =
container_concat(OldTensorAdaptor::GetLowerDimensionHiddenIdss(), low_dim_hidden_idss);
constexpr auto all_up_dim_hidden_idss =
container_concat(OldTensorAdaptor::GetUpperDimensionHiddenIdss(), up_dim_hidden_idss);
return TensorAdaptor<remove_cvref_t<decltype(all_transforms)>,
remove_cvref_t<decltype(all_low_dim_hidden_idss)>,
remove_cvref_t<decltype(all_up_dim_hidden_idss)>,
remove_cvref_t<decltype(OldTensorAdaptor::GetBottomDimensionHiddenIds())>,
remove_cvref_t<decltype(new_top_dim_hidden_ids)>>{all_transforms};
}
template <typename TensorAdaptor0, typename TensorAdaptor1> template <typename TensorAdaptor0, typename TensorAdaptor1>
__host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& adaptor0, __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& adaptor0,
const TensorAdaptor1& adaptor1) const TensorAdaptor1& adaptor1)
...@@ -415,62 +659,11 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a ...@@ -415,62 +659,11 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
TensorAdaptor1::GetTopDimensionHiddenIds() + Number<adaptor1_hidden_id_shift>{}; TensorAdaptor1::GetTopDimensionHiddenIds() + Number<adaptor1_hidden_id_shift>{};
// put everything together // put everything together
return TensorAdaptor<remove_cv_t<decltype(all_transforms)>, return TensorAdaptor<remove_cvref_t<decltype(all_transforms)>,
remove_cv_t<decltype(all_low_dim_hidden_idss)>, remove_cvref_t<decltype(all_low_dim_hidden_idss)>,
remove_cv_t<decltype(all_up_dim_hidden_idss)>, remove_cvref_t<decltype(all_up_dim_hidden_idss)>,
remove_cv_t<decltype(bottom_dim_hidden_ids)>, remove_cvref_t<decltype(bottom_dim_hidden_ids)>,
remove_cv_t<decltype(top_dim_hidden_ids)>>{all_transforms}; remove_cvref_t<decltype(top_dim_hidden_ids)>>{all_transforms};
}
// Transforms: Tuple<transforms...>
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
template <typename Transforms, typename LowerDimensionOldTopIdss, typename UpperDimensionNewTopIdss>
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms& transforms,
LowerDimensionOldTopIdss,
UpperDimensionNewTopIdss)
{
constexpr index_t ntransform = Transforms::Size();
static_assert(LowerDimensionOldTopIdss::Size() == ntransform &&
UpperDimensionNewTopIdss::Size() == ntransform,
"wrong!");
// sanity check on LowerDimensionOldTopIdss and UpperDimensionNewTopIdss
constexpr auto all_low_dim_old_top_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionOldTopIdss{});
constexpr auto all_up_dim_new_top_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionNewTopIdss{});
static_assert(is_valid_sequence_map<decltype(all_low_dim_old_top_ids)>::value &&
is_valid_sequence_map<decltype(all_up_dim_new_top_ids)>::value,
"wrong!");
constexpr index_t ndim_old_top = all_low_dim_old_top_ids.Size();
constexpr index_t ndim_new_top = all_up_dim_new_top_ids.Size();
// low_dim_hidden_idss
constexpr auto low_dim_hidden_idss = LowerDimensionOldTopIdss{};
// up_dim_hidden_idss: shift UpperDimensionNewTopIdss by ndim_bottom
constexpr auto up_dim_hidden_idss = generate_tuple(
[](auto itran) { return UpperDimensionNewTopIdss{}[itran] + Number<ndim_old_top>{}; },
Number<ntransform>{});
// bottom_dim_hidden_ids
constexpr auto bottom_dim_hidden_ids =
typename arithmetic_sequence_gen<0, ndim_old_top, 1>::type{};
// top_dim_hidden_ids
constexpr auto top_dim_hidden_ids =
typename arithmetic_sequence_gen<0, ndim_new_top, 1>::type{} + Number<ndim_old_top>{};
return TensorAdaptor<remove_cv_t<Transforms>,
remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(bottom_dim_hidden_ids)>,
remove_cv_t<decltype(top_dim_hidden_ids)>>{transforms};
} }
template <typename X, typename... Xs, typename enable_if<sizeof...(Xs) >= 2, bool>::type = false> template <typename X, typename... Xs, typename enable_if<sizeof...(Xs) >= 2, bool>::type = false>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace ck {
template <index_t NDimHidden, typename BottomDimensionHiddenIds, typename TopDimensionHiddenIds>
struct TensorAdaptorCoordinate
{
static constexpr index_t ndim_bottom_ = BottomDimensionHiddenIds::Size();
static constexpr index_t ndim_top_ = TopDimensionHiddenIds::Size();
using HiddenIndex = MultiIndex<NDimHidden>;
using BottomIndex = MultiIndex<ndim_bottom_>;
using TopIndex = MultiIndex<ndim_top_>;
public:
__host__ __device__ constexpr TensorAdaptorCoordinate() = default;
__host__ __device__ constexpr TensorAdaptorCoordinate(const HiddenIndex& idx_hidden)
: idx_hidden_{idx_hidden}
{
}
__host__ __device__ constexpr auto GetTopIndex() const
{
return get_container_subset(idx_hidden_, TopDimensionHiddenIds{});
}
__host__ __device__ constexpr auto GetBottomIndex() const
{
return get_container_subset(idx_hidden_, BottomDimensionHiddenIds{});
}
__host__ __device__ constexpr const auto& GetHiddenIndex() const { return idx_hidden_; }
__host__ __device__ constexpr auto& GetHiddenIndex() { return idx_hidden_; }
//
HiddenIndex idx_hidden_;
};
template <typename Adaptor, typename TopIndex>
__host__ __device__ constexpr auto make_tensor_adaptor_coordinate(const Adaptor& adaptor,
const TopIndex& idx_top)
{
static_assert(Adaptor::GetNumOfTopDimension() == TopIndex::Size(),
"wrong! # of dimension inconsistent");
constexpr index_t ntransform = Adaptor::GetNumOfTransform();
constexpr index_t ndim_hidden = Adaptor::GetNumOfHiddenDimension();
constexpr auto bottom_dim_ids = Adaptor::GetBottomDimensionHiddenIds();
constexpr auto top_dim_ids = Adaptor::GetTopDimensionHiddenIds();
MultiIndex<ndim_hidden> idx_hidden;
// initialize visible index
set_container_subset(idx_hidden, top_dim_ids, idx_top);
// calculate hidden index
static_for<ntransform, 0, -1>{}([&adaptor, &idx_hidden](auto itran_p1) {
auto itran = itran_p1 - Number<1>{};
const auto& tran = adaptor.GetTransforms().At(itran);
constexpr auto dims_low = Adaptor::GetLowerDimensionHiddenIdss().At(itran);
constexpr auto dims_up = Adaptor::GetUpperDimensionHiddenIdss().At(itran);
const auto idx_up = get_container_subset(idx_hidden, dims_up);
MultiIndex<dims_low.Size()> idx_low;
tran.CalculateLowerIndex(idx_low, idx_up);
set_container_subset(idx_hidden, dims_low, idx_low);
});
return TensorAdaptorCoordinate<ndim_hidden,
remove_cvref_t<decltype(bottom_dim_ids)>,
remove_cvref_t<decltype(top_dim_ids)>>{idx_hidden};
}
template <bool JudgeDoTransforms = true,
typename Adaptor,
typename AdaptorCoord,
typename TopIndex,
typename BottomIndex>
__host__ __device__ constexpr void move_tensor_adaptor_coordinate(const Adaptor& adaptor,
AdaptorCoord& coord,
const TopIndex& idx_diff_top,
BottomIndex& idx_diff_bottom)
{
constexpr index_t ndim_hidden = Adaptor::GetNumOfHiddenDimension();
constexpr index_t ndim_top = Adaptor::GetNumOfTopDimension();
// constexpr index_t ndim_bottom = Adaptor::GetNumOfBottomDimension();
constexpr index_t ntransform = Adaptor::GetNumOfTransform();
// STATIC_ASSERT(TopIndex::Size() == ndim_top && BottomIndex::Size() == ndim_bottom, "");
// judge whether calculation of lower diff is needed for each transform
// use index_t for boolean type
auto do_transforms = make_zero_multi_index<ntransform>();
if constexpr(JudgeDoTransforms)
{
auto is_non_zero_diff = make_zero_multi_index<ndim_hidden>();
// decide do_transform by checkout non-zero index diff components
MultiIndex<ndim_top> non_zero_diff_pick_top;
static_for<0, ndim_top, 1>{}(
[&](auto i) { non_zero_diff_pick_top(i) = (idx_diff_top[i] != 0); });
set_container_subset(
is_non_zero_diff, Adaptor::GetTopDimensionHiddenIds(), non_zero_diff_pick_top);
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
constexpr auto dims_low = Adaptor::GetLowerDimensionHiddenIdss().At(itran);
constexpr auto dims_up = Adaptor::GetUpperDimensionHiddenIdss().At(itran);
const auto non_zero_diff_pick_up = get_container_subset(is_non_zero_diff, dims_up);
MultiIndex<dims_low.Size()> non_zero_diff_pick_low;
// if any of upper index diff components is non-zero, then
// 1) Need to do this transform
// 2) all components of lower index diff will assume to be non-zero and need to be
// computed
const bool idx_diff_up_has_non_zero = container_reduce(
non_zero_diff_pick_up, [](auto a, auto b) constexpr { return a or b; }, false);
do_transforms(itran) = idx_diff_up_has_non_zero;
static_for<0, dims_low.Size(), 1>{}(
[&](auto i) { non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero; });
set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low);
});
}
else
{
static_for<ntransform - 1, -1, -1>{}([&](auto itran) { do_transforms(itran) = 1; });
}
// this is what needs to be calculated
auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>();
// initialize top index diff
set_container_subset(idx_diff_hidden, Adaptor::GetTopDimensionHiddenIds(), idx_diff_top);
// this is what needs to be updated
auto& idx_hidden = coord.GetHiddenIndex();
// update top index
auto idx_hidden_pick_top =
get_container_subset(idx_hidden, Adaptor::GetTopDimensionHiddenIds());
idx_hidden_pick_top += idx_diff_top;
set_container_subset(idx_hidden, Adaptor::GetTopDimensionHiddenIds(), idx_hidden_pick_top);
// update rest of hidden index
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
if(do_transforms[itran])
{
const auto& tran = adaptor.GetTransforms().At(itran);
constexpr auto dims_low = Adaptor::GetLowerDimensionHiddenIdss().At(itran);
constexpr auto dims_up = Adaptor::GetUpperDimensionHiddenIdss().At(itran);
const auto idx_up_new = get_container_subset(idx_hidden, dims_up);
auto idx_low = get_container_subset(idx_hidden, dims_low);
const auto idx_diff_up = get_container_subset(idx_diff_hidden, dims_up);
MultiIndex<dims_low.Size()> idx_diff_low;
tran.UpdateLowerIndex(idx_diff_low, idx_diff_up, idx_low, idx_up_new);
set_container_subset(idx_diff_hidden, dims_low, idx_diff_low);
set_container_subset(idx_hidden, dims_low, idx_low);
}
});
// set bottom index diff
idx_diff_bottom = get_container_subset(idx_diff_hidden, Adaptor::GetBottomDimensionHiddenIds());
}
template <bool JudgeDoTransforms = true, typename Adaptor, typename AdaptorCoord, typename TopIndex>
__host__ __device__ constexpr void move_tensor_adaptor_coordinate(const Adaptor& adaptor,
AdaptorCoord& coord,
const TopIndex& idx_diff_top)
{
constexpr index_t ndim_bottom = Adaptor::GetNumOfBottomDimension();
MultiIndex<ndim_bottom> tmp;
move_tensor_adaptor_coordinate<JudgeDoTransforms>(adaptor, coord, idx_diff_top, tmp);
}
template <typename Adaptor, typename AdaptorCoord>
__host__ __device__ constexpr bool
adaptor_coordinate_is_valid_assuming_top_index_is_valid(const Adaptor& adaptor,
const AdaptorCoord& coord)
{
bool valid = true;
constexpr index_t ntransform = Adaptor::GetNumOfTransform();
const auto& idx_hidden = coord.GetHiddenIndex();
static_for<ntransform - 1, -1, -1>{}([&adaptor, &idx_hidden, &valid](auto itran) {
const auto tran = adaptor.GetTransforms().At(itran);
// check validity, only if current transformation does not always has a valid mapping
if constexpr(!decltype(tran)::IsValidUpperIndexAlwaysMappedToValidLowerIndex())
{
const auto idx_up =
get_container_subset(idx_hidden, Adaptor::GetUpperDimensionHiddenIdss().At(itran));
// Comment: using valid = valid && .. will result in weird control flow in ISA
valid &= tran.IsValidUpperIndexMappedToValidLowerIndex(idx_up);
}
});
return valid;
}
template <typename Adaptor, typename AdpatorCoord>
__host__ __device__ constexpr bool adaptor_coordinate_is_valid(const Adaptor& adaptor,
const AdpatorCoord& coord)
{
// check top index
const auto& idx_top = coord.GetTopIndex();
bool is_top_index_valid = true;
static_for<0, Adaptor::GetNumOfDimension(), 1>{}(
[&is_top_index_valid, &idx_top, &adaptor](auto i) {
is_top_index_valid =
is_top_index_valid && (idx_top[i] >= 0 && idx_top[i] < adaptor.GetLength(i));
});
// check other hidden index
return is_top_index_valid &&
adaptor_coordinate_is_valid_assuming_top_index_is_valid(adaptor, coord);
}
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_adaptor_coordinate.hpp"
namespace ck {
template <index_t NDimHidden, typename TopDimensionHiddenIds>
struct TensorCoordinate
: public TensorAdaptorCoordinate<NDimHidden, Sequence<0>, TopDimensionHiddenIds>
{
using Base = TensorAdaptorCoordinate<NDimHidden, Sequence<0>, TopDimensionHiddenIds>;
// TODO make these private
static constexpr index_t ndim_top_ = TopDimensionHiddenIds::Size();
using HiddenIndex = MultiIndex<NDimHidden>;
using TopIndex = MultiIndex<ndim_top_>;
public:
__host__ __device__ constexpr TensorCoordinate() = default;
__host__ __device__ constexpr TensorCoordinate(const HiddenIndex& idx_hidden) : Base{idx_hidden}
{
}
// construct from TensorAdaptorCoordinte base class
__host__ __device__ constexpr TensorCoordinate(const Base& adaptor_coord) : Base{adaptor_coord}
{
}
__host__ __device__ constexpr auto GetIndex() const { return Base::GetTopIndex(); }
__host__ __device__ constexpr index_t GetOffset() const
{
return Base::GetBottomIndex()[Number<0>{}];
}
__host__ __device__ constexpr const auto& GetHiddenIndex() const
{
return Base::GetHiddenIndex();
}
__host__ __device__ auto& GetHiddenIndex() { return Base::GetHiddenIndex(); }
};
template <typename TensorDesc, typename TopIndex>
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc& tensor_desc,
const TopIndex& idx_top)
{
const auto adaptor_coord = make_tensor_adaptor_coordinate(tensor_desc, idx_top);
return TensorCoordinate<TensorDesc::GetNumOfHiddenDimension(),
remove_cvref_t<decltype(TensorDesc::GetTopDimensionHiddenIds())>>{
adaptor_coord};
}
template <bool JudgeDoTransforms = true, typename TensorDesc, typename TensorCoord, typename Index>
__host__ __device__ constexpr void
move_tensor_coordinate(const TensorDesc& tensor_desc, TensorCoord& coord, const Index& coord_step)
{
move_tensor_adaptor_coordinate(tensor_desc, coord, coord_step);
}
template <typename TensorDesc, typename TensorCoord>
__host__ __device__ constexpr bool
coordinate_has_valid_offset_assuming_top_index_is_valid(const TensorDesc& tensor_desc,
const TensorCoord& coord)
{
return adaptor_coordinate_is_valid_assuming_top_index_is_valid(tensor_desc, coord);
}
template <typename TensorDesc, typename TensorCoord>
__host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc& tensor_desc,
const TensorCoord& coord)
{
return adaptor_coordinate_is_valid(tensor_desc, coord);
}
} // namespace ck
...@@ -4,612 +4,200 @@ ...@@ -4,612 +4,200 @@
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/utility/sequence_helper.hpp" #include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/multi_index_transform.hpp"
namespace ck { namespace ck {
template <index_t NDimHidden, typename VisibleDimensionIds>
struct TensorCoordinate;
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
struct TensorCoordinateStep;
// Transforms: Tuple<transforms...> // Transforms: Tuple<transforms...>
// LowerDimensionIdss : Tuple<Sequence<...>, ...> // LowerDimensionHiddenIdss : Tuple<Sequence<...>, ...>
// UpperDimensionIdss : Tuple<Sequence<...>, ...> // UpperDimensionHiddenIdss : Tuple<Sequence<...>, ...>
// VisibleDimensionIds> : Sequence<...> // TopDimensionHiddenIds> : Sequence<...>
template <typename Transforms, template <typename Transforms,
typename LowerDimensionIdss, typename LowerDimensionHiddenIdss,
typename UpperDimensionIdss, typename UpperDimensionHiddenIdss,
typename VisibleDimensionIds, typename TopDimensionHiddenIds,
typename ElementSpaceSize> typename ElementSpaceSize,
struct TensorDescriptor typename GuaranteedVectorLengths_,
typename GuaranteedVectorSrides_>
struct TensorDescriptor : public TensorAdaptor<Transforms,
LowerDimensionHiddenIdss,
UpperDimensionHiddenIdss,
Sequence<0>,
TopDimensionHiddenIds>
{ {
// TODO make these private using Base = TensorAdaptor<Transforms,
__host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); } LowerDimensionHiddenIdss,
UpperDimensionHiddenIdss,
__host__ __device__ static constexpr index_t GetNumOfVisibleDimension() Sequence<0>,
{ TopDimensionHiddenIds>;
return VisibleDimensionIds::Size();
}
__host__ __device__ static constexpr index_t GetNumOfHiddenDimension()
{
constexpr auto all_low_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionIdss{});
constexpr auto all_up_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionIdss{});
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
using unique_sort_all_dim_ids = typename sequence_unique_sort<decltype(all_dim_ids), using ElementSpaceSizeType = ElementSpaceSize;
math::less<index_t>,
math::equal<index_t>>::type;
return unique_sort_all_dim_ids::Size(); constexpr static index_t ntransform_ = Base::GetNumOfTransform();
} constexpr static index_t ndim_hidden_ = Base::GetNumOfHiddenDimension();
constexpr static index_t ndim_top_ = Base::GetNumOfTopDimension();
__host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms)
{
const auto lengths = generate_tuple(
[&](auto idim_visible) {
constexpr auto tmp = GetTransformAndItsUpperDimension(idim_visible);
constexpr index_t itran = tmp[Number<0>{}];
constexpr index_t idim_up = tmp[Number<1>{}];
constexpr bool found = tmp[Number<2>{}];
static_assert(found == true,
"wrong! not found matching transformation and upper-dimension");
const auto length = using GuaranteedVectorLengths = GuaranteedVectorLengths_;
transforms[Number<itran>{}].GetUpperLengths()[Number<idim_up>{}]; using GuaranteedVectorStrides = GuaranteedVectorSrides_;
return length; static_assert(GuaranteedVectorLengths::Size() == ndim_hidden_ &&
}, GuaranteedVectorStrides::Size() == ndim_hidden_,
Number<ndim_visible_>{}); "wrong! inconsistent # of hidden dimensions");
// TODO: make container_reduce support tuple of Number and index_t
return container_reduce(lengths, math::multiplies{}, Number<1>{});
}
template <index_t IDim>
__host__ __device__ static constexpr auto GetTransformAndItsUpperDimension(Number<IDim>)
{
constexpr auto idim_visible = Number<IDim>{};
constexpr index_t idim_hidden = VisibleDimensionIds::At(idim_visible); using TopIndex = MultiIndex<ndim_top_>;
using HiddenIndex = MultiIndex<ndim_hidden_>;
index_t itran_found = 0;
index_t idim_up_found = 0;
bool found = false;
static_for<0, ntransform_, 1>{}([&](auto itran) {
constexpr auto up_dim_ids = UpperDimensionIdss{}[itran];
static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) {
if constexpr(up_dim_ids[idim_up] == idim_hidden)
{
itran_found = itran;
idim_up_found = idim_up;
found = true;
}
});
});
return make_tuple(itran_found, idim_up_found, found);
}
constexpr static index_t ntransform_ = GetNumOfTransform();
constexpr static index_t ndim_visible_ = GetNumOfVisibleDimension();
constexpr static index_t ndim_hidden_ = GetNumOfHiddenDimension();
using VisibleIndex = MultiIndex<ndim_visible_>;
using HiddenIndex = MultiIndex<ndim_hidden_>;
using Coordinate = TensorCoordinate<ndim_hidden_, VisibleDimensionIds>;
// may be index_t or Number<>
using ElementSize = remove_cv_t<decltype(InitializeElementSize(Transforms{}))>;
public: public:
#if 0 // workaround compiler complaint about constexpr
__host__ __device__ constexpr TensorDescriptor() = default; __host__ __device__ constexpr TensorDescriptor() = default;
#else
__host__ __device__ constexpr TensorDescriptor()
: transforms_{}, element_size_{}, element_space_size_{}
{
}
#endif
__host__ __device__ constexpr TensorDescriptor(const Transforms& transforms, __host__ __device__ constexpr TensorDescriptor(const Transforms& transforms,
ElementSpaceSize element_space_size) ElementSpaceSize element_space_size)
: transforms_{transforms}, : Base{transforms}, element_space_size_{element_space_size}
element_size_{InitializeElementSize(transforms)},
element_space_size_{element_space_size}
{ {
static_assert(Transforms::Size() == ntransform_ && static_assert(Transforms::Size() == ntransform_ &&
LowerDimensionIdss::Size() == ntransform_ && LowerDimensionHiddenIdss::Size() == ntransform_ &&
UpperDimensionIdss::Size() == ntransform_, UpperDimensionHiddenIdss::Size() == ntransform_,
"wrong! inconsistent # of transformations"); "wrong! inconsistent # of transformations");
// TODO check dependency of dimensions is valid // TODO check dependency of dimensions is valid
} }
__host__ __device__ static constexpr index_t GetNumOfDimension() // construct from TensorAdaptor base class
__host__ __device__ constexpr TensorDescriptor(const Base& adaptor,
ElementSpaceSize element_space_size)
: Base{adaptor}, element_space_size_{element_space_size}
{ {
return GetNumOfVisibleDimension();
} }
template <index_t IDim> __host__ __device__ static constexpr index_t GetNumOfDimension()
__host__ __device__ constexpr auto GetLength(Number<IDim>) const
{ {
static_assert(IDim >= 0 && IDim < ndim_visible_, "wrong! out of range"); return Base::GetNumOfTopDimension();
constexpr auto tmp = GetTransformAndItsUpperDimension(Number<IDim>{});
constexpr index_t itran = tmp[Number<0>{}];
constexpr index_t idim_up = tmp[Number<1>{}];
constexpr bool found = tmp[Number<2>{}];
static_assert(found == true,
"wrong! not found matching transformation and upper-dimension");
return transforms_[Number<itran>{}].GetUpperLengths()[Number<idim_up>{}];
} }
__host__ __device__ constexpr auto GetLengths() const template <index_t IDim>
__host__ __device__ constexpr auto GetLength(Number<IDim> idim) const
{ {
// FIXME: use Tuple of reference instead return Base::GetTopDimensionLength(idim);
return generate_sequence_v2([&](auto I) { return GetLength(I); }, Number<ndim_visible_>{});
} }
__host__ __device__ constexpr auto GetElementSize() const { return element_size_; } __host__ __device__ constexpr auto GetLengths() const { return Base::GetTopDimensionLengths(); }
__host__ __device__ constexpr auto GetElementSpaceSize() const { return element_space_size_; } __host__ __device__ constexpr auto GetElementSpaceSize() const { return element_space_size_; }
template <typename Idx> template <typename Idx>
__host__ __device__ constexpr index_t CalculateOffset(const Idx& idx) const __host__ __device__ constexpr index_t CalculateOffset(const Idx& idx) const
{ {
static_assert(Idx::Size() == GetNumOfDimension(), "wrong! inconsistent # of dimension"); return Base::CalculateBottomIndex(idx)[Number<0>{}];
return make_tensor_coordinate(*this, idx).GetOffset();
} }
// TODO make these private // TODO make these private
__host__ __device__ constexpr const auto& GetTransforms() const { return transforms_; } __host__ __device__ constexpr const auto& GetTransforms() const
__host__ __device__ static constexpr auto GetLowerDimensionIdss()
{ {
return LowerDimensionIdss{}; return Base::GetTransforms();
} }
__host__ __device__ static constexpr auto GetUpperDimensionIdss() __host__ __device__ static constexpr auto GetLowerDimensionHiddenIdss()
{ {
return UpperDimensionIdss{}; return Base::GetLowerDimensionHiddenIdss();
} }
__host__ __device__ static constexpr auto GetVisibleDimensionIds() __host__ __device__ static constexpr auto GetUpperDimensionHiddenIdss()
{ {
return VisibleDimensionIds{}; return Base::GetUpperDimensionHiddenIdss();
} }
__host__ __device__ static constexpr bool IsKnownAtCompileTime() __host__ __device__ static constexpr auto GetTopDimensionHiddenIds()
{ {
bool is_known = true; return Base::GetTopDimensionHiddenIds();
static_for<0, Transforms::Size(), 1>{}([&](auto i) {
is_known &= remove_cvref_t<decltype(Transforms{}[i])>::IsKnownAtCompileTime();
});
return is_known && is_known_at_compile_time<ElementSize>::value &&
is_known_at_compile_time<ElementSpaceSize>::value;
} }
__host__ __device__ void Print() const __host__ __device__ static constexpr bool IsStatic()
{ {
printf("{"); return Base::IsKnownAtCompileTime() && is_known_at_compile_time<ElementSpaceSize>::value;
printf("TensorDescriptor, ");
static_for<0, ntransform_, 1>{}([&](auto i) {
printf("transforms: ");
transforms_[i].Print();
printf("LowerDimensionIds:");
LowerDimensionIdss{}.At(i).Print();
printf("UpperDimensionIds:");
UpperDimensionIdss{}.At(i).Print();
});
printf("}");
VisibleDimensionIds::Print();
} }
// TODO make these private __host__ __device__ static constexpr bool IsKnownAtCompileTime() { return IsStatic(); }
Transforms transforms_;
ElementSize element_size_;
ElementSpaceSize element_space_size_;
};
template <index_t NDimHidden, typename VisibleDimensionIds> __host__ __device__ static constexpr auto GetTopDimensionSafeVectorLengthStrides()
struct TensorCoordinate
{
// TODO make these private
static constexpr index_t ndim_visible_ = VisibleDimensionIds::Size();
using HiddenIndex = MultiIndex<NDimHidden>;
using VisibleIndex = MultiIndex<ndim_visible_>;
public:
__host__ __device__ constexpr TensorCoordinate() = default;
__host__ __device__ constexpr TensorCoordinate(const HiddenIndex& idx_hidden)
: idx_hidden_{idx_hidden}
{ {
return Base::GetTopDimensionSafeVectorLengthStrides(
to_array<index_t, ndim_hidden_>(GuaranteedVectorLengths{}),
to_array<index_t, ndim_hidden_>(GuaranteedVectorStrides{}));
} }
__host__ __device__ constexpr auto GetIndex() const { return GetVisibleIndex(); } __host__ __device__ void Print() const
__host__ __device__ constexpr index_t GetOffset() const { return idx_hidden_[Number<0>{}]; }
// TODO make these private
__host__ __device__ constexpr const auto& GetHiddenIndex() const { return idx_hidden_; }
__host__ __device__ auto& GetHiddenIndex() { return idx_hidden_; }
__host__ __device__ constexpr auto GetVisibleIndex() const
{ {
return get_container_subset(idx_hidden_, VisibleDimensionIds{}); printf("TensorDescriptor{");
}
// TODO make these private
HiddenIndex idx_hidden_;
};
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack> // TensorAdaptor
struct TensorCoordinateStep Base::Print();
{ printf(", ");
// TODO make these private
using VisibleIndex = MultiIndex<NDimVisible>;
public: // element_space_size_
__host__ __device__ constexpr TensorCoordinateStep() = default; printf("element_space_size_: ");
print(element_space_size_);
__host__ __device__ constexpr TensorCoordinateStep(const VisibleIndex& idx_diff_visible, printf("}");
const MultiIndex<NTransform>& do_transforms)
: idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms}
{
} }
__host__ __device__ constexpr const auto& GetIndexDiff() const { return GetVisibleIndexDiff(); }
// TODO make these private // TODO make these private
__host__ __device__ constexpr const auto& GetVisibleIndexDiff() const ElementSpaceSize element_space_size_;
{
return idx_diff_visible_;
}
VisibleIndex idx_diff_visible_;
MultiIndex<NTransform> do_transforms_;
// HACK: control UpdateLowerIndex()
static constexpr UpdateLowerIndexHack update_lower_index_hack_;
}; };
// TODO: How to fix this? It uses an struct instead of lambda because lambda template <typename Adaptor, typename ElementSpaceSize>
// doesn't have constructor, and to put it outside the scope where it is used __host__ __device__ constexpr auto
// (transform_tensor_descriptor) because template cannot be defined inside a function make_tensor_descriptor_from_adaptor(const Adaptor& adaptor,
// template const ElementSpaceSize& element_space_size)
template <typename NewTransforms>
struct lambda_get_up_dim_num
{ {
template <typename I> constexpr index_t NDimHidden = Adaptor::GetNumOfHiddenDimension();
__host__ __device__ constexpr auto operator()(I) const
{ return TensorDescriptor<remove_cvref_t<decltype(adaptor.GetTransforms())>,
using Tran = remove_reference_t<decltype(NewTransforms{}.At(I{}))>; remove_cvref_t<decltype(adaptor.GetLowerDimensionHiddenIdss())>,
return Number<Tran::GetNumOfUpperDimension()>{}; remove_cvref_t<decltype(adaptor.GetUpperDimensionHiddenIdss())>,
} remove_cvref_t<decltype(adaptor.GetTopDimensionHiddenIds())>,
}; remove_cvref_t<decltype(element_space_size)>,
typename uniform_sequence_gen<NDimHidden, -1>::type,
typename uniform_sequence_gen<NDimHidden, -1>::type>{
adaptor, element_space_size};
}
template <typename OldTensorDescriptor, template <typename OldTensorDescriptor,
typename NewTransforms, typename NewTransforms,
typename NewLowerDimensionOldVisibleIdss, typename NewLowerDimensionOldTopIdss,
typename NewUpperDimensionNewVisibleIdss> typename NewUpperDimensionNewTopIdss>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
const NewTransforms& new_transforms, const NewTransforms& new_transforms,
NewLowerDimensionOldVisibleIdss, NewLowerDimensionOldTopIdss,
NewUpperDimensionNewVisibleIdss) NewUpperDimensionNewTopIdss)
{ {
// sanity check
{
static_assert(NewTransforms::Size() == NewLowerDimensionOldVisibleIdss::Size() &&
NewTransforms::Size() == NewUpperDimensionNewVisibleIdss::Size(),
"wrong! inconsitent number of transform");
constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewLowerDimensionOldVisibleIdss{});
constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewUpperDimensionNewVisibleIdss{});
static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
is_valid_sequence_map<decltype(all_new_top_ids)>::value,
"wrong!");
}
// lower dimension's hidden idss
// convert lower dimension visible idss (tuple of sequences) to hidden idss (tuple of
// sequences)
constexpr auto low_dim_hidden_idss = transform_tuples(
// convert lower dimension visible ids (a sequence) to hidden ids (a sequence)
[](auto low_dim_visible_ids) constexpr {
return transform_sequences(
// convert lower dimension visible id to hidden id
[](auto low_dim_visible_id) constexpr {
return OldTensorDescriptor::GetVisibleDimensionIds()[low_dim_visible_id];
},
low_dim_visible_ids);
},
NewLowerDimensionOldVisibleIdss{});
constexpr index_t num_new_transform = NewTransforms::Size();
// upper dimension's hidden idss
constexpr index_t old_hidden_dim_number = OldTensorDescriptor::GetNumOfHiddenDimension();
constexpr auto up_dim_numbers =
generate_sequence(lambda_get_up_dim_num<NewTransforms>{}, Number<num_new_transform>{});
constexpr auto up_dim_numbers_scan = merge_sequences(
Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, math::plus<index_t>{}, Number<0>{}));
constexpr auto up_dim_hidden_idss = generate_tuple(
[ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr {
return
typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
old_hidden_dim_number + up_dim_numbers_scan[i + 1],
1>::type{};
},
Number<num_new_transform>{});
// new visible dimension's hidden ids
constexpr auto unordered_new_visible_dim_hidden_ids = unpack(
[](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
constexpr auto new_visible_dim_unordered2ordered = unpack(
[](auto... xs) constexpr { return merge_sequences(xs...); },
NewUpperDimensionNewVisibleIdss{});
constexpr auto new_visible_dim_hidden_ids =
unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered);
// put everything together
const auto all_transforms = container_concat(old_tensor_desc.GetTransforms(), new_transforms);
constexpr auto all_low_dim_hidden_idss =
container_concat(OldTensorDescriptor::GetLowerDimensionIdss(), low_dim_hidden_idss);
constexpr auto all_up_dim_hidden_idss =
container_concat(OldTensorDescriptor::GetUpperDimensionIdss(), up_dim_hidden_idss);
const auto element_space_size = old_tensor_desc.GetElementSpaceSize(); const auto element_space_size = old_tensor_desc.GetElementSpaceSize();
return TensorDescriptor<remove_cv_t<decltype(all_transforms)>, const auto new_tensor_adaptor = transform_tensor_adaptor(old_tensor_desc,
remove_cv_t<decltype(all_low_dim_hidden_idss)>, new_transforms,
remove_cv_t<decltype(all_up_dim_hidden_idss)>, NewLowerDimensionOldTopIdss{},
remove_cv_t<decltype(new_visible_dim_hidden_ids)>, NewUpperDimensionNewTopIdss{});
remove_cv_t<decltype(element_space_size)>>{all_transforms,
element_space_size}; constexpr index_t NDimHiddenOld = OldTensorDescriptor::GetNumOfHiddenDimension();
} constexpr index_t NDimHiddenNew = decltype(new_tensor_adaptor)::GetNumOfHiddenDimension();
template <typename TensorDesc, typename VisibleIndex> using NewGuaranteedVectorLengths = typename sequence_merge<
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc& tensor_desc, typename OldTensorDescriptor::GuaranteedVectorLengths,
const VisibleIndex& idx_visible) typename uniform_sequence_gen<NDimHiddenNew - NDimHiddenOld, -1>::type>::type;
{
static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(), using NewGuaranteedVectorStrides = typename sequence_merge<
"wrong! # of dimension inconsistent"); typename OldTensorDescriptor::GuaranteedVectorStrides,
typename uniform_sequence_gen<NDimHiddenNew - NDimHiddenOld, -1>::type>::type;
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension(); return TensorDescriptor<
constexpr auto visible_dim_ids = TensorDesc::GetVisibleDimensionIds(); remove_cvref_t<decltype(new_tensor_adaptor.GetTransforms())>,
remove_cvref_t<decltype(new_tensor_adaptor.GetLowerDimensionHiddenIdss())>,
MultiIndex<ndim_hidden> idx_hidden; remove_cvref_t<decltype(new_tensor_adaptor.GetUpperDimensionHiddenIdss())>,
remove_cvref_t<decltype(new_tensor_adaptor.GetTopDimensionHiddenIds())>,
// initialize visible index remove_cvref_t<decltype(element_space_size)>,
set_container_subset(idx_hidden, visible_dim_ids, idx_visible); NewGuaranteedVectorLengths,
NewGuaranteedVectorStrides>{new_tensor_adaptor, element_space_size};
// calculate hidden index
static_for<ntransform, 0, -1>{}([&tensor_desc, &idx_hidden](auto itran_p1) {
auto itran = itran_p1 - Number<1>{};
const auto& tran = tensor_desc.GetTransforms().At(itran);
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran);
const auto idx_up = get_container_subset(idx_hidden, dims_up);
MultiIndex<dims_low.Size()> idx_low;
tran.CalculateLowerIndex(idx_low, idx_up);
set_container_subset(idx_hidden, dims_low, idx_low);
});
return TensorCoordinate<ndim_hidden, decltype(visible_dim_ids)>{idx_hidden};
}
// UpdateLowerIndexHack: Sequence<...>
// HACK: control UpdateLowerIndex
template <typename TensorDesc, typename VisibleIndex, typename UpdateLowerIndexHack>
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&,
const VisibleIndex& idx_diff_visible,
UpdateLowerIndexHack)
{
static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(),
"wrong! # of dimension inconsistent");
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension();
constexpr index_t ndim_visible = TensorDesc::GetNumOfVisibleDimension();
constexpr auto visible_dim_ids = TensorDesc::GetVisibleDimensionIds();
static_assert(UpdateLowerIndexHack::Size() == ntransform, "wrong!");
// use index_t for boolean type
auto do_transforms = make_zero_multi_index<ntransform>();
auto is_non_zero_diff = make_zero_multi_index<ndim_hidden>();
// decide do_transform by checkout non-zero index diff components
MultiIndex<VisibleIndex::Size()> non_zero_diff_pick_visible;
static_for<0, ndim_visible, 1>{}(
[&](auto i) { non_zero_diff_pick_visible(i) = (idx_diff_visible[i] != 0); });
set_container_subset(is_non_zero_diff, visible_dim_ids, non_zero_diff_pick_visible);
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran);
const auto non_zero_diff_pick_up = get_container_subset(is_non_zero_diff, dims_up);
MultiIndex<dims_low.Size()> non_zero_diff_pick_low;
// if any of upper index diff components is non-zero, then
// 1) Need to do this transform
// 2) all components of lower index diff will assume to be non-zero and need to be
// computed
const bool idx_diff_up_has_non_zero = container_reduce(
non_zero_diff_pick_up, [](auto a, auto b) constexpr { return a or b; }, false);
do_transforms(itran) = idx_diff_up_has_non_zero;
static_for<0, dims_low.Size(), 1>{}(
[&](auto i) { non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero; });
set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low);
});
return TensorCoordinateStep<ntransform, ndim_visible, UpdateLowerIndexHack>{idx_diff_visible,
do_transforms};
}
template <typename TensorDesc, typename VisibleIndex>
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&,
const VisibleIndex& idx_diff_visible)
{
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
return make_tensor_coordinate_step(
TensorDesc{}, idx_diff_visible, typename uniform_sequence_gen<ntransform, 0>::type{});
}
template <typename TensorDesc, typename TensorCoord, typename TensorCoordStep>
__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tensor_desc,
TensorCoord& coord,
const TensorCoordStep& coord_step)
{
constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension();
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
// this is what needs to be calculated
auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>();
// initialize visible index diff
set_container_subset(
idx_diff_hidden, TensorDesc::GetVisibleDimensionIds(), coord_step.GetVisibleIndexDiff());
// this is what needs to be updated
auto& idx_hidden = coord.GetHiddenIndex();
// update visible index
auto idx_hidden_pick_visible =
get_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds());
idx_hidden_pick_visible += coord_step.GetIndexDiff();
set_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds(), idx_hidden_pick_visible);
// update rest of hidden index
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
if(coord_step.do_transforms_[itran])
{
const auto& tran = tensor_desc.GetTransforms().At(itran);
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran);
const auto idx_up_new = get_container_subset(idx_hidden, dims_up);
auto idx_low = get_container_subset(idx_hidden, dims_low);
const auto idx_diff_up = get_container_subset(idx_diff_hidden, dims_up);
MultiIndex<dims_low.Size()> idx_diff_low;
// HACK: control UpdateLowerIndex for Merge using hack
constexpr index_t Hack = decltype(coord_step.update_lower_index_hack_)::At(itran);
tran.UpdateLowerIndex(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
set_container_subset(idx_diff_hidden, dims_low, idx_diff_low);
set_container_subset(idx_hidden, dims_low, idx_low);
}
});
} }
template <typename TensorDesc, typename TensorCoord>
__host__ __device__ constexpr bool
coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc& tensor_desc,
const TensorCoord& coord)
{
bool valid = true;
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
const auto& idx_hidden = coord.GetHiddenIndex();
static_for<ntransform - 1, -1, -1>{}([&tensor_desc, &idx_hidden, &valid](auto itran) {
const auto tran = tensor_desc.GetTransforms().At(itran);
// check validity, only if current transformation does not always has a valid mapping
if constexpr(!decltype(tran)::IsValidUpperIndexAlwaysMappedToValidLowerIndex())
{
const auto idx_up =
get_container_subset(idx_hidden, TensorDesc::GetUpperDimensionIdss().At(itran));
// Comment: using valid = valid && .. will result in weird control flow in ISA
valid &= tran.IsValidUpperIndexMappedToValidLowerIndex(idx_up);
}
});
return valid;
}
template <typename TensorDesc, typename TensorCoord>
__host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc& tensor_desc,
const TensorCoord& coord)
{
// check visible index
const auto& idx_visible = coord.GetVisibleIndex();
bool is_visible_index_valid = true;
static_for<0, TensorDesc::GetNumOfDimension(), 1>{}(
[&is_visible_index_valid, &idx_visible, &tensor_desc](auto i) {
is_visible_index_valid =
is_visible_index_valid &&
(idx_visible[i] >= 0 && idx_visible[i] < tensor_desc.GetLength(i));
});
// check other hidden index
return is_visible_index_valid &&
coordinate_has_valid_offset_assuming_visible_index_is_valid(tensor_desc, coord);
}
template <typename TensorDesc>
using TensorCoordinate_t = decltype(make_tensor_coordinate(
TensorDesc{}, MultiIndex<remove_cvref_t<TensorDesc>::GetNumOfDimension()>{}));
template <typename TensorDesc>
using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step(
TensorDesc{}, MultiIndex<remove_cvref_t<TensorDesc>::GetNumOfDimension()>{}));
} // namespace ck } // namespace ck
...@@ -4,20 +4,13 @@ ...@@ -4,20 +4,13 @@
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
namespace ck { namespace ck {
/* namespace detail {
* These functions create tensor descriptor at runtime. If they are not constexpr, you will
* likely see usage of scratch memory during construction of these tensor descriptors. So
* it's better to call these functions on host and then pass the constructed tensor descritpors
* to GPU. If the tensor descritpors being constructed are constexpr, then you can call these
* functions on GPU without worrying about scratch memory usage.
*/
#if CK_WORKAROUND_SWDEV_275126
template <typename Lengths, typename Strides, index_t I, typename AccOld> template <typename Lengths, typename Strides, index_t I, typename AccOld>
__host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengths& lengths, __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengths& lengths,
const Strides& strides, const Strides& strides,
...@@ -35,7 +28,12 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt ...@@ -35,7 +28,12 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt
return acc_new; return acc_new;
} }
} }
#endif
} // namespace detail
/*
* These functions create naive tensor descriptor
*/
// Lengths..., Strides... could be: // Lengths..., Strides... could be:
// 1) index_t, which is known at run-time, or // 1) index_t, which is known at run-time, or
...@@ -45,9 +43,14 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt ...@@ -45,9 +43,14 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt
// 2) LongNumber<> // 2) LongNumber<>
template <typename... Lengths, template <typename... Lengths,
typename... Strides, typename... Strides,
index_t GuaranteedLastDimensionVectorLength = -1,
index_t GuaranteedLastDimensionVectorStride = -1,
typename enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false> typename enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Lengths...>& lengths, __host__ __device__ constexpr auto
const Tuple<Strides...>& strides) make_naive_tensor_descriptor(const Tuple<Lengths...>& lengths,
const Tuple<Strides...>& strides,
Number<GuaranteedLastDimensionVectorLength> = Number<-1>{},
Number<GuaranteedLastDimensionVectorStride> = Number<-1>{})
{ {
constexpr index_t N = sizeof...(Lengths); constexpr index_t N = sizeof...(Lengths);
...@@ -60,34 +63,24 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Leng ...@@ -60,34 +63,24 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Leng
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{}; constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
#if !CK_WORKAROUND_SWDEV_275126
// rocm-4.1 compiler would crash for recursive labmda
// recursive function for reduction
auto f = [&](auto fs, auto i, auto acc_old) {
auto acc_new = acc_old + (lengths[i] - Number<1>{}) * strides[i];
if constexpr(i.value < N - 1)
{
return fs(fs, i + Number<1>{}, acc_new);
}
else
{
return acc_new;
}
};
const auto element_space_size = f(f, Number<0>{}, LongNumber<1>{});
#else
const auto element_space_size = const auto element_space_size =
calculate_element_space_size_impl(lengths, strides, Number<0>{}, LongNumber<1>{}); detail::calculate_element_space_size_impl(lengths, strides, Number<0>{}, LongNumber<1>{});
#endif
using GuaranteedVectorLengths =
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
Sequence<GuaranteedLastDimensionVectorLength>>::type;
using GuaranteedVectorStrides =
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
Sequence<GuaranteedLastDimensionVectorStride>>::type;
return TensorDescriptor<remove_cv_t<decltype(transforms)>, return TensorDescriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>, remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>, remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>, remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>>{transforms, remove_cv_t<decltype(element_space_size)>,
element_space_size}; GuaranteedVectorLengths,
GuaranteedVectorStrides>{transforms, element_space_size};
} }
// Lengths... could be: // Lengths... could be:
...@@ -96,9 +89,10 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Leng ...@@ -96,9 +89,10 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Leng
// element_space_size could be: // element_space_size could be:
// 1) long_index_t, or // 1) long_index_t, or
// 2) LongNumber<> // 2) LongNumber<>
template <typename... Lengths> template <typename... Lengths, index_t GuaranteedLastDimensionVectorLength = -1>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths) make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths,
Number<GuaranteedLastDimensionVectorLength> = Number<-1>{})
{ {
constexpr index_t N = sizeof...(Lengths); constexpr index_t N = sizeof...(Lengths);
...@@ -113,12 +107,20 @@ make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths) ...@@ -113,12 +107,20 @@ make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths)
const auto element_space_size = container_reduce(lengths, math::multiplies{}, LongNumber<1>{}); const auto element_space_size = container_reduce(lengths, math::multiplies{}, LongNumber<1>{});
using GuaranteedVectorLengths =
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type,
Sequence<GuaranteedLastDimensionVectorLength>>::type;
using GuaranteedVectorStrides =
typename sequence_merge<typename uniform_sequence_gen<N, -1>::type, Sequence<1>>::type;
return TensorDescriptor<remove_cv_t<decltype(transforms)>, return TensorDescriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>, remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>, remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>, remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>>{transforms, remove_cv_t<decltype(element_space_size)>,
element_space_size}; GuaranteedVectorLengths,
GuaranteedVectorStrides>{transforms, element_space_size};
} }
// Lengths... could be: // Lengths... could be:
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "ck/utility/math.hpp" #include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp" #include "ck/utility/sequence.hpp"
#include "ck/utility/sequence_helper.hpp" #include "ck/utility/sequence_helper.hpp"
#include "ck/utility/statically_indexed_array_multi_index.hpp" #include "ck/utility/multi_index.hpp"
#include "ck/utility/tuple_helper.hpp" #include "ck/utility/tuple_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp" #include "ck/tensor_description/tensor_adaptor.hpp"
......
...@@ -87,7 +87,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -87,7 +87,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
// |KRepeat |MRepeat|MWave |MLane |KPack // |KRepeat |MRepeat|MWave |MLane |KPack
return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0); return make_multi_index(0, 0, waveId_m, WMMA_a_idx, 0);
} }
__device__ static auto CalculateBThreadOriginDataIndex() __device__ static auto CalculateBThreadOriginDataIndex()
...@@ -98,7 +98,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -98,7 +98,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
// |KRepeat |NRepeat|Nwave |NLane |KPack // |KRepeat |NRepeat|Nwave |NLane |KPack
return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0); return make_multi_index(0, 0, waveId_n, WMMA_b_idx, 0);
} }
template <index_t m0, index_t n0> template <index_t m0, index_t n0>
......
...@@ -108,7 +108,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -108,7 +108,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]); return make_multi_index(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]);
} }
__device__ static auto CalculateBThreadOriginDataIndex() __device__ static auto CalculateBThreadOriginDataIndex()
...@@ -119,7 +119,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -119,7 +119,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]); return make_multi_index(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]);
} }
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i> template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
...@@ -144,11 +144,12 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -144,11 +144,12 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple(Sequence<0, 1, 2>{})); make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; make_multi_index(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; make_multi_index(n0, waveId_n, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n); return make_multi_index(c_thread_m, c_thread_n);
} }
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i> template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
...@@ -336,17 +337,19 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -336,17 +337,19 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
vector_type<FloatAB, KPack> b_thread_vec; vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) { static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf a_thread_vec.template AsType<FloatAB>()(i) =
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}]; a_thread_buf[Number<a_thread_desc_.CalculateOffset(
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf make_multi_index(0, 0, 0, k + i))>{}];
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}]; b_thread_vec.template AsType<FloatAB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_multi_index(0, 0, 0, k + i))>{}];
}); });
using mfma_input_type = using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type; typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_multi_index(m0, n0, 0));
xdlops_gemm.template Run( xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
...@@ -707,7 +710,7 @@ struct BlockwiseGemmXdlops_v2 ...@@ -707,7 +710,7 @@ struct BlockwiseGemmXdlops_v2
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]); return make_multi_index(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]);
} }
__device__ static auto CalculateBThreadOriginDataIndex() __device__ static auto CalculateBThreadOriginDataIndex()
...@@ -718,7 +721,7 @@ struct BlockwiseGemmXdlops_v2 ...@@ -718,7 +721,7 @@ struct BlockwiseGemmXdlops_v2
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]); return make_multi_index(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]);
} }
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i> template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
...@@ -765,10 +768,10 @@ struct BlockwiseGemmXdlops_v2 ...@@ -765,10 +768,10 @@ struct BlockwiseGemmXdlops_v2
m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]); m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
} }
using Tuple4 = decltype(CalculateAThreadOriginDataIndex()); using Array4 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin = CalculateAThreadOriginDataIndex(), __host__ __device__ BlockwiseGemmXdlops_v2(Array4 a_origin = CalculateAThreadOriginDataIndex(),
Tuple4 b_origin = CalculateBThreadOriginDataIndex()) Array4 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin) : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{ {
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(), static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
......
...@@ -66,7 +66,7 @@ struct BlockwiseSoftmax ...@@ -66,7 +66,7 @@ struct BlockwiseSoftmax
reduce::Add, reduce::Add,
false>>::type; false>>::type;
using ThreadClusterLengths_M_K = decltype(ThreadClusterDesc_M_K{}.GetLengths()); using ThreadClusterLengths_M_K = decltype(to_sequence(ThreadClusterDesc_M_K{}.GetLengths()));
using BlockwiseMaxReduce = PartitionedBlockwiseReduction_v2<AccDataType, using BlockwiseMaxReduce = PartitionedBlockwiseReduction_v2<AccDataType,
BlockSize, BlockSize,
......
...@@ -50,6 +50,10 @@ struct ThreadGroupTensorSliceTransfer_v4r1 ...@@ -50,6 +50,10 @@ struct ThreadGroupTensorSliceTransfer_v4r1
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
#if 1 // debug
__host__ __device__ constexpr ThreadGroupTensorSliceTransfer_v4r1() : threadwise_transfer_{} {}
#endif
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r1( __device__ constexpr ThreadGroupTensorSliceTransfer_v4r1(
const SrcDesc& src_desc, const SrcDesc& src_desc,
const Index& src_block_slice_origin, const Index& src_block_slice_origin,
......
...@@ -86,7 +86,7 @@ struct BlockToCTileMap_M00_N0_M01 ...@@ -86,7 +86,7 @@ struct BlockToCTileMap_M00_N0_M01
const auto M00 = math::integer_divide_ceil(M0, M01); const auto M00 = math::integer_divide_ceil(M0, M01);
const auto m00_n0_m01_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor( const auto m00_n0_m01_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_insert_transform(1), make_tuple(make_replicate_transform(make_tuple(1)),
make_unmerge_transform(make_tuple(M00, M01)), make_unmerge_transform(make_tuple(M00, M01)),
make_pass_through_transform(make_tuple(N0))), make_pass_through_transform(make_tuple(N0))),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
...@@ -402,7 +402,8 @@ struct BlockToCTileMap_M00_N00_M01_N01 ...@@ -402,7 +402,8 @@ struct BlockToCTileMap_M00_N00_M01_N01
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
make_tuple(make_insert_transform(1), // swallow the carry from lower dimensions make_tuple(make_replicate_transform(
make_tuple(1)), // swallow the carry from lower dimensions
make_unmerge_transform(make_tuple(M00, M01)), make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))), make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
......
...@@ -630,7 +630,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -630,7 +630,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
Gemm1KPack, // AMmaKStride Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{ Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
// BMmaKStride // BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin make_multi_index(0, 0, 0, 0)}; // A_origin
auto c_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer(); auto c_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer();
......
...@@ -796,7 +796,7 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle ...@@ -796,7 +796,7 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
Gemm1KPack, // AMmaKStride Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{ Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
// BMmaKStride // BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin make_multi_index(0, 0, 0, 0)}; // A_origin
auto acc1_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer(); auto acc1_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer();
...@@ -953,7 +953,7 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle ...@@ -953,7 +953,7 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
// works like multi-dimension static_for (static_ford), but provides both the linear // works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index // index as well as n-d index
using Acc0TileIterator = SpaceFillingCurve< using Acc0TileIterator = SpaceFillingCurve<
decltype(c_thread_lengths), decltype(to_sequence(c_thread_lengths)),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type, typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type, typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>; // SnakeCurved false>; // SnakeCurved
......
...@@ -651,7 +651,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -651,7 +651,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
Gemm1KPack, // AMmaKStride Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{ Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
// BMmaKStride // BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin make_multi_index(0, 0, 0, 0)}; // A_origin
auto acc1_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer(); auto acc1_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer();
...@@ -769,7 +769,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -769,7 +769,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// works like multi-dimension static_for (static_ford), but provides both the linear // works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index // index as well as n-d index
using Acc0TileIterator = SpaceFillingCurve< using Acc0TileIterator = SpaceFillingCurve<
decltype(c_thread_lengths), decltype(to_sequence(c_thread_lengths)),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type, typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type, typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>; // SnakeCurved false>; // SnakeCurved
......
...@@ -45,7 +45,7 @@ struct ThreadwiseTensorSliceSet_v1 ...@@ -45,7 +45,7 @@ struct ThreadwiseTensorSliceSet_v1
constexpr auto coord = make_tensor_coordinate(desc, origin_idx + access_idx); constexpr auto coord = make_tensor_coordinate(desc, origin_idx + access_idx);
constexpr bool is_valid = constexpr bool is_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(desc, coord); coordinate_has_valid_offset_assuming_top_index_is_valid(desc, coord);
constexpr index_t offset = coord.GetOffset(); constexpr index_t offset = coord.GetOffset();
......
...@@ -70,8 +70,6 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -70,8 +70,6 @@ struct ThreadwiseTensorSliceTransfer_v1r3
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(const DstDesc& dst_desc, __device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(const DstDesc& dst_desc,
const Index& dst_slice_origin_idx, const Index& dst_slice_origin_idx,
const ElementwiseOperation& element_op) const ElementwiseOperation& element_op)
...@@ -147,7 +145,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -147,7 +145,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
}); });
const bool is_dst_valid = const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); coordinate_has_valid_offset_assuming_top_index_is_valid(dst_desc, dst_coord_);
// copy data from dst_vector into dst_buf // copy data from dst_vector into dst_buf
dst_buf.template Update<DstInMemOp, dst_vector_t>( dst_buf.template Update<DstInMemOp, dst_vector_t>(
...@@ -159,18 +157,14 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -159,18 +157,14 @@ struct ThreadwiseTensorSliceTransfer_v1r3
{ {
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate( move_tensor_coordinate(dst_desc, dst_coord_, forward_step);
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
} }
}); });
// move dst coordinate back to slice origin (or not) // move dst coordinate back to slice origin (or not)
if constexpr(DstResetCoordinateAfterRun) if constexpr(DstResetCoordinateAfterRun)
{ {
const auto dst_reset_step = move_tensor_coordinate(dst_desc, dst_coord_, GetDstCoordinateResetStep());
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
} }
} }
...@@ -250,8 +244,6 @@ struct ThreadwiseTensorSliceTransfer_v2 ...@@ -250,8 +244,6 @@ struct ThreadwiseTensorSliceTransfer_v2
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDesc& src_desc, __device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDesc& src_desc,
const Index& src_slice_origin_idx) const Index& src_slice_origin_idx)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx)) : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx))
...@@ -311,7 +303,7 @@ struct ThreadwiseTensorSliceTransfer_v2 ...@@ -311,7 +303,7 @@ struct ThreadwiseTensorSliceTransfer_v2
constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d); constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d);
const bool is_src_valid = const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); coordinate_has_valid_offset_assuming_top_index_is_valid(src_desc, src_coord_);
// copy data from src_buf into src_vector // copy data from src_buf into src_vector
src_vector.template AsType<src_vector_t>()(Number<0>{}) = src_vector.template AsType<src_vector_t>()(Number<0>{}) =
...@@ -341,18 +333,14 @@ struct ThreadwiseTensorSliceTransfer_v2 ...@@ -341,18 +333,14 @@ struct ThreadwiseTensorSliceTransfer_v2
{ {
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate( move_tensor_coordinate(src_desc, src_coord_, forward_step);
src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step));
} }
}); });
// move src coordinate back to slice origin (or not) // move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun) if constexpr(SrcResetCoordinateAfterRun)
{ {
const auto src_reset_step = move_tensor_coordinate(src_desc, src_coord_, GetSrcCoordinateResetStep());
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
} }
} }
...@@ -388,29 +376,7 @@ struct ThreadwiseTensorSliceTransfer_v2 ...@@ -388,29 +376,7 @@ struct ThreadwiseTensorSliceTransfer_v2
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep(); : src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time? move_tensor_coordinate(src_desc, src_coord_, adjusted_step_idx);
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template <typename SrcMoveSliceWindowStepHack>
__device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx,
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(
src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
} }
private: private:
...@@ -450,9 +416,6 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -450,9 +416,6 @@ struct ThreadwiseTensorSliceTransfer_v3
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v3(const SrcDesc& src_desc, __device__ constexpr ThreadwiseTensorSliceTransfer_v3(const SrcDesc& src_desc,
const Index& src_slice_origin, const Index& src_slice_origin,
const DstDesc& dst_desc, const DstDesc& dst_desc,
...@@ -574,7 +537,7 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -574,7 +537,7 @@ struct ThreadwiseTensorSliceTransfer_v3
using src_vector_t = typename decltype(src_tmp_vector)::type; using src_vector_t = typename decltype(src_tmp_vector)::type;
const bool is_src_valid = const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); coordinate_has_valid_offset_assuming_top_index_is_valid(src_desc, src_coord_);
// copy data from src_buf to src_tmp_vector // copy data from src_buf to src_tmp_vector
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) = src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
...@@ -741,7 +704,7 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -741,7 +704,7 @@ struct ThreadwiseTensorSliceTransfer_v3
// copy data from dst_tmp_vector to dst_buf // copy data from dst_tmp_vector to dst_buf
const bool is_dst_valid = const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); coordinate_has_valid_offset_assuming_top_index_is_valid(dst_desc, dst_coord_);
dst_buf.template Set<dst_vector_t>( dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset(), dst_coord_.GetOffset(),
...@@ -1033,8 +996,6 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1033,8 +996,6 @@ struct ThreadwiseTensorSliceTransfer_v4
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v4(const Index& src_ref_idx) __device__ constexpr ThreadwiseTensorSliceTransfer_v4(const Index& src_ref_idx)
: src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx)) : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx))
{ {
...@@ -1130,19 +1091,16 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1130,19 +1091,16 @@ struct ThreadwiseTensorSliceTransfer_v4
constexpr auto src_ref_to_data_disp_idx = constexpr auto src_ref_to_data_disp_idx =
src_ref_to_origin_disp_idx + data_to_origin_disp_idx; src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
constexpr auto src_ref_to_data_disp_coord_step =
make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx);
auto src_data_coord = src_ref_coord_; auto src_data_coord = src_ref_coord_;
move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step); move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_idx);
vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector; vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector;
using src_vector_t = typename decltype(src_tmp_vector)::type; using src_vector_t = typename decltype(src_tmp_vector)::type;
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( const bool is_src_valid =
src_desc, src_data_coord); coordinate_has_valid_offset_assuming_top_index_is_valid(src_desc, src_data_coord);
// copy data from src_buf into src_tmp_vector // copy data from src_buf into src_tmp_vector
if constexpr(SrcBuffer::IsDynamicBuffer()) if constexpr(SrcBuffer::IsDynamicBuffer())
......
...@@ -6,7 +6,10 @@ ...@@ -6,7 +6,10 @@
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_coordinate.hpp"
#include "ck/tensor/thread_private_tensor.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
// FIXME: remove
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp" #include "ck/tensor/static_tensor.hpp"
...@@ -53,8 +56,8 @@ template <typename SliceLengths, ...@@ -53,8 +56,8 @@ template <typename SliceLengths,
typename SrcElementwiseOperation, typename SrcElementwiseOperation,
typename DstElementwiseOperation, typename DstElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp, InMemoryDataOperationEnum DstInMemOp,
typename SrcData, typename SrcDataTmp,
typename DstData, typename DstDataTmp,
typename SrcDesc, typename SrcDesc,
typename DstDesc, typename DstDesc,
typename SrcDimAccessOrder, typename SrcDimAccessOrder,
...@@ -77,14 +80,34 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -77,14 +80,34 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
using SrcData = remove_cvref_t<SrcDataTmp>;
using DstData = remove_cvref_t<DstDataTmp>;
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
static constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
static constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
using SrcSpaceFillingCurve = SpaceFillingCurve<SliceLengths,
SrcDimAccessOrder,
remove_cv_t<decltype(src_scalar_per_access)>>;
using DstSpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DstDimAccessOrder,
remove_cv_t<decltype(dst_scalar_per_access)>>;
__host__ __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1()
: src_coord_{}, dst_coord_{}, src_element_op_{}, dst_element_op_{}
{
}
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1( __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(
const SrcDesc& src_desc, const SrcDesc& src_desc,
const Index& src_slice_origin, const Index& src_slice_origin,
...@@ -122,87 +145,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -122,87 +145,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1
is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value, is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value,
"wrong! SrcBuffer and SrcData data type are inconsistent"); "wrong! SrcBuffer and SrcData data type are inconsistent");
// scalar per access on each dim constexpr auto num_access = SrcSpaceFillingCurve::GetNumOfAccess();
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
static_assert(SliceLengths::At(SrcVectorDim) % SrcScalarPerVector == 0,
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector");
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// make forward steps
const auto src_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(src_desc, forward_step_idx);
},
Number<nDim>{});
// make backward steps
const auto src_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(src_desc, backward_step_idx);
},
Number<nDim>{});
// loop over tensor and copy
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// loop over space-filling curve
static_for<0, num_access, 1>{}([&](auto iAccess) {
// calculate src data index // calculate src data index
constexpr auto src_data_idx = [&]() { constexpr auto src_data_idx = SrcSpaceFillingCurve::GetIndexTupleOfNumber(iAccess);
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
: ordered_src_access_lengths[i] - 1 -
ordered_src_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_scalar_per_access;
}();
constexpr auto src_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
const bool is_src_valid = const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); coordinate_has_valid_offset_assuming_top_index_is_valid(src_desc, src_coord_);
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>; using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type; using src_vector_t = typename src_vector_type::type;
...@@ -212,50 +163,29 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -212,50 +163,29 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)}; src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
// copy data from src_vector_container into src_thread_scratch_ // copy data from src_vector_container into src_thread_scratch_
#if 1 // debug
src_thread_scratch_tuple_(thread_scratch_id) src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<src_vector_t>( .template SetAsType<src_vector_t>(
src_data_idx_seq, src_vector_container.template AsType<src_vector_t>()[I0]); src_data_idx, src_vector_container.template AsType<src_vector_t>()[I0]);
#else
src_thread_scratch_tuple_(thread_scratch_id)
.template Set<src_vector_t>(
src_data_idx, src_vector_container.template AsType<src_vector_t>()[I0]);
#endif
constexpr auto move_on_dim = [&]() constexpr // move src coordinate
if constexpr(iAccess.value != num_access - 1)
{ {
StaticallyIndexedArray<bool, nDim> move_on_dim_; constexpr auto step = SrcSpaceFillingCurve::GetForwardStep(iAccess);
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) { move_tensor_coordinate(src_desc, src_coord_, step);
move_on_dim_(i) &=
ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
});
});
return move_on_dim_;
} }
();
// move src coord
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(
src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
}
else
{
move_tensor_coordinate(
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
}
}
});
}); });
// move src coordinate back to slice origin (or not) // move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun) if constexpr(SrcResetCoordinateAfterRun)
{ {
const auto src_reset_step = const auto src_reset_step = GetSrcCoordinateResetStep();
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
move_tensor_coordinate(src_desc, src_coord_, src_reset_step); move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
} }
...@@ -265,13 +195,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -265,13 +195,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
__device__ void __device__ void
TransferDataFromSrcThreadScratchToDstThreadScratch(Number<ThreadScratchId> thread_scratch_id) TransferDataFromSrcThreadScratchToDstThreadScratch(Number<ThreadScratchId> thread_scratch_id)
{ {
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
static_ford<SliceLengths>{}([&](auto idx) {
// convert from SrcData to DstData here
dst_thread_scratch_(idx) =
type_convert<DstData>(src_thread_scratch_tuple_[thread_scratch_id][idx]);
});
#else
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
// TODO make this logic more generic for more sub-dword datatype // TODO make this logic more generic for more sub-dword datatype
if constexpr(SrcVectorDim != DstVectorDim && if constexpr(SrcVectorDim != DstVectorDim &&
...@@ -347,7 +270,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -347,7 +270,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_element_op_(dst_v, src_thread_scratch_tuple_[thread_scratch_id][idx]); src_element_op_(dst_v, src_thread_scratch_tuple_[thread_scratch_id][idx]);
dst_thread_scratch_(idx) = dst_v; dst_thread_scratch_(idx) = dst_v;
}); });
#endif
} }
template <typename DstBuffer, index_t ThreadScratchId = 0> template <typename DstBuffer, index_t ThreadScratchId = 0>
...@@ -367,91 +289,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -367,91 +289,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value, is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
"wrong! SrcBuffer or DstBuffer data type is wrong"); "wrong! SrcBuffer or DstBuffer data type is wrong");
// src scalar per access on each dim constexpr auto num_access = DstSpaceFillingCurve::GetNumOfAccess();
// TODO: don't use this
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// make forward steps
const auto dst_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst_desc, forward_step_idx);
},
Number<nDim>{});
// make backward steps
const auto dst_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst_desc, backward_step_idx);
},
Number<nDim>{});
// loop over tensor and copy // loop over tensor and copy
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) { static_for<0, num_access, 1>{}([&](auto iAccess) {
// judge move forward or move backward // calculate src data index
constexpr auto forward_sweep = [&]() { constexpr auto dst_data_idx = DstSpaceFillingCurve::GetIndexTupleOfNumber(iAccess);
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate dst data index
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i]
: ordered_dst_access_lengths[i] - 1 -
ordered_dst_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_scalar_per_access;
}();
constexpr auto dst_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<dst_data_idx[i]>{}; }, Number<dst_data_idx.Size()>{});
const bool is_dst_valid = const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); coordinate_has_valid_offset_assuming_top_index_is_valid(dst_desc, dst_coord_);
using dst_vector_type = vector_type_maker_t<DstData, DstScalarPerVector>; using dst_vector_type = vector_type_maker_t<DstData, DstScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type; using dst_vector_t = typename dst_vector_type::type;
// copy data from dst_thread_scratch_ into dst_vector_container // copy data from dst_thread_scratch_ into dst_vector_container
auto dst_vector_container = dst_vector_type{ auto dst_vector_container =
dst_thread_scratch_.template GetAsType<dst_vector_t>(dst_data_idx_seq)}; dst_vector_type{dst_thread_scratch_.template GetAsType<dst_vector_t>(dst_data_idx)};
static_for<0, DstScalarPerVector, 1>{}([&](auto i) { static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
DstData dst_v; DstData dst_v;
...@@ -468,165 +321,54 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -468,165 +321,54 @@ struct ThreadwiseTensorSliceTransfer_v3r1
is_dst_valid, is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]); dst_vector_container.template AsType<dst_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr // move dst coord
if constexpr(iAccess.value != num_access - 1)
{ {
StaticallyIndexedArray<bool, nDim> move_on_dim_; constexpr auto step = DstSpaceFillingCurve::GetForwardStep(iAccess);
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) { move_tensor_coordinate(dst_desc, dst_coord_, step);
move_on_dim_(i) &=
ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
});
});
return move_on_dim_;
} }
();
// move dst coord
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(
dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
}
else
{
move_tensor_coordinate(
dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
}
}
});
}); });
// move dst coordinate back to slice origin (or not) // move dst coordinate back to slice origin (or not)
if constexpr(DstResetCoordinateAfterRun) if constexpr(DstResetCoordinateAfterRun)
{ {
const auto dst_reset_step = move_tensor_coordinate(dst_desc, dst_coord_, GetDstCoordinateResetStep());
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
} }
} }
__device__ static constexpr auto GetSrcCoordinateResetStep() __device__ static constexpr auto GetSrcCoordinateResetStep()
{ {
// scalar per access on each dim constexpr auto num_access = SrcSpaceFillingCurve::GetNumOfAccess();
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_lengths[I0] - 1;
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate src data index after last iteration in RunRead(), if it has not being reset by
// RunRead()
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
});
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_scalar_per_access;
}();
//
constexpr auto reset_src_data_step = [&]() {
Index reset_src_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
return reset_src_data_step_; if constexpr(num_access == 0)
}(); {
return typename SrcSpaceFillingCurve::Index{};
}
else
{
constexpr auto reset_step =
SrcSpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
return reset_src_data_step; return reset_step;
}
} }
__device__ static constexpr auto GetDstCoordinateResetStep() __device__ static constexpr auto GetDstCoordinateResetStep()
{ {
// scalar per access on each dim constexpr auto num_access = DstSpaceFillingCurve::GetNumOfAccess();
// TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_lengths[I0] - 1;
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
// RunWrite()
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) { if constexpr(num_access == 0)
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; {
}); return typename DstSpaceFillingCurve::Index{};
}
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * else
dst_scalar_per_access; {
}(); constexpr auto reset_step =
DstSpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
//
constexpr auto reset_dst_data_step = [&]() {
Index reset_dst_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
return reset_dst_data_step_;
}();
return reset_dst_data_step; return reset_step;
}
} }
// src_slice_origin_step_idx need to be known at compile-time, for performance reason // src_slice_origin_step_idx need to be known at compile-time, for performance reason
...@@ -638,10 +380,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -638,10 +380,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep(); : src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time? move_tensor_coordinate(src_desc, src_coord_, adjusted_step_idx);
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
} }
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
...@@ -653,17 +392,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -653,17 +392,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetDstCoordinateResetStep(); : dst_slice_origin_step_idx + GetDstCoordinateResetStep();
// is it OK to construct a new step every time? move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step_idx);
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
} }
__device__ static constexpr auto GetSrcThreadScratchDescriptor() __device__ static constexpr auto GetSrcThreadScratchDescriptor()
{ {
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_access_lengths_and_vector_length = container_push_back( constexpr auto src_access_lengths_and_vector_length = container_push_back(
...@@ -711,9 +444,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -711,9 +444,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
__device__ static constexpr auto GetDstThreadScratchDescriptor() __device__ static constexpr auto GetDstThreadScratchDescriptor()
{ {
// 1st stage of transforms // 1st stage of transforms
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_access_lengths_and_vector_length = container_push_back( constexpr auto dst_access_lengths_and_vector_length = container_push_back(
...@@ -761,11 +491,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -761,11 +491,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
#if 1 // debug
using SrcThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr, using SrcThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
SrcData, SrcData,
SrcScalarPerVector, SrcScalarPerVector,
decltype(src_thread_scratch_desc_), decltype(src_thread_scratch_desc_),
true>; true>;
#else
using SrcThreadScratch = ThreadPrivateTensor<SrcData, decltype(src_thread_scratch_desc_)>;
#endif
using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr, using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData, DstData,
......
...@@ -133,8 +133,8 @@ struct ThreadwiseTensorSliceTransfer_v4r1 ...@@ -133,8 +133,8 @@ struct ThreadwiseTensorSliceTransfer_v4r1
using src_vector_t = typename decltype(src_vector)::type; using src_vector_t = typename decltype(src_vector)::type;
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( const bool is_src_valid =
src_desc, src_data_coord); coordinate_has_valid_offset_assuming_top_index_is_valid(src_desc, src_data_coord);
// copy data from src_buf into src_vector // copy data from src_buf into src_vector
src_vector.template AsType<src_vector_t>()(I0) = src_vector.template AsType<src_vector_t>()(I0) =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment