"...git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "bac32ec1f777e4ed65a363e1c93ccd7f4beb990b"
Commit 1b9e6e11 authored by ltqin's avatar ltqin
Browse files

seperate split-k instance files

parent 303b1a86
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_splitk_xdl.hpp"
#include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
/*using device_gemm_xdl_instance_f32_f32_f32_km_kn_mn = std::tuple<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 720>
// clang-format on
>;
*/
template <>
void add_device_splitk_gemm_instance<F32, F32, F32, Col, Row, Row>(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& )
{
/* using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_km_kn_mn;
const auto device_gemms = DeviceGemms{};
ck::static_for<0, std::tuple_size_v<DeviceGemms>, 1>{}([&](auto i) {
using Gemm = remove_cvref_t<decltype(std::get<i>(device_gemms))>;
auto gemm = Gemm{};
device_op_instances.push_back(std::make_unique<Gemm>(gemm));
});*/
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_splitk_xdl.hpp"
#include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
/*using device_gemm_xdl_instance_f32_f32_f32_km_nk_mn = std::tuple<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 2, 4, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 2, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 4, 4, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 2, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 2, 4, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 1, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>
// clang-format on
>;
*/
template <>
void add_device_splitk_gemm_instance<F32, F32, F32, Col, Col, Row>(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& )
{
/* using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_km_nk_mn;
const auto device_gemms = DeviceGemms{};
ck::static_for<0, std::tuple_size_v<DeviceGemms>, 1>{}([&](auto i) {
using Gemm = remove_cvref_t<decltype(std::get<i>(device_gemms))>;
auto gemm = Gemm{};
device_op_instances.push_back(std::make_unique<Gemm>(gemm));
});*/
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_splitk_xdl.hpp"
#include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = std::tuple<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 96, 128, 4, 8, 16, 16, 3, 4, S<1, 1, 3, 4>, S<1, 4, 32, 2>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 8>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 720>
/* DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 720>*/
>;
template <>
void add_device_splitk_gemm_instance<F32, F32, F32, Row, Row, Row>(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& device_op_instances)
{
using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn;
const auto device_gemms = DeviceGemms{};
ck::static_for<0, std::tuple_size_v<DeviceGemms>, 1>{}([&](auto i) {
using Gemm = remove_cvref_t<decltype(std::get<i>(device_gemms))>;
auto gemm = Gemm{};
device_op_instances.push_back(std::make_unique<Gemm>(gemm));
});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_splitk_xdl.hpp"
#include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
/*using device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn = std::tuple<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 1, 4, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 1, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 1, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 1, 4, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 1, 2, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>
// clang-format on
>;
*/
template <>
void add_device_splitk_gemm_instance<F32, F32, F32, Row, Col, Row>(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& )
{
/* using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn;
const auto device_gemms = DeviceGemms{};
ck::static_for<0, std::tuple_size_v<DeviceGemms>, 1>{}([&](auto i) {
using Gemm = remove_cvref_t<decltype(std::get<i>(device_gemms))>;
auto gemm = Gemm{};
device_op_instances.push_back(std::make_unique<Gemm>(gemm));
});*/
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -20,6 +20,17 @@ void add_device_gemm_instance(
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>&);
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
void add_device_splitk_gemm_instance(
std::vector<DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>&);
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
......
......@@ -2,6 +2,7 @@
#define DEVICE_GEMM_SPLITK_XDL_HPP
#include <iostream>
#include <sstream>
#include "device.hpp"
#include "device_base.hpp"
#include "device_gemm.hpp"
......@@ -578,6 +579,24 @@ struct DeviceGemmSplitKXdl
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmXdlSplitK"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
......
......@@ -10,7 +10,7 @@ using DeviceGemmNoOpPtr = DeviceGemmPtr<ck::tensor_operation::element_wise::Pass
ck::tensor_operation::element_wise::PassThrough>;
template <>
void add_device_gemm_instance<float,
void add_device_splitk_gemm_instance<float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
......@@ -18,7 +18,7 @@ void add_device_gemm_instance<float,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_gemm_instance<float,
void add_device_splitk_gemm_instance<float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
......@@ -26,7 +26,7 @@ void add_device_gemm_instance<float,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_gemm_instance<float,
void add_device_splitk_gemm_instance<float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
......@@ -34,44 +34,13 @@ void add_device_gemm_instance<float,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_gemm_instance<float,
void add_device_splitk_gemm_instance<float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_gemm_instance<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_gemm_instance<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_gemm_instance<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_gemm_instance<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device
......
......@@ -22,6 +22,10 @@ set(DEVICE_GEMM_INSTANCE_SOURCE
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_splitk_instance_f32_f32_f32_mk_kn_mn.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_splitk_instance_f32_f32_f32_mk_nk_mn.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_splitk_instance_f32_f32_f32_km_kn_mn.cpp;
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_splitk_instance_f32_f32_f32_km_nk_mn.cpp;
)
add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE})
......
#pragma once
#include "device_gemm_instance.hpp"
#include "device_gemm_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using DeviceGemmNoOpPtr = DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
template <>
void add_device_gemm_instance<float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_gemm_instance<float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_gemm_instance<float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_gemm_instance<float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_gemm_instance<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_gemm_instance<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_gemm_instance<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_gemm_instance<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
namespace ck {
namespace profiler {
......
......@@ -26,10 +26,14 @@ using DeviceGemmNoOpPtr =
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
using GEMM_PTR = std::vector<DeviceGemmNoOpPtr>;
static std::vector<std::vector<bool>> LayOut = {{0, 0, 0}, {0, 1, 0}, {1, 0, 0}, {1, 1, 0}};
static std::vector<std::vector<bool>>& GetLayoutType(){
static std::vector<std::vector<bool>> LayOut = {{0, 0, 0}, {0, 1, 0}, {1, 0, 0}, {1, 1, 0}};
return LayOut;
}
static void add_device_gemm_instance_mk_kn_mn(GEMM_PTR& gemm_ptrs)
{
ck::tensor_operation::device::device_gemm_instance::add_device_gemm_instance<
ck::tensor_operation::device::device_gemm_instance::add_device_splitk_gemm_instance<
float,
float,
float,
......@@ -39,7 +43,7 @@ static void add_device_gemm_instance_mk_kn_mn(GEMM_PTR& gemm_ptrs)
}
static void add_device_gemm_instance_mk_nk_mn(GEMM_PTR& gemm_ptrs)
{
ck::tensor_operation::device::device_gemm_instance::add_device_gemm_instance<
ck::tensor_operation::device::device_gemm_instance::add_device_splitk_gemm_instance<
float,
float,
float,
......@@ -49,7 +53,7 @@ static void add_device_gemm_instance_mk_nk_mn(GEMM_PTR& gemm_ptrs)
}
static void add_device_gemm_instance_km_kn_mn(GEMM_PTR& gemm_ptrs)
{
ck::tensor_operation::device::device_gemm_instance::add_device_gemm_instance<
ck::tensor_operation::device::device_gemm_instance::add_device_splitk_gemm_instance<
float,
float,
float,
......@@ -59,7 +63,7 @@ static void add_device_gemm_instance_km_kn_mn(GEMM_PTR& gemm_ptrs)
}
static void add_device_gemm_instance_km_nk_mn(GEMM_PTR& gemm_ptrs)
{
ck::tensor_operation::device::device_gemm_instance::add_device_gemm_instance<
ck::tensor_operation::device::device_gemm_instance::add_device_splitk_gemm_instance<
float,
float,
float,
......@@ -68,13 +72,19 @@ static void add_device_gemm_instance_km_nk_mn(GEMM_PTR& gemm_ptrs)
ck::tensor_layout::gemm::RowMajor>(gemm_ptrs);
}
static std::vector<void (*)(GEMM_PTR&)> AddDeviceGemmInstance = {add_device_gemm_instance_mk_kn_mn,
static auto& GetAddDeviceGemmInstance()
{
static std::vector<void (*)(GEMM_PTR&)> AddDeviceGemmInstance = {add_device_gemm_instance_mk_kn_mn,
add_device_gemm_instance_mk_nk_mn,
add_device_gemm_instance_km_kn_mn,
add_device_gemm_instance_km_nk_mn};
return AddDeviceGemmInstance;
}
static void add_device_gemm_instance(GEMM_PTR& gemm_ptrs, int layout)
{
AddDeviceGemmInstance[layout](gemm_ptrs);
GetAddDeviceGemmInstance()[layout](gemm_ptrs);
}
template <typename T>
......@@ -95,7 +105,6 @@ static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
}
int main(int argc, char* argv[])
{
if(argc != 8)
{
printf("arg1: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
......@@ -121,6 +130,8 @@ int main(int argc, char* argv[])
printf("arg1 must be 0 ,1 ,2 or 3 \n");
return 1;
}
auto LayOut = GetLayoutType();
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, bool isRevert) {
if(isRevert)
......@@ -205,4 +216,4 @@ int main(int argc, char* argv[])
std::cout << "test split k: Fail " << std::endl;
}
return 0;
}
\ No newline at end of file
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment