Commit d1e9d6cc authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into refactor_threadwise_copy_using_sfcurve

parents 7fcff6f5 992f71e3
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl_splitk_c_shuffle.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances = std::tuple<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on
>;
void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl_splitk_c_shuffle.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on
>;
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl_splitk_c_shuffle.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>
// clang-format on
>;
using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple<
// clang-format off
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 4>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 2, 2, true, 1, 9, S<1, 2, 1, 72>, 2>
// clang-format on
>;
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances{});
add_device_operation_instances(
instances, device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_tile_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -44,7 +44,7 @@ using CElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default;
// clang-format off
#if 1
#if 0
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| Num|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch|
......@@ -52,6 +52,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// [256, 128, 4, 8], 1 stage, 2 occupancy
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1>;
#elif 1
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
//######|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>;
#elif 0
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| Num|
......@@ -81,14 +88,6 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
// [ 64, 144, 4, 8], 2 stage, 2 occupancy
// 85 TFlops
// < F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 144, 4, 8, 16, 16, 1, 9, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1, 2>;
#elif 1
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Num|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Prefetch|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// [128, 144, 8, 8], 1 stage, 1 occupancy, bounded by LDS size
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 144, 8, 8, 16, 16, 2, 9, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 1, 9, S<1, 1, 8, 1, 9, 2>, 8, 1>;
#endif
// clang-format on
......@@ -155,8 +154,8 @@ int main(int argc, char* argv[])
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<BDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<BDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
......@@ -169,9 +168,13 @@ int main(int argc, char* argv[])
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
}
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
......
......@@ -154,4 +154,15 @@ struct GeneratorTensor_Checkboard
}
};
template <ck::index_t Dim>
struct GeneratorTensor_Sequential
{
template <typename... Ts>
float operator()(Ts... Xs) const
{
std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}};
return dims[Dim];
}
};
#endif
......@@ -44,6 +44,11 @@ void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector<Devic
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
......@@ -68,7 +73,7 @@ void profile_gemm_impl(int do_verification,
int StrideA,
int StrideB,
int StrideC,
int KBatch = 1)
int KBatch)
{
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
......@@ -181,7 +186,6 @@ void profile_gemm_impl(int do_verification,
{
if(KBatch > 1)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
}
......@@ -213,6 +217,13 @@ void profile_gemm_impl(int do_verification,
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
if(KBatch > 1)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
}
else
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
......@@ -220,9 +231,17 @@ void profile_gemm_impl(int do_verification,
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
}
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
if(KBatch > 1)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
}
else
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
......@@ -233,9 +252,17 @@ void profile_gemm_impl(int do_verification,
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
}
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
if(KBatch > 1)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
}
else
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
......@@ -243,9 +270,17 @@ void profile_gemm_impl(int do_verification,
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
}
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
if(KBatch > 1)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
}
else
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
......@@ -254,6 +289,7 @@ void profile_gemm_impl(int do_verification,
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
}
}
}
if(gemm_ptrs.size() <= 0)
{
......
......@@ -78,7 +78,8 @@ int profile_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
......@@ -97,7 +98,8 @@ int profile_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC);
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{
......@@ -116,7 +118,8 @@ int profile_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{
......@@ -135,7 +138,8 @@ int profile_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC);
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{
......
......@@ -13,40 +13,24 @@ include_directories(BEFORE
${PROJECT_SOURCE_DIR}/test/include
)
# test_magic_number_division
set(MAGIC_NUMBER_DIVISISON_SOURCE magic_number_division/main.cpp)
add_executable(test_magic_number_division ${MAGIC_NUMBER_DIVISISON_SOURCE})
target_link_libraries(test_magic_number_division PRIVATE host_tensor)
set(CONV2D_FWD_SOURCE conv2d_fwd/main.cpp)
add_executable(test_conv2d_fwd ${CONV2D_FWD_SOURCE})
target_link_libraries(test_conv2d_fwd PRIVATE host_tensor)
target_link_libraries(test_conv2d_fwd PRIVATE device_conv2d_fwd_instance)
# test_split_k
set(SPLIT_K_SOURCE split_k/main.cpp)
add_executable(test_split_k ${SPLIT_K_SOURCE})
target_link_libraries(test_split_k PRIVATE host_tensor)
target_link_libraries(test_split_k PRIVATE device_gemm_instance)
# test_conv_util
set(CONV_UTIL_SOURCE conv_util/main.cpp)
add_executable(test_conv_util ${CONV_UTIL_SOURCE})
target_link_libraries(test_conv_util PRIVATE host_tensor)
# test_reference_conv_fwd
set(REFERENCE_CONV_FWD_SOURCE reference_conv_fwd/main.cpp)
add_executable(test_reference_conv_fwd ${REFERENCE_CONV_FWD_SOURCE})
target_link_libraries(test_reference_conv_fwd PRIVATE host_tensor)
# test_convnd_fwd_xdl
set(CONVND_FWD_XDL_SOURCE convnd_fwd_xdl/main.cpp)
add_executable(test_convnd_fwd_xdl ${CONVND_FWD_XDL_SOURCE})
target_link_libraries(test_convnd_fwd_xdl PRIVATE host_tensor)
# test space_filling_curve_
set(SPACE_FILLING_CURVE_SOURCE space_filling_curve/space_filling_curve.cpp)
add_executable(space_filling_curve ${SPACE_FILLING_CURVE_SOURCE})
target_link_libraries(space_filling_curve PRIVATE host_tensor)
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR})
add_custom_target(tests)
function(add_test_executeable TEST_NAME)
add_executable(${TEST_NAME} ${ARGN})
target_link_libraries(${TEST_NAME} PRIVATE host_tensor)
target_link_libraries(${TEST_NAME} PRIVATE device_gemm_instance)
target_link_libraries(${TEST_NAME} PRIVATE device_conv2d_fwd_instance)
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}> )
add_dependencies(tests ${TEST_NAME})
add_dependencies(check ${TEST_NAME})
endfunction(add_test_executeable TEST_NAME)
file(GLOB TESTS *.cpp)
foreach(TEST ${TESTS})
get_filename_component(BASE_NAME ${TEST} NAME_WE)
message("adding test ${BASE_NAME}")
add_test_executeable(test_${BASE_NAME} ${TEST})
endforeach(TEST ${TESTS})
......@@ -75,8 +75,12 @@ int main(int argc, char* argv[])
ck::index_t in_left_pad_w = 1;
ck::index_t in_right_pad_h = 1;
ck::index_t in_right_pad_w = 1;
if(argc == 3)
if(argc == 1)
{
init_method = 1;
data_type = 0;
}
else if(argc == 3)
{
data_type = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
......@@ -275,33 +279,31 @@ int main(int argc, char* argv[])
if(success)
{
std::cout << "test conv2d fwd : Pass" << std::endl;
return 0;
}
else
{
std::cout << "test conv2d fwd: Fail " << std::endl;
return -1;
}
};
int res = -1;
if(data_type == 0)
{
Run(float(), float(), float());
res = Run(float(), float(), float());
}
else if(data_type == 1)
{
Run(ck::half_t(), ck::half_t(), ck::half_t());
res = Run(ck::half_t(), ck::half_t(), ck::half_t());
}
else if(data_type == 2)
{
Run(ushort(), ushort(), ushort());
res = Run(ushort(), ushort(), ushort());
}
else if(data_type == 3)
{
Run(int8_t(), int8_t(), int8_t());
}
else
{
return 1;
res = Run(int8_t(), int8_t(), int8_t());
}
return 0;
return res;
}
......@@ -161,11 +161,12 @@ int main(int, char*[])
if(pass)
{
std::cout << "test magic number division: Pass" << std::endl;
return 0;
}
else
{
std::cout << "test magic number division: Fail" << std::endl;
return -1;
}
return 1;
}
......@@ -57,32 +57,24 @@ static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
return true;
}
int main(int argc, char* argv[])
struct gemmArgs
{
if(argc != 9)
{
printf("arg1: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
printf("arg2 to 7: M, N, K, StrideA, StrideB, StrideC KBatch\n");
return 1;
}
const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[1]));
const int M = std::stoi(argv[2]);
const int N = std::stoi(argv[3]);
const int K = std::stoi(argv[4]);
int layout;
int M;
int N;
int K;
int StrideA;
int StrideB;
int StrideC;
int KBatch;
};
const int StrideA = std::stoi(argv[5]);
const int StrideB = std::stoi(argv[6]);
const int StrideC = std::stoi(argv[7]);
const int KBatch = std::stoi(argv[8]);
int test_gemm(const gemmArgs& args)
{
bool a_row_major, b_row_major, c_row_major;
switch(layout)
switch(args.layout)
{
case GemmMatrixLayout::MK_KN_MN:
a_row_major = true;
......@@ -121,10 +113,10 @@ int main(int argc, char* argv[])
}
};
Tensor<float> a_m_k(f_host_tensor_descriptor(M, K, StrideA, a_row_major));
Tensor<float> b_k_n(f_host_tensor_descriptor(K, N, StrideB, b_row_major));
Tensor<float> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, c_row_major));
Tensor<float> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, c_row_major));
Tensor<float> a_m_k(f_host_tensor_descriptor(args.M, args.K, args.StrideA, a_row_major));
Tensor<float> b_k_n(f_host_tensor_descriptor(args.K, args.N, args.StrideB, b_row_major));
Tensor<float> c_m_n_host_result(f_host_tensor_descriptor(args.M, args.N, args.StrideC, c_row_major));
Tensor<float> c_m_n_device_result(f_host_tensor_descriptor(args.M, args.N, args.StrideC, c_row_major));
// init data
std::size_t num_thread = std::thread::hardware_concurrency();
......@@ -151,17 +143,17 @@ int main(int argc, char* argv[])
// add device GEMM instances
std::vector<DeviceGemmNoOpPtr> gemm_ptrs;
if(layout == GemmMatrixLayout::MK_KN_MN)
if(args.layout == GemmMatrixLayout::MK_KN_MN)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
}
else if(layout == GemmMatrixLayout::MK_NK_MN)
else if(args.layout == GemmMatrixLayout::MK_NK_MN)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
}
else if(layout == GemmMatrixLayout::KM_KN_MN)
else if(args.layout == GemmMatrixLayout::KM_KN_MN)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
......@@ -179,16 +171,16 @@ int main(int argc, char* argv[])
gemm_ptr->MakeArgumentPointer(static_cast<float*>(a_device_buf.GetDeviceBuffer()),
static_cast<float*>(b_device_buf.GetDeviceBuffer()),
static_cast<float*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
args.M,
args.N,
args.K,
args.StrideA,
args.StrideB,
args.StrideC,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
KBatch);
args.KBatch);
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
......@@ -205,7 +197,7 @@ int main(int argc, char* argv[])
success = true;
}
}
auto error_code = 0;
if(success)
{
std::cout << "test split k : Pass" << std::endl;
......@@ -213,6 +205,49 @@ int main(int argc, char* argv[])
else
{
std::cout << "test split k: Fail " << std::endl;
error_code = -1; // test needs to report failure
}
return error_code;
}
int main(int argc, char* argv[])
{
std::vector<gemmArgs> test_cases;
if(argc == 1)
{
test_cases = {{0, 3, 3, 3, 3, 3, 3, 1}};
// JD: Populate with more and meaningful
return 0;
}
else if(argc == 9)
{
const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[1]));
const int M = std::stoi(argv[2]);
const int N = std::stoi(argv[3]);
const int K = std::stoi(argv[4]);
const int StrideA = std::stoi(argv[5]);
const int StrideB = std::stoi(argv[6]);
const int StrideC = std::stoi(argv[7]);
const int KBatch = std::stoi(argv[8]);
test_cases = {{layout, M, N, K, StrideA, StrideB, StrideC, KBatch}};
}
else
{
printf("arg1: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
printf("arg2 to 7: M, N, K, StrideA, StrideB, StrideC KBatch\n");
return -1;
}
for(const auto& kinder: test_cases)
{
const auto res = test_gemm(kinder);
if(!res)
return -1;
}
return 0;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment