Unverified Commit f7d28f3e authored by rocking5566's avatar rocking5566 Committed by GitHub
Browse files

Gemm+layernorm instance, ckProfiler, client example (#568)

* Add gemm + layernorm instance

* Add ckProfiler

* Add test

* Add client example

* Detect if user forger to set the workrspace

* Use literal in the example

* [What] use builtin function for sqrt
[Why] compiler will not use v_sqrt_f64_e64 if we use ::sqrt()

* check gemm vaildity in IsSupportedArgument

* Add more testcases

* Merge duplicated folder in client example

* Print more infomation

* Use better kernel parameter for MS problem size

* clang format

* Add constexpr for if condition and remove redundant include

* Remove cstdlib and add constexpr
parent 76d144fa
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using F16_F16_Tuple = ck::Tuple<F16, F16>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using Row_Row_Tuple = ck::Tuple<Row, Row>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// e = elementwise((a * b), d0, d1)
// h = layernorm(e, gamma, beta)
// outout: h[m, n]
// input: a[k, m], b[k, n], d0[m, n], d1[m, n], gamma[n], beta[n]
template <LoopScheduler GemmLoopScheduler, PipelineVersion GemmPipeline>
using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances = std::tuple<
// clang-format off
//#######################################| A| B| Ds| H| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| LoopScheduler| Pipeline|
//#######################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| | |
//#######################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _NWaveNPerXdl| _M_N| _M| | |
//#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<16, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<16, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<32, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<32, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<16, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<16, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>
// clang-format on
>;
// irregular tile size
using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_irregular_tile_instances =
std::tuple<
// clang-format off
//#######################################| A| B| Ds| H| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| LoopScheduler| Pipeline|
//#######################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| | |
//#######################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _NWaveNPerXdl| _M_N| _M| | |
//#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Interwave, PipelineVersion::v1>
#endif
#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
// pipeline v2, 1 wave
,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Default, PipelineVersion::v2>
#endif
// clang-format on
>;
void add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Row,
Row,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
F16,
F16,
PassThrough,
PassThrough,
AddReluAdd,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances<
LoopScheduler::Default,
PipelineVersion::v1>{});
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
add_device_operation_instances(
instances,
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances<
LoopScheduler::Interwave,
PipelineVersion::v1>{});
#endif
#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
add_device_operation_instances(
instances,
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances<
LoopScheduler::Default,
PipelineVersion::v2>{});
#endif
add_device_operation_instances(
instances,
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_irregular_tile_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using F16_F16_Tuple = ck::Tuple<F16, F16>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using Row_Row_Tuple = ck::Tuple<Row, Row>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// e = elementwise((a * b), d0, d1)
// h = layernorm(e, gamma, beta)
// outout: h[m, n]
// input: a[k, m], b[k, n], d0[m, n], d1[m, n], gamma[n], beta[n]
template <LoopScheduler GemmLoopScheduler, PipelineVersion GemmPipeline>
using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances = std::tuple<
// clang-format off
//#######################################| A| B| Ds| H| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| LoopScheduler| Pipeline|
//#######################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| | |
//#######################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _NWaveNPerXdl| _M_N| _M| | |
//#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<16, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<16, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<16, 4>, 8, S<16, 4>, 1, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<16, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<16, 4>, 8, S<16, 4>, 1, GemmLoopScheduler, GemmPipeline>,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<16, 4>, 8, S<16, 4>, 1, GemmLoopScheduler, GemmPipeline>
// clang-format on
>;
// irregular tile size
using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_irregular_tile_instances =
std::tuple<
// clang-format off
//#######################################| A| B| Ds| H| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| LoopScheduler| Pipeline|
//#######################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| | |
//#######################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _NWaveNPerXdl| _M_N| _M| | |
//#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// pipeline v1, 1 wave
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Interwave, PipelineVersion::v1>
#endif
#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
// pipeline v2, 1 wave
,
DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Default, PipelineVersion::v2>
#endif
// clang-format on
>;
void add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Row,
Col,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
F16,
F16,
PassThrough,
PassThrough,
AddReluAdd,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances<
LoopScheduler::Default,
PipelineVersion::v1>{});
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
add_device_operation_instances(
instances,
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances<
LoopScheduler::Interwave,
PipelineVersion::v1>{});
#endif
#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES
add_device_operation_instances(
instances,
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances<
LoopScheduler::Default,
PipelineVersion::v2>{});
#endif
add_device_operation_instances(
instances,
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_irregular_tile_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp"
namespace ck {
namespace profiler {
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename D0DataType,
typename D1DataType,
typename EMeanVarDataType,
typename GammaDataType,
typename BetaDataType,
typename HDataType,
typename AElementOp,
typename BElementOp,
typename CDEElementOp,
typename HElementOp>
void host_gemm_layernorm(Tensor<HDataType>& h_m_n,
const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
const Tensor<D0DataType>& d0_m_n,
const Tensor<D1DataType>& d1_m_n,
const Tensor<GammaDataType>& gamma_n,
const Tensor<BetaDataType>& beta_n,
AElementOp a_element_op,
BElementOp b_element_op,
CDEElementOp cde_element_op,
HElementOp h_element_op,
int M,
int N,
AccDataType epsilon = 1e-5)
{
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReferenceGemm = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
AccDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
using ReferenceLayernorm = ck::tensor_operation::host::ReferenceLayernorm<EMeanVarDataType,
GammaDataType,
BetaDataType,
HDataType,
AccDataType,
HElementOp,
2,
1>;
Tensor<EMeanVarDataType> e_m_n(HostTensorDescriptor{M, N});
Tensor<AccDataType> c_m_n(HostTensorDescriptor{M, N});
auto ref_gemm = ReferenceGemm{};
auto ref_gemm_invoker = ref_gemm.MakeInvoker();
auto ref_gemm_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
ref_gemm_invoker.Run(ref_gemm_argument);
for(int n = 0; n < N; ++n)
{
for(int m = 0; m < M; ++m)
{
AccDataType e = static_cast<AccDataType>(e_m_n(m, n));
AccDataType d0 = static_cast<AccDataType>(d0_m_n(m, n));
AccDataType d1 = static_cast<AccDataType>(d1_m_n(m, n));
cde_element_op(e, c_m_n(m, n), d0, d1);
e_m_n(m, n) = static_cast<EMeanVarDataType>(e);
}
}
ReferenceLayernorm ref_layernorm;
auto ref_layernorm_invoker = ref_layernorm.MakeInvoker();
auto ref_layernorm_argument = ref_layernorm.MakeArgument(
e_m_n, gamma_n, beta_n, h_m_n, h_element_op, {M, N}, {1}, epsilon);
ref_layernorm_invoker.Run(ref_layernorm_argument);
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename D0DataType,
typename D1DataType,
typename EMeanVarDataType,
typename GammaDataType,
typename BetaDataType,
typename HDataType,
typename ALayout,
typename BLayout,
typename D0Layout,
typename D1Layout,
typename HLayout>
bool profile_gemm_add_relu_add_layernorm_impl(int do_verification,
int init_method,
bool /*do_log*/,
bool time_kernel,
int M,
int N,
int K,
int StrideA,
int StrideB,
int StrideD0,
int StrideD1,
int StrideH,
AccDataType epsilon = 1e-5)
{
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor({len}, {stride});
};
auto f_host_tensor_descriptor2d =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if constexpr(std::is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
Tensor<D1DataType> d0_m_n(f_host_tensor_descriptor2d(M, N, StrideD0, D0Layout{}));
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor2d(M, N, StrideD1, D1Layout{}));
Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1));
Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1));
Tensor<HDataType> h_m_n(f_host_tensor_descriptor2d(M, N, StrideH, HLayout{}));
Tensor<HDataType> h_m_n_host(f_host_tensor_descriptor2d(M, N, StrideH, HLayout{}));
switch(init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1});
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{-1, 1});
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{-1, 1});
gamma_n.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{-1, 1});
beta_n.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{-1, 1});
break;
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = AddReluAdd;
using HElementOp = PassThrough;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto cde_element_op = CDEElementOp{};
const auto h_element_op = HElementOp{};
using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleDLayernorm<
ALayout,
BLayout,
ck::Tuple<D0Layout, D1Layout>,
HLayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
GammaDataType,
BetaDataType,
HDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::AddReluAdd,
ck::tensor_operation::element_wise::PassThrough>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
// run reference
if(do_verification)
{
host_gemm_layernorm<ADataType,
BDataType,
AccDataType,
D0DataType,
D1DataType,
EMeanVarDataType,
GammaDataType,
BetaDataType,
HDataType>(h_m_n_host,
a_m_k,
b_k_n,
d0_m_n,
d1_m_n,
gamma_n,
beta_n,
a_element_op,
b_element_op,
cde_element_op,
h_element_op,
M,
N,
epsilon);
}
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
DeviceMem d1_m_n_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpaceSize());
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpaceSize());
DeviceMem h_device_buf(sizeof(HDataType) * h_m_n.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
d0_m_n_device_buf.ToDevice(d0_m_n.mData.data());
d1_m_n_device_buf.ToDevice(d1_m_n.mData.data());
gamma_device_buf.ToDevice(gamma_n.mData.data());
beta_device_buf.ToDevice(beta_n.mData.data());
std::string best_op_name;
float best_ave_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
bool pass = true;
int num_kernel = 0;
// profile device operation instances
for(auto& op_ptr : op_ptrs)
{
auto argument_ptr = op_ptr->MakeArgumentPointer(
a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
{d0_m_n_device_buf.GetDeviceBuffer(), d1_m_n_device_buf.GetDeviceBuffer()},
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
h_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
{StrideD0, StrideD1},
StrideH,
epsilon,
a_element_op,
b_element_op,
cde_element_op,
h_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
// re-init E to zero before profiling a kernel
h_device_buf.SetZero();
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t num_byte =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
(sizeof(D0DataType) + sizeof(D1DataType) + sizeof(HDataType)) * M * N +
(sizeof(GammaDataType) + sizeof(BetaDataType)) * N;
float gb_per_sec = num_byte / 1.E6 / ave_time;
if(time_kernel)
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec
<< " GB/s, " << op_name << std::endl;
if(ave_time < best_ave_time)
{
best_op_name = op_name;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
{
h_device_buf.FromDevice(h_m_n.mData.data());
pass = pass && ck::utils::check_err(
h_m_n, h_m_n_host, "Error: Incorrect results h_m_n", 1e-2, 1e-2);
}
}
else
{
if(time_kernel)
std::cout << op_name << " does not support this problem" << std::endl;
}
}
if(num_kernel == 0)
{
std::cout << "Error: No kernel is applicable" << std::endl;
pass = false;
}
else
{
if(time_kernel)
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, "
<< best_op_name << std::endl;
}
return pass;
}
} // namespace profiler
} // namespace ck
......@@ -8,6 +8,7 @@ set(PROFILER_SOURCES
profile_gemm_add_add_fastgelu.cpp
profile_gemm_add_multiply.cpp
profile_gemm_add_fastgelu.cpp
profile_gemm_add_relu_add_layernorm.cpp
profile_gemm_fastgelu.cpp
profile_gemm_reduce.cpp
profile_batched_gemm.cpp
......@@ -43,6 +44,7 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgel
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance)
......@@ -66,5 +68,4 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_instan
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_gemm_add_relu_add_layernorm_impl.hpp"
#include "profiler_operation_registry.hpp"
#define OP_NAME "gemm_add_relu_add_layernorm"
#define OP_DESC "GEMM+Add+Relu+Add+Layernorm"
int profile_gemm_add_relu_add_layernorm(int argc, char* argv[])
{
enum struct MatrixLayout
{
MK_KN_MN_MN_MN, // 0
MK_NK_MN_MN_MN, // 1
KM_KN_MN_MN_MN, // 2
KM_NK_MN_MN_MN, // 3
};
enum struct MatrixDataType
{
F32, // 0
F16, // 1
BF16, // 2
};
if(argc != 16)
{
// clang-format off
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16)\n");
printf("arg3: matrix layout (0: H[m, n] = Layernorm(Relu(A[m, k] * B[k, n] + D0[m, n]) + D1[m, n]);\n");
printf(" 1: H[m, n] = Layernorm(Relu(A[m, k] * B[n, k] + D0[m, n]) + D1[m, n]);\n");
printf(" 2: H[m, n] = Layernorm(Relu(A[k, m] * B[k, n] + D0[m, n]) + D1[m, n]);\n");
printf(" 3: H[m, n] = Layernorm(Relu(A[k, m] * B[n, k] + D0[m, n]) + D1[m, n]))\n");
printf("arg4: verification (0: no; 1: yes)\n");
printf("arg5: initialization (0: no init; 1: decimal value)\n");
printf("arg6: print tensor value (0: no; 1: yes)\n");
printf("arg7: time kernel (0=no, 1=yes)\n");
printf("arg8 to 15: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideH\n");
// clang-format on
exit(1);
}
const auto data_type = static_cast<MatrixDataType>(std::stoi(argv[2]));
const auto layout = static_cast<MatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]);
const bool time_kernel = std::stoi(argv[7]);
const int M = std::stoi(argv[8]);
const int N = std::stoi(argv[9]);
const int K = std::stoi(argv[10]);
const int StrideA = std::stoi(argv[11]);
const int StrideB = std::stoi(argv[12]);
const int StrideD0 = std::stoi(argv[13]);
const int StrideD1 = std::stoi(argv[14]);
const int StrideH = std::stoi(argv[15]);
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
auto profile = [&](auto a_type,
auto b_type,
auto acc_type,
auto d0_type,
auto d1_type,
auto e_mean_var_type,
auto gamma_type,
auto beta_type,
auto h_type,
auto a_layout,
auto b_layout,
auto d0_layout,
auto d1_layout,
auto h_layout) {
using ADataType = decltype(a_type);
using BDataType = decltype(b_type);
using AccDataType = decltype(acc_type);
using D0DataType = decltype(d0_type);
using D1DataType = decltype(d1_type);
using EMeanVarDataType = decltype(e_mean_var_type);
using GammaDataType = decltype(gamma_type);
using BetaDataType = decltype(beta_type);
using HDataType = decltype(h_type);
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout);
using D0Layout = decltype(d0_layout);
using D1Layout = decltype(d1_layout);
using HLayout = decltype(h_layout);
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
const int DefaultStrideD0 = ck::is_same_v<D0Layout, Row> ? N : M;
const int DefaultStrideD1 = ck::is_same_v<D1Layout, Row> ? N : M;
const int DefaultStrideH = ck::is_same_v<HLayout, Row> ? N : M;
bool pass = ck::profiler::profile_gemm_add_relu_add_layernorm_impl<ADataType,
BDataType,
AccDataType,
D0DataType,
D1DataType,
EMeanVarDataType,
GammaDataType,
BetaDataType,
HDataType,
ALayout,
BLayout,
D0Layout,
D1Layout,
HLayout>(
do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? DefaultStrideA : StrideA,
(StrideB < 0) ? DefaultStrideB : StrideB,
(StrideD0 < 0) ? DefaultStrideD0 : StrideD0,
(StrideD1 < 0) ? DefaultStrideD1 : StrideD1,
(StrideH < 0) ? DefaultStrideH : StrideH);
return pass ? 0 : 1;
};
if(data_type == MatrixDataType::F16 && layout == MatrixLayout::MK_KN_MN_MN_MN)
{
return profile(F16{},
F16{},
F32{},
F16{},
F16{},
F16{},
F16{},
F16{},
F16{},
Row{},
Row{},
Row{},
Row{},
Row{});
}
else if(data_type == MatrixDataType::F16 && layout == MatrixLayout::MK_NK_MN_MN_MN)
{
return profile(F16{},
F16{},
F32{},
F16{},
F16{},
F16{},
F16{},
F16{},
F16{},
Row{},
Col{},
Row{},
Row{},
Row{});
}
else if(data_type == MatrixDataType::F16 && layout == MatrixLayout::KM_KN_MN_MN_MN)
{
return profile(F16{},
F16{},
F32{},
F16{},
F16{},
F16{},
F16{},
F16{},
F16{},
Col{},
Row{},
Row{},
Row{},
Row{});
}
else if(data_type == MatrixDataType::F16 && layout == MatrixLayout::KM_NK_MN_MN_MN)
{
return profile(F16{},
F16{},
F32{},
F16{},
F16{},
F16{},
F16{},
F16{},
F16{},
Col{},
Col{},
Row{},
Row{},
Row{});
}
else
{
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
}
}
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_add_relu_add_layernorm);
......@@ -27,7 +27,7 @@ function(add_gtest_executable TEST_NAME)
# suppress gtest warnings
target_compile_options(${TEST_NAME} PRIVATE -Wno-global-constructors -Wno-undef)
target_link_libraries(${TEST_NAME} PRIVATE gtest_main)
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}> )
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}>)
rocm_install(TARGETS ${TEST_NAME} COMPONENT tests)
endfunction(add_gtest_executable TEST_NAME)
......@@ -36,6 +36,7 @@ add_subdirectory(space_filling_curve)
add_subdirectory(conv_util)
add_subdirectory(reference_conv_fwd)
add_subdirectory(gemm)
add_subdirectory(gemm_layernorm)
add_subdirectory(gemm_split_k)
add_subdirectory(gemm_reduce)
add_subdirectory(batched_gemm)
......
add_custom_target(test_gemm_layernorm)
add_gtest_executable(test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16.cpp)
target_link_libraries(test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance)
add_dependencies(test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "profiler/profile_gemm_add_relu_add_layernorm_impl.hpp"
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using F16 = ck::half_t;
using F32 = float;
using ck::index_t;
template <typename Tuple>
class TestGemmAddReluAddLayernorm : public ::testing::Test
{
protected:
using ADataType = std::tuple_element_t<0, Tuple>;
using BDataType = std::tuple_element_t<1, Tuple>;
using AccDataType = std::tuple_element_t<2, Tuple>;
using D0DataType = std::tuple_element_t<3, Tuple>;
using D1DataType = std::tuple_element_t<4, Tuple>;
using EMeanVarDataType = std::tuple_element_t<5, Tuple>;
using GammaDataType = std::tuple_element_t<6, Tuple>;
using BetaDataType = std::tuple_element_t<7, Tuple>;
using HDataType = std::tuple_element_t<8, Tuple>;
using ALayout = std::tuple_element_t<9, Tuple>;
using BLayout = std::tuple_element_t<10, Tuple>;
using D0Layout = std::tuple_element_t<11, Tuple>;
using D1Layout = std::tuple_element_t<12, Tuple>;
using HLayout = std::tuple_element_t<13, Tuple>;
void Run()
{
std::vector<std::vector<ck::index_t>> lengths = {
{1024, 1024, 1024}, {2048, 640, 640}, {1, 1, 1}};
for(auto length : lengths)
{
int M = length[0];
int N = length[1];
int K = length[2];
int StrideA = ck::is_same_v<ALayout, Row> ? K : M;
int StrideB = ck::is_same_v<BLayout, Row> ? N : K;
int StrideD0 = 0;
int StrideD1 = ck::is_same_v<D1Layout, Row> ? N : M;
int StrideH = ck::is_same_v<HLayout, Row> ? N : M;
bool success = ck::profiler::profile_gemm_add_relu_add_layernorm_impl<ADataType,
BDataType,
AccDataType,
D0DataType,
D1DataType,
EMeanVarDataType,
GammaDataType,
BetaDataType,
HDataType,
ALayout,
BLayout,
D0Layout,
D1Layout,
HLayout>(
true, 1, false, false, M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideH);
EXPECT_TRUE(success);
}
}
};
using KernelTypes = ::testing::Types<
std::tuple<F16, F16, F32, F16, F16, F16, F16, F16, F16, Row, Row, Row, Row, Row>,
std::tuple<F16, F16, F32, F16, F16, F16, F16, F16, F16, Row, Col, Row, Row, Row>,
std::tuple<F16, F16, F32, F16, F16, F16, F16, F16, F16, Col, Row, Row, Row, Row>,
std::tuple<F16, F16, F32, F16, F16, F16, F16, F16, F16, Col, Col, Row, Row, Row>>;
TYPED_TEST_SUITE(TestGemmAddReluAddLayernorm, KernelTypes);
TYPED_TEST(TestGemmAddReluAddLayernorm, Test_FP16) { this->Run(); }
add_custom_target(test_layernorm)
add_custom_target(test_normalization)
add_gtest_executable(test_layernorm2d_fp32 test_layernorm2d_fp32.cpp)
add_gtest_executable(test_layernorm2d_fp16 test_layernorm2d_fp16.cpp)
add_gtest_executable(test_groupnorm_fp16 test_groupnorm_fp16.cpp)
add_gtest_executable(test_groupnorm_fp32 test_groupnorm_fp32.cpp)
add_gtest_executable(test_groupnorm_fp32 test_groupnorm_fp32.cpp)
target_link_libraries(test_layernorm2d_fp32 PRIVATE utility device_normalization_instance)
target_link_libraries(test_layernorm2d_fp16 PRIVATE utility device_normalization_instance)
target_link_libraries(test_groupnorm_fp16 PRIVATE utility device_normalization_instance)
target_link_libraries(test_groupnorm_fp32 PRIVATE utility device_normalization_instance)
add_dependencies(test_layernorm test_layernorm2d_fp32)
add_dependencies(test_layernorm test_layernorm2d_fp16)
add_dependencies(test_layernorm test_groupnorm_fp16)
add_dependencies(test_layernorm test_groupnorm_fp32)
add_dependencies(test_normalization test_layernorm2d_fp32)
add_dependencies(test_normalization test_layernorm2d_fp16)
add_dependencies(test_normalization test_groupnorm_fp16)
add_dependencies(test_normalization test_groupnorm_fp32)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment