Unverified Commit 88833bd9 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Merge pull request #32 from ROCmSoftwarePlatform/develop

Merge develop into master
parents 31b40352 f3acd251
#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// GemmM = K
// GemmK = N * Ho * Wo
// GemmN = C * Y * X
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmK1Value>
__host__ __device__ constexpr auto
transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
const TensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
const TensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
const TensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<GemmK1Value>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1);
const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2);
const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3);
const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3);
const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2);
const auto X = wei_k_c_y_x_grid_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = K;
const auto GemmN = C * Y * X;
const auto GemmK = N * Ho * Wo;
const auto GemmK0 = GemmK / GemmK1;
// weight tensor
const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// input tensor
const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor(
in_n_c_hi_wi_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor(
in_n_c_hip_wip_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_gemmk_gemmn_grid_desc =
transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmk0_gemmn_gemmk1_grid_desc =
transform_tensor_descriptor(in_gemmk_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// output tensor
const auto out_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto out_gemmk0_gemmm_gemmk1_grid_desc =
transform_tensor_descriptor(out_gemmk_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif
...@@ -1327,6 +1327,129 @@ struct Merge_v2r2_magic_division ...@@ -1327,6 +1327,129 @@ struct Merge_v2r2_magic_division
} }
}; };
// Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to
// be used for low_lengths that are known at compile time and are power of 2, otherwise performance
// will be very bad
template <typename LowLengths>
struct Merge_v3_division_mod
{
static constexpr index_t NDimLow = LowLengths::Size();
using LowerIndex = MultiIndex<NDimLow>;
using UpperIndex = MultiIndex<1>;
using LowLengthsScan =
decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{}));
using UpLengths =
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
LowLengths low_lengths_;
LowLengthsScan low_lengths_scan_;
UpLengths up_lengths_;
__host__ __device__ constexpr Merge_v3_division_mod() = default;
__host__ __device__ constexpr Merge_v3_division_mod(const LowLengths& low_lengths)
: low_lengths_{low_lengths},
low_lengths_scan_{
container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})},
up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
{
static_assert(LowerIndex::Size() == NDimLow, "wrong!");
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
index_t tmp = idx_up[Number<0>{}];
// division and mod
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
idx_low(i) = tmp / this->low_lengths_scan_[i];
tmp %= this->low_lengths_scan_[i];
});
idx_low(Number<NDimLow - 1>{}) = tmp;
}
template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff&,
LowIdx& idx_low,
const UpIdx& idx_up_new,
Number<Hack>) const
{
static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
constexpr auto I0 = Number<0>{};
constexpr auto INm1 = Number<NDimLow - 1>{};
index_t tmp = idx_up_new[I0];
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
const index_t tmp2 = idx_low[i];
idx_low(i) = tmp / this->low_lengths_scan_[i];
idx_diff_low(i) = idx_low[i] - tmp2;
tmp %= this->low_lengths_scan_[i];
});
const index_t tmp2 = idx_low[INm1];
idx_low(INm1) = tmp;
idx_diff_low(INm1) = idx_low[INm1] - tmp2;
}
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<LowLengths>::value &&
is_known_at_compile_time<LowLengthsScan>::value &&
is_known_at_compile_time<UpLengths>::value;
}
template <typename UpIdx>
__host__ __device__ static constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
{
return true;
}
__host__ __device__ void Print() const
{
printf("{");
printf("Merge_v3_direct_division_mod, ");
printf("low_lengths_ ");
print_multi_index(low_lengths_);
printf("low_lengths_scan_ ");
print_multi_index(low_lengths_scan_);
printf("up_lengths_ ");
print_multi_index(up_lengths_);
printf("}");
}
};
template <typename UpLengths, bool Use24BitIntegerCalculation> template <typename UpLengths, bool Use24BitIntegerCalculation>
struct UnMerge struct UnMerge
{ {
......
...@@ -52,22 +52,36 @@ __host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_leng ...@@ -52,22 +52,36 @@ __host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_leng
template <typename LowLengths> template <typename LowLengths>
__host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths) __host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths)
{ {
#if !CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION #if CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
return Merge_v1_carry_check<LowLengths>{low_lengths}; return make_merge_transform_v2_magic_division(low_lengths);
#else #else
return make_merge_transform_v1_carry_check(low_lengths);
#endif
}
template <typename LowLengths>
__host__ __device__ constexpr auto
make_merge_transform_v1_carry_check(const LowLengths& low_lengths)
{
return Merge_v1_carry_check<LowLengths>{low_lengths};
}
template <typename LowLengths>
__host__ __device__ constexpr auto
make_merge_transform_v2_magic_division(const LowLengths& low_lengths)
{
#if 1 #if 1
return Merge_v2_magic_division<LowLengths>{low_lengths}; return Merge_v2_magic_division<LowLengths>{low_lengths};
#else #else
return Merge_v2r2_magic_division<LowLengths>{low_lengths}; return Merge_v2r2_magic_division<LowLengths>{low_lengths};
#endif #endif
#endif
} }
template <typename LowLengths> template <typename LowLengths>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_merge_transform_v2_magic_division(const LowLengths& low_lengths) make_merge_transform_v3_division_mod(const LowLengths& low_lengths)
{ {
return Merge_v2_magic_division<LowLengths>{low_lengths}; return Merge_v3_division_mod<LowLengths>{low_lengths};
} }
template <typename UpLengths, bool Use24BitIntegerCalculation = false> template <typename UpLengths, bool Use24BitIntegerCalculation = false>
......
...@@ -189,8 +189,7 @@ struct TensorAdaptor ...@@ -189,8 +189,7 @@ struct TensorAdaptor
bool is_known = true; bool is_known = true;
static_for<0, Transforms::Size(), 1>{}([&](auto i) { static_for<0, Transforms::Size(), 1>{}([&](auto i) {
is_known &= is_known &= remove_cvref_t<decltype(Transforms{}[i])>::IsKnownAtCompileTime();
remove_cv_t<remove_reference_t<decltype(Transforms{}[i])>>::IsKnownAtCompileTime();
}); });
return is_known && is_known_at_compile_time<ElementSize>::value; return is_known && is_known_at_compile_time<ElementSize>::value;
......
...@@ -185,8 +185,7 @@ struct TensorDescriptor ...@@ -185,8 +185,7 @@ struct TensorDescriptor
bool is_known = true; bool is_known = true;
static_for<0, Transforms::Size(), 1>{}([&](auto i) { static_for<0, Transforms::Size(), 1>{}([&](auto i) {
is_known &= is_known &= remove_cvref_t<decltype(Transforms{}[i])>::IsKnownAtCompileTime();
remove_cv_t<remove_reference_t<decltype(Transforms{}[i])>>::IsKnownAtCompileTime();
}); });
return is_known && is_known_at_compile_time<ElementSize>::value && return is_known && is_known_at_compile_time<ElementSize>::value &&
...@@ -587,11 +586,11 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc& ...@@ -587,11 +586,11 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc&
template <typename TensorDesc> template <typename TensorDesc>
using TensorCoordinate_t = decltype(make_tensor_coordinate( using TensorCoordinate_t = decltype(make_tensor_coordinate(
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{})); TensorDesc{}, MultiIndex<remove_cvref_t<TensorDesc>::GetNumOfDimension()>{}));
template <typename TensorDesc> template <typename TensorDesc>
using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step( using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step(
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{})); TensorDesc{}, MultiIndex<remove_cvref_t<TensorDesc>::GetNumOfDimension()>{}));
} // namespace ck } // namespace ck
#endif #endif
...@@ -110,13 +110,11 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -110,13 +110,11 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
const BThreadBuffer& b_thread_buf, const BThreadBuffer& b_thread_buf,
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
{ {
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABlockBuffer::type>>, static_assert(
remove_cv_t<remove_reference_t<FloatA>>>::value && is_same<remove_cvref_t<typename ABlockBuffer::type>, remove_cvref_t<FloatA>>::value &&
is_same<remove_cv_t<remove_reference_t<typename BThreadBuffer::type>>, is_same<remove_cvref_t<typename BThreadBuffer::type>, remove_cvref_t<FloatB>>::value &&
remove_cv_t<remove_reference_t<FloatB>>>::value && is_same<remove_cvref_t<typename CThreadBuffer::type>, remove_cvref_t<FloatC>>::value &&
is_same<remove_cv_t<remove_reference_t<typename CThreadBuffer::type>>, "wrong! inconsistent type");
remove_cv_t<remove_reference_t<FloatC>>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
......
...@@ -4,21 +4,21 @@ ...@@ -4,21 +4,21 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "xdlops_gemm.hpp" #include "xdlops_gemm.hpp"
#include "tensor_adaptor.hpp"
namespace ck { namespace ck {
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatAB,
class ABlockDesc, typename AK0MK1BlockDesc,
class BBlockDesc, typename BK0NK1BlockDesc,
index_t MPerWave, index_t MPerXDL,
index_t NPerWave, index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t K1> index_t K1>
struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{ {
using CIndex = MultiIndex<2>;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
...@@ -26,329 +26,165 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -26,329 +26,165 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static constexpr index_t WaveSize = 64; static constexpr index_t WaveSize = 64;
static constexpr index_t M0 = ABlockDesc{}.GetLength(I1); static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t M1 = ABlockDesc{}.GetLength(I2); static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
static constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
static constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1>{}; static constexpr index_t K0 = BK0NK1BlockDesc{}.GetLength(I0);
static constexpr index_t KPerBlock = K0;
static constexpr index_t MWaves = M1 / MPerWave; static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{};
static constexpr index_t NWaves = N1 / NPerWave;
static constexpr index_t MRepeat = M0; static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NRepeat = N0; static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
__device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); } __device__ static auto GetWaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
__device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); } constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
__device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); } return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto CalculateAThreadOriginDataIndex() __device__ static auto CalculateAThreadOriginDataIndex()
{ {
const index_t thread_id = get_thread_local_1d_id(); const auto wave_idx = GetWaveIdx();
const index_t waveId = thread_id / WaveSize;
const index_t laneId = thread_id % WaveSize; const auto waveId_m = wave_idx[I0];
const index_t waveId_m = waveId / NWaves;
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
if constexpr(xdlops_gemm.IsKReduction)
{ return make_tuple(xdlops_a_idx[I0], 0, waveId_m, xdlops_a_idx[I1], 0);
const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId);
const index_t k_offset = xdlops_gemm.GetBlkId(laneId);
return make_tuple(k_offset, 0, m_offset, 0);
}
else
{
const index_t m_offset = waveId_m * MPerWave + laneId;
const index_t k_offset = 0;
return make_tuple(k_offset, 0, m_offset, 0);
}
} }
__device__ static auto CalculateBThreadOriginDataIndex() __device__ static auto CalculateBThreadOriginDataIndex()
{ {
const index_t thread_id = get_thread_local_1d_id(); const auto wave_idx = GetWaveIdx();
const index_t waveId = thread_id / WaveSize;
const index_t laneId = thread_id % WaveSize; const auto waveId_n = wave_idx[I1];
const index_t waveId_n = waveId % NWaves;
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
if constexpr(xdlops_gemm.IsKReduction)
{ return make_tuple(xdlops_b_idx[I0], 0, waveId_n, xdlops_b_idx[I1], 0);
const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId);
const index_t k_offset = xdlops_gemm.GetBlkId(laneId);
return make_tuple(k_offset, 0, n_offset, 0);
}
else
{
const index_t n_offset = waveId_n * NPerWave + laneId;
const index_t k_offset = 0;
return make_tuple(k_offset, 0, n_offset, 0);
}
} }
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i> template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static CIndex __device__ static auto
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>) CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{ {
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const index_t waveId = get_thread_local_1d_id() / WaveSize; const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t waveId_m = waveId / NWaves; constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
const index_t waveId_n = waveId % NWaves; make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0]; const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1]; make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
return CIndex{m_offset, n_offset}; return make_tuple(c_thread_m, c_thread_n);
} }
__device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1() __host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
: a_thread_copy_{CalculateAThreadOriginDataIndex()},
b_thread_copy_{CalculateBThreadOriginDataIndex()}
{ {
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0NK1BlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), static_assert(AK0MK1BlockDesc{}.GetLength(I0) == BK0NK1BlockDesc{}.GetLength(I0),
"wrong! K dimension not consistent"); "wrong! K0 dimension not consistent");
static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3), static_assert(AK0MK1BlockDesc{}.GetLength(I2) == BK0NK1BlockDesc{}.GetLength(I2),
"wrong! K1 dimension not consistent"); "wrong! K1 dimension not consistent");
static_assert(BlockSize == MWaves * NWaves * WaveSize, static_assert(BlockSize == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n"); "BlockSize != MWaves * NWaves * WaveSize\n");
static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!"); static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!");
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!");
static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!");
} }
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> __host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2ThreadDescriptor()
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{ {
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>( constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize());
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
vector_type<FloatAB, a_thread_desc_.GetElementSpaceSize()> a_thread_vec;
vector_type<FloatAB, b_thread_desc_.GetElementSpaceSize()> b_thread_vec;
static_for<0, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k) {
// read A
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
// read B constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
b_thread_copy_.Run(BBlockDesc{}, constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
make_tuple(k, I0, I0, I0), constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
b_block_buf, constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
using mfma_input_type = return make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, I1, M0, M1, M2, N));
typename vector_type<FloatAB, xdlops_gemm.mfma_type.k_base>::type;
static_for<0, a_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(Number<i>{}) = a_thread_buf[Number<i>{}];
});
static_for<0, b_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatAB>()(Number<i>{}) = b_thread_buf[Number<i>{}];
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
m0,
n0>(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf);
});
});
});
} }
private: __host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2BlockDescriptor()
// A[K, M]
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
// B[K, N]
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
static constexpr auto c_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
ABlockDesc,
decltype(a_thread_desc_),
Sequence<1, MRepeat, 1, K1>,
Sequence<0, 1, 2, 3>,
3,
K1,
1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
BBlockDesc,
decltype(b_thread_desc_),
Sequence<1, NRepeat, 1, K1>,
Sequence<0, 1, 2, 3>,
3,
K1,
1>;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
template <index_t BlockSize,
typename FloatAB,
class ABlockDesc,
class BBlockDesc,
index_t MPerWave,
index_t NPerWave,
index_t K1>
struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
{
using CIndex = MultiIndex<2>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto xdlops_gemm = XdlopsGemm<float, MPerWave, NPerWave, K1>{};
static constexpr index_t WaveSize = 64;
static constexpr index_t M0 = ABlockDesc{}.GetLength(I1);
static constexpr index_t M1 = ABlockDesc{}.GetLength(I2);
static constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
static constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
static constexpr index_t MWaves = M1 / MPerWave;
static constexpr index_t NWaves = N1 / NPerWave;
static constexpr index_t MRepeat = M0;
static constexpr index_t NRepeat = N0;
__device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); }
__device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); }
__device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); }
__device__ static auto CalculateAThreadOriginDataIndex()
{ {
const index_t thread_id = get_thread_local_1d_id(); constexpr auto c_m0_n0_m1_n1_m2_n2_block_desc =
const index_t waveId = thread_id / WaveSize; make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
const index_t laneId = thread_id % WaveSize; Number<NRepeat>{},
const index_t waveId_m = waveId / NWaves; Number<MWaves>{},
Number<NWaves>{},
if constexpr(xdlops_gemm.IsKReduction) Number<MPerXDL>{},
{ Number<NPerXDL>{}));
const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId);
const index_t k_offset = xdlops_gemm.GetBlkId(laneId); return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_block_desc);
return make_tuple(k_offset, 0, m_offset, 0);
}
else
{
const index_t m_offset = waveId_m * MPerWave + laneId;
const index_t k_offset = 0;
return make_tuple(k_offset, 0, m_offset, 0);
}
} }
__device__ static auto CalculateBThreadOriginDataIndex() template <typename CMNGridDesc>
__host__ __device__ static constexpr auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
{ {
const index_t thread_id = get_thread_local_1d_id(); const auto c_m0_n0_m1_n1_m2_n2_grid_desc = transform_tensor_descriptor(
const index_t waveId = thread_id / WaveSize; c_m_n_grid_desc,
const index_t laneId = thread_id % WaveSize; make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL)),
const index_t waveId_n = waveId % NWaves; make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
if constexpr(xdlops_gemm.IsKReduction) make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
{
const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId); return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_grid_desc);
const index_t k_offset = xdlops_gemm.GetBlkId(laneId);
return make_tuple(k_offset, 0, n_offset, 0);
}
else
{
const index_t n_offset = waveId_n * NPerWave + laneId;
const index_t k_offset = 0;
return make_tuple(k_offset, 0, n_offset, 0);
}
} }
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i> __host__ __device__ static constexpr auto MakeAK0M0M1M2K1BlockDescriptor()
__device__ static CIndex
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{ {
return transform_tensor_descriptor(
const index_t waveId = get_thread_local_1d_id() / WaveSize; AK0MK1BlockDesc{},
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerXDL>{})),
const index_t waveId_m = waveId / NWaves; make_pass_through_transform(Number<K1>{})),
const index_t waveId_n = waveId % NWaves; make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0];
const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1];
return CIndex{m_offset, n_offset};
} }
__device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline() __host__ __device__ static constexpr auto MakeBK0N0N1N2K1BlockDescriptor()
: a_thread_copy_{CalculateAThreadOriginDataIndex()},
b_thread_copy_{CalculateBThreadOriginDataIndex()}
{ {
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), return transform_tensor_descriptor(
"wrong! Desc should be known at compile-time"); BK0NK1BlockDesc{},
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), make_unmerge_transform(
"wrong! K dimension not consistent"); make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerXDL>{})),
make_pass_through_transform(Number<K1>{})),
static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
"wrong! K1 dimension not consistent"); make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
static_assert(BlockSize == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n");
static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!");
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!");
static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!");
} }
static constexpr auto a_k0_m0_m1_m2_k1_block_desc = MakeAK0M0M1M2K1BlockDescriptor();
static constexpr auto b_k0_n0_n1_n2_k1_block_desc = MakeBK0N0N1N2K1BlockDescriptor();
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf, __device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf, const BBlockBuffer& b_block_buf,
...@@ -359,165 +195,87 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline ...@@ -359,165 +195,87 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); vector_type<FloatAB, K1> a_thread_vec;
// read A_sub_0 vector_type<FloatAB, K1> b_thread_vec;
a_thread_copy_.Run(ABlockDesc{},
make_tuple(I0, I0, I0, I0), static_for<0, KPerBlock, xdlops_gemm.KPerXdlops / xdlops_gemm.KPerThread>{}([&](auto k0) {
a_block_buf, // read A
a_thread_desc_, a_thread_copy_.Run(a_k0_m0_m1_m2_k1_block_desc,
make_tuple(I0, I0, I0, I0), make_tuple(k0, I0, I0, I0, I0),
a_thread_buf);
// read B_sub_0
b_thread_copy_.Run(BBlockDesc{},
make_tuple(I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
// read B_sub_1
b_thread_copy_.Run(BBlockDesc{},
make_tuple(I0, I1, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I1, I0, I0),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(ABlockDesc{},
make_tuple(I0, I1, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I1, I0, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
0,
0>(a_thread_buf, b_thread_buf, c_thread_buf);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
0,
1>(a_thread_buf, b_thread_buf, c_thread_buf);
static_for<xdlops_gemm.KPerXdlops, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k) {
// read A_sub_0
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I0, I0, I0),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
a_thread_buf); a_thread_buf);
// C_sub_10 += transpose(A_sub_1) * B_sub_0 // read B
xdlops_gemm.template Run<decltype(a_thread_desc_), b_thread_copy_.Run(b_k0_n0_n1_n2_k1_block_desc,
decltype(b_thread_desc_), make_tuple(k0, I0, I0, I0, I0),
decltype(c_thread_desc_),
1,
0>(a_thread_buf, b_thread_buf, c_thread_buf);
// read B_sub_0
b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
1,
1>(a_thread_buf, b_thread_buf, c_thread_buf);
// read B_sub_1
b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I1, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(I0, I1, I0, I0), make_tuple(I0, I0, I0, I0, I0),
b_thread_buf); b_thread_buf);
// read A_sub_1 using mfma_input_type = typename vector_type<FloatAB, xdlops_gemm.KPerThread>::type;
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I1, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I1, I0, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0 static_for<0, MRepeat, 1>{}([&](auto m0) {
xdlops_gemm.template Run<decltype(a_thread_desc_), static_for<0, NRepeat, 1>{}([&](auto n0) {
decltype(b_thread_desc_), static_for<0, K1, 1>{}([&](auto i) {
decltype(c_thread_desc_), a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
0, [Number<a_thread_desc_.CalculateOffset(make_tuple(0, m0, 0, 0, i))>{}];
0>(a_thread_buf, b_thread_buf, c_thread_buf); });
// C_sub_01 += transpose(A_sub_0) * B_sub_1 static_for<0, K1, 1>{}([&](auto i) {
xdlops_gemm.template Run<decltype(a_thread_desc_), b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
decltype(b_thread_desc_), [Number<b_thread_desc_.CalculateOffset(make_tuple(0, n0, 0, 0, i))>{}];
decltype(c_thread_desc_), });
0,
1>(a_thread_buf, b_thread_buf, c_thread_buf); constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run<c_offset>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf);
});
});
}); });
// C_sub_10 += transpose(A_sub_1) * B_sub_0
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
1,
0>(a_thread_buf, b_thread_buf, c_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
1,
1>(a_thread_buf, b_thread_buf, c_thread_buf);
} }
private: private:
// A[K, M] // A[K, M]
static constexpr auto a_thread_desc_ = static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{})); make_tuple(I1, Number<MRepeat>{}, I1, I1, Number<K1>{}));
// B[K, N] // B[K, N]
static constexpr auto b_thread_desc_ = static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{})); make_tuple(I1, Number<NRepeat>{}, I1, I1, Number<K1>{}));
static constexpr auto c_thread_desc_ = static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{})); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<xdlops_gemm.GetNumXdlops()>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB, using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB, FloatAB,
ABlockDesc, decltype(a_k0_m0_m1_m2_k1_block_desc),
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<1, 1, 1, K1>, Sequence<1, MRepeat, 1, 1, K1>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3, 4>,
3, 4,
1, // K1, K1,
1>; 1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB, using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB, FloatAB,
BBlockDesc, decltype(b_k0_n0_n1_n2_k1_block_desc),
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<1, 1, 1, K1>, Sequence<1, NRepeat, 1, 1, K1>,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3, 4>,
3, 4,
1, // K1, K1,
1>; 1>;
AThreadCopy a_thread_copy_; AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
BThreadCopy b_thread_copy_; BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
}; };
} // namespace ck } // namespace ck
......
...@@ -18,7 +18,7 @@ template <typename GridwiseGemm, ...@@ -18,7 +18,7 @@ template <typename GridwiseGemm,
typename FloatC, typename FloatC,
typename AK0MK1GridDesc, typename AK0MK1GridDesc,
typename BK0NK1GridDesc, typename BK0NK1GridDesc,
typename CM0M1M2NGridDesc, typename CM0N0M1N1M2M3M4N2GridDesc,
typename CBlockClusterAdaptor> typename CBlockClusterAdaptor>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
...@@ -29,7 +29,7 @@ __global__ void ...@@ -29,7 +29,7 @@ __global__ void
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AK0MK1GridDesc a_k0_m_k1_grid_desc, const AK0MK1GridDesc a_k0_m_k1_grid_desc,
const BK0NK1GridDesc b_k0_n_k1_grid_desc, const BK0NK1GridDesc b_k0_n_k1_grid_desc,
const CM0M1M2NGridDesc c_m0_m1_m2_n_grid_desc, const CM0N0M1N1M2M3M4N2GridDesc c_m0_m1_m2_n_grid_desc,
const CBlockClusterAdaptor c_block_cluster_adaptor) const CBlockClusterAdaptor c_block_cluster_adaptor)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
...@@ -43,7 +43,7 @@ __global__ void ...@@ -43,7 +43,7 @@ __global__ void
p_shared_block, p_shared_block,
a_k0_m_k1_grid_desc, a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc, b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor); c_block_cluster_adaptor);
} }
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
...@@ -52,7 +52,7 @@ template <typename GridwiseGemm, ...@@ -52,7 +52,7 @@ template <typename GridwiseGemm,
typename FloatC, typename FloatC,
typename AK0MK1GridDesc, typename AK0MK1GridDesc,
typename BK0NK1GridDesc, typename BK0NK1GridDesc,
typename CM0M1M2NGridDesc, typename CM0N0M1N1M2M3M4N2GridDesc,
typename CBlockClusterAdaptor> typename CBlockClusterAdaptor>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
...@@ -63,7 +63,7 @@ __global__ void ...@@ -63,7 +63,7 @@ __global__ void
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_k0_m_k1_grid_desc, const void CONSTANT* p_a_k0_m_k1_grid_desc,
const void CONSTANT* p_b_k0_n_k1_grid_desc, const void CONSTANT* p_b_k0_n_k1_grid_desc,
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, const void CONSTANT* p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
const void CONSTANT* p_c_block_cluster_adaptor) const void CONSTANT* p_c_block_cluster_adaptor)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
...@@ -73,8 +73,9 @@ __global__ void ...@@ -73,8 +73,9 @@ __global__ void
cast_pointer_to_generic_address_space(p_a_k0_m_k1_grid_desc)); cast_pointer_to_generic_address_space(p_a_k0_m_k1_grid_desc));
const auto b_k0_n_k1_grid_desc = *reinterpret_cast<const BK0NK1GridDesc*>( const auto b_k0_n_k1_grid_desc = *reinterpret_cast<const BK0NK1GridDesc*>(
cast_pointer_to_generic_address_space(p_b_k0_n_k1_grid_desc)); cast_pointer_to_generic_address_space(p_b_k0_n_k1_grid_desc));
const auto c_m0_m1_m2_n_grid_desc = *reinterpret_cast<const CM0M1M2NGridDesc*>( const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
cast_pointer_to_generic_address_space(p_c_m0_m1_m2_n_grid_desc)); *reinterpret_cast<const CM0N0M1N1M2M3M4N2GridDesc*>(
cast_pointer_to_generic_address_space(p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc));
const auto c_block_cluster_adaptor = *reinterpret_cast<const CBlockClusterAdaptor*>( const auto c_block_cluster_adaptor = *reinterpret_cast<const CBlockClusterAdaptor*>(
cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor)); cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor));
...@@ -86,7 +87,7 @@ __global__ void ...@@ -86,7 +87,7 @@ __global__ void
p_shared_block, p_shared_block,
a_k0_m_k1_grid_desc, a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc, b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor); c_block_cluster_adaptor);
} }
#endif #endif
...@@ -102,8 +103,8 @@ template <index_t BlockSize, ...@@ -102,8 +103,8 @@ template <index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t KPerBlock, index_t KPerBlock,
index_t MPerWave, index_t MPerXDL,
index_t NPerWave, index_t NPerXDL,
index_t K1Value, index_t K1Value,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
...@@ -138,6 +139,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -138,6 +139,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{}; static constexpr auto K1 = Number<K1Value>{};
...@@ -179,14 +183,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -179,14 +183,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const auto N = b_k0_n_k1_grid_desc.GetLength(I1); const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) &&
(NPerBlock % (NRepeat * NPerXDL)) == 0,
"Invalid tuning param!");
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
K0 == b_k0_n_k1_grid_desc.GetLength(I0) && K0 == b_k0_n_k1_grid_desc.GetLength(I0) &&
K1 == a_k0_m_k1_grid_desc.GetLength(I2) && K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) && K1 == b_k0_n_k1_grid_desc.GetLength(I2)) &&
(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0) && (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0);
(MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0);
} }
__host__ __device__ static constexpr index_t __host__ __device__ static constexpr index_t
...@@ -201,29 +207,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -201,29 +207,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCM0M1M2NGridDescriptor(const CMNGridDesc& c_m_n_grid_desc) MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
{ {
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1>{}; constexpr auto max_lds_align = K1;
constexpr auto CLayout = xdlops_gemm.GetCLayout();
constexpr auto M0 = Number<CLayout.M1()>{};
constexpr auto M1 = Number<CLayout.N1()>{};
constexpr auto M2 = Number<CLayout.M0()>{};
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
constexpr auto N1 = Number<CLayout.N0()>{}; constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
const auto c_m0_m1_m2_n_grid_desc = transform_tensor_descriptor( constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
c_m_n_grid_desc, make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, M0, M1, M2)),
make_unmerge_transform(make_tuple(NRepeat, NWaves, N1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
return c_m0_m1_m2_n_grid_desc; using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
K1>;
return BlockwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -253,8 +258,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -253,8 +258,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return c_blockid_to_m0_n0_block_cluster_adaptor; return c_blockid_to_m0_n0_block_cluster_adaptor;
} }
using CM0M1M2NGridDesc = decltype(MakeCM0M1M2NGridDescriptor(CMNGridDesc{})); using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{})); using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}));
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
...@@ -262,7 +267,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -262,7 +267,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
const AK0MK1GridDesc& a_k0_m_k1_grid_desc, const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc, const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const CM0M1M2NGridDesc& c_m0_m1_m2_n_grid_desc, const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
const CBlockClusterAdaptor& c_block_cluster_adaptor) const CBlockClusterAdaptor& c_block_cluster_adaptor)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
...@@ -270,7 +275,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -270,7 +275,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize()); p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_grid, c_m0_m1_m2_n_grid_desc.GetElementSpaceSize()); p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize());
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
...@@ -358,50 +363,26 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -358,50 +363,26 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// register // register
// sanity check // sanity check
static_assert(MPerBlock % (MPerWave * MRepeat) == 0 &&
NPerBlock % (NPerWave * NRepeat) == 0,
"wrong!");
constexpr auto a_k0_m0_m1_k1_block_desc = transform_tensor_descriptor(
a_k0_m_k1_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MPerBlock / MRepeat>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
constexpr auto b_k0_n0_n1_k1_block_desc = transform_tensor_descriptor(
b_k0_n_k1_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NPerBlock / NRepeat>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB, FloatAB,
decltype(a_k0_m0_m1_k1_block_desc), decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n0_n1_k1_block_desc), decltype(b_k0_n_k1_block_desc),
MPerWave, MPerXDL,
NPerWave, NPerXDL,
K1>{}; MRepeat,
NRepeat,
constexpr auto CLayout = blockwise_gemm.GetCLayout(); K1>{};
constexpr index_t BlkSize = CLayout.GetBlkSize();
constexpr index_t NumBlks = CLayout.GetNumBlks();
constexpr index_t NumXdlops = CLayout.GetNumXdlops();
static_assert(NumBlks == 1 && NumXdlops == 1, "K Reduction Mfma only");
constexpr auto c_mr_nr_blk_desc = constexpr auto c_mr_nr_blk_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
blockwise_gemm.GetCM0N0M1N1M2M3M4N2ThreadDescriptor();
constexpr auto CBlkSize = c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc.GetElementSpaceSize();
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum_t::Vgpr,
vector_type<FloatAcc, BlkSize>, vector_type<FloatAcc, CBlkSize>,
c_mr_nr_blk_desc.GetElementSpaceSize(), c_mr_nr_blk_desc.GetElementSpaceSize(),
true> true>
c_thread_buf; c_thread_buf;
...@@ -474,94 +455,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -474,94 +455,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
} }
#if 0
// output: register to global memory // output: register to global memory
{ {
constexpr index_t M0 = CLayout.M1(); constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
constexpr index_t M1 = CLayout.N1(); blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor();
constexpr index_t M2 = CLayout.M0();
constexpr index_t N0 = CLayout.N1();
constexpr index_t N1 = CLayout.N0();
constexpr auto c_m0_m1_m2_n_thread_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<1>{},
Number<1>{},
Number<M0>{},
Number<1>{},
Number<M2>{},
Number<1>{}));
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize(), true>
c_blk_buf_;
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
static_for<0, NRepeat, 1>{}([&](auto nr_i) {
constexpr auto blk_off =
c_mr_nr_blk_desc.CalculateOffset(make_tuple(mr_i, nr_i));
static_for<0, BlkSize, 1>{}([&](auto j) {
c_blk_buf_(Number<blk_off * BlkSize + j>{}) =
c_thread_buf[Number<blk_off>{}]
.template AsType<FloatAcc>()[Number<j>{}];
});
});
});
// calculate origin of thread output tensor on global memory constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
// blockwise GEMM c matrix starting index constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
const auto c_thread_mtx_on_block = constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_grid =
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{};
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
ThreadwiseTensorSliceTransfer_v1r3<
FloatC,
FloatC,
decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_grid_desc),
Sequence<MRepeat, NRepeat, 1, 1, M0, 1, M2, 1>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{
c_m0_m1_m2_n_grid_desc,
make_multi_index(m_thread_data_on_grid / (M2 * M1 * M0 * MWaves),
n_thread_data_on_grid / (N1 * NWaves),
m_thread_data_on_grid % (M2 * M1 * M0 * MWaves) / (M2 * M1 * M0),
n_thread_data_on_grid % (N1 * NWaves) / N1,
m_thread_data_on_grid % (M2 * M1 * M0) / (M2 * M1),
m_thread_data_on_grid % (M2 * M1) / M2,
m_thread_data_on_grid % M2,
n_thread_data_on_grid % N1)}
.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_blk_buf_,
c_m0_m1_m2_n_grid_desc,
c_grid_buf,
c_m0_m1_m2_n_grid_tensor_step_hacks);
}
#else
{
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
constexpr auto c_m0_m1_m2_n_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, I1, Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
...@@ -574,92 +475,96 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -574,92 +475,96 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const index_t n_thread_data_on_grid = const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{}; constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{};
auto c_thread_copy = auto c_thread_copy =
ThreadwiseTensorSliceTransfer_v1r3<FloatC, ThreadwiseTensorSliceTransfer_v1r3<FloatC,
FloatC, FloatC,
decltype(c_m0_m1_m2_n_thread_desc), decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
decltype(c_m0_m1_m2_n_grid_desc), decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc),
Sequence<1, 1, 1, 1, M0, 1, M2, 1>, Sequence<I1, I1, I1, I1, M2, I1, M4, I1>,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true>{ true>{
c_m0_m1_m2_n_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
make_multi_index(0, make_multi_index(0,
0, 0,
0, 0,
0, 0,
m_thread_data_on_grid / (M2 * M1), m_thread_data_on_grid / (M3 * M4),
m_thread_data_on_grid % (M2 * M1) / M2, m_thread_data_on_grid % (M3 * M4) / M4,
m_thread_data_on_grid % M2, m_thread_data_on_grid % M4,
n_thread_data_on_grid)}; n_thread_data_on_grid)};
auto init_copy = [&](auto c_thread_idx_) { auto init_copy = [&](auto c_thread_idx_) {
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf, c_grid_buf,
c_m0_m1_m2_n_grid_tensor_step_hacks); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
return c_thread_idx_; return c_thread_idx_;
}; };
auto mrepeat_plus_copy = [&](auto c_thread_idx_) { auto mrepeat_plus_copy = [&](auto c_thread_idx_) {
constexpr auto mrepeat_step_plus = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0); constexpr auto mrepeat_step_plus = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, mrepeat_step_plus); c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
mrepeat_step_plus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf, c_grid_buf,
c_m0_m1_m2_n_grid_tensor_step_hacks); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
}; };
auto nrepeat_plus_copy = [&](auto c_thread_idx_) { auto nrepeat_plus_copy = [&](auto c_thread_idx_) {
constexpr auto nrepeat_step_plus = make_multi_index(0, 1, 0, 0, 0, 0, 0, 0); constexpr auto nrepeat_step_plus = make_multi_index(0, 1, 0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, nrepeat_step_plus); c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
nrepeat_step_plus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf, c_grid_buf,
c_m0_m1_m2_n_grid_tensor_step_hacks); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
}; };
auto mrepeat_minus_copy = [&](auto c_thread_idx_) { auto mrepeat_minus_copy = [&](auto c_thread_idx_) {
constexpr auto mrepeat_step_plus = make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0); constexpr auto mrepeat_step_plus = make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, mrepeat_step_plus); c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
mrepeat_step_plus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf, c_grid_buf,
c_m0_m1_m2_n_grid_tensor_step_hacks); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
}; };
auto nrepeat_minus_copy = [&](auto c_thread_idx_) { auto nrepeat_minus_copy = [&](auto c_thread_idx_) {
constexpr auto nrepeat_step_minus = make_multi_index(0, -1, 0, 0, 0, 0, 0, 0); constexpr auto nrepeat_step_minus = make_multi_index(0, -1, 0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, nrepeat_step_minus); c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
nrepeat_step_minus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(), c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf, c_grid_buf,
c_m0_m1_m2_n_grid_tensor_step_hacks); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
}; };
static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or
...@@ -791,7 +696,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -791,7 +696,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
init_copy(make_tuple(I0, I0)); init_copy(make_tuple(I0, I0));
} }
} }
#endif
} }
}; // namespace ck }; // namespace ck
......
...@@ -55,19 +55,16 @@ struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1 ...@@ -55,19 +55,16 @@ struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
CBuffer& c_buf, CBuffer& c_buf,
COriginIdx) COriginIdx)
{ {
static_assert(is_known_at_compile_time<remove_cvref_t<AOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<BOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<COriginIdx>>::value,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
static_assert( static_assert(
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value && is_same<remove_cvref_t<typename ABuffer::type>, remove_cvref_t<FloatA>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value && is_same<remove_cvref_t<typename BBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value, is_same<remove_cvref_t<typename CBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); "wrong! inconsistent type");
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABuffer::type>>,
remove_cv_t<remove_reference_t<FloatA>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename BBuffer::type>>,
remove_cv_t<remove_reference_t<FloatB>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename CBuffer::type>>,
remove_cv_t<remove_reference_t<FloatC>>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -157,19 +154,16 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_ ...@@ -157,19 +154,16 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_
CBuffer& c_buf, CBuffer& c_buf,
COriginIdx) COriginIdx)
{ {
static_assert(is_known_at_compile_time<remove_cvref_t<AOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<BOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<COriginIdx>>::value,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
static_assert( static_assert(
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value && is_same<remove_cvref_t<typename ABuffer::type>, remove_cvref_t<FloatA>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value && is_same<remove_cvref_t<typename BBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value, is_same<remove_cvref_t<typename CBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); "wrong! inconsistent type");
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABuffer::type>>,
remove_cv_t<remove_reference_t<FloatA>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename BBuffer::type>>,
remove_cv_t<remove_reference_t<FloatB>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename CBuffer::type>>,
remove_cv_t<remove_reference_t<FloatC>>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
......
...@@ -41,19 +41,16 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3 ...@@ -41,19 +41,16 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
CDesc::IsKnownAtCompileTime(), CDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<AOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<BOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<COriginIdx>>::value,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
static_assert( static_assert(
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value && is_same<remove_cvref_t<typename ABuffer::type>, remove_cvref_t<FloatA>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value && is_same<remove_cvref_t<typename BBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value, is_same<remove_cvref_t<typename CBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); "wrong! inconsistent type");
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABuffer::type>>,
remove_cv_t<remove_reference_t<FloatA>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename BBuffer::type>>,
remove_cv_t<remove_reference_t<FloatB>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename CBuffer::type>>,
remove_cv_t<remove_reference_t<FloatC>>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
......
...@@ -30,11 +30,11 @@ struct ThreadwiseTensorSliceSet_v1 ...@@ -30,11 +30,11 @@ struct ThreadwiseTensorSliceSet_v1
static_assert(Buffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); static_assert(Buffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
static_assert(is_known_at_compile_time<remove_cv_t<remove_reference_t<OriginIdx>>>::value, static_assert(is_known_at_compile_time<remove_cvref_t<OriginIdx>>::value,
"wrong! OriginIdx need to be known at compile-time"); "wrong! OriginIdx need to be known at compile-time");
// Desc is known at compile-time // Desc is known at compile-time
constexpr auto desc = remove_cv_t<remove_reference_t<Desc>>{}; constexpr auto desc = remove_cvref_t<Desc>{};
// OriginIdx is known at compile-time // OriginIdx is known at compile-time
constexpr auto origin_idx = to_multi_index(OriginIdx{}); constexpr auto origin_idx = to_multi_index(OriginIdx{});
......
...@@ -95,18 +95,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -95,18 +95,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3
static_assert(SrcDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time"); "wrong! SrcDesc need to known at compile-time");
static_assert( static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value,
is_known_at_compile_time<remove_cv_t<remove_reference_t<SrcSliceOriginIdx>>>::value, "wrong! SrcSliceOrigin need to known at compile-time");
"wrong! SrcSliceOrigin need to known at compile-time");
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer"); static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
// static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
// remove_cv_t<remove_reference_t<SrcData>>>::value,
//"wrong! SrcBuffer data type is wrong");
// SrcDesc and src_slice_origin_idx are known at compile-time // SrcDesc and src_slice_origin_idx are known at compile-time
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{}; constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -208,10 +203,20 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -208,10 +203,20 @@ struct ThreadwiseTensorSliceTransfer_v1r3
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
// copy data from dst_vector into dst_buf // copy data from dst_vector into dst_buf
dst_buf.template Set<dst_vector_t>( if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Set)
dst_coord_.GetOffset(), {
is_dst_valid, dst_buf.template Set<dst_vector_t>(
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]); dst_coord_.GetOffset(),
is_dst_valid,
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
}
else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::AtomicAdd)
{
dst_buf.template AtomicAdd<dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
}
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
{ {
...@@ -411,16 +416,15 @@ struct ThreadwiseTensorSliceTransfer_v2 ...@@ -411,16 +416,15 @@ struct ThreadwiseTensorSliceTransfer_v2
static_assert(DstDesc::IsKnownAtCompileTime(), static_assert(DstDesc::IsKnownAtCompileTime(),
"wrong! DstDesc need to known at compile-time"); "wrong! DstDesc need to known at compile-time");
static_assert( static_assert(is_known_at_compile_time<remove_cvref_t<DstSliceOriginIdx>>::value,
is_known_at_compile_time<remove_cv_t<remove_reference_t<DstSliceOriginIdx>>>::value, "wrong! DstSliceOrigin need to known at compile-time");
"wrong! DstSliceOrigin need to known at compile-time");
static_assert(is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>, static_assert(
remove_cv_t<remove_reference_t<DstData>>>::value && is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value &&
"wrong! inconsistent type"); "wrong! inconsistent type");
// DstDesc and dst_slice_origin_idx are known at compile-time // DstDesc and dst_slice_origin_idx are known at compile-time
constexpr auto dst_desc = remove_cv_t<remove_reference_t<DstDesc>>{}; constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{}; constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{};
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -732,9 +736,9 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -732,9 +736,9 @@ struct ThreadwiseTensorSliceTransfer_v3
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
"wrong!"); "wrong!");
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>, static_assert(
remove_cv_t<remove_reference_t<SrcData>>>::value, is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value,
"wrong! SrcBuffer and SrcData data type are inconsistent"); "wrong! SrcBuffer and SrcData data type are inconsistent");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -889,9 +893,9 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -889,9 +893,9 @@ struct ThreadwiseTensorSliceTransfer_v3
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
"wrong!"); "wrong!");
static_assert(is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>, static_assert(
remove_cv_t<remove_reference_t<DstData>>>::value, is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
"wrong! SrcBuffer or DstBuffer data type is wrong"); "wrong! SrcBuffer or DstBuffer data type is wrong");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -1305,24 +1309,21 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1305,24 +1309,21 @@ struct ThreadwiseTensorSliceTransfer_v4
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time"); "wrong! SrcDesc and DstDesc need to known at compile-time");
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>, static_assert(
remove_cv_t<remove_reference_t<SrcData>>>::value && is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value &&
is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>, is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
remove_cv_t<remove_reference_t<DstData>>>::value, "wrong! SrcBuffer or DstBuffer data type is wrong");
"wrong! SrcBuffer or DstBuffer data type is wrong");
static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
static_assert( static_assert(is_known_at_compile_time<remove_cvref_t<SrcRefToOriginDisplacement>>::value &&
is_known_at_compile_time< is_known_at_compile_time<remove_cvref_t<DstOriginIdx>>::value,
remove_cv_t<remove_reference_t<SrcRefToOriginDisplacement>>>::value && "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
is_known_at_compile_time<remove_cv_t<remove_reference_t<DstOriginIdx>>>::value, "at compile-time");
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time");
// SrcDesc and DstDesc are known at compile-time // SrcDesc and DstDesc are known at compile-time
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{}; constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto dst_desc = remove_cv_t<remove_reference_t<DstDesc>>{}; constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{});
......
...@@ -80,9 +80,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -80,9 +80,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
"wrong!"); "wrong!");
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>, static_assert(
remove_cv_t<remove_reference_t<SrcData>>>::value, is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value,
"wrong! SrcBuffer and SrcData data type are inconsistent"); "wrong! SrcBuffer and SrcData data type are inconsistent");
// tensor descriptor for src_vector // tensor descriptor for src_vector
constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{}; constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{};
...@@ -248,9 +248,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -248,9 +248,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
"wrong!"); "wrong!");
static_assert(is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>, static_assert(
remove_cv_t<remove_reference_t<DstData>>>::value, is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
"wrong! SrcBuffer or DstBuffer data type is wrong"); "wrong! SrcBuffer or DstBuffer data type is wrong");
// tensor descriptor for dst_vector // tensor descriptor for dst_vector
constexpr auto dst_vector_tensor_lengths = DstVectorTensorLengths{}; constexpr auto dst_vector_tensor_lengths = DstVectorTensorLengths{};
...@@ -669,24 +669,21 @@ struct ThreadwiseTensorSliceTransfer_v4r1 ...@@ -669,24 +669,21 @@ struct ThreadwiseTensorSliceTransfer_v4r1
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time"); "wrong! SrcDesc and DstDesc need to known at compile-time");
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>, static_assert(
remove_cv_t<remove_reference_t<SrcData>>>::value && is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value &&
is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>, is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
remove_cv_t<remove_reference_t<DstData>>>::value, "wrong! SrcBuffer or DstBuffer data type is wrong");
"wrong! SrcBuffer or DstBuffer data type is wrong");
static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
static_assert( static_assert(is_known_at_compile_time<remove_cvref_t<SrcRefToOriginDisplacement>>::value &&
is_known_at_compile_time< is_known_at_compile_time<remove_cvref_t<DstOriginIdx>>::value,
remove_cv_t<remove_reference_t<SrcRefToOriginDisplacement>>>::value && "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
is_known_at_compile_time<remove_cv_t<remove_reference_t<DstOriginIdx>>>::value, "at compile-time");
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time");
// SrcDesc and DstDesc are known at compile-time // SrcDesc and DstDesc are known at compile-time
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{}; constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto dst_desc = remove_cv_t<remove_reference_t<DstDesc>>{}; constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{});
......
...@@ -7,21 +7,18 @@ ...@@ -7,21 +7,18 @@
namespace ck { namespace ck {
enum struct mfma_instr enum struct MfmaInstr
{ {
/// fp32
mfma_f32_32x32x1xf32 = 0, mfma_f32_32x32x1xf32 = 0,
mfma_f32_16x16x1xf32, mfma_f32_16x16x1xf32,
mfma_f32_4x4x1xf32, mfma_f32_4x4x1xf32,
mfma_f32_32x32x2xf32, // k reduction mfma_f32_32x32x2xf32, // k reduction
mfma_f32_16x16x4xf32, // k reduction mfma_f32_16x16x4xf32, // k reduction
/// fp16
mfma_f32_32x32x4f16, mfma_f32_32x32x4f16,
mfma_f32_16x16x4f16, mfma_f32_16x16x4f16,
mfma_f32_4x4x4f16, mfma_f32_4x4x4f16,
mfma_f32_32x32x8f16, // k reduction mfma_f32_32x32x8f16, // k reduction
mfma_f32_16x16x16f16, // k reduction mfma_f32_16x16x16f16, // k reduction
/// bfp16
mfma_f32_32x32x2bf16, mfma_f32_32x32x2bf16,
mfma_f32_16x16x2bf16, mfma_f32_16x16x2bf16,
mfma_f32_4x4x2bf16, mfma_f32_4x4x2bf16,
...@@ -29,25 +26,23 @@ enum struct mfma_instr ...@@ -29,25 +26,23 @@ enum struct mfma_instr
mfma_f32_16x16x8bf16, // k reduction mfma_f32_16x16x8bf16, // k reduction
}; };
template <mfma_instr instr> template <MfmaInstr instr>
struct mfma_info; struct mfma_type;
template <> template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32> struct mfma_type<MfmaInstr::mfma_f32_32x32x1xf32>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4; static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_blk = 32; static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk; static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 2; static constexpr index_t num_output_blks = 2;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; static constexpr index_t m_per_blk = 32;
static constexpr index_t m = 32; static constexpr index_t n_per_blk = 32;
static constexpr index_t n = 32; static constexpr index_t k_per_blk = 1;
static constexpr index_t k = 1; static constexpr bool is_k_reduction = false;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 1;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -62,21 +57,19 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32> ...@@ -62,21 +57,19 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
}; };
template <> template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32> struct mfma_type<MfmaInstr::mfma_f32_32x32x2xf32>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4; static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_blk = 32; static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk; static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1; static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; static constexpr index_t m_per_blk = 32;
static constexpr index_t m = 32; static constexpr index_t n_per_blk = 32;
static constexpr index_t n = 32; static constexpr index_t k_per_blk = 1;
static constexpr index_t k = 2; static constexpr bool is_k_reduction = true;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 1;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -91,21 +84,19 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32> ...@@ -91,21 +84,19 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
}; };
template <> template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32> struct mfma_type<MfmaInstr::mfma_f32_16x16x4xf32>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1; static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_blk = 16; static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk; static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1; static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; static constexpr index_t m_per_blk = 16;
static constexpr index_t m = 16; static constexpr index_t n_per_blk = 16;
static constexpr index_t n = 16; static constexpr index_t k_per_blk = 1;
static constexpr index_t k = 4; static constexpr bool is_k_reduction = true;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 1;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -120,21 +111,19 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32> ...@@ -120,21 +111,19 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32>
}; };
template <> template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x1xf32> struct mfma_type<MfmaInstr::mfma_f32_16x16x1xf32>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1; static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_blk = 16; static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk; static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 4; static constexpr index_t num_output_blks = 4;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; static constexpr index_t m_per_blk = 16;
static constexpr index_t m = 16; static constexpr index_t n_per_blk = 16;
static constexpr index_t n = 16; static constexpr index_t k_per_blk = 1;
static constexpr index_t k = 1; static constexpr bool is_k_reduction = false;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 1;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -150,21 +139,19 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x1xf32> ...@@ -150,21 +139,19 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x1xf32>
// treat 4x4x1 as a single-blk 4x64 mfma // treat 4x4x1 as a single-blk 4x64 mfma
template <> template <>
struct mfma_info<mfma_instr::mfma_f32_4x4x1xf32> struct mfma_type<MfmaInstr::mfma_f32_4x4x1xf32>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1; static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_blk = 64; static constexpr index_t num_threads_per_blk = 64;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1; static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1; static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = 4; static constexpr index_t m_per_blk = 4;
static constexpr index_t m = 4; static constexpr index_t n_per_blk = 64;
static constexpr index_t n = 64; static constexpr index_t k_per_blk = 1;
static constexpr index_t k = 1; static constexpr bool is_k_reduction = false;
static constexpr index_t cycles = 8;
static constexpr index_t k_base = 1;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -179,21 +166,19 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x1xf32> ...@@ -179,21 +166,19 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x1xf32>
}; };
template <> template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x4f16> struct mfma_type<MfmaInstr::mfma_f32_32x32x4f16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4; static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_blk = 32; static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk; static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 2; static constexpr index_t num_output_blks = 2;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; static constexpr index_t m_per_blk = 32;
static constexpr index_t m = 32; static constexpr index_t n_per_blk = 32;
static constexpr index_t n = 32; static constexpr index_t k_per_blk = 4;
static constexpr index_t k = 4; static constexpr bool is_k_reduction = false;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 4;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -208,21 +193,19 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4f16> ...@@ -208,21 +193,19 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4f16>
}; };
template <> template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x8f16> struct mfma_type<MfmaInstr::mfma_f32_32x32x8f16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4; static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_blk = 32; static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk; static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1; static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; static constexpr index_t m_per_blk = 32;
static constexpr index_t m = 32; static constexpr index_t n_per_blk = 32;
static constexpr index_t n = 32; static constexpr index_t k_per_blk = 4;
static constexpr index_t k = 8; static constexpr bool is_k_reduction = true;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 4;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -237,21 +220,19 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x8f16> ...@@ -237,21 +220,19 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x8f16>
}; };
template <> template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x16f16> struct mfma_type<MfmaInstr::mfma_f32_16x16x16f16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1; static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_blk = 16; static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk; static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1; static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; static constexpr index_t m_per_blk = 16;
static constexpr index_t m = 16; static constexpr index_t n_per_blk = 16;
static constexpr index_t n = 16; static constexpr index_t k_per_blk = 4;
static constexpr index_t k = 16; static constexpr bool is_k_reduction = true;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 4;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -266,21 +247,19 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x16f16> ...@@ -266,21 +247,19 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x16f16>
}; };
template <> template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x4f16> struct mfma_type<MfmaInstr::mfma_f32_16x16x4f16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1; static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_blk = 16; static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk; static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 4; static constexpr index_t num_output_blks = 4;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; static constexpr index_t m_per_blk = 16;
static constexpr index_t m = 16; static constexpr index_t n_per_blk = 16;
static constexpr index_t n = 16; static constexpr index_t k_per_blk = 4;
static constexpr index_t k = 4; static constexpr bool is_k_reduction = false;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 4;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -295,21 +274,19 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4f16> ...@@ -295,21 +274,19 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4f16>
}; };
template <> template <>
struct mfma_info<mfma_instr::mfma_f32_4x4x4f16> struct mfma_type<MfmaInstr::mfma_f32_4x4x4f16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1; static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_blk = 64; static constexpr index_t num_threads_per_blk = 64;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1; static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1; static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = 4; static constexpr index_t m_per_blk = 4;
static constexpr index_t m = 4; static constexpr index_t n_per_blk = 64;
static constexpr index_t n = 64; static constexpr index_t k_per_blk = 4;
static constexpr index_t k = 4; static constexpr bool is_k_reduction = false;
static constexpr index_t cycles = 8;
static constexpr index_t k_base = 4;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -325,21 +302,19 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x4f16> ...@@ -325,21 +302,19 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x4f16>
#if 0 #if 0
template <> template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16> struct mfma_type<MfmaInstr::mfma_f32_32x32x2bf16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4; static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_blk = 32; static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk; static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 2; static constexpr index_t num_output_blks = 2;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; static constexpr index_t m_per_blk = 32;
static constexpr index_t m = 32; static constexpr index_t n_per_blk = 32;
static constexpr index_t n = 32; static constexpr index_t k_per_blk = 2;
static constexpr index_t k = 2; static constexpr bool is_k_reduction = false;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 2;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -359,21 +334,19 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16> ...@@ -359,21 +334,19 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
}; };
template <> template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16> struct mfma_type<MfmaInstr::mfma_f32_32x32x4bf16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4; static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_blk = 32; static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk; static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1; static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; static constexpr index_t m_per_blk = 32;
static constexpr index_t m = 32; static constexpr index_t n_per_blk = 32;
static constexpr index_t n = 32; static constexpr index_t k_per_blk = 2;
static constexpr index_t k = 4; static constexpr bool is_k_reduction = true;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 2;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -392,21 +365,19 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16> ...@@ -392,21 +365,19 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
}; };
template <> template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16> struct mfma_type<MfmaInstr::mfma_f32_16x16x8bf16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1; static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_blk = 16; static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk; static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1; static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; static constexpr index_t m_per_blk = 16;
static constexpr index_t m = 16; static constexpr index_t n_per_blk = 16;
static constexpr index_t n = 16; static constexpr index_t k_per_blk = 2;
static constexpr index_t k = 8; static constexpr bool is_k_reduction = true;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 2;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -425,21 +396,19 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16> ...@@ -425,21 +396,19 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
}; };
template <> template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16> struct mfma_type<MfmaInstr::mfma_f32_16x16x2bf16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1; static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_blk = 16; static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk; static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 4; static constexpr index_t num_output_blks = 4;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; static constexpr index_t m_per_blk = 16;
static constexpr index_t m = 16; static constexpr index_t n_per_blk = 16;
static constexpr index_t n = 16; static constexpr index_t k_per_blk = 2;
static constexpr index_t k = 2; static constexpr bool is_k_reduction = false;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 2;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -458,21 +427,19 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16> ...@@ -458,21 +427,19 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
}; };
template <> template <>
struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16> struct mfma_type<MfmaInstr::mfma_f32_4x4x2bf16>
{ {
static constexpr index_t group_size = 4; static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1; static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk; static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_blk = 64; static constexpr index_t num_threads_per_blk = 64;
static constexpr index_t wave_size = 64; static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1; static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1; static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = 4; static constexpr index_t m_per_blk = 4;
static constexpr index_t m = 4; static constexpr index_t n_per_blk = 64;
static constexpr index_t n = 64; static constexpr index_t k_per_blk = 2;
static constexpr index_t k = 2; static constexpr bool is_k_reduction = false;
static constexpr index_t cycles = 8;
static constexpr index_t k_base = 2;
template <index_t MPerXdlops, template <index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
...@@ -491,200 +458,227 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16> ...@@ -491,200 +458,227 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
}; };
#endif #endif
template <mfma_instr instr, index_t MPerXdlops_, index_t NPerXdlops_> template <typename base_type, index_t MPerXdlops, index_t NPerXdlops>
struct xdlops_info struct MfmaSelector
{ {
static constexpr auto mfma_type = mfma_info<instr>{}; template <typename base_type_, index_t MPerXdlops_, index_t NPerXdlops_>
static constexpr auto GetMfma();
static constexpr index_t MPerXdlops = MPerXdlops_; template <>
static constexpr index_t NPerXdlops = NPerXdlops_; static constexpr auto GetMfma<float, 64, 64>()
static constexpr bool IsABroadcast()
{ {
static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast"); return MfmaInstr::mfma_f32_32x32x1xf32;
return true;
} }
static constexpr bool IsKReduction() template <>
static constexpr auto GetMfma<float, 32, 64>()
{ {
return (mfma_type.num_output_blks == 1) && (mfma_type.num_input_blks > 1); return MfmaInstr::mfma_f32_32x32x1xf32;
} }
static constexpr index_t GetKPerXdlops() template <>
static constexpr auto GetMfma<float, 16, 64>()
{ {
return IsKReduction() ? mfma_type.num_input_blks : 1; return MfmaInstr::mfma_f32_16x16x1xf32;
} }
static constexpr index_t GetNumCRegs() { return MPerXdlops * NPerXdlops / mfma_type.wave_size; }
};
template <class base_type, index_t MPerWave, index_t NPerWave, index_t KPack>
struct XdlopsGemm
{
template <class base_type_ = base_type,
index_t MPerWave_ = MPerWave,
index_t NPerWave_ = NPerWave>
static constexpr auto GetXdlopsInfo();
template <> template <>
static constexpr auto GetXdlopsInfo<float, 64, 64>() static constexpr auto GetMfma<float, 8, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 64>{}; return MfmaInstr::mfma_f32_4x4x1xf32;
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 32, 64>() static constexpr auto GetMfma<float, 4, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 32, 64>{}; return MfmaInstr::mfma_f32_4x4x1xf32;
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 16, 64>() static constexpr auto GetMfma<float, 32, 32>()
{ {
return xdlops_info<mfma_instr::mfma_f32_16x16x1xf32, 16, 64>{}; return MfmaInstr::mfma_f32_32x32x2xf32;
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 8, 64>() static constexpr auto GetMfma<float, 16, 16>()
{ {
return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 8, 64>{}; return MfmaInstr::mfma_f32_16x16x4xf32;
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 4, 64>() static constexpr auto GetMfma<half_t, 64, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 4, 64>{}; return MfmaInstr::mfma_f32_32x32x4f16;
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 32, 32>() static constexpr auto GetMfma<half_t, 32, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x2xf32, 32, 32>{}; return MfmaInstr::mfma_f32_32x32x4f16;
} }
template <> template <>
static constexpr auto GetXdlopsInfo<float, 16, 16>() static constexpr auto GetMfma<half_t, 32, 32>()
{ {
return xdlops_info<mfma_instr::mfma_f32_16x16x4xf32, 16, 16>{}; return MfmaInstr::mfma_f32_32x32x8f16;
} }
template <> template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 64>() static constexpr auto GetMfma<half_t, 16, 16>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64>{}; return MfmaInstr::mfma_f32_16x16x16f16;
} }
template <> template <>
static constexpr auto GetXdlopsInfo<half_t, 32, 64>() static constexpr auto GetMfma<half_t, 16, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 32, 64>{}; return MfmaInstr::mfma_f32_16x16x4f16;
} }
template <> template <>
static constexpr auto GetXdlopsInfo<half_t, 32, 32>() static constexpr auto GetMfma<half_t, 8, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x8f16, 32, 32>{}; return MfmaInstr::mfma_f32_4x4x4f16;
} }
template <> template <>
static constexpr auto GetXdlopsInfo<half_t, 16, 16>() static constexpr auto GetMfma<half_t, 4, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_16x16x16f16, 16, 16>{}; return MfmaInstr::mfma_f32_4x4x4f16;
} }
#if 0
template <> template <>
static constexpr auto GetXdlopsInfo<half_t, 16, 64>() static constexpr auto GetMfma<ushort, 128, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_16x16x4f16, 16, 64>{}; return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 64, 2, 1, c_vec32_4_t>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<half_t, 8, 64>() static constexpr auto GetMfma<ushort, 64, 128>()
{ {
return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 8, 64>{}; return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 64, 1, 2, c_vec32_4_t>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<half_t, 4, 64>() static constexpr auto GetMfma<ushort, 64, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 4, 64>{}; return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 64, 1, 1, c_vec32_2_t>{};
} }
#if 0
template <> template <>
static constexpr auto GetXdlopsInfo<ushort, 128, 64>() static constexpr auto GetMfma<ushort, 64, 32>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 2, 1, c_vec32_4_t>{}; return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 32, 1, 1, c_vec32_1_t>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 128>() static constexpr auto GetMfma<ushort, 32, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 1, 2, c_vec32_4_t>{}; return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 32, 64, 1, 1, c_vec32_1_t>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 64>() static constexpr auto GetMfma<ushort, 64, 16>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 1, 1, c_vec32_2_t>{}; return xdlops_info<MfmaInstr::mfma_f32_16x16x2bf16, 64, 16, 1, 1, c_vec16_1_t>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 32>() static constexpr auto GetMfma<ushort, 16, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 32, 1, 1, c_vec32_1_t>{}; return xdlops_info<MfmaInstr::mfma_f32_16x16x2bf16, 16, 64, 1, 1, c_vec16_1_t>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<ushort, 32, 64>() static constexpr auto GetMfma<ushort, 8, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 32, 64, 1, 1, c_vec32_1_t>{}; return xdlops_info<MfmaInstr::mfma_f32_4x4x2bf16, 8, 64, 1, 1, c_vec4_2_t>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 16>() static constexpr auto GetMfma<ushort, 4, 64>()
{ {
return xdlops_info<mfma_instr::mfma_f32_16x16x2bf16, 64, 16, 1, 1, c_vec16_1_t>{}; return xdlops_info<MfmaInstr::mfma_f32_4x4x2bf16, 4, 64, 1, 1, c_vec4_1_t>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<ushort, 16, 64>() static constexpr auto GetMfma<ushort, 32, 32>()
{ {
return xdlops_info<mfma_instr::mfma_f32_16x16x2bf16, 16, 64, 1, 1, c_vec16_1_t>{}; return xdlops_info<MfmaInstr::mfma_f32_32x32x4bf16, 32, 32, 1, 1, c_vec16_1_t>{};
} }
template <> template <>
static constexpr auto GetXdlopsInfo<ushort, 8, 64>() static constexpr auto GetMfma<ushort, 16, 16>()
{ {
return xdlops_info<mfma_instr::mfma_f32_4x4x2bf16, 8, 64, 1, 1, c_vec4_2_t>{}; return xdlops_info<MfmaInstr::mfma_f32_16x16x8bf16, 16, 16, 1, 1, c_vec4_1_t>{};
} }
#endif
template <> static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
static constexpr auto GetXdlopsInfo<ushort, 4, 64>()
__host__ __device__ static constexpr void mfma_check()
{ {
return xdlops_info<mfma_instr::mfma_f32_4x4x2bf16, 4, 64, 1, 1, c_vec4_1_t>{}; static_assert(selected_mfma.group_size * selected_mfma.num_groups_per_blk ==
selected_mfma.num_regs_per_blk,
"wrong! num_regs_per_blk");
static_assert(selected_mfma.num_threads_per_blk == selected_mfma.n_per_blk,
"n_per_blk != num_threads_per_blk");
static_assert(selected_mfma.num_regs_per_blk * selected_mfma.num_input_blks ==
selected_mfma.m_per_blk,
"m_per_blk != num_input_blks * num_regs_per_blk");
static_assert(selected_mfma.num_output_blks == selected_mfma.num_input_blks ||
selected_mfma.num_output_blks == 1,
"incorrect num_output_blks");
static_assert(selected_mfma.num_regs_per_blk * selected_mfma.wave_size ==
selected_mfma.m_per_blk * selected_mfma.n_per_blk,
"num_regs_per_blk incorrect");
static_assert(selected_mfma.is_k_reduction ||
(selected_mfma.num_input_blks == selected_mfma.num_output_blks),
"is_k_reduction wrong!");
} }
template <> __host__ __device__ constexpr MfmaSelector() { mfma_check(); }
static constexpr auto GetXdlopsInfo<ushort, 32, 32>()
static constexpr bool IsABroadcast()
{ {
return xdlops_info<mfma_instr::mfma_f32_32x32x4bf16, 32, 32, 1, 1, c_vec16_1_t>{}; static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast");
return true;
} }
template <> static constexpr index_t GetKPerXdlops()
static constexpr auto GetXdlopsInfo<ushort, 16, 16>()
{ {
return xdlops_info<mfma_instr::mfma_f32_16x16x8bf16, 16, 16, 1, 1, c_vec4_1_t>{}; return (selected_mfma.is_k_reduction ? selected_mfma.num_input_blks : 1) *
selected_mfma.k_per_blk;
} }
#endif
static constexpr index_t GetKPerThread() { return selected_mfma.k_per_blk; }
};
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack>
struct XdlopsGemm
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
using CIndex = MultiIndex<2>; using CIndex = MultiIndex<2>;
__device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; } __device__ static constexpr index_t GetNumBlks() { return mfma_instr.num_output_blks; }
__device__ static constexpr index_t GetNumXdlops() __device__ static constexpr index_t GetNumXdlops()
{ {
return MPerXdlops * NPerXdlops / (mfma_type.m * mfma_type.n * mfma_type.num_output_blks); return MPerXdlops * NPerXdlops /
(mfma_instr.m_per_blk * mfma_instr.n_per_blk * mfma_instr.num_output_blks);
} }
__host__ __device__ constexpr XdlopsGemm() __host__ __device__ constexpr XdlopsGemm()
...@@ -697,104 +691,142 @@ struct XdlopsGemm ...@@ -697,104 +691,142 @@ struct XdlopsGemm
MPerXdlops == 64, MPerXdlops == 64,
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"); "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
static_assert(mfma_type.num_threads_blk == mfma_type.n, "n != num_threads_blk"); static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk");
static_assert(mfma_type.num_regs_blk * mfma_type.num_input_blks == mfma_type.m, }
"m != num_input_blks * num_regs_blk");
static_assert(mfma_type.num_output_blks == mfma_type.num_input_blks || template <typename CM0N0M1N1M2N2Desc>
mfma_type.num_output_blks == 1, __host__ __device__ static constexpr auto
"incorrect num_output_blks"); MakeCM0N0M1N1M2M3M4N2Descriptor(const CM0N0M1N1M2N2Desc& c_m0_n0_m1_n1_m2_n2_desc)
static_assert(mfma_type.num_regs_blk * mfma_type.wave_size == mfma_type.m * mfma_type.n, {
"num_regs_blk incorrect"); const auto M0 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I0);
const auto N0 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I1);
static_assert(mfma_type.k % mfma_type.k_base == 0, "k % kbase != 0!"); const auto M1 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I2);
const auto N1 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I3);
return transform_tensor_descriptor(
c_m0_n0_m1_n1_m2_n2_desc,
make_tuple(make_pass_through_transform(M0),
make_pass_through_transform(N0),
make_pass_through_transform(M1),
make_pass_through_transform(N1),
make_unmerge_transform(make_tuple(mfma_instr.num_groups_per_blk,
mfma_instr.num_input_blks,
mfma_instr.group_size)),
make_pass_through_transform(mfma_instr.num_threads_per_blk)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4, 5, 6>{},
Sequence<7>{}));
} }
__device__ static constexpr index_t GetRegSizePerXdlops() __device__ static constexpr index_t GetRegSizePerXdlops()
{ {
return MPerXdlops * NPerXdlops / mfma_type.wave_size; return MPerXdlops * NPerXdlops / mfma_instr.wave_size;
} }
template <class ADesc, template <index_t c_offset, class FloatA, class FloatB, class FloatC>
class BDesc,
class CDesc,
index_t m0,
index_t n0,
class FloatA,
class FloatB,
class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{ {
static_assert(is_same<base_type, float>::value || is_same<base_type, half_t>::value || static_assert(is_same<base_type, float>::value || is_same<base_type, half_t>::value ||
is_same<base_type, ushort>::value, is_same<base_type, ushort>::value,
"base base_type must be float, half, ushort!"); "base base_type must be float, half, ushort!");
static_assert(KPack % mfma_type.k_base == 0, "KPack cannot be divided by k_base"); static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
mfma_instr.template run<MPerXdlops, NPerXdlops, c_offset>(
p_a_wave[k], p_b_wave[k], p_c_thread);
});
}
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % mfma_instr.wave_size; }
constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0)) * GetNumXdlops(); __device__ static auto GetBlkIdx()
{
const auto laneId = GetLaneId();
static_for<0, KPack, mfma_type.k_base>{}([&](auto k) { constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(0, m0, 0, k)); make_tuple(make_merge_transform(
constexpr index_t b_offset = BDesc{}.CalculateOffset(make_tuple(0, n0, 0, k)); make_tuple(1, mfma_instr.num_input_blks, mfma_instr.num_threads_per_blk))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
mfma_type.template run<MPerXdlops, NPerXdlops, c_offset>( const auto blk_idx =
p_a_wave[Number<a_offset / mfma_type.k_base>{}], threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId));
p_b_wave[Number<b_offset / mfma_type.k_base>{}],
p_c_thread); const auto blk_id = blk_idx[I1];
}); const auto blk_td = blk_idx[I2];
return make_tuple(blk_id, blk_td);
} }
__device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i) __host__ __device__ static auto CalculateAThreadOriginDataIndex()
{ {
const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size; const auto laneId = GetLaneId();
const index_t blk_id = laneId / mfma_type.num_threads_blk; const auto blk_idx = GetBlkIdx();
const index_t blk_td = laneId % mfma_type.num_threads_blk;
index_t n_offset = blk_i * mfma_type.n + blk_td; const auto blk_id = blk_idx[I0];
index_t m_offset = xdlops_i * mfma_type.m + blk_id * mfma_type.group_size; const auto blk_td = blk_idx[I1];
return CIndex{m_offset, n_offset}; if constexpr(mfma_instr.is_k_reduction)
{
return make_tuple(blk_id, blk_td);
}
else
{
return make_tuple(0, laneId);
}
} }
static constexpr index_t MRepeats = GetXdlopsInfo().MRepeats; __host__ __device__ static auto CalculateBThreadOriginDataIndex()
static constexpr index_t NRepeats = GetXdlopsInfo().NRepeats; {
static constexpr index_t MPerXdlops = GetXdlopsInfo().MPerXdlops; const auto laneId = GetLaneId();
static constexpr index_t NPerXdlops = GetXdlopsInfo().NPerXdlops; const auto blk_idx = GetBlkIdx();
static constexpr bool IsKReduction = GetXdlopsInfo().IsKReduction(); const auto blk_id = blk_idx[I0];
static constexpr bool IsABroadcast = GetXdlopsInfo().IsABroadcast(); const auto blk_td = blk_idx[I1];
static constexpr index_t KPerXdlops = GetXdlopsInfo().GetKPerXdlops();
static constexpr auto GetBlkId(const index_t lane_id) if constexpr(mfma_instr.is_k_reduction)
{ {
return lane_id / mfma_type.num_threads_blk; return make_tuple(blk_id, blk_td);
}
else
{
return make_tuple(0, laneId);
}
} }
static constexpr auto GetBlkTd(const index_t lane_id) __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
{ {
return lane_id % mfma_type.num_threads_blk; const auto blk_idx = GetBlkIdx();
}
static constexpr auto mfma_type = GetXdlopsInfo().mfma_type; const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
struct CLayout index_t n_offset = blk_i * mfma_instr.n_per_blk + blk_td;
{ index_t m_offset = xdlops_i * mfma_instr.m_per_blk + blk_id * mfma_instr.group_size;
__host__ __device__ static constexpr index_t M1() { return mfma_type.num_groups_blk; }
__host__ __device__ static constexpr index_t M0() { return mfma_type.group_size; }
__host__ __device__ static constexpr index_t N1() { return mfma_type.num_input_blks; }
__host__ __device__ static constexpr index_t N0() { return mfma_type.num_threads_blk; }
__device__ static constexpr index_t GetBlkSize() { return mfma_type.num_regs_blk; } return CIndex{m_offset, n_offset};
}
__device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; } static constexpr auto mfma = MfmaSelector<base_type, MPerXdlops, NPerXdlops>{};
__device__ static constexpr index_t GetNumXdlops() static constexpr auto mfma_instr = mfma.selected_mfma;
{
return MPerXdlops * NPerXdlops / static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
(mfma_type.m * mfma_type.n * mfma_type.num_output_blks); static constexpr auto KPerThread = mfma.GetKPerThread();
}
};
__host__ __device__ static constexpr auto GetCLayout() { return CLayout{}; } __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths()
{
return make_tuple(
Number<mfma_instr.num_groups_per_blk>{}, I1, Number<mfma_instr.group_size>{}, I1);
}
}; };
} // namespace ck } // namespace ck
......
...@@ -202,6 +202,22 @@ llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata, ...@@ -202,6 +202,22 @@ llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
// atomic add
// int
__device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
int32_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32");
// float
__device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32(
float vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32");
template <typename T, index_t N> template <typename T, index_t N>
__device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource,
...@@ -209,13 +225,49 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -209,13 +225,49 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
index_t src_wave_addr_offset) index_t src_wave_addr_offset)
{ {
static_assert( static_assert(
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, double>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)), (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented"); "wrong! not implemented");
if constexpr(is_same<T, float>::value) if constexpr(is_same<T, double>::value)
{
// use fp32 load to mimic fp64 load
if constexpr(N == 1)
{
const float2_t tmp = llvm_amdgcn_raw_buffer_load_fp32x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<double>(tmp);
}
else if constexpr(N == 2)
{
const float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<double2_t>(tmp);
}
else if constexpr(N == 4)
{
const float4_t f32_0 = llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
const float4_t f32_1 =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(float),
0);
vector_type<double, 4> tmp;
tmp.AsType<double2_t>()(Number<0>{}) = as_type<double2_t>(f32_0);
tmp.AsType<double2_t>()(Number<1>{}) = as_type<double2_t>(f32_1);
return tmp.AsType<double4_t>()(Number<0>{});
}
}
else if constexpr(is_same<T, float>::value)
{ {
if constexpr(N == 1) if constexpr(N == 1)
{ {
...@@ -267,25 +319,11 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -267,25 +319,11 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
} }
else if constexpr(N == 8) else if constexpr(N == 8)
{ {
#if 0 // use fp32 load to mimic fp16 load
vector_type<half_t, 8> tmp;
tmp.AsType<half4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_fp16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.AsType<half4_t>()(Number<1>{}) =
llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(half_t),
0);
return tmp.AsType<half8_t>()(Number<0>{});
#else
float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4( float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<half8_t>(tmp); return as_type<half8_t>(tmp);
#endif
} }
} }
else if constexpr(is_same<T, int32_t>::value) else if constexpr(is_same<T, int32_t>::value)
...@@ -417,13 +455,34 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -417,13 +455,34 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
index_t dst_wave_addr_offset) index_t dst_wave_addr_offset)
{ {
static_assert( static_assert(
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, double>::value && (N == 1 || N == 2)) ||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)),
"wrong! not implemented"); "wrong! not implemented");
if constexpr(is_same<T, float>::value) if constexpr(is_same<T, double>::value)
{
// use fp32 store to mimic fp64 store
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_fp32x2(as_type<float2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_fp32x4(as_type<float4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
}
else if constexpr(is_same<T, float>::value)
{ {
if constexpr(N == 1) if constexpr(N == 1)
{ {
...@@ -450,6 +509,49 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -450,6 +509,49 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
0); 0);
} }
} }
else if constexpr(is_same<T, half_t>::value)
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_fp16(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 8)
{
vector_type<half_t, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(half_t),
0);
}
}
else if constexpr(is_same<T, int32_t>::value) else if constexpr(is_same<T, int32_t>::value)
{ {
if constexpr(N == 1) if constexpr(N == 1)
...@@ -536,53 +638,132 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -536,53 +638,132 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
0); 0);
} }
} }
else if constexpr(is_same<T, half_t>::value) }
template <typename T, index_t N>
__device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::type src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
{
static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)),
"wrong! not implemented");
if constexpr(is_same<T, float>::value)
{ {
if constexpr(N == 1) if constexpr(N == 1)
{ {
llvm_amdgcn_raw_buffer_store_fp16(src_thread_data, llvm_amdgcn_raw_buffer_atomic_add_fp32(src_thread_data,
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); 0);
} }
else if constexpr(N == 2) else if constexpr(N == 2)
{ {
llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data, vector_type<float, 2> tmp{src_thread_data};
dst_wave_buffer_resource,
dst_thread_addr_offset, llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<0>{}],
dst_wave_addr_offset, dst_wave_buffer_resource,
0); dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(float),
0);
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data, vector_type<float, 4> tmp{src_thread_data};
dst_wave_buffer_resource,
dst_thread_addr_offset, llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<0>{}],
dst_wave_addr_offset, dst_wave_buffer_resource,
0); dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(float),
0);
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<2>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 2 * sizeof(float),
0);
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<3>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 3 * sizeof(float),
0);
} }
else if constexpr(N == 8) }
else if constexpr(is_same<T, int32_t>::value)
{
if constexpr(N == 1)
{ {
vector_type<half_t, 8> tmp{src_thread_data}; llvm_amdgcn_raw_buffer_atomic_add_i32(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
vector_type<int32_t, 2> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}], llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<0>{}],
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); 0);
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}], llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<1>{}],
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(half_t), dst_wave_addr_offset + sizeof(int32_t),
0); 0);
}
else if constexpr(N == 4)
{
vector_type<int32_t, 4> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t),
0);
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<2>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 2 * sizeof(int32_t),
0);
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<3>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 3 * sizeof(int32_t),
0);
} }
} }
} }
// buffer_load requires: // buffer_load requires:
// 1) p_src_wave must be in global memory space // 1) p_src_wave must point to global memory space
// 2) p_src_wave must be a wavewise pointer. // 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true. // It is user's responsibility to make sure that is true.
template <typename T, index_t N> template <typename T, index_t N>
...@@ -616,7 +797,7 @@ amd_buffer_load_invalid_element_return_return_zero(const T* p_src_wave, ...@@ -616,7 +797,7 @@ amd_buffer_load_invalid_element_return_return_zero(const T* p_src_wave,
} }
// buffer_load requires: // buffer_load requires:
// 1) p_src_wave must be in global memory space // 1) p_src_wave must point to global memory space
// 2) p_src_wave must be a wavewise pointer. // 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true. // It is user's responsibility to make sure that is true.
template <typename T, index_t N> template <typename T, index_t N>
...@@ -644,8 +825,8 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, ...@@ -644,8 +825,8 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
} }
// buffer_store requires: // buffer_store requires:
// 1) p_dst_wave must be global memory // 1) p_dst_wave must point to global memory
// 2) p_dst_wave to be a wavewise pointer. // 2) p_dst_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true. // It is user's responsibility to make sure that is true.
template <typename T, index_t N> template <typename T, index_t N>
__device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::type src_thread_data, __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::type src_thread_data,
...@@ -677,5 +858,40 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -677,5 +858,40 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#endif #endif
} }
// buffer_atomic_add requires:
// 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t N>
__device__ void
amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thread_data,
T* p_dst_wave,
const index_t dst_thread_element_offset,
const bool dst_thread_element_valid,
const index_t dst_element_space_size)
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space_size);
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
using vector_t = typename vector_type_maker<T, N>::type::type;
using scalar_t = typename scalar_type<vector_t>::type;
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff;
amd_buffer_atomic_add_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#else
if(dst_thread_element_valid)
{
amd_buffer_atomic_add_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
#endif
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -48,7 +48,7 @@ struct Array<TData, 0> ...@@ -48,7 +48,7 @@ struct Array<TData, 0>
template <typename X, typename... Xs> template <typename X, typename... Xs>
__host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs) __host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs)
{ {
using data_type = remove_cv_t<remove_reference_t<X>>; using data_type = remove_cvref_t<X>;
return Array<data_type, sizeof...(Xs) + 1>{{std::forward<X>(x), std::forward<Xs>(xs)...}}; return Array<data_type, sizeof...(Xs) + 1>{{std::forward<X>(x), std::forward<Xs>(xs)...}};
} }
......
...@@ -85,8 +85,8 @@ ...@@ -85,8 +85,8 @@
#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1 #define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
#endif #endif
#ifndef CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK #ifndef CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK 1 #define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1
#endif #endif
// pass tensor descriptor by value or void* // pass tensor descriptor by value or void*
......
...@@ -73,6 +73,13 @@ struct scalar_type<vector_type<T, N>> ...@@ -73,6 +73,13 @@ struct scalar_type<vector_type<T, N>>
}; };
// //
template <>
struct scalar_type<double>
{
using type = double;
static constexpr index_t vector_size = 1;
};
template <> template <>
struct scalar_type<float> struct scalar_type<float>
{ {
...@@ -864,6 +871,10 @@ struct vector_type<T, 256> ...@@ -864,6 +871,10 @@ struct vector_type<T, 256>
} }
}; };
// fp64
using double2_t = typename vector_type<double, 2>::type;
using double4_t = typename vector_type<double, 4>::type;
// fp32 // fp32
using float2_t = typename vector_type<float, 2>::type; using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type; using float4_t = typename vector_type<float, 4>::type;
......
...@@ -39,18 +39,15 @@ struct DynamicBuffer ...@@ -39,18 +39,15 @@ struct DynamicBuffer
} }
template <typename X, template <typename X,
typename enable_if< typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type, typename scalar_type<remove_cvref_t<T>>::type>::value,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value, bool>::type = false>
bool>::type = false>
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const
{ {
// X contains multiple T // X contains multiple T
constexpr index_t scalar_per_t_vector = constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
scalar_type<remove_cv_t<remove_reference_t<T>>>::vector_size;
constexpr index_t scalar_per_x_vector = constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
scalar_type<remove_cv_t<remove_reference_t<X>>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X need to be multiple T"); "wrong! X need to be multiple T");
...@@ -67,15 +64,14 @@ struct DynamicBuffer ...@@ -67,15 +64,14 @@ struct DynamicBuffer
if constexpr(InvalidElementUseNumericalZeroValue) if constexpr(InvalidElementUseNumericalZeroValue)
{ {
return amd_buffer_load_invalid_element_return_return_zero< return amd_buffer_load_invalid_element_return_return_zero<remove_cvref_t<T>,
remove_cv_t<remove_reference_t<T>>, t_per_x>(
t_per_x>(p_data_, i, is_valid_element, element_space_size_); p_data_, i, is_valid_element, element_space_size_);
} }
else else
{ {
return amd_buffer_load_invalid_element_return_customized_value< return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>,
remove_cv_t<remove_reference_t<T>>, t_per_x>(
t_per_x>(
p_data_, i, is_valid_element, element_space_size_, invalid_element_value_); p_data_, i, is_valid_element, element_space_size_, invalid_element_value_);
} }
} }
...@@ -94,18 +90,15 @@ struct DynamicBuffer ...@@ -94,18 +90,15 @@ struct DynamicBuffer
} }
template <typename X, template <typename X,
typename enable_if< typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type, typename scalar_type<remove_cvref_t<T>>::type>::value,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value, bool>::type = false>
bool>::type = false>
__host__ __device__ void Set(index_t i, bool is_valid_element, const X& x) __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x)
{ {
// X contains multiple T // X contains multiple T
constexpr index_t scalar_per_t_vector = constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
scalar_type<remove_cv_t<remove_reference_t<T>>>::vector_size;
constexpr index_t scalar_per_x_vector = constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
scalar_type<remove_cv_t<remove_reference_t<X>>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X need to be multiple T"); "wrong! X need to be multiple T");
...@@ -115,7 +108,7 @@ struct DynamicBuffer ...@@ -115,7 +108,7 @@ struct DynamicBuffer
#if CK_USE_AMD_BUFFER_ADDRESSING #if CK_USE_AMD_BUFFER_ADDRESSING
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_store<remove_cv_t<remove_reference_t<T>>, t_per_x>( amd_buffer_store<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_); x, p_data_, i, is_valid_element, element_space_size_);
#else #else
if(is_valid_element) if(is_valid_element)
...@@ -136,70 +129,65 @@ struct DynamicBuffer ...@@ -136,70 +129,65 @@ struct DynamicBuffer
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to // ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
// ds_write_b128 // ds_write_b128
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
if constexpr(is_same<typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type, if constexpr(is_same<typename scalar_type<remove_cvref_t<T>>::type, int8_t>::value)
int8_t>::value)
{ {
static_assert( static_assert((is_same<remove_cvref_t<T>, int8_t>::value &&
(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value && is_same<remove_cvref_t<X>, int8_t>::value) ||
is_same<remove_cv_t<remove_reference_t<X>>, int8_t>::value) || (is_same<remove_cvref_t<T>, int8_t>::value &&
(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value && is_same<remove_cvref_t<X>, int8x2_t>::value) ||
is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value) || (is_same<remove_cvref_t<T>, int8_t>::value &&
(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value && is_same<remove_cvref_t<X>, int8x4_t>::value) ||
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) || (is_same<remove_cvref_t<T>, int8x4_t>::value &&
(is_same<remove_cv_t<remove_reference_t<T>>, int8x4_t>::value && is_same<remove_cvref_t<X>, int8x4_t>::value) ||
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) || (is_same<remove_cvref_t<T>, int8x8_t>::value &&
(is_same<remove_cv_t<remove_reference_t<T>>, int8x8_t>::value && is_same<remove_cvref_t<X>, int8x8_t>::value) ||
is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value) || (is_same<remove_cvref_t<T>, int8x16_t>::value &&
(is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value && is_same<remove_cvref_t<X>, int8x16_t>::value),
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value), "wrong! not implemented for this combination, please add "
"wrong! not implemented for this combination, please add " "implementation");
"implementation");
if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value && is_same<remove_cvref_t<X>, int8_t>::value)
is_same<remove_cv_t<remove_reference_t<X>>, int8_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*c_style_pointer_cast<int8_t*>(&p_data_[i]) = *c_style_pointer_cast<int8_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int8_t*>(&x); *c_style_pointer_cast<const int8_t*>(&x);
} }
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value && else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value) is_same<remove_cvref_t<X>, int8x2_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*c_style_pointer_cast<int16_t*>(&p_data_[i]) = *c_style_pointer_cast<int16_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int16_t*>(&x); *c_style_pointer_cast<const int16_t*>(&x);
} }
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value && else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) is_same<remove_cvref_t<X>, int8x4_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*c_style_pointer_cast<int32_t*>(&p_data_[i]) = *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32_t*>(&x); *c_style_pointer_cast<const int32_t*>(&x);
} }
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, else if constexpr(is_same<remove_cvref_t<T>, int8x4_t>::value &&
int8x4_t>::value && is_same<remove_cvref_t<X>, int8x4_t>::value)
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*c_style_pointer_cast<int32_t*>(&p_data_[i]) = *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32_t*>(&x); *c_style_pointer_cast<const int32_t*>(&x);
} }
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, else if constexpr(is_same<remove_cvref_t<T>, int8x8_t>::value &&
int8x8_t>::value && is_same<remove_cvref_t<X>, int8x8_t>::value)
is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) = *c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x2_t*>(&x); *c_style_pointer_cast<const int32x2_t*>(&x);
} }
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, else if constexpr(is_same<remove_cvref_t<T>, int8x16_t>::value &&
int8x16_t>::value && is_same<remove_cvref_t<X>, int8x16_t>::value)
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value)
{ {
// HACK: cast pointer of x is bad // HACK: cast pointer of x is bad
// TODO: remove this after compiler fix // TODO: remove this after compiler fix
...@@ -223,6 +211,35 @@ struct DynamicBuffer ...@@ -223,6 +211,35 @@ struct DynamicBuffer
} }
} }
template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
__host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X& x)
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X need to be multiple T");
static_assert(GetAddressSpace() == AddressSpaceEnum_t::Global, "only support global mem");
#if CK_USE_AMD_BUFFER_ADDRESSING
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_);
#else
if(is_valid_element)
{
atomicAdd(&p_data_[i], x);
}
#endif
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return false; } __host__ __device__ static constexpr bool IsStaticBuffer() { return false; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
...@@ -234,9 +251,14 @@ __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize el ...@@ -234,9 +251,14 @@ __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize el
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, true>{p, element_space_size}; return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, true>{p, element_space_size};
} }
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize> template <
AddressSpaceEnum_t BufferAddressSpace,
typename T,
typename ElementSpaceSize,
typename X,
typename enable_if<is_same<remove_cvref_t<T>, remove_cvref_t<X>>::value, bool>::type = false>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, T invalid_element_value) make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element_value)
{ {
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, false>{ return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, false>{
p, element_space_size, invalid_element_value}; p, element_space_size, invalid_element_value};
......
...@@ -114,12 +114,11 @@ struct MagicDivision ...@@ -114,12 +114,11 @@ struct MagicDivision
__host__ __device__ static constexpr uint32_t __host__ __device__ static constexpr uint32_t
DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift) DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{ {
uint32_t tmp = (uint64_t(dividend) * uint64_t(multiplier)) >> 32; uint32_t tmp = __umulhi(dividend, multiplier);
return (tmp + dividend) >> shift; return (tmp + dividend) >> shift;
} }
#if 1 // debug // magic division for int32_t
// HACK: magic division for int32_t
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be // HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
// non-negative for result to be correct // non-negative for result to be correct
// TODO: figure out how to do magic number divison for int32_t as dividended // TODO: figure out how to do magic number divison for int32_t as dividended
...@@ -127,27 +126,9 @@ struct MagicDivision ...@@ -127,27 +126,9 @@ struct MagicDivision
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{ {
uint32_t dividend_u32 = as_type<uint32_t>(dividend_i32); uint32_t dividend_u32 = as_type<uint32_t>(dividend_i32);
uint32_t tmp = uint32_t tmp = __umulhi(dividend_u32, multiplier);
(static_cast<uint64_t>(dividend_u32) * static_cast<uint64_t>(multiplier)) >> 32;
return (tmp + dividend_u32) >> shift; return (tmp + dividend_u32) >> shift;
} }
#else
// the inline ASM is producing wrong result
__host__ __device__ static int32_t
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t r;
asm volatile("\n \
v_mul_hi_u32 %0, %1, %2 \n \
v_add_u32_e32 %0, %1, %0 \n \
v_lshrrev_b32_e32 %0, %3, %0 \n \
"
: "=v"(r)
: "v"(as_type<uint32_t>(dividend_i32)), "s"(multiplier), "s"(shift));
return as_type<int32_t>(r);
}
#endif
}; };
} // namespace ck } // namespace ck
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment