Commit d8fed085 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Add 3d grouped conv fwd wmma instances

parent bba085d2
...@@ -234,6 +234,20 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( ...@@ -234,6 +234,20 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC, NHWGC,
...@@ -248,6 +262,21 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( ...@@ -248,6 +262,21 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances( void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
...@@ -293,6 +322,20 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( ...@@ -293,6 +322,20 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
...@@ -323,6 +366,20 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances( ...@@ -323,6 +366,20 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK // grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
...@@ -354,6 +411,20 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( ...@@ -354,6 +411,20 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
...@@ -384,6 +455,20 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances( ...@@ -384,6 +455,20 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif #endif
template <ck::index_t NumDimSpatial, template <ck::index_t NumDimSpatial,
...@@ -516,14 +601,22 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -516,14 +601,22 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#ifdef DL_KERNELS #ifdef DL_KERNELS
add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs);
#endif #endif
add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs);
} }
#endif #endif
#ifdef CK_ENABLE_BDF16 #ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> && if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>) is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
{ {
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs);
} }
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances(op_ptrs);
}
#endif #endif
} }
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, GNDHWC> && else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, GNDHWC> &&
...@@ -541,6 +634,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -541,6 +634,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, half_t>) is_same_v<OutDataType, half_t>)
{ {
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs);
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs);
} }
#endif #endif
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
...@@ -555,6 +649,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -555,6 +649,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, int8_t>) is_same_v<OutDataType, int8_t>)
{ {
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(op_ptrs);
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(op_ptrs);
} }
#endif #endif
} }
...@@ -573,6 +668,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -573,6 +668,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, half_t>) is_same_v<OutDataType, half_t>)
{ {
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs);
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs);
} }
#endif #endif
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
...@@ -587,6 +683,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -587,6 +683,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<OutDataType, int8_t>) is_same_v<OutDataType, int8_t>)
{ {
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(op_ptrs);
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(op_ptrs);
} }
#endif #endif
} }
......
...@@ -13,9 +13,13 @@ add_instance_library(device_grouped_conv2d_fwd_instance ...@@ -13,9 +13,13 @@ add_instance_library(device_grouped_conv2d_fwd_instance
device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp
device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp
# WMMA # WMMA
# GNHWC, GKYXC, GNHWK
device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp
device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp
# NHWGC, GKYXC, NHWGK # NHWGC, GKYXC, NHWGK
device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp
device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp
# NHWGC, GKYXC, NHWGK
device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
) )
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv2d_fwd_wmma_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -24,7 +24,8 @@ void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( ...@@ -24,7 +24,8 @@ void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_wmma_f16_instances<GNHWC, device_grouped_conv_fwd_wmma_f16_instances<2,
GNHWC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
GNHWK, GNHWK,
...@@ -33,7 +34,8 @@ void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( ...@@ -33,7 +34,8 @@ void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(
ConvFwdDefault>{}); ConvFwdDefault>{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_wmma_f16_instances<GNHWC, device_grouped_conv_fwd_wmma_f16_instances<2,
GNHWC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
GNHWK, GNHWK,
...@@ -42,7 +44,8 @@ void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( ...@@ -42,7 +44,8 @@ void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(
ConvFwd1x1P0>{}); ConvFwd1x1P0>{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_wmma_f16_instances<GNHWC, device_grouped_conv_fwd_wmma_f16_instances<2,
GNHWC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
GNHWK, GNHWK,
...@@ -51,7 +54,8 @@ void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( ...@@ -51,7 +54,8 @@ void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(
ConvFwd1x1S1P0>{}); ConvFwd1x1S1P0>{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_wmma_f16_instances<GNHWC, device_grouped_conv_fwd_wmma_f16_instances<2,
GNHWC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
GNHWK, GNHWK,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv2d_fwd_wmma_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -24,7 +24,8 @@ void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances( ...@@ -24,7 +24,8 @@ void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_wmma_i8_instances<GNHWC, device_grouped_conv_fwd_wmma_i8_instances<2,
GNHWC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
GNHWK, GNHWK,
...@@ -33,7 +34,8 @@ void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances( ...@@ -33,7 +34,8 @@ void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(
ConvFwdDefault>{}); ConvFwdDefault>{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_wmma_i8_instances<GNHWC, device_grouped_conv_fwd_wmma_i8_instances<2,
GNHWC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
GNHWK, GNHWK,
...@@ -42,7 +44,8 @@ void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances( ...@@ -42,7 +44,8 @@ void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(
ConvFwd1x1P0>{}); ConvFwd1x1P0>{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_wmma_i8_instances<GNHWC, device_grouped_conv_fwd_wmma_i8_instances<2,
GNHWC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
GNHWK, GNHWK,
...@@ -51,7 +54,8 @@ void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances( ...@@ -51,7 +54,8 @@ void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(
ConvFwd1x1S1P0>{}); ConvFwd1x1S1P0>{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_fwd_wmma_i8_instances<GNHWC, device_grouped_conv_fwd_wmma_i8_instances<2,
GNHWC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
GNHWK, GNHWK,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi ,wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_f16_instances<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
Empty_Tuple,
PassThrough,
ConvFwdDefault>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_f16_instances<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
Empty_Tuple,
PassThrough,
ConvFwd1x1P0>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_f16_instances<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
Empty_Tuple,
PassThrough,
ConvFwd1x1S1P0>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_f16_instances<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
Empty_Tuple,
PassThrough,
ConvFwdOddC>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi ,wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_i8_instances<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
Empty_Tuple,
PassThrough,
ConvFwdDefault>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_i8_instances<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
Empty_Tuple,
PassThrough,
ConvFwd1x1P0>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_i8_instances<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
Empty_Tuple,
PassThrough,
ConvFwd1x1S1P0>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_i8_instances<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
Empty_Tuple,
PassThrough,
ConvFwdOddC>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -8,4 +8,9 @@ add_instance_library(device_grouped_conv3d_fwd_instance ...@@ -8,4 +8,9 @@ add_instance_library(device_grouped_conv3d_fwd_instance
device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp
device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp
device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp
) )
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[g, n, di, hi ,wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho,
// wo, k]
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_f16_instances<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
Empty_Tuple,
PassThrough,
ConvFwdDefault>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_f16_instances<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
Empty_Tuple,
PassThrough,
ConvFwd1x1P0>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_f16_instances<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
Empty_Tuple,
PassThrough,
ConvFwd1x1S1P0>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_f16_instances<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
Empty_Tuple,
PassThrough,
ConvFwdOddC>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[g, n, di, hi ,wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho,
// wo, k]
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_i8_instances<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
Empty_Tuple,
PassThrough,
ConvFwdDefault>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_i8_instances<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
Empty_Tuple,
PassThrough,
ConvFwd1x1P0>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_i8_instances<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
Empty_Tuple,
PassThrough,
ConvFwd1x1S1P0>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_i8_instances<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
Empty_Tuple,
PassThrough,
ConvFwdOddC>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, di, hi ,wi, g, c] * wei[g, k, z, y, x, c] = out[n, do, ho, wo,
// g, k]
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_f16_instances<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
Empty_Tuple,
PassThrough,
ConvFwdDefault>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_f16_instances<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
Empty_Tuple,
PassThrough,
ConvFwd1x1P0>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_f16_instances<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
Empty_Tuple,
PassThrough,
ConvFwd1x1S1P0>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_f16_instances<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
Empty_Tuple,
PassThrough,
ConvFwdOddC>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, di, hi ,wi, g, c] * wei[g, k, z, y, x, c] = out[n, do, ho, wo,
// g, k]
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_i8_instances<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
Empty_Tuple,
PassThrough,
ConvFwdDefault>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_i8_instances<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
Empty_Tuple,
PassThrough,
ConvFwd1x1P0>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_i8_instances<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
Empty_Tuple,
PassThrough,
ConvFwd1x1S1P0>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_wmma_i8_instances<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
Empty_Tuple,
PassThrough,
ConvFwdOddC>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment