"driver/vscode:/vscode.git/clone" did not exist on "4a66157846abd463fad2cbd32ad2d8da573f0f66"
Unverified Commit 73b67f29 authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Add support for NGCHW in grouped conv bwd wei (#1491)

* Add support for NGCHW in grouped conv bwd wei

* Comments fixes

* navi fixes

* Update function names
parent a9b170b5
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKZYXC,
NGKDHW,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances<
3,
NGCDHW,
GKZYXC,
NGKDHW,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v5>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -16,6 +16,7 @@ enum struct ConvLayout ...@@ -16,6 +16,7 @@ enum struct ConvLayout
GNCHW_GKCYX_GNKHW, // 0 GNCHW_GKCYX_GNKHW, // 0
GNHWC_GKYXC_GNHWK, // 1 GNHWC_GKYXC_GNHWK, // 1
NHWGC_GKYXC_NHWGK, // 2 NHWGC_GKYXC_NHWGK, // 2
NGCHW_GKYXC_NGKHW, // 3
}; };
enum struct ConvDataType enum struct ConvDataType
...@@ -178,6 +179,13 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) ...@@ -178,6 +179,13 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
} }
} }
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
{
if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
}
if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{ {
if(data_type == ConvDataType::F32_F32_F32) if(data_type == ConvDataType::F32_F32_F32)
...@@ -224,6 +232,13 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) ...@@ -224,6 +232,13 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}); I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{});
} }
} }
else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
{
if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
}
std::cout << "this data_type & layout is not implemented" << std::endl; std::cout << "this data_type & layout is not implemented" << std::endl;
......
...@@ -23,6 +23,26 @@ def run_ck_profiler_cmd(cmd): ...@@ -23,6 +23,26 @@ def run_ck_profiler_cmd(cmd):
subprocess.run(cmd) subprocess.run(cmd)
def parse_layouts(args):
if args.in_layout == "NCW" or args.in_layout == "NCHW" or \
args.in_layout == "NCDHW":
if args.ck_profier_op == "grouped_conv_bwd_weight":
args.layout = 3
else:
print('Not supported layout for this op')
exit(1)
elif args.in_layout == "NWC" or args.in_layout == "NHWC" or \
args.in_layout == "NDHWC":
if args.ck_profier_op == "grouped_conv_bwd_weight":
args.layout = 2
elif args.ck_profier_op == "grouped_conv_bwd_data" or \
args.ck_profier_op == "grouped_conv_fwd":
args.layout = 1
else:
print('Not supported layout for this op')
exit(1)
def parse_data_type(args): def parse_data_type(args):
if args.data_type == "fp32": if args.data_type == "fp32":
if args.ck_profier_op == "grouped_conv_bwd_weight" or \ if args.ck_profier_op == "grouped_conv_bwd_weight" or \
...@@ -79,8 +99,7 @@ def add_conv_params_to_cmd(args, cmd): ...@@ -79,8 +99,7 @@ def add_conv_params_to_cmd(args, cmd):
def run_ck_grouped_conv_fwd(args): def run_ck_grouped_conv_fwd(args):
args.ck_profier_op = "grouped_conv_fwd" args.ck_profier_op = "grouped_conv_fwd"
parse_data_type(args) parse_data_type(args)
# default for MIOpen NHWGC parse_layouts(args)
args.layout = 1
# use int32 by default # use int32 by default
args.index_type = 0 args.index_type = 0
...@@ -99,8 +118,7 @@ def run_ck_grouped_conv_fwd(args): ...@@ -99,8 +118,7 @@ def run_ck_grouped_conv_fwd(args):
def run_ck_grouped_conv_bwd_data(args): def run_ck_grouped_conv_bwd_data(args):
args.ck_profier_op = "grouped_conv_bwd_data" args.ck_profier_op = "grouped_conv_bwd_data"
parse_data_type(args) parse_data_type(args)
# default for MIOpen NHWGC parse_layouts(args)
args.layout = 1
cmd = [str(args.ck_profiler_cmd), str(args.ck_profier_op)] cmd = [str(args.ck_profiler_cmd), str(args.ck_profier_op)]
cmd += [str(args.data_type), str(args.layout)] cmd += [str(args.data_type), str(args.layout)]
...@@ -117,8 +135,7 @@ def run_ck_grouped_conv_bwd_data(args): ...@@ -117,8 +135,7 @@ def run_ck_grouped_conv_bwd_data(args):
def run_ck_grouped_conv_bwd_weight(args): def run_ck_grouped_conv_bwd_weight(args):
args.ck_profier_op = "grouped_conv_bwd_weight" args.ck_profier_op = "grouped_conv_bwd_weight"
parse_data_type(args) parse_data_type(args)
# default for MIOpen NHWGC parse_layouts(args)
args.layout = 2
# Test all split K value from the list {1, 2, 4, 8, 32, 64, 128} # Test all split K value from the list {1, 2, 4, 8, 32, 64, 128}
args.split_k_value = -1 args.split_k_value = -1
...@@ -181,8 +198,8 @@ if __name__ == "__main__": ...@@ -181,8 +198,8 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"-in_layout", "-in_layout",
"-I", "-I",
default=-1, default="NCHW",
type=int, type=str,
required=False, required=False,
help="Input Layout (Default=NCHW for 2d conv, NCDHW for 3d conv)" help="Input Layout (Default=NCHW for 2d conv, NCDHW for 3d conv)"
) )
......
...@@ -66,6 +66,12 @@ class TestGroupedConvndBwdWeight : public ::testing::Test ...@@ -66,6 +66,12 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
{ {
return true; return true;
} }
// Skip due to the lack of kernels for NGCDHW
if constexpr(std::is_same_v<InLayout, NGCW> || std::is_same_v<InLayout, NGCHW> ||
std::is_same_v<InLayout, NGCDHW>)
{
return true;
}
} }
else else
{ {
...@@ -139,7 +145,8 @@ using KernelTypes2d = ::testing::Types< ...@@ -139,7 +145,8 @@ using KernelTypes2d = ::testing::Types<
std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNHWC, GKYXC, GNHWK, ck::Number<2>>, std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNHWC, GKYXC, GNHWK, ck::Number<2>>,
std::tuple<float, float, float, NHWGC, GKYXC, NHWGK, ck::Number<2>>, std::tuple<float, float, float, NHWGC, GKYXC, NHWGK, ck::Number<2>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, NHWGC, GKYXC, NHWGK, ck::Number<2>>, std::tuple<ck::half_t, ck::half_t, ck::half_t, NHWGC, GKYXC, NHWGK, ck::Number<2>>,
std::tuple<ck::bhalf_t, float, ck::bhalf_t, NHWGC, GKYXC, NHWGK, ck::Number<2>>>; std::tuple<ck::bhalf_t, float, ck::bhalf_t, NHWGC, GKYXC, NHWGK, ck::Number<2>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, NGCHW, GKYXC, NGKHW, ck::Number<2>>>;
using KernelTypes3d = ::testing::Types< using KernelTypes3d = ::testing::Types<
std::tuple<float, float, float, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>, std::tuple<float, float, float, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>, std::tuple<ck::half_t, ck::half_t, ck::half_t, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>,
...@@ -148,7 +155,8 @@ using KernelTypes3d = ::testing::Types< ...@@ -148,7 +155,8 @@ using KernelTypes3d = ::testing::Types<
std::tuple<float, float, float, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>, std::tuple<float, float, float, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>, std::tuple<ck::half_t, ck::half_t, ck::half_t, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>,
std::tuple<ck::bhalf_t, float, ck::bhalf_t, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>, std::tuple<ck::bhalf_t, float, ck::bhalf_t, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>,
std::tuple<int8_t, int8_t, int8_t, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>>; std::tuple<int8_t, int8_t, int8_t, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, NGCDHW, GKZYXC, NGKDHW, ck::Number<3>>>;
TYPED_TEST_SUITE(TestGroupedConvndBwdWeight1d, KernelTypes1d); TYPED_TEST_SUITE(TestGroupedConvndBwdWeight1d, KernelTypes1d);
TYPED_TEST_SUITE(TestGroupedConvndBwdWeight2d, KernelTypes2d); TYPED_TEST_SUITE(TestGroupedConvndBwdWeight2d, KernelTypes2d);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment