Commit 857fdad0 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Redesign grouped_conv_bwd_weight instances

parent 139b950f
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_bwd_weight_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -26,13 +26,15 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_in ...@@ -26,13 +26,15 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_in
// 1. Default // 1. Default
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv2d_bwd_weight_xdl_c_shuffle_bf16_instances<GNHWC, device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<2,
GKYXC, GNHWC,
GNHWK, GKYXC,
ConvBwdWeightDefault>{}); GNHWK,
ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0 // 2. Filter1x1Stride1Pad0
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_bwd_weight_xdl_c_shuffle_bf16_instances< device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<
2,
GNHWC, GNHWC,
GKYXC, GKYXC,
GNHWK, GNHWK,
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_bwd_weight_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -26,19 +26,19 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances( ...@@ -26,19 +26,19 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
// 1. Default // 1. Default
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv2d_bwd_weight_xdl_c_shuffle_f16_default_instances< device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<2,
GNHWC, GNHWC,
GKYXC, GKYXC,
GNHWK, GNHWK,
ConvBwdWeightDefault>{}); ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0 // 2. Filter1x1Stride1Pad0
add_device_operation_instances( add_device_operation_instances(instances,
instances, device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<
device_grouped_conv2d_bwd_weight_xdl_c_shuffle_f16_default_instances< 2,
GNHWC, GNHWC,
GKYXC, GKYXC,
GNHWK, GNHWK,
ConvBwdWeightFilter1x1Stride1Pad0>{}); ConvBwdWeightFilter1x1Stride1Pad0>{});
} }
} // namespace instance } // namespace instance
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_bwd_weight_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -26,19 +26,19 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( ...@@ -26,19 +26,19 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
// 1. Default // 1. Default
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv2d_bwd_weight_xdl_c_shuffle_f32_default_instances< device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<2,
GNHWC, GNHWC,
GKYXC, GKYXC,
GNHWK, GNHWK,
ConvBwdWeightDefault>{}); ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0 // 2. Filter1x1Stride1Pad0
add_device_operation_instances( add_device_operation_instances(instances,
instances, device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<
device_grouped_conv2d_bwd_weight_xdl_c_shuffle_f32_default_instances< 2,
GNHWC, GNHWC,
GKYXC, GKYXC,
GNHWK, GNHWK,
ConvBwdWeightFilter1x1Stride1Pad0>{}); ConvBwdWeightFilter1x1Stride1Pad0>{});
} }
} // namespace instance } // namespace instance
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_bwd_weight_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -26,13 +26,15 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_in ...@@ -26,13 +26,15 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_in
// 1. Default // 1. Default
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv2d_bwd_weight_xdl_c_shuffle_bf16_instances<NHWGC, device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<2,
GKYXC, NHWGC,
NHWGK, GKYXC,
ConvBwdWeightDefault>{}); NHWGK,
ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0 // 2. Filter1x1Stride1Pad0
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_bwd_weight_xdl_c_shuffle_bf16_instances< device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances<
2,
NHWGC, NHWGC,
GKYXC, GKYXC,
NHWGK, NHWGK,
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_bwd_weight_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -26,19 +26,19 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances( ...@@ -26,19 +26,19 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
// 1. Default // 1. Default
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv2d_bwd_weight_xdl_c_shuffle_f16_default_instances< device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
NHWGK, NHWGK,
ConvBwdWeightDefault>{}); ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0 // 2. Filter1x1Stride1Pad0
add_device_operation_instances( add_device_operation_instances(instances,
instances, device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_instances<
device_grouped_conv2d_bwd_weight_xdl_c_shuffle_f16_default_instances< 2,
NHWGC, NHWGC,
GKYXC, GKYXC,
NHWGK, NHWGK,
ConvBwdWeightFilter1x1Stride1Pad0>{}); ConvBwdWeightFilter1x1Stride1Pad0>{});
} }
} // namespace instance } // namespace instance
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_bwd_weight_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -26,19 +26,19 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( ...@@ -26,19 +26,19 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
// 1. Default // 1. Default
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_grouped_conv2d_bwd_weight_xdl_c_shuffle_f32_default_instances< device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
NHWGK, NHWGK,
ConvBwdWeightDefault>{}); ConvBwdWeightDefault>{});
// 2. Filter1x1Stride1Pad0 // 2. Filter1x1Stride1Pad0
add_device_operation_instances( add_device_operation_instances(instances,
instances, device_grouped_conv_bwd_weight_xdl_c_shuffle_f32_instances<
device_grouped_conv2d_bwd_weight_xdl_c_shuffle_f32_default_instances< 2,
NHWGC, NHWGC,
GKYXC, GKYXC,
NHWGK, NHWGK,
ConvBwdWeightFilter1x1Stride1Pad0>{}); ConvBwdWeightFilter1x1Stride1Pad0>{});
} }
} // namespace instance } // namespace instance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment