Commit 036c5234 authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge remote-tracking branch 'origin/develop' into aosewski/ggemm_multi_d2

parents 22995e9a 7843a8a7
......@@ -24,6 +24,8 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instance
PassThrough,
F8>>>& instances)
{
#if CK_BUILD_DEPRECATED
#pragma message "These instances are getting deprecated"
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_f16_comp_f8_instances<3,
......@@ -48,6 +50,10 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instance
Empty_Tuple,
NDHWGK,
ConvFwd1x1S1P0>{});
#else
#pragma message "These instances were deprecated"
std::ignore = instances;
#endif
}
} // namespace instance
......
......@@ -47,14 +47,14 @@ using D0Layout = Row;
// using DsLayout = ck::Tuple<Row>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using Add = ck::tensor_operation::element_wise::Add;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AElementOp = PassThrough;
using BElementOp = Scales;
using BElementOp = Multiply;
// using CDEElementOp = AddFastGelu;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
......
......@@ -47,14 +47,14 @@ using D0Layout = Row;
// using DsLayout = ck::Tuple<Row>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using Add = ck::tensor_operation::element_wise::Add;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AElementOp = PassThrough;
using BElementOp = Scales;
using BElementOp = Multiply;
// using CDEElementOp = AddFastGelu;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
......
......@@ -47,14 +47,14 @@ using D0Layout = Row;
// using DsLayout = ck::Tuple<Row>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using Add = ck::tensor_operation::element_wise::Add;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AElementOp = PassThrough;
using BElementOp = Scales;
using BElementOp = Multiply;
// using CDEElementOp = AddFastGelu;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
......
# ONLY XDL_KERNELS
set(GROUPED_GEMM_TILE_LOOP_INSTANCES)
list(APPEND GROUPED_GEMM_TILE_LOOP_INSTANCES
device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instance.cpp
)
add_instance_library(device_grouped_gemm_tile_loop_instance ${GROUPED_GEMM_TILE_LOOP_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using DsDataType = ck::Tuple<>;
using DsLayout = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_irregular_tile_instances = std::tuple<
// clang-format off
//###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>
// clang-format on
>;
void add_device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_irregular_tile_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using DsDataType = ck::Tuple<>;
using DsLayout = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple<
// clang-format off
//###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 8, 32, 32, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 64, 8, 8, 32, 32, 4, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 64, 8, 8, 32, 32, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
// clang-format on
>;
void add_device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Col,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_irregular_tile_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <typename DsLayout, typename DsDataType, typename CDEElementwiseOp>
using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_irregular_tile_instances = std::tuple<
// clang-format off
//###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if 1
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>
#endif
#if 0
//comp
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8>,
//latency
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4>,
//mem
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 16, 64, 128, 8, 4, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 32, 64, 128, 8, 4, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmMNKPadding, 1, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, 8>
#endif
// clang-format on
>;
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
ck::Tuple<Row>,
Row,
BF16,
I8,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_irregular_tile_instances<
ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply>{});
}
void add_device_grouped_gemm_xdl_tile_loop_multiply_bias_bf16_i8_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
ck::Tuple<Row, Row>,
Row,
BF16,
I8,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAdd>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_irregular_tile_instances<
ck::Tuple<Row, Row>,
ck::Tuple<BF16, BF16>,
MultiplyAdd>{});
}
void add_device_grouped_gemm_xdl_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
ck::Tuple<Row, Row>,
Row,
BF16,
I8,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_irregular_tile_instances<
ck::Tuple<Row, Row>,
ck::Tuple<BF16, BF16>,
MultiplyAddFastGelu>{});
}
void add_device_grouped_gemm_xdl_tile_loop_multiply_fastgelu_bf16_i8_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
ck::Tuple<Row>,
Row,
BF16,
I8,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_irregular_tile_instances<
ck::Tuple<Row>,
ck::Tuple<BF16>,
MultiplyFastGelu>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -13,15 +13,6 @@
./bin/ckProfiler gemm 1 1 1 1 0 5 3840 4096 4096 4096 4096 4096
```
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
```bash
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
....
Best Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s
```
## Profile 2D forward convolution kernels
```bash
#arg1: tensor operation (conv=Convolution)
......@@ -37,15 +28,6 @@ Best Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s
################ op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads
./bin/ckProfiler conv2d_fwd 1 1 1 1 1 1 0 5 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
```
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
```bash
in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192}
wei_k_c_y_x: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192}
out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256}
....
Best Perf: 1.42509 ms, 102.988 TFlops, 234.086 GB/s
```
## Profile contraction kernels
```bash
......@@ -71,16 +53,6 @@ Best Perf: 1.42509 ms, 102.988 TFlops, 234.086 GB/s
./bin/ckProfiler contraction_bilinear 0 0 2 1 0 0 0 1 1.0 1.0 128 128 128 128 128 128
```
Result (MI100)
```bash
a_m_k: dim 4, lengths {128, 128, 128, 128}, strides {2097152, 16384, 128, 1}
b_k_n: dim 4, lengths {128, 128, 128, 128}, strides {128, 1, 2097152, 16384}
d_m_n: dim 4, lengths {128, 128, 128, 128}, strides {2097152, 16384, 128, 1}
e_m_n: dim 4, lengths {128, 128, 128, 128}, strides {2097152, 16384, 128, 1}
....
Best Perf: 211.405 ms, 41.6077 TFlops, 15.2372 GB/s
```
## Profile batched gemm multiple D kernels
```bash
#arg1: tensor operation (batched_gemm_multi_d=Batched GEMM multi D);
......@@ -99,14 +71,6 @@ Best Perf: 211.405 ms, 41.6077 TFlops, 15.2372 GB/s
./bin/ckProfiler batched_gemm_multi_d 0 1 0 0 0 1 4096 4096 4096 4096 4096 4096 16777216 16777216 16777216 16
```
Result (Radeon RX 6800 XT)
```bash
arg.a_grid_desc_k0_m0_m1_k1_{2048, 4096, 2}
arg.b_grid_desc_k0_n0_n1_k1_{2048, 4096, 2}
arg.e_grid_desc_m_n_{ 4096, 4096}
....
Best Perf: 58.0306 ms, 37.8942 TFlops, 27.7545 GB/s
```
## Profile grouped convolution backward data kernels
```bash
# arg1: tensor operation (grouped_conv_bwd_data: Grouped Convolution Backward Data)
......@@ -134,20 +98,6 @@ Best Perf: 58.0306 ms, 37.8942 TFlops, 27.7545 GB/s
```
Result (MI100, FP16, GNHWC_GKYXC_GNHWK)
```bash
out: dim 5, lengths {32, 4, 192, 28, 28}, strides {602112, 150528, 1, 5376, 192}
wei: dim 5, lengths {32, 192, 192, 3, 3}, strides {331776, 1728, 1, 576, 192}
in: dim 5, lengths {32, 4, 192, 28, 28}, strides {602112, 150528, 1, 5376, 192}
....
Best configuration parameters:
name: DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<256, 128, 256, 32, 8, 2, Default, 32, 32, 2, 4, 8, 4, 1, 1>
avg_time: 0.768321
tflops: 86.6679
GB/s: 127.947
```
## Profile grouped convolution backward weight kernels
```bash
# arg1: tensor operation (grouped_conv_bwd_weight: Grouped Convolution Backward Weight)
......@@ -179,19 +129,6 @@ GB/s: 127.947
```
Result (MI100, FP16, GNHWC_GKYXC_GNHWK)
```bash
input: dim 5, lengths {32, 512, 1024, 28, 28}, strides {411041792, 802816, 1, 28672, 1024}
weight: dim 5, lengths {32, 512, 1024, 3, 3}, strides {4718592, 9216, 1, 3072, 1024}
output: dim 5, lengths {32, 512, 512, 26, 26}, strides {177209344, 346112, 1, 13312, 512}
....
Best configuration parameters:
name: DeviceGroupedConvBwdWeight_Xdl_CShuffle<256, 256, 128, 4, Default, 8, 4, 2, 8, 4, 8, 2, 1, 1, 8>
avg_time: 68.5216
tflops: 95.337
GB/s: 69.2301
```
Note: This kernel use atomic add, this will cause output buffer to be accumulated multiple times, causing verification failure. To work around it, do not use CK's own timer and do verification at the same time.
## Profile image to column/column to image kernels
......@@ -224,17 +161,6 @@ Note: This kernel use atomic add, this will cause output buffer to be accumulate
```
Result (MI210, FP32, NHWC)
```bash
input: dim 5, lengths {1, 256, 512, 28, 28}, strides {102760448, 401408, 1, 14336, 512}
output: dim 2, lengths {173056, 4608}, strides {4608, 1}
....
Best configuration parameters:
name: DeviceImageToColumn<128, 32, 64, 4>
avg_time: 3.12326
GB/s: 2042.59
```
Note: Column to image kernel adds to the output memory, this will cause output buffer to be accumulated multiple times, causing verification failure. To work around it, do not use CK's own timer and do verification at the same time.
## Profile Permute scale kernels
......@@ -254,12 +180,3 @@ Note: Column to image kernel adds to the output memory, this will cause output b
################ op datatype verify init log time dim0 dim1 dim2 in_stride0 in_stride1 in_stride2 out_stride0 out_stride1 out_stride2
./bin/ckProfiler permute_scale 0 1 1 0 1 64 64 64 4096 64 1 1 64 4096
```
Result (MI100, FP32)
```bash
A: dim 3, lengths {64, 64, 64}, strides {4096, 64, 1}
B: dim 3, lengths {64, 64, 64}, strides {1, 64, 4096}
....
Best perf = 0.0146878 ms, 142.782 GB/s, DeviceElementwiseNormalizationImpl<3, 2>
```
......@@ -43,7 +43,8 @@ bool profile_gemm_universal_impl(int do_verification,
int StrideC,
int KBatch,
int n_warmup,
int n_iter)
int n_iter,
uint64_t rotating = 0)
{
bool pass = true;
......@@ -66,9 +67,16 @@ bool profile_gemm_universal_impl(int do_verification,
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
int total_gemm_needed = a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes();
int rotating_count = std::max(
1,
std::min(n_iter,
static_cast<int>(std::ceil(static_cast<double>(rotating) / total_gemm_needed))));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl;
std::cout << "rotating count: " << rotating_count << std::endl;
switch(init_method)
{
......@@ -200,8 +208,14 @@ bool profile_gemm_universal_impl(int do_verification,
std::string op_name = op_ptr->GetTypeString();
float ave_time = invoker_ptr->Run(
argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter});
float ave_time = invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr,
time_kernel,
0,
n_warmup,
n_iter,
rotating_count > 1,
rotating_count});
std::size_t flop = std::size_t(2) * M * N * K;
......
......@@ -188,6 +188,10 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
out_element_op,
split_k);
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
// using atomic add, so need to reset input
......
......@@ -73,9 +73,11 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification,
std::vector<Tensor<BDataType>> b_k_n;
std::vector<Tensor<CDataType>> c_m_n_host_results;
std::vector<Tensor<CDataType>> c_m_n_device_results;
int sum_of_m = 0;
for(std::size_t i = 0; i < group_count; i++)
{
sum_of_m += Ms[i];
a_m_k.push_back(
Tensor<ADataType>(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{})));
b_k_n.push_back(
......@@ -86,11 +88,12 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification,
c_m_n_host_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
#if DEBUG_LOG
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i
<< "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
<< "]:" << c_m_n_device_results[i].mDesc << std::endl;
#endif // DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n["
<< i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
<< "]:" << c_m_n_device_results[i].mDesc << std::endl;
}
std::size_t num_thread = 1;
switch(init_method)
{
......@@ -146,7 +149,7 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification,
a_device_buf[i]->ToDevice(a_m_k[i].mData.data());
b_device_buf[i]->ToDevice(b_k_n[i].mData.data());
gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
p_a.push_back(a_device_buf[i]->GetDeviceBuffer());
p_b.push_back(b_device_buf[i]->GetDeviceBuffer());
......
......@@ -87,11 +87,12 @@ bool profile_grouped_gemm_impl(int do_verification,
c_m_n_host_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
#if DEBUG_LOG
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i
<< "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
<< "]:" << c_m_n_device_results[i].mDesc << std::endl;
#endif // DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n["
<< i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
<< "]:" << c_m_n_device_results[i].mDesc << std::endl;
}
std::size_t num_thread = 1;
switch(init_method)
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include "ck/ck.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace ck {
namespace profiler {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout>
bool profile_grouped_gemm_tile_loop_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int n_warmup = 10,
int n_iter = 50)
{
bool pass = true;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
std::size_t group_count = Ms.size();
if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() &&
group_count == StrideBs.size() && group_count == StrideCs.size()))
{
throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n");
}
std::vector<Tensor<ADataType>> a_m_k;
std::vector<Tensor<BDataType>> b_k_n;
std::vector<Tensor<CDataType>> c_m_n_host_results;
std::vector<Tensor<CDataType>> c_m_n_device_results;
for(std::size_t i = 0; i < group_count; i++)
{
a_m_k.push_back(
Tensor<ADataType>(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{})));
b_k_n.push_back(
Tensor<BDataType>(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{})));
c_m_n_device_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
c_m_n_host_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n["
<< i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
<< "]:" << c_m_n_device_results[i].mDesc << std::endl;
}
switch(init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5, 5}(a_m_k[i]);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5, 5}(b_k_n[i]);
break;
case 2:
ck::utils::FillUniformDistribution<ADataType>{.0, 1.}(a_m_k[i]);
ck::utils::FillUniformDistribution<BDataType>{-0.5, 0.5}(b_k_n[i]);
break;
default:
ck::utils::FillConstant<ADataType>{1}(a_m_k[i]);
ck::utils::FillConstant<BDataType>{1}(b_k_n[i]);
}
}
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_device_buf, b_device_buf, c_device_buf;
a_device_buf.reserve(group_count);
b_device_buf.reserve(group_count);
c_device_buf.reserve(group_count);
std::vector<const void*> p_a, p_b;
std::vector<void*> p_c;
p_a.reserve(group_count);
p_b.reserve(group_count);
p_c.reserve(group_count);
using KernelArguments = ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments<>;
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<KernelArguments> gemm_kargs;
gemm_descs.reserve(group_count);
gemm_kargs.reserve(group_count);
for(std::size_t i = 0; i < group_count; i++)
{
a_device_buf.emplace_back(
std::make_unique<DeviceMem>(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize()));
b_device_buf.emplace_back(
std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize()));
c_device_buf.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpaceSize()));
a_device_buf[i]->ToDevice(a_m_k[i].mData.data());
b_device_buf[i]->ToDevice(b_k_n[i].mData.data());
c_device_buf[i]->SetZero();
p_a.push_back(a_device_buf[i]->GetDeviceBuffer());
p_b.push_back(b_device_buf[i]->GetDeviceBuffer());
p_c.push_back(c_device_buf[i]->GetDeviceBuffer());
gemm_descs.push_back({0, Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
gemm_kargs.push_back({a_device_buf[i]->GetDeviceBuffer(),
b_device_buf[i]->GetDeviceBuffer(),
{},
c_device_buf[i]->GetDeviceBuffer(),
Ms[i],
Ns[i],
Ks[i],
StrideAs[i],
StrideBs[i],
{},
StrideCs[i]});
}
using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmTileLoop<ALayout,
BLayout,
ck::Tuple<>,
CLayout,
ADataType,
BDataType,
ck::Tuple<>,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
if(op_ptrs.size() <= 0)
{
throw std::runtime_error("wrong! no device GEMM instance found");
}
std::string best_gemm_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
auto p_ds = std::vector<std::array<const void*, 0>>{};
if(do_verification)
{
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_m_k[i],
b_k_n[i],
c_m_n_host_results[i],
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
}
}
// profile device GEMM instances
for(auto& gemm_ptr : op_ptrs)
{
auto argument_ptr =
gemm_ptr->MakeArgumentPointer(p_a,
p_b,
p_ds,
p_c,
gemm_descs,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{});
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
std::string gemm_name = gemm_ptr->GetTypeString();
DeviceMem gemm_arg_dev_mem(gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()));
hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(),
gemm_kargs.data(),
gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()),
hipMemcpyHostToDevice));
gemm_ptr->SetDeviceKernelArgs(argument_ptr.get(), gemm_arg_dev_mem.GetDeviceBuffer());
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false, 0, n_warmup, n_iter});
if(do_verification)
{
bool instance_pass = true;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data());
instance_pass = instance_pass && ck::utils::check_err(c_m_n_device_results[i],
c_m_n_host_results[i]);
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k[i].mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n[i].mData, ",") << std::endl;
LogRangeAsType<float>(
std::cout << "c_device: ", c_m_n_device_results[i].mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_host : ", c_m_n_host_results[i].mData, ",")
<< std::endl;
}
}
std::cout << "Instance: " << gemm_name << " verification "
<< (instance_pass ? "SUCCEED" : "FAILED") << std::endl;
pass = pass && instance_pass;
}
if(time_kernel)
{
float ave_time = invoker_ptr->Run(
argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter});
std::size_t flop = 0, num_btype = 0;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i];
num_btype += sizeof(ADataType) * Ms[i] * Ks[i] +
sizeof(BDataType) * Ks[i] * Ns[i] +
sizeof(CDataType) * Ms[i] * Ns[i];
}
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
<< " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << std::endl;
if(tflops > best_tflops)
{
best_gemm_name = gemm_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
}
else
{
std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem"
<< std::endl;
}
}
if(time_kernel)
{
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
}
return pass;
}
} // namespace profiler
} // namespace ck
......@@ -88,11 +88,12 @@ bool profile_grouped_gemm_two_stage_impl(int do_verification,
c_m_n_host_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
#if DEBUG_LOG
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i
<< "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
<< "]:" << c_m_n_device_results[i].mDesc << std::endl;
#endif // DEBUG_LOG
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
{
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n["
<< i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
<< "]:" << c_m_n_device_results[i].mDesc << std::endl;
}
std::size_t num_thread = 1;
switch(init_method)
{
......
......@@ -43,6 +43,7 @@ if(GPU_TARGETS MATCHES "gfx9")
list(APPEND PROFILER_SOURCES profile_grouped_gemm_two_stage.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiple_d_splitk.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_tile_loop.cpp)
endif()
list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp)
list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp)
......@@ -113,6 +114,7 @@ if(GPU_TARGETS MATCHES "gfx9")
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_multiple_d_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_tile_loop_instance)
endif()
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
......
......@@ -33,7 +33,7 @@ enum struct GemmDataType
int profile_gemm_universal(int argc, char* argv[])
{
if(argc != 15 && argc != 17)
if(argc != 15 && argc != 18)
{
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: f16, "
......@@ -51,6 +51,7 @@ int profile_gemm_universal(int argc, char* argv[])
printf("optional:\n");
printf("arg15: number of warm-up cycles (default 1)\n");
printf("arg16: number of iterations (default 10)\n");
printf("arg17: memory for rotating buffer (default 0, size in MB)\n");
exit(1);
}
......@@ -70,17 +71,20 @@ int profile_gemm_universal(int argc, char* argv[])
const int StrideC = std::stoi(argv[13]);
const int KBatch = std::stoi(argv[14]);
int n_warmup = 1;
int n_iter = 10;
if(argc == 17)
int n_warmup = 1;
int n_iter = 10;
uint64_t rotating = 0;
if(argc == 18)
{
n_warmup = std::stoi(argv[15]);
n_iter = std::stoi(argv[16]);
rotating = std::stoull(argv[17]) * 1024 * 1024;
}
using F32 = float;
using F16 = ck::half_t;
using F8 = ck::f8_t;
using F32 = float;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F8 = ck::f8_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
......@@ -124,7 +128,8 @@ int profile_gemm_universal(int argc, char* argv[])
(StrideC < 0) ? DefaultStrideC : StrideC,
KBatch,
n_warmup,
n_iter);
n_iter,
rotating);
return pass ? 0 : 1;
};
......@@ -153,6 +158,14 @@ int profile_gemm_universal(int argc, char* argv[])
{
return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(BF16{}, BF16{}, F32{}, BF16{}, Row{}, Row{}, Row{});
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(BF16{}, BF16{}, F32{}, BF16{}, Row{}, Col{}, Row{});
}
else
{
std::cout << "this data_type & layout is not implemented" << std::endl;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
#include "profiler/profile_grouped_gemm_tile_loop_impl.hpp"
#include "profiler_operation_registry.hpp"
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 0
};
enum struct GemmDataType
{
F16_F16_F16, // 0
};
#define OP_NAME "grouped_gemm_tile_loop"
#define OP_DESC "Grouped GEMM Multiple D Tile Loop"
namespace {
std::vector<int> argToIntArray(char* input)
{
std::vector<int> out;
std::istringstream in(input);
std::string item;
while(std::getline(in, item, ','))
{
out.push_back(std::stoi(item));
}
return out;
}
int profile_grouped_gemm_tile_loop(int argc, char* argv[])
{
if(argc < 14)
{
std::cout
<< "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: fp16)\n"
<< "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n]);\n"
<< " 1: A[m, k] * B[n, k] = C[m, n];\n"
<< "arg4: verification (0: no; 1: yes)\n"
<< "arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n"
<< "arg7: time kernel (0=n0, 1=yes)\n"
<< "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
"64,64 64,64 128,128)\n"
<< "optional:\n"
<< "arg14: number of warm-up cycles (default 1)\n"
<< "arg15: number of iterations (default 10)\n"
<< std::endl;
exit(1);
}
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]);
const bool time_kernel = std::stoi(argv[7]);
const auto Ms = argToIntArray(argv[8]);
const auto Ns = argToIntArray(argv[9]);
const auto Ks = argToIntArray(argv[10]);
auto StrideAs = argToIntArray(argv[11]);
auto StrideBs = argToIntArray(argv[12]);
auto StrideCs = argToIntArray(argv[13]);
const int DefaultStrideA = Ks[0];
const int DefaultStrideB = Ns[0];
const int DefaultStrideC = Ns[0];
for(size_t i = 0; i < Ms.size(); ++i)
{
StrideAs[i] = StrideAs[i] == -1 ? DefaultStrideA : StrideAs[i];
StrideBs[i] = StrideBs[i] == -1 ? DefaultStrideB : StrideBs[i];
StrideCs[i] = StrideCs[i] == -1 ? DefaultStrideC : StrideCs[i];
}
int n_warmup = 10;
int n_iter = 50;
if(argc == 16)
{
n_warmup = std::stoi(argv[14]);
n_iter = std::stoi(argv[15]);
}
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_tile_loop_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_tile_loop_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
n_warmup,
n_iter);
}
else
{
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");
}
return 0;
}
} // anonymous namespace
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_grouped_gemm_tile_loop);
......@@ -65,7 +65,7 @@ set -- "${POSITIONAL[@]}" # restore positional parameters
# NUMACTL="numactl --cpunodebind=1 --membind=1"
NUMACTL=
# ENV_CONF=
GPU=mi100
GPU=gfx908
PROF_ITER_COUNT=10000
LOG_DIR_PATH=../log/${LOG_DIR}
set -x
......
......@@ -182,3 +182,4 @@ add_subdirectory(wrapper)
if(GPU_TARGETS MATCHES "gfx11")
add_subdirectory(wmma_op)
endif()
add_subdirectory(position_embedding)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment