Unverified Commit 82fc5383 authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Enable grouped conv bwd wei bf16 NGCHW (#1589)

* Enable grouped conv bwd wei bf16 NGCHW

* fixes

* fixes

* Fixes

* fixes

* fixes

* Fixes
parent 0394f8a7
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_bf16_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v2>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_bf16_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v5>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKZYXC,
NGKDHW,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_instances<
3,
NGCDHW,
GKZYXC,
NGKDHW,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v2>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKZYXC,
NGKDHW,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_instances<
3,
NGCDHW,
GKZYXC,
NGKDHW,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v5>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
......@@ -24,14 +24,16 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<3,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances<
3,
GNDHWC,
GKZYXC,
GNDHWK,
ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances<
3,
GNDHWC,
GKZYXC,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
......@@ -25,14 +25,16 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<3,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances<
3,
NDHWGC,
GKZYXC,
......
......@@ -25,7 +25,8 @@ enum struct ConvDataType
F16_F16_F16, // 1
BF16_F32_BF16, // 2
F16_F16_F16_BF8_F8, // 3
I8_I8_I8 // 4
I8_I8_I8, // 4
BF16_BF16_BF16, // 5
};
#define OP_NAME "grouped_conv_bwd_weight"
......@@ -38,7 +39,8 @@ static void print_helper_msg()
<< " 1: Input fp16, Weight fp16, Output fp16\n"
<< " 2: Input bf16, Weight fp32, Output bf16\n"
<< " 3: Input fp16, Weight fp16, Output fp16, Gemm bf8@fp8\n"
<< " 4: Input int8, Weight int8, Output int8)\n"
<< " 4: Input int8, Weight int8, Output int8\n"
<< " 5: Input bf16, Weight bf16, Output bf16)\n"
<< "arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, "
"N, K, Ho, Wo]\n"
<< " 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, "
......@@ -187,6 +189,11 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
{
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
if(data_type == ConvDataType::BF16_BF16_BF16)
{
// fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
}
if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{
......@@ -203,6 +210,11 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
// fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
}
if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(
I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::I8_I8_I8)
{
return profile(
......@@ -224,6 +236,11 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
// fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
}
if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
if(data_type == ConvDataType::F16_F16_F16_BF8_F8)
{
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, BF8{}, F8{});
......@@ -240,6 +257,11 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
{
return profile(I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(
I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
}
std::cout << "this data_type & layout is not implemented" << std::endl;
......
......@@ -65,8 +65,9 @@ def parse_data_type(args):
if args.ck_profier_op == "grouped_conv_fwd":
args.data_type = 3
if args.data_type == "bfp16":
if args.ck_profier_op == "grouped_conv_bwd_weight" or \
args.ck_profier_op == "grouped_conv_bwd_data" or \
if args.ck_profier_op == "grouped_conv_bwd_weight":
args.data_type = 5
if args.ck_profier_op == "grouped_conv_bwd_data" or \
args.ck_profier_op == "grouped_conv_fwd":
args.data_type = 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment