Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
85712f19
Commit
85712f19
authored
Sep 05, 2021
by
Chao Liu
Browse files
add GEMM km_nk_mn
parent
e6106577
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
279 additions
and
0 deletions
+279
-0
host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp
host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp
+219
-0
host/driver_offline/src/gemm_driver_offline.cpp
host/driver_offline/src/gemm_driver_offline.cpp
+42
-0
host/host_tensor/include/host_gemm.hpp
host/host_tensor/include/host_gemm.hpp
+18
-0
No files found.
host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp
0 → 100644
View file @
85712f19
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
ABType
,
typename
AccType
,
typename
CType
,
typename
ADesc
,
typename
BDesc
,
typename
CDesc
>
void
device_gemm_xdlops_km_nk_mn
(
const
ADesc
&
a_k_m_grid_desc
,
const
BDesc
&
b_n_k_grid_desc
,
const
CDesc
&
c_m_n_grid_desc
,
const
Tensor
<
ABType
>&
a_k_m
,
const
Tensor
<
ABType
>&
b_n_k
,
Tensor
<
CType
>&
c_m_n
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
DeviceMem
a_k_m_device_buf
(
sizeof
(
ABType
)
*
a_k_m
.
mDesc
.
GetElementSpace
());
DeviceMem
b_n_k_device_buf
(
sizeof
(
ABType
)
*
b_n_k
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CType
)
*
c_m_n
.
mDesc
.
GetElementSpace
());
a_k_m_device_buf
.
ToDevice
(
a_k_m
.
mData
.
data
());
b_n_k_device_buf
.
ToDevice
(
b_n_k
.
mData
.
data
());
c_m_n_device_buf
.
ToDevice
(
c_m_n
.
mData
.
data
());
#if 0
// [M, N, K0, K1] = [256, 128, 4, 8] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 256;
constexpr index_t NPerBlock = 128;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_M = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif
1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
256
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_M
=
4
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
2
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#endif
const
auto
K
=
a_k_m_grid_desc
.
GetLength
(
I0
);
const
auto
M
=
a_k_m_grid_desc
.
GetLength
(
I1
);
const
auto
N
=
b_n_k_grid_desc
.
GetLength
(
I0
);
constexpr
auto
K1Number
=
Number
<
K1
>
{};
const
auto
K0
=
K
/
K1Number
;
const
auto
a_k0_m_k1_grid_desc
=
transform_tensor_descriptor
(
a_k_m_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
const
auto
b_k0_n_k1_grid_desc
=
transform_tensor_descriptor
(
b_n_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: K0
Sequence
<
0
,
0
,
0
>
{},
// 1+: M
Sequence
<
0
,
0
,
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0-: K0
Sequence
<
0
,
0
,
0
>
{},
// 1-: M
Sequence
<
0
,
0
,
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: K0
Sequence
<
0
,
0
,
0
>
{},
// 1+: N
Sequence
<
0
,
0
,
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0-: K0
Sequence
<
0
,
0
,
0
>
{},
// 1-: N
Sequence
<
0
,
0
,
0
>
{}));
// 2-: K1
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
ABType
,
AccType
,
CType
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
c_m_n_grid_desc
),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
K1
,
MRepeat
,
NRepeat
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
Sequence
<
0
,
2
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
ABlockTransferSrcScalarPerVector_M
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
BBlockTransferSrcScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
0
,
2
,
4
,
5
,
6
,
1
,
3
,
7
>
,
7
,
CThreadTransferDstScalarPerVector
,
decltype
(
a_k0_m_k1_grid_step_hacks
),
decltype
(
b_k0_n_k1_grid_step_hacks
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
>
(
static_cast
<
ABType
*>
(
a_k_m_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_n_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
a_k0_m_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
a_k0_m_k1_grid_move_slice_window_step_hacks
,
b_k0_n_k1_grid_move_slice_window_step_hacks
,
nrepeat
);
float
perf
=
static_cast
<
float
>
((
std
::
size_t
(
2
)
*
M
*
N
*
K
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
// copy result back to host
c_m_n_device_buf
.
FromDevice
(
c_m_n
.
mData
.
data
());
}
host/driver_offline/src/gemm_driver_offline.cpp
View file @
85712f19
...
@@ -15,10 +15,12 @@
...
@@ -15,10 +15,12 @@
#include "device_gemm_xdlops_mk_kn_mn.hpp"
#include "device_gemm_xdlops_mk_kn_mn.hpp"
#include "device_gemm_xdlops_mk_nk_mn.hpp"
#include "device_gemm_xdlops_mk_nk_mn.hpp"
#include "device_gemm_xdlops_km_kn_mn.hpp"
#include "device_gemm_xdlops_km_kn_mn.hpp"
#include "device_gemm_xdlops_km_nk_mn.hpp"
#define USE_GEMM_XDL_MK_KN_MN 1
#define USE_GEMM_XDL_MK_KN_MN 1
#define USE_GEMM_XDL_MK_NK_MN 1
#define USE_GEMM_XDL_MK_NK_MN 1
#define USE_GEMM_XDL_KM_KN_MN 1
#define USE_GEMM_XDL_KM_KN_MN 1
#define USE_GEMM_XDL_KM_NK_MN 1
enum
GemmAlgo
enum
GemmAlgo
{
{
...
@@ -123,6 +125,23 @@ int main(int argc, char* argv[])
...
@@ -123,6 +125,23 @@ int main(int argc, char* argv[])
c_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
c_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
c_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
c_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
}
}
else
if
(
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
a_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K
);
a_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
M
);
a_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
M
);
a_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
b_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
b_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
K
);
b_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K
);
b_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
c_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
M
);
c_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
N
);
c_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
c_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
}
else
else
{
{
std
::
runtime_error
(
"wrong! not implemented"
);
std
::
runtime_error
(
"wrong! not implemented"
);
...
@@ -190,6 +209,14 @@ int main(int argc, char* argv[])
...
@@ -190,6 +209,14 @@ int main(int argc, char* argv[])
return
make_tuple
(
a_desc
,
b_desc
,
c_desc
);
return
make_tuple
(
a_desc
,
b_desc
,
c_desc
);
};
};
auto
f_make_for_device_km_nk_mn
=
[
&
]()
{
const
auto
a_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K
,
M
),
make_tuple
(
M
,
I1
));
const
auto
b_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
K
),
make_tuple
(
K
,
I1
));
const
auto
c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
N
,
I1
));
return
make_tuple
(
a_desc
,
b_desc
,
c_desc
);
};
#if USE_GEMM_XDL_MK_KN_MN
#if USE_GEMM_XDL_MK_KN_MN
if
(
algo
==
GemmAlgo
::
Xdl_MK_KN_MN
)
if
(
algo
==
GemmAlgo
::
Xdl_MK_KN_MN
)
{
{
...
@@ -235,6 +262,21 @@ int main(int argc, char* argv[])
...
@@ -235,6 +262,21 @@ int main(int argc, char* argv[])
}
}
#endif
#endif
#if USE_GEMM_XDL_KM_NK_MN
if
(
algo
==
GemmAlgo
::
Xdl_KM_NK_MN
)
{
if
(
layout
!=
GemmMatrixLayout
::
KM_NK_MN
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
const
auto
descs
=
f_make_for_device_km_nk_mn
();
device_gemm_xdlops_km_nk_mn
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
descs
[
I0
],
descs
[
I1
],
descs
[
I2
],
a
,
b
,
c_device
,
nrepeat
);
}
#endif
if
(
do_verification
)
if
(
do_verification
)
{
{
host_gemm
(
a
,
b
,
c_host
,
layout
);
host_gemm
(
a
,
b
,
c_host
,
layout
);
...
...
host/host_tensor/include/host_gemm.hpp
View file @
85712f19
...
@@ -62,6 +62,24 @@ void host_gemm(const Tensor<AType>& a,
...
@@ -62,6 +62,24 @@ void host_gemm(const Tensor<AType>& a,
make_ParallelTensorFunctor
(
f_km_kn_mn
,
c
.
mDesc
.
GetLengths
()[
0
],
c
.
mDesc
.
GetLengths
()[
1
])(
make_ParallelTensorFunctor
(
f_km_kn_mn
,
c
.
mDesc
.
GetLengths
()[
0
],
c
.
mDesc
.
GetLengths
()[
1
])(
std
::
thread
::
hardware_concurrency
());
std
::
thread
::
hardware_concurrency
());
}
}
else
if
(
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
auto
f_km_nk_mn
=
[
&
](
auto
m
,
auto
n
)
{
const
int
K
=
a
.
mDesc
.
GetLengths
()[
0
];
double
v
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
v
+=
static_cast
<
const
double
>
(
a
(
k
,
m
))
*
static_cast
<
const
double
>
(
b
(
n
,
k
));
}
c
(
m
,
n
)
=
v
;
};
make_ParallelTensorFunctor
(
f_km_nk_mn
,
c
.
mDesc
.
GetLengths
()[
0
],
c
.
mDesc
.
GetLengths
()[
1
])(
std
::
thread
::
hardware_concurrency
());
}
else
else
{
{
throw
std
::
runtime_error
(
"wrong! not supported layout"
);
throw
std
::
runtime_error
(
"wrong! not supported layout"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment