Commit 000db488 authored by ltqin's avatar ltqin
Browse files

add tunning parameter for NT

parent b98e339d
...@@ -123,6 +123,11 @@ ...@@ -123,6 +123,11 @@
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1 #define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif #endif
// using split k xdlops kernel
#ifndef CK_USE_SPLITK_XDLOPS
#define CK_USE_SPLITK_XDLOPS 1
#endif
namespace ck { namespace ck {
enum InMemoryDataOperationEnum_t enum InMemoryDataOperationEnum_t
......
...@@ -4,9 +4,6 @@ ...@@ -4,9 +4,6 @@
#include "device_gemm_splitk_xdl.hpp" #include "device_gemm_splitk_xdl.hpp"
#include "device_gemm_instance.hpp" #include "device_gemm_instance.hpp"
#ifndef CK_USE_SPLITK_XDLOPS
#define CK_USE_SPLITK_XDLOPS 1
#endif
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
#include <stdlib.h> #include <stdlib.h>
#include "config.hpp" #include "config.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl.hpp"
#include "device_gemm_splitk_xdl.hpp"
#include "device_gemm_instance.hpp" #include "device_gemm_instance.hpp"
namespace ck { namespace ck {
...@@ -18,6 +19,29 @@ template <ck::index_t... Is> ...@@ -18,6 +19,29 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n] // Compilation parameters for a[m, k] * b[n, k] = c[m, n]
#if CK_USE_SPLITK_XDLOPS
using device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn = std::tuple<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 1, 4, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 1, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 1, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 1, 4, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 1, 2, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>
// clang-format on
>;
#else
using device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn = std::tuple< using device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn = std::tuple<
// clang-format off // clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds| //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
...@@ -39,6 +63,7 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn = std::tuple< ...@@ -39,6 +63,7 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn = std::tuple<
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 2, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true> DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 2, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 7, 1, true, true>
// clang-format on // clang-format on
>; >;
#endif
template <> template <>
void add_device_gemm_instance<F32, F32, F32, Row, Col, Row>( void add_device_gemm_instance<F32, F32, F32, Row, Col, Row>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment