Commit 1a089f6f authored by aska-0096's avatar aska-0096
Browse files

sanity bug fix

parent c8c016dd
...@@ -54,8 +54,9 @@ struct BlockwiseGemmXdlops_pipeline_base ...@@ -54,8 +54,9 @@ struct BlockwiseGemmXdlops_pipeline_base
static constexpr index_t AMmaKStride = KPack; static constexpr index_t AMmaKStride = KPack;
static constexpr index_t BMmaKStride = KPack; static constexpr index_t BMmaKStride = KPack;
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t KRepeat = KPerThread / KPack; static constexpr index_t KRepeat = KPerThread / KPack;
static constexpr index_t KPerInnerLoop = KPack;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
......
...@@ -227,8 +227,26 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayo ...@@ -227,8 +227,26 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayo
} }
}; };
constexpr index_t minimum_occupancy = constexpr index_t minimum_occupancy = []() {
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
{
return 2;
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
constexpr index_t instance_lds_size =
MPerBlock * KPerBlock * sizeof(ADataType) +
NPerBlock * KPerBlock * sizeof(BDataType);
return ((MPerBlock * NPerBlock / BlockSize <= 128) &&
(instance_lds_size <= 32768))
? 2
: 1;
}
else
{
return 1;
}
}();
if(has_main_k_block_loop) if(has_main_k_block_loop)
{ {
......
...@@ -1550,30 +1550,35 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1550,30 +1550,35 @@ struct GridwiseGemm_xdl_cshuffle_v3
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
constexpr auto KPerInnerLoop = blockwise_gemm_pipeline.KPerInnerLoop;
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) { static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) { static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec; static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
vector_type<ComputeTypeB, KPack> b_thread_vec; vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) = static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_buf[Number<a_thread_desc.CalculateOffset( a_thread_vec.template AsType<ComputeTypeA>()(ik) =
make_tuple(m0, I0, k0, ik))>{}]; a_thread_buf[Number<a_thread_desc.CalculateOffset(
b_thread_vec.template AsType<ComputeTypeA>()(ik) = make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_buf[Number<b_thread_desc.CalculateOffset( b_thread_vec.template AsType<ComputeTypeA>()(ik) =
make_tuple(n0, I0, k0, ik))>{}]; b_thread_buf[Number<b_thread_desc.CalculateOffset(
}); make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.Run(
b_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
}); });
}); });
}); });
...@@ -1592,29 +1597,31 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1592,29 +1597,31 @@ struct GridwiseGemm_xdl_cshuffle_v3
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) { static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) { static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec; static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
vector_type<ComputeTypeB, KPack> b_thread_vec; vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) = static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_buf[Number<a_thread_desc.CalculateOffset( a_thread_vec.template AsType<ComputeTypeA>()(ik) =
make_tuple(shuffle_m0 + m0, I0, k0, ik))>{}]; a_thread_buf[Number<a_thread_desc.CalculateOffset(
b_thread_vec.template AsType<ComputeTypeA>()(ik) = make_tuple(shuffle_m0 + m0, I0, k0, k_ + ik))>{}];
b_thread_buf[Number<b_thread_desc.CalculateOffset( b_thread_vec.template AsType<ComputeTypeA>()(ik) =
make_tuple(shuffle_n0 + n0, I0, k0, ik))>{}]; b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(shuffle_n0 + n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, shuffle_n0 + n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
using mfma_input_type =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, shuffle_n0 + n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
}); });
}); });
...@@ -2025,31 +2032,35 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -2025,31 +2032,35 @@ struct GridwiseGemm_xdl_cshuffle_v3
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
constexpr auto KPerInnerLoop = blockwise_gemm_pipeline.KPerInnerLoop;
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) { static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) { static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec; static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
vector_type<ComputeTypeB, KPack> b_thread_vec; vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) = static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_buf[Number<a_thread_desc.CalculateOffset( a_thread_vec.template AsType<ComputeTypeA>()(ik) =
make_tuple(m0, I0, k0, ik))>{}]; a_thread_buf[Number<a_thread_desc.CalculateOffset(
b_thread_vec.template AsType<ComputeTypeA>()(ik) = make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_buf[Number<b_thread_desc.CalculateOffset( b_thread_vec.template AsType<ComputeTypeA>()(ik) =
make_tuple(n0, I0, k0, ik))>{}]; b_thread_buf[Number<b_thread_desc.CalculateOffset(
}); make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.Run(
b_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
}); });
}); });
}); });
...@@ -2068,29 +2079,31 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -2068,29 +2079,31 @@ struct GridwiseGemm_xdl_cshuffle_v3
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) { static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) { static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec; static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
vector_type<ComputeTypeB, KPack> b_thread_vec; vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) = static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_buf[Number<a_thread_desc.CalculateOffset( a_thread_vec.template AsType<ComputeTypeA>()(ik) =
make_tuple(shuffle_m0 + m0, I0, k0, ik))>{}]; a_thread_buf[Number<a_thread_desc.CalculateOffset(
b_thread_vec.template AsType<ComputeTypeA>()(ik) = make_tuple(shuffle_m0 + m0, I0, k0, k_ + ik))>{}];
b_thread_buf[Number<b_thread_desc.CalculateOffset( b_thread_vec.template AsType<ComputeTypeA>()(ik) =
make_tuple(shuffle_n0 + n0, I0, k0, ik))>{}]; b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(shuffle_n0 + n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, shuffle_n0 + n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
using mfma_input_type =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, shuffle_n0 + n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
}); });
}); });
......
...@@ -1685,30 +1685,35 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -1685,30 +1685,35 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
constexpr auto KPerInnerLoop = blockwise_gemm_pipeline.KPerInnerLoop;
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) { static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) { static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec; static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
vector_type<ComputeTypeB, KPack> b_thread_vec; vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) = static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_buf[Number<a_thread_desc.CalculateOffset( a_thread_vec.template AsType<ComputeTypeA>()(ik) =
make_tuple(m0, I0, k0, ik))>{}]; a_thread_buf[Number<a_thread_desc.CalculateOffset(
b_thread_vec.template AsType<ComputeTypeA>()(ik) = make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_buf[Number<b_thread_desc.CalculateOffset( b_thread_vec.template AsType<ComputeTypeA>()(ik) =
make_tuple(n0, I0, k0, ik))>{}]; b_thread_buf[Number<b_thread_desc.CalculateOffset(
}); make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.Run(
b_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
}); });
}); });
}); });
...@@ -1728,29 +1733,31 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -1728,29 +1733,31 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) { static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) { static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec; static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
vector_type<ComputeTypeB, KPack> b_thread_vec; vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) = static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_buf[Number<a_thread_desc.CalculateOffset( a_thread_vec.template AsType<ComputeTypeA>()(ik) =
make_tuple(shuffle_m0 + m0, I0, k0, ik))>{}]; a_thread_buf[Number<a_thread_desc.CalculateOffset(
b_thread_vec.template AsType<ComputeTypeA>()(ik) = make_tuple(shuffle_m0 + m0, I0, k0, k_ + ik))>{}];
b_thread_buf[Number<b_thread_desc.CalculateOffset( b_thread_vec.template AsType<ComputeTypeA>()(ik) =
make_tuple(shuffle_n0 + n0, I0, k0, ik))>{}]; b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(shuffle_n0 + n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, shuffle_n0 + n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
using mfma_input_type =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, shuffle_n0 + n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
}); });
}); });
...@@ -1790,7 +1797,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -1790,7 +1797,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
I0, I0,
cde_lds_and_global_step); cde_lds_and_global_step);
EpilogueScheduler(); // EpilogueScheduler();
} }
}); });
} }
...@@ -2236,30 +2243,35 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -2236,30 +2243,35 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
constexpr auto KPerInnerLoop = blockwise_gemm_pipeline.KPerInnerLoop;
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) { static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) { static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec; static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
vector_type<ComputeTypeB, KPack> b_thread_vec; vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) = static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_buf[Number<a_thread_desc.CalculateOffset( a_thread_vec.template AsType<ComputeTypeA>()(ik) =
make_tuple(m0, I0, k0, ik))>{}]; a_thread_buf[Number<a_thread_desc.CalculateOffset(
b_thread_vec.template AsType<ComputeTypeA>()(ik) = make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_buf[Number<b_thread_desc.CalculateOffset( b_thread_vec.template AsType<ComputeTypeA>()(ik) =
make_tuple(n0, I0, k0, ik))>{}]; b_thread_buf[Number<b_thread_desc.CalculateOffset(
}); make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.Run(
b_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
}); });
}); });
}); });
...@@ -2279,29 +2291,31 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -2279,29 +2291,31 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) { static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) { static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec; static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
vector_type<ComputeTypeB, KPack> b_thread_vec; vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) = static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_buf[Number<a_thread_desc.CalculateOffset( a_thread_vec.template AsType<ComputeTypeA>()(ik) =
make_tuple(shuffle_m0 + m0, I0, k0, ik))>{}]; a_thread_buf[Number<a_thread_desc.CalculateOffset(
b_thread_vec.template AsType<ComputeTypeA>()(ik) = make_tuple(shuffle_m0 + m0, I0, k0, k_ + ik))>{}];
b_thread_buf[Number<b_thread_desc.CalculateOffset( b_thread_vec.template AsType<ComputeTypeA>()(ik) =
make_tuple(shuffle_n0 + n0, I0, k0, ik))>{}]; b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(shuffle_n0 + n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, shuffle_n0 + n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
using mfma_input_type =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, shuffle_n0 + n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
}); });
}); });
...@@ -2341,7 +2355,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -2341,7 +2355,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
I0, I0,
cde_lds_and_global_step); cde_lds_and_global_step);
EpilogueScheduler(); // EpilogueScheduler();
} }
}); });
} }
......
...@@ -234,7 +234,26 @@ bool profile_gemm_multiply_multiply_impl(int do_verification, ...@@ -234,7 +234,26 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
{ {
c_device_buf.FromDevice(e_m_n_device_result.mData.data()); c_device_buf.FromDevice(e_m_n_device_result.mData.data());
pass = pass & ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); #if defined CK_ENABLE_FP8 || defined CK_ENABLE_INT8
// set softer tolerances for fp8
if constexpr((is_same_v<ADataType, f8_t> || is_same_v<BDataType, f8_t> ||
is_same_v<EDataType, f8_t>) ||
(is_same_v<ADataType, int8_t> || is_same_v<BDataType, int8_t> ||
is_same_v<EDataType, int8_t>))
{
std::string msg = "Error: Incorrect results!";
double rtol = 1e-1;
double atol = 1e-1;
pass = pass & ck::utils::check_err(
e_m_n_device_result, e_m_n_host_result, msg, rtol, atol);
}
else
{
#endif
pass = pass & ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_INT8
}
#endif
if(do_log) if(do_log)
{ {
...@@ -276,27 +295,6 @@ bool profile_gemm_multiply_multiply_impl(int do_verification, ...@@ -276,27 +295,6 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch " << " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch "
<< kbatch_curr << std::endl; << kbatch_curr << std::endl;
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_INT8
// set softer tolerances for fp8
if constexpr((is_same_v<ADataType, f8_t> || is_same_v<BDataType, f8_t> ||
is_same_v<EDataType, f8_t>) ||
(is_same_v<ADataType, int8_t> || is_same_v<BDataType, int8_t> ||
is_same_v<EDataType, int8_t>))
{
std::string msg = "Error: Incorrect results!";
double rtol = 1e-1;
double atol = 1e-1;
pass = pass & ck::utils::check_err(
e_m_n_device_result, e_m_n_host_result, msg, rtol, atol);
}
else
{
#endif
pass = pass & ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_INT8
}
#endif
if(tflops > best_tflops && ave_time > 1e-10) if(tflops > best_tflops && ave_time > 1e-10)
{ {
best_op_name = op_name; best_op_name = op_name;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment