Commit 842d910e authored by coderfeli's avatar coderfeli
Browse files

Merge branch 'update_cka8w8' into update_cka8w8_uc

parents f82c9aef e2127d7a
......@@ -56,6 +56,7 @@ struct BlockwiseGemmXdlops_pipeline_base
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t KRepeat = KPerThread / KPack;
static constexpr index_t KPerInnerLoop = KPack;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
......
......@@ -418,13 +418,65 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
__builtin_amdgcn_sched_barrier(0);
i += 1;
} while(i < (num_loop - 1));
} while(i < (num_loop - 2));
}
// tail
if constexpr(TailNum == TailNumber::Full)
{
a_thread_buf_tail = a_thread_buf;
b_thread_buf_tail = b_thread_buf;
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k0, I0),
a_thread_buf_tail);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k0, I0),
b_thread_buf_tail);
});
});
}
}
......@@ -580,11 +632,66 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
__builtin_amdgcn_sched_barrier(0);
i += 1;
} while(i < (num_loop - 1));
} while(i < (num_loop - 2));
}
// tail
if constexpr(TailNum == TailNumber::Full)
{
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k0, I0),
b_thread_buf);
});
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
......
......@@ -227,8 +227,20 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayo
}
};
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
constexpr index_t minimum_occupancy = []() {
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
{
return 2;
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
}
else
{
return 1;
}
}();
if(has_main_k_block_loop)
{
......
......@@ -212,13 +212,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
constexpr index_t instance_lds_size =
MPerBlock * KPerBlock * sizeof(ADataType) +
NPerBlock * KPerBlock * sizeof(BDataType);
return ((MPerBlock * NPerBlock / BlockSize <= 128) &&
(instance_lds_size <= 32768))
? 2
: 1;
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
}
else
{
......
......@@ -1550,19 +1550,22 @@ struct GridwiseGemm_xdl_cshuffle_v3
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
constexpr auto KPerInnerLoop = blockwise_gemm_pipeline.KPerInnerLoop;
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
......@@ -1571,12 +1574,14 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
});
__builtin_amdgcn_sched_barrier(0);
static_for<0, num_access, 1>{}([&](auto access_id) {
......@@ -1592,16 +1597,17 @@ struct GridwiseGemm_xdl_cshuffle_v3
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, I0, k0, ik))>{}];
make_tuple(shuffle_m0 + m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(shuffle_n0 + n0, I0, k0, ik))>{}];
make_tuple(shuffle_n0 + n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
......@@ -1618,6 +1624,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
});
});
});
});
}
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
......@@ -2025,20 +2032,22 @@ struct GridwiseGemm_xdl_cshuffle_v3
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
constexpr auto KPerInnerLoop = blockwise_gemm_pipeline.KPerInnerLoop;
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
......@@ -2047,12 +2056,14 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
});
__builtin_amdgcn_sched_barrier(0);
static_for<0, num_access, 1>{}([&](auto access_id) {
......@@ -2068,16 +2079,17 @@ struct GridwiseGemm_xdl_cshuffle_v3
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, I0, k0, ik))>{}];
make_tuple(shuffle_m0 + m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(shuffle_n0 + n0, I0, k0, ik))>{}];
make_tuple(shuffle_n0 + n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
......@@ -2094,6 +2106,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
});
});
});
});
}
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
......
......@@ -1685,19 +1685,22 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
constexpr auto KPerInnerLoop = blockwise_gemm_pipeline.KPerInnerLoop;
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
......@@ -1706,12 +1709,14 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
});
__builtin_amdgcn_sched_barrier(0);
static_for<0, num_access, 1>{}([&](auto access_id) {
......@@ -1728,16 +1733,17 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, I0, k0, ik))>{}];
make_tuple(shuffle_m0 + m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(shuffle_n0 + n0, I0, k0, ik))>{}];
make_tuple(shuffle_n0 + n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
......@@ -1754,6 +1760,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
});
});
});
});
}
// each thread write its data from VGPR to LDS
......@@ -1790,7 +1797,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
I0,
cde_lds_and_global_step);
EpilogueScheduler();
// EpilogueScheduler();
}
});
}
......@@ -2236,19 +2243,22 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
constexpr auto KPerInnerLoop = blockwise_gemm_pipeline.KPerInnerLoop;
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
......@@ -2257,12 +2267,14 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
});
__builtin_amdgcn_sched_barrier(0);
static_for<0, num_access, 1>{}([&](auto access_id) {
......@@ -2279,16 +2291,17 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) {
static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc.CalculateOffset(
make_tuple(shuffle_m0 + m0, I0, k0, ik))>{}];
make_tuple(shuffle_m0 + m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeTypeA>()(ik) =
b_thread_buf[Number<b_thread_desc.CalculateOffset(
make_tuple(shuffle_n0 + n0, I0, k0, ik))>{}];
make_tuple(shuffle_n0 + n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
......@@ -2305,6 +2318,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
});
});
});
});
}
// each thread write its data from VGPR to LDS
......@@ -2341,7 +2355,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
I0,
cde_lds_and_global_step);
EpilogueScheduler();
// EpilogueScheduler();
}
});
}
......
......@@ -16,7 +16,8 @@ namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
#ifdef CK_ENABLE_FP8
#ifdef CK_ENABLE_BF16
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances_part1(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
......@@ -174,6 +175,165 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_i
MultiplyMultiply>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instances_part1(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instances_part1(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instances_part2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instances_part2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instances_part1(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instances_part1(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instances_part2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instances_part2(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
#endif
#endif
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_INT8))
void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
......@@ -292,7 +452,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
#ifdef CK_ENABLE_FP8
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, f8_t> &&
is_same_v<CDataType, bhalf_t>)
{
......@@ -329,6 +490,44 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
}
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, f8_t> &&
is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instances_part1(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instances_part1(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instances_part2(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instances_part2(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instances_part1(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instances_part1(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instances_part2(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instances_part2(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v1_default_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v1_kpadding_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_default_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_kpadding_instances(
op_ptrs);
}
}
#endif
#endif
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_INT8))
if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
is_same_v<CDataType, half_t>)
......
......@@ -15,6 +15,19 @@ list(APPEND GEMM_MULTIPLY_MULTIPLY_INSTANCES
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instance_part1.cpp
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instance_part1.cpp
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instance_part2.cpp
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instance_part2.cpp
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instance_part1.cpp
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part1.cpp
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instance_part2.cpp
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part2.cpp
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v1_default_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_default_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_default_instance.cpp
......@@ -32,6 +45,15 @@ set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_default_instance_part2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instance_part1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instance_part1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instance_part2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instance_part2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instance_part1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instance_part2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
......
......@@ -234,7 +234,26 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
{
c_device_buf.FromDevice(e_m_n_device_result.mData.data());
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_INT8
// set softer tolerances for fp8
if constexpr((is_same_v<ADataType, f8_t> || is_same_v<BDataType, f8_t> ||
is_same_v<EDataType, f8_t>) ||
(is_same_v<ADataType, int8_t> || is_same_v<BDataType, int8_t> ||
is_same_v<EDataType, int8_t>))
{
std::string msg = "Error: Incorrect results!";
double rtol = 1e-1;
double atol = 1e-1;
pass = pass & ck::utils::check_err(
e_m_n_device_result, e_m_n_host_result, msg, rtol, atol);
}
else
{
#endif
pass = pass & ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_INT8
}
#endif
if(do_log)
{
......@@ -276,27 +295,6 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch "
<< kbatch_curr << std::endl;
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_INT8
// set softer tolerances for fp8
if constexpr((is_same_v<ADataType, f8_t> || is_same_v<BDataType, f8_t> ||
is_same_v<EDataType, f8_t>) ||
(is_same_v<ADataType, int8_t> || is_same_v<BDataType, int8_t> ||
is_same_v<EDataType, int8_t>))
{
std::string msg = "Error: Incorrect results!";
double rtol = 1e-1;
double atol = 1e-1;
pass = pass & ck::utils::check_err(
e_m_n_device_result, e_m_n_host_result, msg, rtol, atol);
}
else
{
#endif
pass = pass & ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_INT8
}
#endif
if(tflops > best_tflops && ave_time > 1e-10)
{
best_op_name = op_name;
......
......@@ -28,6 +28,7 @@ enum struct GemmDataType
F16_F16_F16_F8, // 6
F8_F8_BF16, // 7
INT8_INT8_F16, // 8
F8_F8_F16, // 9
};
#define OP_NAME "gemm_multiply_multiply"
......@@ -40,7 +41,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
"f16->f8; 7: f8->bf16, "
"comp f8; 8: int8->f16)\n");
"comp f8; 8: int8->f16; 9: f8->f16, comp f8;)\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
......@@ -166,6 +167,11 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
return profile(
F8{}, F8{}, F8{}, F32{}, F32{}, F32{}, BF16{}, Row{}, Col{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::F8_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(
F8{}, F8{}, F8{}, F32{}, F32{}, F32{}, F16{}, Row{}, Col{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::INT8_INT8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment