Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
386686d7
Commit
386686d7
authored
Sep 04, 2021
by
Chao Liu
Browse files
add gemm driver
parent
627d8ef3
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
449 additions
and
0 deletions
+449
-0
host/driver_offline/CMakeLists.txt
host/driver_offline/CMakeLists.txt
+3
-0
host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp
host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp
+218
-0
host/driver_offline/src/gemm_driver_offline.cpp
host/driver_offline/src/gemm_driver_offline.cpp
+183
-0
host/host_tensor/include/gemm_common.hpp
host/host_tensor/include/gemm_common.hpp
+12
-0
host/host_tensor/include/host_gemm.hpp
host/host_tensor/include/host_gemm.hpp
+33
-0
No files found.
host/driver_offline/CMakeLists.txt
View file @
386686d7
...
@@ -14,11 +14,14 @@ include_directories(BEFORE
...
@@ -14,11 +14,14 @@ include_directories(BEFORE
set
(
CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp
)
set
(
CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp
)
set
(
CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp
)
set
(
CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp
)
set
(
CONV_WRW_DRIVER_OFFLINE_SOURCE src/conv_wrw_driver_offline.cpp
)
set
(
CONV_WRW_DRIVER_OFFLINE_SOURCE src/conv_wrw_driver_offline.cpp
)
set
(
GEMM_DRIVER_OFFLINE_SOURCE src/gemm_driver_offline.cpp
)
add_executable
(
conv_fwd_driver_offline
${
CONV_FWD_DRIVER_OFFLINE_SOURCE
}
)
add_executable
(
conv_fwd_driver_offline
${
CONV_FWD_DRIVER_OFFLINE_SOURCE
}
)
add_executable
(
conv_bwd_driver_offline
${
CONV_BWD_DRIVER_OFFLINE_SOURCE
}
)
add_executable
(
conv_bwd_driver_offline
${
CONV_BWD_DRIVER_OFFLINE_SOURCE
}
)
add_executable
(
conv_wrw_driver_offline
${
CONV_WRW_DRIVER_OFFLINE_SOURCE
}
)
add_executable
(
conv_wrw_driver_offline
${
CONV_WRW_DRIVER_OFFLINE_SOURCE
}
)
add_executable
(
gemm_driver_offline
${
GEMM_DRIVER_OFFLINE_SOURCE
}
)
target_link_libraries
(
conv_fwd_driver_offline PRIVATE host_tensor
)
target_link_libraries
(
conv_fwd_driver_offline PRIVATE host_tensor
)
target_link_libraries
(
conv_bwd_driver_offline PRIVATE host_tensor
)
target_link_libraries
(
conv_bwd_driver_offline PRIVATE host_tensor
)
target_link_libraries
(
conv_wrw_driver_offline PRIVATE host_tensor
)
target_link_libraries
(
conv_wrw_driver_offline PRIVATE host_tensor
)
target_link_libraries
(
gemm_driver_offline PRIVATE host_tensor
)
host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp
0 → 100644
View file @
386686d7
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
ABType
,
typename
AccType
,
typename
CType
,
typename
ADesc
,
typename
BDesc
,
typename
CDesc
>
void
device_gemm_xdlops_mk_nk_mn
(
const
ADesc
&
a_m_k_grid_desc
,
const
BDesc
&
b_n_k_grid_desc
,
const
CDesc
&
c_m_n_grid_desc
,
const
Tensor
<
ABType
>&
a_m_k
,
const
Tensor
<
ABType
>&
b_n_k
,
Tensor
<
CType
>&
c_m_n
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
DeviceMem
a_m_k_device_buf
(
sizeof
(
ABType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_n_k_device_buf
(
sizeof
(
ABType
)
*
b_n_k
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CType
)
*
c_m_n
.
mDesc
.
GetElementSpace
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_n_k_device_buf
.
ToDevice
(
b_n_k
.
mData
.
data
());
c_m_n_device_buf
.
ToDevice
(
c_m_n
.
mData
.
data
());
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t MPerBlock = 128;
constexpr index_t NPerBlock = 256;
constexpr index_t KPerBlock = 4;
constexpr index_t MPerXDL = 32;
constexpr index_t NPerXDL = 32;
constexpr index_t K1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>;
using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4;
using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>;
using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>;
constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4;
constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4;
constexpr index_t CThreadTransferDstScalarPerVector = 1;
#elif
1
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
128
;
constexpr
index_t
NPerBlock
=
256
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
2
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
1
,
4
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#endif
const
auto
K
=
a_m_k_grid_desc
.
GetLength
(
I1
);
const
auto
M
=
a_m_k_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
b_n_k_grid_desc
.
GetLength
(
I0
);
constexpr
auto
K1Number
=
Number
<
K1
>
{};
const
auto
K0
=
K
/
K1Number
;
const
auto
a_k0_m_k1_grid_desc
=
transform_tensor_descriptor
(
a_m_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
M
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
const
auto
b_k0_n_k1_grid_desc
=
transform_tensor_descriptor
(
b_n_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: K0
Sequence
<
0
,
0
,
0
>
{},
// 1+: M
Sequence
<
0
,
0
,
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0-: K0
Sequence
<
0
,
0
,
0
>
{},
// 1-: M
Sequence
<
0
,
0
,
0
>
{}));
// 2-: K1
constexpr
auto
b_k0_n_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: K0
Sequence
<
0
,
0
,
0
>
{},
// 1+: N
Sequence
<
0
,
0
,
0
>
{}),
// 2+: K1
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0-: K0
Sequence
<
0
,
0
,
0
>
{},
// 1-: N
Sequence
<
0
,
0
,
0
>
{}));
// 2-: K1
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
b_k0_n_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
ABType
,
AccType
,
CType
,
InMemoryDataOperationEnum_t
::
Set
,
decltype
(
a_k0_m_k1_grid_desc
),
decltype
(
b_k0_n_k1_grid_desc
),
decltype
(
c_m_n_grid_desc
),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
K1
,
MRepeat
,
NRepeat
,
ABlockTransferThreadSliceLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
ABlockTransferSrcScalarPerVector_K1
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
BBlockTransferThreadSliceLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
BBlockTransferSrcScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
0
,
2
,
4
,
5
,
6
,
1
,
3
,
7
>
,
7
,
CThreadTransferDstScalarPerVector
,
decltype
(
a_k0_m_k1_grid_step_hacks
),
decltype
(
b_k0_n_k1_grid_step_hacks
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
a_k0_m_k1_grid_move_slice_window_step_hacks
),
decltype
(
b_k0_n_k1_grid_move_slice_window_step_hacks
),
false
// CAccessOrderMRepeatNRepeat
>
(
static_cast
<
ABType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ABType
*>
(
b_n_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
,
a_k0_m_k1_grid_step_hacks
,
b_k0_n_k1_grid_step_hacks
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
a_k0_m_k1_grid_move_slice_window_step_hacks
,
b_k0_n_k1_grid_move_slice_window_step_hacks
,
nrepeat
);
float
perf
=
static_cast
<
float
>
((
std
::
size_t
(
2
)
*
M
*
N
*
K
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
// copy result back to host
c_m_n_device_buf
.
FromDevice
(
c_m_n
.
mData
.
data
());
}
host/driver_offline/src/gemm_driver_offline.cpp
0 → 100644
View file @
386686d7
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "gemm_common.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdlops_mk_nk_mn.hpp"
#define USE_GEMM_XDL_MK_NK_MN 1
enum
GemmAlgo
{
Xdl_MK_KN_MN
,
// 0
Xdl_MK_NK_MN
,
// 1
};
int
main
(
int
argc
,
char
*
argv
[])
{
using
namespace
ck
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
// dynamic mode
if
(
argc
!=
10
)
{
printf
(
"arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat
\n
"
);
printf
(
"rest: M, N, K
\n
"
);
exit
(
1
);
}
const
auto
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
1
]));
const
auto
algo
=
static_cast
<
GemmAlgo
>
(
std
::
stoi
(
argv
[
2
]));
const
bool
do_verification
=
std
::
stoi
(
argv
[
3
]);
const
int
init_method
=
std
::
stoi
(
argv
[
4
]);
const
bool
do_log
=
std
::
stoi
(
argv
[
5
]);
const
int
nrepeat
=
std
::
stoi
(
argv
[
6
]);
const
index_t
M
=
std
::
stoi
(
argv
[
7
]);
const
index_t
N
=
std
::
stoi
(
argv
[
8
]);
const
index_t
K
=
std
::
stoi
(
argv
[
9
]);
#if 0
using ab_data_t = float;
using acc_data_t = float;
using c_data_t = float;
#elif
1
using
ab_data_t
=
half_t
;
using
acc_data_t
=
float
;
using
c_data_t
=
half_t
;
#elif 1
using
ab_data_t
=
int8_t
;
using
acc_data_t
=
int32_t
;
using
c_data_t
=
int8_t
;
#endif
std
::
vector
<
std
::
size_t
>
a_lengths_host
(
2
),
b_lengths_host
(
2
),
c_lengths_host
(
2
);
std
::
vector
<
std
::
size_t
>
a_strides_host
(
2
),
b_strides_host
(
2
),
c_strides_host
(
2
);
if
(
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
a_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K
);
a_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
M
);
a_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
M
);
a_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
b_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K
);
b_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
N
);
b_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
b_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
c_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
M
);
c_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
N
);
c_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
c_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
}
else
if
(
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
a_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
M
);
a_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
K
);
a_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K
);
a_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
b_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
b_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
K
);
b_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
K
);
b_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
c_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
M
);
c_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
N
);
c_strides_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
c_strides_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
1
);
}
else
{
std
::
runtime_error
(
"wrong! not implemented"
);
}
Tensor
<
ab_data_t
>
a
(
a_lengths_host
,
a_strides_host
);
Tensor
<
ab_data_t
>
b
(
b_lengths_host
,
b_strides_host
);
Tensor
<
c_data_t
>
c_host
(
c_lengths_host
,
c_strides_host
);
Tensor
<
c_data_t
>
c_device
(
c_lengths_host
,
c_strides_host
);
std
::
cout
<<
"layout: "
<<
layout
<<
std
::
endl
;
ostream_HostTensorDescriptor
(
a
.
mDesc
,
std
::
cout
<<
"a: "
);
ostream_HostTensorDescriptor
(
b
.
mDesc
,
std
::
cout
<<
"b: "
);
ostream_HostTensorDescriptor
(
c_host
.
mDesc
,
std
::
cout
<<
"c: "
);
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
switch
(
init_method
)
{
case
0
:
// no initialization
break
;
case
1
:
a
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
b
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
break
;
case
2
:
a
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
b
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
break
;
case
3
:
a
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
b
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
break
;
case
4
:
a
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
b
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
break
;
default:
a
.
GenerateTensorValue
(
GeneratorTensor_3
<
float
>
{
0.0
,
1.0
},
num_thread
);
b
.
GenerateTensorValue
(
GeneratorTensor_3
<
float
>
{
-
0.5
,
0.5
},
num_thread
);
}
auto
f_make_for_device_mk_nk_mn
=
[
&
]()
{
const
auto
a_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
K
,
I1
));
const
auto
b_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
K
),
make_tuple
(
K
,
I1
));
const
auto
c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
N
,
I1
));
return
make_tuple
(
a_desc
,
b_desc
,
c_desc
);
};
#if USE_GEMM_XDL_MK_NK_MN
if
(
algo
==
GemmAlgo
::
Xdl_MK_NK_MN
)
{
if
(
layout
!=
GemmMatrixLayout
::
MK_NK_MN
)
{
throw
std
::
runtime_error
(
"wrong! layout"
);
}
const
auto
descs
=
f_make_for_device_mk_nk_mn
();
device_gemm_xdlops_mk_nk_mn
<
ab_data_t
,
acc_data_t
,
c_data_t
>
(
descs
[
I0
],
descs
[
I1
],
descs
[
I2
],
a
,
b
,
c_device
,
nrepeat
);
}
#endif
if
(
do_verification
)
{
host_gemm
(
a
,
b
,
c_host
,
layout
);
check_error
(
c_host
,
c_device
);
if
(
do_log
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"a : "
,
a
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"b: "
,
b
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_host : "
,
c_host
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_device: "
,
c_device
.
mData
,
","
)
<<
std
::
endl
;
}
}
}
host/host_tensor/include/gemm_common.hpp
0 → 100644
View file @
386686d7
#ifndef GEMM_COMMON_HPP
#define GEMM_COMMON_HPP
enum
GemmMatrixLayout
{
MK_KN_MN
,
// 0
KM_KN_MN
,
// 1
KM_NK_MN
,
// 2
MK_NK_MN
,
// 3
};
#endif
host/host_tensor/include/host_gemm.hpp
0 → 100644
View file @
386686d7
#pragma once
#include "host_tensor.hpp"
#include "gemm_common.hpp"
template
<
typename
AType
,
typename
BType
,
typename
CType
>
void
host_gemm
(
const
Tensor
<
AType
>&
a
,
const
Tensor
<
BType
>&
b
,
Tensor
<
CType
>&
c
,
const
GemmMatrixLayout
layout
)
{
auto
f_mk_nk_mn
=
[
&
](
auto
m
,
auto
n
)
{
const
int
K
=
a
.
mDesc
.
GetLengths
()[
1
];
double
v
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
v
+=
static_cast
<
const
double
>
(
a
(
m
,
k
))
*
static_cast
<
const
double
>
(
b
(
n
,
k
));
}
c
(
m
,
n
)
=
v
;
};
if
(
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
make_ParallelTensorFunctor
(
f_mk_nk_mn
,
c
.
mDesc
.
GetLengths
()[
0
],
c
.
mDesc
.
GetLengths
()[
1
])(
std
::
thread
::
hardware_concurrency
());
}
else
{
throw
std
::
runtime_error
(
"wrong! not supported layout"
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment