Commit 874a78f9 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents 6368be50 2fd6c6d4
......@@ -7,6 +7,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
......@@ -52,7 +53,8 @@ using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances = std::tuple<
DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>
DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffleV2< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 2, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
......
......@@ -7,6 +7,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
......@@ -57,7 +58,8 @@ using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple<
DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>
DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffleV2< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 2, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
......
......@@ -7,6 +7,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
......@@ -52,7 +53,8 @@ using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::tuple<
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>,
DeviceGemm_Xdl_CShuffleV2< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 2, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>
#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES
// pipeline v1, 2 waves
,
......
......@@ -27,6 +27,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmKPadding = ck::tensor_operation::device::GemmSpecialization::KPadding;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
......@@ -110,17 +111,39 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple<
// clang-format on
>;
template <ck::tensor_operation::device::GemmSpecialization GemmSpec>
template <ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::PipelineVersion PipVer,
ck::LoopScheduler LoopSche>
using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances = std::tuple<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave>
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 512, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 512, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 4, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 8, 8, 16, 16, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 8, 8, 16, 16, 1, 8, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 8, 8, 16, 16, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 8, 8, 16, 16, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>
// clang-format on
>;
......@@ -141,9 +164,51 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(
add_device_operation_instances(
instances, device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances<GemmMNKPadding>{});
add_device_operation_instances(
instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<GemmMNKPadding>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<
GemmDefault,
ck::PipelineVersion::v1,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<
GemmDefault,
ck::PipelineVersion::v2,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<
GemmDefault,
ck::PipelineVersion::v1,
ck::LoopScheduler::Interwave>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<
GemmKPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<
GemmKPadding,
ck::PipelineVersion::v2,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<
GemmKPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Interwave>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<
GemmMNKPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<
GemmMNKPadding,
ck::PipelineVersion::v2,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances<
GemmMNKPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Interwave>{});
}
} // namespace instance
......
......@@ -27,6 +27,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmKPadding = ck::tensor_operation::device::GemmSpecialization::KPadding;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
......@@ -95,6 +96,41 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple<
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v2>
// clang-format on
>;
template <ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::PipelineVersion PipVer,
ck::LoopScheduler LoopSche>
using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances = std::tuple<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 512, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 512, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 4, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 16, 4, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 8, 8, 16, 16, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 8, 8, 16, 16, 1, 8, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 8, 8, 16, 16, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 8, 8, 16, 16, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>
// clang-format on
>;
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
......@@ -112,6 +148,52 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
add_device_operation_instances(
instances, device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances<GemmMNKPadding>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances<
GemmDefault,
ck::PipelineVersion::v1,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances<
GemmDefault,
ck::PipelineVersion::v2,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances<
GemmDefault,
ck::PipelineVersion::v1,
ck::LoopScheduler::Interwave>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances<
GemmKPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances<
GemmKPadding,
ck::PipelineVersion::v2,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances<
GemmKPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Interwave>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances<
GemmMNKPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances<
GemmMNKPadding,
ck::PipelineVersion::v2,
ck::LoopScheduler::Default>{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances<
GemmMNKPadding,
ck::PipelineVersion::v1,
ck::LoopScheduler::Interwave>{});
}
} // namespace instance
......
......@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_layernorm2d_bwd_gamma_beta_rank_2_1_f16_instances(
void add_device_layernorm2d_bwd_gamma_beta_f16_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdGammaBeta<F16, F16, F16, F16, F16, 2, 1>>>&
instances)
{
......
......@@ -8,7 +8,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_layernorm2d_bwd_gamma_beta_rank_2_1_f32_instances(
void add_device_layernorm2d_bwd_gamma_beta_f32_instances(
std::vector<std::unique_ptr<DeviceNormalizationBwdGammaBeta<F32, F32, F32, F32, F32, 2, 1>>>&
instances)
{
......
......@@ -42,7 +42,9 @@ int profile_gemm_impl(int do_verification,
int K,
int StrideA,
int StrideB,
int StrideC)
int StrideC,
int n_warmup,
int n_iter)
{
bool pass = true;
......@@ -165,8 +167,8 @@ int profile_gemm_impl(int do_verification,
std::string op_name = op_ptr->GetTypeString();
float avg_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, 10, 50});
float avg_time = invoker_ptr->Run(
argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter});
std::size_t flop = std::size_t(2) * M * N * K;
......@@ -296,7 +298,7 @@ int profile_gemm_impl(int do_verification,
}
}
return pass ? 0 : 1;
return pass;
}
} // namespace profiler
......
......@@ -42,7 +42,9 @@ bool profile_gemm_splitk_impl(int do_verification,
int StrideA,
int StrideB,
int StrideC,
int KBatch)
int KBatch,
int n_warmup,
int n_iter)
{
bool pass = true;
......@@ -143,7 +145,7 @@ bool profile_gemm_splitk_impl(int do_verification,
// profile device GEMM instances
for(auto& op_ptr : op_ptrs)
{
std::vector<int> kbatch_list = {1, 2, 4, 8, 12, 16, 20, 32, 36, 40, 64, 96, 128};
std::vector<int> kbatch_list = {1, 2, 4, 8, 12, 16, 19, 20, 32, 38};
if(KBatch > 0)
{
......@@ -177,7 +179,8 @@ bool profile_gemm_splitk_impl(int do_verification,
// re-init C to zero before profiling next kernel
c_device_buf.SetZero();
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr, false, 0, n_warmup, n_iter});
if(do_verification)
{
......@@ -200,8 +203,8 @@ bool profile_gemm_splitk_impl(int do_verification,
std::string op_name = op_ptr->GetTypeString();
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
float ave_time = invoker_ptr->Run(
argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter});
std::size_t flop = std::size_t(2) * M * N * K;
......
......@@ -42,7 +42,9 @@ bool profile_grouped_gemm_impl(int do_verification,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1)
int kbatch = 1,
int n_warmup = 1,
int n_iter = 10)
{
bool pass = true;
......@@ -261,7 +263,8 @@ bool profile_grouped_gemm_impl(int do_verification,
for(std::size_t i = 0; i < gemm_descs.size(); i++)
c_device_buf[i]->SetZero();
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr, false, 0, n_warmup, n_iter});
if(do_verification)
{
......@@ -307,8 +310,8 @@ bool profile_grouped_gemm_impl(int do_verification,
pass = pass && instance_pass;
}
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
float ave_time = invoker_ptr->Run(
argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter});
if(time_kernel)
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/groupnorm_bwd_gamma_beta.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp"
namespace ck {
namespace profiler {
template <typename DYDataType,
typename XDataType,
typename MeanInvStdDataType,
typename ComputeDataType,
typename DGammaDataType,
typename DBetaDataType>
bool profile_groupnorm_bwd_gamma_beta_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
std::vector<index_t> length)
{
// we don't need GammaDataType and DXDataType here, just for reference class
using GammaDataType = DYDataType;
using DXDataType = DYDataType;
if(length.size() != 5)
return false;
index_t N = length[0];
index_t G = length[3];
index_t C = length[4];
std::vector<index_t> reduce_dim = {0, 1, 2};
std::vector<index_t> gamma_beta_length = {G, C};
Tensor<DYDataType> dy(length);
Tensor<XDataType> x(length);
Tensor<GammaDataType> gamma(gamma_beta_length); // dummy tensor, for reference
Tensor<MeanInvStdDataType> mean({N, G});
Tensor<MeanInvStdDataType> inv_std({N, G});
Tensor<DGammaDataType> dgamma(gamma_beta_length);
Tensor<DBetaDataType> dbeta(gamma_beta_length);
Tensor<DXDataType> host_dx(length); // dummy tensor, for reference
Tensor<DGammaDataType> host_dgamma(gamma_beta_length);
Tensor<DBetaDataType> host_dbeta(gamma_beta_length);
std::vector<index_t> strideDy =
std::vector<ck::index_t>{dy.mDesc.GetStrides().begin(), dy.mDesc.GetStrides().end()};
std::vector<index_t> strideX =
std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()};
std::vector<index_t> strideDGamma{dgamma.mDesc.GetStrides().begin(),
dgamma.mDesc.GetStrides().end()};
std::vector<index_t> strideDBeta{dbeta.mDesc.GetStrides().begin(),
dbeta.mDesc.GetStrides().end()};
std::vector<index_t> strideMeanInvStd = {G, 0, 0, 1, 0};
switch(init_method)
{
case 0:
dy.GenerateTensorValue(GeneratorTensor_1<DYDataType>{});
x.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
mean.GenerateTensorValue(GeneratorTensor_1<MeanInvStdDataType>{});
inv_std.GenerateTensorValue(GeneratorTensor_1<MeanInvStdDataType>{});
dgamma.GenerateTensorValue(GeneratorTensor_1<DGammaDataType>{});
dbeta.GenerateTensorValue(GeneratorTensor_1<DBetaDataType>{});
break;
case 1:
dy.GenerateTensorValue(GeneratorTensor_2<DYDataType>{-5, 5});
x.GenerateTensorValue(GeneratorTensor_2<XDataType>{-5, 5});
mean.GenerateTensorValue(GeneratorTensor_2<MeanInvStdDataType>{-5, 5});
inv_std.GenerateTensorValue(GeneratorTensor_2<MeanInvStdDataType>{0, 5});
dgamma.GenerateTensorValue(GeneratorTensor_2<DGammaDataType>{-5, 5});
dbeta.GenerateTensorValue(GeneratorTensor_2<DBetaDataType>{-5, 5});
break;
default:
dy.GenerateTensorValue(GeneratorTensor_3<DYDataType>{0, 1});
x.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1});
mean.GenerateTensorValue(GeneratorTensor_3<MeanInvStdDataType>{-0.5, 0.5});
inv_std.GenerateTensorValue(GeneratorTensor_3<MeanInvStdDataType>{0, 0.5});
dgamma.GenerateTensorValue(GeneratorTensor_3<DGammaDataType>{-0.5, 0.5});
dbeta.GenerateTensorValue(GeneratorTensor_3<DBetaDataType>{-0.5, 0.5});
}
DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize());
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize());
DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize());
DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize());
DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize());
dy_dev.ToDevice(dy.mData.data());
x_dev.ToDevice(x.mData.data());
mean_dev.ToDevice(mean.mData.data());
inv_std_dev.ToDevice(inv_std.mData.data());
// add device normalization instances
using DeviceOp =
ck::tensor_operation::device::DeviceNormalizationBwdGammaBeta<DYDataType,
XDataType,
MeanInvStdDataType,
DGammaDataType,
DBetaDataType,
5,
3>;
// get device op instances
const auto instance_ptrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << instance_ptrs.size() << " instances" << std::endl;
std::string best_instance_name;
float best_avg_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
if(do_verification)
{
using ReferenceInstance =
ck::tensor_operation::host::ReferenceGroupnormBwd<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
DGammaDataType,
DBetaDataType,
DXDataType,
ComputeDataType>;
ReferenceInstance ref;
auto ref_argument =
ref.MakeArgument(dy, x, gamma, mean, inv_std, host_dgamma, host_dbeta, host_dx, length);
auto ref_invoker = ref.MakeInvoker();
ref_invoker.Run(ref_argument);
}
std::size_t num_bytes = dy.mDesc.GetElementSize() * sizeof(DYDataType) +
x.mDesc.GetElementSize() * sizeof(XDataType) +
mean.mDesc.GetElementSize() * sizeof(MeanInvStdDataType) +
inv_std.mDesc.GetElementSize() * sizeof(MeanInvStdDataType) +
dgamma.mDesc.GetElementSize() * sizeof(DGammaDataType) +
dbeta.mDesc.GetElementSize() * sizeof(DBetaDataType);
int num_kernel = 0;
for(auto& inst_ptr : instance_ptrs)
{
auto argument_ptr = inst_ptr->MakeArgumentPointer(length,
strideDy,
strideX,
strideMeanInvStd,
strideMeanInvStd,
gamma_beta_length,
strideDGamma,
strideDBeta,
reduce_dim,
dy_dev.GetDeviceBuffer(),
x_dev.GetDeviceBuffer(),
mean_dev.GetDeviceBuffer(),
inv_std_dev.GetDeviceBuffer(),
dgamma_dev.GetDeviceBuffer(),
dbeta_dev.GetDeviceBuffer());
if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
}
else
{
if(time_kernel)
{
std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: ";
LogRange(std::cout << "input lengths = ", length, ", ") << std::endl;
}
continue;
}
size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
auto invoker_ptr = inst_ptr->MakeInvokerPointer();
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
float gb_per_sec = num_bytes / 1.E6 / avg_time;
if(time_kernel)
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, "
<< inst_ptr->GetTypeString() << std::endl;
if(avg_time < best_avg_time)
{
best_instance_name = inst_ptr->GetTypeString();
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
{
dgamma_dev.FromDevice(dgamma.mData.data());
dbeta_dev.FromDevice(dbeta.mData.data());
bool pass =
ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3);
pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3);
if(do_log)
{
LogRangeAsType<float>(std::cout << "dy : ", dy.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "host_dgamma : ", host_dgamma.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "dgamma : ", dgamma.mData, ",") << std::endl;
}
if(!pass)
{
std::cout << inst_ptr->GetTypeString() << " failed verification: ";
LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl;
return false;
}
else
{
if(time_kernel)
std::cout << "pass" << std::endl;
}
}
}
if(time_kernel)
{
LogRange(std::cout << "length = ", length, ",") << ", ";
LogRange(std::cout << "reduce dims ", reduce_dim, ",") << std::endl;
std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s,"
<< best_instance_name << std::endl;
}
if(num_kernel == 0)
{
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
}
return true;
}
} // namespace profiler
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/layernorm_bwd_gamma_beta.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp"
namespace ck {
namespace profiler {
template <typename DYDataType,
typename XDataType,
typename MeanInvStdDataType,
typename ComputeDataType,
typename DGammaDataType,
typename DBetaDataType,
index_t Rank>
bool profile_layernorm_bwd_gamma_beta_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
std::vector<index_t> length)
{
// we don't need GammaDataType and DXDataType here, just for reference class
using GammaDataType = DYDataType;
using DXDataType = DYDataType;
if(length.size() != Rank || Rank < 2)
return false;
// Assume normalize dimension for first dimension
// Layernorm 2D, input = [M, K], reduce on M axis
// Layernorm 4D, input = [N, H, W, C], redice on N axis
constexpr int NumReduceDim = Rank - 1;
std::vector<index_t> reduce_dim = {0};
std::vector<index_t> invarient_length{length.begin() + 1, length.end()};
Tensor<DYDataType> dy(length);
Tensor<XDataType> x(length);
Tensor<GammaDataType> gamma(invarient_length); // dummy tensor, for reference
Tensor<MeanInvStdDataType> mean({length[0]});
Tensor<MeanInvStdDataType> inv_std({length[0]});
Tensor<DGammaDataType> dgamma(invarient_length);
Tensor<DBetaDataType> dbeta(invarient_length);
Tensor<DXDataType> host_dx(length); // dummy tensor, for reference
Tensor<DGammaDataType> host_dgamma(invarient_length);
Tensor<DBetaDataType> host_dbeta(invarient_length);
std::vector<index_t> strideDy =
std::vector<ck::index_t>{dy.mDesc.GetStrides().begin(), dy.mDesc.GetStrides().end()};
std::vector<index_t> strideX = strideDy;
std::vector<index_t> strideDGamma{dgamma.mDesc.GetStrides().begin(),
dgamma.mDesc.GetStrides().end()};
std::vector<index_t> strideDBeta{dbeta.mDesc.GetStrides().begin(),
dbeta.mDesc.GetStrides().end()};
std::vector<index_t> strideMeanInvStd{Rank, 0};
strideMeanInvStd[0] = 1;
switch(init_method)
{
case 0:
dy.GenerateTensorValue(GeneratorTensor_1<DYDataType>{});
x.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
mean.GenerateTensorValue(GeneratorTensor_1<MeanInvStdDataType>{});
inv_std.GenerateTensorValue(GeneratorTensor_1<MeanInvStdDataType>{});
dgamma.GenerateTensorValue(GeneratorTensor_1<DGammaDataType>{});
dbeta.GenerateTensorValue(GeneratorTensor_1<DBetaDataType>{});
break;
case 1:
dy.GenerateTensorValue(GeneratorTensor_2<DYDataType>{-5, 5});
x.GenerateTensorValue(GeneratorTensor_2<XDataType>{-5, 5});
mean.GenerateTensorValue(GeneratorTensor_2<MeanInvStdDataType>{-5, 5});
inv_std.GenerateTensorValue(GeneratorTensor_2<MeanInvStdDataType>{0, 5});
dgamma.GenerateTensorValue(GeneratorTensor_2<DGammaDataType>{-5, 5});
dbeta.GenerateTensorValue(GeneratorTensor_2<DBetaDataType>{-5, 5});
break;
default:
dy.GenerateTensorValue(GeneratorTensor_3<DYDataType>{0, 1});
x.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1});
mean.GenerateTensorValue(GeneratorTensor_3<MeanInvStdDataType>{-0.5, 0.5});
inv_std.GenerateTensorValue(GeneratorTensor_3<MeanInvStdDataType>{0, 0.5});
dgamma.GenerateTensorValue(GeneratorTensor_3<DGammaDataType>{-0.5, 0.5});
dbeta.GenerateTensorValue(GeneratorTensor_3<DBetaDataType>{-0.5, 0.5});
}
DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize());
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize());
DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize());
DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize());
DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize());
dy_dev.ToDevice(dy.mData.data());
x_dev.ToDevice(x.mData.data());
mean_dev.ToDevice(mean.mData.data());
inv_std_dev.ToDevice(inv_std.mData.data());
// add device normalization instances
using DeviceOp =
ck::tensor_operation::device::DeviceNormalizationBwdGammaBeta<DYDataType,
XDataType,
MeanInvStdDataType,
DGammaDataType,
DBetaDataType,
Rank,
NumReduceDim>;
// get device op instances
const auto instance_ptrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << instance_ptrs.size() << " instances" << std::endl;
std::string best_instance_name;
float best_avg_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
if(do_verification)
{
using ReferenceInstance =
ck::tensor_operation::host::ReferenceLayernormBwd<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
DGammaDataType,
DBetaDataType,
DXDataType,
ComputeDataType>;
ReferenceInstance ref;
auto ref_argument =
ref.MakeArgument(dy, x, gamma, mean, inv_std, host_dgamma, host_dbeta, host_dx, length);
auto ref_invoker = ref.MakeInvoker();
ref_invoker.Run(ref_argument);
}
std::size_t num_bytes = dy.mDesc.GetElementSize() * sizeof(DYDataType) +
x.mDesc.GetElementSize() * sizeof(XDataType) +
mean.mDesc.GetElementSize() * sizeof(MeanInvStdDataType) +
inv_std.mDesc.GetElementSize() * sizeof(MeanInvStdDataType) +
dgamma.mDesc.GetElementSize() * sizeof(DGammaDataType) +
dbeta.mDesc.GetElementSize() * sizeof(DBetaDataType);
int num_kernel = 0;
for(auto& inst_ptr : instance_ptrs)
{
auto argument_ptr = inst_ptr->MakeArgumentPointer(length,
strideDy,
strideX,
strideMeanInvStd,
strideMeanInvStd,
invarient_length,
strideDGamma,
strideDBeta,
reduce_dim,
dy_dev.GetDeviceBuffer(),
x_dev.GetDeviceBuffer(),
mean_dev.GetDeviceBuffer(),
inv_std_dev.GetDeviceBuffer(),
dgamma_dev.GetDeviceBuffer(),
dbeta_dev.GetDeviceBuffer());
if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
}
else
{
if(time_kernel)
{
std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: ";
LogRange(std::cout << "input lengths = ", length, ", ") << std::endl;
}
continue;
}
size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
auto invoker_ptr = inst_ptr->MakeInvokerPointer();
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
float gb_per_sec = num_bytes / 1.E6 / avg_time;
if(time_kernel)
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, "
<< inst_ptr->GetTypeString() << std::endl;
if(avg_time < best_avg_time)
{
best_instance_name = inst_ptr->GetTypeString();
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
{
dgamma_dev.FromDevice(dgamma.mData.data());
dbeta_dev.FromDevice(dbeta.mData.data());
bool pass =
ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3);
pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3);
if(do_log)
{
LogRangeAsType<float>(std::cout << "dy : ", dy.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "host_dgamma : ", host_dgamma.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "dgamma : ", dgamma.mData, ",") << std::endl;
}
if(!pass)
{
std::cout << inst_ptr->GetTypeString() << " failed verification: ";
LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl;
return false;
}
else
{
if(time_kernel)
std::cout << "pass" << std::endl;
}
}
}
if(time_kernel)
{
LogRange(std::cout << "length = ", length, ",") << ", ";
LogRange(std::cout << "reduce dims ", reduce_dim, ",") << std::endl;
std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s,"
<< best_instance_name << std::endl;
}
if(num_kernel == 0)
{
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
}
return true;
}
} // namespace profiler
} // namespace ck
......@@ -19,6 +19,8 @@ set(PROFILER_SOURCES
profile_groupnorm_bwd_data.cpp
profile_groupnorm_fwd.cpp
profile_layernorm_bwd_data.cpp
profile_layernorm_bwd_gamma_beta.cpp
profile_groupnorm_bwd_gamma_beta.cpp
profile_layernorm_fwd.cpp
profile_max_pool3d_fwd.cpp
profile_avg_pool3d_bwd.cpp
......@@ -82,6 +84,7 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_gamma_beta_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
......
......@@ -42,12 +42,15 @@ static void print_helper_msg()
<< "arg6: print tensor value (0: no; 1: yes)\n"
<< "arg7: time kernel (0: no, 1: yes)\n"
<< "arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"
<< "optional:\n"
<< "arg14: number of warm-up cycles (default 1)\n"
<< "arg15: number of iterations (default 10)\n"
<< std::endl;
}
int profile_gemm(int argc, char* argv[])
{
if(argc != 14)
if(argc != 14 && argc != 16)
{
print_helper_msg();
exit(1);
......@@ -68,6 +71,13 @@ int profile_gemm(int argc, char* argv[])
const int StrideB = std::stoi(argv[12]);
const int StrideC = std::stoi(argv[13]);
int n_warmup = 1;
int n_iter = 10;
if(argc == 16)
{
n_warmup = std::stoi(argv[14]);
n_iter = std::stoi(argv[15]);
}
using F32 = float;
using F16 = ck::half_t;
#ifdef CK_ENABLE_BF16
......@@ -120,13 +130,21 @@ int profile_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? DefaultStrideA : StrideA,
(StrideB < 0) ? DefaultStrideB : StrideB,
(StrideC < 0) ? DefaultStrideC : StrideC);
(StrideC < 0) ? DefaultStrideC : StrideC,
n_warmup,
n_iter);
return pass ? 0 : 1;
};
if(false)
;
if(data_type != GemmDataType::F32_F32_F32 && data_type != GemmDataType::F16_F16_F16 &&
data_type != GemmDataType::BF16_BF16_BF16 && data_type != GemmDataType::INT8_INT8_INT8 &&
data_type != GemmDataType::F8_F8_F8)
{
// dummy clause before the else clauses for different data types
std::cout << "Gemm: this data_type is not implemented" << std::endl;
return 1;
}
#ifdef CK_ENABLE_FP32
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{
......@@ -219,7 +237,7 @@ int profile_gemm(int argc, char* argv[])
#endif
else
{
std::cout << "this data_type & layout is not implemented" << std::endl;
std::cout << "Gemm: this data_type & layout is not implemented" << std::endl;
return 1;
}
......
......@@ -33,7 +33,7 @@ enum struct GemmDataType
int profile_gemm_splitk(int argc, char* argv[])
{
if(argc != 15)
if(argc != 15 && argc != 17)
{
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: f16, "
......@@ -48,6 +48,9 @@ int profile_gemm_splitk(int argc, char* argv[])
printf("arg7: time kernel (0=no, 1=yes)\n");
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
printf("arg14: split k into mulitiple batch\n");
printf("optional:\n");
printf("arg15: number of warm-up cycles (default 1)\n");
printf("arg16: number of iterations (default 10)\n");
exit(1);
}
......@@ -67,6 +70,14 @@ int profile_gemm_splitk(int argc, char* argv[])
const int StrideC = std::stoi(argv[13]);
const int KBatch = std::stoi(argv[14]);
int n_warmup = 1;
int n_iter = 10;
if(argc == 17)
{
n_warmup = std::stoi(argv[15]);
n_iter = std::stoi(argv[16]);
}
using F32 = float;
using F16 = ck::half_t;
#if defined CK_ENABLE_FP8
......@@ -117,7 +128,9 @@ int profile_gemm_splitk(int argc, char* argv[])
(StrideA < 0) ? DefaultStrideA : StrideA,
(StrideB < 0) ? DefaultStrideB : StrideB,
(StrideC < 0) ? DefaultStrideC : StrideC,
KBatch);
KBatch,
n_warmup,
n_iter);
return pass ? 0 : 1;
};
......
......@@ -69,7 +69,10 @@ int profile_grouped_gemm(int argc, char* argv[])
<< "arg7: time kernel (0=n0, 1=yes)\n"
<< "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
"64,64 64,64 128,128)\n"
<< "arg15: kbatch value (default 4)\n"
<< "arg15: kbatch value (default 1)\n"
<< "optional:\n"
<< "arg16: number of warm-up cycles (default 1)\n"
<< "arg17: number of iterations (default 10)\n"
<< std::endl;
exit(1);
......@@ -90,6 +93,15 @@ int profile_grouped_gemm(int argc, char* argv[])
const auto StrideBs = argToIntArray(argv[12]);
const auto StrideCs = argToIntArray(argv[13]);
const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1;
int n_warmup = 1;
int n_iter = 10;
if(argc == 17)
{
n_warmup = std::stoi(argv[16]);
n_iter = std::stoi(argv[17]);
}
#ifdef CK_ENABLE_FP16
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
......@@ -109,7 +121,9 @@ int profile_grouped_gemm(int argc, char* argv[])
StrideAs,
StrideBs,
StrideCs,
kbatch);
kbatch,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
......@@ -129,7 +143,9 @@ int profile_grouped_gemm(int argc, char* argv[])
StrideAs,
StrideBs,
StrideCs,
kbatch);
kbatch,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{
......@@ -149,7 +165,9 @@ int profile_grouped_gemm(int argc, char* argv[])
StrideAs,
StrideBs,
StrideCs,
kbatch);
kbatch,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{
......@@ -169,7 +187,9 @@ int profile_grouped_gemm(int argc, char* argv[])
StrideAs,
StrideBs,
StrideCs,
kbatch);
kbatch,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
......@@ -189,7 +209,9 @@ int profile_grouped_gemm(int argc, char* argv[])
StrideAs,
StrideBs,
StrideCs,
kbatch);
kbatch,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
......@@ -209,7 +231,9 @@ int profile_grouped_gemm(int argc, char* argv[])
StrideAs,
StrideBs,
StrideCs,
kbatch);
kbatch,
n_warmup,
n_iter);
}
else
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
#include <unordered_map>
#include "profiler/data_type_enum.hpp"
#include "profiler/profile_groupnorm_bwd_gamma_beta_impl.hpp"
#include "profiler_operation_registry.hpp"
using ck::index_t;
struct groupnormBwdGammaBetaArgParser
{
std::unordered_map<std::string, std::vector<int>> long_opts = {{"length", {}}};
bool parse_opt(int argc, char* argv[], const std::string& key, int i)
{
if(std::string("--") + key == argv[i])
{
int pos = i;
while(++i < argc && argv[i][0] != '-') {}
int end = i;
for(int j = pos + 1; j < end; j++)
{
long_opts[key].push_back(std::stoi(argv[j]));
}
return true;
}
return false;
}
void operator()(int argc, char* argv[])
{
for(auto& kv : long_opts)
{
for(int i = 1; i < argc; i++)
{
if(parse_opt(argc, argv, kv.first, i))
break;
}
}
}
};
void print_help_groupnorm_bwd_gamma_beta()
{
// eg: ckProfiler groupnorm_bwd_gamma_beta 1 0 2 0 1 --length 1 16 16 32 40
std::cout << "arg1: data type (0: fp16; 1: fp32)\n"
<< "arg2: verification (0: no; 1: yes)\n"
<< "arg3: initialization (0: no init; 1: integer value; 2: decimal value)\n"
<< "arg4: print tensor value (0: no; 1: yes)\n"
<< "arg5: time kernel (0=no, 1=yes)\n"
<< "--length: tensor extents (e.g, --length 1 16 16 32 40) \n"
<< std::endl;
}
int profile_groupnorm_bwd_gamma_beta(int argc, char* argv[])
{
if(argc <= 2)
{
print_help_groupnorm_bwd_gamma_beta();
return 0;
}
groupnormBwdGammaBetaArgParser arg_parser;
// short unnamed options
const ck::DataTypeEnum data_type = static_cast<ck::DataTypeEnum>(std::stoi(argv[2]));
const bool do_verification = std::stoi(argv[3]);
const int init_method = std::stoi(argv[4]);
const bool do_log = std::stoi(argv[5]);
const bool time_kernel = std::stoi(argv[6]);
// parse the long options
arg_parser(argc, argv);
const std::vector<index_t> length = arg_parser.long_opts["length"];
using F32 = float;
if(length.size() == 5)
{
if(data_type == ck::DataTypeEnum::Float)
{
ck::profiler::profile_groupnorm_bwd_gamma_beta_impl<F32, F32, F32, F32, F32, F32>(
do_verification, init_method, do_log, time_kernel, length);
}
else
{
throw std::runtime_error("not implemented yet");
}
}
else
{
throw std::runtime_error("length should be 5");
}
return 0;
}
REGISTER_PROFILER_OPERATION("groupnorm_bwd_gamma_beta",
"Group Normalization",
profile_groupnorm_bwd_gamma_beta);
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
#include <unordered_map>
#include "profiler/data_type_enum.hpp"
#include "profiler/profile_layernorm_bwd_gamma_beta_impl.hpp"
#include "profiler_operation_registry.hpp"
using ck::index_t;
struct layernormBwdGammaBetaArgParser
{
std::unordered_map<std::string, std::vector<int>> long_opts = {{"length", {}}};
bool parse_opt(int argc, char* argv[], const std::string& key, int i)
{
if(std::string("--") + key == argv[i])
{
int pos = i;
while(++i < argc && argv[i][0] != '-') {}
int end = i;
for(int j = pos + 1; j < end; j++)
{
long_opts[key].push_back(std::stoi(argv[j]));
}
return true;
}
return false;
}
void operator()(int argc, char* argv[])
{
for(auto& kv : long_opts)
{
for(int i = 1; i < argc; i++)
{
if(parse_opt(argc, argv, kv.first, i))
break;
}
}
}
};
void print_help_layernorm_bwd_gamma_beta()
{
// eg: ckProfiler layernorm_bwd_gamma_beta 0 0 2 0 1 --length 1502 4096
std::cout << "arg1: data type (0: fp16; 1: fp32)\n"
<< "arg2: verification (0: no; 1: yes)\n"
<< "arg3: initialization (0: no init; 1: integer value; 2: decimal value)\n"
<< "arg4: print tensor value (0: no; 1: yes)\n"
<< "arg5: time kernel (0=no, 1=yes)\n"
<< "--length: tensor extents (e.g, --length 1024 1024) \n"
<< std::endl;
}
int profile_layernorm_bwd_gamma_beta(int argc, char* argv[])
{
if(argc <= 2)
{
print_help_layernorm_bwd_gamma_beta();
return 0;
}
layernormBwdGammaBetaArgParser arg_parser;
// short unnamed options
const ck::DataTypeEnum data_type = static_cast<ck::DataTypeEnum>(std::stoi(argv[2]));
const bool do_verification = std::stoi(argv[3]);
const int init_method = std::stoi(argv[4]);
const bool do_log = std::stoi(argv[5]);
const bool time_kernel = std::stoi(argv[6]);
// parse the long options
arg_parser(argc, argv);
const std::vector<index_t> length = arg_parser.long_opts["length"];
using F16 = ck::half_t;
using F32 = float;
if(length.size() == 2)
{
constexpr int rank = 2;
if(data_type == ck::DataTypeEnum::Half)
{
ck::profiler::profile_layernorm_bwd_gamma_beta_impl<F16, F16, F16, F32, F16, F16, rank>(
do_verification, init_method, do_log, time_kernel, length);
}
else if(data_type == ck::DataTypeEnum::Float)
{
ck::profiler::profile_layernorm_bwd_gamma_beta_impl<F32, F32, F32, F32, F32, F32, rank>(
do_verification, init_method, do_log, time_kernel, length);
}
else
{
throw std::runtime_error("not implemented yet");
}
}
else
{
throw std::runtime_error("not implemented yet");
}
return 0;
}
REGISTER_PROFILER_OPERATION("layernorm_bwd_gamma_beta",
"Layer Normalization",
profile_layernorm_bwd_gamma_beta);
#find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}'
find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}'
git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}'
......@@ -140,6 +140,7 @@ add_subdirectory(block_to_ctile_map)
add_subdirectory(softmax)
add_subdirectory(normalization_fwd)
add_subdirectory(normalization_bwd_data)
add_subdirectory(normalization_bwd_gamma_beta)
add_subdirectory(data_type)
add_subdirectory(elementwise_normalization)
add_subdirectory(batchnorm)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment