Commit cc391773 authored by Jing Zhang's avatar Jing Zhang
Browse files

merge master

parents ff4b1b1d e4790c25
...@@ -45,7 +45,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -45,7 +45,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
const InRightPads& in_right_pads, const InRightPads& in_right_pads,
const FloatAB* __restrict__ p_wei_global, const FloatAB* __restrict__ p_wei_global,
const FloatAB* __restrict__ p_in_global, const FloatAB* __restrict__ p_in_global,
const FloatC* __restrict__ p_d_global,
FloatC* __restrict__ p_out_global) const FloatC* __restrict__ p_out_global) const
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -237,7 +236,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -237,7 +236,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
const FloatAB*, const FloatAB*,
decltype(out_k_n_ho_wo_global_desc), decltype(out_k_n_ho_wo_global_desc),
const FloatC*,
FloatC*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -252,7 +250,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -252,7 +250,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
out_k_n_ho_wo_global_desc, out_k_n_ho_wo_global_desc,
p_d_global,
p_out_global, p_out_global,
integral_constant<bool, true>{}, integral_constant<bool, true>{},
integral_constant<bool, true>{}); integral_constant<bool, true>{});
...@@ -265,7 +262,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -265,7 +262,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
const FloatAB*, const FloatAB*,
decltype(out_k_n_ho_wo_global_desc), decltype(out_k_n_ho_wo_global_desc),
const FloatC*,
FloatC*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -280,7 +276,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -280,7 +276,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
out_k_n_ho_wo_global_desc, out_k_n_ho_wo_global_desc,
p_d_global,
p_out_global, p_out_global,
integral_constant<bool, true>{}, integral_constant<bool, true>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
...@@ -293,7 +288,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -293,7 +288,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
const FloatAB*, const FloatAB*,
decltype(out_k_n_ho_wo_global_desc), decltype(out_k_n_ho_wo_global_desc),
const FloatC*,
FloatC*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -308,7 +302,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -308,7 +302,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
out_k_n_ho_wo_global_desc, out_k_n_ho_wo_global_desc,
p_d_global,
p_out_global, p_out_global,
integral_constant<bool, false>{}, integral_constant<bool, false>{},
integral_constant<bool, true>{}); integral_constant<bool, true>{});
...@@ -321,7 +314,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -321,7 +314,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
const FloatAB*, const FloatAB*,
decltype(out_k_n_ho_wo_global_desc), decltype(out_k_n_ho_wo_global_desc),
const FloatC*,
FloatC*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -336,7 +328,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -336,7 +328,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
p_in_global, p_in_global,
out_k_n_ho_wo_global_desc, out_k_n_ho_wo_global_desc,
p_d_global,
p_out_global, p_out_global,
integral_constant<bool, false>{}, integral_constant<bool, false>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
......
...@@ -47,7 +47,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -47,7 +47,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const InRightPads& in_right_pads, const InRightPads& in_right_pads,
const FloatAB* __restrict__ p_wei_global, const FloatAB* __restrict__ p_wei_global,
const FloatAB* __restrict__ p_in_global, const FloatAB* __restrict__ p_in_global,
FloatC* __restrict__ p_d_global, const FloatC* __restrict__ p_d_global,
FloatC* __restrict__ p_out_global) const FloatC* __restrict__ p_out_global) const
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -271,7 +271,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -271,7 +271,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
const FloatAB*, const FloatAB*,
decltype(add_k_n_hopx2_wopx2_global_desc), decltype(add_k_n_hopx2_wopx2_global_desc),
FloatAB*, const FloatC*,
decltype(out_k_n_hop_wop_global_desc), decltype(out_k_n_hop_wop_global_desc),
FloatC*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
...@@ -302,7 +302,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -302,7 +302,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
const FloatAB*, const FloatAB*,
decltype(add_k_n_hopx2_wopx2_global_desc), decltype(add_k_n_hopx2_wopx2_global_desc),
FloatAB*, const FloatC*,
decltype(out_k_n_hop_wop_global_desc), decltype(out_k_n_hop_wop_global_desc),
FloatC*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
...@@ -333,7 +333,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -333,7 +333,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
const FloatAB*, const FloatAB*,
decltype(add_k_n_hopx2_wopx2_global_desc), decltype(add_k_n_hopx2_wopx2_global_desc),
FloatAB*, const FloatC*,
decltype(out_k_n_hop_wop_global_desc), decltype(out_k_n_hop_wop_global_desc),
FloatC*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
...@@ -364,7 +364,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -364,7 +364,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
decltype(in_e_n_ho_wo_global_desc), decltype(in_e_n_ho_wo_global_desc),
const FloatAB*, const FloatAB*,
decltype(add_k_n_hopx2_wopx2_global_desc), decltype(add_k_n_hopx2_wopx2_global_desc),
FloatAB*, const FloatC*,
decltype(out_k_n_hop_wop_global_desc), decltype(out_k_n_hop_wop_global_desc),
FloatC*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
......
...@@ -467,8 +467,10 @@ struct DynamicEmbed ...@@ -467,8 +467,10 @@ struct DynamicEmbed
} }
}; };
// Implementation of "Merge" transformation primitive that uses regular to do lowering of
// multi-index and use carry-and-borrow check to do lowering of multi-index delta
template <typename LowLengths> template <typename LowLengths>
struct DynamicMerge struct DynamicMerge_v1_carry_check
{ {
static constexpr index_t NDimLow = LowLengths::Size(); static constexpr index_t NDimLow = LowLengths::Size();
...@@ -485,9 +487,9 @@ struct DynamicMerge ...@@ -485,9 +487,9 @@ struct DynamicMerge
LowLengthsScan low_lengths_scan_; LowLengthsScan low_lengths_scan_;
UpLengths up_lengths_; UpLengths up_lengths_;
__host__ __device__ constexpr DynamicMerge() = default; __host__ __device__ constexpr DynamicMerge_v1_carry_check() = default;
__host__ __device__ constexpr DynamicMerge(const LowLengths& low_lengths) __host__ __device__ constexpr DynamicMerge_v1_carry_check(const LowLengths& low_lengths)
: low_lengths_{low_lengths}, : low_lengths_{low_lengths},
low_lengths_scan_{ low_lengths_scan_{
container_reverse_exclusive_scan(low_lengths, math::multiplies_v2{}, Number<1>{})}, container_reverse_exclusive_scan(low_lengths, math::multiplies_v2{}, Number<1>{})},
...@@ -511,7 +513,8 @@ struct DynamicMerge ...@@ -511,7 +513,8 @@ struct DynamicMerge
index_t tmp = idx_up[Number<0>{}]; index_t tmp = idx_up[Number<0>{}];
static_for<0, NDimLow - 1, 1>{}([&idx_low, &tmp, this](auto i) { // normal division
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
idx_low(i) = tmp / this->low_lengths_scan_[i]; idx_low(i) = tmp / this->low_lengths_scan_[i];
tmp -= idx_low[i] * this->low_lengths_scan_[i]; tmp -= idx_low[i] * this->low_lengths_scan_[i];
}); });
...@@ -978,7 +981,7 @@ struct DynamicMerge ...@@ -978,7 +981,7 @@ struct DynamicMerge
__host__ __device__ void Print() const __host__ __device__ void Print() const
{ {
printf("{"); printf("{");
printf("DynamicMerge, "); printf("DynamicMerge_v1_carry_check, ");
printf("low_lengths_ "); printf("low_lengths_ ");
print_multi_index(low_lengths_); print_multi_index(low_lengths_);
printf("low_lengths_scan_ "); printf("low_lengths_scan_ ");
...@@ -989,6 +992,178 @@ struct DynamicMerge ...@@ -989,6 +992,178 @@ struct DynamicMerge
} }
}; };
template <typename LowLengths>
struct lambda_merge_generate_MagicDivision_calculate_magic_multiplier
{
template <index_t I>
__host__ __device__ constexpr auto operator()(Number<I> i) const
{
return MagicDivision::CalculateMagicMultiplier(LowLengths{}[i]);
}
};
template <typename LowLengths>
struct lambda_merge_generate_MagicDivision_calculate_magic_shift
{
template <index_t I>
__host__ __device__ constexpr auto operator()(Number<I> i) const
{
return MagicDivision::CalculateMagicShift(LowLengths{}[i]);
}
};
// Implementation of "Merge" transformation primitive that uses magic-number-division to do lowering
// of both multi-index and delta of multi-index
// Caution:
// 1. The magic number division implementation being used would produce correct result if the
// dividended is uint32_t and its value is with in 31-bit value range of uint32_t.
// 2. The magic number division for int32_t dividened has not been implemented, the int32_t
// dividend would be bit-wise interpreted as uint32_t and magic number division implementation for
// uint32_t is then used.
// 3. For Merge primitive, upper-index is the dividend.
// 4. When upper-index is uint32_t, its value need to be within 31-bit range.
// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
// non-negative.
template <typename LowLengths>
struct DynamicMerge_v2_magic_division
{
static constexpr index_t NDimLow = LowLengths::Size();
using LowerIndex = MultiIndex<NDimLow>;
using UpperIndex = MultiIndex<1>;
using UpLengths =
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies_v2{}, Number<1>{})));
using LowLengthsMagicDivisorMultipiler = decltype(
generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_multiplier<LowLengths>{},
Number<NDimLow>{}));
using LowLengthsMagicDivisorShift = decltype(
generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_shift<LowLengths>{},
Number<NDimLow>{}));
LowLengths low_lengths_;
LowLengthsMagicDivisorMultipiler low_lengths_magic_divisor_multiplier_;
LowLengthsMagicDivisorShift low_lengths_magic_divisor_shift_;
UpLengths up_lengths_;
__host__ __device__ constexpr DynamicMerge_v2_magic_division() = default;
__host__ __device__ constexpr DynamicMerge_v2_magic_division(const LowLengths& low_lengths)
: low_lengths_{low_lengths},
low_lengths_magic_divisor_multiplier_{generate_tuple(
[&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths[i]); },
Number<NDimLow>{})},
low_lengths_magic_divisor_shift_{generate_tuple(
[&](auto i) { return MagicDivision::CalculateMagicShift(low_lengths[i]); },
Number<NDimLow>{})},
up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies_v2{}, Number<1>{}))}
{
static_assert(LowerIndex::Size() == NDimLow, "wrong!");
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
index_t tmp = idx_up[Number<0>{}];
static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
index_t tmp2 =
MagicDivision::DoMagicDivision(tmp,
this->low_lengths_magic_divisor_multiplier_[i],
this->low_lengths_magic_divisor_shift_[i]);
idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
tmp = tmp2;
});
idx_low(Number<0>{}) = tmp;
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx& idx_up_new,
Number<Hack>) const
{
static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
index_t tmp = idx_up_new[Number<0>{}];
static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
index_t tmp2 =
MagicDivision::DoMagicDivision(tmp,
this->low_lengths_magic_divisor_multiplier_[i],
this->low_lengths_magic_divisor_shift_[i]);
index_t idx_low_old = idx_low[i];
idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
tmp = tmp2;
idx_diff_low(i) = idx_low[i] - idx_low_old;
});
idx_diff_low(Number<0>{}) = tmp - idx_low(Number<0>{});
idx_low(Number<0>{}) = tmp;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<LowLengths>::value &&
is_known_at_compile_time<LowLengthsMagicDivisorMultipiler>::value &&
is_known_at_compile_time<LowLengthsMagicDivisorShift>::value &&
is_known_at_compile_time<UpLengths>::value;
}
template <typename UpIdx>
__host__ __device__ static constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
{
return true;
}
__host__ __device__ void Print() const
{
printf("{");
printf("DynamicMerge_v2_magic_division, ");
printf("low_lengths_ ");
print_multi_index(low_lengths_);
printf("low_lengths_magic_divisor_multiplier_ ");
print_multi_index(low_lengths_magic_divisor_multiplier_);
printf("low_lengths_magic_divisor_shift_ ");
print_multi_index(low_lengths_magic_divisor_shift_);
printf("up_lengths_ ");
print_multi_index(up_lengths_);
printf("}");
}
};
template <typename UpLengths, bool Use24BitIntegerCalculation> template <typename UpLengths, bool Use24BitIntegerCalculation>
struct DynamicUnMerge struct DynamicUnMerge
{ {
......
...@@ -53,7 +53,11 @@ __host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_leng ...@@ -53,7 +53,11 @@ __host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_leng
template <typename LowLengths> template <typename LowLengths>
__host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths) __host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths)
{ {
return DynamicMerge<LowLengths>{low_lengths}; #if !CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
return DynamicMerge_v1_carry_check<LowLengths>{low_lengths};
#else
return DynamicMerge_v2_magic_division<LowLengths>{low_lengths};
#endif
} }
template <typename UpLengths, bool Use24BitIntegerCalculation = false> template <typename UpLengths, bool Use24BitIntegerCalculation = false>
......
...@@ -104,7 +104,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -104,7 +104,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
using vector_t = typename vector_type<Data, DataPerAccess>::type; using vector_t = typename vector_type_maker<Data, DataPerAccess>::type::type;
static_for<0, NSliceRow, 1>{}([&](auto i) { static_for<0, NSliceRow, 1>{}([&](auto i) {
static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) { static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) {
......
...@@ -75,7 +75,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -75,7 +75,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const BGlobalDesc& b_e_n_ho_wo_global_desc, const BGlobalDesc& b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const DGlobalDesc& d_k_n_hox2_wox2_global_desc, const DGlobalDesc& d_k_n_hox2_wox2_global_desc,
FloatAB* __restrict__ p_d_global, const FloatC* __restrict__ p_d_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc, const CGlobalDesc& c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
...@@ -392,8 +392,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -392,8 +392,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
for(index_t w_i = 0; w_i < WoPerThreadx2; ++w_i) for(index_t w_i = 0; w_i < WoPerThreadx2; ++w_i)
{ {
ThreadwiseDynamicTensorSliceTransfer_v2< ThreadwiseDynamicTensorSliceTransfer_v2<
FloatAB, FloatC,
FloatAB, FloatC,
decltype(d_k_n_hox2_wox2_global_desc), decltype(d_k_n_hox2_wox2_global_desc),
decltype(d_k_n_hox2_wox2_thread_desc), decltype(d_k_n_hox2_wox2_thread_desc),
Sequence<1, 1, 1, 1>, Sequence<1, 1, 1, 1>,
...@@ -414,18 +414,18 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -414,18 +414,18 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_d_global, p_d_global,
d_k_n_hox2_wox2_thread_desc, d_k_n_hox2_wox2_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
&(d_vec.Vector()), &(d_vec.template AsType<FloatC>()(Number<0>{})),
c_k_n_ho_wo_global_tensor_iterator_hacks); c_k_n_ho_wo_global_tensor_iterator_hacks);
static_for<0, vector_len, 1>{}([&](auto i) { static_for<0, vector_len, 1>{}([&](auto i) {
d_vec.Scalars()(i) += d_vec.template AsType<int8_t>()(i) +=
p_c_thread[c_k_n_ho_wo_thread_desc.CalculateOffset( p_c_thread[c_k_n_ho_wo_thread_desc.CalculateOffset(
make_tuple(k_i * vector_len + i, 0, h_i / 2, w_i / 2))]; make_tuple(k_i * vector_len + i, 0, h_i / 2, w_i / 2))];
}); });
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatAB, FloatC,
FloatAB, FloatC,
decltype(d_k_n_hox2_wox2_thread_desc), decltype(d_k_n_hox2_wox2_thread_desc),
decltype(d_k_n_hox2_wox2_global_desc), decltype(d_k_n_hox2_wox2_global_desc),
Sequence<1, 1, 1, 1>, Sequence<1, 1, 1, 1>,
...@@ -444,7 +444,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -444,7 +444,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
wox2_thread_data_on_global + w_i)) wox2_thread_data_on_global + w_i))
.Run(d_k_n_hox2_wox2_thread_desc, .Run(d_k_n_hox2_wox2_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
&(d_vec.Vector()), &(d_vec.template AsType<FloatC>()[Number<0>{}]),
d_k_n_hox2_wox2_global_desc, d_k_n_hox2_wox2_global_desc,
p_c_global, p_c_global,
c_k_n_ho_wo_global_tensor_iterator_hacks); c_k_n_ho_wo_global_tensor_iterator_hacks);
...@@ -462,7 +462,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -462,7 +462,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const BGlobalDesc& b_e_n_ho_wo_global_desc, const BGlobalDesc& b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const DGlobalDesc& d_k_n_hox2_wox2_global_desc, const DGlobalDesc& d_k_n_hox2_wox2_global_desc,
FloatAB* __restrict__ p_d_global, const FloatC* __restrict__ p_d_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc, const CGlobalDesc& c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
...@@ -492,7 +492,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -492,7 +492,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const BGlobalDesc* p_b_e_n_ho_wo_global_desc, const BGlobalDesc* p_b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const DGlobalDesc& d_k_n_hox2_wox2_global_desc, const DGlobalDesc& d_k_n_hox2_wox2_global_desc,
FloatAB* __restrict__ p_d_global, const FloatC* __restrict__ p_d_global,
const CGlobalDesc* p_c_k_n_ho_wo_global_desc, const CGlobalDesc* p_c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
...@@ -521,7 +521,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -521,7 +521,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
const void* p_b_e_n_ho_wo_global_desc, const void* p_b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const DGlobalDesc& d_k_n_hox2_wox2_global_desc, const DGlobalDesc& d_k_n_hox2_wox2_global_desc,
FloatAB* __restrict__ p_d_global, const FloatC* __restrict__ p_d_global,
const void* p_c_k_n_ho_wo_global_desc, const void* p_c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
......
...@@ -172,16 +172,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -172,16 +172,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
}(); }();
// copy data // copy data
vector_type<DstData, DstScalarPerVector> dst_vector; typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
using dst_vector_t = typename vector_type<DstData, DstScalarPerVector>::type; using dst_vector_t =
typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
static_for<0, DstScalarPerVector, 1>{}([&](auto i) { static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = constexpr index_t src_offset =
src_desc.CalculateOffset(to_multi_index(src_slice_origin_idx) + dst_data_idx + src_desc.CalculateOffset(to_multi_index(src_slice_origin_idx) + dst_data_idx +
i * dst_scalar_step_in_vector); i * dst_scalar_step_in_vector);
dst_vector.Scalars()(i) = type_convert<DstData>{}(p_src[Number<src_offset>{}]); dst_vector.template AsType<DstData>()(i) =
type_convert<DstData>{}(p_src[Number<src_offset>{}]);
}); });
const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
...@@ -192,7 +194,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -192,7 +194,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
{ {
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ADDRESSING
amd_buffer_store_v2<DstData, DstScalarPerVector>( amd_buffer_store_v2<DstData, DstScalarPerVector>(
dst_vector.Vector(), dst_vector.template AsType<dst_vector_t>()(Number<0>{}),
p_dst, p_dst,
dst_slice_origin_coord_.GetOffset(), dst_slice_origin_coord_.GetOffset(),
is_dst_valid, is_dst_valid,
...@@ -201,7 +203,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -201,7 +203,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
if(is_dst_valid) if(is_dst_valid)
{ {
*reinterpret_cast<dst_vector_t*>( *reinterpret_cast<dst_vector_t*>(
&(p_dst[dst_slice_origin_coord_.GetOffset()])) = dst_vector.Vector(); &(p_dst[dst_slice_origin_coord_.GetOffset()])) =
dst_vector.template AsType<dst_vector_t>()[Number<0>{}];
} }
#endif #endif
} }
...@@ -210,7 +213,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -210,7 +213,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
if(is_dst_valid) if(is_dst_valid)
{ {
*reinterpret_cast<dst_vector_t*>( *reinterpret_cast<dst_vector_t*>(
&(p_dst[dst_slice_origin_coord_.GetOffset()])) = dst_vector.Vector(); &(p_dst[dst_slice_origin_coord_.GetOffset()])) =
dst_vector.template AsType<dst_vector_t>()[Number<0>{}];
} }
} }
...@@ -500,9 +504,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -500,9 +504,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
// copy data // copy data
static_assert(DstAddressSpace == AddressSpace::Vgpr, "wrong! hardcode for vgpr dst"); static_assert(DstAddressSpace == AddressSpace::Vgpr, "wrong! hardcode for vgpr dst");
vector_type<SrcData, SrcScalarPerVector> src_vector; typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector;
using src_vector_t = typename vector_type<SrcData, SrcScalarPerVector>::type; using src_vector_t =
typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_coord_); src_desc, src_slice_origin_coord_);
...@@ -510,24 +515,25 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -510,24 +515,25 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
if constexpr(SrcAddressSpace == AddressSpace::Global) if constexpr(SrcAddressSpace == AddressSpace::Global)
{ {
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ADDRESSING
src_vector.Vector() = amd_buffer_load_v2<SrcData, SrcScalarPerVector>( src_vector.template AsType<src_vector_t>()(Number<0>{}) =
p_src, amd_buffer_load_v2<SrcData, SrcScalarPerVector>(
src_slice_origin_coord_.GetOffset(), p_src,
is_src_valid, src_slice_origin_coord_.GetOffset(),
src_desc.GetElementSpaceSize()); is_src_valid,
src_desc.GetElementSpaceSize());
#else #else
src_vector.Vector() = is_src_valid src_vector.template AsType<src_vector_t>()(Number<0>{}) =
? *reinterpret_cast<const src_vector_t*>( is_src_valid ? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()]) &p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0}; : src_vector_t{0};
#endif #endif
} }
else else
{ {
src_vector.Vector() = is_src_valid src_vector.template AsType<src_vector_t>()(Number<0>{}) =
? *reinterpret_cast<const src_vector_t*>( is_src_valid ? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()]) &p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0}; : src_vector_t{0};
} }
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
...@@ -535,7 +541,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -535,7 +541,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx + dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
i * src_scalar_step_in_vector); i * src_scalar_step_in_vector);
p_dst[Number<dst_offset>{}] = src_vector.Scalars()[i]; p_dst[Number<dst_offset>{}] = src_vector.template AsType<SrcData>()[i];
}); });
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
...@@ -833,9 +839,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -833,9 +839,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
}(); }();
// copy data // copy data
vector_type<SrcData, SrcScalarPerVector> src_vector; typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector;
using src_vector_t = typename vector_type<SrcData, SrcScalarPerVector>::type; using src_vector_t =
typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_coord_); src_desc, src_slice_origin_coord_);
...@@ -843,31 +850,32 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -843,31 +850,32 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
if constexpr(SrcAddressSpace == AddressSpace::Global) if constexpr(SrcAddressSpace == AddressSpace::Global)
{ {
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ADDRESSING
src_vector.Vector() = amd_buffer_load_v2<SrcData, SrcScalarPerVector>( src_vector.template AsType<src_vector_t>()(Number<0>{}) =
p_src, amd_buffer_load_v2<SrcData, SrcScalarPerVector>(
src_slice_origin_coord_.GetOffset(), p_src,
is_src_valid, src_slice_origin_coord_.GetOffset(),
src_desc.GetElementSpaceSize()); is_src_valid,
src_desc.GetElementSpaceSize());
#else #else
src_vector.Vector() = is_src_valid src_vector.template AsType<src_vector_t>()(Number<0>{}) =
? *reinterpret_cast<const src_vector_t*>( is_src_valid ? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()]) &p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0}; : src_vector_t{0};
#endif #endif
} }
else else
{ {
src_vector.Vector() = is_src_valid src_vector.template AsType<src_vector_t>()(Number<0>{}) =
? *reinterpret_cast<const src_vector_t*>( is_src_valid ? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()]) &p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0}; : src_vector_t{0};
} }
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t buffer_offset = constexpr index_t buffer_offset =
buffer_desc_.CalculateOffset(src_data_idx + i * src_scalar_step_in_vector); buffer_desc_.CalculateOffset(src_data_idx + i * src_scalar_step_in_vector);
buffer_(Number<buffer_offset>{}) = src_vector.Scalars()[i]; buffer_(Number<buffer_offset>{}) = src_vector.template AsType<SrcData>()[i];
}); });
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
...@@ -1018,19 +1026,20 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1018,19 +1026,20 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
DstInMemOp == InMemoryDataOperation::Set, DstInMemOp == InMemoryDataOperation::Set,
"wrong! hardcoded for ds_write"); "wrong! hardcoded for ds_write");
vector_type<DstData, DstScalarPerVector> dst_vector; typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
static_for<0, DstScalarPerVector, 1>{}([&](auto i) { static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t buffer_offset = constexpr index_t buffer_offset =
buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector); buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector);
dst_vector.Scalars()(i) = buffer_[Number<buffer_offset>{}]; dst_vector.template AsType<DstData>()(i) = buffer_[Number<buffer_offset>{}];
}); });
using DstVectorType = typename vector_type<DstData, DstScalarPerVector>::type; using DstVectorType =
typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
*reinterpret_cast<DstVectorType*>(p_dst + dst_slice_origin_coord_.GetOffset()) = *reinterpret_cast<DstVectorType*>(p_dst + dst_slice_origin_coord_.GetOffset()) =
dst_vector.Vector(); dst_vector.template AsType<DstVectorType>()[Number<0>{}];
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
{ {
......
...@@ -41,7 +41,7 @@ struct ThreadwiseMatrixSliceCopy_v2 ...@@ -41,7 +41,7 @@ struct ThreadwiseMatrixSliceCopy_v2
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
using vector_t = typename vector_type<Data, DataPerAccess>::type; using vector_t = typename vector_type_maker<Data, DataPerAccess>::type::type;
static_for<0, NSliceRow, 1>{}([&](auto i) { static_for<0, NSliceRow, 1>{}([&](auto i) {
static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) { static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) {
......
...@@ -6,6 +6,17 @@ ...@@ -6,6 +6,17 @@
namespace ck { namespace ck {
template <typename T>
union BufferResource
{
// 128 bit SGPRs to supply buffer resource in buffer instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
int32x4_t data;
T* address[2];
int32_t range[4];
int32_t config[4];
};
__device__ float __llvm_amdgcn_buffer_load_f32(int32x4_t srsrc, __device__ float __llvm_amdgcn_buffer_load_f32(int32x4_t srsrc,
index_t vindex, index_t vindex,
index_t offset, index_t offset,
......
...@@ -72,6 +72,7 @@ amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, flo ...@@ -72,6 +72,7 @@ amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, flo
__device__ void __device__ void
amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1) amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1)
{ {
// TODO remove pointer casting
const half2_t* p_a_half2 = reinterpret_cast<const half2_t*>(&a); const half2_t* p_a_half2 = reinterpret_cast<const half2_t*>(&a);
const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0); const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0);
const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1); const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1);
...@@ -132,6 +133,7 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a, ...@@ -132,6 +133,7 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a,
float& c2, float& c2,
float& c3) float& c3)
{ {
// TODO remove pointer casting
const half2_t* p_a_half2 = reinterpret_cast<const half2_t*>(&a); const half2_t* p_a_half2 = reinterpret_cast<const half2_t*>(&a);
const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0); const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0);
const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1); const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1);
...@@ -177,6 +179,7 @@ __device__ void amd_assembly_outer_product_1x4(half8_t a, ...@@ -177,6 +179,7 @@ __device__ void amd_assembly_outer_product_1x4(half8_t a,
float& c3) float& c3)
{ {
// TODO remove pointer casting
const half4_t* p_a_half4 = reinterpret_cast<const half4_t*>(&a); const half4_t* p_a_half4 = reinterpret_cast<const half4_t*>(&a);
const half4_t* p_b0_half4 = reinterpret_cast<const half4_t*>(&b0); const half4_t* p_b0_half4 = reinterpret_cast<const half4_t*>(&b0);
const half4_t* p_b1_half4 = reinterpret_cast<const half4_t*>(&b1); const half4_t* p_b1_half4 = reinterpret_cast<const half4_t*>(&b1);
...@@ -200,6 +203,7 @@ __device__ void amd_assembly_outer_product_1x4(half16_t a, ...@@ -200,6 +203,7 @@ __device__ void amd_assembly_outer_product_1x4(half16_t a,
float& c2, float& c2,
float& c3) float& c3)
{ {
// TODO remove pointer casting
const half8_t* p_a_half8 = reinterpret_cast<const half8_t*>(&a); const half8_t* p_a_half8 = reinterpret_cast<const half8_t*>(&a);
const half8_t* p_b0_half8 = reinterpret_cast<const half8_t*>(&b0); const half8_t* p_b0_half8 = reinterpret_cast<const half8_t*>(&b0);
const half8_t* p_b1_half8 = reinterpret_cast<const half8_t*>(&b1); const half8_t* p_b1_half8 = reinterpret_cast<const half8_t*>(&b1);
...@@ -224,10 +228,14 @@ amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0 ...@@ -224,10 +228,14 @@ amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0
v_dot4_i32_i8 %1, %2, %4, %1\n \ v_dot4_i32_i8 %1, %2, %4, %1\n \
" "
: "=v"(c0), "=v"(c1) : "=v"(c0), "=v"(c1)
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1)); : "v"(as_type<int32_t>(a)),
"v"(as_type<int32_t>(b0)),
"v"(as_type<int32_t>(b1)),
"0"(c0),
"1"(c1));
#else #else
c0 = __builtin_amdgcn_sdot4(a, b0, c0, false); c0 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(a, b1, c1, false); c1 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b1), c1, false);
#endif #endif
} }
...@@ -253,12 +261,20 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a, ...@@ -253,12 +261,20 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a,
v_dot4_i32_i8 %3, %4, %8, %3\n \ v_dot4_i32_i8 %3, %4, %8, %3\n \
" "
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3)); : "v"(as_type<int32_t>(a)),
"v"(as_type<int32_t>(b0)),
"v"(as_type<int32_t>(b1)),
"v"(as_type<int32_t>(b2)),
"v"(as_type<int32_t>(b3)),
"0"(c0),
"1"(c1),
"2"(c2),
"3"(c3));
#else #else
c0 = __builtin_amdgcn_sdot4(a, b0, c0, false); c0 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(a, b1, c1, false); c1 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b1), c1, false);
c2 = __builtin_amdgcn_sdot4(a, b2, c2, false); c2 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b2), c2, false);
c3 = __builtin_amdgcn_sdot4(a, b3, c3, false); c3 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b3), c3, false);
#endif #endif
} }
...@@ -272,28 +288,24 @@ __device__ void amd_assembly_outer_product_1x4(int8x8_t a, ...@@ -272,28 +288,24 @@ __device__ void amd_assembly_outer_product_1x4(int8x8_t a,
int32_t& c2, int32_t& c2,
int32_t& c3) int32_t& c3)
{ {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
const int8x4_t* p_a_int8x4_t = reinterpret_cast<const int8x4_t*>(&a); amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
const int8x4_t* p_b0_int8x4_t = reinterpret_cast<const int8x4_t*>(&b0); vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I0],
const int8x4_t* p_b1_int8x4_t = reinterpret_cast<const int8x4_t*>(&b1); vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I0],
const int8x4_t* p_b2_int8x4_t = reinterpret_cast<const int8x4_t*>(&b2); vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I0],
const int8x4_t* p_b3_int8x4_t = reinterpret_cast<const int8x4_t*>(&b3); vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I0],
amd_assembly_outer_product_1x4(p_a_int8x4_t[0],
p_b0_int8x4_t[0],
p_b1_int8x4_t[0],
p_b2_int8x4_t[0],
p_b3_int8x4_t[0],
c0, c0,
c1, c1,
c2, c2,
c3); c3);
amd_assembly_outer_product_1x4(p_a_int8x4_t[1], amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
p_b0_int8x4_t[1], vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I1],
p_b1_int8x4_t[1], vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I1],
p_b2_int8x4_t[1], vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I1],
p_b3_int8x4_t[1], vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I1],
c0, c0,
c1, c1,
c2, c2,
...@@ -311,28 +323,46 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a, ...@@ -311,28 +323,46 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
int32_t& c3) int32_t& c3)
{ {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const int8x8_t* p_a_int8x8_t = reinterpret_cast<const int8x8_t*>(&a); amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
const int8x8_t* p_b0_int8x8_t = reinterpret_cast<const int8x8_t*>(&b0); vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I0],
const int8x8_t* p_b1_int8x8_t = reinterpret_cast<const int8x8_t*>(&b1); vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I0],
const int8x8_t* p_b2_int8x8_t = reinterpret_cast<const int8x8_t*>(&b2); vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I0],
const int8x8_t* p_b3_int8x8_t = reinterpret_cast<const int8x8_t*>(&b3); vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I0],
c0,
c1,
c2,
c3);
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I1],
c0,
c1,
c2,
c3);
amd_assembly_outer_product_1x4(p_a_int8x8_t[0], amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
p_b0_int8x8_t[0], vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I2],
p_b1_int8x8_t[0], vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I2],
p_b2_int8x8_t[0], vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I2],
p_b3_int8x8_t[0], vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I2],
c0, c0,
c1, c1,
c2, c2,
c3); c3);
amd_assembly_outer_product_1x4(p_a_int8x8_t[1], amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
p_b0_int8x8_t[1], vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I3],
p_b1_int8x8_t[1], vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I3],
p_b2_int8x8_t[1], vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I3],
p_b3_int8x8_t[1], vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I3],
c0, c0,
c1, c1,
c2, c2,
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "tuple_helper.hpp" #include "tuple_helper.hpp"
#include "type.hpp" #include "type.hpp"
#include "utility.hpp" #include "utility.hpp"
#include "magic_division.hpp"
#if CK_USE_AMD_INLINE_ASM #if CK_USE_AMD_INLINE_ASM
#include "amd_inline_asm.hpp" #include "amd_inline_asm.hpp"
......
...@@ -88,7 +88,7 @@ ...@@ -88,7 +88,7 @@
// experimental implementation // experimental implementation
#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 #define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#endif #endif
#ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
...@@ -115,6 +115,9 @@ ...@@ -115,6 +115,9 @@
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 1 #define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 1
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 0 #define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 0
// merge transformation use magic number division
#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 0
// hack: have underlying assumption that need to be satsified, otherwise it's a bug // hack: have underlying assumption that need to be satsified, otherwise it's a bug
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be // hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
// thread-invariant, otherwise it's a bug // thread-invariant, otherwise it's a bug
...@@ -139,6 +142,11 @@ ...@@ -139,6 +142,11 @@
#define CK_WORKAROUND_SWDEV_275126 1 #define CK_WORKAROUND_SWDEV_275126 1
#endif #endif
// workaround for compiler crash when using buffer load/store for i8
#ifndef CK_WORKAROUND_SWDEV_XXXXXX
#define CK_WORKAROUND_SWDEV_XXXXXX 1
#endif
namespace ck { namespace ck {
enum AddressSpace enum AddressSpace
......
#ifndef CK_MAGIC_DIVISION_HPP
#define CK_MAGIC_DIVISION_HPP
#include "config.hpp"
#include "integral_constant.hpp"
#include "number.hpp"
#include "type.hpp"
#include "tuple.hpp"
namespace ck {
// magic number division
// Caution:
// 1. For uint32_t as dividend: magic number division implementation being used would produce
// correct result if the dividend is uint32_t and its value is within 31-bit value range.
// 2. For int32_t as dividendd: magic number division for int32_t dividened has not been
// implemented, the int32_t dividend would be bit-wise interpreted as uint32_t and magic number
// division implementation for uint32_t is then used. Therefore, dividend value need to be
// non-negative.
// TODO:
// 1. Implement magic number divison for int32_t
// 2. Implement magic number divison for unit32_t with 32-bit value range
struct MagicDivision
{
// uint32_t
__host__ __device__ static constexpr auto CalculateMagicNumbers(uint32_t divisor)
{
// assert(divisior >= 1 && divisior <= INT32_MAX);
uint32_t shift = 0;
for(shift = 0; shift < 32; ++shift)
{
if((1U << shift) >= divisor)
{
break;
}
}
uint64_t one = 1;
uint64_t multiplier = ((one << 32) * ((one << shift) - divisor)) / divisor + 1;
// assert(multiplier <= 0xffffffffUL);
return make_tuple(uint32_t(multiplier), shift);
}
__host__ __device__ static constexpr uint32_t CalculateMagicMultiplier(uint32_t divisor)
{
auto tmp = CalculateMagicNumbers(divisor);
return tmp[Number<0>{}];
}
__host__ __device__ static constexpr uint32_t CalculateMagicShift(uint32_t divisor)
{
auto tmp = CalculateMagicNumbers(divisor);
return tmp[Number<1>{}];
}
// integral_constant<uint32_t, .>
template <uint32_t Divisor>
__host__ __device__ static constexpr auto
CalculateMagicNumbers(integral_constant<uint32_t, Divisor>)
{
constexpr auto tmp = CalculateMagicNumbers(uint32_t{Divisor});
constexpr uint32_t multiplier = tmp[Number<0>{}];
constexpr uint32_t shift = tmp[Number<1>{}];
return make_tuple(integral_constant<uint32_t, multiplier>{},
integral_constant<uint32_t, shift>{});
}
template <uint32_t Divisor>
__host__ __device__ static constexpr auto
CalculateMagicMultiplier(integral_constant<uint32_t, Divisor>)
{
constexpr uint32_t multiplier = CalculateMagicMultiplier(uint32_t{Divisor});
return integral_constant<uint32_t, multiplier>{};
}
template <uint32_t Divisor>
__host__ __device__ static constexpr auto
CalculateMagicShift(integral_constant<uint32_t, Divisor>)
{
constexpr uint32_t shift = CalculateMagicShift(uint32_t{Divisor});
return integral_constant<uint32_t, shift>{};
}
// integral_constant<int32_t, .>
template <int32_t Divisor>
__host__ __device__ static constexpr auto
CalculateMagicNumbers(integral_constant<int32_t, Divisor>)
{
return CalculateMagicNumbers(integral_constant<uint32_t, Divisor>{});
}
template <int32_t Divisor>
__host__ __device__ static constexpr auto
CalculateMagicMultiplier(integral_constant<int32_t, Divisor>)
{
return CalculateMagicMultiplier(integral_constant<uint32_t, Divisor>{});
}
template <int32_t Divisor>
__host__ __device__ static constexpr auto
CalculateMagicShift(integral_constant<int32_t, Divisor>)
{
return CalculateMagicShift(integral_constant<uint32_t, Divisor>{});
}
// magic division for uint32_t
__host__ __device__ static constexpr uint32_t
DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = (uint64_t(dividend) * uint64_t(multiplier)) >> 32;
return (tmp + dividend) >> shift;
}
// HACK: magic division for int32_t
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
// non-negative for result to be correct
// TODO: figure out how to do magic number divison for int32_t as dividended
__host__ __device__ static constexpr int32_t
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = as_type<uint32_t>(dividend_i32);
uint32_t tmp = ((uint64_t)dividend_u32 * (uint64_t)multiplier) >> 32;
return (tmp + dividend_i32) >> shift;
}
};
} // namespace ck
#endif
...@@ -42,5 +42,19 @@ struct is_known_at_compile_time<integral_constant<T, X>> ...@@ -42,5 +42,19 @@ struct is_known_at_compile_time<integral_constant<T, X>>
static constexpr bool value = true; static constexpr bool value = true;
}; };
template <typename Y,
typename X,
typename std::enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
__host__ __device__ constexpr Y as_type(X x)
{
union AsType
{
X x;
Y y;
};
return AsType{x}.y;
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -40,7 +40,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( ...@@ -40,7 +40,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
#if 0 #if 1
// run-time variables // run-time variables
const auto in_n_c_hi_wi_desc = const auto in_n_c_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths())); make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths()));
...@@ -167,7 +167,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( ...@@ -167,7 +167,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#elif 1 #elif 0
// cdata = 64, BlockSize 64, 16x256x4 // cdata = 64, BlockSize 64, 16x256x4
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
......
...@@ -53,7 +53,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -53,7 +53,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr auto C0 = C / Number<InWeiVectorSize>{}; constexpr auto C0 = C / Number<InWeiVectorSize>{};
constexpr auto C1 = Number<InWeiVectorSize>{}; constexpr auto C1 = Number<InWeiVectorSize>{};
#if 0 #if 1
// run-time variables // run-time variables
constexpr auto in_n_hi_wi_c0_desc = constexpr auto in_n_hi_wi_c0_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C0)); make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C0));
...@@ -112,7 +112,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -112,7 +112,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
#if 1 #if 0
// cdata = 16, BlockSize = 64, 16x64x4 // cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
...@@ -211,7 +211,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -211,7 +211,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
#elif 1 #elif 0
// cdata = 64, BlockSize = 64, 16x256x4 // cdata = 64, BlockSize = 64, 16x256x4
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
...@@ -310,7 +310,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -310,7 +310,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
#elif 0 #elif 1
// cdata = 64, BlockSize = 256, 128x128x8 // cdata = 64, BlockSize = 256, 128x128x8
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
......
...@@ -150,7 +150,7 @@ int main(int argc, char* argv[]) ...@@ -150,7 +150,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 1
// 3x3, 71x71 // 3x3, 71x71
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 192; constexpr index_t C = 192;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment