Commit 9adf2e60 authored by aska-0096's avatar aska-0096
Browse files

runtime bug, cannot find symbol

parent b3cc22a3
......@@ -27,7 +27,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>;
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 6, 1>;
// clang-format on
......
......@@ -10,16 +10,6 @@
namespace ck {
enum struct LoopScheduler
{
Default,
};
constexpr LoopScheduler make_default_loop_scheduler()
{
return LoopScheduler::Default;
}
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
......@@ -30,18 +20,22 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
// MRepeat_MWave_MLaneHigh_NRepeat_NWave_NLane_MLanelow
/* A: K0PerBlock x MPerBlock x K1
* B: K0PerBlock x NPerBlock x K1
* C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
*/
struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I3 = Number<4>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto WmmaK = Number<16>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t WaveSize = 32;
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
......@@ -52,7 +46,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto wmma_gemm = WMMAGemm<FloatAB, MPerWMMA, NPerWMMA, KPack>{};
static constexpr auto wmma_gemm = WmmaGemm<FloatAB, FloatAcc, MPerWMMA, NPerWMMA, KPack>{};
static constexpr index_t KPerThread = KPerBlock / wmma_gemm.K0PerWMMA;
......@@ -62,7 +56,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
MRepeat * NRepeat,
wmma_gemm.GetRegSizePerWMMA(),
wmma_gemm.GetRegSizePerWmma(),
true>
c_thread_buf_;
......@@ -87,7 +81,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
const auto waveId_m = wave_idx[I0];
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
// |KRepeat |MRepeat|Mwave |MLane |KPack
// |KRepeat |MRepeat|MWave |MLane |KPack
return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0);
}
......@@ -131,7 +125,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
return make_tuple(c_thread_m, c_thread_n);
}
__host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
__host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3()
{
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0NK1BlockDesc::IsKnownAtCompileTime(),
......@@ -157,76 +151,49 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
make_tuple(Number<MRepeat>{}, I1, MSubGroup, Number<NRepeat>{}, I1, NThreadPerSubGroup, MAccVgprs));
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_MRepeat_Mwave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<MWaves>{},
Number<MPerWMMA>{},
Number<NRepeat>{},
Number<NWaves>{},
Number<NPerWMMA>{}));
return wmma_gemm.MakeCDesc_MRepeat_Mwave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return wmma_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
return wmma_gemm.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
}
__host__ __device__ static constexpr auto MakeABlockDescriptor_KRepeat_M0_M1_M2_KPack()
__host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1()
{
static constexpr auto a_block_desc_temp_km0m1m2 = transform_tensor_descriptor(
AK0MK1BlockDesc{},
make_tuple(
make_merge_transform(make_tuple(Number<A_K0>{}, Number<A_K1>{})),
make_unmerge_transform(make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWMMA>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}));
return transform_tensor_descriptor(
a_block_desc_temp_km0m1m2,
AK0MK1BlockDesc{},
make_tuple(
make_unmerge_transform(make_tuple(Number<A_K0*A_K1/KPack>{}, Number<KPack>{})),
make_pass_through_transform(make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWMMA>{}))),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}),
make_tuple(Sequence<0, 4>{}, Sequence<1, 2, 3>{}));
make_pass_through_transform(Number<A_K0>{}),
make_unmerge_transform(make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWMMA>{})),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
__host__ __device__ static constexpr auto MakeBBlockDescriptor_KRepeat_N0_N1_N2_KPack()
__host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1()
{
static constexpr auto b_block_desc_temp_kn0n1n2 = transform_tensor_descriptor(
BK0NK1BlockDesc{},
make_tuple(
make_merge_transform(make_tuple(Number<B_K0>{}, Number<B_K1>{})),
make_unmerge_transform(make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWMMA>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}));
return transform_tensor_descriptor(
b_block_desc_temp_kn0n1n2,
BK0NK1BlockDesc{},
make_tuple(
make_unmerge_transform(make_tuple(Number<B_K0*B_K1/KPack>{}, Number<KPack>{})),
make_pass_through_transform(make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWMMA>{}))),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}),
make_tuple(Sequence<0, 4>{}, Sequence<1, 2, 3>{}));
make_pass_through_transform(Number<B_K0>{}),
make_unmerge_transform(make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWMMA>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
static constexpr auto a_block_desc_krepeat_m0_m1_m2_kpack = MakeABlockDescriptor_KRepeat_M0_M1_M2_KPack();
static constexpr auto b_block_desc_krepeat_n0_n1_n2_kpack = MakeBBlockDescriptor_KRepeat_N0_N1_N2_KPack();
static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1();
static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1();
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
......@@ -239,9 +206,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
b_thread_desc_.GetElementSpaceSize());
constexpr auto RepeatDiff = MRepeat - NRepeat;
constexpr auto WmmaK = wmma_gemm.k_per_wmma;
static_for<0, KPerBlock / WmmaK, 1>{}([&](auto iWmmaK){
static_for<0, KPerBlock, WmmaK>{}([&](auto iWmmaK){
// Cut to Repeat Retangle to Square, assume MRepeat > NRepeat
static_for<0, RepeatDiff, 1>{}([&](auto iCut){
static_for<0, NRepeat, 1>{}([&](auto iN){
......@@ -251,25 +217,25 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
static_for<0, WmmaK, 1>{}([&](auto iK) {
a_thread_vec.template AsType<FloatAB>()(iK) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(iCut, 0, 0, iK))>{}];
make_tuple(iK/A_K1, iCut, 0, 0, iK%A_K1))>{}];
b_thread_vec.template AsType<FloatAB>()(iK) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(iN, 0, 0, iK))>{}];
make_tuple(iK/B_K1, iN, 0, 0, iK%B_K1))>{}];
});
using wmma_input_type = typename vector_type<FloatAB, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type>(),
b_thread_vec.template AsType<wmma_input_type>(),
a_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
a_thread_copy_.Run(a_block_desc_krepeat_m0_m1_m2_kpack,
make_tuple(Number<iWmmaK>{}, iCut, I0, I0, I0),
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<iWmmaK/A_K1>{}, Number<iCut>{}, I0, I0, Number<iWmmaK%A_K1>{}),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
make_tuple(I0, Number<iCut>{}, I0, I0, I0),
a_thread_buf);
});
// Run FIFO fashion loopover in Square
......@@ -281,25 +247,25 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
static_for<0, WmmaK, 1>{}([&](auto iK) {
a_thread_vec.template AsType<FloatAB>()(iK) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(WmmaInnerloop+RepeatDiff, 0, 0, iK))>{}];
make_tuple(iK/A_K1, WmmaInnerloop+RepeatDiff, 0, 0, iK%A_K1))>{}];
b_thread_vec.template AsType<FloatAB>()(iK) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(iN, 0, 0, iK))>{}];
make_tuple(iK/B_K1, iN, 0, 0, iK%B_K1))>{}];
});
using wmma_input_type = typename vector_type<FloatAB, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop+RepeatDiff, iN, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type>(),
b_thread_vec.template AsType<wmma_input_type>(),
a_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
a_thread_copy_.Run(a_block_desc_krepeat_m0_m1_m2_kpack,
make_tuple(Number<iWmmaK>{}, WmmaInnerloop+RepeatDiff, I0, I0, I0),
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<iWmmaK/A_K1>{}, Number<WmmaInnerloop+RepeatDiff>{}, I0, I0, Number<iWmmaK%A_K1>{}),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
make_tuple(I0, Number<WmmaInnerloop+RepeatDiff>{}, I0, I0, I0),
a_thread_buf);
static_for<WmmaInnerloop+RepeatDiff, MRepeat, 1>{}([&](auto iM){
vector_type<FloatAB, WmmaK> a_thread_vec;
......@@ -308,25 +274,25 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
static_for<0, WmmaK, 1>{}([&](auto iK) {
a_thread_vec.template AsType<FloatAB>()(iK) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(iM, 0, 0, iK))>{}];
make_tuple(iK/A_K1, iM, 0, 0, iK%A_K1))>{}];
b_thread_vec.template AsType<FloatAB>()(iK) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(WmmaInnerloop, 0, 0, iK))>{}];
make_tuple(iK/B_K1, WmmaInnerloop, 0, 0, iK%B_K1))>{}];
});
using wmma_input_type = typename vector_type<FloatAB, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type>(),
b_thread_vec.template AsType<wmma_input_type>(),
a_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
b_thread_copy_.Run(b_block_desc_krepeat_n0_n1_n2_kpack,
make_tuple(Number<iWmmaK>{}, WmmaInnerloop, I0, I0, I0),
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<iWmmaK/B_K1>{}, Number<WmmaInnerloop>{}, I0, I0, Number<iWmmaK%B_K1>{}),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
make_tuple(I0, Number<WmmaInnerloop>{}, I0, I0, I0),
b_thread_buf);
});
});
......@@ -335,33 +301,33 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
protected:
// A[M0, M1, M2, K0 = WmmaK]
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<WmmaK>{}));
make_naive_tensor_descriptor_packed(make_tuple(Number<WmmaK/A_K1>{}, Number<MRepeat>{}, I1, I1, Number<A_K1>{}));
// B[N0, N1, N2, K0 = WmmaK]
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<WmmaK>{}));
make_naive_tensor_descriptor_packed(make_tuple(Number<WmmaK/B_K1>{}, Number<MRepeat>{}, I1, I1, Number<B_K1>{}));
// C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWMMA()));
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(a_block_desc_krepeat_m0_m1_m2_kpack),
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<1, 1, 1, WmmaK>,
Sequence<0, 1, 2, 3>,
3,
Sequence<WmmaK/A_K1, 1, 1, 1, A_K1>,
Sequence<3, 0, 1, 2, 4>,
4,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(b_block_desc_krepeat_n0_n1_n2_kpack),
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<1, 1, 1, WmmaK>,
Sequence<0, 1, 2, 3>,
3,
Sequence<WmmaK/B_K1, 1, 1, 1, B_K1>,
Sequence<3, 0, 1, 2, 4>,
4,
B_K1,
B_K1>;
......
......@@ -12,7 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_v1r1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -38,8 +38,8 @@ template <typename ADataType,
ck::index_t K1,
ck::index_t MPerWMMA,
ck::index_t NPerWMMA,
ck::index_t MWmmaPerWave,
ck::index_t NWmmaPerWave,
ck::index_t MRepeat,
ck::index_t NRepeat,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
......@@ -196,7 +196,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1<
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_wmma<
BlockSize,
ADataType, // TODO: distinguish A/B datatype
AccDataType,
......@@ -214,8 +214,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
MPerWMMA,
NPerWMMA,
K1,
MWmmaPerWave,
NWmmaPerWave,
MRepeat,
NRepeat,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
......@@ -232,16 +232,15 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
#if 0
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
Sequence<0, 1, 2, 3, 4, 5, 6>, // CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
#endif
NumPrefetch,
LoopSched,
PipelineVer>;
// Argument
struct Argument : public BaseArgumentW
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
......@@ -263,7 +262,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
a_grid_desc_k0_m_k1_{},
b_grid_desc_k0_n_k1_{},
c_grid_desc_m_n_{},
c_grid_desc_mblock_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow_{},
c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs_{},
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
......@@ -283,8 +282,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MWmmaPerWave_Mwave_MLaneHigh_NBlock_NWmmaPerWave_Nwave_NLane_MLaneLow(c_grid_desc_m_n_);
c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs_ =
GridwiseGemm::MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_grid_desc_m_n_);
}
}
......@@ -295,8 +294,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MWmmaPerWave_Mwave_MLaneHigh_NBlock_NWmmaPerWave_Nwave_NLane_MLaneLow
c_grid_desc_mblock_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow_;
typename GridwiseGemm::CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
......@@ -347,19 +346,21 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_gemm_wmma_v1r1<
const auto kernel = kernel_gemm_wmma<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmWmma::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmWmma::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDescriptor_MBlock_MWmmaPerWave_Mwave_MLaneHigh_NBlock_NWmmaPerWave_Nwave_NLane_MLaneLow>,
remove_reference_t<typename GridwiseGemm::CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>; // Last Option is W/O
std::cout<<"Host kernel type is "<< type_name<decltype(kernel)>()<<std::endl;
printf("---------------------Crush before kernel launch-------------------\n");
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
......@@ -370,7 +371,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow_,
arg.c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
......@@ -378,13 +379,13 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
}
else
{
const auto kernel = kernel_gemm_wmma_v1r1<
const auto kernel = kernel_gemm_wmma<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmWmma::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmWmma::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDescriptor_MBlock_MWmmaPerWave_Mwave_MLaneHigh_NBlock_NWmmaPerWave_Nwave_NLane_MLaneLow>,
remove_reference_t<typename GridwiseGemm::CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
......@@ -401,7 +402,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow_,
arg.c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
......@@ -540,8 +541,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
<< K1 << ", "
<< MPerWMMA << ", "
<< NPerWMMA << ", "
<< MWmmaPerWave << ", "
<< NWmmaPerWave
<< MRepeat << ", "
<< NRepeat
<< ">"
<< " NumPrefetch: "
<< NumPrefetch << ", "
......
......@@ -22,7 +22,7 @@ template <typename GridwiseGemm,
typename FloatC,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDescriptor_MBlock_MRepeat_Mwave_MSubGroup_NBlock_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs,
typename CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
......@@ -32,14 +32,14 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_wmma_v1r1(
kernel_gemm_wmma(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDescriptor_MBlock_MRepeat_Mwave_MSubGroup_NBlock_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
const CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
......@@ -55,7 +55,7 @@ __global__ void
p_shared,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs,
a_element_op,
b_element_op,
c_element_op,
......@@ -66,7 +66,7 @@ __global__ void
ignore = p_c_grid;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs;
ignore = c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
......@@ -92,8 +92,8 @@ template <
index_t MPerWmma,
index_t NPerWmma,
index_t K1Value,
index_t MWmmaPerWave,
index_t NWmmaPerWave,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
......@@ -114,8 +114,9 @@ template <
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
index_t NumGemmKPrefetchStage = 1,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -132,7 +133,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{
......@@ -207,8 +208,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time");
static_assert((MPerBlock % (MPerWmma * MWmmaPerWave) == 0) &&
(NPerBlock % (NWmmaPerWave * NPerWmma)) == 0,
static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) &&
(NPerBlock % (NRepeat * NPerWmma)) == 0,
"Invalid tuning param!");
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
......@@ -247,35 +248,57 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
template <typename CGridDesc_M_N_>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlock_MRepeat_Mwave_MSubGroup_NBlock_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
const CGridDesc_M_N_& c_grid_desc_m_n)
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
constexpr index_t MWave = MPerBlock / (MWmmaPerWave * MPerWmma);
constexpr index_t NWave = NPerBlock / (NWmmaPerWave * NPerWmma);
constexpr index_t MLaneHigh = 2;
constexpr index_t MLaneLow = NWmmaPerWave / MLaneHigh;
constexpr index_t NLane = NWmmaPerWave;
const auto c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(
MBlock, Number<MWmmaPerWave>{}, Number<MWave>{}, Number<MLaneHigh>{}, Number<MLaneLow>{})),
make_unmerge_transform(make_tuple(
NBlock, Number<NWmmaPerWave>{}, Number<NWave>{}, Number<NLane>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2, 3, 8>{}, Sequence<4, 5, 6, 7>{}));
return c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs;
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0perblock_mperblock_k1 = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
}
}();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() {
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();
constexpr auto WmmaK = 16;
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3<BlockSize,
FloatAB,
FloatAcc,
decltype(a_block_desc_k0perblock_mperblock_k1),
decltype(b_block_desc_k0perblock_nperblock_k1),
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
KPack>;
return BlockwiseGemm::MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_grid_desc_m_n);
}
// return block_id to C matrix tile idx (m0, n0) mapping
......@@ -285,9 +308,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
}
using CGridDescriptor_MBlock_MRepeat_Mwave_MSubGroup_NBlock_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs =
using CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs =
remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MRepeat_Mwave_MSubGroup_NBlock_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
......@@ -300,8 +323,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDescriptor_MBlock_MRepeat_Mwave_MSubGroup_NBlock_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs&
c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
const CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs&
c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
......@@ -315,15 +338,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetElementSpaceSize());
p_c_grid, c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs.GetElementSpaceSize());
/*******************************************************************************/
// BlockIdx.x -> [BlockId.m, BlockId.n]
const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I0),
c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I4))))
make_tuple(c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I0),
c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I4))))
{ return; }
// Store BlockId into SGPR
......@@ -415,8 +438,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
decltype(b_block_desc_k0perblock_nperblock_k1),
MPerWmma,
NPerWmma,
MWmmaPerWave,
NWmmaPerWave,
MRepeat,
NRepeat,
KPack>{};
// Prepare Register for C matrix
......@@ -450,20 +473,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
blockwise_gemm,
c_thread_buf,
K0BlockMainLoop);
// NO C-shuffle, direct write
/*******************************************************************************/
// write out C matrix, c shuffle not implemented
{
constexpr c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MLaneHigh_NRepeat_NWave_NLane_MLaneLow();
constexpr c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
blockwise_gemm.MakeCDesc_MRepeat_Mwave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
constexpr auto MRepeat = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I0);
constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I1);
constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I2);
constexpr auto NRepeat = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I3);
constexpr auto Nwave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I4);
constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I5);
constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I6);
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
constexpr auto MWave = c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I1);
constexpr auto MSubGroup = c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I2);
constexpr auto Nwave = c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I4);
constexpr auto NThreadPerSubGroup = c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I5);
constexpr auto MAccVgprs = c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I6);
// Mapping
const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0);
......@@ -476,16 +496,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup =
const auto n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(NRepeat, Nwave, NThreadPerSubGroup))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_grid_idx = m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor(
const auto m_thread_data_on_grid_idx = m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_grid));
const auto n_thread_data_on_grid_idx = n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup(
const auto n_thread_data_on_grid_idx = n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_grid));
......@@ -494,8 +514,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
/* typename SrcData */ FloatAcc,
/* typename DstData */ FloatC,
/* typename SrcDesc */ decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
/* typename DstDesc */ decltype(c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
/* typename DstDesc */ decltype(c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs),
/* typename ElementwiseOperation */ CElementwiseOperation,
// Thread register Mapping
/* typename SliceLengths */ Sequence<MRepeat, I1, I1, NRepeat, I1, I1, MAccVgprs>,
/* typename DimAccessOrder */ CThreadTransferSrcDstAccessOrder,
/* index_t DstVectorDim */ CThreadTransferSrcDstVectorDim,
......@@ -504,7 +525,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
/* index_t DstScalarStrideInVector */ 1,
/* bool DstResetCoordinateAfterRun */ true>
{
/* dst_desc */ c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
/* dst_desc */ c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs,
/* dst_slice_origin_idx */ make_multi_index(m_thread_data_on_grid_idx[I0],
m_thread_data_on_grid_idx[I1],
m_thread_data_on_grid_idx[I2],
......@@ -517,9 +538,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
c_thread_copy.Run(
/* c_thread_desc */ c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
/* c_start point */ make_tuple(I0, I0, I0, I0, I0, I0, I0),
/* c_buffer */ c_thread_buf,
/* c_grid_desc */ c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
/* c_register_beginning*/ make_tuple(I0, I0, I0, I0, I0, I0, I0),
/* c_local(register) */ c_thread_buf,
/* c_grid_desc */ c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs,
/* c_grid_buf */ c_grid_buf);
}
// clang-format on
......
......@@ -72,12 +72,14 @@ enum struct WmmaInstr
template <WmmaInstr Instr,
index_t WaveSize,
typename enable_if<WaveSize == 32 || WaveSize == 64, bool>:: = false>
struct wmma_type;
typename = void>
struct wmma_type{};
// A-swizzled
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16, WaveSize>
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
WaveSize,
typename std::enable_if_t<WaveSize == 32 ||WaveSize == 64>>
{
// Absolute fixing property
// * Data Pixel
......@@ -172,11 +174,7 @@ struct WmmaSelector
static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave * selected_wmma.acc_data_size==
selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4,
"WRONG! Number of Accumulator Register");
static_assert(selected_wmma.lane_size * selected_wmma.num_srcregs_per_wmma * selected_wmma.src_data_size==
selected_wmma.m_per_wmma * selected_wmma.k_per_wmma * 4,
"WRONG! Number of Source Register");
"WRONG! Invalid Number of Accumulator Register");
}
};
......@@ -206,25 +204,25 @@ struct WmmaGemm
static_assert(KPack == wmma_instr.k_per_wmma, "KPack should be k_per_wmma");
}
// XDL output supporting C = A * B
// WMMA output supporting C = A * B
// MPerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave
template <typename CDesc_MRepeat_Mwave_MPerWMMA_NRepeat_NWave_NPerWMMA>
template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
__host__ __device__ static constexpr auto
MakeCDesc_MRepeat_Mwave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(const CDesc_MRepeat_Mwave_MPerWMMA_NRepeat_NWave_NPerWMMA& c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma)
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA& c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
{
const auto MRepeat = c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma.GetLength(I0);
const auto NRepeat = c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma.GetLength(I3);
const auto MWave = c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma.GetLength(I1);
const auto NWave = c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma.GetLength(I4);
const auto MBlockxRepeat = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0);
const auto NBlockxRepeat = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3);
const auto MWave = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1);
const auto NWave = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
return transform_tensor_descriptor(
c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma,
make_tuple(make_pass_through_transform(MRepeat),
make_pass_through_transform(Mwave),
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
make_tuple(make_pass_through_transform(MBlockxRepeat),
make_pass_through_transform(MWave),
make_unmerge_transform(make_tuple(Number<wmma_instr.num_subgroups>{},
Number<wmma_instr.num_acc_vgprs_per_wave>{})),
make_pass_through_transform(NRepeat),
make_pass_through_transform(NBlockxRepeat),
make_pass_through_transform(NWave),
make_pass_through_transform(Number<wmma_instr.num_thread_per_subgroups>{})),
make_tuple(Sequence<0>{},
......@@ -266,12 +264,12 @@ struct WmmaGemm
if constexpr(!TransposeC)
{
wmma_instr.template run<MPerWmma, NPerWmma>(
p_a_wave[0], p_b_wave[0], p_c_thread);
p_a_wave, p_b_wave, p_c_thread);
}
else
{
wmma_instr.template run<MPerWmma, NPerWmma>(
p_b_wave[0], p_a_wave[0], p_c_thread);
p_b_wave, p_a_wave, p_c_thread);
}
}
......@@ -318,7 +316,7 @@ struct WmmaGemm
__host__ __device__ static constexpr auto GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()
{
return make_tuple(
Number<I1, I1, Number<wmma_instr.num_acc_vgprs_per_wave>{});
I1, I1, Number<wmma_instr.num_acc_vgprs_per_wave>{});
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment