Unverified Commit 309b1c64 authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Fixed Weight layout of grouped_conv 3d fwd (#743)



* Changed wei layout

* changed layout for examples

* fixed client example

---------
Co-authored-by: default avatarroot <root@ctr-ubbsmc15.amd.com>
parent c5f6ec84
...@@ -141,14 +141,10 @@ bool run_grouped_conv_fwd(std::array<ck::index_t, NumDimSpatial + NumNonSpatialD ...@@ -141,14 +141,10 @@ bool run_grouped_conv_fwd(std::array<ck::index_t, NumDimSpatial + NumNonSpatialD
std::next(rbegin(in_strides)), std::next(rbegin(in_strides)),
std::next(rbegin(in_strides), NumDimSpatial + 1)); std::next(rbegin(in_strides), NumDimSpatial + 1));
std::rotate(
std::next(rbegin(wei_lengths)), std::next(rbegin(wei_lengths), 2), rend(wei_lengths));
std::rotate(rbegin(wei_lengths), std::rotate(rbegin(wei_lengths),
std::next(rbegin(wei_lengths)), std::next(rbegin(wei_lengths)),
std::next(rbegin(wei_lengths), NumDimSpatial + 1)); std::next(rbegin(wei_lengths), NumDimSpatial + 1));
std::rotate(
std::next(rbegin(wei_strides)), std::next(rbegin(wei_strides), 2), rend(wei_strides));
std::rotate(rbegin(wei_strides), std::rotate(rbegin(wei_strides),
std::next(rbegin(wei_strides)), std::next(rbegin(wei_strides)),
std::next(rbegin(wei_strides), NumDimSpatial + 1)); std::next(rbegin(wei_strides), NumDimSpatial + 1));
......
...@@ -11,7 +11,7 @@ using WeiDataType = ck::half_t; ...@@ -11,7 +11,7 @@ using WeiDataType = ck::half_t;
using OutDataType = ck::half_t; using OutDataType = ck::half_t;
using InLayout = ck::tensor_layout::convolution::NDHWGC; using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::KZYXGC; using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK; using OutLayout = ck::tensor_layout::convolution::NDHWGK;
static constexpr ck::index_t NumDimSpatial = 3; static constexpr ck::index_t NumDimSpatial = 3;
...@@ -38,7 +38,7 @@ int main() ...@@ -38,7 +38,7 @@ int main()
InLayout, InLayout,
WeiLayout, WeiLayout,
OutLayout>( OutLayout>(
{N, Di, Hi, Wi, G, C}, {K, Z, Y, X, G, C}, {N, Do, Ho, Wo, G, K}) {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K})
? EXIT_SUCCESS ? EXIT_SUCCESS
: EXIT_FAILURE; : EXIT_FAILURE;
} }
...@@ -11,7 +11,7 @@ using WeiDataType = float; ...@@ -11,7 +11,7 @@ using WeiDataType = float;
using OutDataType = float; using OutDataType = float;
using InLayout = ck::tensor_layout::convolution::NDHWGC; using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::KZYXGC; using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK; using OutLayout = ck::tensor_layout::convolution::NDHWGK;
static constexpr ck::index_t NumDimSpatial = 3; static constexpr ck::index_t NumDimSpatial = 3;
...@@ -38,7 +38,7 @@ int main() ...@@ -38,7 +38,7 @@ int main()
InLayout, InLayout,
WeiLayout, WeiLayout,
OutLayout>( OutLayout>(
{N, Di, Hi, Wi, G, C}, {K, Z, Y, X, G, C}, {N, Do, Ho, Wo, G, K}) {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K})
? EXIT_SUCCESS ? EXIT_SUCCESS
: EXIT_FAILURE; : EXIT_FAILURE;
} }
...@@ -245,11 +245,11 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances( ...@@ -245,11 +245,11 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
// grouped conv3d forward, NDHWGC/KZYXGC/NDHWGK // grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances( void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC, NDHWGC,
KZYXGC, GKZYXC,
Empty_Tuple, Empty_Tuple,
NDHWGK, NDHWGK,
BF16, BF16,
...@@ -260,10 +260,10 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances( ...@@ -260,10 +260,10 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances( void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC, NDHWGC,
KZYXGC, GKZYXC,
Empty_Tuple, Empty_Tuple,
NDHWGK, NDHWGK,
F16, F16,
...@@ -274,10 +274,10 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances( ...@@ -274,10 +274,10 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances( void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC, NDHWGC,
KZYXGC, GKZYXC,
Empty_Tuple, Empty_Tuple,
NDHWGK, NDHWGK,
F32, F32,
...@@ -288,10 +288,10 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances( ...@@ -288,10 +288,10 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instances( void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC, NDHWGC,
KZYXGC, GKZYXC,
Empty_Tuple, Empty_Tuple,
NDHWGK, NDHWGK,
int8_t, int8_t,
...@@ -433,28 +433,28 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -433,28 +433,28 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
} }
} }
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> && else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
is_same_v<WeiLayout, KZYXGC> && is_same_v<OutLayout, NDHWGK>) is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK>)
{ {
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> && if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>) is_same_v<OutDataType, float>)
{ {
add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs);
} }
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>) is_same_v<OutDataType, half_t>)
{ {
add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs);
} }
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>) is_same_v<OutDataType, ck::bhalf_t>)
{ {
add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(op_ptrs);
} }
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> && else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>) is_same_v<OutDataType, int8_t>)
{ {
add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(op_ptrs);
} }
} }
......
...@@ -4,8 +4,8 @@ add_instance_library(device_grouped_conv3d_fwd_instance ...@@ -4,8 +4,8 @@ add_instance_library(device_grouped_conv3d_fwd_instance
device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp
device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instance.cpp device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instance.cpp device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instance.cpp device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instance.cpp device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment