Commit 1a089f6f authored by aska-0096's avatar aska-0096
Browse files

sanity bug fix

parent c8c016dd
......@@ -54,8 +54,9 @@ struct BlockwiseGemmXdlops_pipeline_base
static constexpr index_t AMmaKStride = KPack;
static constexpr index_t BMmaKStride = KPack;
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t KRepeat = KPerThread / KPack;
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t KRepeat = KPerThread / KPack;
static constexpr index_t KPerInnerLoop = KPack;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
......
......@@ -227,8 +227,26 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayo
}
};
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
constexpr index_t minimum_occupancy = []() {
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
{
return 2;
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
constexpr index_t instance_lds_size =
MPerBlock * KPerBlock * sizeof(ADataType) +
NPerBlock * KPerBlock * sizeof(BDataType);
return ((MPerBlock * NPerBlock / BlockSize <= 128) &&
(instance_lds_size <= 32768))
? 2
: 1;
}
else
{
return 1;
}
}();
if(has_main_k_block_loop)
{
......
......@@ -1550,30 +1550,35 @@ struct GridwiseGemm_xdl_cshuffle_v3
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
constexpr auto KPerInnerLoop = blockwise_gemm_pipeline.KPerInnerLoop;
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
using mfma_input_type =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(m0, n0, 0));
constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
});
......@@ -1592,29 +1597,31 @@ struct GridwiseGemm_xdl_cshuffle_v3
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(shuffle_n0 + n0, I0, k0, ik))>{}];
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(shuffle_n0 + n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, shuffle_n0 + n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
using mfma_input_type =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, shuffle_n0 + n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
......@@ -2025,31 +2032,35 @@ struct GridwiseGemm_xdl_cshuffle_v3
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
constexpr auto KPerInnerLoop = blockwise_gemm_pipeline.KPerInnerLoop;
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
using mfma_input_type =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(m0, n0, 0));
constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
});
......@@ -2068,29 +2079,31 @@ struct GridwiseGemm_xdl_cshuffle_v3
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(shuffle_n0 + n0, I0, k0, ik))>{}];
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(shuffle_n0 + n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, shuffle_n0 + n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
using mfma_input_type =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, shuffle_n0 + n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
......
......@@ -1685,30 +1685,35 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
constexpr auto KPerInnerLoop = blockwise_gemm_pipeline.KPerInnerLoop;
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
using mfma_input_type =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(m0, n0, 0));
constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
});
......@@ -1728,29 +1733,31 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(shuffle_n0 + n0, I0, k0, ik))>{}];
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(shuffle_n0 + n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, shuffle_n0 + n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
using mfma_input_type =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, shuffle_n0 + n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
......@@ -1790,7 +1797,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
I0,
cde_lds_and_global_step);
EpilogueScheduler();
// EpilogueScheduler();
}
});
}
......@@ -2236,30 +2243,35 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
constexpr auto KPerInnerLoop = blockwise_gemm_pipeline.KPerInnerLoop;
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
using mfma_input_type =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(m0, n0, 0));
constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
});
......@@ -2279,29 +2291,31 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(shuffle_n0 + n0, I0, k0, ik))>{}];
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(shuffle_n0 + n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, shuffle_n0 + n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
using mfma_input_type =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = c_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, shuffle_n0 + n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
......@@ -2341,7 +2355,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
I0,
cde_lds_and_global_step);
EpilogueScheduler();
// EpilogueScheduler();
}
});
}
......
......@@ -234,7 +234,26 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
{
c_device_buf.FromDevice(e_m_n_device_result.mData.data());
pass = pass & ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_INT8
// set softer tolerances for fp8
if constexpr((is_same_v<ADataType, f8_t> || is_same_v<BDataType, f8_t> ||
is_same_v<EDataType, f8_t>) ||
(is_same_v<ADataType, int8_t> || is_same_v<BDataType, int8_t> ||
is_same_v<EDataType, int8_t>))
{
std::string msg = "Error: Incorrect results!";
double rtol = 1e-1;
double atol = 1e-1;
pass = pass & ck::utils::check_err(
e_m_n_device_result, e_m_n_host_result, msg, rtol, atol);
}
else
{
#endif
pass = pass & ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_INT8
}
#endif
if(do_log)
{
......@@ -276,27 +295,6 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch "
<< kbatch_curr << std::endl;
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_INT8
// set softer tolerances for fp8
if constexpr((is_same_v<ADataType, f8_t> || is_same_v<BDataType, f8_t> ||
is_same_v<EDataType, f8_t>) ||
(is_same_v<ADataType, int8_t> || is_same_v<BDataType, int8_t> ||
is_same_v<EDataType, int8_t>))
{
std::string msg = "Error: Incorrect results!";
double rtol = 1e-1;
double atol = 1e-1;
pass = pass & ck::utils::check_err(
e_m_n_device_result, e_m_n_host_result, msg, rtol, atol);
}
else
{
#endif
pass = pass & ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_INT8
}
#endif
if(tflops > best_tflops && ave_time > 1e-10)
{
best_op_name = op_name;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment