Commit 7cacac2c authored by ltqin's avatar ltqin
Browse files

Merge branch 'develop' into backward_weight_v4r4r2_xdlops

parents e6d9dd20 9d3f634a
......@@ -4,21 +4,21 @@
#include "common_header.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "xdlops_gemm.hpp"
#include "tensor_adaptor.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
class ABlockDesc,
class BBlockDesc,
index_t MPerWave,
index_t NPerWave,
typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t K1>
struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{
using CIndex = MultiIndex<2>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
......@@ -26,329 +26,165 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static constexpr index_t WaveSize = 64;
static constexpr index_t M0 = ABlockDesc{}.GetLength(I1);
static constexpr index_t M1 = ABlockDesc{}.GetLength(I2);
static constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
static constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1>{};
static constexpr index_t K0 = BK0NK1BlockDesc{}.GetLength(I0);
static constexpr index_t KPerBlock = K0;
static constexpr index_t MWaves = M1 / MPerWave;
static constexpr index_t NWaves = N1 / NPerWave;
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{};
static constexpr index_t MRepeat = M0;
static constexpr index_t NRepeat = N0;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
__device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); }
__device__ static auto GetWaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
__device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); }
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
__device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); }
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto CalculateAThreadOriginDataIndex()
{
const index_t thread_id = get_thread_local_1d_id();
const index_t waveId = thread_id / WaveSize;
const index_t laneId = thread_id % WaveSize;
const index_t waveId_m = waveId / NWaves;
if constexpr(xdlops_gemm.IsKReduction)
{
const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId);
const index_t k_offset = xdlops_gemm.GetBlkId(laneId);
return make_tuple(k_offset, 0, m_offset, 0);
}
else
{
const index_t m_offset = waveId_m * MPerWave + laneId;
const index_t k_offset = 0;
return make_tuple(k_offset, 0, m_offset, 0);
}
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(xdlops_a_idx[I0], 0, waveId_m, xdlops_a_idx[I1], 0);
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
const index_t thread_id = get_thread_local_1d_id();
const index_t waveId = thread_id / WaveSize;
const index_t laneId = thread_id % WaveSize;
const index_t waveId_n = waveId % NWaves;
if constexpr(xdlops_gemm.IsKReduction)
{
const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId);
const index_t k_offset = xdlops_gemm.GetBlkId(laneId);
return make_tuple(k_offset, 0, n_offset, 0);
}
else
{
const index_t n_offset = waveId_n * NPerWave + laneId;
const index_t k_offset = 0;
return make_tuple(k_offset, 0, n_offset, 0);
}
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(xdlops_b_idx[I0], 0, waveId_n, xdlops_b_idx[I1], 0);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static CIndex
__device__ static auto
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const index_t waveId = get_thread_local_1d_id() / WaveSize;
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t waveId_m = waveId / NWaves;
const index_t waveId_n = waveId % NWaves;
constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0];
const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1];
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
return CIndex{m_offset, n_offset};
return make_tuple(c_thread_m, c_thread_n);
}
__device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1()
: a_thread_copy_{CalculateAThreadOriginDataIndex()},
b_thread_copy_{CalculateBThreadOriginDataIndex()}
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
{
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0NK1BlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
"wrong! K dimension not consistent");
static_assert(AK0MK1BlockDesc{}.GetLength(I0) == BK0NK1BlockDesc{}.GetLength(I0),
"wrong! K0 dimension not consistent");
static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3),
static_assert(AK0MK1BlockDesc{}.GetLength(I2) == BK0NK1BlockDesc{}.GetLength(I2),
"wrong! K1 dimension not consistent");
static_assert(BlockSize == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n");
static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!");
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!");
static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!");
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!");
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
__host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2ThreadDescriptor()
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize());
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
vector_type<FloatAB, a_thread_desc_.GetElementSpaceSize()> a_thread_vec;
vector_type<FloatAB, b_thread_desc_.GetElementSpaceSize()> b_thread_vec;
static_for<0, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k) {
// read A
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
// read B
b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.mfma_type.k_base>::type;
static_for<0, a_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(Number<i>{}) = a_thread_buf[Number<i>{}];
});
static_for<0, b_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatAB>()(Number<i>{}) = b_thread_buf[Number<i>{}];
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
m0,
n0>(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf);
});
});
});
return make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, I1, M0, M1, M2, N));
}
private:
// A[K, M]
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
// B[K, N]
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
static constexpr auto c_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
ABlockDesc,
decltype(a_thread_desc_),
Sequence<1, MRepeat, 1, K1>,
Sequence<0, 1, 2, 3>,
3,
K1,
1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
BBlockDesc,
decltype(b_thread_desc_),
Sequence<1, NRepeat, 1, K1>,
Sequence<0, 1, 2, 3>,
3,
K1,
1>;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
template <index_t BlockSize,
typename FloatAB,
class ABlockDesc,
class BBlockDesc,
index_t MPerWave,
index_t NPerWave,
index_t K1>
struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
{
using CIndex = MultiIndex<2>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto xdlops_gemm = XdlopsGemm<float, MPerWave, NPerWave, K1>{};
static constexpr index_t WaveSize = 64;
static constexpr index_t M0 = ABlockDesc{}.GetLength(I1);
static constexpr index_t M1 = ABlockDesc{}.GetLength(I2);
static constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
static constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
static constexpr index_t MWaves = M1 / MPerWave;
static constexpr index_t NWaves = N1 / NPerWave;
static constexpr index_t MRepeat = M0;
static constexpr index_t NRepeat = N0;
__device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); }
__device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); }
__device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); }
__device__ static auto CalculateAThreadOriginDataIndex()
__host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2BlockDescriptor()
{
const index_t thread_id = get_thread_local_1d_id();
const index_t waveId = thread_id / WaveSize;
const index_t laneId = thread_id % WaveSize;
const index_t waveId_m = waveId / NWaves;
if constexpr(xdlops_gemm.IsKReduction)
{
const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId);
const index_t k_offset = xdlops_gemm.GetBlkId(laneId);
return make_tuple(k_offset, 0, m_offset, 0);
}
else
{
const index_t m_offset = waveId_m * MPerWave + laneId;
const index_t k_offset = 0;
return make_tuple(k_offset, 0, m_offset, 0);
}
constexpr auto c_m0_n0_m1_n1_m2_n2_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_block_desc);
}
__device__ static auto CalculateBThreadOriginDataIndex()
template <typename CMNGridDesc>
__host__ __device__ static constexpr auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
{
const index_t thread_id = get_thread_local_1d_id();
const index_t waveId = thread_id / WaveSize;
const index_t laneId = thread_id % WaveSize;
const index_t waveId_n = waveId % NWaves;
if constexpr(xdlops_gemm.IsKReduction)
{
const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId);
const index_t k_offset = xdlops_gemm.GetBlkId(laneId);
return make_tuple(k_offset, 0, n_offset, 0);
}
else
{
const index_t n_offset = waveId_n * NPerWave + laneId;
const index_t k_offset = 0;
return make_tuple(k_offset, 0, n_offset, 0);
}
const auto c_m0_n0_m1_n1_m2_n2_grid_desc = transform_tensor_descriptor(
c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_grid_desc);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static CIndex
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
__host__ __device__ static constexpr auto MakeAK0M0M1M2K1BlockDescriptor()
{
const index_t waveId = get_thread_local_1d_id() / WaveSize;
const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
const index_t waveId_m = waveId / NWaves;
const index_t waveId_n = waveId % NWaves;
const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0];
const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1];
return CIndex{m_offset, n_offset};
return transform_tensor_descriptor(
AK0MK1BlockDesc{},
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerXDL>{})),
make_pass_through_transform(Number<K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
__device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline()
: a_thread_copy_{CalculateAThreadOriginDataIndex()},
b_thread_copy_{CalculateBThreadOriginDataIndex()}
__host__ __device__ static constexpr auto MakeBK0N0N1N2K1BlockDescriptor()
{
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
"wrong! K dimension not consistent");
static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3),
"wrong! K1 dimension not consistent");
static_assert(BlockSize == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n");
static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!");
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!");
static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!");
return transform_tensor_descriptor(
BK0NK1BlockDesc{},
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerXDL>{})),
make_pass_through_transform(Number<K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
static constexpr auto a_k0_m0_m1_m2_k1_block_desc = MakeAK0M0M1M2K1BlockDescriptor();
static constexpr auto b_k0_n0_n1_n2_k1_block_desc = MakeBK0N0N1N2K1BlockDescriptor();
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
......@@ -359,165 +195,87 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize());
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
// read A_sub_0
a_thread_copy_.Run(ABlockDesc{},
make_tuple(I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
// read B_sub_0
b_thread_copy_.Run(BBlockDesc{},
make_tuple(I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
// read B_sub_1
b_thread_copy_.Run(BBlockDesc{},
make_tuple(I0, I1, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I1, I0, I0),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(ABlockDesc{},
make_tuple(I0, I1, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I1, I0, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
0,
0>(a_thread_buf, b_thread_buf, c_thread_buf);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
0,
1>(a_thread_buf, b_thread_buf, c_thread_buf);
static_for<xdlops_gemm.KPerXdlops, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k) {
// read A_sub_0
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I0, I0, I0),
vector_type<FloatAB, K1> a_thread_vec;
vector_type<FloatAB, K1> b_thread_vec;
static_for<0, KPerBlock, xdlops_gemm.KPerXdlops / xdlops_gemm.KPerThread>{}([&](auto k0) {
// read A
a_thread_copy_.Run(a_k0_m0_m1_m2_k1_block_desc,
make_tuple(k0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
make_tuple(I0, I0, I0, I0, I0),
a_thread_buf);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
1,
0>(a_thread_buf, b_thread_buf, c_thread_buf);
// read B_sub_0
b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
1,
1>(a_thread_buf, b_thread_buf, c_thread_buf);
// read B_sub_1
b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I1, I0, I0),
// read B
b_thread_copy_.Run(b_k0_n0_n1_n2_k1_block_desc,
make_tuple(k0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I1, I0, I0),
make_tuple(I0, I0, I0, I0, I0),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I1, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I1, I0, I0),
a_thread_buf);
using mfma_input_type = typename vector_type<FloatAB, xdlops_gemm.KPerThread>::type;
// C_sub_00 += transpose(A_sub_0) * B_sub_0
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
0,
0>(a_thread_buf, b_thread_buf, c_thread_buf);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
0,
1>(a_thread_buf, b_thread_buf, c_thread_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, K1, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, m0, 0, 0, i))>{}];
});
static_for<0, K1, 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, n0, 0, 0, i))>{}];
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run<c_offset>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf);
});
});
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
1,
0>(a_thread_buf, b_thread_buf, c_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
1,
1>(a_thread_buf, b_thread_buf, c_thread_buf);
}
private:
// A[K, M]
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<MRepeat>{}, I1, I1, Number<K1>{}));
// B[K, N]
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<NRepeat>{}, I1, I1, Number<K1>{}));
static constexpr auto c_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<xdlops_gemm.GetNumXdlops()>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
ABlockDesc,
decltype(a_k0_m0_m1_m2_k1_block_desc),
decltype(a_thread_desc_),
Sequence<1, 1, 1, K1>,
Sequence<0, 1, 2, 3>,
3,
1, // K1,
Sequence<1, MRepeat, 1, 1, K1>,
Sequence<0, 1, 2, 3, 4>,
4,
K1,
1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
BBlockDesc,
decltype(b_k0_n0_n1_n2_k1_block_desc),
decltype(b_thread_desc_),
Sequence<1, 1, 1, K1>,
Sequence<0, 1, 2, 3>,
3,
1, // K1,
Sequence<1, NRepeat, 1, 1, K1>,
Sequence<0, 1, 2, 3, 4>,
4,
K1,
1>;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
};
} // namespace ck
......
......@@ -18,7 +18,7 @@ template <typename GridwiseGemm,
typename FloatC,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename CM0M1M2NGridDesc,
typename CM0N0M1N1M2M3M4N2GridDesc,
typename CBlockClusterAdaptor>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
......@@ -29,7 +29,7 @@ __global__ void
FloatC* __restrict__ p_c_grid,
const AK0MK1GridDesc a_k0_m_k1_grid_desc,
const BK0NK1GridDesc b_k0_n_k1_grid_desc,
const CM0M1M2NGridDesc c_m0_m1_m2_n_grid_desc,
const CM0N0M1N1M2M3M4N2GridDesc c_m0_m1_m2_n_grid_desc,
const CBlockClusterAdaptor c_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
......@@ -43,7 +43,7 @@ __global__ void
p_shared_block,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
......@@ -52,7 +52,7 @@ template <typename GridwiseGemm,
typename FloatC,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename CM0M1M2NGridDesc,
typename CM0N0M1N1M2M3M4N2GridDesc,
typename CBlockClusterAdaptor>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
......@@ -63,7 +63,7 @@ __global__ void
FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_k0_m_k1_grid_desc,
const void CONSTANT* p_b_k0_n_k1_grid_desc,
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
const void CONSTANT* p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
const void CONSTANT* p_c_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
......@@ -73,8 +73,9 @@ __global__ void
cast_pointer_to_generic_address_space(p_a_k0_m_k1_grid_desc));
const auto b_k0_n_k1_grid_desc = *reinterpret_cast<const BK0NK1GridDesc*>(
cast_pointer_to_generic_address_space(p_b_k0_n_k1_grid_desc));
const auto c_m0_m1_m2_n_grid_desc = *reinterpret_cast<const CM0M1M2NGridDesc*>(
cast_pointer_to_generic_address_space(p_c_m0_m1_m2_n_grid_desc));
const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
*reinterpret_cast<const CM0N0M1N1M2M3M4N2GridDesc*>(
cast_pointer_to_generic_address_space(p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc));
const auto c_block_cluster_adaptor = *reinterpret_cast<const CBlockClusterAdaptor*>(
cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor));
......@@ -86,7 +87,7 @@ __global__ void
p_shared_block,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor);
}
#endif
......@@ -102,8 +103,8 @@ template <index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
index_t MPerXDL,
index_t NPerXDL,
index_t K1Value,
index_t MRepeat,
index_t NRepeat,
......@@ -138,6 +139,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
......@@ -179,14 +183,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) &&
(NPerBlock % (NRepeat * NPerXDL)) == 0,
"Invalid tuning param!");
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
K0 == b_k0_n_k1_grid_desc.GetLength(I0) &&
K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) &&
(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0) &&
(MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0);
(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0);
}
__host__ __device__ static constexpr index_t
......@@ -201,29 +207,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}
__host__ __device__ static constexpr auto
MakeCM0M1M2NGridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
{
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1>{};
constexpr auto CLayout = xdlops_gemm.GetCLayout();
constexpr auto M0 = Number<CLayout.M1()>{};
constexpr auto M1 = Number<CLayout.N1()>{};
constexpr auto M2 = Number<CLayout.M0()>{};
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
constexpr auto max_lds_align = K1;
constexpr auto N1 = Number<CLayout.N0()>{};
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
const auto c_m0_m1_m2_n_grid_desc = transform_tensor_descriptor(
c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, M0, M1, M2)),
make_unmerge_transform(make_tuple(NRepeat, NWaves, N1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
return c_m0_m1_m2_n_grid_desc;
using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
K1>;
return BlockwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
}
__host__ __device__ static constexpr auto
......@@ -253,8 +258,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return c_blockid_to_m0_n0_block_cluster_adaptor;
}
using CM0M1M2NGridDesc = decltype(MakeCM0M1M2NGridDescriptor(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}));
using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}));
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
......@@ -262,7 +267,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
FloatAB* __restrict__ p_shared_block,
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const CM0M1M2NGridDesc& c_m0_m1_m2_n_grid_desc,
const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
const CBlockClusterAdaptor& c_block_cluster_adaptor)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
......@@ -270,7 +275,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_grid, c_m0_m1_m2_n_grid_desc.GetElementSpaceSize());
p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize());
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
......@@ -358,50 +363,26 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// register
// sanity check
static_assert(MPerBlock % (MPerWave * MRepeat) == 0 &&
NPerBlock % (NPerWave * NRepeat) == 0,
"wrong!");
constexpr auto a_k0_m0_m1_k1_block_desc = transform_tensor_descriptor(
a_k0_m_k1_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MPerBlock / MRepeat>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
constexpr auto b_k0_n0_n1_k1_block_desc = transform_tensor_descriptor(
b_k0_n_k1_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NPerBlock / NRepeat>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto blockwise_gemm =
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize,
FloatAB,
decltype(a_k0_m0_m1_k1_block_desc),
decltype(b_k0_n0_n1_k1_block_desc),
MPerWave,
NPerWave,
K1>{};
constexpr auto CLayout = blockwise_gemm.GetCLayout();
constexpr index_t BlkSize = CLayout.GetBlkSize();
constexpr index_t NumBlks = CLayout.GetNumBlks();
constexpr index_t NumXdlops = CLayout.GetNumXdlops();
static_assert(NumBlks == 1 && NumXdlops == 1, "K Reduction Mfma only");
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
K1>{};
constexpr auto c_mr_nr_blk_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
blockwise_gemm.GetCM0N0M1N1M2M3M4N2ThreadDescriptor();
constexpr auto CBlkSize = c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc.GetElementSpaceSize();
StaticBuffer<AddressSpaceEnum_t::Vgpr,
vector_type<FloatAcc, BlkSize>,
vector_type<FloatAcc, CBlkSize>,
c_mr_nr_blk_desc.GetElementSpaceSize(),
true>
c_thread_buf;
......@@ -474,94 +455,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
#if 0
// output: register to global memory
{
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
constexpr index_t N0 = CLayout.N1();
constexpr index_t N1 = CLayout.N0();
constexpr auto c_m0_m1_m2_n_thread_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<1>{},
Number<1>{},
Number<M0>{},
Number<1>{},
Number<M2>{},
Number<1>{}));
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize(), true>
c_blk_buf_;
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
static_for<0, NRepeat, 1>{}([&](auto nr_i) {
constexpr auto blk_off =
c_mr_nr_blk_desc.CalculateOffset(make_tuple(mr_i, nr_i));
static_for<0, BlkSize, 1>{}([&](auto j) {
c_blk_buf_(Number<blk_off * BlkSize + j>{}) =
c_thread_buf[Number<blk_off>{}]
.template AsType<FloatAcc>()[Number<j>{}];
});
});
});
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor();
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_grid =
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{};
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
ThreadwiseTensorSliceTransfer_v1r3<
FloatC,
FloatC,
decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_grid_desc),
Sequence<MRepeat, NRepeat, 1, 1, M0, 1, M2, 1>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{
c_m0_m1_m2_n_grid_desc,
make_multi_index(m_thread_data_on_grid / (M2 * M1 * M0 * MWaves),
n_thread_data_on_grid / (N1 * NWaves),
m_thread_data_on_grid % (M2 * M1 * M0 * MWaves) / (M2 * M1 * M0),
n_thread_data_on_grid % (N1 * NWaves) / N1,
m_thread_data_on_grid % (M2 * M1 * M0) / (M2 * M1),
m_thread_data_on_grid % (M2 * M1) / M2,
m_thread_data_on_grid % M2,
n_thread_data_on_grid % N1)}
.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_blk_buf_,
c_m0_m1_m2_n_grid_desc,
c_grid_buf,
c_m0_m1_m2_n_grid_tensor_step_hacks);
}
#else
{
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
constexpr auto c_m0_m1_m2_n_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, I1, Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
......@@ -574,92 +475,96 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{};
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{};
auto c_thread_copy =
ThreadwiseTensorSliceTransfer_v1r3<FloatC,
FloatC,
decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_grid_desc),
Sequence<1, 1, 1, 1, M0, 1, M2, 1>,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc),
Sequence<I1, I1, I1, I1, M2, I1, M4, I1>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{
c_m0_m1_m2_n_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
make_multi_index(0,
0,
0,
0,
m_thread_data_on_grid / (M2 * M1),
m_thread_data_on_grid % (M2 * M1) / M2,
m_thread_data_on_grid % M2,
m_thread_data_on_grid / (M3 * M4),
m_thread_data_on_grid % (M3 * M4) / M4,
m_thread_data_on_grid % M4,
n_thread_data_on_grid)};
auto init_copy = [&](auto c_thread_idx_) {
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc,
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf,
c_m0_m1_m2_n_grid_tensor_step_hacks);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
return c_thread_idx_;
};
auto mrepeat_plus_copy = [&](auto c_thread_idx_) {
constexpr auto mrepeat_step_plus = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, mrepeat_step_plus);
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
mrepeat_step_plus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc,
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf,
c_m0_m1_m2_n_grid_tensor_step_hacks);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
};
auto nrepeat_plus_copy = [&](auto c_thread_idx_) {
constexpr auto nrepeat_step_plus = make_multi_index(0, 1, 0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, nrepeat_step_plus);
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
nrepeat_step_plus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc,
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf,
c_m0_m1_m2_n_grid_tensor_step_hacks);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
};
auto mrepeat_minus_copy = [&](auto c_thread_idx_) {
constexpr auto mrepeat_step_plus = make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, mrepeat_step_plus);
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
mrepeat_step_plus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc,
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf,
c_m0_m1_m2_n_grid_tensor_step_hacks);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
};
auto nrepeat_minus_copy = [&](auto c_thread_idx_) {
constexpr auto nrepeat_step_minus = make_multi_index(0, -1, 0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, nrepeat_step_minus);
c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
nrepeat_step_minus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc,
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_grid_buf,
c_m0_m1_m2_n_grid_tensor_step_hacks);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks);
};
static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or
......@@ -791,7 +696,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
init_copy(make_tuple(I0, I0));
}
}
#endif
}
}; // namespace ck
......
......@@ -7,21 +7,18 @@
namespace ck {
enum struct mfma_instr
enum struct MfmaInstr
{
/// fp32
mfma_f32_32x32x1xf32 = 0,
mfma_f32_16x16x1xf32,
mfma_f32_4x4x1xf32,
mfma_f32_32x32x2xf32, // k reduction
mfma_f32_16x16x4xf32, // k reduction
/// fp16
mfma_f32_32x32x4f16,
mfma_f32_16x16x4f16,
mfma_f32_4x4x4f16,
mfma_f32_32x32x8f16, // k reduction
mfma_f32_16x16x16f16, // k reduction
/// bfp16
mfma_f32_32x32x2bf16,
mfma_f32_16x16x2bf16,
mfma_f32_4x4x2bf16,
......@@ -29,25 +26,23 @@ enum struct mfma_instr
mfma_f32_16x16x8bf16, // k reduction
};
template <mfma_instr instr>
struct mfma_info;
template <MfmaInstr instr>
struct mfma_type;
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
struct mfma_type<MfmaInstr::mfma_f32_32x32x1xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 2;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 1;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 1;
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 2;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 1;
static constexpr bool is_k_reduction = false;
template <index_t MPerXdlops,
index_t NPerXdlops,
......@@ -62,21 +57,19 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
struct mfma_type<MfmaInstr::mfma_f32_32x32x2xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 2;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 1;
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 1;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops,
index_t NPerXdlops,
......@@ -91,21 +84,19 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32>
struct mfma_type<MfmaInstr::mfma_f32_16x16x4xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 4;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 1;
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 1;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops,
index_t NPerXdlops,
......@@ -120,21 +111,19 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32>
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x1xf32>
struct mfma_type<MfmaInstr::mfma_f32_16x16x1xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 4;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 1;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 1;
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 4;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 1;
static constexpr bool is_k_reduction = false;
template <index_t MPerXdlops,
index_t NPerXdlops,
......@@ -150,21 +139,19 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x1xf32>
// treat 4x4x1 as a single-blk 4x64 mfma
template <>
struct mfma_info<mfma_instr::mfma_f32_4x4x1xf32>
struct mfma_type<MfmaInstr::mfma_f32_4x4x1xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 64;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = 4;
static constexpr index_t m = 4;
static constexpr index_t n = 64;
static constexpr index_t k = 1;
static constexpr index_t cycles = 8;
static constexpr index_t k_base = 1;
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 64;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 4;
static constexpr index_t n_per_blk = 64;
static constexpr index_t k_per_blk = 1;
static constexpr bool is_k_reduction = false;
template <index_t MPerXdlops,
index_t NPerXdlops,
......@@ -179,21 +166,19 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x1xf32>
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x4f16>
struct mfma_type<MfmaInstr::mfma_f32_32x32x4f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 2;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 4;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 4;
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 2;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 4;
static constexpr bool is_k_reduction = false;
template <index_t MPerXdlops,
index_t NPerXdlops,
......@@ -208,21 +193,19 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4f16>
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x8f16>
struct mfma_type<MfmaInstr::mfma_f32_32x32x8f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 8;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 4;
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 4;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops,
index_t NPerXdlops,
......@@ -237,21 +220,19 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x8f16>
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x16f16>
struct mfma_type<MfmaInstr::mfma_f32_16x16x16f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 16;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 4;
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 4;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops,
index_t NPerXdlops,
......@@ -266,21 +247,19 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x16f16>
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x4f16>
struct mfma_type<MfmaInstr::mfma_f32_16x16x4f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 4;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 4;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 4;
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 4;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 4;
static constexpr bool is_k_reduction = false;
template <index_t MPerXdlops,
index_t NPerXdlops,
......@@ -295,21 +274,19 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4f16>
};
template <>
struct mfma_info<mfma_instr::mfma_f32_4x4x4f16>
struct mfma_type<MfmaInstr::mfma_f32_4x4x4f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 64;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = 4;
static constexpr index_t m = 4;
static constexpr index_t n = 64;
static constexpr index_t k = 4;
static constexpr index_t cycles = 8;
static constexpr index_t k_base = 4;
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 64;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 4;
static constexpr index_t n_per_blk = 64;
static constexpr index_t k_per_blk = 4;
static constexpr bool is_k_reduction = false;
template <index_t MPerXdlops,
index_t NPerXdlops,
......@@ -325,21 +302,19 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x4f16>
#if 0
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
struct mfma_type<MfmaInstr::mfma_f32_32x32x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 2;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 2;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 2;
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 2;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 2;
static constexpr bool is_k_reduction = false;
template <index_t MPerXdlops,
index_t NPerXdlops,
......@@ -359,21 +334,19 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
struct mfma_type<MfmaInstr::mfma_f32_32x32x4bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 4;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 2;
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 2;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops,
index_t NPerXdlops,
......@@ -392,21 +365,19 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
struct mfma_type<MfmaInstr::mfma_f32_16x16x8bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 8;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 2;
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 2;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops,
index_t NPerXdlops,
......@@ -425,21 +396,19 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
struct mfma_type<MfmaInstr::mfma_f32_16x16x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 4;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 2;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 2;
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 4;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 2;
static constexpr bool is_k_reduction = false;
template <index_t MPerXdlops,
index_t NPerXdlops,
......@@ -458,21 +427,19 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
};
template <>
struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
struct mfma_type<MfmaInstr::mfma_f32_4x4x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 64;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = 4;
static constexpr index_t m = 4;
static constexpr index_t n = 64;
static constexpr index_t k = 2;
static constexpr index_t cycles = 8;
static constexpr index_t k_base = 2;
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 64;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 4;
static constexpr index_t n_per_blk = 64;
static constexpr index_t k_per_blk = 2;
static constexpr bool is_k_reduction = false;
template <index_t MPerXdlops,
index_t NPerXdlops,
......@@ -491,200 +458,227 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
};
#endif
template <mfma_instr instr, index_t MPerXdlops_, index_t NPerXdlops_>
struct xdlops_info
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops>
struct MfmaSelector
{
static constexpr auto mfma_type = mfma_info<instr>{};
template <typename base_type_, index_t MPerXdlops_, index_t NPerXdlops_>
static constexpr auto GetMfma();
static constexpr index_t MPerXdlops = MPerXdlops_;
static constexpr index_t NPerXdlops = NPerXdlops_;
static constexpr bool IsABroadcast()
template <>
static constexpr auto GetMfma<float, 64, 64>()
{
static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast");
return true;
return MfmaInstr::mfma_f32_32x32x1xf32;
}
static constexpr bool IsKReduction()
template <>
static constexpr auto GetMfma<float, 32, 64>()
{
return (mfma_type.num_output_blks == 1) && (mfma_type.num_input_blks > 1);
return MfmaInstr::mfma_f32_32x32x1xf32;
}
static constexpr index_t GetKPerXdlops()
template <>
static constexpr auto GetMfma<float, 16, 64>()
{
return IsKReduction() ? mfma_type.num_input_blks : 1;
return MfmaInstr::mfma_f32_16x16x1xf32;
}
static constexpr index_t GetNumCRegs() { return MPerXdlops * NPerXdlops / mfma_type.wave_size; }
};
template <class base_type, index_t MPerWave, index_t NPerWave, index_t KPack>
struct XdlopsGemm
{
template <class base_type_ = base_type,
index_t MPerWave_ = MPerWave,
index_t NPerWave_ = NPerWave>
static constexpr auto GetXdlopsInfo();
template <>
static constexpr auto GetXdlopsInfo<float, 64, 64>()
static constexpr auto GetMfma<float, 8, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 64>{};
return MfmaInstr::mfma_f32_4x4x1xf32;
}
template <>
static constexpr auto GetXdlopsInfo<float, 32, 64>()
static constexpr auto GetMfma<float, 4, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 32, 64>{};
return MfmaInstr::mfma_f32_4x4x1xf32;
}
template <>
static constexpr auto GetXdlopsInfo<float, 16, 64>()
static constexpr auto GetMfma<float, 32, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x1xf32, 16, 64>{};
return MfmaInstr::mfma_f32_32x32x2xf32;
}
template <>
static constexpr auto GetXdlopsInfo<float, 8, 64>()
static constexpr auto GetMfma<float, 16, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 8, 64>{};
return MfmaInstr::mfma_f32_16x16x4xf32;
}
template <>
static constexpr auto GetXdlopsInfo<float, 4, 64>()
static constexpr auto GetMfma<half_t, 64, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 4, 64>{};
return MfmaInstr::mfma_f32_32x32x4f16;
}
template <>
static constexpr auto GetXdlopsInfo<float, 32, 32>()
static constexpr auto GetMfma<half_t, 32, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2xf32, 32, 32>{};
return MfmaInstr::mfma_f32_32x32x4f16;
}
template <>
static constexpr auto GetXdlopsInfo<float, 16, 16>()
static constexpr auto GetMfma<half_t, 32, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x4xf32, 16, 16>{};
return MfmaInstr::mfma_f32_32x32x8f16;
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 64>()
static constexpr auto GetMfma<half_t, 16, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64>{};
return MfmaInstr::mfma_f32_16x16x16f16;
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 32, 64>()
static constexpr auto GetMfma<half_t, 16, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 32, 64>{};
return MfmaInstr::mfma_f32_16x16x4f16;
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 32, 32>()
static constexpr auto GetMfma<half_t, 8, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x8f16, 32, 32>{};
return MfmaInstr::mfma_f32_4x4x4f16;
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 16, 16>()
static constexpr auto GetMfma<half_t, 4, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x16f16, 16, 16>{};
return MfmaInstr::mfma_f32_4x4x4f16;
}
#if 0
template <>
static constexpr auto GetXdlopsInfo<half_t, 16, 64>()
static constexpr auto GetMfma<ushort, 128, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x4f16, 16, 64>{};
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 64, 2, 1, c_vec32_4_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 8, 64>()
static constexpr auto GetMfma<ushort, 64, 128>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 8, 64>{};
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 64, 1, 2, c_vec32_4_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 4, 64>()
static constexpr auto GetMfma<ushort, 64, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 4, 64>{};
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 64, 1, 1, c_vec32_2_t>{};
}
#if 0
template <>
static constexpr auto GetXdlopsInfo<ushort, 128, 64>()
static constexpr auto GetMfma<ushort, 64, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 2, 1, c_vec32_4_t>{};
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 32, 1, 1, c_vec32_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 128>()
static constexpr auto GetMfma<ushort, 32, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 1, 2, c_vec32_4_t>{};
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 32, 64, 1, 1, c_vec32_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 64>()
static constexpr auto GetMfma<ushort, 64, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 1, 1, c_vec32_2_t>{};
return xdlops_info<MfmaInstr::mfma_f32_16x16x2bf16, 64, 16, 1, 1, c_vec16_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 32>()
static constexpr auto GetMfma<ushort, 16, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 32, 1, 1, c_vec32_1_t>{};
return xdlops_info<MfmaInstr::mfma_f32_16x16x2bf16, 16, 64, 1, 1, c_vec16_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 32, 64>()
static constexpr auto GetMfma<ushort, 8, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 32, 64, 1, 1, c_vec32_1_t>{};
return xdlops_info<MfmaInstr::mfma_f32_4x4x2bf16, 8, 64, 1, 1, c_vec4_2_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 16>()
static constexpr auto GetMfma<ushort, 4, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x2bf16, 64, 16, 1, 1, c_vec16_1_t>{};
return xdlops_info<MfmaInstr::mfma_f32_4x4x2bf16, 4, 64, 1, 1, c_vec4_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 16, 64>()
static constexpr auto GetMfma<ushort, 32, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x2bf16, 16, 64, 1, 1, c_vec16_1_t>{};
return xdlops_info<MfmaInstr::mfma_f32_32x32x4bf16, 32, 32, 1, 1, c_vec16_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 8, 64>()
static constexpr auto GetMfma<ushort, 16, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x2bf16, 8, 64, 1, 1, c_vec4_2_t>{};
return xdlops_info<MfmaInstr::mfma_f32_16x16x8bf16, 16, 16, 1, 1, c_vec4_1_t>{};
}
#endif
template <>
static constexpr auto GetXdlopsInfo<ushort, 4, 64>()
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
__host__ __device__ static constexpr void mfma_check()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x2bf16, 4, 64, 1, 1, c_vec4_1_t>{};
static_assert(selected_mfma.group_size * selected_mfma.num_groups_per_blk ==
selected_mfma.num_regs_per_blk,
"wrong! num_regs_per_blk");
static_assert(selected_mfma.num_threads_per_blk == selected_mfma.n_per_blk,
"n_per_blk != num_threads_per_blk");
static_assert(selected_mfma.num_regs_per_blk * selected_mfma.num_input_blks ==
selected_mfma.m_per_blk,
"m_per_blk != num_input_blks * num_regs_per_blk");
static_assert(selected_mfma.num_output_blks == selected_mfma.num_input_blks ||
selected_mfma.num_output_blks == 1,
"incorrect num_output_blks");
static_assert(selected_mfma.num_regs_per_blk * selected_mfma.wave_size ==
selected_mfma.m_per_blk * selected_mfma.n_per_blk,
"num_regs_per_blk incorrect");
static_assert(selected_mfma.is_k_reduction ||
(selected_mfma.num_input_blks == selected_mfma.num_output_blks),
"is_k_reduction wrong!");
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 32, 32>()
__host__ __device__ constexpr MfmaSelector() { mfma_check(); }
static constexpr bool IsABroadcast()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4bf16, 32, 32, 1, 1, c_vec16_1_t>{};
static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast");
return true;
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 16, 16>()
static constexpr index_t GetKPerXdlops()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x8bf16, 16, 16, 1, 1, c_vec4_1_t>{};
return (selected_mfma.is_k_reduction ? selected_mfma.num_input_blks : 1) *
selected_mfma.k_per_blk;
}
#endif
static constexpr index_t GetKPerThread() { return selected_mfma.k_per_blk; }
};
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack>
struct XdlopsGemm
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
using CIndex = MultiIndex<2>;
__device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; }
__device__ static constexpr index_t GetNumBlks() { return mfma_instr.num_output_blks; }
__device__ static constexpr index_t GetNumXdlops()
{
return MPerXdlops * NPerXdlops / (mfma_type.m * mfma_type.n * mfma_type.num_output_blks);
return MPerXdlops * NPerXdlops /
(mfma_instr.m_per_blk * mfma_instr.n_per_blk * mfma_instr.num_output_blks);
}
__host__ __device__ constexpr XdlopsGemm()
......@@ -697,104 +691,142 @@ struct XdlopsGemm
MPerXdlops == 64,
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
static_assert(mfma_type.num_threads_blk == mfma_type.n, "n != num_threads_blk");
static_assert(mfma_type.num_regs_blk * mfma_type.num_input_blks == mfma_type.m,
"m != num_input_blks * num_regs_blk");
static_assert(mfma_type.num_output_blks == mfma_type.num_input_blks ||
mfma_type.num_output_blks == 1,
"incorrect num_output_blks");
static_assert(mfma_type.num_regs_blk * mfma_type.wave_size == mfma_type.m * mfma_type.n,
"num_regs_blk incorrect");
static_assert(mfma_type.k % mfma_type.k_base == 0, "k % kbase != 0!");
static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk");
}
template <typename CM0N0M1N1M2N2Desc>
__host__ __device__ static constexpr auto
MakeCM0N0M1N1M2M3M4N2Descriptor(const CM0N0M1N1M2N2Desc& c_m0_n0_m1_n1_m2_n2_desc)
{
const auto M0 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I0);
const auto N0 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I1);
const auto M1 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I2);
const auto N1 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I3);
return transform_tensor_descriptor(
c_m0_n0_m1_n1_m2_n2_desc,
make_tuple(make_pass_through_transform(M0),
make_pass_through_transform(N0),
make_pass_through_transform(M1),
make_pass_through_transform(N1),
make_unmerge_transform(make_tuple(mfma_instr.num_groups_per_blk,
mfma_instr.num_input_blks,
mfma_instr.group_size)),
make_pass_through_transform(mfma_instr.num_threads_per_blk)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4, 5, 6>{},
Sequence<7>{}));
}
__device__ static constexpr index_t GetRegSizePerXdlops()
{
return MPerXdlops * NPerXdlops / mfma_type.wave_size;
return MPerXdlops * NPerXdlops / mfma_instr.wave_size;
}
template <class ADesc,
class BDesc,
class CDesc,
index_t m0,
index_t n0,
class FloatA,
class FloatB,
class FloatC>
template <index_t c_offset, class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{
static_assert(is_same<base_type, float>::value || is_same<base_type, half_t>::value ||
is_same<base_type, ushort>::value,
"base base_type must be float, half, ushort!");
static_assert(KPack % mfma_type.k_base == 0, "KPack cannot be divided by k_base");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
mfma_instr.template run<MPerXdlops, NPerXdlops, c_offset>(
p_a_wave[k], p_b_wave[k], p_c_thread);
});
}
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % mfma_instr.wave_size; }
constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0)) * GetNumXdlops();
__device__ static auto GetBlkIdx()
{
const auto laneId = GetLaneId();
static_for<0, KPack, mfma_type.k_base>{}([&](auto k) {
constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(0, m0, 0, k));
constexpr index_t b_offset = BDesc{}.CalculateOffset(make_tuple(0, n0, 0, k));
constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(
make_tuple(1, mfma_instr.num_input_blks, mfma_instr.num_threads_per_blk))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
mfma_type.template run<MPerXdlops, NPerXdlops, c_offset>(
p_a_wave[Number<a_offset / mfma_type.k_base>{}],
p_b_wave[Number<b_offset / mfma_type.k_base>{}],
p_c_thread);
});
const auto blk_idx =
threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId));
const auto blk_id = blk_idx[I1];
const auto blk_td = blk_idx[I2];
return make_tuple(blk_id, blk_td);
}
__device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
{
const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size;
const index_t blk_id = laneId / mfma_type.num_threads_blk;
const index_t blk_td = laneId % mfma_type.num_threads_blk;
const auto laneId = GetLaneId();
const auto blk_idx = GetBlkIdx();
index_t n_offset = blk_i * mfma_type.n + blk_td;
index_t m_offset = xdlops_i * mfma_type.m + blk_id * mfma_type.group_size;
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
return CIndex{m_offset, n_offset};
if constexpr(mfma_instr.is_k_reduction)
{
return make_tuple(blk_id, blk_td);
}
else
{
return make_tuple(0, laneId);
}
}
static constexpr index_t MRepeats = GetXdlopsInfo().MRepeats;
static constexpr index_t NRepeats = GetXdlopsInfo().NRepeats;
static constexpr index_t MPerXdlops = GetXdlopsInfo().MPerXdlops;
static constexpr index_t NPerXdlops = GetXdlopsInfo().NPerXdlops;
__host__ __device__ static auto CalculateBThreadOriginDataIndex()
{
const auto laneId = GetLaneId();
const auto blk_idx = GetBlkIdx();
static constexpr bool IsKReduction = GetXdlopsInfo().IsKReduction();
static constexpr bool IsABroadcast = GetXdlopsInfo().IsABroadcast();
static constexpr index_t KPerXdlops = GetXdlopsInfo().GetKPerXdlops();
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
static constexpr auto GetBlkId(const index_t lane_id)
{
return lane_id / mfma_type.num_threads_blk;
if constexpr(mfma_instr.is_k_reduction)
{
return make_tuple(blk_id, blk_td);
}
else
{
return make_tuple(0, laneId);
}
}
static constexpr auto GetBlkTd(const index_t lane_id)
__device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
{
return lane_id % mfma_type.num_threads_blk;
}
const auto blk_idx = GetBlkIdx();
static constexpr auto mfma_type = GetXdlopsInfo().mfma_type;
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
struct CLayout
{
__host__ __device__ static constexpr index_t M1() { return mfma_type.num_groups_blk; }
__host__ __device__ static constexpr index_t M0() { return mfma_type.group_size; }
__host__ __device__ static constexpr index_t N1() { return mfma_type.num_input_blks; }
__host__ __device__ static constexpr index_t N0() { return mfma_type.num_threads_blk; }
index_t n_offset = blk_i * mfma_instr.n_per_blk + blk_td;
index_t m_offset = xdlops_i * mfma_instr.m_per_blk + blk_id * mfma_instr.group_size;
__device__ static constexpr index_t GetBlkSize() { return mfma_type.num_regs_blk; }
return CIndex{m_offset, n_offset};
}
__device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; }
static constexpr auto mfma = MfmaSelector<base_type, MPerXdlops, NPerXdlops>{};
__device__ static constexpr index_t GetNumXdlops()
{
return MPerXdlops * NPerXdlops /
(mfma_type.m * mfma_type.n * mfma_type.num_output_blks);
}
};
static constexpr auto mfma_instr = mfma.selected_mfma;
static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
static constexpr auto KPerThread = mfma.GetKPerThread();
__host__ __device__ static constexpr auto GetCLayout() { return CLayout{}; }
__host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths()
{
return make_tuple(
Number<mfma_instr.num_groups_per_blk>{}, I1, Number<mfma_instr.group_size>{}, I1);
}
};
} // namespace ck
......
......@@ -114,12 +114,11 @@ struct MagicDivision
__host__ __device__ static constexpr uint32_t
DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = (uint64_t(dividend) * uint64_t(multiplier)) >> 32;
uint32_t tmp = __umulhi(dividend, multiplier);
return (tmp + dividend) >> shift;
}
#if 1 // debug
// HACK: magic division for int32_t
// magic division for int32_t
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
// non-negative for result to be correct
// TODO: figure out how to do magic number divison for int32_t as dividended
......@@ -127,27 +126,9 @@ struct MagicDivision
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = as_type<uint32_t>(dividend_i32);
uint32_t tmp =
(static_cast<uint64_t>(dividend_u32) * static_cast<uint64_t>(multiplier)) >> 32;
uint32_t tmp = __umulhi(dividend_u32, multiplier);
return (tmp + dividend_u32) >> shift;
}
#else
// the inline ASM is producing wrong result
__host__ __device__ static int32_t
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t r;
asm volatile("\n \
v_mul_hi_u32 %0, %1, %2 \n \
v_add_u32_e32 %0, %1, %0 \n \
v_lshrrev_b32_e32 %0, %3, %0 \n \
"
: "=v"(r)
: "v"(as_type<uint32_t>(dividend_i32)), "s"(multiplier), "s"(shift));
return as_type<int32_t>(r);
}
#endif
};
} // namespace ck
......
......@@ -56,9 +56,9 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
......@@ -84,9 +84,9 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
......@@ -112,9 +112,9 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
......@@ -140,9 +140,9 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
......@@ -168,9 +168,9 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
......@@ -223,25 +223,27 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmn
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1
constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple(
// clang-format off
constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: NRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: MWaves
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 3+: NWaves
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 7+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: MRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: NRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: MWaves
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 3-: NWaves
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 7-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
//clang-format on
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{};
......@@ -263,8 +265,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmMPerXDL,
GemmNPerXDL,
GemmK1,
MRepeat,
NRepeat,
......@@ -289,7 +291,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
GemmCThreadTransferDstScalarPerVector,
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(in_m0_m1_m2_n_grid_step_hacks),
decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false // CAccessOrderMRepeatNRepeat
......@@ -301,7 +303,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
in_gemmm_gemmn_grid_desc,
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
out_gemmk0_gemmn_gemmk1_grid_step_hacks,
in_m0_m1_m2_n_grid_step_hacks,
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
......
......@@ -195,25 +195,27 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1
constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple(
// clang-format off
constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: MRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 2+: MWaves
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: NWaves
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 4+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 5+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 6+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: MRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: NRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 2-: MWaves
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: NWaves
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 4-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 5-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 6-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
//clang-format on
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{};
......@@ -265,7 +267,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
GemmCThreadTransferDstScalarPerVector,
decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(in_m0_m1_m2_n_grid_step_hacks),
decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
true // CAccessOrderMRepeatNRepeat
......@@ -277,7 +279,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
in_gemmm_gemmn_grid_desc,
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
in_m0_m1_m2_n_grid_step_hacks,
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
......
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_k_ho_wo_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_c_hi_wi,
const Tensor<TInWei>& wei_k_c_y_x,
Tensor<TOut>& out_n_k_ho_wo,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr auto I8 = Number<8>{};
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
#if 0
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 8;
constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 0
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 0
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 4]
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 4]
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 4;
constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#endif
const auto descs =
#if 1
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad
#else
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_1x1
#endif
<TInWei, GemmMPerBlock, GemmNPerBlock, GemmMPerWave, GemmNPerWave, GemmKPack>(
wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads);
for(index_t i = 0; i < 5; ++i)
{
#if 0
float ave_time = launch_kernel_gemm_xdlops_v1
#else
float ave_time = launch_kernel_gemm_xdlops_v2
#endif
<BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
decltype(descs[I0]),
decltype(descs[I1]),
decltype(descs[I2]),
decltype(descs[I3]),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmKPack,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_KPack,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<0, 2, 1>,
Sequence<1, 0, 2>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_KPack,
false, // don't move back src coordinate after threadwise copy, which will be fused
// with MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1,
decltype(descs[I4]),
decltype(descs[I5]),
decltype(descs[I6]),
decltype(descs[I7]),
decltype(descs[I8])>(static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
descs[I0],
descs[I1],
descs[I2],
descs[I3],
descs[I4],
descs[I5],
descs[I6],
descs[I7],
descs[I8],
nrepeat);
float perf = (float)calculate_convolution_flops(
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
// copy result back to host
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
}
......@@ -48,10 +48,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
#if 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
......@@ -59,10 +59,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
......@@ -105,23 +105,23 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}));
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0>{};
......@@ -169,7 +169,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
GemmCThreadTransferDstScalarPerVector,
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(out_m0_m1_m2_n_grid_step_hacks),
decltype(out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false>(static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
......@@ -180,7 +180,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
out_gemmm_gemmn_grid_desc,
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
out_m0_m1_m2_n_grid_step_hacks,
out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
......
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r2.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_hi_wi_c,
const Tensor<TInWei>& wei_k_y_x_c,
Tensor<TOut>& out_n_ho_wo_k,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
#if 1
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#endif
const auto descs =
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc,
in_n_hi_wi_c_desc,
out_n_ho_wo_k_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
Number<GemmK1>{});
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto out_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0>{};
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_gemm_xdlops_v2r2<
BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmABlockTransferSrcScalarPerVector_GemmK1,
GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmBBlockTransferSrcScalarPerVector_GemmK1,
GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
Sequence<2, 3, 0, 1>,
2,
GemmCThreadTransferDstScalarPerVector,
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(out_m0_m1_m2_n_grid_step_hacks),
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks)>(
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
wei_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc,
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
out_m0_m1_m2_n_grid_step_hacks,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
{
const auto N = out_n_ho_wo_k_lengths[I0];
const auto K = out_n_ho_wo_k_lengths[I3];
const auto C = wei_k_y_x_c_lengths[I3];
const auto Ho = out_n_ho_wo_k_lengths[I1];
const auto Wo = out_n_ho_wo_k_lengths[I2];
const auto Y = wei_k_y_x_c_lengths[I1];
const auto X = wei_k_y_x_c_lengths[I2];
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
// copy result back to host
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
}
......@@ -56,8 +56,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 4;
......@@ -84,9 +84,9 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
......@@ -112,9 +112,9 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 4;
......@@ -140,9 +140,9 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
......@@ -168,9 +168,9 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
......@@ -196,9 +196,9 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
......@@ -249,23 +249,23 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: MRepeat
Sequence<0, 0, 0, 0, 0>{}, // 1+: NRepeat
Sequence<0, 0, 0, 0, 0>{}, // 2+: MWaves
Sequence<0, 0, 0, 0, 0>{}, // 3+: NWaves
Sequence<0, 0, 0, 0, 0>{}, // 4+: M0
Sequence<0, 0, 0, 0, 0>{}, // 5+: M1
Sequence<0, 0, 0, 0, 0>{}, // 6+: M2
Sequence<0, 0, 0, 0, 0>{}), // 7+: N1
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: MRepeat
Sequence<0, 0, 0, 0, 0>{}, // 1-: NRepeat
Sequence<0, 0, 0, 0, 0>{}, // 2-: MWaves
Sequence<0, 0, 0, 0, 0>{}, // 3-: NWaves
Sequence<0, 0, 0, 0, 0>{}, // 4-: M0
Sequence<0, 0, 0, 0, 0>{}, // 5-: M1
Sequence<0, 0, 0, 0, 0>{}, // 6-: M2
Sequence<0, 0, 0, 0, 0>{})); // 7-: N1
constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: MWaves
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: NWaves
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: MRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: NRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: MWaves
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: NWaves
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N1
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
......@@ -287,8 +287,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmMPerXDL,
GemmNPerXDL,
GemmK1,
MRepeat,
NRepeat,
......@@ -313,7 +313,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
GemmCThreadTransferDstScalarPerVector,
decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(out_m0_m1_m2_n_grid_step_hacks),
decltype(out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false // CAccessOrderMRepeatNRepeat
......@@ -325,7 +325,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
out_gemmm_gemmn_grid_desc,
in_gemmk0_gemmm_gemmk1_grid_step_hacks,
wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
out_m0_m1_m2_n_grid_step_hacks,
out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
......
......@@ -17,8 +17,8 @@ template <ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t MPerWave,
ck::index_t NPerWave,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t K1,
ck::index_t MRepeat,
ck::index_t NRepeat,
......@@ -79,8 +79,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWave,
NPerWave,
MPerXDL,
NPerXDL,
K1,
MRepeat,
NRepeat,
......@@ -129,9 +129,10 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
}
const auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc);
using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
......@@ -144,7 +145,7 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
FloatC,
remove_reference_t<AK0MK1GridDesc>,
remove_reference_t<BK0NK1GridDesc>,
remove_reference_t<CM0M1M2NGridDesc>,
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
remove_reference_t<CBlockClusterAdaptor>>;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
......@@ -158,18 +159,18 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
p_c_grid,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor);
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc));
DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc));
DeviceMem c_m0_m1_m2_n_grid_desc_dev_buf(sizeof(CM0M1M2NGridDesc));
DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc));
DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor));
a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc);
b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc);
c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
float ave_time = launch_and_time_kernel(
......@@ -183,7 +184,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
p_c_grid,
cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
#endif
return ave_time;
......
......@@ -20,12 +20,12 @@
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_MODE 1
#define USE_CONV_FWD_V4R4_NCHW 1
#define USE_CONV_FWD_V4R4R2_NHWC 1
#define USE_CONV_FWD_V4R4_NCHW 0
#define USE_CONV_FWD_V4R4R2_NHWC 0
#define USE_CONV_FWD_V6R1_NCHW 0
#define USE_CONV_FWD_V5R1_NCHW 0
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 1
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 1
enum ConvForwardAlgo
{
......@@ -126,7 +126,7 @@ int main(int argc, char* argv[])
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
#endif
#if 1
#if 0
using in_data_t = float;
using acc_data_t = float;
using out_data_t = float;
......
......@@ -9,8 +9,8 @@ struct tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw
int NPerBlock;
int KPerBlock;
int MPerWave;
int NPerWave;
int MPerXDL;
int NPerXDL;
int K1;
int MRepeat;
......@@ -45,8 +45,8 @@ static tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw
128, // MPerBlock,
128, // NPerBlock,
4, // KPerBlock,
32, // MPerWave,
32, // NPerWave,
32, // MPerXDL,
32, // NPerXDL,
4, // K1,
2, // MRepeat,
2, // NRepeat,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment