"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "fe04a5bab23c974b043e3054f30c3f1de038d3f9"
Unverified Commit 472fa029 authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Enable grouped conv with small K or C (#822)

* Enable grouped conv with small K or C

* Add missing instances

* Refactor grouped conv fwd instances

* Fix fp16 instances since it supports src_per_vec %2 = 0

* Add generic instances
parent 9c54eaab
...@@ -378,6 +378,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -378,6 +378,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const index_t GemmM = K; const index_t GemmM = K;
const index_t GemmN = C * X; const index_t GemmN = C * X;
const auto PadGemmM = (MPerBlock - GemmM % MPerBlock) % MPerBlock;
const auto PadGemmN = (NPerBlock - GemmN % NPerBlock) % NPerBlock;
const index_t GemmKBatch = batch_k; const index_t GemmKBatch = batch_k;
const index_t GemmK0 = const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
...@@ -496,9 +499,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -496,9 +499,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const auto wei_gemmm_gemmn_grid_desc = const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); make_naive_tensor_descriptor_packed(make_tuple(K, X * C));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, // Padd
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
wei_gemmm_gemmn_grid_desc); transform_tensor_descriptor(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch),
make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmM, PadGemmM),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch),
make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmN, PadGemmN),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto wei_gemmm_gemmn_pad_grid_desc =
transform_tensor_descriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
wei_gemmm_gemmn_pad_grid_desc);
} }
} }
...@@ -546,6 +577,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -546,6 +577,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const index_t GemmM = K; const index_t GemmM = K;
const index_t GemmN = C * X * Y; const index_t GemmN = C * X * Y;
const auto PadGemmM = (MPerBlock - GemmM % MPerBlock) % MPerBlock;
const auto PadGemmN = (NPerBlock - GemmN % NPerBlock) % NPerBlock;
const index_t GemmKBatch = batch_k; const index_t GemmKBatch = batch_k;
const index_t GemmK0 = const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
...@@ -651,9 +685,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -651,9 +685,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, // Padd
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
wei_grid_desc); transform_tensor_descriptor(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch),
make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmM, PadGemmM),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch),
make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmN, PadGemmN),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto wei_gemmm_gemmn_pad_grid_desc =
transform_tensor_descriptor(wei_grid_desc,
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
wei_gemmm_gemmn_pad_grid_desc);
} }
} }
...@@ -708,6 +770,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -708,6 +770,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const index_t GemmM = K; const index_t GemmM = K;
const index_t GemmN = C * Z * X * Y; const index_t GemmN = C * Z * X * Y;
const auto PadGemmM = (MPerBlock - GemmM % MPerBlock) % MPerBlock;
const auto PadGemmN = (NPerBlock - GemmN % NPerBlock) % NPerBlock;
const index_t GemmKBatch = batch_k; const index_t GemmKBatch = batch_k;
const index_t GemmK0 = const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
...@@ -822,9 +887,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -822,9 +887,37 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, // Padd
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
wei_grid_desc); transform_tensor_descriptor(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch),
make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmM, PadGemmM),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch),
make_pass_through_transform(GemmK0),
make_right_pad_transform(GemmN, PadGemmN),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto wei_gemmm_gemmn_pad_grid_desc =
transform_tensor_descriptor(wei_grid_desc,
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
wei_gemmm_gemmn_pad_grid_desc);
} }
} // function end } // function end
......
...@@ -63,6 +63,7 @@ using device_grouped_conv_bwd_data_xdl_f16_instances = ...@@ -63,6 +63,7 @@ using device_grouped_conv_bwd_data_xdl_f16_instances =
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, GNHWK, GKYXC, Empty_Tuple, GNHWC, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
// TODO: After enable, add instance for small conv.K and conv.C
#endif #endif
// clang-format on // clang-format on
>; >;
...@@ -97,6 +98,7 @@ using device_grouped_conv_bwd_data_xdl_bf16_instances = ...@@ -97,6 +98,7 @@ using device_grouped_conv_bwd_data_xdl_bf16_instances =
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
// TODO: After enable, add instance for small conv.K and conv.C
#endif #endif
// clang-format on // clang-format on
>; >;
...@@ -131,6 +133,7 @@ using device_grouped_conv_bwd_data_xdl_f32_instances = ...@@ -131,6 +133,7 @@ using device_grouped_conv_bwd_data_xdl_f32_instances =
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
// TODO: After enable, add instance for small conv.K and conv.C
#endif #endif
// clang-format on // clang-format on
>; >;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
#include "device_grouped_conv2d_fwd_common.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using Empty_Tuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
static constexpr auto ConvFwdOddC =
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <typename InLayout, template <typename InLayout,
typename WeiLayout, typename WeiLayout,
typename DsLayout, typename DsLayout,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using Empty_Tuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using NHWGC = ck::tensor_layout::convolution::NHWGC;
using GNHWC = ck::tensor_layout::convolution::GNHWC;
using GKYXC = ck::tensor_layout::convolution::GKYXC;
using NHWGK = ck::tensor_layout::convolution::NHWGK;
using GNHWK = ck::tensor_layout::convolution::GNHWK;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
static constexpr auto ConvFwdOddC =
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_fwd_dl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_dl_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_fwd_dl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_dl_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_fwd_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -24,40 +24,36 @@ void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( ...@@ -24,40 +24,36 @@ void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_xdl_bf16_instances<GNHWC, device_grouped_conv_fwd_xdl_bf16_instances<2,
GKYXC, GNHWC,
Empty_Tuple, GKYXC,
GNHWK, Empty_Tuple,
Empty_Tuple, GNHWK,
PassThrough, ConvFwdDefault>{});
ConvFwdDefault>{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_xdl_bf16_instances<GNHWC, device_grouped_conv_fwd_xdl_bf16_instances<2,
GKYXC, GNHWC,
Empty_Tuple, GKYXC,
GNHWK, Empty_Tuple,
Empty_Tuple, GNHWK,
PassThrough, ConvFwd1x1P0>{});
ConvFwd1x1P0>{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_xdl_bf16_instances<GNHWC, device_grouped_conv_fwd_xdl_bf16_instances<2,
GKYXC, GNHWC,
Empty_Tuple, GKYXC,
GNHWK, Empty_Tuple,
Empty_Tuple, GNHWK,
PassThrough, ConvFwd1x1S1P0>{});
ConvFwd1x1S1P0>{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_xdl_bf16_instances<GNHWC, device_grouped_conv_fwd_xdl_bf16_instances<2,
GKYXC, GNHWC,
Empty_Tuple, GKYXC,
GNHWK, Empty_Tuple,
Empty_Tuple, GNHWK,
PassThrough, ConvFwdOddC>{});
ConvFwdOddC>{});
} }
} // namespace instance } // namespace instance
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_fwd_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -24,40 +24,36 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances( ...@@ -24,40 +24,36 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_xdl_f16_instances<GNHWC, device_grouped_conv_fwd_xdl_f16_instances<2,
GKYXC, GNHWC,
Empty_Tuple, GKYXC,
GNHWK, Empty_Tuple,
Empty_Tuple, GNHWK,
PassThrough, ConvFwdDefault>{});
ConvFwdDefault>{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_xdl_f16_instances<GNHWC, device_grouped_conv_fwd_xdl_f16_instances<2,
GKYXC, GNHWC,
Empty_Tuple, GKYXC,
GNHWK, Empty_Tuple,
Empty_Tuple, GNHWK,
PassThrough, ConvFwd1x1P0>{});
ConvFwd1x1P0>{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_xdl_f16_instances<GNHWC, device_grouped_conv_fwd_xdl_f16_instances<2,
GKYXC, GNHWC,
Empty_Tuple, GKYXC,
GNHWK, Empty_Tuple,
Empty_Tuple, GNHWK,
PassThrough, ConvFwd1x1S1P0>{});
ConvFwd1x1S1P0>{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_xdl_f16_instances<GNHWC, device_grouped_conv_fwd_xdl_f16_instances<2,
GKYXC, GNHWC,
Empty_Tuple, GKYXC,
GNHWK, Empty_Tuple,
Empty_Tuple, GNHWK,
PassThrough, ConvFwdOddC>{});
ConvFwdOddC>{});
} }
} // namespace instance } // namespace instance
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_fwd_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -24,40 +24,36 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances( ...@@ -24,40 +24,36 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_xdl_f32_instances<GNHWC, device_grouped_conv_fwd_xdl_f32_instances<2,
GKYXC, GNHWC,
Empty_Tuple, GKYXC,
GNHWK, Empty_Tuple,
Empty_Tuple, GNHWK,
PassThrough, ConvFwdDefault>{});
ConvFwdDefault>{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_xdl_f32_instances<GNHWC, device_grouped_conv_fwd_xdl_f32_instances<2,
GKYXC, GNHWC,
Empty_Tuple, GKYXC,
GNHWK, Empty_Tuple,
Empty_Tuple, GNHWK,
PassThrough, ConvFwd1x1P0>{});
ConvFwd1x1P0>{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_xdl_f32_instances<GNHWC, device_grouped_conv_fwd_xdl_f32_instances<2,
GKYXC, GNHWC,
Empty_Tuple, GKYXC,
GNHWK, Empty_Tuple,
Empty_Tuple, GNHWK,
PassThrough, ConvFwd1x1S1P0>{});
ConvFwd1x1S1P0>{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_xdl_f32_instances<GNHWC, device_grouped_conv_fwd_xdl_f32_instances<2,
GKYXC, GNHWC,
Empty_Tuple, GKYXC,
GNHWK, Empty_Tuple,
Empty_Tuple, GNHWK,
PassThrough, ConvFwdOddC>{});
ConvFwdOddC>{});
} }
} // namespace instance } // namespace instance
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_fwd_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -24,40 +24,36 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( ...@@ -24,40 +24,36 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_xdl_bf16_instances<NHWGC, device_grouped_conv_fwd_xdl_bf16_instances<2,
GKYXC, NHWGC,
Empty_Tuple, GKYXC,
NHWGK, Empty_Tuple,
Empty_Tuple, NHWGK,
PassThrough, ConvFwdDefault>{});
ConvFwdDefault>{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_xdl_bf16_instances<NHWGC, device_grouped_conv_fwd_xdl_bf16_instances<2,
GKYXC, NHWGC,
Empty_Tuple, GKYXC,
NHWGK, Empty_Tuple,
Empty_Tuple, NHWGK,
PassThrough, ConvFwd1x1P0>{});
ConvFwd1x1P0>{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_xdl_bf16_instances<NHWGC, device_grouped_conv_fwd_xdl_bf16_instances<2,
GKYXC, NHWGC,
Empty_Tuple, GKYXC,
NHWGK, Empty_Tuple,
Empty_Tuple, NHWGK,
PassThrough, ConvFwd1x1S1P0>{});
ConvFwd1x1S1P0>{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_xdl_bf16_instances<NHWGC, device_grouped_conv_fwd_xdl_bf16_instances<2,
GKYXC, NHWGC,
Empty_Tuple, GKYXC,
NHWGK, Empty_Tuple,
Empty_Tuple, NHWGK,
PassThrough, ConvFwdOddC>{});
ConvFwdOddC>{});
} }
} // namespace instance } // namespace instance
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_fwd_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -24,40 +24,36 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( ...@@ -24,40 +24,36 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_xdl_f16_instances<NHWGC, device_grouped_conv_fwd_xdl_f16_instances<2,
GKYXC, NHWGC,
Empty_Tuple, GKYXC,
NHWGK, Empty_Tuple,
Empty_Tuple, NHWGK,
PassThrough, ConvFwdDefault>{});
ConvFwdDefault>{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_xdl_f16_instances<NHWGC, device_grouped_conv_fwd_xdl_f16_instances<2,
GKYXC, NHWGC,
Empty_Tuple, GKYXC,
NHWGK, Empty_Tuple,
Empty_Tuple, NHWGK,
PassThrough, ConvFwd1x1P0>{});
ConvFwd1x1P0>{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_xdl_f16_instances<NHWGC, device_grouped_conv_fwd_xdl_f16_instances<2,
GKYXC, NHWGC,
Empty_Tuple, GKYXC,
NHWGK, Empty_Tuple,
Empty_Tuple, NHWGK,
PassThrough, ConvFwd1x1S1P0>{});
ConvFwd1x1S1P0>{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_xdl_f16_instances<NHWGC, device_grouped_conv_fwd_xdl_f16_instances<2,
GKYXC, NHWGC,
Empty_Tuple, GKYXC,
NHWGK, Empty_Tuple,
Empty_Tuple, NHWGK,
PassThrough, ConvFwdOddC>{});
ConvFwdOddC>{});
} }
} // namespace instance } // namespace instance
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "device_grouped_conv2d_fwd_xdl_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -24,40 +24,36 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances( ...@@ -24,40 +24,36 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_xdl_f32_instances<NHWGC, device_grouped_conv_fwd_xdl_f32_instances<2,
GKYXC, NHWGC,
Empty_Tuple, GKYXC,
NHWGK, Empty_Tuple,
Empty_Tuple, NHWGK,
PassThrough, ConvFwdDefault>{});
ConvFwdDefault>{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_xdl_f32_instances<NHWGC, device_grouped_conv_fwd_xdl_f32_instances<2,
GKYXC, NHWGC,
Empty_Tuple, GKYXC,
NHWGK, Empty_Tuple,
Empty_Tuple, NHWGK,
PassThrough, ConvFwd1x1P0>{});
ConvFwd1x1P0>{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_xdl_f32_instances<NHWGC, device_grouped_conv_fwd_xdl_f32_instances<2,
GKYXC, NHWGC,
Empty_Tuple, GKYXC,
NHWGK, Empty_Tuple,
Empty_Tuple, NHWGK,
PassThrough, ConvFwd1x1S1P0>{});
ConvFwd1x1S1P0>{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_xdl_f32_instances<NHWGC, device_grouped_conv_fwd_xdl_f32_instances<2,
GKYXC, NHWGC,
Empty_Tuple, GKYXC,
NHWGK, Empty_Tuple,
Empty_Tuple, NHWGK,
PassThrough, ConvFwdOddC>{});
ConvFwdOddC>{});
} }
} // namespace instance } // namespace instance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment