Unverified Commit f95267f1 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Gemm+Reduce Fusion (#128)

* add gridwise gemm v4r1

* rename

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* use sfc in shuffling

* remove hardcode

* remove hardcode

* refactor

* fix build

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* adding gemm+reduce

* format

* clean

* adding gemm+reduce

* adding profiler for gemm+reduce

* adding gemm+reduce profiler

* fix build

* clean up

* gemm+reduce

* fix build

* update DeviceGemm_Xdl_CShuffle; update enum to enum class

* clean up

* add test for gemm+reduce

* clean up

* refactor

* fix build

* fix build
parent f91579aa
......@@ -6,7 +6,7 @@
#include <half.hpp>
#include "profile_conv_bwd_data_impl.hpp"
enum ConvDataType
enum struct ConvDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
......@@ -14,19 +14,19 @@ enum ConvDataType
INT8_INT8_INT8, // 3
};
enum ConvInputLayout
enum struct ConvInputLayout
{
NCHW, // 0
NHWC, // 1
};
enum ConvWeightLayout
enum struct ConvWeightLayout
{
KCYX, // 0
KYXC, // 1
};
enum ConvOutputLayout
enum struct ConvOutputLayout
{
NKHW, // 0
NHWK, // 1
......@@ -50,10 +50,10 @@ int profile_conv_bwd_data(int argc, char* argv[])
exit(1);
}
const int data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const int in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3]));
const int wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4]));
const int out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5]));
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const auto in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3]));
const auto wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4]));
const auto out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5]));
const bool do_verification = std::stoi(argv[6]);
const int init_method = std::stoi(argv[7]);
const bool do_log = std::stoi(argv[8]);
......
......@@ -6,7 +6,7 @@
#include <half.hpp>
#include "profile_conv_fwd_impl.hpp"
enum ConvDataType
enum struct ConvDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
......@@ -14,19 +14,19 @@ enum ConvDataType
INT8_INT8_INT8, // 3
};
enum ConvInputLayout
enum struct ConvInputLayout
{
NCHW, // 0
NHWC, // 1
};
enum ConvWeightLayout
enum struct ConvWeightLayout
{
KCYX, // 0
KYXC, // 1
};
enum ConvOutputLayout
enum struct ConvOutputLayout
{
NKHW, // 0
NHWK, // 1
......@@ -50,10 +50,10 @@ int profile_conv_fwd(int argc, char* argv[])
exit(1);
}
const int data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const int in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3]));
const int wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4]));
const int out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5]));
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const auto in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3]));
const auto wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4]));
const auto out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5]));
const bool do_verification = std::stoi(argv[6]);
const int init_method = std::stoi(argv[7]);
const bool do_log = std::stoi(argv[8]);
......
......@@ -6,25 +6,25 @@
#include <half.hpp>
#include "profile_conv_fwd_bias_relu_impl.hpp"
enum ConvDataType
enum struct ConvDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
};
enum ConvInputLayout
enum struct ConvInputLayout
{
NCHW, // 0
NHWC, // 1
};
enum ConvWeightLayout
enum struct ConvWeightLayout
{
KCYX, // 0
KYXC, // 1
};
enum ConvOutputLayout
enum struct ConvOutputLayout
{
NKHW, // 0
NHWK, // 1
......@@ -48,10 +48,10 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[])
exit(1);
}
const int data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const int in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3]));
const int wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4]));
const int out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5]));
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const auto in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3]));
const auto wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4]));
const auto out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5]));
const bool do_verification = std::stoi(argv[6]);
const int init_method = std::stoi(argv[7]);
const bool do_log = std::stoi(argv[8]);
......
......@@ -6,25 +6,25 @@
#include <half.hpp>
#include "profile_conv_fwd_bias_relu_add_impl.hpp"
enum ConvDataType
enum struct ConvDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
};
enum ConvInputLayout
enum struct ConvInputLayout
{
NCHW, // 0
NHWC, // 1
};
enum ConvWeightLayout
enum struct ConvWeightLayout
{
KCYX, // 0
KYXC, // 1
};
enum ConvOutputLayout
enum struct ConvOutputLayout
{
NKHW, // 0
NHWK, // 1
......@@ -49,10 +49,10 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[])
exit(1);
}
const int data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const int in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3]));
const int wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4]));
const int out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5]));
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const auto in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3]));
const auto wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4]));
const auto out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5]));
const bool do_verification = std::stoi(argv[6]);
const int init_method = std::stoi(argv[7]);
const bool do_log = std::stoi(argv[8]);
......
......@@ -6,25 +6,25 @@
#include <half.hpp>
#include "profile_conv_fwd_bias_relu_atomic_add_impl.hpp"
enum ConvDataType
enum struct ConvDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
};
enum ConvInputLayout
enum struct ConvInputLayout
{
NCHW, // 0
NHWC, // 1
};
enum ConvWeightLayout
enum struct ConvWeightLayout
{
KCYX, // 0
KYXC, // 1
};
enum ConvOutputLayout
enum struct ConvOutputLayout
{
NKHW, // 0
NHWK, // 1
......@@ -49,10 +49,10 @@ int profile_conv_fwd_bias_relu_atomic_add(int argc, char* argv[])
exit(1);
}
const int data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const int in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3]));
const int wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4]));
const int out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5]));
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const auto in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3]));
const auto wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4]));
const auto out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5]));
const bool do_verification = std::stoi(argv[6]);
const int init_method = std::stoi(argv[7]);
const bool do_log = std::stoi(argv[8]);
......
......@@ -6,7 +6,7 @@
#include <half.hpp>
#include "profile_gemm_impl.hpp"
enum GemmMatrixLayout
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
......@@ -18,7 +18,7 @@ enum GemmMatrixLayout
KM_NK_NM, // 7
};
enum GemmDataType
enum struct GemmDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
......@@ -45,8 +45,8 @@ int profile_gemm(int argc, char* argv[])
exit(1);
}
const int data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]);
......
......@@ -6,7 +6,7 @@
#include <half.hpp>
#include "profile_gemm_bias_2d_impl.hpp"
enum GemmMatrixLayout
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
......@@ -18,7 +18,7 @@ enum GemmMatrixLayout
KM_NK_NM, // 7
};
enum GemmDataType
enum struct GemmDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
......@@ -45,8 +45,8 @@ int profile_gemm_bias_2d(int argc, char* argv[])
exit(1);
}
const int data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]);
......
......@@ -6,7 +6,7 @@
#include <half.hpp>
#include "profile_gemm_bias_relu_impl.hpp"
enum GemmMatrixLayout
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
......@@ -18,7 +18,7 @@ enum GemmMatrixLayout
KM_NK_NM, // 7
};
enum GemmDataType
enum struct GemmDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
......@@ -43,8 +43,8 @@ int profile_gemm_bias_relu(int argc, char* argv[])
exit(1);
}
const int data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]);
......
......@@ -6,7 +6,7 @@
#include <half.hpp>
#include "profile_gemm_bias_relu_add_impl.hpp"
enum GemmMatrixLayout
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
......@@ -18,7 +18,7 @@ enum GemmMatrixLayout
KM_NK_NM, // 7
};
enum GemmDataType
enum struct GemmDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
......@@ -43,8 +43,8 @@ int profile_gemm_bias_relu_add(int argc, char* argv[])
exit(1);
}
const int data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]);
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -16,6 +16,7 @@ include_directories(BEFORE
${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/cpu
${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/gpu
${PROJECT_SOURCE_DIR}/test/include
${PROJECT_SOURCE_DIR}/profiler/include
${PROJECT_SOURCE_DIR}/external/include/half
)
......@@ -35,9 +36,10 @@ add_subdirectory(space_filling_curve)
add_subdirectory(conv_util)
add_subdirectory(reference_conv_fwd)
add_subdirectory(gemm)
add_subdirectory(grouped_gemm)
add_subdirectory(gemm_split_k)
add_subdirectory(gemm_reduce)
add_subdirectory(batched_gemm)
add_subdirectory(grouped_gemm)
add_subdirectory(convnd_fwd)
add_subdirectory(conv2d_bwd_data)
add_subdirectory(batched_gemm)
add_subdirectory(reduce)
include_directories(BEFORE
${PROJECT_SOURCE_DIR}/profiler/include
${PROJECT_SOURCE_DIR}/test/include
${PROJECT_SOURCE_DIR}/external/include/half
)
add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16.cpp)
target_link_libraries(test_gemm_reduce_fp16 PRIVATE host_tensor)
target_link_libraries(test_gemm_reduce_fp16 PRIVATE device_gemm_reduce_instance)
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment