Commit 9adf2e60 authored by aska-0096's avatar aska-0096
Browse files

runtime bug, cannot find symbol

parent b3cc22a3
...@@ -27,7 +27,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma ...@@ -27,7 +27,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| // ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| // ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>; < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 6, 1>;
// clang-format on // clang-format on
......
...@@ -10,16 +10,6 @@ ...@@ -10,16 +10,6 @@
namespace ck { namespace ck {
enum struct LoopScheduler
{
Default,
};
constexpr LoopScheduler make_default_loop_scheduler()
{
return LoopScheduler::Default;
}
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
...@@ -30,18 +20,22 @@ template <index_t BlockSize, ...@@ -30,18 +20,22 @@ template <index_t BlockSize,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPack> index_t KPack>
// MRepeat_MWave_MLaneHigh_NRepeat_NWave_NLane_MLanelow /* A: K0PerBlock x MPerBlock x K1
* B: K0PerBlock x NPerBlock x K1
* C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
*/
struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto I3 = Number<4>{}; static constexpr auto I4 = Number<4>{};
static constexpr auto WmmaK = Number<16>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t WaveSize = get_warp_size(); static constexpr index_t WaveSize = 32;
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
...@@ -52,7 +46,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -52,7 +46,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto wmma_gemm = WMMAGemm<FloatAB, MPerWMMA, NPerWMMA, KPack>{}; static constexpr auto wmma_gemm = WmmaGemm<FloatAB, FloatAcc, MPerWMMA, NPerWMMA, KPack>{};
static constexpr index_t KPerThread = KPerBlock / wmma_gemm.K0PerWMMA; static constexpr index_t KPerThread = KPerBlock / wmma_gemm.K0PerWMMA;
...@@ -62,7 +56,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -62,7 +56,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc, FloatAcc,
MRepeat * NRepeat, MRepeat * NRepeat,
wmma_gemm.GetRegSizePerWMMA(), wmma_gemm.GetRegSizePerWmma(),
true> true>
c_thread_buf_; c_thread_buf_;
...@@ -87,7 +81,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -87,7 +81,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
const auto waveId_m = wave_idx[I0]; const auto waveId_m = wave_idx[I0];
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
// |KRepeat |MRepeat|Mwave |MLane |KPack // |KRepeat |MRepeat|MWave |MLane |KPack
return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0); return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0);
} }
...@@ -131,7 +125,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -131,7 +125,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
return make_tuple(c_thread_m, c_thread_n); return make_tuple(c_thread_m, c_thread_n);
} }
__host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1() __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3()
{ {
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0NK1BlockDesc::IsKnownAtCompileTime(), BK0NK1BlockDesc::IsKnownAtCompileTime(),
...@@ -157,76 +151,49 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -157,76 +151,49 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
make_tuple(Number<MRepeat>{}, I1, MSubGroup, Number<NRepeat>{}, I1, NThreadPerSubGroup, MAccVgprs)); make_tuple(Number<MRepeat>{}, I1, MSubGroup, Number<NRepeat>{}, I1, NThreadPerSubGroup, MAccVgprs));
} }
__host__ __device__ static constexpr auto GetCBlockDescriptor_MRepeat_Mwave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<MWaves>{},
Number<MPerWMMA>{},
Number<NRepeat>{},
Number<NWaves>{},
Number<NPerWMMA>{}));
return wmma_gemm.MakeCDesc_MRepeat_Mwave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
}
template <typename CGridDesc_M_N> template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(const CGridDesc_M_N& c_grid_desc_m_n)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0); const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1); const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = transform_tensor_descriptor(
c_grid_desc_m_n, c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return wmma_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); return wmma_gemm.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
} }
__host__ __device__ static constexpr auto MakeABlockDescriptor_KRepeat_M0_M1_M2_KPack() __host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1()
{ {
static constexpr auto a_block_desc_temp_km0m1m2 = transform_tensor_descriptor(
AK0MK1BlockDesc{},
make_tuple(
make_merge_transform(make_tuple(Number<A_K0>{}, Number<A_K1>{})),
make_unmerge_transform(make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWMMA>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}));
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_block_desc_temp_km0m1m2, AK0MK1BlockDesc{},
make_tuple( make_tuple(
make_unmerge_transform(make_tuple(Number<A_K0*A_K1/KPack>{}, Number<KPack>{})), make_pass_through_transform(Number<A_K0>{}),
make_pass_through_transform(make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWMMA>{}))), make_unmerge_transform(make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWMMA>{})),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}), make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0, 4>{}, Sequence<1, 2, 3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
} }
__host__ __device__ static constexpr auto MakeBBlockDescriptor_KRepeat_N0_N1_N2_KPack() __host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1()
{ {
static constexpr auto b_block_desc_temp_kn0n1n2 = transform_tensor_descriptor(
BK0NK1BlockDesc{},
make_tuple(
make_merge_transform(make_tuple(Number<B_K0>{}, Number<B_K1>{})),
make_unmerge_transform(make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWMMA>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}));
return transform_tensor_descriptor( return transform_tensor_descriptor(
b_block_desc_temp_kn0n1n2, BK0NK1BlockDesc{},
make_tuple( make_tuple(
make_unmerge_transform(make_tuple(Number<B_K0*B_K1/KPack>{}, Number<KPack>{})), make_pass_through_transform(Number<B_K0>{}),
make_pass_through_transform(make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWMMA>{}))), make_unmerge_transform(make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWMMA>{})),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}), make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0, 4>{}, Sequence<1, 2, 3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
} }
static constexpr auto a_block_desc_krepeat_m0_m1_m2_kpack = MakeABlockDescriptor_KRepeat_M0_M1_M2_KPack(); static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1();
static constexpr auto b_block_desc_krepeat_n0_n1_n2_kpack = MakeBBlockDescriptor_KRepeat_N0_N1_N2_KPack(); static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1();
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf, __device__ void Run(const ABlockBuffer& a_block_buf,
...@@ -239,9 +206,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -239,9 +206,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
constexpr auto RepeatDiff = MRepeat - NRepeat; constexpr auto RepeatDiff = MRepeat - NRepeat;
constexpr auto WmmaK = wmma_gemm.k_per_wmma;
static_for<0, KPerBlock / WmmaK, 1>{}([&](auto iWmmaK){ static_for<0, KPerBlock, WmmaK>{}([&](auto iWmmaK){
// Cut to Repeat Retangle to Square, assume MRepeat > NRepeat // Cut to Repeat Retangle to Square, assume MRepeat > NRepeat
static_for<0, RepeatDiff, 1>{}([&](auto iCut){ static_for<0, RepeatDiff, 1>{}([&](auto iCut){
static_for<0, NRepeat, 1>{}([&](auto iN){ static_for<0, NRepeat, 1>{}([&](auto iN){
...@@ -251,25 +217,25 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -251,25 +217,25 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
static_for<0, WmmaK, 1>{}([&](auto iK) { static_for<0, WmmaK, 1>{}([&](auto iK) {
a_thread_vec.template AsType<FloatAB>()(iK) = a_thread_vec.template AsType<FloatAB>()(iK) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(iCut, 0, 0, iK))>{}]; make_tuple(iK/A_K1, iCut, 0, 0, iK%A_K1))>{}];
b_thread_vec.template AsType<FloatAB>()(iK) = b_thread_vec.template AsType<FloatAB>()(iK) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset( b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(iN, 0, 0, iK))>{}]; make_tuple(iK/B_K1, iN, 0, 0, iK%B_K1))>{}];
}); });
using wmma_input_type = typename vector_type<FloatAB, WmmaK>::type; using wmma_input_type = typename vector_type<FloatAB, WmmaK>::type;
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0));
wmma_gemm.template Run( wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type>(), a_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type>(), b_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
a_thread_copy_.Run(a_block_desc_krepeat_m0_m1_m2_kpack, a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<iWmmaK>{}, iCut, I0, I0, I0), make_tuple(Number<iWmmaK/A_K1>{}, Number<iCut>{}, I0, I0, Number<iWmmaK%A_K1>{}),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, I0, I0, I0), make_tuple(I0, Number<iCut>{}, I0, I0, I0),
a_thread_buf); a_thread_buf);
}); });
// Run FIFO fashion loopover in Square // Run FIFO fashion loopover in Square
...@@ -281,25 +247,25 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -281,25 +247,25 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
static_for<0, WmmaK, 1>{}([&](auto iK) { static_for<0, WmmaK, 1>{}([&](auto iK) {
a_thread_vec.template AsType<FloatAB>()(iK) = a_thread_vec.template AsType<FloatAB>()(iK) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(WmmaInnerloop+RepeatDiff, 0, 0, iK))>{}]; make_tuple(iK/A_K1, WmmaInnerloop+RepeatDiff, 0, 0, iK%A_K1))>{}];
b_thread_vec.template AsType<FloatAB>()(iK) = b_thread_vec.template AsType<FloatAB>()(iK) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset( b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(iN, 0, 0, iK))>{}]; make_tuple(iK/B_K1, iN, 0, 0, iK%B_K1))>{}];
}); });
using wmma_input_type = typename vector_type<FloatAB, WmmaK>::type; using wmma_input_type = typename vector_type<FloatAB, WmmaK>::type;
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop+RepeatDiff, iN, 0)); c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop+RepeatDiff, iN, 0));
wmma_gemm.template Run( wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type>(), a_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type>(), b_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
a_thread_copy_.Run(a_block_desc_krepeat_m0_m1_m2_kpack, a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<iWmmaK>{}, WmmaInnerloop+RepeatDiff, I0, I0, I0), make_tuple(Number<iWmmaK/A_K1>{}, Number<WmmaInnerloop+RepeatDiff>{}, I0, I0, Number<iWmmaK%A_K1>{}),
a_block_buf, a_block_buf,
a_thread_desc_, a_thread_desc_,
make_tuple(I0, I0, I0, I0), make_tuple(I0, Number<WmmaInnerloop+RepeatDiff>{}, I0, I0, I0),
a_thread_buf); a_thread_buf);
static_for<WmmaInnerloop+RepeatDiff, MRepeat, 1>{}([&](auto iM){ static_for<WmmaInnerloop+RepeatDiff, MRepeat, 1>{}([&](auto iM){
vector_type<FloatAB, WmmaK> a_thread_vec; vector_type<FloatAB, WmmaK> a_thread_vec;
...@@ -308,25 +274,25 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -308,25 +274,25 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
static_for<0, WmmaK, 1>{}([&](auto iK) { static_for<0, WmmaK, 1>{}([&](auto iK) {
a_thread_vec.template AsType<FloatAB>()(iK) = a_thread_vec.template AsType<FloatAB>()(iK) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(iM, 0, 0, iK))>{}]; make_tuple(iK/A_K1, iM, 0, 0, iK%A_K1))>{}];
b_thread_vec.template AsType<FloatAB>()(iK) = b_thread_vec.template AsType<FloatAB>()(iK) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset( b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(WmmaInnerloop, 0, 0, iK))>{}]; make_tuple(iK/B_K1, WmmaInnerloop, 0, 0, iK%B_K1))>{}];
}); });
using wmma_input_type = typename vector_type<FloatAB, WmmaK>::type; using wmma_input_type = typename vector_type<FloatAB, WmmaK>::type;
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0)); c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0));
wmma_gemm.template Run( wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type>(), a_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type>(), b_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
b_thread_copy_.Run(b_block_desc_krepeat_n0_n1_n2_kpack, b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<iWmmaK>{}, WmmaInnerloop, I0, I0, I0), make_tuple(Number<iWmmaK/B_K1>{}, Number<WmmaInnerloop>{}, I0, I0, Number<iWmmaK%B_K1>{}),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(I0, I0, I0, I0), make_tuple(I0, Number<WmmaInnerloop>{}, I0, I0, I0),
b_thread_buf); b_thread_buf);
}); });
}); });
...@@ -335,33 +301,33 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3 ...@@ -335,33 +301,33 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
protected: protected:
// A[M0, M1, M2, K0 = WmmaK] // A[M0, M1, M2, K0 = WmmaK]
static constexpr auto a_thread_desc_ = static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<WmmaK>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<WmmaK/A_K1>{}, Number<MRepeat>{}, I1, I1, Number<A_K1>{}));
// B[N0, N1, N2, K0 = WmmaK] // B[N0, N1, N2, K0 = WmmaK]
static constexpr auto b_thread_desc_ = static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<WmmaK>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<WmmaK/B_K1>{}, Number<MRepeat>{}, I1, I1, Number<B_K1>{}));
// C[M, N, NumRegWMMA] // C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWMMA())); make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB, using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB, FloatAB,
decltype(a_block_desc_krepeat_m0_m1_m2_kpack), decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_), decltype(a_thread_desc_),
Sequence<1, 1, 1, WmmaK>, Sequence<WmmaK/A_K1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3>, Sequence<3, 0, 1, 2, 4>,
3, 4,
A_K1, A_K1,
A_K1>; A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB, using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB, FloatAB,
decltype(b_block_desc_krepeat_n0_n1_n2_kpack), decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_), decltype(b_thread_desc_),
Sequence<1, 1, 1, WmmaK>, Sequence<WmmaK/B_K1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3>, Sequence<3, 0, 1, 2, 4>,
3, 4,
B_K1, B_K1,
B_K1>; B_K1>;
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_v1r1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -38,8 +38,8 @@ template <typename ADataType, ...@@ -38,8 +38,8 @@ template <typename ADataType,
ck::index_t K1, ck::index_t K1,
ck::index_t MPerWMMA, ck::index_t MPerWMMA,
ck::index_t NPerWMMA, ck::index_t NPerWMMA,
ck::index_t MWmmaPerWave, ck::index_t MRepeat,
ck::index_t NWmmaPerWave, ck::index_t NRepeat,
typename ABlockTransferThreadClusterLengths_K0_M_K1, typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
...@@ -196,7 +196,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -196,7 +196,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1< using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_wmma<
BlockSize, BlockSize,
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
...@@ -214,8 +214,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -214,8 +214,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
MPerWMMA, MPerWMMA,
NPerWMMA, NPerWMMA,
K1, K1,
MWmmaPerWave, MRepeat,
NWmmaPerWave, NRepeat,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -232,16 +232,15 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -232,16 +232,15 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun, false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN, BBlockLdsAddExtraN,
#if 0 Sequence<0, 1, 2, 3, 4, 5, 6>, // CThreadTransferSrcDstAccessOrder,
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
#endif
NumPrefetch, NumPrefetch,
LoopSched,
PipelineVer>; PipelineVer>;
// Argument // Argument
struct Argument : public BaseArgumentW struct Argument : public BaseArgument
{ {
Argument(const ADataType* p_a_grid, Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid, const BDataType* p_b_grid,
...@@ -263,7 +262,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -263,7 +262,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
a_grid_desc_k0_m_k1_{}, a_grid_desc_k0_m_k1_{},
b_grid_desc_k0_n_k1_{}, b_grid_desc_k0_n_k1_{},
c_grid_desc_m_n_{}, c_grid_desc_m_n_{},
c_grid_desc_mblock_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow_{}, c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs_{},
block_2_ctile_map_{}, block_2_ctile_map_{},
M01_{M01}, M01_{M01},
N01_{N01}, N01_{N01},
...@@ -283,8 +282,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -283,8 +282,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
c_grid_desc_m_n_, c_grid_desc_m_n_,
block_2_ctile_map_)) block_2_ctile_map_))
{ {
c_grid_desc_mblock_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow_ = c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MWmmaPerWave_Mwave_MLaneHigh_NBlock_NWmmaPerWave_Nwave_NLane_MLaneLow(c_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_grid_desc_m_n_);
} }
} }
...@@ -295,8 +294,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -295,8 +294,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MWmmaPerWave_Mwave_MLaneHigh_NBlock_NWmmaPerWave_Nwave_NLane_MLaneLow typename GridwiseGemm::CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
c_grid_desc_mblock_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow_; c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
...@@ -347,19 +346,21 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -347,19 +346,21 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_wmma_v1r1< const auto kernel = kernel_gemm_wmma<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceGemmWmma::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmWmma::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmWmma::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmWmma::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDescriptor_MBlock_MWmmaPerWave_Mwave_MLaneHigh_NBlock_NWmmaPerWave_Nwave_NLane_MLaneLow>, remove_reference_t<typename GridwiseGemm::CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>, remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>; // Last Option is W/O true>; // Last Option is W/O
std::cout<<"Host kernel type is "<< type_name<decltype(kernel)>()<<std::endl;
printf("---------------------Crush before kernel launch-------------------\n");
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(grid_size), dim3(grid_size),
...@@ -370,7 +371,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -370,7 +371,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow_, arg.c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
...@@ -378,13 +379,13 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -378,13 +379,13 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
} }
else else
{ {
const auto kernel = kernel_gemm_wmma_v1r1< const auto kernel = kernel_gemm_wmma<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceGemmWmma::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmWmma::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmWmma::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmWmma::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDescriptor_MBlock_MWmmaPerWave_Mwave_MLaneHigh_NBlock_NWmmaPerWave_Nwave_NLane_MLaneLow>, remove_reference_t<typename GridwiseGemm::CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
...@@ -401,7 +402,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -401,7 +402,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow_, arg.c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
...@@ -540,8 +541,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout, ...@@ -540,8 +541,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
<< K1 << ", " << K1 << ", "
<< MPerWMMA << ", " << MPerWMMA << ", "
<< NPerWMMA << ", " << NPerWMMA << ", "
<< MWmmaPerWave << ", " << MRepeat << ", "
<< NWmmaPerWave << NRepeat
<< ">" << ">"
<< " NumPrefetch: " << " NumPrefetch: "
<< NumPrefetch << ", " << NumPrefetch << ", "
......
...@@ -22,7 +22,7 @@ template <typename GridwiseGemm, ...@@ -22,7 +22,7 @@ template <typename GridwiseGemm,
typename FloatC, typename FloatC,
typename AGridDesc_K0_M_K1, typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
typename CGridDescriptor_MBlock_MRepeat_Mwave_MSubGroup_NBlock_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs, typename CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
...@@ -32,14 +32,14 @@ __global__ void ...@@ -32,14 +32,14 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_wmma_v1r1( kernel_gemm_wmma(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDescriptor_MBlock_MRepeat_Mwave_MSubGroup_NBlock_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs const CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs, c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
...@@ -55,7 +55,7 @@ __global__ void ...@@ -55,7 +55,7 @@ __global__ void
p_shared, p_shared,
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs, c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
...@@ -66,7 +66,7 @@ __global__ void ...@@ -66,7 +66,7 @@ __global__ void
ignore = p_c_grid; ignore = p_c_grid;
ignore = a_grid_desc_k0_m_k1; ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1; ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs; ignore = c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = c_element_op;
...@@ -92,8 +92,8 @@ template < ...@@ -92,8 +92,8 @@ template <
index_t MPerWmma, index_t MPerWmma,
index_t NPerWmma, index_t NPerWmma,
index_t K1Value, index_t K1Value,
index_t MWmmaPerWave, index_t MRepeat,
index_t NWmmaPerWave, index_t NRepeat,
typename ABlockTransferThreadClusterLengths_K0_M_K1, typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
...@@ -114,8 +114,9 @@ template < ...@@ -114,8 +114,9 @@ template <
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
index_t NumGemmKPrefetchStage = 1, index_t NumGemmKPrefetchStage = 1,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -132,7 +133,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 ...@@ -132,7 +133,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{ {
...@@ -207,8 +208,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 ...@@ -207,8 +208,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value, static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time"); "wrong! K1 need to be known at compile-time");
static_assert((MPerBlock % (MPerWmma * MWmmaPerWave) == 0) && static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) &&
(NPerBlock % (NWmmaPerWave * NPerWmma)) == 0, (NPerBlock % (NRepeat * NPerWmma)) == 0,
"Invalid tuning param!"); "Invalid tuning param!");
const auto M = a_grid_desc_k0_m_k1.GetLength(I1); const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
...@@ -247,35 +248,57 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 ...@@ -247,35 +248,57 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
} }
template <typename CGridDesc_M_N_>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlock_MRepeat_Mwave_MSubGroup_NBlock_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs( MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
const CGridDesc_M_N_& c_grid_desc_m_n) const CGridDesc_M_N& c_grid_desc_m_n)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0); constexpr auto max_lds_align = K1;
const auto N = c_grid_desc_m_n.GetLength(I1);
// A matrix in LDS memory, dst of blockwise copy
const auto MBlock = M / MPerBlock; constexpr auto a_block_desc_k0perblock_mperblock_k1 = [&]() {
const auto NBlock = N / NPerBlock; if constexpr(ABlockLdsExtraM)
{
constexpr index_t MWave = MPerBlock / (MWmmaPerWave * MPerWmma); return make_naive_tensor_descriptor(
constexpr index_t NWave = NPerBlock / (NWmmaPerWave * NPerWmma); make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
constexpr index_t MLaneHigh = 2; }
constexpr index_t MLaneLow = NWmmaPerWave / MLaneHigh; else
constexpr index_t NLane = NWmmaPerWave; {
return make_naive_tensor_descriptor_aligned(
const auto c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs = make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
transform_tensor_descriptor( }
c_grid_desc_m_n, }();
make_tuple(make_unmerge_transform(make_tuple(
MBlock, Number<MWmmaPerWave>{}, Number<MWave>{}, Number<MLaneHigh>{}, Number<MLaneLow>{})), // B matrix in LDS memory, dst of blockwise copy
make_unmerge_transform(make_tuple( constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() {
NBlock, Number<NWmmaPerWave>{}, Number<NWave>{}, Number<NLane>{}))), if constexpr(BBlockLdsExtraN)
make_tuple(Sequence<0>{}, Sequence<1>{}), {
make_tuple(Sequence<0, 1, 2, 3, 8>{}, Sequence<4, 5, 6, 7>{})); return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
return c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs; make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
}
}();
constexpr auto WmmaK = 16;
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
using BlockwiseGemm = BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3<BlockSize,
FloatAB,
FloatAcc,
decltype(a_block_desc_k0perblock_mperblock_k1),
decltype(b_block_desc_k0perblock_nperblock_k1),
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
KPack>;
return BlockwiseGemm::MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(c_grid_desc_m_n);
} }
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
...@@ -285,9 +308,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 ...@@ -285,9 +308,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>( return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n); c_grid_desc_m_n);
} }
using CGridDescriptor_MBlock_MRepeat_Mwave_MSubGroup_NBlock_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs = using CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs =
remove_cvref_t<decltype( remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MRepeat_Mwave_MSubGroup_NBlock_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs( MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
CGridDesc_M_N{}))>; CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
...@@ -300,8 +323,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 ...@@ -300,8 +323,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDescriptor_MBlock_MRepeat_Mwave_MSubGroup_NBlock_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs& const CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs&
c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs, c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
...@@ -315,15 +338,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 ...@@ -315,15 +338,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs.GetElementSpaceSize());
/*******************************************************************************/ /*******************************************************************************/
// BlockIdx.x -> [BlockId.m, BlockId.n] // BlockIdx.x -> [BlockId.m, BlockId.n]
const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex( if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx, block_work_idx,
make_tuple(c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I0), make_tuple(c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I0),
c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I4)))) c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I4))))
{ return; } { return; }
// Store BlockId into SGPR // Store BlockId into SGPR
...@@ -415,8 +438,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 ...@@ -415,8 +438,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
decltype(b_block_desc_k0perblock_nperblock_k1), decltype(b_block_desc_k0perblock_nperblock_k1),
MPerWmma, MPerWmma,
NPerWmma, NPerWmma,
MWmmaPerWave, MRepeat,
NWmmaPerWave, NRepeat,
KPack>{}; KPack>{};
// Prepare Register for C matrix // Prepare Register for C matrix
...@@ -450,20 +473,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 ...@@ -450,20 +473,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
blockwise_gemm, blockwise_gemm,
c_thread_buf, c_thread_buf,
K0BlockMainLoop); K0BlockMainLoop);
// NO C-shuffle, direct write /*******************************************************************************/
// write out C matrix, c shuffle not implemented
{ {
constexpr c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MLaneHigh_NRepeat_NWave_NLane_MLaneLow(); blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
constexpr c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
blockwise_gemm.MakeCDesc_MRepeat_Mwave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); constexpr auto MWave = c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I1);
constexpr auto MSubGroup = c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I2);
constexpr auto MRepeat = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I0); constexpr auto Nwave = c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I4);
constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I1); constexpr auto NThreadPerSubGroup = c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I5);
constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I2); constexpr auto MAccVgprs = c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I6);
constexpr auto NRepeat = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I3);
constexpr auto Nwave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I4);
constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I5);
constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs.GetLength(I6);
// Mapping // Mapping
const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0); const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0);
...@@ -476,16 +496,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 ...@@ -476,16 +496,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
make_tuple(Sequence<0, 1, 2, 3>{}), make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup = const auto n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(NRepeat, Nwave, NThreadPerSubGroup))), make_tuple(make_merge_transform(make_tuple(NRepeat, Nwave, NThreadPerSubGroup))),
make_tuple(Sequence<0, 1, 2>{}), make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto m_thread_data_on_grid_idx = m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor( const auto m_thread_data_on_grid_idx = m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_grid)); make_multi_index(m_thread_data_on_grid));
const auto n_thread_data_on_grid_idx = n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup( const auto n_thread_data_on_grid_idx = n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_grid)); make_multi_index(n_thread_data_on_grid));
...@@ -494,8 +514,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 ...@@ -494,8 +514,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
/* typename SrcData */ FloatAcc, /* typename SrcData */ FloatAcc,
/* typename DstData */ FloatC, /* typename DstData */ FloatC,
/* typename SrcDesc */ decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), /* typename SrcDesc */ decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
/* typename DstDesc */ decltype(c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs), /* typename DstDesc */ decltype(c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs),
/* typename ElementwiseOperation */ CElementwiseOperation, /* typename ElementwiseOperation */ CElementwiseOperation,
// Thread register Mapping
/* typename SliceLengths */ Sequence<MRepeat, I1, I1, NRepeat, I1, I1, MAccVgprs>, /* typename SliceLengths */ Sequence<MRepeat, I1, I1, NRepeat, I1, I1, MAccVgprs>,
/* typename DimAccessOrder */ CThreadTransferSrcDstAccessOrder, /* typename DimAccessOrder */ CThreadTransferSrcDstAccessOrder,
/* index_t DstVectorDim */ CThreadTransferSrcDstVectorDim, /* index_t DstVectorDim */ CThreadTransferSrcDstVectorDim,
...@@ -504,7 +525,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 ...@@ -504,7 +525,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
/* index_t DstScalarStrideInVector */ 1, /* index_t DstScalarStrideInVector */ 1,
/* bool DstResetCoordinateAfterRun */ true> /* bool DstResetCoordinateAfterRun */ true>
{ {
/* dst_desc */ c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs, /* dst_desc */ c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs,
/* dst_slice_origin_idx */ make_multi_index(m_thread_data_on_grid_idx[I0], /* dst_slice_origin_idx */ make_multi_index(m_thread_data_on_grid_idx[I0],
m_thread_data_on_grid_idx[I1], m_thread_data_on_grid_idx[I1],
m_thread_data_on_grid_idx[I2], m_thread_data_on_grid_idx[I2],
...@@ -517,9 +538,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1 ...@@ -517,9 +538,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
c_thread_copy.Run( c_thread_copy.Run(
/* c_thread_desc */ c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, /* c_thread_desc */ c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
/* c_start point */ make_tuple(I0, I0, I0, I0, I0, I0, I0), /* c_register_beginning*/ make_tuple(I0, I0, I0, I0, I0, I0, I0),
/* c_buffer */ c_thread_buf, /* c_local(register) */ c_thread_buf,
/* c_grid_desc */ c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs, /* c_grid_desc */ c_grid_desc_mblockxrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs,
/* c_grid_buf */ c_grid_buf); /* c_grid_buf */ c_grid_buf);
} }
// clang-format on // clang-format on
......
...@@ -72,12 +72,14 @@ enum struct WmmaInstr ...@@ -72,12 +72,14 @@ enum struct WmmaInstr
template <WmmaInstr Instr, template <WmmaInstr Instr,
index_t WaveSize, index_t WaveSize,
typename enable_if<WaveSize == 32 || WaveSize == 64, bool>:: = false> typename = void>
struct wmma_type; struct wmma_type{};
// A-swizzled // A-swizzled
template <index_t WaveSize> template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16, WaveSize> struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
WaveSize,
typename std::enable_if_t<WaveSize == 32 ||WaveSize == 64>>
{ {
// Absolute fixing property // Absolute fixing property
// * Data Pixel // * Data Pixel
...@@ -172,11 +174,7 @@ struct WmmaSelector ...@@ -172,11 +174,7 @@ struct WmmaSelector
static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave * selected_wmma.acc_data_size== static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave * selected_wmma.acc_data_size==
selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4, selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4,
"WRONG! Number of Accumulator Register"); "WRONG! Invalid Number of Accumulator Register");
static_assert(selected_wmma.lane_size * selected_wmma.num_srcregs_per_wmma * selected_wmma.src_data_size==
selected_wmma.m_per_wmma * selected_wmma.k_per_wmma * 4,
"WRONG! Number of Source Register");
} }
}; };
...@@ -206,25 +204,25 @@ struct WmmaGemm ...@@ -206,25 +204,25 @@ struct WmmaGemm
static_assert(KPack == wmma_instr.k_per_wmma, "KPack should be k_per_wmma"); static_assert(KPack == wmma_instr.k_per_wmma, "KPack should be k_per_wmma");
} }
// XDL output supporting C = A * B // WMMA output supporting C = A * B
// MPerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave // MPerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave
template <typename CDesc_MRepeat_Mwave_MPerWMMA_NRepeat_NWave_NPerWMMA> template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCDesc_MRepeat_Mwave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(const CDesc_MRepeat_Mwave_MPerWMMA_NRepeat_NWave_NPerWMMA& c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma) (const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA& c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
{ {
const auto MRepeat = c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma.GetLength(I0); const auto MBlockxRepeat = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0);
const auto NRepeat = c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma.GetLength(I3); const auto NBlockxRepeat = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3);
const auto MWave = c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma.GetLength(I1); const auto MWave = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1);
const auto NWave = c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma.GetLength(I4); const auto NWave = c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
return transform_tensor_descriptor( return transform_tensor_descriptor(
c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma, c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
make_tuple(make_pass_through_transform(MRepeat), make_tuple(make_pass_through_transform(MBlockxRepeat),
make_pass_through_transform(Mwave), make_pass_through_transform(MWave),
make_unmerge_transform(make_tuple(Number<wmma_instr.num_subgroups>{}, make_unmerge_transform(make_tuple(Number<wmma_instr.num_subgroups>{},
Number<wmma_instr.num_acc_vgprs_per_wave>{})), Number<wmma_instr.num_acc_vgprs_per_wave>{})),
make_pass_through_transform(NRepeat), make_pass_through_transform(NBlockxRepeat),
make_pass_through_transform(NWave), make_pass_through_transform(NWave),
make_pass_through_transform(Number<wmma_instr.num_thread_per_subgroups>{})), make_pass_through_transform(Number<wmma_instr.num_thread_per_subgroups>{})),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
...@@ -266,12 +264,12 @@ struct WmmaGemm ...@@ -266,12 +264,12 @@ struct WmmaGemm
if constexpr(!TransposeC) if constexpr(!TransposeC)
{ {
wmma_instr.template run<MPerWmma, NPerWmma>( wmma_instr.template run<MPerWmma, NPerWmma>(
p_a_wave[0], p_b_wave[0], p_c_thread); p_a_wave, p_b_wave, p_c_thread);
} }
else else
{ {
wmma_instr.template run<MPerWmma, NPerWmma>( wmma_instr.template run<MPerWmma, NPerWmma>(
p_b_wave[0], p_a_wave[0], p_c_thread); p_b_wave, p_a_wave, p_c_thread);
} }
} }
...@@ -318,7 +316,7 @@ struct WmmaGemm ...@@ -318,7 +316,7 @@ struct WmmaGemm
__host__ __device__ static constexpr auto GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths() __host__ __device__ static constexpr auto GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()
{ {
return make_tuple( return make_tuple(
Number<I1, I1, Number<wmma_instr.num_acc_vgprs_per_wave>{}); I1, I1, Number<wmma_instr.num_acc_vgprs_per_wave>{});
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment