"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "276b2648b84593046ad4bc8413d6e74e642c8be0"
Unverified Commit cec69bc3 authored by JD's avatar JD Committed by GitHub
Browse files

Add host API (#220)



* Add host API

* manually rebase on develop

* clean

* manually rebase on develop

* exclude tests from all target

* address review comments

* update client app name

* fix missing lib name

* clang-format update

* refactor

* refactor

* refactor

* refactor

* refactor

* fix test issue

* refactor

* refactor

* refactor

* upate cmake and readme
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
parent 0f912e20
...@@ -2,9 +2,7 @@ ...@@ -2,9 +2,7 @@
set(DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE set(DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE
device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp; device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp;
) )
add_library(device_conv2d_fwd_bias_relu_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE}) add_library(device_conv2d_fwd_bias_relu_add_instance OBJECT ${DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE})
target_compile_features(device_conv2d_fwd_bias_relu_add_instance PUBLIC)
set_target_properties(device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
install(TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib)
clang_tidy_check(device_conv2d_fwd_bias_relu_add_instance) clang_tidy_check(device_conv2d_fwd_bias_relu_add_instance)
...@@ -3,9 +3,7 @@ set(DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE ...@@ -3,9 +3,7 @@ set(DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE
device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp; device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp;
) )
add_library(device_conv2d_fwd_bias_relu_atomic_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE}) add_library(device_conv2d_fwd_bias_relu_atomic_add_instance OBJECT ${DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE})
target_compile_features(device_conv2d_fwd_bias_relu_atomic_add_instance PUBLIC)
set_target_properties(device_conv2d_fwd_bias_relu_atomic_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_conv2d_fwd_bias_relu_atomic_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
install(TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib)
clang_tidy_check(device_conv2d_fwd_bias_relu_atomic_add_instance) clang_tidy_check(device_conv2d_fwd_bias_relu_atomic_add_instance)
...@@ -5,9 +5,8 @@ set(DEVICE_CONV3D_FWD_INSTANCE_SOURCE ...@@ -5,9 +5,8 @@ set(DEVICE_CONV3D_FWD_INSTANCE_SOURCE
device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp; device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp;
device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp; device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp;
) )
add_library(device_conv3d_fwd_instance SHARED ${DEVICE_CONV3D_FWD_INSTANCE_SOURCE}) add_library(device_conv3d_fwd_instance OBJECT ${DEVICE_CONV3D_FWD_INSTANCE_SOURCE})
target_compile_features(device_conv3d_fwd_instance PUBLIC) target_compile_features(device_conv3d_fwd_instance PUBLIC)
set_target_properties(device_conv3d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_conv3d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
install(TARGETS device_conv3d_fwd_instance LIBRARY DESTINATION lib)
clang_tidy_check(device_conv3d_fwd_instance) clang_tidy_check(device_conv3d_fwd_instance)
...@@ -14,7 +14,7 @@ set(DEVICE_CONVND_BWD_DATA_INSTANCE_SOURCE ...@@ -14,7 +14,7 @@ set(DEVICE_CONVND_BWD_DATA_INSTANCE_SOURCE
device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp; device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp;
) )
add_library(device_convnd_bwd_data_instance SHARED ${DEVICE_CONVND_BWD_DATA_INSTANCE_SOURCE}) add_library(device_convnd_bwd_data_instance OBJECT ${DEVICE_CONVND_BWD_DATA_INSTANCE_SOURCE})
target_compile_features(device_convnd_bwd_data_instance PUBLIC) target_compile_features(device_convnd_bwd_data_instance PUBLIC)
set_target_properties(device_convnd_bwd_data_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_convnd_bwd_data_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
install(TARGETS device_convnd_bwd_data_instance LIBRARY DESTINATION lib) install(TARGETS device_convnd_bwd_data_instance LIBRARY DESTINATION lib)
......
#include <stdlib.h>
#include "config.hpp"
#include "device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
#include "host_interface.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_conv2d_fwd_instance {
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances);
} // namespace device_conv2d_fwd_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
struct DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl
{
std::unique_ptr<DeviceConvFwdPtr_t::BaseArgument>
MakeArgumentPointer(void* in_ptr,
void* wei_ptr,
void* out_ptr,
size_t N,
size_t K,
size_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) const
{
return el->MakeArgumentPointer(in_ptr,
wei_ptr,
out_ptr,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
PassThrough{},
PassThrough{},
PassThrough{});
}
std::unique_ptr<DeviceConvFwdPtr_t::BaseInvoker> MakeInvokerPointer() const
{
return el->MakeInvokerPointer();
}
std::string GetTypeString() { return el->GetTypeString(); }
bool IsSupportedArgument(const DeviceConvFwdPtr_t::BaseArgument* arg)
{
return el->IsSupportedArgument(arg);
}
ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough> el;
};
DeviceConvFwdPtr_t::DeviceConvFwdPtr_t() : pImpl(nullptr) {}
DeviceConvFwdPtr_t::~DeviceConvFwdPtr_t() = default;
DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t&&) = default;
DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl& other)
: pImpl(std::make_unique<DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl>(std::move(other)))
{
}
std::unique_ptr<DeviceConvFwdPtr_t::BaseArgument>
DeviceConvFwdPtr_t::MakeArgumentPointer(void* in_ptr,
void* wei_ptr,
void* out_ptr,
size_t N,
size_t K,
size_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) const
{
return pImpl->MakeArgumentPointer(in_ptr,
wei_ptr,
out_ptr,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
}
std::unique_ptr<DeviceConvFwdPtr_t::BaseInvoker> DeviceConvFwdPtr_t::MakeInvokerPointer() const
{
return pImpl->MakeInvokerPointer();
}
std::string DeviceConvFwdPtr_t::GetTypeString() { return pImpl->GetTypeString(); }
bool DeviceConvFwdPtr_t::IsSupportedArgument(const DeviceConvFwdPtr_t::BaseArgument* arg_ptr)
{
return pImpl->IsSupportedArgument(arg_ptr);
}
using namespace ck::tensor_operation::device::device_conv2d_fwd_instance;
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances)
{
std::vector<
ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>
local_instances;
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(local_instances);
for(auto& kinder : local_instances)
{
DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)};
instances.emplace_back(tmp);
}
return;
}
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances)
{
std::vector<
ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>
local_instances;
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(local_instances);
for(auto& kinder : local_instances)
{
DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)};
instances.emplace_back(tmp); // Perhaps we can do better
}
return;
}
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances)
{
std::vector<
ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>
local_instances;
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(local_instances);
for(auto& kinder : local_instances)
{
DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)};
instances.emplace_back(tmp); // Perhaps we can do better
}
return;
}
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances)
{
std::vector<
ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>
local_instances;
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(local_instances);
for(auto& kinder : local_instances)
{
DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)};
instances.emplace_back(tmp); // Perhaps we can do better
}
return;
}
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t(
std::vector<DeviceConvFwdPtr_t>& instances)
{
std::vector<
ck::tensor_operation::device::DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>
local_instances;
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(local_instances);
for(auto& kinder : local_instances)
{
DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)};
instances.emplace_back(tmp);
}
return;
}
...@@ -35,10 +35,9 @@ set(DEVICE_GEMM_INSTANCE_SOURCE ...@@ -35,10 +35,9 @@ set(DEVICE_GEMM_INSTANCE_SOURCE
device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp; device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp;
) )
add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE}) add_library(device_gemm_instance OBJECT ${DEVICE_GEMM_INSTANCE_SOURCE})
target_compile_features(device_gemm_instance PUBLIC) target_compile_features(device_gemm_instance PUBLIC)
set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
install(TARGETS device_gemm_instance LIBRARY DESTINATION lib)
clang_tidy_check(device_gemm_instance) clang_tidy_check(device_gemm_instance)
...@@ -10,9 +10,7 @@ set(DEVICE_GEMM_BIAS2D_INSTANCE_SOURCE ...@@ -10,9 +10,7 @@ set(DEVICE_GEMM_BIAS2D_INSTANCE_SOURCE
device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instance.cpp; device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instance.cpp;
) )
add_library(device_gemm_bias2d_instance SHARED ${DEVICE_GEMM_BIAS2D_INSTANCE_SOURCE}) add_library(device_gemm_bias2d_instance OBJECT ${DEVICE_GEMM_BIAS2D_INSTANCE_SOURCE})
target_compile_features(device_gemm_bias2d_instance PUBLIC)
set_target_properties(device_gemm_bias2d_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_gemm_bias2d_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
install(TARGETS device_gemm_bias2d_instance LIBRARY DESTINATION lib)
clang_tidy_check(device_gemm_bias2d_instance) clang_tidy_check(device_gemm_bias2d_instance)
...@@ -6,9 +6,7 @@ set(DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE ...@@ -6,9 +6,7 @@ set(DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE
device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instance.cpp; device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instance.cpp;
) )
add_library(device_gemm_bias_relu_instance SHARED ${DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE}) add_library(device_gemm_bias_relu_instance OBJECT ${DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE})
target_compile_features(device_gemm_bias_relu_instance PUBLIC)
set_target_properties(device_gemm_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_gemm_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
install(TARGETS device_gemm_bias_relu_instance LIBRARY DESTINATION lib)
clang_tidy_check(device_gemm_bias_relu_instance) clang_tidy_check(device_gemm_bias_relu_instance)
...@@ -6,9 +6,7 @@ set(DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE ...@@ -6,9 +6,7 @@ set(DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE
device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instance.cpp; device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instance.cpp;
) )
add_library(device_gemm_bias_relu_add_instance SHARED ${DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE}) add_library(device_gemm_bias_relu_add_instance OBJECT ${DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE})
target_compile_features(device_gemm_bias_relu_add_instance PUBLIC)
set_target_properties(device_gemm_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_gemm_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
install(TARGETS device_gemm_bias_relu_add_instance LIBRARY DESTINATION lib)
clang_tidy_check(device_gemm_bias_relu_add_instance) clang_tidy_check(device_gemm_bias_relu_add_instance)
...@@ -6,7 +6,7 @@ set(DEVICE_GROUPED_GEMM_INSTANCE_SOURCE ...@@ -6,7 +6,7 @@ set(DEVICE_GROUPED_GEMM_INSTANCE_SOURCE
device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp; device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp;
) )
add_library(device_grouped_gemm_instance SHARED ${DEVICE_GROUPED_GEMM_INSTANCE_SOURCE}) add_library(device_grouped_gemm_instance OBJECT ${DEVICE_GROUPED_GEMM_INSTANCE_SOURCE})
target_compile_features(device_grouped_gemm_instance PUBLIC) target_compile_features(device_grouped_gemm_instance PUBLIC)
set_target_properties(device_grouped_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_grouped_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
......
...@@ -38,9 +38,7 @@ set(DEVICE_REDUCE_INSTANCE_SOURCE ...@@ -38,9 +38,7 @@ set(DEVICE_REDUCE_INSTANCE_SOURCE
device_reduce_instance_multiblock_partial_reduce_b16_f32_b16.cpp; device_reduce_instance_multiblock_partial_reduce_b16_f32_b16.cpp;
) )
add_library(device_reduce_instance SHARED ${DEVICE_REDUCE_INSTANCE_SOURCE}) add_library(device_reduce_instance OBJECT ${DEVICE_REDUCE_INSTANCE_SOURCE})
target_compile_features(device_reduce_instance PUBLIC)
set_target_properties(device_reduce_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_reduce_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
install(TARGETS device_reduce_instance LIBRARY DESTINATION lib)
clang_tidy_check(device_reduce_instance) clang_tidy_check(device_reduce_instance)
...@@ -63,7 +63,7 @@ template <typename ADataType, ...@@ -63,7 +63,7 @@ template <typename ADataType,
bool profile_batched_gemm_impl(int do_verification, bool profile_batched_gemm_impl(int do_verification,
int init_method, int init_method,
bool do_log, bool do_log,
int nrepeat, bool time_kernel,
int M, int M,
int N, int N,
int K, int K,
...@@ -356,11 +356,12 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -356,11 +356,12 @@ bool profile_batched_gemm_impl(int do_verification,
{ {
std::string gemm_name = gemm_ptr->GetTypeString(); std::string gemm_name = gemm_ptr->GetTypeString();
float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * BatchCount * M * N * K; std::size_t flop = std::size_t(2) * BatchCount * M * N * K;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N) * sizeof(CDataType) * M * N) *
BatchCount; BatchCount;
......
...@@ -53,7 +53,7 @@ template <typename ADataType, ...@@ -53,7 +53,7 @@ template <typename ADataType,
bool profile_batched_gemm_reduce_impl(int do_verification, bool profile_batched_gemm_reduce_impl(int do_verification,
int init_method, int init_method,
bool do_log, bool do_log,
int nrepeat, bool time_kernel,
int M, int M,
int N, int N,
int K, int K,
...@@ -258,31 +258,13 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -258,31 +258,13 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
// warm up
invoker_ptr->Run(argument_ptr.get());
// timing
float total_time = 0;
for(int i = 0; i < nrepeat; ++i)
{ {
// init DO, D1 to 0 // init DO, D1 to 0
d0_device_buf.SetZero(); d0_device_buf.SetZero();
d1_device_buf.SetZero(); d1_device_buf.SetZero();
KernelTimer timer; float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
timer.Start();
invoker_ptr->Run(argument_ptr.get());
timer.End();
total_time += timer.GetElapsedTime();
}
float ave_time = total_time / nrepeat;
std::string gemm_name = gemm_ptr->GetTypeString(); std::string gemm_name = gemm_ptr->GetTypeString();
......
...@@ -51,7 +51,7 @@ template <int NDimSpatial, ...@@ -51,7 +51,7 @@ template <int NDimSpatial,
void profile_conv_bwd_data_impl(int do_verification, void profile_conv_bwd_data_impl(int do_verification,
int init_method, int init_method,
bool do_log, bool do_log,
int nrepeat, bool time_kernel,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
...@@ -228,7 +228,8 @@ void profile_conv_bwd_data_impl(int do_verification, ...@@ -228,7 +228,8 @@ void profile_conv_bwd_data_impl(int do_verification,
{ {
std::string conv_name = conv_ptr->GetTypeString(); std::string conv_name = conv_ptr->GetTypeString();
float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamControl{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X;
......
#pragma once #pragma once
#include "stream_config.hpp"
#include "config.hpp" #include "config.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
...@@ -43,7 +45,7 @@ template <int NDimSpatial, ...@@ -43,7 +45,7 @@ template <int NDimSpatial,
bool profile_conv_bwd_weight_impl(int do_verification, bool profile_conv_bwd_weight_impl(int do_verification,
int init_method, int init_method,
bool do_log, bool do_log,
int nrepeat, bool time_kernel,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
...@@ -182,6 +184,7 @@ bool profile_conv_bwd_weight_impl(int do_verification, ...@@ -182,6 +184,7 @@ bool profile_conv_bwd_weight_impl(int do_verification,
// profile device Conv instances // profile device Conv instances
bool pass = true; bool pass = true;
for(auto& conv_ptr : conv_ptrs) for(auto& conv_ptr : conv_ptrs)
{ {
// using atomic, so need to reset input // using atomic, so need to reset input
...@@ -189,6 +192,7 @@ bool profile_conv_bwd_weight_impl(int do_verification, ...@@ -189,6 +192,7 @@ bool profile_conv_bwd_weight_impl(int do_verification,
{ {
wei_device_buf.SetZero(); wei_device_buf.SetZero();
} }
auto argument_ptr = conv_ptr->MakeArgumentPointer( auto argument_ptr = conv_ptr->MakeArgumentPointer(
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()), static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
...@@ -214,7 +218,8 @@ bool profile_conv_bwd_weight_impl(int do_verification, ...@@ -214,7 +218,8 @@ bool profile_conv_bwd_weight_impl(int do_verification,
{ {
std::string conv_name = conv_ptr->GetTypeString(); std::string conv_name = conv_ptr->GetTypeString();
float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X;
...@@ -242,6 +247,7 @@ bool profile_conv_bwd_weight_impl(int do_verification, ...@@ -242,6 +247,7 @@ bool profile_conv_bwd_weight_impl(int do_verification,
wei_device_buf.FromDevice(wei_k_c_y_x_device_result.mData.data()); wei_device_buf.FromDevice(wei_k_c_y_x_device_result.mData.data());
float max_error = check_error(wei_k_c_y_x_host_result, wei_k_c_y_x_device_result); float max_error = check_error(wei_k_c_y_x_host_result, wei_k_c_y_x_device_result);
if(max_error > 8) if(max_error > 8)
{ {
pass = false; pass = false;
......
...@@ -42,7 +42,7 @@ template <int NDimSpatial, ...@@ -42,7 +42,7 @@ template <int NDimSpatial,
void profile_conv_fwd_bias_relu_add_impl(int do_verification, void profile_conv_fwd_bias_relu_add_impl(int do_verification,
int init_method, int init_method,
bool do_log, bool do_log,
int nrepeat, bool time_kernel,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
...@@ -219,7 +219,8 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification, ...@@ -219,7 +219,8 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification,
{ {
std::string conv_name = op_ptr->GetTypeString(); std::string conv_name = op_ptr->GetTypeString();
float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X;
......
...@@ -119,7 +119,7 @@ template <int NDimSpatial, ...@@ -119,7 +119,7 @@ template <int NDimSpatial,
void profile_conv_fwd_bias_relu_atomic_add_impl(int do_verification, void profile_conv_fwd_bias_relu_atomic_add_impl(int do_verification,
int init_method, int init_method,
bool do_log, bool do_log,
int nrepeat, bool time_kernel,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
...@@ -275,7 +275,8 @@ void profile_conv_fwd_bias_relu_atomic_add_impl(int do_verification, ...@@ -275,7 +275,8 @@ void profile_conv_fwd_bias_relu_atomic_add_impl(int do_verification,
{ {
std::string conv_name = op_ptr->GetTypeString(); std::string conv_name = op_ptr->GetTypeString();
float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X;
......
...@@ -41,7 +41,7 @@ template <int NDimSpatial, ...@@ -41,7 +41,7 @@ template <int NDimSpatial,
void profile_conv_fwd_bias_relu_impl(int do_verification, void profile_conv_fwd_bias_relu_impl(int do_verification,
int init_method, int init_method,
bool do_log, bool do_log,
int nrepeat, bool time_kernel,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
...@@ -207,7 +207,8 @@ void profile_conv_fwd_bias_relu_impl(int do_verification, ...@@ -207,7 +207,8 @@ void profile_conv_fwd_bias_relu_impl(int do_verification,
{ {
std::string conv_name = op_ptr->GetTypeString(); std::string conv_name = op_ptr->GetTypeString();
float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X;
......
...@@ -269,7 +269,7 @@ template <int NDimSpatial, ...@@ -269,7 +269,7 @@ template <int NDimSpatial,
bool profile_convnd_bwd_data_impl(int do_verification, bool profile_convnd_bwd_data_impl(int do_verification,
int init_method, int init_method,
bool do_log, bool do_log,
int nrepeat, bool time_kernel,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
...@@ -410,7 +410,8 @@ bool profile_convnd_bwd_data_impl(int do_verification, ...@@ -410,7 +410,8 @@ bool profile_convnd_bwd_data_impl(int do_verification,
{ {
std::string conv_name = conv_ptr->GetTypeString(); std::string conv_name = conv_ptr->GetTypeString();
float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t flop =
ck::utils::conv::get_flops(N, C, K, filter_spatial_lengths, output_spatial_lengths); ck::utils::conv::get_flops(N, C, K, filter_spatial_lengths, output_spatial_lengths);
......
...@@ -65,7 +65,7 @@ template <typename ADataType, ...@@ -65,7 +65,7 @@ template <typename ADataType,
void profile_gemm_bias_2d_impl(int do_verification, void profile_gemm_bias_2d_impl(int do_verification,
int init_method, int init_method,
bool do_log, bool do_log,
int nrepeat, bool time_kernel,
int M, int M,
int N, int N,
int K, int K,
...@@ -259,7 +259,8 @@ void profile_gemm_bias_2d_impl(int do_verification, ...@@ -259,7 +259,8 @@ void profile_gemm_bias_2d_impl(int do_verification,
{ {
std::string gemm_name = gemm_ptr->GetTypeString(); std::string gemm_name = gemm_ptr->GetTypeString();
float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment