"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "f015c77687827c3cf68f9227db0ed15006902deb"
Commit aa0a891a authored by Chao Liu's avatar Chao Liu
Browse files

added 1x1 conv

parent a86da8e0
#include <stdlib.h>
#include "config.hpp"
#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0.hpp"
#include "device_conv_instance.hpp"
#include "element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_conv_instance {
using F16 = ck::half_t;
using F32 = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple<
// clang-format off
//#######################################################################| InData| WeiData| OutData| AccData| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//#######################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//#######################################################################| | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//#######################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 2, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 1, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 1, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 2, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_P0< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 2, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, S<1, 4, 8>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 7, 1, true, true>
// clang-format on
>;
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_fp16_instances(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& device_conv_instances)
{
using DeviceConvs = device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instances;
const auto device_convs = DeviceConvs{};
ck::static_for<0, std::tuple_size_v<DeviceConvs>, 1>{}([&](auto i) {
using Conv = remove_cvref_t<decltype(std::get<i>(device_convs))>;
auto conv = Conv{};
device_conv_instances.push_back(std::make_unique<Conv>(conv));
});
}
} // namespace device_conv_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#ifndef DEVICE_BASE_HPP #ifndef DEVICE_BASE_HPP
#define DEVICE_BASE_HPP #define DEVICE_BASE_HPP
#include <string>
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -32,6 +34,7 @@ struct BaseOperator ...@@ -32,6 +34,7 @@ struct BaseOperator
BaseOperator& operator=(const BaseOperator&) = default; BaseOperator& operator=(const BaseOperator&) = default;
virtual bool IsSupportedArgument(const BaseArgument*) = 0; virtual bool IsSupportedArgument(const BaseArgument*) = 0;
virtual std::string GetTypeString() const = 0;
virtual ~BaseOperator() {} virtual ~BaseOperator() {}
}; };
......
...@@ -34,76 +34,12 @@ struct DeviceConvFwd : public BaseOperator ...@@ -34,76 +34,12 @@ struct DeviceConvFwd : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
struct DeviceConvBwd : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(void* p_in,
const void* p_wei,
const void* p_out,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
struct DeviceConvWrw : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in,
void* p_wei,
const void* p_out,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename InElementwiseOperation, template <typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation> typename OutElementwiseOperation>
using DeviceConvFwdPtr = std::unique_ptr< using DeviceConvFwdPtr = std::unique_ptr<
DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>; DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
template <typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
using DeviceConvBwdPtr = std::unique_ptr<
DeviceConvBwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
template <typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
using DeviceConvWrwPtr = std::unique_ptr<
DeviceConvWrw<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define DEVICE_CONV2D_FWD_XDL_NHWC_KYXC_NHWK_HPP #define DEVICE_CONV2D_FWD_XDL_NHWC_KYXC_NHWK_HPP
#include <iostream> #include <iostream>
#include <sstream>
#include "device.hpp" #include "device.hpp"
#include "device_base.hpp" #include "device_base.hpp"
#include "device_conv.hpp" #include "device_conv.hpp"
...@@ -594,6 +595,24 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -594,6 +595,24 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
{ {
return std::make_unique<Invoker>(Invoker{}); return std::make_unique<Invoker>(Invoker{});
} }
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock
<< ">";
// clang-format on
return str.str();
}
}; // namespace device }; // namespace device
} // namespace device } // namespace device
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define DEVICE_CONV2D_FWD_XDL_NHWC_KYXC_NHWK_1X1_S1_P0_HPP #define DEVICE_CONV2D_FWD_XDL_NHWC_KYXC_NHWK_1X1_S1_P0_HPP
#include <iostream> #include <iostream>
#include <sstream>
#include "device.hpp" #include "device.hpp"
#include "device_base.hpp" #include "device_base.hpp"
#include "device_conv.hpp" #include "device_conv.hpp"
...@@ -546,6 +547,24 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_S1 ...@@ -546,6 +547,24 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_S1
{ {
return std::make_unique<Invoker>(Invoker{}); return std::make_unique<Invoker>(Invoker{});
} }
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K_1x1_s1_P0"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock
<< ">";
// clang-format on
return str.str();
}
}; // namespace device }; // namespace device
} // namespace device } // namespace device
......
...@@ -21,30 +21,6 @@ void add_device_conv_fwd_instance( ...@@ -21,30 +21,6 @@ void add_device_conv_fwd_instance(
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>&); ck::tensor_operation::element_wise::PassThrough>>&);
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout>
void add_device_conv_bwd_instance(
std::vector<DeviceConvBwdPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>&);
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout>
void add_device_conv_wrw_instance(
std::vector<DeviceConvWrwPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>&);
} // namespace device_conv_instance } // namespace device_conv_instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define DEVICE_GEMM_XDL_HPP #define DEVICE_GEMM_XDL_HPP
#include <iostream> #include <iostream>
#include <sstream>
#include "device.hpp" #include "device.hpp"
#include "device_base.hpp" #include "device_base.hpp"
#include "device_gemm.hpp" #include "device_gemm.hpp"
...@@ -483,6 +484,24 @@ struct DeviceGemmXdl ...@@ -483,6 +484,24 @@ struct DeviceGemmXdl
{ {
return std::make_unique<Invoker>(Invoker{}); return std::make_unique<Invoker>(Invoker{});
} }
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmXdl"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock
<< ">";
// clang-format on
return str.str();
}
}; };
} // namespace device } // namespace device
......
...@@ -34,6 +34,7 @@ install(TARGETS device_gemm_instance LIBRARY DESTINATION lib) ...@@ -34,6 +34,7 @@ install(TARGETS device_gemm_instance LIBRARY DESTINATION lib)
set(DEVICE_CONV_INSTANCE_SOURCE set(DEVICE_CONV_INSTANCE_SOURCE
${PROJECT_SOURCE_DIR}/device_operation/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instance.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instance.cpp; ${PROJECT_SOURCE_DIR}/device_operation/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instance.cpp;
) )
...@@ -48,4 +49,5 @@ set(PROFILER_SOURCE profiler.cpp gemm_profiler.cpp conv_profiler.cpp) ...@@ -48,4 +49,5 @@ set(PROFILER_SOURCE profiler.cpp gemm_profiler.cpp conv_profiler.cpp)
add_executable(ckProfiler ${PROFILER_SOURCE}) add_executable(ckProfiler ${PROFILER_SOURCE})
target_link_libraries(ckProfiler PRIVATE host_tensor) target_link_libraries(ckProfiler PRIVATE host_tensor)
target_link_libraries(ckProfiler PRIVATE device_gemm_instance device_conv_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_conv_instance)
...@@ -39,6 +39,9 @@ void add_device_conv_fwd_instance<2, ...@@ -39,6 +39,9 @@ void add_device_conv_fwd_instance<2,
ck::tensor_layout::convolution::NHWK>( ck::tensor_layout::convolution::NHWK>(
std::vector<DeviceConvFwdNoOpPtr>&); std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_fp16_instances(
std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_fp16_instances( void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_fp16_instances(
std::vector<DeviceConvFwdNoOpPtr>&); std::vector<DeviceConvFwdNoOpPtr>&);
...@@ -158,6 +161,9 @@ void profile_conv(int do_verification, ...@@ -158,6 +161,9 @@ void profile_conv(int do_verification,
OutLayout>( OutLayout>(
conv_ptrs); conv_ptrs);
ck::tensor_operation::device::device_conv_instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_fp16_instances(conv_ptrs);
ck::tensor_operation::device::device_conv_instance:: ck::tensor_operation::device::device_conv_instance::
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_fp16_instances(conv_ptrs); add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_fp16_instances(conv_ptrs);
...@@ -166,6 +172,7 @@ void profile_conv(int do_verification, ...@@ -166,6 +172,7 @@ void profile_conv(int do_verification,
throw std::runtime_error("wrong! no device Conv instance found"); throw std::runtime_error("wrong! no device Conv instance found");
} }
std::string best_conv_name;
float best_ave_time = 0; float best_ave_time = 0;
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
...@@ -195,6 +202,8 @@ void profile_conv(int do_verification, ...@@ -195,6 +202,8 @@ void profile_conv(int do_verification,
if(conv_ptr->IsSupportedArgument(argument_ptr.get())) if(conv_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
std::string conv_name = conv_ptr->GetTypeString();
float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat);
std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X;
...@@ -208,10 +217,11 @@ void profile_conv(int do_verification, ...@@ -208,10 +217,11 @@ void profile_conv(int do_verification,
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s" << std::endl; << " GB/s, " << conv_name << std::endl;
if(tflops > best_tflops) if(tflops > best_tflops)
{ {
best_conv_name = conv_name;
best_tflops = tflops; best_tflops = tflops;
best_ave_time = ave_time; best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec; best_gb_per_sec = gb_per_sec;
...@@ -241,7 +251,7 @@ void profile_conv(int do_verification, ...@@ -241,7 +251,7 @@ void profile_conv(int do_verification,
} }
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s" << std::endl; << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl;
} }
} // namespace profiler } // namespace profiler
......
...@@ -164,6 +164,7 @@ void profile_gemm(int do_verification, ...@@ -164,6 +164,7 @@ void profile_gemm(int do_verification,
throw std::runtime_error("wrong! no device GEMM instance found"); throw std::runtime_error("wrong! no device GEMM instance found");
} }
std::string best_gemm_name;
float best_ave_time = 0; float best_ave_time = 0;
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
...@@ -189,9 +190,12 @@ void profile_gemm(int do_verification, ...@@ -189,9 +190,12 @@ void profile_gemm(int do_verification,
if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
std::string gemm_name = gemm_ptr->GetTypeString();
float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat);
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + sizeof(CDataType) * M * N; sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + sizeof(CDataType) * M * N;
...@@ -200,10 +204,11 @@ void profile_gemm(int do_verification, ...@@ -200,10 +204,11 @@ void profile_gemm(int do_verification,
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s" << std::endl; << " GB/s, " << gemm_name << std::endl;
if(tflops > best_tflops) if(tflops > best_tflops)
{ {
best_gemm_name = gemm_name;
best_tflops = tflops; best_tflops = tflops;
best_ave_time = ave_time; best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec; best_gb_per_sec = gb_per_sec;
...@@ -234,7 +239,7 @@ void profile_gemm(int do_verification, ...@@ -234,7 +239,7 @@ void profile_gemm(int do_verification,
} }
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s" << std::endl; << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
} }
} // namespace profiler } // namespace profiler
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment