Commit 6c8ca54b authored by ltqin's avatar ltqin
Browse files

gemm reference add double data type

parent bf5af9f9
......@@ -81,7 +81,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<float, float, float, PassThrough, PassThrough, PassThrough>;
ReferenceGemm<float, float, float, float, PassThrough, PassThrough, PassThrough>;
int main(int argc, char* argv[])
{
......
......@@ -54,7 +54,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
int main(int argc, char* argv[])
{
......
......@@ -54,7 +54,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
int main(int argc, char* argv[])
{
......
......@@ -80,8 +80,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
4>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
int main(int argc, char* argv[])
{
......
......@@ -72,8 +72,13 @@ using DeviceConvBwdWeightInstance = ck::tensor_operation::device::
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
using ReferenceConvBwdWeightInstance = ck::tensor_operation::host::
ReferenceConvBwdWeight<InDataType, WeiDataType, OutDataType, InElementOp, WeiElementOp, OutElementOp>;
using ReferenceConvBwdWeightInstance =
ck::tensor_operation::host::ReferenceConvBwdWeight<InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>;
int main(int argc, char* argv[])
{
......
......@@ -83,8 +83,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
16>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, RequantReluRequant>;
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
PassThrough,
PassThrough,
RequantReluRequant>;
int main(int argc, char* argv[])
{
......
......@@ -56,7 +56,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdl
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
int main(int argc, char* argv[])
{
......
......@@ -52,7 +52,7 @@ using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
int main(int argc, char* argv[])
{
......
......@@ -387,17 +387,17 @@ struct mfma_type<MfmaInstr::mfma_i32_16x16x16i8>
template <>
struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4; //group_size * num_groups_per_blk;
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4; // group_size * num_groups_per_blk;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4; //wave_size / num_threads_per_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 1;
static constexpr bool is_k_reduction = true;
static constexpr index_t num_input_blks = 4; // wave_size / num_threads_per_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 1;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
......@@ -690,8 +690,9 @@ struct XdlopsGemm
template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{
static_assert(is_same<base_type, double>::value ||is_same<base_type, float>::value || is_same<base_type, half_t>::value ||
is_same<base_type, bhalf_t>::value || is_same<base_type, int8_t>::value,
static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value ||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
is_same<base_type, int8_t>::value,
"base base_type must be double, float, half, bfloat16, and int8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
......
......@@ -13,6 +13,7 @@ namespace host {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
......@@ -55,12 +56,12 @@ struct ReferenceGemm : public device::BaseOperator
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
float v_acc = 0;
AccDataType v_acc = 0;
for(int k = 0; k < K; ++k)
{
float v_a;
float v_b;
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(v_a, static_cast<const float>(arg.a_m_k_(m, k)));
arg.b_element_op_(v_b, static_cast<const float>(arg.b_k_n_(k, n)));
......@@ -68,7 +69,7 @@ struct ReferenceGemm : public device::BaseOperator
v_acc += v_a * v_b;
}
float v_c;
AccDataType v_c;
arg.c_element_op_(v_c, v_acc);
......
......@@ -85,6 +85,7 @@ namespace profiler {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout>
......@@ -457,8 +458,14 @@ void profile_gemm_impl(int do_verification,
bf16_to_f32_(b_k_n, b_f32_k_n);
bf16_to_f32_(c_m_n_device_result, c_m_n_device_f32_result);
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<float, float, float, AElementOp, BElementOp, CElementOp>;
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemm<float,
float,
float,
float,
AElementOp,
BElementOp,
CElementOp>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
......@@ -490,6 +497,7 @@ void profile_gemm_impl(int do_verification,
ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
......
......@@ -127,8 +127,13 @@ bool profile_gemm_reduce_impl(int do_verification,
if(do_verification)
{
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
DDataType,
AElementOp,
BElementOp,
CElementOp>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
......
......@@ -43,6 +43,7 @@ namespace profiler {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout>
......@@ -270,6 +271,7 @@ void profile_grouped_gemm_impl(int do_verification,
ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
......
......@@ -68,6 +68,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -88,6 +89,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -108,6 +110,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -128,6 +131,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -146,6 +150,7 @@ int profile_gemm(int argc, char* argv[])
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_gemm_impl<float,
float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
......@@ -166,6 +171,7 @@ int profile_gemm(int argc, char* argv[])
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_gemm_impl<float,
float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
......@@ -186,6 +192,7 @@ int profile_gemm(int argc, char* argv[])
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
{
ck::profiler::profile_gemm_impl<float,
float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
......@@ -206,6 +213,7 @@ int profile_gemm(int argc, char* argv[])
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{
ck::profiler::profile_gemm_impl<float,
float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
......@@ -228,6 +236,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<int8_t,
int8_t,
int8_t,
int32_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -248,6 +257,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<int8_t,
int8_t,
int8_t,
int32_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -268,6 +278,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<int8_t,
int8_t,
int8_t,
int32_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -288,6 +299,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<int8_t,
int8_t,
int8_t,
int32_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -308,6 +320,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -328,6 +341,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -348,6 +362,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -368,6 +383,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
......
......@@ -52,9 +52,10 @@ void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
int main()
{
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
......@@ -74,6 +75,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
RowMajor,
RowMajor,
......@@ -96,6 +98,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
ColumnMajor,
RowMajor,
......@@ -118,6 +121,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
RowMajor,
RowMajor,
......@@ -142,6 +146,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
ColumnMajor,
RowMajor,
......
......@@ -53,12 +53,13 @@ void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(std::vector<De
int main()
{
using ADataType = float;
using BDataType = float;
using CDataType = float;
using ADataType = float;
using BDataType = float;
using CDataType = float;
using AccDataType = float,
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
bool res = true;
std::vector<DeviceGemmNoOpPtr> gemmPtrs;
......@@ -75,6 +76,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
RowMajor,
RowMajor,
......@@ -97,6 +99,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
ColumnMajor,
RowMajor,
......@@ -119,6 +122,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
RowMajor,
RowMajor,
......@@ -141,6 +145,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
ColumnMajor,
RowMajor,
......
......@@ -46,12 +46,13 @@ void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(
int main()
{
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int8_t;
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int8_t;
using AccDataType = int32_t,
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
std::vector<DeviceGemmNoOpPtr> gemmPtrs;
bool res = true;
......@@ -65,6 +66,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
RowMajor,
RowMajor,
......@@ -83,6 +85,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
ColumnMajor,
RowMajor,
......@@ -101,6 +104,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
RowMajor,
RowMajor,
......@@ -119,6 +123,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
ColumnMajor,
RowMajor,
......
......@@ -106,6 +106,7 @@ template <typename DeviceGemmPtr_,
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout,
......@@ -188,6 +189,7 @@ struct TestGemm
ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
......@@ -306,6 +308,7 @@ struct TestGemmBF16
// use fp32 host kernel to verify bf16 device kernel
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemm<float,
float,
float,
float,
AElementwiseOperation,
......
......@@ -151,8 +151,13 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
{
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data());
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment