Unverified Commit 87ea11d0 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge pull request #192 from ROCm/merge_from_public

Merge from public
parents 171ed358 09d4c3a4
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -24,6 +25,14 @@ struct BlockGemmPipelineAGmemBGmemCRegV1 ...@@ -24,6 +25,14 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
static constexpr index_t kNPerBlock = BlockGemmShape::kN; static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK; static constexpr index_t kKPerBlock = BlockGemmShape::kK;
static constexpr index_t AlignmentA = Problem::AlignmentA;
static constexpr index_t AlignmentB = Problem::AlignmentB;
static constexpr index_t AlignmentC = Problem::AlignmentC;
static constexpr bool kPadA = Problem::kPadA;
static constexpr bool kPadB = Problem::kPadB;
static constexpr bool kPadC = Problem::kPadC;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
{ {
return ck_tile::integer_divide_ceil( return ck_tile::integer_divide_ceil(
...@@ -35,6 +44,11 @@ struct BlockGemmPipelineAGmemBGmemCRegV1 ...@@ -35,6 +44,11 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size(); Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
} }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename ADramBlockWindowTmp, template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp, typename BDramBlockWindowTmp,
typename AElementFunction, typename AElementFunction,
...@@ -140,8 +154,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1 ...@@ -140,8 +154,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
} }
index_t iCounter = num_loop - 1; index_t iCounter = num_loop - 1;
while(iCounter > 0)
do
{ {
// global read i + 1 // global read i + 1
a_block_tile = load_tile(a_copy_dram_window); a_block_tile = load_tile(a_copy_dram_window);
...@@ -167,8 +180,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1 ...@@ -167,8 +180,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
store_tile(b_copy_lds_window, b_block_tile_tmp); store_tile(b_copy_lds_window, b_block_tile_tmp);
iCounter--; iCounter--;
}
} while(iCounter > 0);
// tail // tail
{ {
......
...@@ -91,6 +91,33 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -91,6 +91,33 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
return b_lds_block_desc; return b_lds_block_desc;
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeA()
{
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_a;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeB()
{
constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_b;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
index_t smem_size = 0;
smem_size += smem_size_a + smem_size_b;
return smem_size;
}
#elif 1 #elif 1
// fake XOR // fake XOR
template <typename Problem> template <typename Problem>
...@@ -178,6 +205,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -178,6 +205,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr index_t M2 = get_warp_size() / K0; constexpr index_t M2 = get_warp_size() / K0;
#if 1 // coalesce reading for each blocks #if 1 // coalesce reading for each blocks
constexpr index_t M1 = kBlockSize / get_warp_size(); constexpr index_t M1 = kBlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = kMPerBlock / (M2 * M1); constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution( return make_static_tile_distribution(
...@@ -216,6 +245,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -216,6 +245,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr index_t N2 = get_warp_size() / K0; constexpr index_t N2 = get_warp_size() / K0;
#if 1 // coalesce reading for each blocks #if 1 // coalesce reading for each blocks
constexpr index_t N1 = kBlockSize / get_warp_size(); constexpr index_t N1 = kBlockSize / get_warp_size();
static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = kNPerBlock / (N2 * N1); constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution( return make_static_tile_distribution(
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
namespace ck_tile { namespace ck_tile {
......
...@@ -5,13 +5,17 @@ ...@@ -5,13 +5,17 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#define VectorLoadSize 16
namespace ck_tile { namespace ck_tile {
template <typename ADataType_, template <typename ADataType_,
typename BDataType_, typename BDataType_,
typename CDataType_, typename CDataType_,
index_t kBlockSize_, typename BlockGemmShape_,
typename BlockGemmShape_> bool kPadA_ = false,
bool kPadB_ = false,
bool kPadC_ = false>
struct BlockGemmPipelineProblem struct BlockGemmPipelineProblem
{ {
using ADataType = remove_cvref_t<ADataType_>; using ADataType = remove_cvref_t<ADataType_>;
...@@ -19,7 +23,14 @@ struct BlockGemmPipelineProblem ...@@ -19,7 +23,14 @@ struct BlockGemmPipelineProblem
using CDataType = remove_cvref_t<CDataType_>; using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>; using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
static constexpr index_t kBlockSize = kBlockSize_; static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadA = kPadA_;
static constexpr bool kPadB = kPadB_;
static constexpr bool kPadC = kPadC_;
static constexpr index_t AlignmentA = kPadA ? 1 : VectorLoadSize / sizeof(ADataType);
static constexpr index_t AlignmentB = kPadB ? 1 : VectorLoadSize / sizeof(BDataType);
static constexpr index_t AlignmentC = kPadC ? 1 : VectorLoadSize / sizeof(CDataType);
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -7,12 +7,18 @@ ...@@ -7,12 +7,18 @@
namespace ck_tile { namespace ck_tile {
template <index_t kMPerTile, index_t kNPerTile, index_t kKPerTile> template <typename BlockTile_, typename BlockWarps_, typename WarpTile_>
struct TileGemmShape struct TileGemmShape
{ {
static constexpr index_t kM = kMPerTile; using BlockTile = remove_cvref_t<BlockTile_>;
static constexpr index_t kN = kNPerTile; using BlockWarps = remove_cvref_t<BlockWarps_>;
static constexpr index_t kK = kKPerTile; using WarpTile = remove_cvref_t<WarpTile_>;
static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
static constexpr index_t kM = BlockTile::at(number<0>{});
static constexpr index_t kN = BlockTile::at(number<1>{});
static constexpr index_t kK = BlockTile::at(number<2>{});
}; };
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp"
#include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp"
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace ck_tile {
template <typename Problem_>
struct ImageToColumn
{
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr auto I3 = number<3>{};
static constexpr auto I4 = number<4>{};
using Problem = remove_cvref_t<Problem_>;
using InDataType = remove_cvref_t<typename Problem::InDataType>;
using OutDataType = remove_cvref_t<typename Problem::OutDataType>;
static constexpr index_t NDimSpatial = Problem::NDimSpatial;
static constexpr index_t AligmentIn = Problem::AligmentIn;
static constexpr index_t AligmentOut = Problem::AligmentOut;
static_assert(NDimSpatial == 2, "Not supported.");
static constexpr index_t kMPerBlock = Problem::BlockShape::kMPerBlock;
static constexpr index_t kKPerBlock = Problem::BlockShape::kKPerBlock;
struct Kargs
{
const void* p_in;
void* p_out;
const long_index_t G;
const long_index_t N;
const long_index_t C;
const array<long_index_t, NDimSpatial> input_spatial_lengths;
const array<long_index_t, NDimSpatial> filter_spatial_lengths;
const array<long_index_t, NDimSpatial> output_spatial_lengths;
const array<long_index_t, NDimSpatial + 3> image_g_n_c_wis_strides;
const array<long_index_t, 3> gemm_g_m_k_strides;
const array<long_index_t, NDimSpatial> conv_filter_strides;
const array<long_index_t, NDimSpatial> conv_filter_dilations;
const array<long_index_t, NDimSpatial> input_left_pads;
const array<long_index_t, NDimSpatial> input_right_pads;
};
CK_TILE_HOST static constexpr Kargs
MakeKargs(const void* p_in,
void* p_out,
const long_index_t G,
const long_index_t N,
const long_index_t C,
const array<long_index_t, NDimSpatial> input_spatial_lengths,
const array<long_index_t, NDimSpatial> filter_spatial_lengths,
const array<long_index_t, NDimSpatial> output_spatial_lengths,
const array<long_index_t, NDimSpatial + 3> image_g_n_c_wis_strides,
const array<long_index_t, 3> gemm_g_m_k_strides,
const array<long_index_t, NDimSpatial> conv_filter_strides,
const array<long_index_t, NDimSpatial> conv_filter_dilations,
const array<long_index_t, NDimSpatial> input_left_pads,
const array<long_index_t, NDimSpatial> input_right_pads)
{
return Kargs{p_in,
p_out,
G,
N,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
image_g_n_c_wis_strides,
gemm_g_m_k_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
}
CK_TILE_HOST static constexpr auto GridSize(index_t GemmM, index_t GemmK, index_t Batch)
{
return dim3(
integer_divide_ceil(GemmM, kMPerBlock), integer_divide_ceil(GemmK, kKPerBlock), Batch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::kBlockSize; }
CK_TILE_DEVICE auto MakeImageMKDesc(const Kargs& kargs) const
{
static_assert(NDimSpatial == 2, "Not supported.");
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(
kargs.N, kargs.input_spatial_lengths[I0], kargs.input_spatial_lengths[I1], kargs.C),
make_tuple(kargs.image_g_n_c_wis_strides[I1],
kargs.image_g_n_c_wis_strides[I3],
kargs.image_g_n_c_wis_strides[I4],
kargs.image_g_n_c_wis_strides[I2]),
number<AligmentIn>{},
I1);
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_hi_wi_c_desc,
make_tuple(make_pass_through_transform(kargs.N),
make_pad_transform(kargs.input_spatial_lengths[I0],
kargs.input_left_pads[I0],
kargs.input_right_pads[I0]),
make_pad_transform(kargs.input_spatial_lengths[I1],
kargs.input_left_pads[I1],
kargs.input_right_pads[I1]),
make_pass_through_transform(kargs.C)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
make_tuple(
make_pass_through_transform(kargs.N),
make_embed_transform(
make_tuple(kargs.filter_spatial_lengths[I0], kargs.output_spatial_lengths[I0]),
make_tuple(kargs.conv_filter_dilations[I0], kargs.conv_filter_strides[I0])),
make_embed_transform(
make_tuple(kargs.filter_spatial_lengths[I1], kargs.output_spatial_lengths[I1]),
make_tuple(kargs.conv_filter_dilations[I1], kargs.conv_filter_strides[I1])),
make_pass_through_transform(kargs.C)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{}));
return transform_tensor_descriptor(
in_n_y_ho_x_wo_c_desc,
make_tuple(
make_merge_transform(make_tuple(
kargs.N, kargs.output_spatial_lengths[I0], kargs.output_spatial_lengths[I1])),
make_merge_transform(make_tuple(
kargs.filter_spatial_lengths[I0], kargs.filter_spatial_lengths[I1], kargs.C))),
make_tuple(sequence<0, 2, 4>{}, sequence<1, 3, 5>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
CK_TILE_DEVICE auto CalculateMKDims(const Kargs& kargs) const
{
static_assert(NDimSpatial == 2, "Not supported.");
const index_t M = kargs.N * static_cast<index_t>(kargs.output_spatial_lengths[I0] *
kargs.output_spatial_lengths[I1]);
const index_t K = kargs.C * static_cast<index_t>(kargs.filter_spatial_lengths[I0] *
kargs.filter_spatial_lengths[I1]);
return make_tuple(M, K);
}
CK_TILE_DEVICE static constexpr auto MakeBlockTileDistribution()
{
using P = typename Problem::BlockShape;
// P: {kMWarpPerBlock * kKWarpPerBlock, kMThreadPerWarp * kKThreadPerWarp}
// Y: {kMPerThread, kKPerThread}
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<P::kMWarpPerBlock, P::kMThreadPerWarp, P::kMPerThread>,
sequence<P::kKWarpPerBlock, P::kKThreadPerWarp, P::kKPerThread>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<1, 2>,
sequence<2, 2>>{});
}
CK_TILE_DEVICE void ConvTensorRearrange(const Kargs& kargs) const
{
const auto [M, K] = CalculateMKDims(kargs);
const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock);
const index_t iK = __builtin_amdgcn_readfirstlane(blockIdx.y * kKPerBlock);
const index_t iBatch = __builtin_amdgcn_readfirstlane(blockIdx.z);
const auto in_offset = iBatch * kargs.image_g_n_c_wis_strides[I0];
const auto out_offset = iBatch * kargs.gemm_g_m_k_strides[I0];
const auto image_m_k = make_tensor_view<address_space_enum::global>(
static_cast<const InDataType*>(kargs.p_in) + in_offset, MakeImageMKDesc(kargs));
const auto gemm_m_k = make_naive_tensor_view<address_space_enum::global>(
static_cast<OutDataType*>(kargs.p_out) + out_offset,
make_tuple(M, K),
make_tuple(kargs.gemm_g_m_k_strides[I1], kargs.gemm_g_m_k_strides[I2]),
number<AligmentOut>{},
I1);
const auto image_m_k_padded =
pad_tensor_view(image_m_k,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
sequence<false, true>{});
const auto gemm_m_k_padded =
pad_tensor_view(gemm_m_k,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
sequence<false, true>{});
constexpr auto dstr = MakeBlockTileDistribution();
const auto image_tile =
make_tile_window(image_m_k_padded,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{iM, iK},
dstr);
auto gemm_tile = make_tile_window(gemm_m_k_padded,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{iM, iK},
dstr);
// load from Global
const auto loaded_tile = load_tile(image_tile);
// save to Global
store_tile(gemm_tile, loaded_tile);
}
CK_TILE_DEVICE void operator()(Kargs& kargs) const { ConvTensorRearrange(kargs); }
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename InDataType_,
typename OutDataType_,
typename BlockShape_,
index_t NDimSpatial_,
index_t AligmentIn_,
index_t AligmentOut_>
struct BlockImageToColumnProblem
{
using InDataType = remove_cvref_t<InDataType_>;
using OutDataType = remove_cvref_t<OutDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr index_t NDimSpatial = NDimSpatial_;
static constexpr index_t AligmentIn = AligmentIn_;
static constexpr index_t AligmentOut = AligmentOut_;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename ThreadTile, // Sequence<...
typename WarpTile, // Sequence<...
typename BlockTile> // Sequence<...
struct TileImageToColumnShape
{
static constexpr index_t kMPerThread = ThreadTile::at(number<0>{});
static constexpr index_t kKPerThread = ThreadTile::at(number<1>{});
static constexpr index_t kMPerWarp = WarpTile::at(number<0>{});
static constexpr index_t kKPerWarp = WarpTile::at(number<1>{});
static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread;
static constexpr index_t kKThreadPerWarp = kKPerWarp / kKPerThread;
static constexpr index_t kMPerBlock = BlockTile::at(number<0>{});
static constexpr index_t kKPerBlock = BlockTile::at(number<1>{});
static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp;
static constexpr index_t kKWarpPerBlock = kKPerBlock / kKPerWarp;
static constexpr index_t kBlockSize = warpSize * kMWarpPerBlock * kKWarpPerBlock;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_BF16
void add_device_avgpool_2D_bwd_nhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<2, BF16, BF16, NHWC, NHWC>>>&);
#endif
#ifdef CK_ENABLE_FP16
void add_device_avgpool_2D_bwd_nhwc_f16_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<2, F16, F16, NHWC, NHWC>>>&);
#endif
#ifdef CK_ENABLE_FP8
void add_device_avgpool_2D_bwd_nhwc_f8_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<2, F8, F8, NHWC, NHWC>>>&);
#endif
#ifdef CK_ENABLE_FP32
void add_device_avgpool_2D_bwd_nhwc_f32_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<2, F32, F32, NHWC, NHWC>>>&);
#endif
#ifdef CK_ENABLE_INT8
void add_device_avgpool_2D_bwd_nhwc_int8_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<2, I8, I8, NHWC, NHWC>>>&);
#endif
template <typename DOutDataType, typename DInDataType, typename InLayout, typename OutLayout>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::
DeviceAvgPoolBwd<2, DOutDataType, DInDataType, InLayout, OutLayout>>
{
using DeviceOp = DeviceAvgPoolBwd<2, DOutDataType, DInDataType, InLayout, OutLayout>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<InLayout, NHWC> && is_same_v<OutLayout, NHWC>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<DOutDataType, F16> && is_same_v<DInDataType, F16>)
add_device_avgpool_2D_bwd_nhwc_f16_instances(op_ptrs);
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<DOutDataType, BF16> && is_same_v<DInDataType, BF16>)
add_device_avgpool_2D_bwd_nhwc_bf16_instances(op_ptrs);
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<DOutDataType, F32> && is_same_v<DInDataType, F32>)
add_device_avgpool_2D_bwd_nhwc_f32_instances(op_ptrs);
#endif
#ifdef CK_ENABLE_FP8
else if constexpr(is_same_v<DOutDataType, F8> && is_same_v<DInDataType, F8>)
add_device_avgpool_2D_bwd_nhwc_f8_instances(op_ptrs);
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<DOutDataType, I8> && is_same_v<DInDataType, I8>)
add_device_avgpool_2D_bwd_nhwc_int8_instances(op_ptrs);
#endif
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -335,6 +335,105 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_insta ...@@ -335,6 +335,105 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_insta
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif #endif
#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8)) #if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8))
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances( void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances(
...@@ -618,6 +717,58 @@ struct DeviceOperationInstanceFactory< ...@@ -618,6 +717,58 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances( add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances(
op_ptrs); op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instances(
op_ptrs);
}
} }
#endif #endif
#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8)) #if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8))
......
...@@ -249,6 +249,40 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -249,6 +249,40 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
} }
#endif #endif
} }
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NGCHW> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NGKHW>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
is_same_v<BComputeType, float>)
{
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances(
op_ptrs);
}
#endif
}
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, GNDHWC> && if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, GNDHWC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, GNDHWK>) is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, GNDHWK>)
......
...@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances( ...@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK // grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances( void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
......
...@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances ...@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK // grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances( void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
......
...@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances ...@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK // grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
......
...@@ -171,6 +171,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances( ...@@ -171,6 +171,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
PassThrough>>>& instances); PassThrough>>>& instances);
#endif #endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK // grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances( void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(
......
...@@ -39,6 +39,20 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_insta ...@@ -39,6 +39,20 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_insta
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
...@@ -55,6 +69,20 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_insta ...@@ -55,6 +69,20 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_insta
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif #endif
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -23,6 +23,15 @@ void add_device_maxpool_bwd_bf16_instances( ...@@ -23,6 +23,15 @@ void add_device_maxpool_bwd_bf16_instances(
void add_device_maxpool_bwd_f32_instances( void add_device_maxpool_bwd_f32_instances(
std::vector<std::unique_ptr<DeviceMaxPoolBwd<F32, I32, F32>>>&); std::vector<std::unique_ptr<DeviceMaxPoolBwd<F32, I32, F32>>>&);
#endif #endif
#ifdef CK_ENABLE_FP8
void add_device_maxpool_bwd_f8_instances(
std::vector<std::unique_ptr<DeviceMaxPoolBwd<F8, I32, F8>>>&);
#endif
#ifdef CK_ENABLE_INT8
void add_device_maxpool_bwd_int8_instances(
std::vector<std::unique_ptr<DeviceMaxPoolBwd<I8, I32, I8>>>&);
#endif
template <typename DOutDataType, typename IndexDataType, typename DInDataType> template <typename DOutDataType, typename IndexDataType, typename DInDataType>
struct DeviceOperationInstanceFactory< struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceMaxPoolBwd<DOutDataType, IndexDataType, DInDataType>> ck::tensor_operation::device::DeviceMaxPoolBwd<DOutDataType, IndexDataType, DInDataType>>
...@@ -32,6 +41,7 @@ struct DeviceOperationInstanceFactory< ...@@ -32,6 +41,7 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
if constexpr(is_same_v<DOutDataType, F16> && is_same_v<DInDataType, F16> && if constexpr(is_same_v<DOutDataType, F16> && is_same_v<DInDataType, F16> &&
is_same_v<IndexDataType, I32>) is_same_v<IndexDataType, I32>)
...@@ -47,6 +57,16 @@ struct DeviceOperationInstanceFactory< ...@@ -47,6 +57,16 @@ struct DeviceOperationInstanceFactory<
is_same_v<IndexDataType, I32>) is_same_v<IndexDataType, I32>)
add_device_maxpool_bwd_f32_instances(op_ptrs); add_device_maxpool_bwd_f32_instances(op_ptrs);
#endif #endif
#ifdef CK_ENABLE_FP8
else if constexpr(is_same_v<DOutDataType, F8> && is_same_v<DInDataType, F8> &&
is_same_v<IndexDataType, I32>)
add_device_maxpool_bwd_f8_instances(op_ptrs);
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<DOutDataType, I8> && is_same_v<DInDataType, I8> &&
is_same_v<IndexDataType, I32>)
add_device_maxpool_bwd_int8_instances(op_ptrs);
#endif
return op_ptrs; return op_ptrs;
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto InOutRank = 4;
static constexpr auto WindowRank = 2;
static constexpr auto MaxOp = ck::ReduceTensorOp::MAX;
static constexpr auto AvgOp = ck::ReduceTensorOp::AVG;
#ifdef CK_ENABLE_FP16
// FP16
void add_device_pool2d_fwd_nhwc_f16_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, NHWC, NHWC, MaxOp, false>>>&);
void add_device_pool2d_fwd_nhwc_f16_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, NHWC, NHWC, AvgOp, false>>>&);
// FP16 - return index
void add_device_pool2d_fwd_nhwc_index_f16_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, NHWC, NHWC, MaxOp, true>>>&);
#endif
#ifdef CK_ENABLE_BF16
// BF16
void add_device_pool2d_fwd_nhwc_bf16_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, BF16, BF16, I32, NHWC, NHWC, MaxOp, false>>>&);
void add_device_pool2d_fwd_nhwc_bf16_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, BF16, BF16, I32, NHWC, NHWC, AvgOp, false>>>&);
// BF16 - return index
void add_device_pool2d_fwd_nhwc_index_bf16_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, BF16, BF16, I32, NHWC, NHWC, MaxOp, true>>>&);
#endif
#ifdef CK_ENABLE_FP32
// FP32
void add_device_pool2d_fwd_nhwc_f32_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, NHWC, NHWC, MaxOp, false>>>&);
void add_device_pool2d_fwd_nhwc_f32_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, NHWC, NHWC, AvgOp, false>>>&);
// FP32 - return index
void add_device_pool2d_fwd_nhwc_index_f32_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, NHWC, NHWC, MaxOp, true>>>&);
#endif
#ifdef CK_ENABLE_INT8
// I8
void add_device_pool2d_fwd_nhwc_i8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, I8, I8, I32, NHWC, NHWC, MaxOp, false>>>&);
void add_device_pool2d_fwd_nhwc_i8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, I8, I8, I32, NHWC, NHWC, AvgOp, false>>>&);
// I8 - return index
void add_device_pool2d_fwd_nhwc_index_i8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, I8, I8, I32, NHWC, NHWC, MaxOp, true>>>&);
#endif
#ifdef CK_ENABLE_FP8
// F8
void add_device_pool2d_fwd_nhwc_f8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F8, F8, I32, NHWC, NHWC, MaxOp, false>>>&);
void add_device_pool2d_fwd_nhwc_f8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F8, F8, I32, NHWC, NHWC, AvgOp, false>>>&);
// F8 - return index
void add_device_pool2d_fwd_nhwc_index_f8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F8, F8, I32, NHWC, NHWC, MaxOp, true>>>&);
#endif
template <typename InDataType,
typename OutDataType,
typename IndexDataType,
typename InLayout,
typename OutLayout,
ck::ReduceTensorOp ReduceOpId,
bool OutputIndex>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFwd<InOutRank,
WindowRank,
InDataType,
OutDataType,
IndexDataType,
InLayout,
OutLayout,
ReduceOpId,
OutputIndex>>
{
using DeviceOp = DevicePoolFwd<InOutRank,
WindowRank,
InDataType,
OutDataType,
IndexDataType,
InLayout,
OutLayout,
ReduceOpId,
OutputIndex>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<InLayout, NHWC> && is_same_v<OutLayout, NHWC>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool2d_fwd_nhwc_index_f16_instances(op_ptrs);
}
else
{
add_device_pool2d_fwd_nhwc_f16_instances(op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<OutDataType, BF16> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool2d_fwd_nhwc_index_bf16_instances(op_ptrs);
}
else
{
add_device_pool2d_fwd_nhwc_bf16_instances(op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool2d_fwd_nhwc_index_f32_instances(op_ptrs);
}
else
{
add_device_pool2d_fwd_nhwc_f32_instances(op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, I8> && is_same_v<OutDataType, I8> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool2d_fwd_nhwc_index_i8_instances(op_ptrs);
}
else
{
add_device_pool2d_fwd_nhwc_i8_instances(op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_FP8
else if constexpr(is_same_v<InDataType, F8> && is_same_v<OutDataType, F8> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool2d_fwd_nhwc_index_f8_instances(op_ptrs);
}
else
{
add_device_pool2d_fwd_nhwc_f8_instances(op_ptrs);
}
}
#endif
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -22,7 +22,7 @@ static constexpr auto WindowRank = 3; ...@@ -22,7 +22,7 @@ static constexpr auto WindowRank = 3;
static constexpr auto MaxOp = ck::ReduceTensorOp::MAX; static constexpr auto MaxOp = ck::ReduceTensorOp::MAX;
static constexpr auto AvgOp = ck::ReduceTensorOp::AVG; static constexpr auto AvgOp = ck::ReduceTensorOp::AVG;
#ifdef CK_ENABLE_FP16
// FP16 // FP16
void add_device_pool3d_fwd_ndhwc_f16_instances( void add_device_pool3d_fwd_ndhwc_f16_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
...@@ -36,8 +36,22 @@ void add_device_pool3d_fwd_ndhwc_f16_instances( ...@@ -36,8 +36,22 @@ void add_device_pool3d_fwd_ndhwc_f16_instances(
void add_device_pool3d_fwd_ndhwc_index_f16_instances( void add_device_pool3d_fwd_ndhwc_index_f16_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, NDHWC, NDHWC, MaxOp, true>>>&); DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, NDHWC, NDHWC, MaxOp, true>>>&);
#endif
#ifdef CK_ENABLE_BF16 using F8 = ck::f8_t;
// F8
void add_device_pool3d_fwd_ndhwc_f8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F8, F8, I32, NDHWC, NDHWC, MaxOp, false>>>&);
void add_device_pool3d_fwd_ndhwc_f8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F8, F8, I32, NDHWC, NDHWC, AvgOp, false>>>&);
// FP8 - return index
void add_device_pool3d_fwd_ndhwc_index_f8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F8, F8, I32, NDHWC, NDHWC, MaxOp, true>>>&);
// BF16 // BF16
void add_device_pool3d_fwd_ndhwc_bf16_instances( void add_device_pool3d_fwd_ndhwc_bf16_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
...@@ -51,8 +65,7 @@ void add_device_pool3d_fwd_ndhwc_bf16_instances( ...@@ -51,8 +65,7 @@ void add_device_pool3d_fwd_ndhwc_bf16_instances(
void add_device_pool3d_fwd_ndhwc_index_bf16_instances( void add_device_pool3d_fwd_ndhwc_index_bf16_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, BF16, BF16, I32, NDHWC, NDHWC, MaxOp, true>>>&); DevicePoolFwd<InOutRank, WindowRank, BF16, BF16, I32, NDHWC, NDHWC, MaxOp, true>>>&);
#endif
#ifdef CK_ENABLE_FP32
// FP32 // FP32
void add_device_pool3d_fwd_ndhwc_f32_instances( void add_device_pool3d_fwd_ndhwc_f32_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
...@@ -66,7 +79,21 @@ void add_device_pool3d_fwd_ndhwc_f32_instances( ...@@ -66,7 +79,21 @@ void add_device_pool3d_fwd_ndhwc_f32_instances(
void add_device_pool3d_fwd_ndhwc_index_f32_instances( void add_device_pool3d_fwd_ndhwc_index_f32_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, NDHWC, NDHWC, MaxOp, true>>>&); DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, NDHWC, NDHWC, MaxOp, true>>>&);
#endif
// I8
void add_device_pool3d_fwd_ndhwc_i8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, I8, I8, I32, NDHWC, NDHWC, MaxOp, false>>>&);
void add_device_pool3d_fwd_ndhwc_i8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, I8, I8, I32, NDHWC, NDHWC, AvgOp, false>>>&);
// I8 - return index
void add_device_pool3d_fwd_ndhwc_index_i8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, I8, I8, I32, NDHWC, NDHWC, MaxOp, true>>>&);
template <typename InDataType, template <typename InDataType,
typename OutDataType, typename OutDataType,
typename IndexDataType, typename IndexDataType,
...@@ -99,7 +126,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw ...@@ -99,7 +126,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<InLayout, NDHWC> && is_same_v<OutLayout, NDHWC>) if constexpr(is_same_v<InLayout, NDHWC> && is_same_v<OutLayout, NDHWC>)
{ {
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> && if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> &&
is_same_v<IndexDataType, I32>) is_same_v<IndexDataType, I32>)
{ {
...@@ -112,8 +138,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw ...@@ -112,8 +138,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
add_device_pool3d_fwd_ndhwc_f16_instances(op_ptrs); add_device_pool3d_fwd_ndhwc_f16_instances(op_ptrs);
} }
} }
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<OutDataType, BF16> && else if constexpr(is_same_v<InDataType, BF16> && is_same_v<OutDataType, BF16> &&
is_same_v<IndexDataType, I32>) is_same_v<IndexDataType, I32>)
{ {
...@@ -126,8 +150,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw ...@@ -126,8 +150,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
add_device_pool3d_fwd_ndhwc_bf16_instances(op_ptrs); add_device_pool3d_fwd_ndhwc_bf16_instances(op_ptrs);
} }
} }
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> && else if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> &&
is_same_v<IndexDataType, I32>) is_same_v<IndexDataType, I32>)
{ {
...@@ -140,7 +162,30 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw ...@@ -140,7 +162,30 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
add_device_pool3d_fwd_ndhwc_f32_instances(op_ptrs); add_device_pool3d_fwd_ndhwc_f32_instances(op_ptrs);
} }
} }
#endif else if constexpr(is_same_v<InDataType, F8> && is_same_v<OutDataType, F8> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool3d_fwd_ndhwc_index_f8_instances(op_ptrs);
}
else
{
add_device_pool3d_fwd_ndhwc_f8_instances(op_ptrs);
}
}
else if constexpr(is_same_v<InDataType, I8> && is_same_v<OutDataType, I8> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool3d_fwd_ndhwc_index_i8_instances(op_ptrs);
}
else
{
add_device_pool3d_fwd_ndhwc_i8_instances(op_ptrs);
}
}
} }
return op_ptrs; return op_ptrs;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment