Commit 857fdad0 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Redesign grouped_conv_bwd_weight instances

parent 139b950f
......@@ -3,7 +3,7 @@
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_bwd_weight_xdl_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
......@@ -26,13 +26,15 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_in
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv2d_bwd_weight_xdl_c_shuffle_bf16_instances<GNHWC,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<2,
GNHWC,
GKYXC,
GNHWK,
ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(instances,
device_grouped_conv2d_bwd_weight_xdl_c_shuffle_bf16_instances<
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<
2,
GNHWC,
GKYXC,
GNHWK,
......
......@@ -3,7 +3,7 @@
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_bwd_weight_xdl_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
......@@ -26,15 +26,15 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv2d_bwd_weight_xdl_c_shuffle_f16_default_instances<
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<2,
GNHWC,
GKYXC,
GNHWK,
ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv2d_bwd_weight_xdl_c_shuffle_f16_default_instances<
add_device_operation_instances(instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<
2,
GNHWC,
GKYXC,
GNHWK,
......
......@@ -3,7 +3,7 @@
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_bwd_weight_xdl_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
......@@ -26,15 +26,15 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv2d_bwd_weight_xdl_c_shuffle_f32_default_instances<
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<2,
GNHWC,
GKYXC,
GNHWK,
ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv2d_bwd_weight_xdl_c_shuffle_f32_default_instances<
add_device_operation_instances(instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<
2,
GNHWC,
GKYXC,
GNHWK,
......
......@@ -3,7 +3,7 @@
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_bwd_weight_xdl_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
......@@ -26,13 +26,15 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_in
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv2d_bwd_weight_xdl_c_shuffle_bf16_instances<NHWGC,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<2,
NHWGC,
GKYXC,
NHWGK,
ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(instances,
device_grouped_conv2d_bwd_weight_xdl_c_shuffle_bf16_instances<
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<
2,
NHWGC,
GKYXC,
NHWGK,
......
......@@ -3,7 +3,7 @@
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_bwd_weight_xdl_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
......@@ -26,15 +26,15 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv2d_bwd_weight_xdl_c_shuffle_f16_default_instances<
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<2,
NHWGC,
GKYXC,
NHWGK,
ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv2d_bwd_weight_xdl_c_shuffle_f16_default_instances<
add_device_operation_instances(instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<
2,
NHWGC,
GKYXC,
NHWGK,
......
......@@ -3,7 +3,7 @@
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_bwd_weight_xdl_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
......@@ -26,15 +26,15 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv2d_bwd_weight_xdl_c_shuffle_f32_default_instances<
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<2,
NHWGC,
GKYXC,
NHWGK,
ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv2d_bwd_weight_xdl_c_shuffle_f32_default_instances<
add_device_operation_instances(instances,
device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<
2,
NHWGC,
GKYXC,
NHWGK,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment