Unverified Commit f95267f1 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Gemm+Reduce Fusion (#128)

* add gridwise gemm v4r1

* rename

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* use sfc in shuffling

* remove hardcode

* remove hardcode

* refactor

* fix build

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* format

* clean

* adding gemm+reduce

* adding profiler for gemm+reduce

* adding gemm+reduce profiler

* fix build

* clean up

* gemm+reduce

* fix build

* update DeviceGemm_Xdl_CShuffle; update enum to enum class

* clean up

* add test for gemm+reduce

* clean up

* refactor

* fix build

* fix build
parent f91579aa
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <half.hpp> #include <half.hpp>
#include "profile_conv_bwd_data_impl.hpp" #include "profile_conv_bwd_data_impl.hpp"
enum ConvDataType enum struct ConvDataType
{ {
F32_F32_F32, // 0 F32_F32_F32, // 0
F16_F16_F16, // 1 F16_F16_F16, // 1
...@@ -14,19 +14,19 @@ enum ConvDataType ...@@ -14,19 +14,19 @@ enum ConvDataType
INT8_INT8_INT8, // 3 INT8_INT8_INT8, // 3
}; };
enum ConvInputLayout enum struct ConvInputLayout
{ {
NCHW, // 0 NCHW, // 0
NHWC, // 1 NHWC, // 1
}; };
enum ConvWeightLayout enum struct ConvWeightLayout
{ {
KCYX, // 0 KCYX, // 0
KYXC, // 1 KYXC, // 1
}; };
enum ConvOutputLayout enum struct ConvOutputLayout
{ {
NKHW, // 0 NKHW, // 0
NHWK, // 1 NHWK, // 1
...@@ -50,10 +50,10 @@ int profile_conv_bwd_data(int argc, char* argv[]) ...@@ -50,10 +50,10 @@ int profile_conv_bwd_data(int argc, char* argv[])
exit(1); exit(1);
} }
const int data_type = static_cast<ConvDataType>(std::stoi(argv[2])); const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const int in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3])); const auto in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3]));
const int wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4])); const auto wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4]));
const int out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5])); const auto out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5]));
const bool do_verification = std::stoi(argv[6]); const bool do_verification = std::stoi(argv[6]);
const int init_method = std::stoi(argv[7]); const int init_method = std::stoi(argv[7]);
const bool do_log = std::stoi(argv[8]); const bool do_log = std::stoi(argv[8]);
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <half.hpp> #include <half.hpp>
#include "profile_conv_fwd_impl.hpp" #include "profile_conv_fwd_impl.hpp"
enum ConvDataType enum struct ConvDataType
{ {
F32_F32_F32, // 0 F32_F32_F32, // 0
F16_F16_F16, // 1 F16_F16_F16, // 1
...@@ -14,19 +14,19 @@ enum ConvDataType ...@@ -14,19 +14,19 @@ enum ConvDataType
INT8_INT8_INT8, // 3 INT8_INT8_INT8, // 3
}; };
enum ConvInputLayout enum struct ConvInputLayout
{ {
NCHW, // 0 NCHW, // 0
NHWC, // 1 NHWC, // 1
}; };
enum ConvWeightLayout enum struct ConvWeightLayout
{ {
KCYX, // 0 KCYX, // 0
KYXC, // 1 KYXC, // 1
}; };
enum ConvOutputLayout enum struct ConvOutputLayout
{ {
NKHW, // 0 NKHW, // 0
NHWK, // 1 NHWK, // 1
...@@ -50,10 +50,10 @@ int profile_conv_fwd(int argc, char* argv[]) ...@@ -50,10 +50,10 @@ int profile_conv_fwd(int argc, char* argv[])
exit(1); exit(1);
} }
const int data_type = static_cast<ConvDataType>(std::stoi(argv[2])); const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const int in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3])); const auto in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3]));
const int wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4])); const auto wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4]));
const int out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5])); const auto out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5]));
const bool do_verification = std::stoi(argv[6]); const bool do_verification = std::stoi(argv[6]);
const int init_method = std::stoi(argv[7]); const int init_method = std::stoi(argv[7]);
const bool do_log = std::stoi(argv[8]); const bool do_log = std::stoi(argv[8]);
......
...@@ -6,25 +6,25 @@ ...@@ -6,25 +6,25 @@
#include <half.hpp> #include <half.hpp>
#include "profile_conv_fwd_bias_relu_impl.hpp" #include "profile_conv_fwd_bias_relu_impl.hpp"
enum ConvDataType enum struct ConvDataType
{ {
F32_F32_F32, // 0 F32_F32_F32, // 0
F16_F16_F16, // 1 F16_F16_F16, // 1
}; };
enum ConvInputLayout enum struct ConvInputLayout
{ {
NCHW, // 0 NCHW, // 0
NHWC, // 1 NHWC, // 1
}; };
enum ConvWeightLayout enum struct ConvWeightLayout
{ {
KCYX, // 0 KCYX, // 0
KYXC, // 1 KYXC, // 1
}; };
enum ConvOutputLayout enum struct ConvOutputLayout
{ {
NKHW, // 0 NKHW, // 0
NHWK, // 1 NHWK, // 1
...@@ -48,10 +48,10 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[]) ...@@ -48,10 +48,10 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[])
exit(1); exit(1);
} }
const int data_type = static_cast<ConvDataType>(std::stoi(argv[2])); const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const int in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3])); const auto in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3]));
const int wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4])); const auto wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4]));
const int out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5])); const auto out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5]));
const bool do_verification = std::stoi(argv[6]); const bool do_verification = std::stoi(argv[6]);
const int init_method = std::stoi(argv[7]); const int init_method = std::stoi(argv[7]);
const bool do_log = std::stoi(argv[8]); const bool do_log = std::stoi(argv[8]);
......
...@@ -6,25 +6,25 @@ ...@@ -6,25 +6,25 @@
#include <half.hpp> #include <half.hpp>
#include "profile_conv_fwd_bias_relu_add_impl.hpp" #include "profile_conv_fwd_bias_relu_add_impl.hpp"
enum ConvDataType enum struct ConvDataType
{ {
F32_F32_F32, // 0 F32_F32_F32, // 0
F16_F16_F16, // 1 F16_F16_F16, // 1
}; };
enum ConvInputLayout enum struct ConvInputLayout
{ {
NCHW, // 0 NCHW, // 0
NHWC, // 1 NHWC, // 1
}; };
enum ConvWeightLayout enum struct ConvWeightLayout
{ {
KCYX, // 0 KCYX, // 0
KYXC, // 1 KYXC, // 1
}; };
enum ConvOutputLayout enum struct ConvOutputLayout
{ {
NKHW, // 0 NKHW, // 0
NHWK, // 1 NHWK, // 1
...@@ -49,10 +49,10 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[]) ...@@ -49,10 +49,10 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[])
exit(1); exit(1);
} }
const int data_type = static_cast<ConvDataType>(std::stoi(argv[2])); const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const int in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3])); const auto in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3]));
const int wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4])); const auto wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4]));
const int out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5])); const auto out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5]));
const bool do_verification = std::stoi(argv[6]); const bool do_verification = std::stoi(argv[6]);
const int init_method = std::stoi(argv[7]); const int init_method = std::stoi(argv[7]);
const bool do_log = std::stoi(argv[8]); const bool do_log = std::stoi(argv[8]);
......
...@@ -6,25 +6,25 @@ ...@@ -6,25 +6,25 @@
#include <half.hpp> #include <half.hpp>
#include "profile_conv_fwd_bias_relu_atomic_add_impl.hpp" #include "profile_conv_fwd_bias_relu_atomic_add_impl.hpp"
enum ConvDataType enum struct ConvDataType
{ {
F32_F32_F32, // 0 F32_F32_F32, // 0
F16_F16_F16, // 1 F16_F16_F16, // 1
}; };
enum ConvInputLayout enum struct ConvInputLayout
{ {
NCHW, // 0 NCHW, // 0
NHWC, // 1 NHWC, // 1
}; };
enum ConvWeightLayout enum struct ConvWeightLayout
{ {
KCYX, // 0 KCYX, // 0
KYXC, // 1 KYXC, // 1
}; };
enum ConvOutputLayout enum struct ConvOutputLayout
{ {
NKHW, // 0 NKHW, // 0
NHWK, // 1 NHWK, // 1
...@@ -49,10 +49,10 @@ int profile_conv_fwd_bias_relu_atomic_add(int argc, char* argv[]) ...@@ -49,10 +49,10 @@ int profile_conv_fwd_bias_relu_atomic_add(int argc, char* argv[])
exit(1); exit(1);
} }
const int data_type = static_cast<ConvDataType>(std::stoi(argv[2])); const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const int in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3])); const auto in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3]));
const int wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4])); const auto wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4]));
const int out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5])); const auto out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5]));
const bool do_verification = std::stoi(argv[6]); const bool do_verification = std::stoi(argv[6]);
const int init_method = std::stoi(argv[7]); const int init_method = std::stoi(argv[7]);
const bool do_log = std::stoi(argv[8]); const bool do_log = std::stoi(argv[8]);
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <half.hpp> #include <half.hpp>
#include "profile_gemm_impl.hpp" #include "profile_gemm_impl.hpp"
enum GemmMatrixLayout enum struct GemmMatrixLayout
{ {
MK_KN_MN, // 0 MK_KN_MN, // 0
MK_NK_MN, // 1 MK_NK_MN, // 1
...@@ -18,7 +18,7 @@ enum GemmMatrixLayout ...@@ -18,7 +18,7 @@ enum GemmMatrixLayout
KM_NK_NM, // 7 KM_NK_NM, // 7
}; };
enum GemmDataType enum struct GemmDataType
{ {
F32_F32_F32, // 0 F32_F32_F32, // 0
F16_F16_F16, // 1 F16_F16_F16, // 1
...@@ -45,8 +45,8 @@ int profile_gemm(int argc, char* argv[]) ...@@ -45,8 +45,8 @@ int profile_gemm(int argc, char* argv[])
exit(1); exit(1);
} }
const int data_type = static_cast<GemmDataType>(std::stoi(argv[2])); const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3])); const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]); const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]); const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]); const bool do_log = std::stoi(argv[6]);
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <half.hpp> #include <half.hpp>
#include "profile_gemm_bias_2d_impl.hpp" #include "profile_gemm_bias_2d_impl.hpp"
enum GemmMatrixLayout enum struct GemmMatrixLayout
{ {
MK_KN_MN, // 0 MK_KN_MN, // 0
MK_NK_MN, // 1 MK_NK_MN, // 1
...@@ -18,7 +18,7 @@ enum GemmMatrixLayout ...@@ -18,7 +18,7 @@ enum GemmMatrixLayout
KM_NK_NM, // 7 KM_NK_NM, // 7
}; };
enum GemmDataType enum struct GemmDataType
{ {
F32_F32_F32, // 0 F32_F32_F32, // 0
F16_F16_F16, // 1 F16_F16_F16, // 1
...@@ -45,8 +45,8 @@ int profile_gemm_bias_2d(int argc, char* argv[]) ...@@ -45,8 +45,8 @@ int profile_gemm_bias_2d(int argc, char* argv[])
exit(1); exit(1);
} }
const int data_type = static_cast<GemmDataType>(std::stoi(argv[2])); const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3])); const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]); const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]); const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]); const bool do_log = std::stoi(argv[6]);
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <half.hpp> #include <half.hpp>
#include "profile_gemm_bias_relu_impl.hpp" #include "profile_gemm_bias_relu_impl.hpp"
enum GemmMatrixLayout enum struct GemmMatrixLayout
{ {
MK_KN_MN, // 0 MK_KN_MN, // 0
MK_NK_MN, // 1 MK_NK_MN, // 1
...@@ -18,7 +18,7 @@ enum GemmMatrixLayout ...@@ -18,7 +18,7 @@ enum GemmMatrixLayout
KM_NK_NM, // 7 KM_NK_NM, // 7
}; };
enum GemmDataType enum struct GemmDataType
{ {
F32_F32_F32, // 0 F32_F32_F32, // 0
F16_F16_F16, // 1 F16_F16_F16, // 1
...@@ -43,8 +43,8 @@ int profile_gemm_bias_relu(int argc, char* argv[]) ...@@ -43,8 +43,8 @@ int profile_gemm_bias_relu(int argc, char* argv[])
exit(1); exit(1);
} }
const int data_type = static_cast<GemmDataType>(std::stoi(argv[2])); const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3])); const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]); const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]); const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]); const bool do_log = std::stoi(argv[6]);
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <half.hpp> #include <half.hpp>
#include "profile_gemm_bias_relu_add_impl.hpp" #include "profile_gemm_bias_relu_add_impl.hpp"
enum GemmMatrixLayout enum struct GemmMatrixLayout
{ {
MK_KN_MN, // 0 MK_KN_MN, // 0
MK_NK_MN, // 1 MK_NK_MN, // 1
...@@ -18,7 +18,7 @@ enum GemmMatrixLayout ...@@ -18,7 +18,7 @@ enum GemmMatrixLayout
KM_NK_NM, // 7 KM_NK_NM, // 7
}; };
enum GemmDataType enum struct GemmDataType
{ {
F32_F32_F32, // 0 F32_F32_F32, // 0
F16_F16_F16, // 1 F16_F16_F16, // 1
...@@ -43,8 +43,8 @@ int profile_gemm_bias_relu_add(int argc, char* argv[]) ...@@ -43,8 +43,8 @@ int profile_gemm_bias_relu_add(int argc, char* argv[])
exit(1); exit(1);
} }
const int data_type = static_cast<GemmDataType>(std::stoi(argv[2])); const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3])); const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]); const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]); const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]); const bool do_log = std::stoi(argv[6]);
......
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "profile_gemm_reduce_impl.hpp"
int profile_gemm_reduce(int argc, char* argv[])
{
enum struct GemmMatrixLayout_t
{
MK_KN_MN, // 0
MK_NK_MN, // 1
KM_KN_MN, // 2
KM_NK_MN, // 3
};
enum struct GemmReduceDataType_t
{
F32_F32_F32_F32_F32, // 0
F16_F16_F16_F32_F32, // 1
};
if(!(argc == 14 || argc == 15))
{
printf("arg1: tensor operation (gemm: GEMM+Reduce)\n");
printf("arg2: data type (0: fp32; 1: fp16)\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
printf("arg4: verification (0: no; 1: yes)\n");
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg8: print tensor value (0: no; 1: yes)\n");
printf("arg7: run kernel # of times (>1)\n");
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
printf("arg14: split k into mulitiple batch\n");
exit(1);
}
const auto data_type = static_cast<GemmReduceDataType_t>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout_t>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]);
const int nrepeat = std::stoi(argv[7]);
const int M = std::stoi(argv[8]);
const int N = std::stoi(argv[9]);
const int K = std::stoi(argv[10]);
const int StrideA = std::stoi(argv[11]);
const int StrideB = std::stoi(argv[12]);
const int StrideC = std::stoi(argv[13]);
if(data_type == GemmReduceDataType_t::F16_F16_F16_F32_F32 &&
layout == GemmMatrixLayout_t::MK_KN_MN)
{
ck::profiler::profile_gemm_reduce_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
}
else if(data_type == GemmReduceDataType_t::F16_F16_F16_F32_F32 &&
layout == GemmMatrixLayout_t::MK_NK_MN)
{
ck::profiler::profile_gemm_reduce_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC);
}
else if(data_type == GemmReduceDataType_t::F16_F16_F16_F32_F32 &&
layout == GemmMatrixLayout_t::KM_KN_MN)
{
ck::profiler::profile_gemm_reduce_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
}
else if(data_type == GemmReduceDataType_t::F16_F16_F16_F32_F32 &&
layout == GemmMatrixLayout_t::KM_NK_MN)
{
ck::profiler::profile_gemm_reduce_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC);
}
else
{
throw std::runtime_error("wrong! this data_type & layout is not implemented");
}
return 1;
}
...@@ -84,7 +84,7 @@ static std::vector<T> getTypeValuesFromString(const char* cstr_values) ...@@ -84,7 +84,7 @@ static std::vector<T> getTypeValuesFromString(const char* cstr_values)
return (values); return (values);
} }
typedef enum enum struct appDataType_t
{ {
appHalf = 0, appHalf = 0,
appFloat = 1, appFloat = 1,
...@@ -93,7 +93,7 @@ typedef enum ...@@ -93,7 +93,7 @@ typedef enum
appInt8x4 = 4, appInt8x4 = 4,
appBFloat16 = 5, appBFloat16 = 5,
appDouble = 6, appDouble = 6,
} appDataType_t; };
static void check_reduce_dims(const int rank, const std::vector<int>& reduceDims) static void check_reduce_dims(const int rank, const std::vector<int>& reduceDims)
{ {
...@@ -131,8 +131,8 @@ class AppArgs ...@@ -131,8 +131,8 @@ class AppArgs
std::vector<float> scales; std::vector<float> scales;
ReduceTensorOp_t reduceOp = ReduceTensorOp_t::ADD; ReduceTensorOp_t reduceOp = ReduceTensorOp_t::ADD;
appDataType_t compTypeId = appFloat; appDataType_t compTypeId = appDataType_t::appFloat;
appDataType_t outTypeId = appFloat; appDataType_t outTypeId = appDataType_t::appFloat;
bool compType_assigned = false; bool compType_assigned = false;
bool outType_assigned = false; bool outType_assigned = false;
...@@ -339,15 +339,16 @@ int profile_reduce(int argc, char* argv[]) ...@@ -339,15 +339,16 @@ int profile_reduce(int argc, char* argv[])
if(args.use_half) if(args.use_half)
{ {
if(!args.compType_assigned) if(!args.compType_assigned)
args.compTypeId = appHalf; args.compTypeId = appDataType_t::appHalf;
if(args.outType_assigned && (args.outTypeId != appHalf && args.outTypeId != appFloat)) if(args.outType_assigned &&
args.outTypeId = appFloat; (args.outTypeId != appDataType_t::appHalf && args.outTypeId != appDataType_t::appFloat))
args.outTypeId = appDataType_t::appFloat;
if(!args.outType_assigned) if(!args.outType_assigned)
args.outTypeId = appHalf; args.outTypeId = appDataType_t::appHalf;
if(args.compTypeId == appHalf) if(args.compTypeId == appDataType_t::appHalf)
{ {
profile_reduce_impl<ck::half_t, ck::half_t, ck::half_t>(args.do_verification, profile_reduce_impl<ck::half_t, ck::half_t, ck::half_t>(args.do_verification,
args.init_method, args.init_method,
...@@ -362,7 +363,7 @@ int profile_reduce(int argc, char* argv[]) ...@@ -362,7 +363,7 @@ int profile_reduce(int argc, char* argv[])
args.scales[0], args.scales[0],
args.scales[1]); args.scales[1]);
} }
else if(args.compTypeId == appFloat) else if(args.compTypeId == appDataType_t::appFloat)
{ {
profile_reduce_impl<ck::half_t, float, ck::half_t>(args.do_verification, profile_reduce_impl<ck::half_t, float, ck::half_t>(args.do_verification,
args.init_method, args.init_method,
...@@ -398,15 +399,16 @@ int profile_reduce(int argc, char* argv[]) ...@@ -398,15 +399,16 @@ int profile_reduce(int argc, char* argv[])
else if(args.use_int8) else if(args.use_int8)
{ {
if(!args.compType_assigned) if(!args.compType_assigned)
args.compTypeId = appInt8; args.compTypeId = appDataType_t::appInt8;
if(args.outType_assigned && (args.outTypeId != appInt8 && args.outTypeId != appInt32)) if(args.outType_assigned &&
args.outTypeId = appInt32; (args.outTypeId != appDataType_t::appInt8 && args.outTypeId != appDataType_t::appInt32))
args.outTypeId = appDataType_t::appInt32;
if(!args.outType_assigned) if(!args.outType_assigned)
args.outTypeId = appInt8; args.outTypeId = appDataType_t::appInt8;
if(args.compTypeId == appInt8) if(args.compTypeId == appDataType_t::appInt8)
{ {
profile_reduce_impl<int8_t, int8_t, int8_t>(args.do_verification, profile_reduce_impl<int8_t, int8_t, int8_t>(args.do_verification,
args.init_method, args.init_method,
...@@ -421,7 +423,7 @@ int profile_reduce(int argc, char* argv[]) ...@@ -421,7 +423,7 @@ int profile_reduce(int argc, char* argv[])
args.scales[0], args.scales[0],
args.scales[1]); args.scales[1]);
} }
else if(args.compTypeId == appInt32) else if(args.compTypeId == appDataType_t::appInt32)
{ {
profile_reduce_impl<int8_t, int32_t, int8_t>(args.do_verification, profile_reduce_impl<int8_t, int32_t, int8_t>(args.do_verification,
args.init_method, args.init_method,
...@@ -441,11 +443,12 @@ int profile_reduce(int argc, char* argv[]) ...@@ -441,11 +443,12 @@ int profile_reduce(int argc, char* argv[])
} }
else if(args.use_bf16) else if(args.use_bf16)
{ {
if(args.outType_assigned && (args.outTypeId != appBFloat16 && args.outTypeId != appFloat)) if(args.outType_assigned && (args.outTypeId != appDataType_t::appBFloat16 &&
args.outTypeId = appFloat; args.outTypeId != appDataType_t::appFloat))
args.outTypeId = appDataType_t::appFloat;
if(!args.outType_assigned) if(!args.outType_assigned)
args.outTypeId = appBFloat16; args.outTypeId = appDataType_t::appBFloat16;
profile_reduce_impl<ck::bhalf_t, float, ck::bhalf_t>(args.do_verification, profile_reduce_impl<ck::bhalf_t, float, ck::bhalf_t>(args.do_verification,
args.init_method, args.init_method,
...@@ -462,7 +465,7 @@ int profile_reduce(int argc, char* argv[]) ...@@ -462,7 +465,7 @@ int profile_reduce(int argc, char* argv[])
} }
else else
{ {
if(args.compTypeId == appFloat) if(args.compTypeId == appDataType_t::appFloat)
{ {
profile_reduce_impl<float, float, float>(args.do_verification, profile_reduce_impl<float, float, float>(args.do_verification,
args.init_method, args.init_method,
...@@ -477,7 +480,7 @@ int profile_reduce(int argc, char* argv[]) ...@@ -477,7 +480,7 @@ int profile_reduce(int argc, char* argv[])
args.scales[0], args.scales[0],
args.scales[1]); args.scales[1]);
} }
else if(args.compTypeId == appDouble) else if(args.compTypeId == appDataType_t::appDouble)
{ {
profile_reduce_impl<float, double, float>(args.do_verification, profile_reduce_impl<float, double, float>(args.do_verification,
args.init_method, args.init_method,
......
...@@ -5,17 +5,18 @@ ...@@ -5,17 +5,18 @@
#include <cstring> #include <cstring>
int profile_gemm(int, char*[]); int profile_gemm(int, char*[]);
int profile_batched_gemm(int, char*[]);
int profile_gemm_bias_2d(int, char*[]); int profile_gemm_bias_2d(int, char*[]);
int profile_gemm_bias_relu(int, char*[]); int profile_gemm_bias_relu(int, char*[]);
int profile_gemm_bias_relu_add(int, char*[]); int profile_gemm_bias_relu_add(int, char*[]);
int profile_gemm_reduce(int, char*[]);
int profile_batched_gemm(int, char*[]);
int profile_grouped_gemm(int, char*[]);
int profile_conv_fwd(int, char*[]); int profile_conv_fwd(int, char*[]);
int profile_conv_fwd_bias_relu(int, char*[]); int profile_conv_fwd_bias_relu(int, char*[]);
int profile_conv_fwd_bias_relu_add(int, char*[]); int profile_conv_fwd_bias_relu_add(int, char*[]);
int profile_conv_fwd_bias_relu_atomic_add(int, char*[]); int profile_conv_fwd_bias_relu_atomic_add(int, char*[]);
int profile_conv_bwd_data(int, char*[]); int profile_conv_bwd_data(int, char*[]);
int profile_reduce(int, char*[]); int profile_reduce(int, char*[]);
int profile_grouped_gemm(int, char*[]);
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -35,10 +36,18 @@ int main(int argc, char* argv[]) ...@@ -35,10 +36,18 @@ int main(int argc, char* argv[])
{ {
return profile_gemm_bias_relu_add(argc, argv); return profile_gemm_bias_relu_add(argc, argv);
} }
else if(strcmp(argv[1], "gemm_reduce") == 0)
{
return profile_gemm_reduce(argc, argv);
}
else if(strcmp(argv[1], "batched_gemm") == 0) else if(strcmp(argv[1], "batched_gemm") == 0)
{ {
return profile_batched_gemm(argc, argv); return profile_batched_gemm(argc, argv);
} }
else if(strcmp(argv[1], "grouped_gemm") == 0)
{
profile_grouped_gemm(argc, argv);
}
else if(strcmp(argv[1], "conv_fwd") == 0) else if(strcmp(argv[1], "conv_fwd") == 0)
{ {
return profile_conv_fwd(argc, argv); return profile_conv_fwd(argc, argv);
...@@ -63,10 +72,6 @@ int main(int argc, char* argv[]) ...@@ -63,10 +72,6 @@ int main(int argc, char* argv[])
{ {
return profile_reduce(argc, argv); return profile_reduce(argc, argv);
} }
else if(strcmp(argv[1], "grouped_gemm") == 0)
{
return profile_grouped_gemm(argc, argv);
}
else else
{ {
// clang-format off // clang-format off
...@@ -74,13 +79,14 @@ int main(int argc, char* argv[]) ...@@ -74,13 +79,14 @@ int main(int argc, char* argv[])
" gemm_bias_2d: GEMM+Bias(2D)\n" " gemm_bias_2d: GEMM+Bias(2D)\n"
" gemm_bias_relu: GEMM+Bias+ReLU\n" " gemm_bias_relu: GEMM+Bias+ReLU\n"
" gemm_bias_relu_add: GEMM+Bias+ReLU+Add\n" " gemm_bias_relu_add: GEMM+Bias+ReLU+Add\n"
" gemm_reduce: GEMM+Reduce\n"
" grouped_gemm: Grouped Gemm\n"
" conv_fwd: ForwardConvolution\n" " conv_fwd: ForwardConvolution\n"
" conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU\n" " conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU\n"
" conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add\n" " conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add\n"
" conv_fwd_bias_relu_atomic_add: ForwardConvolution+Bias+ReLU+AtomicAdd\n" " conv_fwd_bias_relu_atomic_add: ForwardConvolution+Bias+ReLU+AtomicAdd\n"
" conv_bwd: BackwardConvolution\n" " conv_bwd: BackwardConvolution\n"
" grouped_gemm: Grouped Gemm\n" " reduce: Reduce\n");
" reduce: REDUCE\n");
// clang-format on // clang-format on
return 0; return 0;
......
...@@ -16,6 +16,7 @@ include_directories(BEFORE ...@@ -16,6 +16,7 @@ include_directories(BEFORE
${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/cpu ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/cpu
${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/gpu ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/gpu
${PROJECT_SOURCE_DIR}/test/include ${PROJECT_SOURCE_DIR}/test/include
${PROJECT_SOURCE_DIR}/profiler/include
${PROJECT_SOURCE_DIR}/external/include/half ${PROJECT_SOURCE_DIR}/external/include/half
) )
...@@ -35,9 +36,10 @@ add_subdirectory(space_filling_curve) ...@@ -35,9 +36,10 @@ add_subdirectory(space_filling_curve)
add_subdirectory(conv_util) add_subdirectory(conv_util)
add_subdirectory(reference_conv_fwd) add_subdirectory(reference_conv_fwd)
add_subdirectory(gemm) add_subdirectory(gemm)
add_subdirectory(grouped_gemm)
add_subdirectory(gemm_split_k) add_subdirectory(gemm_split_k)
add_subdirectory(gemm_reduce)
add_subdirectory(batched_gemm)
add_subdirectory(grouped_gemm)
add_subdirectory(convnd_fwd) add_subdirectory(convnd_fwd)
add_subdirectory(conv2d_bwd_data) add_subdirectory(conv2d_bwd_data)
add_subdirectory(batched_gemm)
add_subdirectory(reduce) add_subdirectory(reduce)
include_directories(BEFORE
${PROJECT_SOURCE_DIR}/profiler/include
${PROJECT_SOURCE_DIR}/test/include
${PROJECT_SOURCE_DIR}/external/include/half
)
add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16.cpp)
target_link_libraries(test_gemm_reduce_fp16 PRIVATE host_tensor)
target_link_libraries(test_gemm_reduce_fp16 PRIVATE device_gemm_reduce_instance)
#include <algorithm>
#include <cstdlib>
#include <half.hpp>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "profile_gemm_reduce_impl.hpp"
int main()
{
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
int M = 512;
int N = 256;
int K = 128;
bool pass = true;
pass = pass &&
ck::profiler::
profile_gemm_reduce_impl<ck::half_t, ck::half_t, ck::half_t, float, Row, Row, Row>(
true, 1, false, 1, M, N, K, K, N, N);
pass = pass &&
ck::profiler::
profile_gemm_reduce_impl<ck::half_t, ck::half_t, ck::half_t, float, Row, Col, Row>(
true, 1, false, 1, M, N, K, K, K, N);
pass = pass &&
ck::profiler::
profile_gemm_reduce_impl<ck::half_t, ck::half_t, ck::half_t, float, Col, Row, Row>(
true, 1, false, 1, M, N, K, M, N, N);
pass = pass &&
ck::profiler::
profile_gemm_reduce_impl<ck::half_t, ck::half_t, ck::half_t, float, Col, Col, Row>(
true, 1, false, 1, M, N, K, M, K, N);
if(pass)
{
std::cout << "test GEMM+Reduce fp16: Pass" << std::endl;
return 0;
}
else
{
std::cout << "test GEMM+Reduce fp16: Fail" << std::endl;
return -1;
}
}
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "device_gemm_xdl_splitk.hpp" #include "device_gemm_xdl_splitk.hpp"
enum GemmMatrixLayout enum struct GemmMatrixLayout
{ {
MK_KN_MN, // 0 MK_KN_MN, // 0
MK_NK_MN, // 1 MK_NK_MN, // 1
...@@ -59,7 +59,7 @@ static bool check_out(const Tensor<T>& ref, const Tensor<T>& result) ...@@ -59,7 +59,7 @@ static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
struct gemmArgs struct gemmArgs
{ {
int layout; GemmMatrixLayout layout;
int M; int M;
int N; int N;
int K; int K;
...@@ -216,13 +216,13 @@ int main(int argc, char* argv[]) ...@@ -216,13 +216,13 @@ int main(int argc, char* argv[])
std::vector<gemmArgs> test_cases; std::vector<gemmArgs> test_cases;
if(argc == 1) if(argc == 1)
{ {
test_cases = {{0, 3, 3, 3, 3, 3, 3, 1}}; test_cases = {{GemmMatrixLayout::MK_KN_MN, 3, 3, 3, 3, 3, 3, 1}};
// JD: Populate with more and meaningful // JD: Populate with more and meaningful
return 0; return 0;
} }
else if(argc == 9) else if(argc == 9)
{ {
const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[1])); const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[1]));
const int M = std::stoi(argv[2]); const int M = std::stoi(argv[2]);
const int N = std::stoi(argv[3]); const int N = std::stoi(argv[3]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment