Commit 9c0811f3 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents ded0d83d 3528a523
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename BlockGemmShape_>
struct GemmTilePartitioner
{
using BlockGemmShape = ck_tile::remove_cvref_t<BlockGemmShape_>;
static constexpr ck_tile::index_t kM = BlockGemmShape::kM;
static constexpr ck_tile::index_t kN = BlockGemmShape::kN;
static constexpr ck_tile::index_t kK = BlockGemmShape::kK;
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t batch_size)
{
ck_tile::index_t GridDimX = (M + kM - 1) / kM;
ck_tile::index_t GridDimY = (N + kN - 1) / kN;
ck_tile::index_t GridDimZ = batch_size;
return dim3(GridDimX, GridDimY, GridDimZ);
}
CK_TILE_DEVICE auto operator()()
{
const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kM);
const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kN);
return ck_tile::make_tuple(iM, iN);
}
};
} // namespace ck_tile
......@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
namespace ck_tile {
......@@ -24,6 +25,14 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
static constexpr index_t AlignmentA = Problem::AlignmentA;
static constexpr index_t AlignmentB = Problem::AlignmentB;
static constexpr index_t AlignmentC = Problem::AlignmentC;
static constexpr bool kPadA = Problem::kPadA;
static constexpr bool kPadB = Problem::kPadB;
static constexpr bool kPadC = Problem::kPadC;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
{
return ck_tile::integer_divide_ceil(
......@@ -35,6 +44,11 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
......@@ -140,8 +154,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
}
index_t iCounter = num_loop - 1;
do
while(iCounter > 0)
{
// global read i + 1
a_block_tile = load_tile(a_copy_dram_window);
......@@ -167,8 +180,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
store_tile(b_copy_lds_window, b_block_tile_tmp);
iCounter--;
} while(iCounter > 0);
}
// tail
{
......
......@@ -91,6 +91,33 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
return b_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeA()
{
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_a;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeB()
{
constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_b;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
index_t smem_size = 0;
smem_size += smem_size_a + smem_size_b;
return smem_size;
}
#elif 1
// fake XOR
template <typename Problem>
......@@ -178,6 +205,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr index_t M2 = get_warp_size() / K0;
#if 1 // coalesce reading for each blocks
constexpr index_t M1 = kBlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution(
......@@ -216,6 +245,8 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr index_t N2 = get_warp_size() / K0;
#if 1 // coalesce reading for each blocks
constexpr index_t N1 = kBlockSize / get_warp_size();
static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
......
......@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
namespace ck_tile {
......
......@@ -5,13 +5,17 @@
#include "ck_tile/core.hpp"
#define VectorLoadSize 16
namespace ck_tile {
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
index_t kBlockSize_,
typename BlockGemmShape_>
typename BlockGemmShape_,
bool kPadA_ = false,
bool kPadB_ = false,
bool kPadC_ = false>
struct BlockGemmPipelineProblem
{
using ADataType = remove_cvref_t<ADataType_>;
......@@ -19,7 +23,14 @@ struct BlockGemmPipelineProblem
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadA = kPadA_;
static constexpr bool kPadB = kPadB_;
static constexpr bool kPadC = kPadC_;
static constexpr index_t AlignmentA = kPadA ? 1 : VectorLoadSize / sizeof(ADataType);
static constexpr index_t AlignmentB = kPadB ? 1 : VectorLoadSize / sizeof(BDataType);
static constexpr index_t AlignmentC = kPadC ? 1 : VectorLoadSize / sizeof(CDataType);
};
} // namespace ck_tile
......@@ -7,12 +7,18 @@
namespace ck_tile {
template <index_t kMPerTile, index_t kNPerTile, index_t kKPerTile>
template <typename BlockTile_, typename BlockWarps_, typename WarpTile_>
struct TileGemmShape
{
static constexpr index_t kM = kMPerTile;
static constexpr index_t kN = kNPerTile;
static constexpr index_t kK = kKPerTile;
using BlockTile = remove_cvref_t<BlockTile_>;
using BlockWarps = remove_cvref_t<BlockWarps_>;
using WarpTile = remove_cvref_t<WarpTile_>;
static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
static constexpr index_t kM = BlockTile::at(number<0>{});
static constexpr index_t kN = BlockTile::at(number<1>{});
static constexpr index_t kK = BlockTile::at(number<2>{});
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_BF16
void add_device_avgpool_2D_bwd_nhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<2, BF16, BF16, NHWC, NHWC>>>&);
#endif
#ifdef CK_ENABLE_FP16
void add_device_avgpool_2D_bwd_nhwc_f16_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<2, F16, F16, NHWC, NHWC>>>&);
#endif
#ifdef CK_ENABLE_FP8
void add_device_avgpool_2D_bwd_nhwc_f8_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<2, F8, F8, NHWC, NHWC>>>&);
#endif
#ifdef CK_ENABLE_FP32
void add_device_avgpool_2D_bwd_nhwc_f32_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<2, F32, F32, NHWC, NHWC>>>&);
#endif
#ifdef CK_ENABLE_INT8
void add_device_avgpool_2D_bwd_nhwc_int8_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<2, I8, I8, NHWC, NHWC>>>&);
#endif
template <typename DOutDataType, typename DInDataType, typename InLayout, typename OutLayout>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::
DeviceAvgPoolBwd<2, DOutDataType, DInDataType, InLayout, OutLayout>>
{
using DeviceOp = DeviceAvgPoolBwd<2, DOutDataType, DInDataType, InLayout, OutLayout>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<InLayout, NHWC> && is_same_v<OutLayout, NHWC>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<DOutDataType, F16> && is_same_v<DInDataType, F16>)
add_device_avgpool_2D_bwd_nhwc_f16_instances(op_ptrs);
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<DOutDataType, BF16> && is_same_v<DInDataType, BF16>)
add_device_avgpool_2D_bwd_nhwc_bf16_instances(op_ptrs);
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<DOutDataType, F32> && is_same_v<DInDataType, F32>)
add_device_avgpool_2D_bwd_nhwc_f32_instances(op_ptrs);
#endif
#ifdef CK_ENABLE_FP8
else if constexpr(is_same_v<DOutDataType, F8> && is_same_v<DInDataType, F8>)
add_device_avgpool_2D_bwd_nhwc_f8_instances(op_ptrs);
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<DOutDataType, I8> && is_same_v<DInDataType, I8>)
add_device_avgpool_2D_bwd_nhwc_int8_instances(op_ptrs);
#endif
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -335,6 +335,105 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_insta
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8))
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances(
......@@ -618,6 +717,58 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instances(
op_ptrs);
}
}
#endif
#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8))
......
......@@ -249,6 +249,40 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
#endif
}
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NGCHW> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NGKHW>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
is_same_v<BComputeType, float>)
{
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
is_same_v<BComputeType, half_t>)
{
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances(
op_ptrs);
}
#endif
}
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, GNDHWC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, GNDHWK>)
......
......@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
PassThrough>>>& instances);
#endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
......
......@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances
PassThrough>>>& instances);
#endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
......
......@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
PassThrough>>>& instances);
#endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
......
......@@ -171,6 +171,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
PassThrough>>>& instances);
#endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(
......
......@@ -39,6 +39,20 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_insta
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
......@@ -55,6 +69,20 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_insta
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKYXC,
Empty_Tuple,
NGKHW,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -23,6 +23,15 @@ void add_device_maxpool_bwd_bf16_instances(
void add_device_maxpool_bwd_f32_instances(
std::vector<std::unique_ptr<DeviceMaxPoolBwd<F32, I32, F32>>>&);
#endif
#ifdef CK_ENABLE_FP8
void add_device_maxpool_bwd_f8_instances(
std::vector<std::unique_ptr<DeviceMaxPoolBwd<F8, I32, F8>>>&);
#endif
#ifdef CK_ENABLE_INT8
void add_device_maxpool_bwd_int8_instances(
std::vector<std::unique_ptr<DeviceMaxPoolBwd<I8, I32, I8>>>&);
#endif
template <typename DOutDataType, typename IndexDataType, typename DInDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceMaxPoolBwd<DOutDataType, IndexDataType, DInDataType>>
......@@ -32,6 +41,7 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<DOutDataType, F16> && is_same_v<DInDataType, F16> &&
is_same_v<IndexDataType, I32>)
......@@ -47,6 +57,16 @@ struct DeviceOperationInstanceFactory<
is_same_v<IndexDataType, I32>)
add_device_maxpool_bwd_f32_instances(op_ptrs);
#endif
#ifdef CK_ENABLE_FP8
else if constexpr(is_same_v<DOutDataType, F8> && is_same_v<DInDataType, F8> &&
is_same_v<IndexDataType, I32>)
add_device_maxpool_bwd_f8_instances(op_ptrs);
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<DOutDataType, I8> && is_same_v<DInDataType, I8> &&
is_same_v<IndexDataType, I32>)
add_device_maxpool_bwd_int8_instances(op_ptrs);
#endif
return op_ptrs;
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto InOutRank = 4;
static constexpr auto WindowRank = 2;
static constexpr auto MaxOp = ck::ReduceTensorOp::MAX;
static constexpr auto AvgOp = ck::ReduceTensorOp::AVG;
#ifdef CK_ENABLE_FP16
// FP16
void add_device_pool2d_fwd_nhwc_f16_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, NHWC, NHWC, MaxOp, false>>>&);
void add_device_pool2d_fwd_nhwc_f16_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, NHWC, NHWC, AvgOp, false>>>&);
// FP16 - return index
void add_device_pool2d_fwd_nhwc_index_f16_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, NHWC, NHWC, MaxOp, true>>>&);
#endif
#ifdef CK_ENABLE_BF16
// BF16
void add_device_pool2d_fwd_nhwc_bf16_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, BF16, BF16, I32, NHWC, NHWC, MaxOp, false>>>&);
void add_device_pool2d_fwd_nhwc_bf16_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, BF16, BF16, I32, NHWC, NHWC, AvgOp, false>>>&);
// BF16 - return index
void add_device_pool2d_fwd_nhwc_index_bf16_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, BF16, BF16, I32, NHWC, NHWC, MaxOp, true>>>&);
#endif
#ifdef CK_ENABLE_FP32
// FP32
void add_device_pool2d_fwd_nhwc_f32_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, NHWC, NHWC, MaxOp, false>>>&);
void add_device_pool2d_fwd_nhwc_f32_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, NHWC, NHWC, AvgOp, false>>>&);
// FP32 - return index
void add_device_pool2d_fwd_nhwc_index_f32_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, NHWC, NHWC, MaxOp, true>>>&);
#endif
#ifdef CK_ENABLE_INT8
// I8
void add_device_pool2d_fwd_nhwc_i8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, I8, I8, I32, NHWC, NHWC, MaxOp, false>>>&);
void add_device_pool2d_fwd_nhwc_i8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, I8, I8, I32, NHWC, NHWC, AvgOp, false>>>&);
// I8 - return index
void add_device_pool2d_fwd_nhwc_index_i8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, I8, I8, I32, NHWC, NHWC, MaxOp, true>>>&);
#endif
#ifdef CK_ENABLE_FP8
// F8
void add_device_pool2d_fwd_nhwc_f8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F8, F8, I32, NHWC, NHWC, MaxOp, false>>>&);
void add_device_pool2d_fwd_nhwc_f8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F8, F8, I32, NHWC, NHWC, AvgOp, false>>>&);
// F8 - return index
void add_device_pool2d_fwd_nhwc_index_f8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F8, F8, I32, NHWC, NHWC, MaxOp, true>>>&);
#endif
template <typename InDataType,
typename OutDataType,
typename IndexDataType,
typename InLayout,
typename OutLayout,
ck::ReduceTensorOp ReduceOpId,
bool OutputIndex>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFwd<InOutRank,
WindowRank,
InDataType,
OutDataType,
IndexDataType,
InLayout,
OutLayout,
ReduceOpId,
OutputIndex>>
{
using DeviceOp = DevicePoolFwd<InOutRank,
WindowRank,
InDataType,
OutDataType,
IndexDataType,
InLayout,
OutLayout,
ReduceOpId,
OutputIndex>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<InLayout, NHWC> && is_same_v<OutLayout, NHWC>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool2d_fwd_nhwc_index_f16_instances(op_ptrs);
}
else
{
add_device_pool2d_fwd_nhwc_f16_instances(op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<OutDataType, BF16> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool2d_fwd_nhwc_index_bf16_instances(op_ptrs);
}
else
{
add_device_pool2d_fwd_nhwc_bf16_instances(op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool2d_fwd_nhwc_index_f32_instances(op_ptrs);
}
else
{
add_device_pool2d_fwd_nhwc_f32_instances(op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, I8> && is_same_v<OutDataType, I8> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool2d_fwd_nhwc_index_i8_instances(op_ptrs);
}
else
{
add_device_pool2d_fwd_nhwc_i8_instances(op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_FP8
else if constexpr(is_same_v<InDataType, F8> && is_same_v<OutDataType, F8> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool2d_fwd_nhwc_index_f8_instances(op_ptrs);
}
else
{
add_device_pool2d_fwd_nhwc_f8_instances(op_ptrs);
}
}
#endif
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -22,7 +22,7 @@ static constexpr auto WindowRank = 3;
static constexpr auto MaxOp = ck::ReduceTensorOp::MAX;
static constexpr auto AvgOp = ck::ReduceTensorOp::AVG;
#ifdef CK_ENABLE_FP16
// FP16
void add_device_pool3d_fwd_ndhwc_f16_instances(
std::vector<std::unique_ptr<
......@@ -36,8 +36,22 @@ void add_device_pool3d_fwd_ndhwc_f16_instances(
void add_device_pool3d_fwd_ndhwc_index_f16_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, NDHWC, NDHWC, MaxOp, true>>>&);
#endif
#ifdef CK_ENABLE_BF16
using F8 = ck::f8_t;
// F8
void add_device_pool3d_fwd_ndhwc_f8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F8, F8, I32, NDHWC, NDHWC, MaxOp, false>>>&);
void add_device_pool3d_fwd_ndhwc_f8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F8, F8, I32, NDHWC, NDHWC, AvgOp, false>>>&);
// FP8 - return index
void add_device_pool3d_fwd_ndhwc_index_f8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F8, F8, I32, NDHWC, NDHWC, MaxOp, true>>>&);
// BF16
void add_device_pool3d_fwd_ndhwc_bf16_instances(
std::vector<std::unique_ptr<
......@@ -51,8 +65,7 @@ void add_device_pool3d_fwd_ndhwc_bf16_instances(
void add_device_pool3d_fwd_ndhwc_index_bf16_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, BF16, BF16, I32, NDHWC, NDHWC, MaxOp, true>>>&);
#endif
#ifdef CK_ENABLE_FP32
// FP32
void add_device_pool3d_fwd_ndhwc_f32_instances(
std::vector<std::unique_ptr<
......@@ -66,7 +79,21 @@ void add_device_pool3d_fwd_ndhwc_f32_instances(
void add_device_pool3d_fwd_ndhwc_index_f32_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, NDHWC, NDHWC, MaxOp, true>>>&);
#endif
// I8
void add_device_pool3d_fwd_ndhwc_i8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, I8, I8, I32, NDHWC, NDHWC, MaxOp, false>>>&);
void add_device_pool3d_fwd_ndhwc_i8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, I8, I8, I32, NDHWC, NDHWC, AvgOp, false>>>&);
// I8 - return index
void add_device_pool3d_fwd_ndhwc_index_i8_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, I8, I8, I32, NDHWC, NDHWC, MaxOp, true>>>&);
template <typename InDataType,
typename OutDataType,
typename IndexDataType,
......@@ -99,7 +126,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<InLayout, NDHWC> && is_same_v<OutLayout, NDHWC>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> &&
is_same_v<IndexDataType, I32>)
{
......@@ -112,8 +138,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
add_device_pool3d_fwd_ndhwc_f16_instances(op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<OutDataType, BF16> &&
is_same_v<IndexDataType, I32>)
{
......@@ -126,8 +150,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
add_device_pool3d_fwd_ndhwc_bf16_instances(op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> &&
is_same_v<IndexDataType, I32>)
{
......@@ -140,7 +162,30 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
add_device_pool3d_fwd_ndhwc_f32_instances(op_ptrs);
}
}
#endif
else if constexpr(is_same_v<InDataType, F8> && is_same_v<OutDataType, F8> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool3d_fwd_ndhwc_index_f8_instances(op_ptrs);
}
else
{
add_device_pool3d_fwd_ndhwc_f8_instances(op_ptrs);
}
}
else if constexpr(is_same_v<InDataType, I8> && is_same_v<OutDataType, I8> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool3d_fwd_ndhwc_index_i8_instances(op_ptrs);
}
else
{
add_device_pool3d_fwd_ndhwc_i8_instances(op_ptrs);
}
}
}
return op_ptrs;
......
......@@ -102,12 +102,14 @@ function(add_instance_library INSTANCE_NAME)
set(FMHA_FWD_FAST_EXP2 true)
endif()
if(FMHA_FWD_FAST_EXP2)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
list(APPEND FMHA_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
list(APPEND FMHA_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
endif()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal)
target_compile_options(device_mha_instance PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS})
list(APPEND FMHA_COMPILE_OPTIONS -Wno-float-equal)
list(APPEND FMHA_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1)
list(APPEND FMHA_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1)
target_compile_options(device_mha_instance PRIVATE ${FMHA_COMPILE_OPTIONS})
endif()
target_compile_features(${INSTANCE_NAME} PUBLIC)
......
set(DEVICE_AVGPOOL_2D_BWD_INSTANCES)
list(APPEND DEVICE_AVGPOOL_2D_BWD_INSTANCES device_avg_pool2d_bwd_nhwc_bf16_instance.cpp
device_avg_pool2d_bwd_nhwc_f16_instance.cpp
device_avg_pool2d_bwd_nhwc_f32_instance.cpp
device_avg_pool2d_bwd_nhwc_f8_instance.cpp
device_avg_pool2d_bwd_nhwc_int8_instance.cpp)
add_instance_library(device_avg_pool2d_bwd_instance ${DEVICE_AVGPOOL_2D_BWD_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_avg_pool2d_bwd_nhwc_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_avgpool_2D_bwd_nhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<2, BF16, BF16, NHWC, NHWC>>>& instances)
{
add_device_operation_instances(instances,
device_avgpool_2D_bwd_nhwc_instances<BF16, BF16, F32>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment