Unverified Commit 23ecf0fa authored by ltqin's avatar ltqin Committed by GitHub
Browse files

Add multiple d gridwise gemm on Navi21 for ResNet50 (#517)



* start add example

* add multiple d fp16 example

* device transfer elementwiseop to gridwise

* gridwise add multiple d

* change example for multiple d

* fix spill registers

* fix for passthrough element op

* fix int8 overflow

* change example file name

* add instance for dl multiple d

* example add DsDataType

* remove grouped_convolution_forward_dl.hpp

* add head file(was deleted before)

* fix not support device issue

* format

* remove passthrough check
Co-authored-by: default avatarletaoqin <letaoqin@amd.com>
parent abf9cc6c
...@@ -8,3 +8,4 @@ add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp6 ...@@ -8,3 +8,4 @@ add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp6
add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp) add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp)
add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp) add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp)
add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp) add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp)
...@@ -30,6 +30,7 @@ void print_helper_msg() ...@@ -30,6 +30,7 @@ void print_helper_msg()
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
typename InDataType, typename InDataType,
typename WeiDataType, typename WeiDataType,
typename DsDataType,
typename OutDataType, typename OutDataType,
typename InElementOp, typename InElementOp,
typename WeiElementOp, typename WeiElementOp,
...@@ -46,8 +47,10 @@ bool run_grouped_conv_fwd_dl(bool do_verification, ...@@ -46,8 +47,10 @@ bool run_grouped_conv_fwd_dl(bool do_verification,
const WeiElementOp& wei_element_op, const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op) const OutElementOp& out_element_op)
{ {
using DDataType = ck::remove_cvref_t<ck::tuple_element_t<0, DsDataType>>;
Tensor<InDataType> in(in_g_n_c_wis_desc); Tensor<InDataType> in(in_g_n_c_wis_desc);
Tensor<WeiDataType> wei(wei_g_k_c_xs_desc); Tensor<WeiDataType> wei(wei_g_k_c_xs_desc);
Tensor<DDataType> bias(out_g_n_k_wos_desc);
Tensor<OutDataType> out_host(out_g_n_k_wos_desc); Tensor<OutDataType> out_host(out_g_n_k_wos_desc);
Tensor<OutDataType> out_device(out_g_n_k_wos_desc); Tensor<OutDataType> out_device(out_g_n_k_wos_desc);
...@@ -59,31 +62,38 @@ bool run_grouped_conv_fwd_dl(bool do_verification, ...@@ -59,31 +62,38 @@ bool run_grouped_conv_fwd_dl(bool do_verification,
{ {
case 0: break; case 0: break;
case 1: case 1:
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}); in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-2, 3});
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5}); wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-2, 3});
bias.GenerateTensorValue(GeneratorTensor_2<DDataType>{-2, 3});
break; break;
case 2: case 2:
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0}); in.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5}); wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
bias.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
break; break;
default: default:
in.GenerateTensorValue(GeneratorTensor_1<InDataType>{1}); in.GenerateTensorValue(GeneratorTensor_1<InDataType>{1});
wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1}); wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{-1});
bias.GenerateTensorValue(GeneratorTensor_1<DDataType>{1});
} }
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize());
DeviceMem bias_device_buf(sizeof(DDataType) * bias.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(in.mData.data()); in_device_buf.ToDevice(in.mData.data());
wei_device_buf.ToDevice(wei.mData.data()); wei_device_buf.ToDevice(wei.mData.data());
bias_device_buf.ToDevice(bias.mData.data());
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{}; std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{}; std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{}; std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{}; std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{};
std::array<ck::index_t, NDimSpatial + 3> c_g_n_k_wos_lengths{}; std::array<ck::index_t, NDimSpatial + 3> d_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> c_g_n_k_wos_strides{}; std::array<ck::index_t, NDimSpatial + 3> d_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{}; std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{}; std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{}; std::array<ck::index_t, NDimSpatial> input_left_pads{};
...@@ -95,8 +105,10 @@ bool run_grouped_conv_fwd_dl(bool do_verification, ...@@ -95,8 +105,10 @@ bool run_grouped_conv_fwd_dl(bool do_verification,
copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides);
copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths);
copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides);
copy(out_g_n_k_wos_desc.GetLengths(), c_g_n_k_wos_lengths); copy(out_g_n_k_wos_desc.GetLengths(), d_g_n_k_wos_lengths);
copy(out_g_n_k_wos_desc.GetStrides(), c_g_n_k_wos_strides); copy(out_g_n_k_wos_desc.GetStrides(), d_g_n_k_wos_strides);
copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths);
copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides);
copy(conv_param.conv_filter_strides_, conv_filter_strides); copy(conv_param.conv_filter_strides_, conv_filter_strides);
copy(conv_param.conv_filter_dilations_, conv_filter_dilations); copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads); copy(conv_param.input_left_pads_, input_left_pads);
...@@ -105,15 +117,19 @@ bool run_grouped_conv_fwd_dl(bool do_verification, ...@@ -105,15 +117,19 @@ bool run_grouped_conv_fwd_dl(bool do_verification,
// do Conv // do Conv
auto conv = DeviceConvNDFwdInstance{}; auto conv = DeviceConvNDFwdInstance{};
auto invoker = conv.MakeInvoker(); auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), auto argument = conv.MakeArgument(
in_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(), wei_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{bias_device_buf.GetDeviceBuffer()},
out_device_buf.GetDeviceBuffer(), out_device_buf.GetDeviceBuffer(),
a_g_n_c_wis_lengths, a_g_n_c_wis_lengths,
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
c_g_n_k_wos_lengths, std::array<std::array<ck::index_t, NDimSpatial + 3>, 1>{{d_g_n_k_wos_lengths}},
c_g_n_k_wos_strides, std::array<std::array<ck::index_t, NDimSpatial + 3>, 1>{{d_g_n_k_wos_strides}},
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -124,6 +140,9 @@ bool run_grouped_conv_fwd_dl(bool do_verification, ...@@ -124,6 +140,9 @@ bool run_grouped_conv_fwd_dl(bool do_verification,
if(!conv.IsSupportedArgument(argument)) if(!conv.IsSupportedArgument(argument))
{ {
std::cout << "wrong! device_conv with the specified compilation parameters does not "
"support this Conv problem"
<< std::endl;
return true; return true;
} }
...@@ -139,16 +158,18 @@ bool run_grouped_conv_fwd_dl(bool do_verification, ...@@ -139,16 +158,18 @@ bool run_grouped_conv_fwd_dl(bool do_verification,
if(do_verification) if(do_verification)
{ {
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial, auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<
NDimSpatial,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp>(); ck::tensor_operation::element_wise::PassThrough>();
auto ref_invoker = ref_conv.MakeInvoker(); auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in, auto ref_argument =
ref_conv.MakeArgument(in,
wei, wei,
out_host, out_host,
conv_param.conv_filter_strides_, conv_param.conv_filter_strides_,
...@@ -157,10 +178,14 @@ bool run_grouped_conv_fwd_dl(bool do_verification, ...@@ -157,10 +178,14 @@ bool run_grouped_conv_fwd_dl(bool do_verification,
conv_param.input_right_pads_, conv_param.input_right_pads_,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op); ck::tensor_operation::element_wise::PassThrough{});
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
// cde_elementwise
out_host.ForEach(
[&](auto&, auto idx) { out_element_op(out_host(idx), out_host(idx), bias(idx)); });
out_device_buf.FromDevice(out_device.mData.data()); out_device_buf.FromDevice(out_device.mData.data());
return ck::utils::check_err( return ck::utils::check_err(
......
...@@ -3,13 +3,14 @@ ...@@ -3,13 +3,14 @@
#include "convnd_fwd_dl_common.hpp" #include "convnd_fwd_dl_common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
using InDataType = ck::half_t; using InDataType = ck::half_t;
using WeiDataType = ck::half_t; using WeiDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
using DsDataType = ck::Tuple<ck::half_t>;
using OutDataType = ck::half_t; using OutDataType = ck::half_t;
template <ck::index_t... Is> template <ck::index_t... Is>
...@@ -17,7 +18,7 @@ using S = ck::Sequence<Is...>; ...@@ -17,7 +18,7 @@ using S = ck::Sequence<Is...>;
using InElementOp = ck::tensor_operation::element_wise::PassThrough; using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough; using OutElementOp = ck::tensor_operation::element_wise::AddRelu;
static constexpr auto ConvSpec = static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
...@@ -26,12 +27,12 @@ static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecial ...@@ -26,12 +27,12 @@ static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecial
template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout> template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
// clang-format off // clang-format off
using DeviceGroupedConvNDFwdInstance = ck::tensor_operation::device::DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK using DeviceGroupedConvNDFwdInstance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// ######| NDim| InData| WeiData| OutData| AccData| InLayout| WeiLayout| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // ######| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Spatial| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ######| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ######| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< NDimSpatial, InDataType, WeiDataType, OutDataType, AccDataType, InLayout, WeiLayout, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; < NDimSpatial, InDataType, WeiDataType, DsDataType, OutDataType, AccDataType, InLayout, WeiLayout, ck::Tuple<OutLayout>, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>;
// clang-format on // clang-format on
#include "run_convnd_fwd_dl_example.inc" #include "run_convnd_fwd_dl_example.inc"
......
...@@ -3,13 +3,14 @@ ...@@ -3,13 +3,14 @@
#include "convnd_fwd_dl_common.hpp" #include "convnd_fwd_dl_common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
using InDataType = float; using InDataType = float;
using WeiDataType = float; using WeiDataType = float;
using AccDataType = float; using AccDataType = float;
using DsDataType = ck::Tuple<float>;
using OutDataType = float; using OutDataType = float;
template <ck::index_t... Is> template <ck::index_t... Is>
...@@ -17,7 +18,7 @@ using S = ck::Sequence<Is...>; ...@@ -17,7 +18,7 @@ using S = ck::Sequence<Is...>;
using InElementOp = ck::tensor_operation::element_wise::PassThrough; using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough; using OutElementOp = ck::tensor_operation::element_wise::AddRelu;
static constexpr auto ConvSpec = static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
...@@ -26,12 +27,12 @@ static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecial ...@@ -26,12 +27,12 @@ static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecial
template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout> template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
// clang-format off // clang-format off
using DeviceGroupedConvNDFwdInstance = ck::tensor_operation::device::DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK using DeviceGroupedConvNDFwdInstance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// ######| NDim| InData| WeiData| OutData| AccData| InLayout| WeiLayout| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // ######| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Spatial| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ######| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ######| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< NDimSpatial, InDataType, WeiDataType, OutDataType, AccDataType, InLayout, WeiLayout, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; < NDimSpatial, InDataType, WeiDataType, DsDataType, OutDataType, AccDataType, InLayout, WeiLayout, ck::Tuple<OutLayout>, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>;
// clang-format on // clang-format on
#include "run_convnd_fwd_dl_example.inc" #include "run_convnd_fwd_dl_example.inc"
......
...@@ -3,13 +3,14 @@ ...@@ -3,13 +3,14 @@
#include "convnd_fwd_dl_common.hpp" #include "convnd_fwd_dl_common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
using InDataType = int8_t; using InDataType = int8_t;
using WeiDataType = int8_t; using WeiDataType = int8_t;
using AccDataType = int32_t; using AccDataType = int32_t;
using DsDataType = ck::Tuple<int8_t>;
using OutDataType = int8_t; using OutDataType = int8_t;
template <ck::index_t... Is> template <ck::index_t... Is>
...@@ -17,7 +18,7 @@ using S = ck::Sequence<Is...>; ...@@ -17,7 +18,7 @@ using S = ck::Sequence<Is...>;
using InElementOp = ck::tensor_operation::element_wise::PassThrough; using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough; using OutElementOp = ck::tensor_operation::element_wise::AddRelu;
static constexpr auto ConvSpec = static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
...@@ -26,12 +27,12 @@ static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecial ...@@ -26,12 +27,12 @@ static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecial
template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout> template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
// clang-format off // clang-format off
using DeviceGroupedConvNDFwdInstance = ck::tensor_operation::device::DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK using DeviceGroupedConvNDFwdInstance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// ######| NDim| InData| WeiData| OutData| AccData| InLayout| WeiLayout| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // ######| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Spatial| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ######| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ######| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< NDimSpatial, InDataType, WeiDataType, OutDataType, AccDataType, InLayout, WeiLayout, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; < NDimSpatial, InDataType, WeiDataType, DsDataType, OutDataType, AccDataType, InLayout, WeiLayout, ck::Tuple<OutLayout>, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>;
// clang-format on // clang-format on
#include "run_convnd_fwd_dl_example.inc" #include "run_convnd_fwd_dl_example.inc"
......
...@@ -61,6 +61,7 @@ bool run_convnd_fwd_dl_example(int argc, char* argv[]) ...@@ -61,6 +61,7 @@ bool run_convnd_fwd_dl_example(int argc, char* argv[])
ndim_spatial_value, ndim_spatial_value,
InDataType, InDataType,
WeiDataType, WeiDataType,
DsDataType,
OutDataType, OutDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
......
...@@ -187,6 +187,22 @@ struct AddRelu ...@@ -187,6 +187,22 @@ struct AddRelu
const float a = x0 + type_convert<float>(x1); const float a = x0 + type_convert<float>(x1);
y = a > 0.0f ? a : 0.0f; y = a > 0.0f ? a : 0.0f;
}; };
template <>
__host__ __device__ constexpr void
operator()<int, int, int8_t>(int& y, const int& x0, const int8_t& x1) const
{
const int8_t a = x0 + x1;
y = a > 0 ? a : 0;
};
template <>
__host__ __device__ constexpr void
operator()<int8_t, int8_t, int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
{
const int8_t a = x0 + x1;
y = a > 0 ? a : 0;
};
}; };
struct AddHardswish struct AddHardswish
......
...@@ -131,6 +131,47 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances( ...@@ -131,6 +131,47 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// grouped conv2d forward, NHWGC/GKYXC/NHWGK // grouped conv2d forward, NHWGC/GKYXC/NHWGK
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
...@@ -273,11 +314,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -273,11 +314,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, float>) is_same_v<OutDataType, float>)
{ {
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
} }
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>) is_same_v<OutDataType, half_t>)
{ {
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
} }
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<WeiDataType, ck::bhalf_t> &&
...@@ -289,6 +332,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -289,6 +332,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, int8_t>) is_same_v<OutDataType, int8_t>)
{ {
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances(op_ptrs);
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances(op_ptrs);
} }
} }
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> && else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwd<2,
GNHWC,
GKYXC,
GNHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwd<2,
GNHWC,
GKYXC,
GNHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwd<2,
GNHWC,
GKYXC,
GNHWK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwd<
NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceGroupedConvFwd<NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_dl.hpp"
#include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -199,8 +198,6 @@ bool profile_grouped_conv_fwd_impl(int do_verification, ...@@ -199,8 +198,6 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
} }
}; };
// xdl
{
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<NDimSpatial, using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<NDimSpatial,
InLayout, InLayout,
WeiLayout, WeiLayout,
...@@ -244,49 +241,6 @@ bool profile_grouped_conv_fwd_impl(int do_verification, ...@@ -244,49 +241,6 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
run_impl(op_ptr, argument_ptr); run_impl(op_ptr, argument_ptr);
} }
}
// dl
{
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwd<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "dl found " << op_ptrs.size() << " instances" << std::endl;
for(auto& op_ptr : op_ptrs)
{
auto argument_ptr = op_ptr->MakeArgumentPointer(in_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(),
out_device_buf.GetDeviceBuffer(),
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
run_impl(op_ptr, argument_ptr);
}
}
std::cout << "Best configuration parameters:" std::cout << "Best configuration parameters:"
<< "\nname: " << best_op_name << "\navg_time: " << best_avg_time << "\nname: " << best_op_name << "\navg_time: " << best_avg_time
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment