Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1b9e6e11
Commit
1b9e6e11
authored
Dec 29, 2021
by
ltqin
Browse files
seperate split-k instance files
parent
303b1a86
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
389 additions
and
45 deletions
+389
-45
device_operation/device_gemm_xdl_splitk_instance_f32_f32_f32_km_kn_mn.cpp
.../device_gemm_xdl_splitk_instance_f32_f32_f32_km_kn_mn.cpp
+62
-0
device_operation/device_gemm_xdl_splitk_instance_f32_f32_f32_km_nk_mn.cpp
.../device_gemm_xdl_splitk_instance_f32_f32_f32_km_nk_mn.cpp
+62
-0
device_operation/device_gemm_xdl_splitk_instance_f32_f32_f32_mk_kn_mn.cpp
.../device_gemm_xdl_splitk_instance_f32_f32_f32_mk_kn_mn.cpp
+62
-0
device_operation/device_gemm_xdl_splitk_instance_f32_f32_f32_mk_nk_mn.cpp
.../device_gemm_xdl_splitk_instance_f32_f32_f32_mk_nk_mn.cpp
+67
-0
device_operation/include/device_gemm_instance.hpp
device_operation/include/device_gemm_instance.hpp
+11
-0
device_operation/include/device_gemm_splitk_xdl.hpp
device_operation/include/device_gemm_splitk_xdl.hpp
+19
-0
device_operation/include/device_gemm_xdl_instance.hpp
device_operation/include/device_gemm_xdl_instance.hpp
+4
-35
profiler/CMakeLists.txt
profiler/CMakeLists.txt
+4
-0
profiler/include/profile_gemm_impl.hpp
profiler/include/profile_gemm_impl.hpp
+78
-1
test/split_k/main.cpp
test/split_k/main.cpp
+20
-9
No files found.
device_operation/device_gemm_xdl_splitk_instance_f32_f32_f32_km_kn_mn.cpp
0 → 100644
View file @
1b9e6e11
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_splitk_xdl.hpp"
#include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
/*using device_gemm_xdl_instance_f32_f32_f32_km_kn_mn = std::tuple<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 720>
// clang-format on
>;
*/
template
<
>
void
add_device_splitk_gemm_instance
<
F32
,
F32
,
F32
,
Col
,
Row
,
Row
>
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
)
{
/* using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_km_kn_mn;
const auto device_gemms = DeviceGemms{};
ck::static_for<0, std::tuple_size_v<DeviceGemms>, 1>{}([&](auto i) {
using Gemm = remove_cvref_t<decltype(std::get<i>(device_gemms))>;
auto gemm = Gemm{};
device_op_instances.push_back(std::make_unique<Gemm>(gemm));
});*/
}
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
device_operation/device_gemm_xdl_splitk_instance_f32_f32_f32_km_nk_mn.cpp
0 → 100644
View file @
1b9e6e11
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_splitk_xdl.hpp"
#include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
/*using device_gemm_xdl_instance_f32_f32_f32_km_nk_mn = std::tuple<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 2, 4, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 2, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 4, 4, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 2, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 2, 4, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2> , S<0, 1, 3, 2> , 2, 1, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>
// clang-format on
>;
*/
template
<
>
void
add_device_splitk_gemm_instance
<
F32
,
F32
,
F32
,
Col
,
Col
,
Row
>
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
)
{
/* using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_km_nk_mn;
const auto device_gemms = DeviceGemms{};
ck::static_for<0, std::tuple_size_v<DeviceGemms>, 1>{}([&](auto i) {
using Gemm = remove_cvref_t<decltype(std::get<i>(device_gemms))>;
auto gemm = Gemm{};
device_op_instances.push_back(std::make_unique<Gemm>(gemm));
});*/
}
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
device_operation/device_gemm_xdl_splitk_instance_f32_f32_f32_mk_kn_mn.cpp
0 → 100644
View file @
1b9e6e11
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_splitk_xdl.hpp"
#include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using
device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn
=
std
::
tuple
<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmSplitKXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
96
,
128
,
4
,
8
,
16
,
16
,
3
,
4
,
S
<
1
,
1
,
3
,
4
>
,
S
<
1
,
4
,
32
,
2
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
S
<
1
,
1
,
2
,
8
>
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
7
,
1
,
true
,
true
,
720
>
/* DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, 7, 1, true, true, 720>*/
>
;
template
<
>
void
add_device_splitk_gemm_instance
<
F32
,
F32
,
F32
,
Row
,
Row
,
Row
>
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
device_op_instances
)
{
using
DeviceGemms
=
device_gemm_instance
::
device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn
;
const
auto
device_gemms
=
DeviceGemms
{};
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
DeviceGemms
>
,
1
>
{}([
&
](
auto
i
)
{
using
Gemm
=
remove_cvref_t
<
decltype
(
std
::
get
<
i
>
(
device_gemms
))
>
;
auto
gemm
=
Gemm
{};
device_op_instances
.
push_back
(
std
::
make_unique
<
Gemm
>
(
gemm
));
});
}
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
device_operation/device_gemm_xdl_splitk_instance_f32_f32_f32_mk_nk_mn.cpp
0 → 100644
View file @
1b9e6e11
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_splitk_xdl.hpp"
#include "device_gemm_instance.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
/*using device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn = std::tuple<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| ABlockLds| BBlockLds|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 1, 2, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 1, 4, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 1, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 1, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 1, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 1, 4, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 2, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>,
DeviceGemmSplitKXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 1, 2, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, S<1, 1, 4, 4>, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 7, 1, true, true, 720>
// clang-format on
>;
*/
template
<
>
void
add_device_splitk_gemm_instance
<
F32
,
F32
,
F32
,
Row
,
Col
,
Row
>
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
)
{
/* using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn;
const auto device_gemms = DeviceGemms{};
ck::static_for<0, std::tuple_size_v<DeviceGemms>, 1>{}([&](auto i) {
using Gemm = remove_cvref_t<decltype(std::get<i>(device_gemms))>;
auto gemm = Gemm{};
device_op_instances.push_back(std::make_unique<Gemm>(gemm));
});*/
}
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
device_operation/include/device_gemm_instance.hpp
View file @
1b9e6e11
...
@@ -20,6 +20,17 @@ void add_device_gemm_instance(
...
@@ -20,6 +20,17 @@ void add_device_gemm_instance(
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>&
);
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>&
);
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
void
add_device_splitk_gemm_instance
(
std
::
vector
<
DeviceGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>&
);
}
// namespace device_gemm_instance
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
device_operation/include/device_gemm_splitk_xdl.hpp
View file @
1b9e6e11
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#define DEVICE_GEMM_SPLITK_XDL_HPP
#define DEVICE_GEMM_SPLITK_XDL_HPP
#include <iostream>
#include <iostream>
#include <sstream>
#include "device.hpp"
#include "device.hpp"
#include "device_base.hpp"
#include "device_base.hpp"
#include "device_gemm.hpp"
#include "device_gemm.hpp"
...
@@ -578,6 +579,24 @@ struct DeviceGemmSplitKXdl
...
@@ -578,6 +579,24 @@ struct DeviceGemmSplitKXdl
{
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGemmXdlSplitK"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
};
}
// namespace device
}
// namespace device
...
...
device_operation/include/device_gemm_xdl_instance.hpp
View file @
1b9e6e11
...
@@ -10,7 +10,7 @@ using DeviceGemmNoOpPtr = DeviceGemmPtr<ck::tensor_operation::element_wise::Pass
...
@@ -10,7 +10,7 @@ using DeviceGemmNoOpPtr = DeviceGemmPtr<ck::tensor_operation::element_wise::Pass
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
template
<
>
template
<
>
void
add_device_gemm_instance
<
float
,
void
add_device_
splitk_
gemm_instance
<
float
,
float
,
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
...
@@ -18,7 +18,7 @@ void add_device_gemm_instance<float,
...
@@ -18,7 +18,7 @@ void add_device_gemm_instance<float,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
template
<
>
void
add_device_gemm_instance
<
float
,
void
add_device_
splitk_
gemm_instance
<
float
,
float
,
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
...
@@ -26,7 +26,7 @@ void add_device_gemm_instance<float,
...
@@ -26,7 +26,7 @@ void add_device_gemm_instance<float,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
template
<
>
void
add_device_gemm_instance
<
float
,
void
add_device_
splitk_
gemm_instance
<
float
,
float
,
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
...
@@ -34,44 +34,13 @@ void add_device_gemm_instance<float,
...
@@ -34,44 +34,13 @@ void add_device_gemm_instance<float,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
template
<
>
void
add_device_gemm_instance
<
float
,
void
add_device_
splitk_
gemm_instance
<
float
,
float
,
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_gemm_instance
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_gemm_instance
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_gemm_instance
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_gemm_instance
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
}
// namespace device_gemm_instance
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace device
...
...
profiler/CMakeLists.txt
View file @
1b9e6e11
...
@@ -22,6 +22,10 @@ set(DEVICE_GEMM_INSTANCE_SOURCE
...
@@ -22,6 +22,10 @@ set(DEVICE_GEMM_INSTANCE_SOURCE
${
PROJECT_SOURCE_DIR
}
/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/device_gemm_xdl_splitk_instance_f32_f32_f32_mk_kn_mn.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/device_gemm_xdl_splitk_instance_f32_f32_f32_mk_nk_mn.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/device_gemm_xdl_splitk_instance_f32_f32_f32_km_kn_mn.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/device_gemm_xdl_splitk_instance_f32_f32_f32_km_nk_mn.cpp;
)
)
add_library
(
device_gemm_instance SHARED
${
DEVICE_GEMM_INSTANCE_SOURCE
}
)
add_library
(
device_gemm_instance SHARED
${
DEVICE_GEMM_INSTANCE_SOURCE
}
)
...
...
profiler/include/profile_gemm_impl.hpp
View file @
1b9e6e11
#pragma once
#pragma once
#include "device_gemm_instance.hpp"
#include "device_gemm_instance.hpp"
#include "device_gemm_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
DeviceGemmNoOpPtr
=
DeviceGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
template
<
>
void
add_device_gemm_instance
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_gemm_instance
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_gemm_instance
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_gemm_instance
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_gemm_instance
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_gemm_instance
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_gemm_instance
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_gemm_instance
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
namespace
ck
{
namespace
ck
{
namespace
profiler
{
namespace
profiler
{
...
...
test/split_k/main.cpp
View file @
1b9e6e11
...
@@ -26,10 +26,14 @@ using DeviceGemmNoOpPtr =
...
@@ -26,10 +26,14 @@ using DeviceGemmNoOpPtr =
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
using
GEMM_PTR
=
std
::
vector
<
DeviceGemmNoOpPtr
>
;
using
GEMM_PTR
=
std
::
vector
<
DeviceGemmNoOpPtr
>
;
static
std
::
vector
<
std
::
vector
<
bool
>>
LayOut
=
{{
0
,
0
,
0
},
{
0
,
1
,
0
},
{
1
,
0
,
0
},
{
1
,
1
,
0
}};
static
std
::
vector
<
std
::
vector
<
bool
>>&
GetLayoutType
(){
static
std
::
vector
<
std
::
vector
<
bool
>>
LayOut
=
{{
0
,
0
,
0
},
{
0
,
1
,
0
},
{
1
,
0
,
0
},
{
1
,
1
,
0
}};
return
LayOut
;
}
static
void
add_device_gemm_instance_mk_kn_mn
(
GEMM_PTR
&
gemm_ptrs
)
static
void
add_device_gemm_instance_mk_kn_mn
(
GEMM_PTR
&
gemm_ptrs
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_instance
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_
splitk_
gemm_instance
<
float
,
float
,
float
,
float
,
float
,
float
,
...
@@ -39,7 +43,7 @@ static void add_device_gemm_instance_mk_kn_mn(GEMM_PTR& gemm_ptrs)
...
@@ -39,7 +43,7 @@ static void add_device_gemm_instance_mk_kn_mn(GEMM_PTR& gemm_ptrs)
}
}
static
void
add_device_gemm_instance_mk_nk_mn
(
GEMM_PTR
&
gemm_ptrs
)
static
void
add_device_gemm_instance_mk_nk_mn
(
GEMM_PTR
&
gemm_ptrs
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_instance
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_
splitk_
gemm_instance
<
float
,
float
,
float
,
float
,
float
,
float
,
...
@@ -49,7 +53,7 @@ static void add_device_gemm_instance_mk_nk_mn(GEMM_PTR& gemm_ptrs)
...
@@ -49,7 +53,7 @@ static void add_device_gemm_instance_mk_nk_mn(GEMM_PTR& gemm_ptrs)
}
}
static
void
add_device_gemm_instance_km_kn_mn
(
GEMM_PTR
&
gemm_ptrs
)
static
void
add_device_gemm_instance_km_kn_mn
(
GEMM_PTR
&
gemm_ptrs
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_instance
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_
splitk_
gemm_instance
<
float
,
float
,
float
,
float
,
float
,
float
,
...
@@ -59,7 +63,7 @@ static void add_device_gemm_instance_km_kn_mn(GEMM_PTR& gemm_ptrs)
...
@@ -59,7 +63,7 @@ static void add_device_gemm_instance_km_kn_mn(GEMM_PTR& gemm_ptrs)
}
}
static
void
add_device_gemm_instance_km_nk_mn
(
GEMM_PTR
&
gemm_ptrs
)
static
void
add_device_gemm_instance_km_nk_mn
(
GEMM_PTR
&
gemm_ptrs
)
{
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_instance
<
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_
splitk_
gemm_instance
<
float
,
float
,
float
,
float
,
float
,
float
,
...
@@ -68,13 +72,19 @@ static void add_device_gemm_instance_km_nk_mn(GEMM_PTR& gemm_ptrs)
...
@@ -68,13 +72,19 @@ static void add_device_gemm_instance_km_nk_mn(GEMM_PTR& gemm_ptrs)
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
gemm_ptrs
);
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
gemm_ptrs
);
}
}
static
std
::
vector
<
void
(
*
)(
GEMM_PTR
&
)
>
AddDeviceGemmInstance
=
{
add_device_gemm_instance_mk_kn_mn
,
static
auto
&
GetAddDeviceGemmInstance
()
{
static
std
::
vector
<
void
(
*
)(
GEMM_PTR
&
)
>
AddDeviceGemmInstance
=
{
add_device_gemm_instance_mk_kn_mn
,
add_device_gemm_instance_mk_nk_mn
,
add_device_gemm_instance_mk_nk_mn
,
add_device_gemm_instance_km_kn_mn
,
add_device_gemm_instance_km_kn_mn
,
add_device_gemm_instance_km_nk_mn
};
add_device_gemm_instance_km_nk_mn
};
return
AddDeviceGemmInstance
;
}
static
void
add_device_gemm_instance
(
GEMM_PTR
&
gemm_ptrs
,
int
layout
)
static
void
add_device_gemm_instance
(
GEMM_PTR
&
gemm_ptrs
,
int
layout
)
{
{
AddDeviceGemmInstance
[
layout
](
gemm_ptrs
);
Get
AddDeviceGemmInstance
()
[
layout
](
gemm_ptrs
);
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -95,7 +105,6 @@ static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
...
@@ -95,7 +105,6 @@ static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
}
}
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
if
(
argc
!=
8
)
if
(
argc
!=
8
)
{
{
printf
(
"arg1: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
);
printf
(
"arg1: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
);
...
@@ -121,6 +130,8 @@ int main(int argc, char* argv[])
...
@@ -121,6 +130,8 @@ int main(int argc, char* argv[])
printf
(
"arg1 must be 0 ,1 ,2 or 3
\n
"
);
printf
(
"arg1 must be 0 ,1 ,2 or 3
\n
"
);
return
1
;
return
1
;
}
}
auto
LayOut
=
GetLayoutType
();
auto
f_host_tensor_descriptor
=
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
bool
isRevert
)
{
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
bool
isRevert
)
{
if
(
isRevert
)
if
(
isRevert
)
...
@@ -205,4 +216,4 @@ int main(int argc, char* argv[])
...
@@ -205,4 +216,4 @@ int main(int argc, char* argv[])
std
::
cout
<<
"test split k: Fail "
<<
std
::
endl
;
std
::
cout
<<
"test split k: Fail "
<<
std
::
endl
;
}
}
return
0
;
return
0
;
}
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment