Commit 648f1f13 authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge remote-tracking branch 'origin/develop' into aosewski/gemm_tile_loop

parents 4e5190f5 cb538740
......@@ -16,38 +16,38 @@ namespace tensor_operation {
namespace device {
namespace instance {
// FP16
#ifdef CK_ENABLE_FP16
void add_device_batchnorm_infer_rank_4_f16_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<F16, F32, F32, F16, F16>,
ck::Tuple<F16>,
ck::tensor_operation::element_wise::NormalizeInInfer,
4>>>&);
// FP32
#endif
#ifdef CK_ENABLE_FP32
void add_device_batchnorm_infer_rank_4_f32_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<F32, F32, F32, F32, F32>,
ck::Tuple<F32>,
ck::tensor_operation::element_wise::NormalizeInInfer,
4>>>&);
// BF16
#endif
#ifdef CK_ENABLE_BF16
void add_device_batchnorm_infer_rank_4_bf16_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<BF16, F32, F32, BF16, BF16>,
ck::Tuple<BF16>,
ck::tensor_operation::element_wise::NormalizeInInfer,
4>>>&);
// FP64
#endif
#ifdef CK_ENABLE_FP64
void add_device_batchnorm_infer_rank_4_f64_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<F64, F64, F64, F64, F64>,
ck::Tuple<F64>,
ck::tensor_operation::element_wise::NormalizeInInfer,
4>>>&);
#endif
template <typename XDataType,
typename YDataType,
typename ScaleDataType,
......@@ -69,7 +69,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElemen
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<XDataType, F16> && is_same_v<YDataType, F16> &&
is_same_v<ScaleDataType, F16> && is_same_v<BiasDataType, F16> &&
is_same_v<MeanVarDataType, F32>)
......@@ -79,34 +79,40 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElemen
add_device_batchnorm_infer_rank_4_f16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> &&
is_same_v<ScaleDataType, F32> && is_same_v<BiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> &&
is_same_v<ScaleDataType, F32> && is_same_v<BiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4)
{
add_device_batchnorm_infer_rank_4_f32_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> &&
is_same_v<ScaleDataType, BF16> && is_same_v<BiasDataType, BF16> &&
is_same_v<MeanVarDataType, F32>)
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> &&
is_same_v<ScaleDataType, BF16> && is_same_v<BiasDataType, BF16> &&
is_same_v<MeanVarDataType, F32>)
{
if constexpr(Rank == 4)
{
add_device_batchnorm_infer_rank_4_bf16_instances(op_ptrs);
}
}
else if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
is_same_v<ScaleDataType, F64> && is_same_v<BiasDataType, F64> &&
is_same_v<MeanVarDataType, F64>)
#endif
#ifdef CK_ENABLE_FP64
if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
is_same_v<ScaleDataType, F64> && is_same_v<BiasDataType, F64> &&
is_same_v<MeanVarDataType, F64>)
{
if constexpr(Rank == 4)
{
add_device_batchnorm_infer_rank_4_f64_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using namespace ck::conv_tensor_rearrange_op;
// Image to Column
// nhwc, 1d
void add_device_image_to_column_nwc_1d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, BF16, BF16, ImageToColumn>>>&
instances);
void add_device_image_to_column_nwc_1d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, F16, F16, ImageToColumn>>>&
instances);
void add_device_image_to_column_nwc_1d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, F32, F32, ImageToColumn>>>&
instances);
void add_device_image_to_column_nwc_1d_i8_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, int8_t, int8_t, ImageToColumn>>>&
instances);
// nhwc, 2d
void add_device_image_to_column_nhwc_2d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, BF16, BF16, ImageToColumn>>>&
instances);
void add_device_image_to_column_nhwc_2d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, F16, F16, ImageToColumn>>>&
instances);
void add_device_image_to_column_nhwc_2d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, F32, F32, ImageToColumn>>>&
instances);
void add_device_image_to_column_nhwc_2d_i8_instances(
std::vector<
std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, int8_t, int8_t, ImageToColumn>>>&
instances);
// nhwc, 3d
void add_device_image_to_column_ndhwc_3d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, BF16, BF16, ImageToColumn>>>&
instances);
void add_device_image_to_column_ndhwc_3d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, F16, F16, ImageToColumn>>>&
instances);
void add_device_image_to_column_ndhwc_3d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, F32, F32, ImageToColumn>>>&
instances);
void add_device_image_to_column_ndhwc_3d_i8_instances(
std::vector<
std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, int8_t, int8_t, ImageToColumn>>>&
instances);
// Column to Image
// nhwc, 1d
void add_device_column_to_image_nwc_1d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, BF16, BF16, ColumnToImage>>>&
instances);
void add_device_column_to_image_nwc_1d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, F16, F16, ColumnToImage>>>&
instances);
void add_device_column_to_image_nwc_1d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, F32, F32, ColumnToImage>>>&
instances);
void add_device_column_to_image_nwc_1d_i8_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<1, GNWC, int8_t, int8_t, ColumnToImage>>>&
instances);
// nhwc, 2d
void add_device_column_to_image_nhwc_2d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, BF16, BF16, ColumnToImage>>>&
instances);
void add_device_column_to_image_nhwc_2d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, F16, F16, ColumnToImage>>>&
instances);
void add_device_column_to_image_nhwc_2d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, F32, F32, ColumnToImage>>>&
instances);
void add_device_column_to_image_nhwc_2d_i8_instances(
std::vector<
std::unique_ptr<DeviceConvTensorRearrange<2, GNHWC, int8_t, int8_t, ColumnToImage>>>&
instances);
// nhwc, 3d
void add_device_column_to_image_ndhwc_3d_bf16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, BF16, BF16, ColumnToImage>>>&
instances);
void add_device_column_to_image_ndhwc_3d_f16_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, F16, F16, ColumnToImage>>>&
instances);
void add_device_column_to_image_ndhwc_3d_f32_instances(
std::vector<std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, F32, F32, ColumnToImage>>>&
instances);
void add_device_column_to_image_ndhwc_3d_i8_instances(
std::vector<
std::unique_ptr<DeviceConvTensorRearrange<3, GNDHWC, int8_t, int8_t, ColumnToImage>>>&
instances);
template <ck::index_t NumDimSpatial,
typename ImageLayout,
typename InDataType,
typename OutDataType,
typename ConvTensorRearrangeOp>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceConvTensorRearrange<NumDimSpatial,
ImageLayout,
InDataType,
OutDataType,
ConvTensorRearrangeOp>>
{
using DeviceOp = DeviceConvTensorRearrange<NumDimSpatial,
ImageLayout,
InDataType,
OutDataType,
ConvTensorRearrangeOp>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ConvTensorRearrangeOp, ImageToColumn>)
{
if constexpr(NumDimSpatial == 1 && is_same_v<ImageLayout, GNWC>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
{
add_device_image_to_column_nwc_1d_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
{
add_device_image_to_column_nwc_1d_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_image_to_column_nwc_1d_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
{
add_device_image_to_column_nwc_1d_i8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 2 && is_same_v<ImageLayout, GNHWC>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
{
add_device_image_to_column_nhwc_2d_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
{
add_device_image_to_column_nhwc_2d_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_image_to_column_nhwc_2d_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
{
add_device_image_to_column_nhwc_2d_i8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 3 && is_same_v<ImageLayout, GNDHWC>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
{
add_device_image_to_column_ndhwc_3d_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
{
add_device_image_to_column_ndhwc_3d_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_image_to_column_ndhwc_3d_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
{
add_device_image_to_column_ndhwc_3d_i8_instances(op_ptrs);
}
}
}
else if constexpr(is_same_v<ConvTensorRearrangeOp, ColumnToImage>)
{
if constexpr(NumDimSpatial == 1 && is_same_v<ImageLayout, GNWC>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
{
add_device_column_to_image_nwc_1d_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
{
add_device_column_to_image_nwc_1d_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_column_to_image_nwc_1d_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
{
add_device_column_to_image_nwc_1d_i8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 2 && is_same_v<ImageLayout, GNHWC>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
{
add_device_column_to_image_nhwc_2d_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
{
add_device_column_to_image_nhwc_2d_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_column_to_image_nhwc_2d_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
{
add_device_column_to_image_nhwc_2d_i8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 3 && is_same_v<ImageLayout, GNDHWC>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
{
add_device_column_to_image_ndhwc_3d_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
{
add_device_column_to_image_ndhwc_3d_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_column_to_image_ndhwc_3d_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
{
add_device_column_to_image_ndhwc_3d_i8_instances(op_ptrs);
}
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using namespace ck::tensor_layout::convolution;
using namespace ck::conv_tensor_rearrange_op;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
template <ck::index_t NDimSpatial, typename InLayout>
using device_column_to_image_bf16_instances = std::tuple<
// clang-format off
//#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar|
//#####################| Dim| | | | Size| Block| Block| Cluster| Per|
//#####################| Spatial| | | | | | | Lengths| Vector|
//#####################| | | | | | | | | |
// generic instance
DeviceColumnToImageImpl<NDimSpatial, InLayout, BF16, BF16, 64, 16, 16, S<8, 8>, 1>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, BF16, BF16, 64, 32, 32, S<8, 8>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, BF16, BF16, 64, 64, 64, S<8, 8>, 8>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, BF16, BF16, 128, 32, 64, S<8, 16>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, BF16, BF16, 128, 64, 128, S<8, 16>, 8>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, BF16, BF16, 256, 64, 64, S<16, 16>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, BF16, BF16, 256, 128, 128, S<16, 16>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, BF16, BF16, 256, 128, 128, S<16, 16>, 8>
// clang-format on
>;
template <ck::index_t NDimSpatial, typename InLayout>
using device_column_to_image_f16_instances = std::tuple<
// clang-format off
//#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar|
//#####################| Dim| | | | Size| Block| Block| Cluster| Per|
//#####################| Spatial| | | | | | | Lengths| Vector|
//#####################| | | | | | | | | |
// generic instance
DeviceColumnToImageImpl<NDimSpatial, InLayout, F16, F16, 64, 16, 16, S<8, 8>, 1>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, F16, F16, 64, 32, 32, S<8, 8>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, F16, F16, 64, 64, 64, S<8, 8>, 8>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, F16, F16, 128, 32, 64, S<8, 16>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, F16, F16, 128, 64, 128, S<8, 16>, 8>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, F16, F16, 256, 64, 64, S<16, 16>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, F16, F16, 256, 128, 128, S<16, 16>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, F16, F16, 256, 128, 128, S<16, 16>, 8>
// clang-format on
>;
template <ck::index_t NDimSpatial, typename InLayout>
using device_column_to_image_f32_instances = std::tuple<
// clang-format off
//#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar|
//#####################| Dim| | | | Size| Block| Block| Cluster| Per|
//#####################| Spatial| | | | | | | Lengths| Vector|
//#####################| | | | | | | | | |
// generic instance
DeviceColumnToImageImpl<NDimSpatial, InLayout, F32, F32, 64, 16, 16, S<8, 8>, 1>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, F32, F32, 64, 32, 32, S<8, 8>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, F32, F32, 128, 32, 64, S<8, 16>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, F32, F32, 256, 64, 64, S<16, 16>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, F32, F32, 256, 128, 128, S<16, 16>, 4>
// clang-format on
>;
template <ck::index_t NDimSpatial, typename InLayout>
using device_column_to_image_i8_instances = std::tuple<
// clang-format off
//#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar|
//#####################| Dim| | | | Size| Block| Block| Cluster| Per|
//#####################| Spatial| | | | | | | Lengths| Vector|
//#####################| | | | | | | | | |
// generic instance
DeviceColumnToImageImpl<NDimSpatial, InLayout, int8_t, int8_t, 64, 16, 16, S<8, 8>, 1>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, int8_t, int8_t, 64, 32, 32, S<8, 8>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, int8_t, int8_t, 64, 64, 64, S<8, 8>, 8>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, int8_t, int8_t, 128, 32, 64, S<8, 16>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, int8_t, int8_t, 128, 64, 128, S<8, 16>, 8>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, int8_t, int8_t, 256, 64, 64, S<16, 16>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, int8_t, int8_t, 256, 128, 128, S<16, 16>, 4>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, int8_t, int8_t, 256, 128, 128, S<16, 16>, 8>,
DeviceColumnToImageImpl<NDimSpatial, InLayout, int8_t, int8_t, 256, 256, 256, S<16, 16>, 16>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -13,6 +13,7 @@ namespace device {
namespace instance {
using namespace ck::tensor_layout::convolution;
using namespace ck::conv_tensor_rearrange_op;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
......@@ -28,17 +29,12 @@ using device_image_to_column_bf16_instances = std::tuple<
//#####################| Dim| | | | Size| Block| Block| Cluster| Per|
//#####################| Spatial| | | | | | | Lengths| Vector|
//#####################| | | | | | | | | |
DeviceImageToColumnImpl<NDimSpatial, InLayout, BF16, BF16, 64, 8, 8, S<8, 8>, 1>,
// generic instance
DeviceImageToColumnImpl<NDimSpatial, InLayout, BF16, BF16, 64, 16, 16, S<8, 8>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, BF16, BF16, 64, 32, 32, S<8, 8>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, BF16, BF16, 64, 64, 64, S<8, 8>, 8>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, BF16, BF16, 128, 16, 16, S<8, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, BF16, BF16, 128, 64, 64, S<8, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, BF16, BF16, 128, 32, 64, S<8, 16>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, BF16, BF16, 128, 64, 128, S<8, 16>, 8>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, BF16, BF16, 256, 16, 16, S<16, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, BF16, BF16, 256, 64, 64, S<16, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, BF16, BF16, 256, 128, 128, S<16, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, BF16, BF16, 256, 64, 64, S<16, 16>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, BF16, BF16, 256, 128, 128, S<16, 16>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, BF16, BF16, 256, 128, 128, S<16, 16>, 8>
......@@ -52,17 +48,13 @@ using device_image_to_column_f16_instances = std::tuple<
//#####################| Dim| | | | Size| Block| Block| Cluster| Per|
//#####################| Spatial| | | | | | | Lengths| Vector|
//#####################| | | | | | | | | |
DeviceImageToColumnImpl<NDimSpatial, InLayout, F16, F16, 64, 8, 8, S<8, 8>, 1>,
// generic instance
DeviceImageToColumnImpl<NDimSpatial, InLayout, F16, F16, 64, 16, 16, S<8, 8>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F16, F16, 64, 32, 32, S<8, 8>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F16, F16, 64, 64, 64, S<8, 8>, 8>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F16, F16, 128, 16, 16, S<8, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F16, F16, 128, 64, 64, S<8, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F16, F16, 128, 32, 64, S<8, 16>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F16, F16, 128, 64, 128, S<8, 16>, 8>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F16, F16, 256, 16, 16, S<16, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F16, F16, 256, 64, 64, S<16, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F16, F16, 256, 128, 128, S<16, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F16, F16, 256, 64, 64, S<16, 16>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F16, F16, 256, 128, 128, S<16, 16>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F16, F16, 256, 128, 128, S<16, 16>, 8>
......@@ -76,15 +68,11 @@ using device_image_to_column_f32_instances = std::tuple<
//#####################| Dim| | | | Size| Block| Block| Cluster| Per|
//#####################| Spatial| | | | | | | Lengths| Vector|
//#####################| | | | | | | | | |
DeviceImageToColumnImpl<NDimSpatial, InLayout, F32, F32, 64, 8, 8, S<8, 8>, 1>,
// generic instance
DeviceImageToColumnImpl<NDimSpatial, InLayout, F32, F32, 64, 16, 16, S<8, 8>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F32, F32, 64, 32, 32, S<8, 8>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F32, F32, 128, 16, 16, S<8, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F32, F32, 128, 64, 64, S<8, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F32, F32, 128, 32, 64, S<8, 16>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F32, F32, 256, 16, 16, S<16, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F32, F32, 256, 64, 64, S<16, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F32, F32, 256, 128, 128, S<16, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F32, F32, 256, 64, 64, S<16, 16>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, F32, F32, 256, 128, 128, S<16, 16>, 4>
// clang-format on
......@@ -97,17 +85,13 @@ using device_image_to_column_i8_instances = std::tuple<
//#####################| Dim| | | | Size| Block| Block| Cluster| Per|
//#####################| Spatial| | | | | | | Lengths| Vector|
//#####################| | | | | | | | | |
DeviceImageToColumnImpl<NDimSpatial, InLayout, int8_t, int8_t, 64, 8, 8, S<8, 8>, 1>,
// generic instance
DeviceImageToColumnImpl<NDimSpatial, InLayout, int8_t, int8_t, 64, 16, 16, S<8, 8>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, int8_t, int8_t, 64, 32, 32, S<8, 8>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, int8_t, int8_t, 64, 64, 64, S<8, 8>, 8>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, int8_t, int8_t, 128, 16, 16, S<8, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, int8_t, int8_t, 128, 64, 64, S<8, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, int8_t, int8_t, 128, 32, 64, S<8, 16>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, int8_t, int8_t, 128, 64, 128, S<8, 16>, 8>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, int8_t, int8_t, 256, 16, 16, S<16, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, int8_t, int8_t, 256, 64, 64, S<16, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, int8_t, int8_t, 256, 128, 128, S<16, 16>, 1>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, int8_t, int8_t, 256, 64, 64, S<16, 16>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, int8_t, int8_t, 256, 128, 128, S<16, 16>, 4>,
DeviceImageToColumnImpl<NDimSpatial, InLayout, int8_t, int8_t, 256, 128, 128, S<16, 16>, 8>,
......
......@@ -312,6 +312,23 @@ void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances(
DeviceGemm<Row, Col, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_FP8
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
#endif
template <typename ALayout,
typename BLayout,
typename CLayout,
......@@ -505,6 +522,32 @@ struct DeviceOperationInstanceFactory<
#endif
}
}
#endif
#ifdef CK_ENABLE_FP8
else if constexpr(is_same_v<ADataType, ck::f8_t> && is_same_v<BDataType, ck::f8_t> &&
is_same_v<CDataType, ck::f8_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
......
......@@ -11,12 +11,12 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef CK_ENABLE_FP16
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP16
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
Row,
......@@ -68,7 +68,8 @@ void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance
PassThrough,
PassThrough,
Bilinear>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Row,
......@@ -120,7 +121,7 @@ void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances(
PassThrough,
PassThrough,
Bilinear>>>& instances);
#endif
// GEMM + Bilinear
template <typename ALayout,
typename BLayout,
......@@ -158,7 +159,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<DDataType, half_t> && is_same_v<EDataType, half_t>)
{
......@@ -187,8 +188,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, std::int8_t> && is_same_v<BDataType, std::int8_t> &&
is_same_v<DDataType, std::int8_t> && is_same_v<EDataType, std::int8_t>)
#endif
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<ADataType, std::int8_t> && is_same_v<BDataType, std::int8_t> &&
is_same_v<DDataType, std::int8_t> && is_same_v<EDataType, std::int8_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<DLayout, Row> && is_same_v<ELayout, Row>)
......@@ -211,7 +214,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
};
......@@ -220,4 +223,3 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP16
void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
......@@ -36,7 +36,8 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
......@@ -56,8 +57,8 @@ void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
#if defined CK_ENABLE_FP8
#endif
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
void add_device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Col, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
......@@ -129,7 +130,7 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
is_same_v<CDataType, float>)
{
......@@ -154,6 +155,8 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>)
{
......@@ -178,7 +181,8 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
#if defined CK_ENABLE_FP8
#endif
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>)
{
......@@ -228,7 +232,6 @@ struct DeviceOperationInstanceFactory<
}
}
#endif
return op_ptrs;
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using I8 = int8_t;
using I32 = int32_t;
using Empty_Tuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvBwdDataDefault =
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default;
static constexpr auto ConvBwdData1x1S1P0 =
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0;
template <index_t NDSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename DsDatatype,
typename CDEElementOp,
ConvolutionBackwardDataSpecialization ConvSpec>
using device_grouped_conv_bwd_data_wmma_f16_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 128, 64, 64, 4, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>,
// blocksize=256
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 256, 128, 256, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 256, 64, 256, 8, 8, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 256, 128, 256, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 256, 128, 64, 8, 8, 16, 16, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// blocksize=128
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 128, 64, 128, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 128, 64, 128, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 128, 128, 128, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 128, 32, 256, 8, 8, 16, 16, 1, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
// blocksize=64
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 64, 32, 64, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 64, 64, 64, 8, 8, 16, 16, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 64, 32, 64, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 64, 32, 128, 8, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
// blocksize=32
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 32, 16, 64, 8, 8, 16, 16, 1, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 32, 64, 32, 8, 8, 16, 16, 4, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 32, 32, 32, 8, 8, 16, 16, 2, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 32, 16, 32, 8, 8, 16, 16, 1, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>
// clang-format on
>;
template <index_t NDSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename DsDatatype,
typename CDEElementOp,
ConvolutionBackwardDataSpecialization ConvSpec>
using device_grouped_conv_bwd_data_wmma_i8_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 32, 1, 4>, 1>,
// blocksize=256
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 256, 64, 256, 8, 16, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// blocksize=128
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 128, 64, 256, 8, 16, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 128, 64, 128, 8, 16, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 128, 128, 256, 8, 16, 16, 16, 4, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 128, 32, 256, 8, 16, 16, 16, 1, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 128, 256, 128, 8, 16, 16, 16, 8, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
// blocksize=64
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 64, 32, 128, 8, 16, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 64, 64, 128, 8, 16, 16, 16, 2, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 64, 32, 128, 8, 16, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 64, 32, 64, 8, 16, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
// blocksize=32
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 32, 16, 64, 8, 16, 16, 16, 1, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 32, 64, 64, 8, 16, 16, 16, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 32, 32, 32, 8, 16, 16, 16, 2, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 32, 16, 64, 8, 16, 16, 16, 1, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using namespace ck::tensor_layout::convolution;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using Empty_Tuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvBwdWeightDefault =
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 =
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
ConvolutionBackwardWeightSpecialization ConvSpec>
using device_grouped_conv_bwd_weight_dl_f32_instances = std::tuple<
// clang-format off
//############################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M1N1Thread| M1N1Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
//############################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Thread| Thread| Thread| ClusterM1Xs| ClusterN1Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| SrcDstAccessOrder| SrcDstVectorDim| DstScalarPerVector|
//############################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | _K0_M0_M1_K1| _K0_M0_M1_K1| ArrangeOrder| | _K0_M0_M1_K1| ContiguousDimOrder| _K0_M0_M1_K1| _K0_N0_N1_K1| _K0_N0_N1_K1| ArrangeOrder| | _K0_N0_N1_K1| ContiguousDimOrder| _K0_N0_N1_K1| | | |
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdWeight_Dl< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<1, 8, 1, 1, 1>, S<1, 2, 1, 128, 1>, S<0, 2, 3, 1, 4>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<1, 1, 1, 8, 1>, S<1, 16, 1, 16, 1>, S<0, 1, 4, 2, 3>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1>
// clang-format on
>;
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
ConvolutionBackwardWeightSpecialization ConvSpec>
using device_grouped_conv_bwd_weight_dl_f16_instances = std::tuple<
// clang-format off
//############################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M1N1Thread| M1N1Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
//############################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Thread| Thread| Thread| ClusterM1Xs| ClusterN1Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| SrcDstAccessOrder| SrcDstVectorDim| DstScalarPerVector|
//############################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | _K0_M0_M1_K1| _K0_M0_M1_K1| ArrangeOrder| | _K0_M0_M1_K1| ContiguousDimOrder| _K0_M0_M1_K1| _K0_N0_N1_K1| _K0_N0_N1_K1| ArrangeOrder| | _K0_N0_N1_K1| ContiguousDimOrder| _K0_N0_N1_K1| | | |
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdWeight_Dl< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<1, 8, 1, 1, 1>, S<1, 2, 1, 128, 1>, S<0, 2, 3, 1, 4>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<1, 1, 1, 8, 1>, S<1, 16, 1, 16, 1>, S<0, 1, 4, 2, 3>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1>
// clang-format on
>;
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
ConvolutionBackwardWeightSpecialization ConvSpec>
using device_grouped_conv_bwd_weight_dl_bf16_instances = std::tuple<
// clang-format off
//############################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M1N1Thread| M1N1Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
//############################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Thread| Thread| Thread| ClusterM1Xs| ClusterN1Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| SrcDstAccessOrder| SrcDstVectorDim| DstScalarPerVector|
//############################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | _K0_M0_M1_K1| _K0_M0_M1_K1| ArrangeOrder| | _K0_M0_M1_K1| ContiguousDimOrder| _K0_M0_M1_K1| _K0_N0_N1_K1| _K0_N0_N1_K1| ArrangeOrder| | _K0_N0_N1_K1| ContiguousDimOrder| _K0_N0_N1_K1| | | |
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdWeight_Dl< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<1, 8, 1, 1, 1>, S<1, 2, 1, 128, 1>, S<0, 2, 3, 1, 4>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<1, 1, 1, 8, 1>, S<1, 16, 1, 16, 1>, S<0, 1, 4, 2, 3>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using I8 = int8_t;
using I32 = int32_t;
using Empty_Tuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using NHWGC = ck::tensor_layout::convolution::NHWGC;
using GNHWC = ck::tensor_layout::convolution::GNHWC;
using GKYXC = ck::tensor_layout::convolution::GKYXC;
using NHWGK = ck::tensor_layout::convolution::NHWGK;
using GNHWK = ck::tensor_layout::convolution::GNHWK;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
static constexpr auto ConvFwdOddC =
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename DsDatatype,
typename CDEElementOp,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv2d_fwd_wmma_f16_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| Ds| EData| AccData| CShuffle| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| DataType| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// blocksize=256
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 64, 256, 4, 8, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 256, 64, 4, 8, 16, 16, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// blocksize=128
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 4, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 4, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 4, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 32, 256, 4, 8, 16, 16, 1, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 256, 32, 4, 8, 16, 16, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
// blocksize=64
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 64, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 4, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 128, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
// blocksize=32
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 16, 64, 4, 8, 16, 16, 1, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 64, 16, 4, 8, 16, 16, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 32, 32, 4, 8, 16, 16, 2, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 16, 16, 4, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>
// clang-format on
>;
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename DsDatatype,
typename CDEElementOp,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv2d_fwd_wmma_i8_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| Ds| EData| AccData| CShuffle| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| DataType| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// blocksize=256
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 64, 256, 4, 16, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 256, 64, 4, 16, 16, 16, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 8, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// blocksize=128
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 8, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 4, 16, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 8, 16, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 4, 16, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 8, 16, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 32, 256, 4, 16, 16, 16, 1, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 256, 32, 4, 16, 16, 16, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
// blocksize=64
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 64, 4, 16, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 4, 16, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 32, 8, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 128, 4, 16, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
// blocksize=32
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 16, 64, 4, 16, 16, 16, 1, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 64, 16, 4, 16, 16, 16, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 32, 32, 4, 16, 16, 16, 2, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -54,9 +54,8 @@ using device_grouped_conv2d_fwd_dl_f16_instances = std::tuple<
// ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instances
// TODO: Change to ScalarPerVector = 1 when inner_product<half_t, half_t, float> will be supported
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F16, F16, DsDatatype, F16, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 8, 16, 4, 2, 2, 1, 2, 1, S<4, 2>, S<1, 1>, S<2, 1, 2, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 1, 1, 2>, S<2, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F16, F16, DsDatatype, F16, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>,
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F16, F16, DsDatatype, F16, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 8, 16, 4, 2, 1, 1, 2, 1, S<4, 2>, S<1, 1>, S<2, 1, 2, 1>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 1, 1, 1>, S<2, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F16, F16, DsDatatype, F16, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1>,
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F16, F16, DsDatatype, F16, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>
// clang-format on
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using I8 = int8_t;
using I32 = int32_t;
using Empty_Tuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
static constexpr auto ConvFwdOddC =
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <index_t NDSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename DsDatatype,
typename CDEElementOp,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_wmma_f16_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| Ds| EData| AccData| CShuffle| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| DataType| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 4, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>,
// blocksize=256
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 64, 256, 4, 8, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 256, 64, 4, 8, 16, 16, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// blocksize=128
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 4, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 4, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 4, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 32, 256, 4, 8, 16, 16, 1, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 256, 32, 4, 8, 16, 16, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
// blocksize=64
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 64, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 4, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 128, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
// blocksize=32
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 16, 64, 4, 8, 16, 16, 1, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 64, 16, 4, 8, 16, 16, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 32, 32, 4, 8, 16, 16, 2, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 16, 16, 4, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>
// clang-format on
>;
template <index_t NDSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename DsDatatype,
typename CDEElementOp,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_wmma_i8_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| Ds| EData| AccData| CShuffle| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| DataType| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//generic instance
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 16, 1, 1, 1, S<1, 32, 1, 4>, 1>,
// blocksize=256
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 64, 256, 4, 16, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 256, 64, 4, 16, 16, 16, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 8, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// blocksize=128
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 8, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 4, 16, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 8, 16, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 4, 16, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 8, 16, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 32, 256, 4, 16, 16, 16, 1, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 256, 32, 4, 16, 16, 16, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
// blocksize=64
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 64, 4, 16, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 4, 16, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 32, 8, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 128, 4, 16, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
// blocksize=32
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 16, 64, 4, 16, 16, 16, 1, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 64, 16, 4, 16, 16, 16, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 32, 32, 4, 16, 16, 16, 2, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -16,6 +16,7 @@ namespace device {
namespace instance {
// conv2d backward data
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
......@@ -30,6 +31,35 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
......@@ -43,7 +73,8 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
......@@ -57,7 +88,37 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
GNHWK,
GKYXC,
Empty_Tuple,
GNHWC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
......@@ -72,6 +133,36 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
......@@ -85,7 +176,8 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
......@@ -99,8 +191,38 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
// conv3d backward data
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
......@@ -115,6 +237,35 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
......@@ -128,7 +279,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
......@@ -142,7 +294,37 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
GNDHWK,
GKZYXC,
Empty_Tuple,
GNDHWC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
......@@ -157,6 +339,35 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
......@@ -170,7 +381,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
......@@ -184,7 +396,36 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
GKZYXC,
Empty_Tuple,
NDHWGC,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
template <ck::index_t NumDimSpatial,
typename OutLayout,
typename WeiLayout,
......@@ -230,42 +471,80 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<InLayout, GNHWC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, GNHWK>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances(
op_ptrs);
}
#endif
}
else if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, NHWGK>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(op_ptrs);
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances(
op_ptrs);
}
#endif
}
}
else if constexpr(NumDimSpatial == 3)
......@@ -274,46 +553,86 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, GNDHWK>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>)
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances(
op_ptrs);
}
#endif
}
else if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
is_same_v<OutDataType, F16>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances(
op_ptrs);
}
#endif
}
}
......
......@@ -17,7 +17,9 @@ namespace tensor_operation {
namespace device {
namespace instance {
// xdl
// conv1d backward weight
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
......@@ -29,7 +31,8 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_insta
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
......@@ -41,7 +44,8 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
......@@ -53,8 +57,9 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
// conv2d backward weight
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
GNHWC,
......@@ -66,7 +71,8 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_in
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
GNHWC,
......@@ -78,7 +84,8 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
GNHWC,
......@@ -90,7 +97,8 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
......@@ -102,7 +110,8 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_in
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
......@@ -114,7 +123,8 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
......@@ -126,8 +136,9 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
// conv3d backward weight
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
......@@ -139,7 +150,8 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
......@@ -151,7 +163,8 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
......@@ -163,7 +176,8 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
......@@ -175,7 +189,8 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
......@@ -187,7 +202,8 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
......@@ -199,6 +215,248 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef DL_KERNELS
// dl
// conv1d backward weight
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
GNWC,
GKXC,
GNWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
NWGC,
GKXC,
NWGK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
NWGC,
GKXC,
NWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
NWGC,
GKXC,
NWGK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
// conv2d backward weight
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
GNHWC,
GKYXC,
GNHWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
GNHWC,
GKYXC,
GNHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
GNHWC,
GKYXC,
GNHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
// conv3d backward weight
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
GNDHWC,
GKZYXC,
GNDHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#endif
template <ck::index_t NumDimSpatial,
typename InLayout,
......@@ -239,23 +497,68 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<InLayout, GNWC> && is_same_v<WeiLayout, GKXC> &&
is_same_v<OutLayout, GNWK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
#ifdef DL_KERNELS
add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances(op_ptrs);
#endif
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances(op_ptrs);
#endif
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
op_ptrs);
#endif
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
op_ptrs);
}
#endif
}
else if constexpr(is_same_v<InLayout, NWGC> && is_same_v<WeiLayout, GKXC> &&
is_same_v<OutLayout, NWGK>)
{
#ifdef DL_KERNELS
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances(
op_ptrs);
}
#endif
#endif
}
}
else if constexpr(NumDimSpatial == 2)
......@@ -263,48 +566,84 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<InLayout, GNHWC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, GNHWK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances(
op_ptrs);
#endif
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances(
op_ptrs);
#endif
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
op_ptrs);
#endif
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
op_ptrs);
}
#endif
}
else if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, NHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs);
#endif
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);
#endif
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
op_ptrs);
#endif
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
op_ptrs);
}
#endif
}
}
else if constexpr(NumDimSpatial == 3)
......@@ -312,48 +651,84 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, GNDHWK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances(
op_ptrs);
#endif
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances(
op_ptrs);
#endif
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
op_ptrs);
#endif
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
op_ptrs);
}
#endif
}
else if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
#endif
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
#endif
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
#ifdef DL_KERNELS
add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
op_ptrs);
#endif
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
op_ptrs);
}
#endif
}
}
......
......@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_BF16
// grouped conv1d forward, GNWC/GKXC/GNWK
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1,
......@@ -31,7 +31,8 @@ void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1,
GNWC,
......@@ -45,7 +46,8 @@ void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1,
GNWC,
......@@ -59,7 +61,8 @@ void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1,
GNWC,
......@@ -73,7 +76,8 @@ void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
......@@ -88,7 +92,8 @@ void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
......@@ -102,7 +107,8 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
......@@ -116,7 +122,9 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef DL_KERNELS
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
......@@ -130,7 +138,8 @@ void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
......@@ -144,7 +153,9 @@ void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
......@@ -158,6 +169,50 @@ void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#ifdef DL_KERNELS
void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
......@@ -171,7 +226,9 @@ void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
......@@ -185,6 +242,50 @@ void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
GNHWC,
GKYXC,
Empty_Tuple,
GNHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#if(defined(CK_ENABLE_FP32) && defined(DL_KERNELS))
void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
......@@ -199,7 +300,9 @@ void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(
PassThrough,
PassThrough>>>& instances);
#endif
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
......@@ -213,6 +316,63 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
......@@ -227,7 +387,65 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
NHWGC,
......@@ -241,7 +459,8 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
......@@ -256,7 +475,8 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
GNDHWC,
......@@ -271,6 +491,63 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
GNDHWC,
......@@ -284,7 +561,8 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
GNDHWC,
......@@ -299,6 +577,63 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
GNDHWC,
GKZYXC,
Empty_Tuple,
GNDHWK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
......@@ -313,7 +648,8 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC,
......@@ -328,6 +664,63 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC,
......@@ -341,7 +734,8 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC,
......@@ -356,6 +750,63 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
......@@ -397,127 +848,210 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(NumDimSpatial == 1 && is_same_v<InLayout, GNWC> &&
is_same_v<WeiLayout, GKXC> && is_same_v<OutLayout, GNWK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
#endif
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(op_ptrs);
}
#endif
}
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
#ifdef DL_KERNELS
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
#endif
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
#ifdef DL_KERNELS
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
#endif
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances(op_ptrs);
}
#endif
}
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs);
#ifdef DL_KERNELS
add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs);
#endif
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs);
#ifdef DL_KERNELS
add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs);
#endif
add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1p0_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances(op_ptrs);
add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances(op_ptrs);
}
#endif
}
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, GNDHWC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, GNDHWK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs);
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs);
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
#endif
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(op_ptrs);
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(op_ptrs);
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances(op_ptrs);
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances(op_ptrs);
}
#endif
}
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs);
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs);
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
#endif
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(op_ptrs);
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(op_ptrs);
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instances(op_ptrs);
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instances(op_ptrs);
}
#endif
}
return op_ptrs;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// fp16_output
void add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// fp8_inputB
void add_device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
Empty_Tuple,
Row,
F16,
F8,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Col,
Empty_Tuple,
Row,
F16,
F8,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// i8_inputB
void add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
Empty_Tuple,
Row,
F16,
I8,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Col,
Empty_Tuple,
Row,
F16,
I8,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <typename ALayout,
typename BLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGroupedGemmFixedNK<ALayout,
BLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
EDataType,
PassThrough,
PassThrough,
PassThrough>>
{
using DeviceOp = DeviceGroupedGemmFixedNK<ALayout,
BLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
EDataType,
PassThrough,
PassThrough,
PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
// fp16_output
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
}
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
}
}
// fp8_input
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_kn_mn_instances(op_ptrs);
}
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_nk_mn_instances(op_ptrs);
}
}
// i8_input
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, int8_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instances(op_ptrs);
}
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_image_to_column.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// nhwc, 1d
void add_device_image_to_column_nhwc_1d_bf16_instances(
std::vector<std::unique_ptr<DeviceImageToColumn<1, GNWC, BF16, BF16>>>& instances);
void add_device_image_to_column_nhwc_1d_f16_instances(
std::vector<std::unique_ptr<DeviceImageToColumn<1, GNWC, F16, F16>>>& instances);
void add_device_image_to_column_nhwc_1d_f32_instances(
std::vector<std::unique_ptr<DeviceImageToColumn<1, GNWC, F32, F32>>>& instances);
void add_device_image_to_column_nhwc_1d_i8_instances(
std::vector<std::unique_ptr<DeviceImageToColumn<1, GNWC, int8_t, int8_t>>>& instances);
// nhwc, 2d
void add_device_image_to_column_nhwc_2d_bf16_instances(
std::vector<std::unique_ptr<DeviceImageToColumn<2, GNHWC, BF16, BF16>>>& instances);
void add_device_image_to_column_nhwc_2d_f16_instances(
std::vector<std::unique_ptr<DeviceImageToColumn<2, GNHWC, F16, F16>>>& instances);
void add_device_image_to_column_nhwc_2d_f32_instances(
std::vector<std::unique_ptr<DeviceImageToColumn<2, GNHWC, F32, F32>>>& instances);
void add_device_image_to_column_nhwc_2d_i8_instances(
std::vector<std::unique_ptr<DeviceImageToColumn<2, GNHWC, int8_t, int8_t>>>& instances);
// nhwc, 3d
void add_device_image_to_column_nhwc_3d_bf16_instances(
std::vector<std::unique_ptr<DeviceImageToColumn<3, GNDHWC, BF16, BF16>>>& instances);
void add_device_image_to_column_nhwc_3d_f16_instances(
std::vector<std::unique_ptr<DeviceImageToColumn<3, GNDHWC, F16, F16>>>& instances);
void add_device_image_to_column_nhwc_3d_f32_instances(
std::vector<std::unique_ptr<DeviceImageToColumn<3, GNDHWC, F32, F32>>>& instances);
void add_device_image_to_column_nhwc_3d_i8_instances(
std::vector<std::unique_ptr<DeviceImageToColumn<3, GNDHWC, int8_t, int8_t>>>& instances);
template <ck::index_t NumDimSpatial, typename InLayout, typename InDataType, typename OutDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::
DeviceImageToColumn<NumDimSpatial, InLayout, InDataType, OutDataType>>
{
using DeviceOp = DeviceImageToColumn<NumDimSpatial, InLayout, InDataType, OutDataType>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 1 && is_same_v<InLayout, GNWC>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
{
add_device_image_to_column_nhwc_1d_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
{
add_device_image_to_column_nhwc_1d_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_image_to_column_nhwc_1d_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
{
add_device_image_to_column_nhwc_1d_i8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
{
add_device_image_to_column_nhwc_2d_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
{
add_device_image_to_column_nhwc_2d_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_image_to_column_nhwc_2d_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
{
add_device_image_to_column_nhwc_2d_i8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, GNDHWC>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<OutDataType, float>)
{
add_device_image_to_column_nhwc_3d_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<OutDataType, half_t>)
{
add_device_image_to_column_nhwc_3d_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_image_to_column_nhwc_3d_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<OutDataType, int8_t>)
{
add_device_image_to_column_nhwc_3d_i8_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -2,13 +2,20 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_norm2.hpp"
......@@ -18,39 +25,10 @@
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_norm2.hpp"
......@@ -60,17 +38,38 @@
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_norm2.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_min.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_max.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_add.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_avg.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_norm2.hpp"
......
function(add_instance_library INSTANCE_NAME)
message("adding instance ${INSTANCE_NAME}")
add_library(${INSTANCE_NAME} OBJECT ${ARGN})
target_compile_features(${INSTANCE_NAME} PUBLIC)
set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(${INSTANCE_NAME})
set(result 1)
if(DEFINED DTYPES)
foreach(source IN LISTS ARGN)
set(test 0)
foreach(type IN LISTS DTYPES)
if(type MATCHES "fp16")
set(type1 "_f16")
elseif(type MATCHES "fp32")
set(type1 "_f32")
elseif(type MATCHES "fp8")
set(type1 "_f8")
elseif(type MATCHES "bf16")
set(type1 "_b16")
elseif(type MATCHES "fp64")
set(type1 "_f64")
elseif(type MATCHES "int8")
set(type1 "_i8")
endif()
#make an exception for reduction kernels
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}" OR "${source}" MATCHES "device_reduce_instance")
#if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0)
break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
endif()
endforeach()
if(test EQUAL 1)
message("removing instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
endif()
foreach(source IN LISTS ARGN)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(ARGN)
add_library(${INSTANCE_NAME} OBJECT ${ARGN})
target_compile_features(${INSTANCE_NAME} PUBLIC)
set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(${INSTANCE_NAME})
set(result 0)
endif()
#message("add_instance_library returns ${result}")
return(PROPAGATE result)
endfunction(add_instance_library INSTANCE_NAME)
......@@ -15,33 +63,49 @@ IF(IS_DIRECTORY "${subdir_path}")
set(cmake_instance)
file(READ "${subdir_path}/CMakeLists.txt" cmake_instance)
set(add_inst 0)
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp8\" " AND DTYPES MATCHES "fp8")
#message("fp8 instance found!")
if(("${cmake_instance}" MATCHES "_fp8" OR "${cmake_instance}" MATCHES "_f8") AND DTYPES MATCHES "fp8")
message("fp8 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp16\"" AND DTYPES MATCHES "fp16")
#message("fp16 instance found!")
if(("${cmake_instance}" MATCHES "_fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16")
message("fp16 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp32\"" AND DTYPES MATCHES "fp32")
#message("fp32 instance found!")
if(("${cmake_instance}" MATCHES "_fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32")
message("fp32 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp64\"" AND DTYPES MATCHES "fp64")
#message("fp64 instance found!")
if(("${cmake_instance}" MATCHES "_fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64")
message("fp64 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"bf16\"" AND DTYPES MATCHES "bf16")
#message("bf16 instance found!")
if("${cmake_instance}" MATCHES "_bf16" AND DTYPES MATCHES "bf16")
message("bf16 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"int8\"" AND DTYPES MATCHES "int8")
#message("int8 instance found!")
if(("${cmake_instance}" MATCHES "_int8" OR "${cmake_instance}" MATCHES "_i8") AND DTYPES MATCHES "int8")
message("int8 instance found!")
set(add_inst 1)
endif()
if(NOT "${cmake_instance}" MATCHES "DTYPES" OR NOT DEFINED DTYPES)
#message("instance should be built for all types!")
set(add_inst 1)
if(NOT "${cmake_instance}" MATCHES "_fp8" OR
NOT "${cmake_instance}" MATCHES "_f8" OR
NOT "${cmake_instance}" MATCHES "_fp16" OR
NOT "${cmake_instance}" MATCHES "_f16" OR
NOT "${cmake_instance}" MATCHES "_fp32" OR
NOT "${cmake_instance}" MATCHES "_f32" OR
NOT "${cmake_instance}" MATCHES "_fp64" OR
NOT "${cmake_instance}" MATCHES "_f64" OR
NOT "${cmake_instance}" MATCHES "_bf16" OR
NOT "${cmake_instance}" MATCHES "_int8" OR
NOT "${cmake_instance}" MATCHES "_i8" OR
NOT "${cmake_instance}" MATCHES "_int4" OR
NOT DEFINED DTYPES)
message("instance should be built for all types!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "quantization" AND DEFINED DTYPES AND NOT DTYPES MATCHES "int8")
message("quantization instances will not be built!")
set(add_inst 0)
endif()
if("${cmake_instance}" MATCHES "ONLY DL_KERNELS" AND NOT DEFINED DL_KERNELS)
message("Found only dl instances, but DL_KERNELS is not set. Skipping.")
......
set(DEVICE_AVGPOOL_BWD_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f16_instance.cpp)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_bf16_instance.cpp)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f32_instance.cpp)
endif()
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f16_instance.cpp
device_avg_pool3d_bwd_ndhwc_bf16_instance.cpp
device_avg_pool3d_bwd_ndhwc_f32_instance.cpp)
add_instance_library(device_avg_pool3d_bwd_instance ${DEVICE_AVGPOOL_BWD_INSTANCES})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment