Commit e444281c authored by Chao Liu's avatar Chao Liu
Browse files

initial implementation for magic number division and DynamicMerge_v2_magic_division that uses it

parent 6d4aefcd
...@@ -467,29 +467,10 @@ struct DynamicEmbed ...@@ -467,29 +467,10 @@ struct DynamicEmbed
} }
}; };
#if 1 // Implementation of "Merge" transformation primitive that uses regular to do lowering of
template <typename LowLengths> // multi-index and use carry-and-borrow check to do lowering of multi-index delta
struct lambda_merge_generate_magic_division_calculate_magic_multiplier
{
template <index_t I>
__host__ __device__ constexpr auto operator()(Number<I> i) const
{
return magic_division::CalculateMagicMultiplier(LowLengths{}[i]);
}
};
template <typename LowLengths> template <typename LowLengths>
struct lambda_merge_generate_magic_division_calculate_magic_shift struct DynamicMerge_v1_carry_check
{
template <index_t I>
__host__ __device__ constexpr auto operator()(Number<I> i) const
{
return magic_division::CalculateMagicShift(LowLengths{}[i]);
}
};
template <typename LowLengths>
struct DynamicMerge
{ {
static constexpr index_t NDimLow = LowLengths::Size(); static constexpr index_t NDimLow = LowLengths::Size();
...@@ -499,35 +480,19 @@ struct DynamicMerge ...@@ -499,35 +480,19 @@ struct DynamicMerge
using LowLengthsScan = decltype( using LowLengthsScan = decltype(
container_reverse_exclusive_scan(LowLengths{}, math::multiplies_v2{}, Number<1>{})); container_reverse_exclusive_scan(LowLengths{}, math::multiplies_v2{}, Number<1>{}));
using LowLengthsMagicDivisorMultipiler = decltype(generate_tuple(
lambda_merge_generate_magic_division_calculate_magic_multiplier<LowLengths>{},
Number<NDimLow>{}));
using LowLengthsMagicDivisorShift = decltype(
generate_tuple(lambda_merge_generate_magic_division_calculate_magic_shift<LowLengths>{},
Number<NDimLow>{}));
using UpLengths = using UpLengths =
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies_v2{}, Number<1>{}))); decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies_v2{}, Number<1>{})));
LowLengths low_lengths_; LowLengths low_lengths_;
LowLengthsScan low_lengths_scan_; LowLengthsScan low_lengths_scan_;
LowLengthsMagicDivisorMultipiler low_lengths_magic_divisor_multiplier_;
LowLengthsMagicDivisorShift low_lengths_magic_divisor_shift_;
UpLengths up_lengths_; UpLengths up_lengths_;
__host__ __device__ constexpr DynamicMerge() = default; __host__ __device__ constexpr DynamicMerge_v1_carry_check() = default;
__host__ __device__ constexpr DynamicMerge(const LowLengths& low_lengths) __host__ __device__ constexpr DynamicMerge_v1_carry_check(const LowLengths& low_lengths)
: low_lengths_{low_lengths}, : low_lengths_{low_lengths},
low_lengths_scan_{ low_lengths_scan_{
container_reverse_exclusive_scan(low_lengths, math::multiplies_v2{}, Number<1>{})}, container_reverse_exclusive_scan(low_lengths, math::multiplies_v2{}, Number<1>{})},
low_lengths_magic_divisor_multiplier_{generate_tuple(
[&](auto i) { return magic_division::CalculateMagicMultiplier(low_lengths[i]); },
Number<NDimLow>{})},
low_lengths_magic_divisor_shift_{generate_tuple(
[&](auto i) { return magic_division::CalculateMagicShift(low_lengths[i]); },
Number<NDimLow>{})},
up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies_v2{}, Number<1>{}))} up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies_v2{}, Number<1>{}))}
{ {
static_assert(LowerIndex::Size() == NDimLow, "wrong!"); static_assert(LowerIndex::Size() == NDimLow, "wrong!");
...@@ -548,7 +513,6 @@ struct DynamicMerge ...@@ -548,7 +513,6 @@ struct DynamicMerge
index_t tmp = idx_up[Number<0>{}]; index_t tmp = idx_up[Number<0>{}];
#if 1
// normal division // normal division
static_for<0, NDimLow - 1, 1>{}([&](auto i) { static_for<0, NDimLow - 1, 1>{}([&](auto i) {
idx_low(i) = tmp / this->low_lengths_scan_[i]; idx_low(i) = tmp / this->low_lengths_scan_[i];
...@@ -556,19 +520,6 @@ struct DynamicMerge ...@@ -556,19 +520,6 @@ struct DynamicMerge
}); });
idx_low(Number<NDimLow - 1>{}) = tmp; idx_low(Number<NDimLow - 1>{}) = tmp;
#else
// magic division
static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
index_t tmp2 =
magic_division::DoMagicDivision(tmp,
this->low_lengths_magic_divisor_multiplier_[i],
this->low_lengths_magic_divisor_shift_[i]);
idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
tmp = tmp2;
});
idx_low(Number<0>{}) = tmp;
#endif
} }
template <typename LowIdxDiff, template <typename LowIdxDiff,
...@@ -1030,7 +981,7 @@ struct DynamicMerge ...@@ -1030,7 +981,7 @@ struct DynamicMerge
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("{"); printf("{");
printf("DynamicMerge, "); printf("DynamicMerge_v1_carry_check, ");
printf("low_lengths_ "); printf("low_lengths_ ");
print_multi_index(low_lengths_); print_multi_index(low_lengths_);
printf("low_lengths_scan_ "); printf("low_lengths_scan_ ");
...@@ -1040,29 +991,41 @@ struct DynamicMerge ...@@ -1040,29 +991,41 @@ struct DynamicMerge
printf("}"); printf("}");
} }
}; };
#else
template <typename LowLengths> template <typename LowLengths>
struct lambda_merge_generate_magic_division_calculate_magic_multiplier struct lambda_merge_generate_MagicDivision_calculate_magic_multiplier
{ {
template <index_t I> template <index_t I>
__host__ __device__ constexpr auto operator()(Number<I> i) const __host__ __device__ constexpr auto operator()(Number<I> i) const
{ {
return magic_division::CalculateMagicMultiplier(LowLengths{}[i]); return MagicDivision::CalculateMagicMultiplier(LowLengths{}[i]);
} }
}; };
template <typename LowLengths> template <typename LowLengths>
struct lambda_merge_generate_magic_division_calculate_magic_shift struct lambda_merge_generate_MagicDivision_calculate_magic_shift
{ {
template <index_t I> template <index_t I>
__host__ __device__ constexpr auto operator()(Number<I> i) const __host__ __device__ constexpr auto operator()(Number<I> i) const
{ {
return magic_division::CalculateMagicShift(LowLengths{}[i]); return MagicDivision::CalculateMagicShift(LowLengths{}[i]);
} }
}; };
// Implementation of "Merge" transformation primitive that uses magic-number-division to do lowering
// of both multi-index and delta of multi-index
// Caution:
// 1. The magic number division implementation being used would produce correct result if the
// dividended is uint32_t and its value is with in 31-bit value range of uint32_t.
// 2. The magic number division for int32_t dividened has not been implemented, the int32_t
// dividend would be bit-wise interpreted as uint32_t and magic number division implementation for
// uint32_t is then used.
// 3. For Merge primitive, upper-index is the dividend.
// 4. When upper-index is uint32_t, its value need to be within 31-bit range.
// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
// non-negative.
template <typename LowLengths> template <typename LowLengths>
struct DynamicMerge struct DynamicMerge_v2_magic_division
{ {
static constexpr index_t NDimLow = LowLengths::Size(); static constexpr index_t NDimLow = LowLengths::Size();
...@@ -1072,12 +1035,12 @@ struct DynamicMerge ...@@ -1072,12 +1035,12 @@ struct DynamicMerge
using UpLengths = using UpLengths =
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies_v2{}, Number<1>{}))); decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies_v2{}, Number<1>{})));
using LowLengthsMagicDivisorMultipiler = decltype(generate_tuple( using LowLengthsMagicDivisorMultipiler = decltype(
lambda_merge_generate_magic_division_calculate_magic_multiplier<LowLengths>{}, generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_multiplier<LowLengths>{},
Number<NDimLow>{})); Number<NDimLow>{}));
using LowLengthsMagicDivisorShift = decltype( using LowLengthsMagicDivisorShift = decltype(
generate_tuple(lambda_merge_generate_magic_division_calculate_magic_shift<LowLengths>{}, generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_shift<LowLengths>{},
Number<NDimLow>{})); Number<NDimLow>{}));
LowLengths low_lengths_; LowLengths low_lengths_;
...@@ -1085,15 +1048,15 @@ struct DynamicMerge ...@@ -1085,15 +1048,15 @@ struct DynamicMerge
LowLengthsMagicDivisorShift low_lengths_magic_divisor_shift_; LowLengthsMagicDivisorShift low_lengths_magic_divisor_shift_;
UpLengths up_lengths_; UpLengths up_lengths_;
__host__ __device__ constexpr DynamicMerge() = default; __host__ __device__ constexpr DynamicMerge_v2_magic_division() = default;
__host__ __device__ constexpr DynamicMerge(const LowLengths& low_lengths) __host__ __device__ constexpr DynamicMerge_v2_magic_division(const LowLengths& low_lengths)
: low_lengths_{low_lengths}, : low_lengths_{low_lengths},
low_lengths_magic_divisor_multiplier_{generate_tuple( low_lengths_magic_divisor_multiplier_{generate_tuple(
[&](auto i) { return magic_division::CalculateMagicMultiplier(low_lengths[i]); }, [&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths[i]); },
Number<NDimLow>{})}, Number<NDimLow>{})},
low_lengths_magic_divisor_shift_{generate_tuple( low_lengths_magic_divisor_shift_{generate_tuple(
[&](auto i) { return magic_division::CalculateMagicShift(low_lengths[i]); }, [&](auto i) { return MagicDivision::CalculateMagicShift(low_lengths[i]); },
Number<NDimLow>{})}, Number<NDimLow>{})},
up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies_v2{}, Number<1>{}))} up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies_v2{}, Number<1>{}))}
{ {
...@@ -1115,9 +1078,9 @@ struct DynamicMerge ...@@ -1115,9 +1078,9 @@ struct DynamicMerge
index_t tmp = idx_up[Number<0>{}]; index_t tmp = idx_up[Number<0>{}];
static_for<NDimLow - 1, 0, -1>{}([&idx_low, &tmp, this](auto i) { static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
index_t tmp2 = index_t tmp2 =
magic_division::DoMagicDivision(tmp, MagicDivision::DoMagicDivision(tmp,
this->low_lengths_magic_divisor_multiplier_[i], this->low_lengths_magic_divisor_multiplier_[i],
this->low_lengths_magic_divisor_shift_[i]); this->low_lengths_magic_divisor_shift_[i]);
idx_low(i) = tmp - tmp2 * this->low_lengths_[i]; idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
...@@ -1142,11 +1105,25 @@ struct DynamicMerge ...@@ -1142,11 +1105,25 @@ struct DynamicMerge
LowIdx::Size() == NDimLow && UpIdx::Size() == 1, LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension"); "wrong! inconsistent # of dimension");
auto idx_low_old = idx_low; index_t tmp = idx_up_new[Number<0>{}];
CalculateLowerIndex(idx_low, idx_up_new); static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
index_t tmp2 =
MagicDivision::DoMagicDivision(tmp,
this->low_lengths_magic_divisor_multiplier_[i],
this->low_lengths_magic_divisor_shift_[i]);
idx_diff_low = idx_low - idx_low_old; index_t idx_low_old = idx_low[i];
idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
tmp = tmp2;
idx_diff_low(i) = idx_low[i] - idx_low_old;
});
idx_diff_low(Number<0>{}) = tmp - idx_low(Number<0>{});
idx_low(Number<0>{}) = tmp;
} }
__host__ __device__ static constexpr bool IsLinearTransform() { return false; } __host__ __device__ static constexpr bool IsLinearTransform() { return false; }
...@@ -1174,7 +1151,7 @@ struct DynamicMerge ...@@ -1174,7 +1151,7 @@ struct DynamicMerge
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("{"); printf("{");
printf("DynamicMerge, "); printf("DynamicMerge_v2_magic_division, ");
printf("low_lengths_ "); printf("low_lengths_ ");
print_multi_index(low_lengths_); print_multi_index(low_lengths_);
printf("low_lengths_magic_divisor_multiplier_ "); printf("low_lengths_magic_divisor_multiplier_ ");
...@@ -1186,7 +1163,6 @@ struct DynamicMerge ...@@ -1186,7 +1163,6 @@ struct DynamicMerge
printf("}"); printf("}");
} }
}; };
#endif
template <typename UpLengths, bool Use24BitIntegerCalculation> template <typename UpLengths, bool Use24BitIntegerCalculation>
struct DynamicUnMerge struct DynamicUnMerge
......
...@@ -53,7 +53,11 @@ __host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_leng ...@@ -53,7 +53,11 @@ __host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_leng
template <typename LowLengths> template <typename LowLengths>
__host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths) __host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths)
{ {
return DynamicMerge<LowLengths>{low_lengths}; #if !CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
return DynamicMerge_v1_carry_check<LowLengths>{low_lengths};
#else
return DynamicMerge_v2_magic_division<LowLengths>{low_lengths};
#endif
} }
template <typename UpLengths, bool Use24BitIntegerCalculation = false> template <typename UpLengths, bool Use24BitIntegerCalculation = false>
......
...@@ -115,6 +115,9 @@ ...@@ -115,6 +115,9 @@
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 1 #define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 1
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 0 #define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 0
// merge transformation use magic number division
#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 1
// hack: have underlying assumption that need to be satsified, otherwise it's a bug // hack: have underlying assumption that need to be satsified, otherwise it's a bug
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be // hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
// thread-invariant, otherwise it's a bug // thread-invariant, otherwise it's a bug
......
...@@ -10,7 +10,17 @@ ...@@ -10,7 +10,17 @@
namespace ck { namespace ck {
// magic number division // magic number division
struct magic_division // Caution:
// 1. For uint32_t as dividend: magic number division implementation being used would produce
// correct result if the dividend is uint32_t and its value is within 31-bit value range.
// 2. For int32_t as dividendd: magic number division for int32_t dividened has not been
// implemented, the int32_t dividend would be bit-wise interpreted as uint32_t and magic number
// division implementation for uint32_t is then used. Therefore, dividend value need to be
// non-negative.
// TODO:
// 1. Implement magic number divison for int32_t
// 2. Implement magic number divison for unit32_t with 32-bit value range
struct MagicDivision
{ {
// uint32_t // uint32_t
__host__ __device__ static constexpr auto CalculateMagicNumbers(uint32_t divisor) __host__ __device__ static constexpr auto CalculateMagicNumbers(uint32_t divisor)
...@@ -100,13 +110,25 @@ struct magic_division ...@@ -100,13 +110,25 @@ struct magic_division
return CalculateMagicShift(integral_constant<uint32_t, Divisor>{}); return CalculateMagicShift(integral_constant<uint32_t, Divisor>{});
} }
// magic division // magic division for uint32_t
__host__ __device__ static constexpr uint32_t __host__ __device__ static constexpr uint32_t
DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift) DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{ {
uint32_t tmp = (uint64_t(dividend) * uint64_t(multiplier)) >> 32; uint32_t tmp = (uint64_t(dividend) * uint64_t(multiplier)) >> 32;
return (tmp + dividend) >> shift; return (tmp + dividend) >> shift;
} }
// HACK: magic division for int32_t
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
// non-negative for result to be correct
// TODO: figure out how to do magic number divison for int32_t as dividended
__host__ __device__ static constexpr int32_t
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = as_type<uint32_t>(dividend_i32);
uint32_t tmp = ((uint64_t)dividend_u32 * (uint64_t)multiplier) >> 32;
return (tmp + dividend_i32) >> shift;
}
}; };
} // namespace ck } // namespace ck
......
...@@ -42,5 +42,19 @@ struct is_known_at_compile_time<integral_constant<T, X>> ...@@ -42,5 +42,19 @@ struct is_known_at_compile_time<integral_constant<T, X>>
static constexpr bool value = true; static constexpr bool value = true;
}; };
template <typename Y,
typename X,
typename std::enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
__host__ __device__ constexpr Y as_type(X x)
{
union AsType
{
X x;
Y y;
};
return AsType{x}.y;
}
} // namespace ck } // namespace ck
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment