Commit d1a50f9f authored by Jing Zhang's avatar Jing Zhang
Browse files

seperate grouped gemm splitk out

parent c8a8385f
...@@ -68,58 +68,6 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances( ...@@ -68,58 +68,6 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename ELayout, typename ELayout,
...@@ -161,17 +109,11 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -161,17 +109,11 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
{ {
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances(
op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> && else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
{ {
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances(
op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> && else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef CK_ENABLE_FP16
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmSplitK<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmSplitK<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <typename ALayout,
typename BLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGroupedGemmSplitK<ALayout,
BLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
EDataType,
PassThrough,
PassThrough,
PassThrough>>
{
using DeviceOp = DeviceGroupedGemmSplitK<ALayout,
BLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
EDataType,
PassThrough,
PassThrough,
PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances(
op_ptrs);
}
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances(
op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
...@@ -4,9 +4,5 @@ add_instance_library(device_grouped_gemm_instance ...@@ -4,9 +4,5 @@ add_instance_library(device_grouped_gemm_instance
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
) )
endif() endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Empty_Tuple = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// a[m, k] * b[k, n] = e[m, n]
using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple<
// clang-format off
//################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Currently AK1 must equal BK1 !
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 16,16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on
>;
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Empty_Tuple = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// a[m, k] * b[n, k] = e[m, n]
using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple<
// clang-format off
//################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
// clang-format on
>;
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_grouped_gemm_splitk_instance
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
)
endif()
...@@ -87,17 +87,17 @@ using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_tile_instanc ...@@ -87,17 +87,17 @@ using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_tile_instanc
>; >;
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances( void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row, std::vector<std::unique_ptr<DeviceGroupedGemmSplitK<Row,
Row, Row,
Empty_Tuple, Empty_Tuple,
Row, Row,
F16, F16,
F16, F16,
Empty_Tuple, Empty_Tuple,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_tile_instances{}); instances, device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_tile_instances{});
......
...@@ -59,17 +59,17 @@ using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_tile_instanc ...@@ -59,17 +59,17 @@ using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_tile_instanc
>; >;
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances( void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row, std::vector<std::unique_ptr<DeviceGroupedGemmSplitK<Row,
Col, Col,
Empty_Tuple, Empty_Tuple,
Row, Row,
F16, F16,
F16, F16,
Empty_Tuple, Empty_Tuple,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_tile_instances{}); instances, device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_tile_instances{});
......
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp"
...@@ -41,8 +40,7 @@ bool profile_grouped_gemm_impl(int do_verification, ...@@ -41,8 +40,7 @@ bool profile_grouped_gemm_impl(int do_verification,
const std::vector<int>& Ks, const std::vector<int>& Ks,
const std::vector<int>& StrideAs, const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs, const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs, const std::vector<int>& StrideCs)
int kbatch = 1)
{ {
bool pass = true; bool pass = true;
...@@ -173,7 +171,6 @@ bool profile_grouped_gemm_impl(int do_verification, ...@@ -173,7 +171,6 @@ bool profile_grouped_gemm_impl(int do_verification,
float best_ave_time = 0; float best_ave_time = 0;
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
float best_kbatch = 0;
auto p_ds = std::vector<std::array<const void*, 0>>{}; auto p_ds = std::vector<std::array<const void*, 0>>{};
...@@ -223,135 +220,85 @@ bool profile_grouped_gemm_impl(int do_verification, ...@@ -223,135 +220,85 @@ bool profile_grouped_gemm_impl(int do_verification,
gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer()); gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
std::string gemm_name = gemm_ptr->GetTypeString(); std::string gemm_name = gemm_ptr->GetTypeString();
using DeviceOpSplitK = ck::tensor_operation::device::DeviceGroupedGemmSplitK<ALayout, if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
BLayout,
ck::Tuple<>,
CLayout,
ADataType,
BDataType,
ck::Tuple<>,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
// skip non-splitk grouped_gemm
if(dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get()) == nullptr)
{
continue;
}
std::vector<int> kbatch_list = {1, 2, 4, 8, 12, 16, 20, 24, 32, 48, 64};
if(kbatch > 0)
{
kbatch_list = {kbatch};
}
for(std::size_t j = 0; j < kbatch_list.size(); j++)
{ {
for(std::size_t i = 0; i < gemm_descs.size(); i++)
c_device_buf[i]->SetZero();
auto kbatch_curr = kbatch_list[j]; invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get()) if(do_verification)
->SetKBatchSize(argument_ptr.get(), kbatch_curr);
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
bool instance_pass = true;
for(std::size_t i = 0; i < gemm_descs.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
c_device_buf[i]->SetZero(); {
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data());
if(do_verification) instance_pass = instance_pass && ck::utils::check_err(c_m_n_device_results[i],
{ c_m_n_host_results[i]);
bool instance_pass = true;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data()); if(do_log)
{
if(std::is_same_v<CDataType, ck::half_t> && kbatch_curr > 1) LogRangeAsType<float>(std::cout << "a : ", a_m_k[i].mData, ",")
{ << std::endl;
instance_pass = LogRangeAsType<float>(std::cout << "b: ", b_k_n[i].mData, ",") << std::endl;
instance_pass && ck::utils::check_err(c_m_n_device_results[i], LogRangeAsType<float>(
c_m_n_host_results[i], std::cout << "c_device: ", c_m_n_device_results[i].mData, ",")
"Error: Incorrect results!", << std::endl;
0.06); LogRangeAsType<float>(
} std::cout << "c_host : ", c_m_n_host_results[i].mData, ",")
else << std::endl;
{
instance_pass =
instance_pass && ck::utils::check_err(c_m_n_device_results[i],
c_m_n_host_results[i]);
}
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k[i].mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n[i].mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_device: ", c_m_n_device_results[i].mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_host : ", c_m_n_host_results[i].mData, ",")
<< std::endl;
}
} }
}
std::cout << "Instance: " << gemm_name << " verification " std::cout << "Instance: " << gemm_name << " verification "
<< (instance_pass ? "SUCCEED" : "FAILED") << std::endl; << (instance_pass ? "SUCCEED" : "FAILED") << std::endl;
pass = pass && instance_pass; pass = pass && instance_pass;
} }
float ave_time = float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
if(time_kernel) if(time_kernel)
{
std::size_t flop = 0, num_btype = 0;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
std::size_t flop = 0, num_btype = 0; flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i];
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i];
num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + num_btype += sizeof(ADataType) * Ms[i] * Ks[i] +
sizeof(BDataType) * Ks[i] * Ns[i] + sizeof(BDataType) * Ks[i] * Ns[i] +
sizeof(CDataType) * Ms[i] * Ns[i]; sizeof(CDataType) * Ms[i] * Ns[i];
} }
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
<< " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << ", KBatch " << " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << std::endl;
<< kbatch_curr << std::endl;
if(tflops > best_tflops) if(tflops > best_tflops)
{ {
best_gemm_name = gemm_name; best_gemm_name = gemm_name;
best_tflops = tflops; best_tflops = tflops;
best_ave_time = ave_time; best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec; best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr;
}
} }
} }
else }
{ else
std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem" {
<< std::endl; std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem"
} << std::endl;
} }
} }
if(time_kernel) if(time_kernel)
{ {
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << ", KBatch = " << best_kbatch << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
<< std::endl;
} }
return pass; return pass;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_splitk.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace ck {
namespace profiler {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout>
bool profile_grouped_gemm_splitk_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1)
{
bool pass = true;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
std::size_t group_count = Ms.size();
if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() &&
group_count == StrideBs.size() && group_count == StrideCs.size()))
{
throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n");
}
std::vector<Tensor<ADataType>> a_m_k;
std::vector<Tensor<BDataType>> b_k_n;
std::vector<Tensor<CDataType>> c_m_n_host_results;
std::vector<Tensor<CDataType>> c_m_n_device_results;
for(std::size_t i = 0; i < group_count; i++)
{
a_m_k.push_back(
Tensor<ADataType>(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{})));
b_k_n.push_back(
Tensor<BDataType>(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{})));
c_m_n_device_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
c_m_n_host_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
#if DEBUG_LOG
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i
<< "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
<< "]:" << c_m_n_device_results[i].mDesc << std::endl;
#endif // DEBUG_LOG
std::size_t num_thread = 1;
switch(init_method)
{
case 0: break;
case 1:
a_m_k[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}, num_thread);
b_k_n[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
break;
default:
a_m_k[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread);
b_k_n[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
}
}
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_device_buf, b_device_buf, c_device_buf;
a_device_buf.reserve(group_count);
b_device_buf.reserve(group_count);
c_device_buf.reserve(group_count);
std::vector<const void*> p_a, p_b;
std::vector<void*> p_c;
p_a.reserve(group_count);
p_b.reserve(group_count);
p_c.reserve(group_count);
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
gemm_descs.reserve(group_count);
for(std::size_t i = 0; i < group_count; i++)
{
a_device_buf.emplace_back(
std::make_unique<DeviceMem>(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize()));
b_device_buf.emplace_back(
std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize()));
c_device_buf.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpaceSize()));
a_device_buf[i]->ToDevice(a_m_k[i].mData.data());
b_device_buf[i]->ToDevice(b_k_n[i].mData.data());
gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
p_a.push_back(a_device_buf[i]->GetDeviceBuffer());
p_b.push_back(b_device_buf[i]->GetDeviceBuffer());
p_c.push_back(c_device_buf[i]->GetDeviceBuffer());
}
using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmSplitK<ALayout,
BLayout,
ck::Tuple<>,
CLayout,
ADataType,
BDataType,
ck::Tuple<>,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
if(op_ptrs.size() <= 0)
{
throw std::runtime_error("wrong! no device GEMM instance found");
}
std::string best_gemm_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
float best_kbatch = 0;
auto p_ds = std::vector<std::array<const void*, 0>>{};
if(do_verification)
{
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_m_k[i],
b_k_n[i],
c_m_n_host_results[i],
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
}
}
// profile device GEMM instances
for(auto& gemm_ptr : op_ptrs)
{
auto argument_ptr =
gemm_ptr->MakeArgumentPointer(p_a,
p_b,
p_ds,
p_c,
gemm_descs,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{});
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get()));
gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
std::string gemm_name = gemm_ptr->GetTypeString();
using DeviceOpSplitK = ck::tensor_operation::device::DeviceGroupedGemmSplitK<ALayout,
BLayout,
ck::Tuple<>,
CLayout,
ADataType,
BDataType,
ck::Tuple<>,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
// skip non-splitk grouped_gemm
if(dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get()) == nullptr)
{
continue;
}
std::vector<int> kbatch_list = {1, 2, 4, 8, 12, 16, 20, 24, 32, 48, 64};
if(kbatch > 0)
{
kbatch_list = {kbatch};
}
for(std::size_t j = 0; j < kbatch_list.size(); j++)
{
auto kbatch_curr = kbatch_list[j];
dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get())
->SetKBatchSize(argument_ptr.get(), kbatch_curr);
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
for(std::size_t i = 0; i < gemm_descs.size(); i++)
c_device_buf[i]->SetZero();
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
if(do_verification)
{
bool instance_pass = true;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data());
if(std::is_same_v<CDataType, ck::half_t> && kbatch_curr > 1)
{
instance_pass =
instance_pass && ck::utils::check_err(c_m_n_device_results[i],
c_m_n_host_results[i],
"Error: Incorrect results!",
0.06);
}
else
{
instance_pass =
instance_pass && ck::utils::check_err(c_m_n_device_results[i],
c_m_n_host_results[i]);
}
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k[i].mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n[i].mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_device: ", c_m_n_device_results[i].mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_host : ", c_m_n_host_results[i].mData, ",")
<< std::endl;
}
}
std::cout << "Instance: " << gemm_name << " verification "
<< (instance_pass ? "SUCCEED" : "FAILED") << std::endl;
pass = pass && instance_pass;
}
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
if(time_kernel)
{
std::size_t flop = 0, num_btype = 0;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i];
num_btype += sizeof(ADataType) * Ms[i] * Ks[i] +
sizeof(BDataType) * Ks[i] * Ns[i] +
sizeof(CDataType) * Ms[i] * Ns[i];
}
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
<< " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << ", KBatch "
<< kbatch_curr << std::endl;
if(tflops > best_tflops)
{
best_gemm_name = gemm_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr;
}
}
}
else
{
std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem"
<< std::endl;
}
}
}
if(time_kernel)
{
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << ", KBatch = " << best_kbatch
<< std::endl;
}
return pass;
}
} // namespace profiler
} // namespace ck
...@@ -39,6 +39,7 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) ...@@ -39,6 +39,7 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp) list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp)
list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp) list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_splitk.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp)
endif() endif()
...@@ -89,6 +90,7 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) ...@@ -89,6 +90,7 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_splitk_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance)
endif() endif()
rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler) rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler)
...@@ -87,7 +87,7 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -87,7 +87,7 @@ int profile_grouped_gemm(int argc, char* argv[])
const auto StrideAs = argToIntArray(argv[11]); const auto StrideAs = argToIntArray(argv[11]);
const auto StrideBs = argToIntArray(argv[12]); const auto StrideBs = argToIntArray(argv[12]);
const auto StrideCs = argToIntArray(argv[13]); const auto StrideCs = argToIntArray(argv[13]);
const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1;
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
...@@ -106,8 +106,7 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -106,8 +106,7 @@ int profile_grouped_gemm(int argc, char* argv[])
Ks, Ks,
StrideAs, StrideAs,
StrideBs, StrideBs,
StrideCs, StrideCs);
kbatch);
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
...@@ -126,8 +125,7 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -126,8 +125,7 @@ int profile_grouped_gemm(int argc, char* argv[])
Ks, Ks,
StrideAs, StrideAs,
StrideBs, StrideBs,
StrideCs, StrideCs);
kbatch);
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
...@@ -146,8 +144,7 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -146,8 +144,7 @@ int profile_grouped_gemm(int argc, char* argv[])
Ks, Ks,
StrideAs, StrideAs,
StrideBs, StrideBs,
StrideCs, StrideCs);
kbatch);
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
...@@ -166,8 +163,7 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -166,8 +163,7 @@ int profile_grouped_gemm(int argc, char* argv[])
Ks, Ks,
StrideAs, StrideAs,
StrideBs, StrideBs,
StrideCs, StrideCs);
kbatch);
} }
else else
{ {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_grouped_gemm_splitk_impl.hpp"
#include "profiler_operation_registry.hpp"
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
KM_KN_MN, // 2
KM_NK_MN, // 3
};
enum struct GemmDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
};
#define OP_NAME "grouped_gemm_splitk"
#define OP_DESC "Grouped GEMM SplitK"
namespace {
std::vector<int> argToIntArray(char* input)
{
std::vector<int> out;
std::istringstream in(input);
std::string item;
while(std::getline(in, item, ','))
{
out.push_back(std::stoi(item));
}
return out;
}
int profile_grouped_gemm_splitk(int argc, char* argv[])
{
if(argc < 14)
{
std::cout
<< "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"
<< "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"
<< " 1: A[m, k] * B[n, k] = C[m, n];\n"
<< " 2: A[k, m] * B[k, n] = C[m, n];\n"
<< " 3: A[k, m] * B[n, k] = C[m, n])\n"
<< "arg4: verification (0: no; 1: yes)\n"
<< "arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n"
<< "arg7: time kernel (0=n0, 1=yes)\n"
<< "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
"64,64 64,64 128,128)\n"
<< "arg15: kbatch value (default 4)\n"
<< std::endl;
exit(1);
}
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]);
const bool time_kernel = std::stoi(argv[7]);
const auto Ms = argToIntArray(argv[8]);
const auto Ns = argToIntArray(argv[9]);
const auto Ks = argToIntArray(argv[10]);
const auto StrideAs = argToIntArray(argv[11]);
const auto StrideBs = argToIntArray(argv[12]);
const auto StrideCs = argToIntArray(argv[13]);
const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1;
#ifdef CK_ENABLE_FP16
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_splitk_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_splitk_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{
ck::profiler::profile_grouped_gemm_splitk_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{
ck::profiler::profile_grouped_gemm_splitk_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch);
}
else
{
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");
}
#endif
return 0;
}
} // anonymous namespace
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_grouped_gemm_splitk);
...@@ -93,6 +93,7 @@ TEST_F(TestGGemmSplitKInterface_MKNKMN, VectorLoadWidth) ...@@ -93,6 +93,7 @@ TEST_F(TestGGemmSplitKInterface_MKNKMN, VectorLoadWidth)
EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs));
} }
#if 0
TEST_F(TestGGemmSplitKInterface_MKNKMN, KLoops) TEST_F(TestGGemmSplitKInterface_MKNKMN, KLoops)
{ {
std::vector<int> Ms{128, 256, 256, 512}; std::vector<int> Ms{128, 256, 256, 512};
...@@ -116,6 +117,7 @@ TEST_F(TestGGemmSplitKInterface_MKNKMN, KLoops) ...@@ -116,6 +117,7 @@ TEST_F(TestGGemmSplitKInterface_MKNKMN, KLoops)
EXPECT_THROW(DefaultGGemmInstance{}.Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch), EXPECT_THROW(DefaultGGemmInstance{}.Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch),
std::runtime_error); std::runtime_error);
} }
#endif
class TestGGemmSplitKInterface_KMKNNM : public ::testing::Test class TestGGemmSplitKInterface_KMKNNM : public ::testing::Test
{ {
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "test_grouped_gemm_util.hpp" #include "test_grouped_gemm_splitk_util.hpp"
using F16 = ck::half_t; using F16 = ck::half_t;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <string>
#include <sstream>
#include <tuple>
#include <vector>
#include <gtest/gtest.h>
#include "ck/ck.hpp"
#include "ck/stream_config.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/number.hpp"
#include "profiler/profile_grouped_gemm_splitk_impl.hpp"
namespace ck {
namespace test {
template <typename Range>
std::string serialize_range(const Range& range)
{
std::stringstream ss;
for(auto& r : range)
{
ss << r << ", ";
}
std::string str = ss.str();
return std::string(str.begin(), str.end() - 2);
}
template <typename Tuple>
class TestGroupedGemm : public testing::TestWithParam<int>
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using ELayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>;
using EDataType = std::tuple_element_t<5, Tuple>;
public:
static constexpr bool verify_ = true;
static constexpr int init_method_ = 1; // decimal value initialization
static constexpr bool log_ = false;
static constexpr bool bench_ = false; // measure kernel performance
void SetUp() override {}
void Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1)
{
bool pass = ck::profiler::profile_grouped_gemm_splitk_impl<ADataType,
BDataType,
EDataType,
float,
ALayout,
BLayout,
ELayout>(
verify_, init_method_, log_, bench_, Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch);
EXPECT_TRUE(pass);
}
};
template <typename ALayout,
typename BLayout,
typename ELayout,
tensor_operation::device::GemmSpecialization GemmSpec,
ck::index_t KPerBlock,
ck::index_t K1,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferSrcScalarPerVector,
index_t CDEBlockTransferScalarPerVector_NPerBlock>
struct DeviceGroupedGemmSplitkInstanceWrapper
{
using F16 = half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = tensor_operation::element_wise::PassThrough;
using EmptyTuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
template <ck::index_t N>
using I = ck::Number<N>;
using ABlockTransferThreadClusterArrageOrder =
std::conditional_t<std::is_same_v<ALayout, Row>, S<0, 2, 1, 3>, S<0, 1, 3, 2>>;
using ABlockTransferSrcAccessOrder =
std::conditional_t<std::is_same_v<ALayout, Row>, S<0, 2, 1, 3>, S<0, 1, 3, 2>>;
using ABlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<ALayout, Row>, I<3>, I<2>>;
using ABlockTransferDstScalarPerVector_K1 =
std::conditional_t<std::is_same_v<ALayout, Row>, I<8>, I<2>>;
using ABlockLdsAddExtraM = std::conditional_t<std::is_same_v<ALayout, Row>, I<1>, I<0>>;
using BBlockTransferThreadClusterArrageOrder =
std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 1, 3, 2>, S<0, 2, 1, 3>>;
using BBlockTransferSrcAccessOrder =
std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 1, 3, 2>, S<0, 2, 1, 3>>;
using BBlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<BLayout, Row>, I<2>, I<3>>;
using BBlockTransferDstScalarPerVector_K1 =
std::conditional_t<std::is_same_v<ALayout, Row>, I<2>, I<8>>;
using BBlockLdsAddExtraM = std::conditional_t<std::is_same_v<ALayout, Row>, I<0>, I<1>>;
using DeviceGroupedGemmSplitKInstance =
tensor_operation::device::DeviceGroupedGemmXdlSplitKCShuffle<
ALayout,
BLayout,
EmptyTuple,
ELayout,
F16,
F16,
F32,
F16,
EmptyTuple,
F16,
PassThrough,
PassThrough,
PassThrough,
GemmSpec,
1,
128,
128,
128,
KPerBlock,
K1,
K1,
32,
32,
4,
2,
S<1, 4, 16, 1>,
ABlockTransferThreadClusterArrageOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim::value,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1::value,
ABlockLdsAddExtraM::value,
S<1, 4, 16, 1>,
BBlockTransferThreadClusterArrageOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim::value,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1::value,
BBlockLdsAddExtraM::value,
1,
1,
S<1, 16, 1, 8>,
CDEBlockTransferScalarPerVector_NPerBlock>;
bool IsSupported(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1) const
{
std::size_t n_groups = Ms.size();
EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups &&
StrideBs.size() == n_groups && StrideCs.size() == n_groups)
<< "The number of groups is not consistent!";
std::vector<tensor_operation::device::GemmDesc> gemm_descs;
for(std::size_t i = 0; i < n_groups; ++i)
{
gemm_descs.push_back(tensor_operation::device::GemmDesc{
Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
}
std::vector<const void*> p_As(n_groups, nullptr);
std::vector<const void*> p_Bs(n_groups, nullptr);
std::vector<void*> p_Cs(n_groups, nullptr);
auto p_Ds = std::vector<std::array<const void*, 0>>{};
auto ggemm_instance = DeviceGroupedGemmSplitKInstance{};
auto argument = ggemm_instance.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
if(kbatch > 1)
{
ggemm_instance.SetKBatchSize(argument, kbatch);
}
return ggemm_instance.IsSupportedArgument(argument);
}
float Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1) const
{
std::size_t n_groups = Ms.size();
EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups &&
StrideBs.size() == n_groups && StrideCs.size() == n_groups)
<< "The number of groups is not consistent!";
std::vector<tensor_operation::device::GemmDesc> gemm_descs;
for(std::size_t i = 0; i < n_groups; ++i)
{
gemm_descs.push_back(tensor_operation::device::GemmDesc{
Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
}
std::vector<const void*> p_As(n_groups, nullptr);
std::vector<const void*> p_Bs(n_groups, nullptr);
std::vector<void*> p_Cs(n_groups, nullptr);
auto p_Ds = std::vector<std::array<const void*, 0>>{};
auto ggemm_instance = DeviceGroupedGemmSplitKInstance{};
auto argument = ggemm_instance.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
if(kbatch > 1)
{
ggemm_instance.SetKBatchSize(argument, kbatch);
}
EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument));
auto invoker = ggemm_instance.MakeInvoker();
DeviceMem gemm_desc_workspace(ggemm_instance.GetWorkSpaceSize(&argument));
ggemm_instance.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());
return invoker.Run(argument, StreamConfig{nullptr, false});
}
};
} // namespace test
} // namespace ck
...@@ -62,8 +62,7 @@ class TestGroupedGemm : public testing::TestWithParam<int> ...@@ -62,8 +62,7 @@ class TestGroupedGemm : public testing::TestWithParam<int>
const std::vector<int>& Ks, const std::vector<int>& Ks,
const std::vector<int>& StrideAs, const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs, const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs, const std::vector<int>& StrideCs)
int kbatch = 1)
{ {
bool pass = ck::profiler::profile_grouped_gemm_impl<ADataType, bool pass = ck::profiler::profile_grouped_gemm_impl<ADataType,
BDataType, BDataType,
...@@ -72,7 +71,7 @@ class TestGroupedGemm : public testing::TestWithParam<int> ...@@ -72,7 +71,7 @@ class TestGroupedGemm : public testing::TestWithParam<int>
ALayout, ALayout,
BLayout, BLayout,
ELayout>( ELayout>(
verify_, init_method_, log_, bench_, Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch); verify_, init_method_, log_, bench_, Ms, Ns, Ks, StrideAs, StrideBs, StrideCs);
EXPECT_TRUE(pass); EXPECT_TRUE(pass);
} }
}; };
...@@ -171,8 +170,7 @@ struct DeviceGroupedGemmSplitkInstanceWrapper ...@@ -171,8 +170,7 @@ struct DeviceGroupedGemmSplitkInstanceWrapper
const std::vector<int>& Ks, const std::vector<int>& Ks,
const std::vector<int>& StrideAs, const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs, const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs, const std::vector<int>& StrideCs) const
int kbatch = 1) const
{ {
std::size_t n_groups = Ms.size(); std::size_t n_groups = Ms.size();
EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups && EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups &&
...@@ -195,10 +193,6 @@ struct DeviceGroupedGemmSplitkInstanceWrapper ...@@ -195,10 +193,6 @@ struct DeviceGroupedGemmSplitkInstanceWrapper
auto ggemm_instance = DeviceGroupedGemmSplitKInstance{}; auto ggemm_instance = DeviceGroupedGemmSplitKInstance{};
auto argument = ggemm_instance.MakeArgument( auto argument = ggemm_instance.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}); p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
if(kbatch > 1)
{
ggemm_instance.SetKBatchSize(argument, kbatch);
}
return ggemm_instance.IsSupportedArgument(argument); return ggemm_instance.IsSupportedArgument(argument);
} }
...@@ -208,8 +202,7 @@ struct DeviceGroupedGemmSplitkInstanceWrapper ...@@ -208,8 +202,7 @@ struct DeviceGroupedGemmSplitkInstanceWrapper
const std::vector<int>& Ks, const std::vector<int>& Ks,
const std::vector<int>& StrideAs, const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs, const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs, const std::vector<int>& StrideCs) const
int kbatch = 1) const
{ {
std::size_t n_groups = Ms.size(); std::size_t n_groups = Ms.size();
EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups && EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups &&
...@@ -232,10 +225,6 @@ struct DeviceGroupedGemmSplitkInstanceWrapper ...@@ -232,10 +225,6 @@ struct DeviceGroupedGemmSplitkInstanceWrapper
auto ggemm_instance = DeviceGroupedGemmSplitKInstance{}; auto ggemm_instance = DeviceGroupedGemmSplitKInstance{};
auto argument = ggemm_instance.MakeArgument( auto argument = ggemm_instance.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}); p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
if(kbatch > 1)
{
ggemm_instance.SetKBatchSize(argument, kbatch);
}
EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument)); EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument));
auto invoker = ggemm_instance.MakeInvoker(); auto invoker = ggemm_instance.MakeInvoker();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment