Unverified Commit 82fc5383 authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Enable grouped conv bwd wei bf16 NGCHW (#1589)

* Enable grouped conv bwd wei bf16 NGCHW

* fixes

* fixes

* Fixes

* fixes

* fixes

* Fixes
parent 0394f8a7
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_bf16_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v2>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_bf16_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v5>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKZYXC,
NGKDHW,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_instances<
3,
NGCDHW,
GKZYXC,
NGKDHW,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v2>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NGCDHW,
GKZYXC,
NGKDHW,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_instances<
3,
NGCDHW,
GKZYXC,
NGKDHW,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v5>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
...@@ -24,14 +24,16 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16 ...@@ -24,14 +24,16 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16
// 1. Default // 1. Default
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<3, device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances<
3,
GNDHWC, GNDHWC,
GKZYXC, GKZYXC,
GNDHWK, GNDHWK,
ConvBwdWeightDefault>{}); ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0 // 2. Filter1x1Stride1Pad0
add_device_operation_instances(instances, add_device_operation_instances(
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances< instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances<
3, 3,
GNDHWC, GNDHWC,
GKZYXC, GKZYXC,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
...@@ -25,14 +25,16 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16 ...@@ -25,14 +25,16 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16
// 1. Default // 1. Default
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<3, device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances<
3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
NDHWGK, NDHWGK,
ConvBwdWeightDefault>{}); ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0 // 2. Filter1x1Stride1Pad0
add_device_operation_instances(instances, add_device_operation_instances(
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances< instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_f32_bf16_instances<
3, 3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
......
...@@ -25,7 +25,8 @@ enum struct ConvDataType ...@@ -25,7 +25,8 @@ enum struct ConvDataType
F16_F16_F16, // 1 F16_F16_F16, // 1
BF16_F32_BF16, // 2 BF16_F32_BF16, // 2
F16_F16_F16_BF8_F8, // 3 F16_F16_F16_BF8_F8, // 3
I8_I8_I8 // 4 I8_I8_I8, // 4
BF16_BF16_BF16, // 5
}; };
#define OP_NAME "grouped_conv_bwd_weight" #define OP_NAME "grouped_conv_bwd_weight"
...@@ -38,7 +39,8 @@ static void print_helper_msg() ...@@ -38,7 +39,8 @@ static void print_helper_msg()
<< " 1: Input fp16, Weight fp16, Output fp16\n" << " 1: Input fp16, Weight fp16, Output fp16\n"
<< " 2: Input bf16, Weight fp32, Output bf16\n" << " 2: Input bf16, Weight fp32, Output bf16\n"
<< " 3: Input fp16, Weight fp16, Output fp16, Gemm bf8@fp8\n" << " 3: Input fp16, Weight fp16, Output fp16, Gemm bf8@fp8\n"
<< " 4: Input int8, Weight int8, Output int8)\n" << " 4: Input int8, Weight int8, Output int8\n"
<< " 5: Input bf16, Weight bf16, Output bf16)\n"
<< "arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, " << "arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, "
"N, K, Ho, Wo]\n" "N, K, Ho, Wo]\n"
<< " 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, " << " 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, "
...@@ -187,6 +189,11 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) ...@@ -187,6 +189,11 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
{ {
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{}); return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
} }
if(data_type == ConvDataType::BF16_BF16_BF16)
{
// fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
} }
if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{ {
...@@ -203,6 +210,11 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) ...@@ -203,6 +210,11 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
// fp32 atomic add is used for weight tensor in bf16 kernel // fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
} }
if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(
I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::I8_I8_I8) else if(data_type == ConvDataType::I8_I8_I8)
{ {
return profile( return profile(
...@@ -224,6 +236,11 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) ...@@ -224,6 +236,11 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
// fp32 atomic add is used for weight tensor in bf16 kernel // fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
} }
if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
if(data_type == ConvDataType::F16_F16_F16_BF8_F8) if(data_type == ConvDataType::F16_F16_F16_BF8_F8)
{ {
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, BF8{}, F8{}); return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, BF8{}, F8{});
...@@ -240,6 +257,11 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[]) ...@@ -240,6 +257,11 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
{ {
return profile(I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, F16{}, F16{}, F16{}, F16{}, F16{}); return profile(I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
} }
if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(
I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
} }
std::cout << "this data_type & layout is not implemented" << std::endl; std::cout << "this data_type & layout is not implemented" << std::endl;
......
...@@ -65,8 +65,9 @@ def parse_data_type(args): ...@@ -65,8 +65,9 @@ def parse_data_type(args):
if args.ck_profier_op == "grouped_conv_fwd": if args.ck_profier_op == "grouped_conv_fwd":
args.data_type = 3 args.data_type = 3
if args.data_type == "bfp16": if args.data_type == "bfp16":
if args.ck_profier_op == "grouped_conv_bwd_weight" or \ if args.ck_profier_op == "grouped_conv_bwd_weight":
args.ck_profier_op == "grouped_conv_bwd_data" or \ args.data_type = 5
if args.ck_profier_op == "grouped_conv_bwd_data" or \
args.ck_profier_op == "grouped_conv_fwd": args.ck_profier_op == "grouped_conv_fwd":
args.data_type = 2 args.data_type = 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment