Unverified Commit f95267f1 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Gemm+Reduce Fusion (#128)

* add gridwise gemm v4r1

* rename

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* use sfc in shuffling

* remove hardcode

* remove hardcode

* refactor

* fix build

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* format

* clean

* adding gemm+reduce

* adding profiler for gemm+reduce

* adding gemm+reduce profiler

* fix build

* clean up

* gemm+reduce

* fix build

* update DeviceGemm_Xdl_CShuffle; update enum to enum class

* clean up

* add test for gemm+reduce

* clean up

* refactor

* fix build

* fix build
parent f91579aa
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <half.hpp> #include <half.hpp>
#include "profile_conv_bwd_data_impl.hpp" #include "profile_conv_bwd_data_impl.hpp"
enum ConvDataType enum struct ConvDataType
{ {
F32_F32_F32, // 0 F32_F32_F32, // 0
F16_F16_F16, // 1 F16_F16_F16, // 1
...@@ -14,19 +14,19 @@ enum ConvDataType ...@@ -14,19 +14,19 @@ enum ConvDataType
INT8_INT8_INT8, // 3 INT8_INT8_INT8, // 3
}; };
enum ConvInputLayout enum struct ConvInputLayout
{ {
NCHW, // 0 NCHW, // 0
NHWC, // 1 NHWC, // 1
}; };
enum ConvWeightLayout enum struct ConvWeightLayout
{ {
KCYX, // 0 KCYX, // 0
KYXC, // 1 KYXC, // 1
}; };
enum ConvOutputLayout enum struct ConvOutputLayout
{ {
NKHW, // 0 NKHW, // 0
NHWK, // 1 NHWK, // 1
...@@ -50,10 +50,10 @@ int profile_conv_bwd_data(int argc, char* argv[]) ...@@ -50,10 +50,10 @@ int profile_conv_bwd_data(int argc, char* argv[])
exit(1); exit(1);
} }
const int data_type = static_cast<ConvDataType>(std::stoi(argv[2])); const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const int in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3])); const auto in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3]));
const int wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4])); const auto wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4]));
const int out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5])); const auto out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5]));
const bool do_verification = std::stoi(argv[6]); const bool do_verification = std::stoi(argv[6]);
const int init_method = std::stoi(argv[7]); const int init_method = std::stoi(argv[7]);
const bool do_log = std::stoi(argv[8]); const bool do_log = std::stoi(argv[8]);
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <half.hpp> #include <half.hpp>
#include "profile_conv_fwd_impl.hpp" #include "profile_conv_fwd_impl.hpp"
enum ConvDataType enum struct ConvDataType
{ {
F32_F32_F32, // 0 F32_F32_F32, // 0
F16_F16_F16, // 1 F16_F16_F16, // 1
...@@ -14,19 +14,19 @@ enum ConvDataType ...@@ -14,19 +14,19 @@ enum ConvDataType
INT8_INT8_INT8, // 3 INT8_INT8_INT8, // 3
}; };
enum ConvInputLayout enum struct ConvInputLayout
{ {
NCHW, // 0 NCHW, // 0
NHWC, // 1 NHWC, // 1
}; };
enum ConvWeightLayout enum struct ConvWeightLayout
{ {
KCYX, // 0 KCYX, // 0
KYXC, // 1 KYXC, // 1
}; };
enum ConvOutputLayout enum struct ConvOutputLayout
{ {
NKHW, // 0 NKHW, // 0
NHWK, // 1 NHWK, // 1
...@@ -50,10 +50,10 @@ int profile_conv_fwd(int argc, char* argv[]) ...@@ -50,10 +50,10 @@ int profile_conv_fwd(int argc, char* argv[])
exit(1); exit(1);
} }
const int data_type = static_cast<ConvDataType>(std::stoi(argv[2])); const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const int in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3])); const auto in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3]));
const int wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4])); const auto wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4]));
const int out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5])); const auto out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5]));
const bool do_verification = std::stoi(argv[6]); const bool do_verification = std::stoi(argv[6]);
const int init_method = std::stoi(argv[7]); const int init_method = std::stoi(argv[7]);
const bool do_log = std::stoi(argv[8]); const bool do_log = std::stoi(argv[8]);
......
...@@ -6,25 +6,25 @@ ...@@ -6,25 +6,25 @@
#include <half.hpp> #include <half.hpp>
#include "profile_conv_fwd_bias_relu_impl.hpp" #include "profile_conv_fwd_bias_relu_impl.hpp"
enum ConvDataType enum struct ConvDataType
{ {
F32_F32_F32, // 0 F32_F32_F32, // 0
F16_F16_F16, // 1 F16_F16_F16, // 1
}; };
enum ConvInputLayout enum struct ConvInputLayout
{ {
NCHW, // 0 NCHW, // 0
NHWC, // 1 NHWC, // 1
}; };
enum ConvWeightLayout enum struct ConvWeightLayout
{ {
KCYX, // 0 KCYX, // 0
KYXC, // 1 KYXC, // 1
}; };
enum ConvOutputLayout enum struct ConvOutputLayout
{ {
NKHW, // 0 NKHW, // 0
NHWK, // 1 NHWK, // 1
...@@ -48,10 +48,10 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[]) ...@@ -48,10 +48,10 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[])
exit(1); exit(1);
} }
const int data_type = static_cast<ConvDataType>(std::stoi(argv[2])); const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const int in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3])); const auto in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3]));
const int wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4])); const auto wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4]));
const int out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5])); const auto out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5]));
const bool do_verification = std::stoi(argv[6]); const bool do_verification = std::stoi(argv[6]);
const int init_method = std::stoi(argv[7]); const int init_method = std::stoi(argv[7]);
const bool do_log = std::stoi(argv[8]); const bool do_log = std::stoi(argv[8]);
......
...@@ -6,25 +6,25 @@ ...@@ -6,25 +6,25 @@
#include <half.hpp> #include <half.hpp>
#include "profile_conv_fwd_bias_relu_add_impl.hpp" #include "profile_conv_fwd_bias_relu_add_impl.hpp"
enum ConvDataType enum struct ConvDataType
{ {
F32_F32_F32, // 0 F32_F32_F32, // 0
F16_F16_F16, // 1 F16_F16_F16, // 1
}; };
enum ConvInputLayout enum struct ConvInputLayout
{ {
NCHW, // 0 NCHW, // 0
NHWC, // 1 NHWC, // 1
}; };
enum ConvWeightLayout enum struct ConvWeightLayout
{ {
KCYX, // 0 KCYX, // 0
KYXC, // 1 KYXC, // 1
}; };
enum ConvOutputLayout enum struct ConvOutputLayout
{ {
NKHW, // 0 NKHW, // 0
NHWK, // 1 NHWK, // 1
...@@ -49,10 +49,10 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[]) ...@@ -49,10 +49,10 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[])
exit(1); exit(1);
} }
const int data_type = static_cast<ConvDataType>(std::stoi(argv[2])); const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const int in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3])); const auto in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3]));
const int wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4])); const auto wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4]));
const int out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5])); const auto out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5]));
const bool do_verification = std::stoi(argv[6]); const bool do_verification = std::stoi(argv[6]);
const int init_method = std::stoi(argv[7]); const int init_method = std::stoi(argv[7]);
const bool do_log = std::stoi(argv[8]); const bool do_log = std::stoi(argv[8]);
......
...@@ -6,25 +6,25 @@ ...@@ -6,25 +6,25 @@
#include <half.hpp> #include <half.hpp>
#include "profile_conv_fwd_bias_relu_atomic_add_impl.hpp" #include "profile_conv_fwd_bias_relu_atomic_add_impl.hpp"
enum ConvDataType enum struct ConvDataType
{ {
F32_F32_F32, // 0 F32_F32_F32, // 0
F16_F16_F16, // 1 F16_F16_F16, // 1
}; };
enum ConvInputLayout enum struct ConvInputLayout
{ {
NCHW, // 0 NCHW, // 0
NHWC, // 1 NHWC, // 1
}; };
enum ConvWeightLayout enum struct ConvWeightLayout
{ {
KCYX, // 0 KCYX, // 0
KYXC, // 1 KYXC, // 1
}; };
enum ConvOutputLayout enum struct ConvOutputLayout
{ {
NKHW, // 0 NKHW, // 0
NHWK, // 1 NHWK, // 1
...@@ -49,10 +49,10 @@ int profile_conv_fwd_bias_relu_atomic_add(int argc, char* argv[]) ...@@ -49,10 +49,10 @@ int profile_conv_fwd_bias_relu_atomic_add(int argc, char* argv[])
exit(1); exit(1);
} }
const int data_type = static_cast<ConvDataType>(std::stoi(argv[2])); const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const int in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3])); const auto in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3]));
const int wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4])); const auto wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4]));
const int out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5])); const auto out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5]));
const bool do_verification = std::stoi(argv[6]); const bool do_verification = std::stoi(argv[6]);
const int init_method = std::stoi(argv[7]); const int init_method = std::stoi(argv[7]);
const bool do_log = std::stoi(argv[8]); const bool do_log = std::stoi(argv[8]);
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <half.hpp> #include <half.hpp>
#include "profile_gemm_impl.hpp" #include "profile_gemm_impl.hpp"
enum GemmMatrixLayout enum struct GemmMatrixLayout
{ {
MK_KN_MN, // 0 MK_KN_MN, // 0
MK_NK_MN, // 1 MK_NK_MN, // 1
...@@ -18,7 +18,7 @@ enum GemmMatrixLayout ...@@ -18,7 +18,7 @@ enum GemmMatrixLayout
KM_NK_NM, // 7 KM_NK_NM, // 7
}; };
enum GemmDataType enum struct GemmDataType
{ {
F32_F32_F32, // 0 F32_F32_F32, // 0
F16_F16_F16, // 1 F16_F16_F16, // 1
...@@ -45,8 +45,8 @@ int profile_gemm(int argc, char* argv[]) ...@@ -45,8 +45,8 @@ int profile_gemm(int argc, char* argv[])
exit(1); exit(1);
} }
const int data_type = static_cast<GemmDataType>(std::stoi(argv[2])); const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3])); const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]); const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]); const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]); const bool do_log = std::stoi(argv[6]);
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <half.hpp> #include <half.hpp>
#include "profile_gemm_bias_2d_impl.hpp" #include "profile_gemm_bias_2d_impl.hpp"
enum GemmMatrixLayout enum struct GemmMatrixLayout
{ {
MK_KN_MN, // 0 MK_KN_MN, // 0
MK_NK_MN, // 1 MK_NK_MN, // 1
...@@ -18,7 +18,7 @@ enum GemmMatrixLayout ...@@ -18,7 +18,7 @@ enum GemmMatrixLayout
KM_NK_NM, // 7 KM_NK_NM, // 7
}; };
enum GemmDataType enum struct GemmDataType
{ {
F32_F32_F32, // 0 F32_F32_F32, // 0
F16_F16_F16, // 1 F16_F16_F16, // 1
...@@ -45,8 +45,8 @@ int profile_gemm_bias_2d(int argc, char* argv[]) ...@@ -45,8 +45,8 @@ int profile_gemm_bias_2d(int argc, char* argv[])
exit(1); exit(1);
} }
const int data_type = static_cast<GemmDataType>(std::stoi(argv[2])); const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3])); const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]); const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]); const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]); const bool do_log = std::stoi(argv[6]);
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <half.hpp> #include <half.hpp>
#include "profile_gemm_bias_relu_impl.hpp" #include "profile_gemm_bias_relu_impl.hpp"
enum GemmMatrixLayout enum struct GemmMatrixLayout
{ {
MK_KN_MN, // 0 MK_KN_MN, // 0
MK_NK_MN, // 1 MK_NK_MN, // 1
...@@ -18,7 +18,7 @@ enum GemmMatrixLayout ...@@ -18,7 +18,7 @@ enum GemmMatrixLayout
KM_NK_NM, // 7 KM_NK_NM, // 7
}; };
enum GemmDataType enum struct GemmDataType
{ {
F32_F32_F32, // 0 F32_F32_F32, // 0
F16_F16_F16, // 1 F16_F16_F16, // 1
...@@ -43,8 +43,8 @@ int profile_gemm_bias_relu(int argc, char* argv[]) ...@@ -43,8 +43,8 @@ int profile_gemm_bias_relu(int argc, char* argv[])
exit(1); exit(1);
} }
const int data_type = static_cast<GemmDataType>(std::stoi(argv[2])); const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3])); const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]); const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]); const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]); const bool do_log = std::stoi(argv[6]);
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <half.hpp> #include <half.hpp>
#include "profile_gemm_bias_relu_add_impl.hpp" #include "profile_gemm_bias_relu_add_impl.hpp"
enum GemmMatrixLayout enum struct GemmMatrixLayout
{ {
MK_KN_MN, // 0 MK_KN_MN, // 0
MK_NK_MN, // 1 MK_NK_MN, // 1
...@@ -18,7 +18,7 @@ enum GemmMatrixLayout ...@@ -18,7 +18,7 @@ enum GemmMatrixLayout
KM_NK_NM, // 7 KM_NK_NM, // 7
}; };
enum GemmDataType enum struct GemmDataType
{ {
F32_F32_F32, // 0 F32_F32_F32, // 0
F16_F16_F16, // 1 F16_F16_F16, // 1
...@@ -43,8 +43,8 @@ int profile_gemm_bias_relu_add(int argc, char* argv[]) ...@@ -43,8 +43,8 @@ int profile_gemm_bias_relu_add(int argc, char* argv[])
exit(1); exit(1);
} }
const int data_type = static_cast<GemmDataType>(std::stoi(argv[2])); const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3])); const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]); const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]); const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]); const bool do_log = std::stoi(argv[6]);
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -16,6 +16,7 @@ include_directories(BEFORE ...@@ -16,6 +16,7 @@ include_directories(BEFORE
${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/cpu ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/cpu
${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/gpu ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/gpu
${PROJECT_SOURCE_DIR}/test/include ${PROJECT_SOURCE_DIR}/test/include
${PROJECT_SOURCE_DIR}/profiler/include
${PROJECT_SOURCE_DIR}/external/include/half ${PROJECT_SOURCE_DIR}/external/include/half
) )
...@@ -35,9 +36,10 @@ add_subdirectory(space_filling_curve) ...@@ -35,9 +36,10 @@ add_subdirectory(space_filling_curve)
add_subdirectory(conv_util) add_subdirectory(conv_util)
add_subdirectory(reference_conv_fwd) add_subdirectory(reference_conv_fwd)
add_subdirectory(gemm) add_subdirectory(gemm)
add_subdirectory(grouped_gemm)
add_subdirectory(gemm_split_k) add_subdirectory(gemm_split_k)
add_subdirectory(gemm_reduce)
add_subdirectory(batched_gemm)
add_subdirectory(grouped_gemm)
add_subdirectory(convnd_fwd) add_subdirectory(convnd_fwd)
add_subdirectory(conv2d_bwd_data) add_subdirectory(conv2d_bwd_data)
add_subdirectory(batched_gemm)
add_subdirectory(reduce) add_subdirectory(reduce)
include_directories(BEFORE
${PROJECT_SOURCE_DIR}/profiler/include
${PROJECT_SOURCE_DIR}/test/include
${PROJECT_SOURCE_DIR}/external/include/half
)
add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16.cpp)
target_link_libraries(test_gemm_reduce_fp16 PRIVATE host_tensor)
target_link_libraries(test_gemm_reduce_fp16 PRIVATE device_gemm_reduce_instance)
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment