Unverified Commit 4ba52b35 authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Add support for NGCHW in grouped conv fwd (#1499)

* Support NGCHW in grouped conv fwd

* Remove not needed variable

* Fixes
parent 0c39954d
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_merged_groups_f16_instances<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_merged_groups_f16_instances<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
ConvFwd3x3>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_merged_groups_f32_instances<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
ConvFwdDefault>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_merged_groups_f32_instances<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
ConvFwd3x3>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -148,6 +148,11 @@ bool profile_grouped_conv_fwd_impl(int do_verification, ...@@ -148,6 +148,11 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
bool pass = true; bool pass = true;
auto run_impl = [&](auto& op_ptr, auto& argument_ptr) { auto run_impl = [&](auto& op_ptr, auto& argument_ptr) {
// workspace_sz will be equal to 0 for other layout than NGCHW
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
if(op_ptr->IsSupportedArgument(argument_ptr.get())) if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
// re-init output to zero before profiling next kernel // re-init output to zero before profiling next kernel
......
...@@ -45,6 +45,8 @@ static void print_helper_msg() ...@@ -45,6 +45,8 @@ static void print_helper_msg()
"N, Ho, Wo, K]\n" "N, Ho, Wo, K]\n"
<< " 2: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, " << " 2: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, "
"Ho, Wo, G, K]\n" "Ho, Wo, G, K]\n"
<< " 3: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, "
"G, K, Ho, Wo]\n"
<< "arg4: verification (0: no, 1: yes)\n" << "arg4: verification (0: no, 1: yes)\n"
<< "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n" << "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n" << "arg6: print tensor value (0: no; 1: yes)\n"
......
...@@ -15,6 +15,7 @@ enum struct ConvLayout ...@@ -15,6 +15,7 @@ enum struct ConvLayout
{ {
GNHWC_GKYXC_GNHWK, // 0 GNHWC_GKYXC_GNHWK, // 0
NHWGC_GKYXC_NHWGK, // 1 NHWGC_GKYXC_NHWGK, // 1
NGCHW_GKYXC_NGKHW, // 2
}; };
enum struct ConvDataType enum struct ConvDataType
...@@ -54,6 +55,8 @@ static void print_helper_msg() ...@@ -54,6 +55,8 @@ static void print_helper_msg()
<< "arg3: indexing data type (0: 32-bit, 1: 64-bit)\n" << "arg3: indexing data type (0: 32-bit, 1: 64-bit)\n"
<< "arg4: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" << "arg4: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n"
<< " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n" << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n"
<< " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, "
"G, K, Ho, Wo]\n"
<< "arg5: verification (0: no, 1: yes)\n" << "arg5: verification (0: no, 1: yes)\n"
<< "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n" << "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n"
<< "arg7: print tensor value (0: no; 1: yes)\n" << "arg7: print tensor value (0: no; 1: yes)\n"
...@@ -111,6 +114,11 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) ...@@ -111,6 +114,11 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
using GNHWK = ck::tensor_layout::convolution::GNHWK; using GNHWK = ck::tensor_layout::convolution::GNHWK;
using GNDHWK = ck::tensor_layout::convolution::GNDHWK; using GNDHWK = ck::tensor_layout::convolution::GNDHWK;
//
using NGCHW = ck::tensor_layout::convolution::NGCHW;
using NGKHW = ck::tensor_layout::convolution::NGKHW;
// //
using NWGC = ck::tensor_layout::convolution::NWGC; using NWGC = ck::tensor_layout::convolution::NWGC;
using NHWGC = ck::tensor_layout::convolution::NHWGC; using NHWGC = ck::tensor_layout::convolution::NHWGC;
...@@ -284,6 +292,17 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) ...@@ -284,6 +292,17 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
} }
} }
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(data_type == ConvDataType::F32_F32_F32)
......
...@@ -28,6 +28,8 @@ def parse_layouts(args): ...@@ -28,6 +28,8 @@ def parse_layouts(args):
args.in_layout == "NCDHW": args.in_layout == "NCDHW":
if args.ck_profier_op == "grouped_conv_bwd_weight": if args.ck_profier_op == "grouped_conv_bwd_weight":
args.layout = 3 args.layout = 3
elif args.ck_profier_op == "grouped_conv_fwd":
args.layout = 2
else: else:
print('Not supported layout for this op') print('Not supported layout for this op')
exit(1) exit(1)
......
...@@ -62,7 +62,9 @@ using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWC, GKYXC, GNHWK>, ...@@ -62,7 +62,9 @@ using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWC, GKYXC, GNHWK>,
std::tuple<float, NHWGC, GKYXC, NHWGK>, std::tuple<float, NHWGC, GKYXC, NHWGK>,
std::tuple<ck::half_t, NHWGC, GKYXC, NHWGK>, std::tuple<ck::half_t, NHWGC, GKYXC, NHWGK>,
std::tuple<ck::bhalf_t, NHWGC, GKYXC, NHWGK>, std::tuple<ck::bhalf_t, NHWGC, GKYXC, NHWGK>,
std::tuple<int8_t, NHWGC, GKYXC, NHWGK>>; std::tuple<int8_t, NHWGC, GKYXC, NHWGK>,
std::tuple<float, NGCHW, GKYXC, NGKHW>,
std::tuple<ck::half_t, NGCHW, GKYXC, NGKHW>>;
using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWC, GKZYXC, GNDHWK>, using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWC, GKZYXC, GNDHWK>,
std::tuple<ck::half_t, GNDHWC, GKZYXC, GNDHWK>, std::tuple<ck::half_t, GNDHWC, GKZYXC, GNDHWK>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment