Commit 952f02ff authored by letaoqin's avatar letaoqin
Browse files

change example for multiple d

parent 1c5b6f7d
...@@ -45,8 +45,10 @@ bool run_grouped_conv_fwd_dl(bool do_verification, ...@@ -45,8 +45,10 @@ bool run_grouped_conv_fwd_dl(bool do_verification,
const WeiElementOp& wei_element_op, const WeiElementOp& wei_element_op,
const OutElementOp& out_element_op) const OutElementOp& out_element_op)
{ {
using DDataType = OutDataType;
Tensor<InDataType> in(in_g_n_c_wis_desc); Tensor<InDataType> in(in_g_n_c_wis_desc);
Tensor<WeiDataType> wei(wei_g_k_c_xs_desc); Tensor<WeiDataType> wei(wei_g_k_c_xs_desc);
Tensor<DDataType> bias(out_g_n_k_wos_desc);
Tensor<OutDataType> out_host(out_g_n_k_wos_desc); Tensor<OutDataType> out_host(out_g_n_k_wos_desc);
Tensor<OutDataType> out_device(out_g_n_k_wos_desc); Tensor<OutDataType> out_device(out_g_n_k_wos_desc);
...@@ -58,31 +60,38 @@ bool run_grouped_conv_fwd_dl(bool do_verification, ...@@ -58,31 +60,38 @@ bool run_grouped_conv_fwd_dl(bool do_verification,
{ {
case 0: break; case 0: break;
case 1: case 1:
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}); in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-2, 3});
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5}); wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-2, 3});
bias.GenerateTensorValue(GeneratorTensor_2<DDataType>{-1, 1});
break; break;
case 2: case 2:
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0}); in.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5}); wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
bias.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
break; break;
default: default:
in.GenerateTensorValue(GeneratorTensor_1<InDataType>{1}); in.GenerateTensorValue(GeneratorTensor_1<InDataType>{1});
wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1}); wei.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{-1});
bias.GenerateTensorValue(GeneratorTensor_1<DDataType>{1});
} }
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize());
DeviceMem bias_device_buf(sizeof(DDataType) * bias.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(in.mData.data()); in_device_buf.ToDevice(in.mData.data());
wei_device_buf.ToDevice(wei.mData.data()); wei_device_buf.ToDevice(wei.mData.data());
bias_device_buf.ToDevice(bias.mData.data());
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{}; std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{}; std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{}; std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{}; std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{};
std::array<ck::index_t, NDimSpatial + 3> c_g_n_k_wos_lengths{}; std::array<ck::index_t, NDimSpatial + 3> d_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> c_g_n_k_wos_strides{}; std::array<ck::index_t, NDimSpatial + 3> d_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{}; std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{}; std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{}; std::array<ck::index_t, NDimSpatial> input_left_pads{};
...@@ -94,8 +103,10 @@ bool run_grouped_conv_fwd_dl(bool do_verification, ...@@ -94,8 +103,10 @@ bool run_grouped_conv_fwd_dl(bool do_verification,
copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides);
copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths);
copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides);
copy(out_g_n_k_wos_desc.GetLengths(), c_g_n_k_wos_lengths); copy(out_g_n_k_wos_desc.GetLengths(), d_g_n_k_wos_lengths);
copy(out_g_n_k_wos_desc.GetStrides(), c_g_n_k_wos_strides); copy(out_g_n_k_wos_desc.GetStrides(), d_g_n_k_wos_strides);
copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths);
copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides);
copy(conv_param.conv_filter_strides_, conv_filter_strides); copy(conv_param.conv_filter_strides_, conv_filter_strides);
copy(conv_param.conv_filter_dilations_, conv_filter_dilations); copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads); copy(conv_param.input_left_pads_, input_left_pads);
...@@ -104,15 +115,19 @@ bool run_grouped_conv_fwd_dl(bool do_verification, ...@@ -104,15 +115,19 @@ bool run_grouped_conv_fwd_dl(bool do_verification,
// do Conv // do Conv
auto conv = DeviceConvNDFwdInstance{}; auto conv = DeviceConvNDFwdInstance{};
auto invoker = conv.MakeInvoker(); auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), auto argument = conv.MakeArgument(
in_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(), wei_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{bias_device_buf.GetDeviceBuffer()},
out_device_buf.GetDeviceBuffer(), out_device_buf.GetDeviceBuffer(),
a_g_n_c_wis_lengths, a_g_n_c_wis_lengths,
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
c_g_n_k_wos_lengths, std::array<std::array<ck::index_t, NDimSpatial + 3>, 1>{{d_g_n_k_wos_lengths}},
c_g_n_k_wos_strides, std::array<std::array<ck::index_t, NDimSpatial + 3>, 1>{{d_g_n_k_wos_strides}},
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -123,7 +138,9 @@ bool run_grouped_conv_fwd_dl(bool do_verification, ...@@ -123,7 +138,9 @@ bool run_grouped_conv_fwd_dl(bool do_verification,
if(!conv.IsSupportedArgument(argument)) if(!conv.IsSupportedArgument(argument))
{ {
return true; throw std::runtime_error(
"wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem");
} }
float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
...@@ -138,16 +155,18 @@ bool run_grouped_conv_fwd_dl(bool do_verification, ...@@ -138,16 +155,18 @@ bool run_grouped_conv_fwd_dl(bool do_verification,
if(do_verification) if(do_verification)
{ {
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial, auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<
NDimSpatial,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp>(); ck::tensor_operation::element_wise::PassThrough>();
auto ref_invoker = ref_conv.MakeInvoker(); auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in, auto ref_argument =
ref_conv.MakeArgument(in,
wei, wei,
out_host, out_host,
conv_param.conv_filter_strides_, conv_param.conv_filter_strides_,
...@@ -156,10 +175,14 @@ bool run_grouped_conv_fwd_dl(bool do_verification, ...@@ -156,10 +175,14 @@ bool run_grouped_conv_fwd_dl(bool do_verification,
conv_param.input_right_pads_, conv_param.input_right_pads_,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op); ck::tensor_operation::element_wise::PassThrough{});
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
// cde_elementwise
out_host.ForEach(
[&](auto&, auto idx) { out_element_op(out_host(idx), out_host(idx), bias(idx)); });
out_device_buf.FromDevice(out_device.mData.data()); out_device_buf.FromDevice(out_device.mData.data());
return ck::utils::check_err( return ck::utils::check_err(
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "convnd_fwd_dl_common.hpp" #include "convnd_fwd_dl_common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
...@@ -17,7 +17,7 @@ using S = ck::Sequence<Is...>; ...@@ -17,7 +17,7 @@ using S = ck::Sequence<Is...>;
using InElementOp = ck::tensor_operation::element_wise::PassThrough; using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough; using OutElementOp = ck::tensor_operation::element_wise::AddRelu;
static constexpr auto ConvSpec = static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
...@@ -26,12 +26,12 @@ static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecial ...@@ -26,12 +26,12 @@ static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecial
template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout> template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
// clang-format off // clang-format off
using DeviceGroupedConvNDFwdInstance = ck::tensor_operation::device::DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK using DeviceGroupedConvNDFwdInstance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// ######| NDim| InData| WeiData| OutData| AccData| InLayout| WeiLayout| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // ######| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Spatial| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ######| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ######| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< NDimSpatial, InDataType, WeiDataType, OutDataType, AccDataType, InLayout, WeiLayout, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; < NDimSpatial, InDataType, WeiDataType, ck::Tuple<OutDataType>, OutDataType, AccDataType, InLayout, WeiLayout, ck::Tuple<OutLayout>, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>;
// clang-format on // clang-format on
#include "run_convnd_fwd_dl_example.inc" #include "run_convnd_fwd_dl_example.inc"
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "convnd_fwd_dl_common.hpp" #include "convnd_fwd_dl_common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
...@@ -17,7 +17,7 @@ using S = ck::Sequence<Is...>; ...@@ -17,7 +17,7 @@ using S = ck::Sequence<Is...>;
using InElementOp = ck::tensor_operation::element_wise::PassThrough; using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough; using OutElementOp = ck::tensor_operation::element_wise::AddRelu;
static constexpr auto ConvSpec = static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
...@@ -26,12 +26,12 @@ static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecial ...@@ -26,12 +26,12 @@ static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecial
template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout> template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
// clang-format off // clang-format off
using DeviceGroupedConvNDFwdInstance = ck::tensor_operation::device::DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK using DeviceGroupedConvNDFwdInstance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// ######| NDim| InData| WeiData| OutData| AccData| InLayout| WeiLayout| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // ######| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Spatial| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ######| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ######| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< NDimSpatial, InDataType, WeiDataType, OutDataType, AccDataType, InLayout, WeiLayout, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; < NDimSpatial, InDataType, WeiDataType, ck::Tuple<OutDataType>, OutDataType, AccDataType, InLayout, WeiLayout, ck::Tuple<OutLayout>, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>;
// clang-format on // clang-format on
#include "run_convnd_fwd_dl_example.inc" #include "run_convnd_fwd_dl_example.inc"
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "convnd_fwd_dl_common.hpp" #include "convnd_fwd_dl_common.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
...@@ -17,7 +17,7 @@ using S = ck::Sequence<Is...>; ...@@ -17,7 +17,7 @@ using S = ck::Sequence<Is...>;
using InElementOp = ck::tensor_operation::element_wise::PassThrough; using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough; using OutElementOp = ck::tensor_operation::element_wise::AddRelu;
static constexpr auto ConvSpec = static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
...@@ -26,12 +26,12 @@ static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecial ...@@ -26,12 +26,12 @@ static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecial
template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout> template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
// clang-format off // clang-format off
using DeviceGroupedConvNDFwdInstance = ck::tensor_operation::device::DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK using DeviceGroupedConvNDFwdInstance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// ######| NDim| InData| WeiData| OutData| AccData| InLayout| WeiLayout| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // ######| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Spatial| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ######| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ######| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< NDimSpatial, InDataType, WeiDataType, OutDataType, AccDataType, InLayout, WeiLayout, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; < NDimSpatial, InDataType, WeiDataType, ck::Tuple<OutDataType>, OutDataType, AccDataType, InLayout, WeiLayout, ck::Tuple<OutLayout>, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>;
// clang-format on // clang-format on
#include "run_convnd_fwd_dl_example.inc" #include "run_convnd_fwd_dl_example.inc"
......
...@@ -18,7 +18,6 @@ using S = ck::Sequence<Is...>; ...@@ -18,7 +18,6 @@ using S = ck::Sequence<Is...>;
using InElementOp = ck::tensor_operation::element_wise::PassThrough; using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::Relu; using OutElementOp = ck::tensor_operation::element_wise::Relu;
;
static constexpr auto ConvSpec = static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
......
...@@ -11,7 +11,7 @@ bool run_convnd_fwd_dl_example(int argc, char* argv[]) ...@@ -11,7 +11,7 @@ bool run_convnd_fwd_dl_example(int argc, char* argv[])
bool time_kernel = false; bool time_kernel = false;
ck::utils::conv::ConvParam conv_param{ ck::utils::conv::ConvParam conv_param{
2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}; 2, 1, 128, 256, 192, {3, 3}, {14, 14}, {2, 2}, {1, 1}, {1, 1}, {1, 1}};
if(argc == 1) if(argc == 1)
{ {
......
...@@ -218,12 +218,12 @@ __global__ void ...@@ -218,12 +218,12 @@ __global__ void
template <index_t NDimSpatial, template <index_t NDimSpatial,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename DsLayout, typename DsDataType,
typename EDataType, typename EDataType,
typename AccDataType, typename AccDataType,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename DsDataType, typename DsLayout,
typename ELayout, typename ELayout,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
...@@ -392,7 +392,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -392,7 +392,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
GridwiseGemmDlMultipleD_km_kn_mn<BlockSize, GridwiseGemmDlMultipleD_km_kn_mn<BlockSize,
ADataType, ADataType,
AccDataType, AccDataType,
DsLayout, DsDataType,
EDataType, EDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
......
...@@ -187,6 +187,22 @@ struct AddRelu ...@@ -187,6 +187,22 @@ struct AddRelu
const float a = x0 + type_convert<float>(x1); const float a = x0 + type_convert<float>(x1);
y = a > 0.0f ? a : 0.0f; y = a > 0.0f ? a : 0.0f;
}; };
template <>
__host__ __device__ constexpr void
operator()<int, int, int8_t>(int& y, const int& x0, const int8_t& x1) const
{
const int a = x0 + x1;
y = a > 0 ? a : 0;
};
template <>
__host__ __device__ constexpr void
operator()<int8_t, int8_t, int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
{
const int8_t a = x0 + x1;
y = a > 0 ? a : 0;
};
}; };
struct AddHardswish struct AddHardswish
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment