Unverified Commit 87ea11d0 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge pull request #192 from ROCm/merge_from_public

Merge from public
parents 171ed358 09d4c3a4
...@@ -102,12 +102,14 @@ function(add_instance_library INSTANCE_NAME) ...@@ -102,12 +102,14 @@ function(add_instance_library INSTANCE_NAME)
set(FMHA_FWD_FAST_EXP2 true) set(FMHA_FWD_FAST_EXP2 true)
endif() endif()
if(FMHA_FWD_FAST_EXP2) if(FMHA_FWD_FAST_EXP2)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) list(APPEND FMHA_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
else() else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) list(APPEND FMHA_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
endif() endif()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) list(APPEND FMHA_COMPILE_OPTIONS -Wno-float-equal)
target_compile_options(device_mha_instance PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS}) list(APPEND FMHA_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1)
list(APPEND FMHA_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1)
target_compile_options(device_mha_instance PRIVATE ${FMHA_COMPILE_OPTIONS})
endif() endif()
target_compile_features(${INSTANCE_NAME} PUBLIC) target_compile_features(${INSTANCE_NAME} PUBLIC)
......
set(DEVICE_AVGPOOL_2D_BWD_INSTANCES)
list(APPEND DEVICE_AVGPOOL_2D_BWD_INSTANCES device_avg_pool2d_bwd_nhwc_bf16_instance.cpp
device_avg_pool2d_bwd_nhwc_f16_instance.cpp
device_avg_pool2d_bwd_nhwc_f32_instance.cpp
device_avg_pool2d_bwd_nhwc_f8_instance.cpp
device_avg_pool2d_bwd_nhwc_int8_instance.cpp)
add_instance_library(device_avg_pool2d_bwd_instance ${DEVICE_AVGPOOL_2D_BWD_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_avg_pool2d_bwd_nhwc_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_avgpool_2D_bwd_nhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<2, BF16, BF16, NHWC, NHWC>>>& instances)
{
add_device_operation_instances(instances,
device_avgpool_2D_bwd_nhwc_instances<BF16, BF16, F32>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_avg_pool2d_bwd_nhwc_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_avgpool_2D_bwd_nhwc_f16_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<2, F16, F16, NHWC, NHWC>>>& instances)
{
add_device_operation_instances(instances,
device_avgpool_2D_bwd_nhwc_instances<F16, F16, F32>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_avg_pool2d_bwd_nhwc_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_avgpool_2D_bwd_nhwc_f32_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<2, F32, F32, NHWC, NHWC>>>& instances)
{
add_device_operation_instances(instances,
device_avgpool_2D_bwd_nhwc_instances<F32, F32, F32>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_avg_pool2d_bwd_nhwc_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_avgpool_2D_bwd_nhwc_f8_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<2, F8, F8, NHWC, NHWC>>>& instances)
{
add_device_operation_instances(instances, device_avgpool_2D_bwd_nhwc_instances<F8, F8, F32>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_avgpool2d_bwd_nhwc_nhwc.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F8 = ck::f8_t;
using I8 = int8_t;
using I32 = int32_t;
using F32 = float;
using NHWC = ck::tensor_layout::convolution::NHWC;
template <typename OutType, typename InType, typename ComputeType>
using device_avgpool_2D_bwd_nhwc_instances = std::tuple<
// clang-format off
DeviceAvgPool2dBwd_NHWC_NHWC<OutType, InType, ComputeType, 256, 256, 1, 1, 1, 1>,
DeviceAvgPool2dBwd_NHWC_NHWC<OutType, InType, ComputeType, 256, 256, 1, 2, 2, 2>,
DeviceAvgPool2dBwd_NHWC_NHWC<OutType, InType, ComputeType, 256, 256, 1, 4, 4, 4>,
DeviceAvgPool2dBwd_NHWC_NHWC<OutType, InType, ComputeType, 256, 256, 1, 8, 8, 8>,
DeviceAvgPool2dBwd_NHWC_NHWC<OutType, InType, ComputeType, 256, 32, 8, 8, 8, 8>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_avg_pool2d_bwd_nhwc_instance_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_avgpool_2D_bwd_nhwc_int8_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<2, I8, I8, NHWC, NHWC>>>& instances)
{
add_device_operation_instances(instances, device_avgpool_2D_bwd_nhwc_instances<I8, I8, I32>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -35,6 +35,26 @@ list(APPEND GEMM_UNIVERSAL_INSTANCES ...@@ -35,6 +35,26 @@ list(APPEND GEMM_UNIVERSAL_INSTANCES
device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instance.cpp device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instance.cpp
device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instance.cpp
device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instance.cpp
device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instance.cpp
device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instance.cpp
device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instance.cpp
device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instance.cpp
device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instance.cpp
device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instance.cpp
device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instance.cpp
device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instance.cpp
device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instance.cpp
device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instance.cpp
device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instance.cpp
device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instance.cpp
) )
set_source_files_properties(device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
...@@ -50,6 +70,14 @@ set_source_files_properties(device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm ...@@ -50,6 +70,14 @@ set_source_files_properties(device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm
set_source_files_properties(device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
list(APPEND GEMM_UNIVERSAL_INSTANCES list(APPEND GEMM_UNIVERSAL_INSTANCES
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = bhalf_t;
using F32 = float;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
template <index_t... Is>
using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
static constexpr auto GemmMPadding = GemmSpecialization::MPadding;
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
static constexpr auto GemmMKPadding = GemmSpecialization::MKPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
template <GemmSpecialization GemmSpec>
using device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = std::tuple<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
// Can we support this kind of odd case? 224(256) = 28*8 + (4*8)
//DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>
// clang-format on
>;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_instances = std::tuple<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Latency friendly
DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
// Memory friendly
DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 2, 16, 16, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 4, 16, 16, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 4, 4, 16, 16, 2, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 4, 4, 16, 16, 1, 2, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 4, 4, 16, 16, 1, 4, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 4, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(
instances, device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_instances<GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_instances<GemmKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_instances<GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_instances<GemmMNPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_instances<Intrawave, GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_instances<Intrawave, GemmKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_instances<Intrawave,
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_instances<Interwave, GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_instances<Interwave, GemmKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_instances<Interwave,
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment