Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
97a5b74a
Commit
97a5b74a
authored
Dec 08, 2021
by
Jing Zhang
Browse files
add padding to M/N for irr tile size
parent
b6116d2f
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
113 additions
and
59 deletions
+113
-59
device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp
...eration/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp
+6
-0
device_operation/include/device_gemm_xdl.hpp
device_operation/include/device_gemm_xdl.hpp
+51
-26
host/driver_offline/src/gemm_driver_offline.cpp
host/driver_offline/src/gemm_driver_offline.cpp
+1
-1
profiler/CMakeLists.txt
profiler/CMakeLists.txt
+11
-9
profiler/gemm_profiler.cpp
profiler/gemm_profiler.cpp
+21
-0
profiler/profiler.cpp
profiler/profiler.cpp
+4
-4
script/profile_gemm.sh
script/profile_gemm.sh
+19
-19
No files found.
device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp
View file @
97a5b74a
...
@@ -31,6 +31,12 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = std::tuple<
...
@@ -31,6 +31,12 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = std::tuple<
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
256
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
256
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
160
,
128
,
4
,
4
,
16
,
16
,
5
,
4
,
S
<
1
,
5
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
160
,
4
,
4
,
16
,
16
,
4
,
5
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
5
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
192
,
128
,
4
,
4
,
32
,
32
,
3
,
2
,
S
<
1
,
3
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
192
,
4
,
4
,
32
,
32
,
2
,
3
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
3
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
96
,
128
,
4
,
4
,
16
,
16
,
3
,
4
,
S
<
1
,
3
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
96
,
4
,
4
,
16
,
16
,
4
,
3
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
3
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
64
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
64
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
1
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
1
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
7
,
1
,
true
,
true
>
,
...
...
device_operation/include/device_gemm_xdl.hpp
View file @
97a5b74a
...
@@ -78,10 +78,14 @@ struct DeviceGemmXdl
...
@@ -78,10 +78,14 @@ struct DeviceGemmXdl
}
}
}();
}();
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
std
::
cout
<<
"PadM = "
<<
PadM
<<
" M = "
<<
M
+
PadM
<<
std
::
endl
;
const
auto
a_grid_desc_k0_m_k1
=
const
auto
a_grid_desc_k0_m_k1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pa
ss_through
_transform
(
M
)),
make_pa
d
_transform
(
M
,
I0
,
PadM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
...
@@ -105,10 +109,14 @@ struct DeviceGemmXdl
...
@@ -105,10 +109,14 @@ struct DeviceGemmXdl
}
}
}();
}();
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
std
::
cout
<<
"PadN = "
<<
PadN
<<
" N = "
<<
N
+
PadN
<<
std
::
endl
;
const
auto
b_grid_desc_k0_n_k1
=
const
auto
b_grid_desc_k0_n_k1
=
transform_tensor_descriptor
(
b_grid_desc_k_n
,
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pa
ss_through
_transform
(
N
)),
make_pa
d
_transform
(
N
,
I0
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
...
@@ -117,6 +125,7 @@ struct DeviceGemmXdl
...
@@ -117,6 +125,7 @@ struct DeviceGemmXdl
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideC
)
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideC
)
{
{
const
auto
c_grid_desc_m_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
...
@@ -125,6 +134,18 @@ struct DeviceGemmXdl
...
@@ -125,6 +134,18 @@ struct DeviceGemmXdl
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
}
}
}();
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
const
auto
c_grid_desc_m_n_
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_pad_transform
(
M
,
I0
,
PadM
),
make_pad_transform
(
N
,
I0
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
c_grid_desc_m_n_
;
}
}
using
AGridDesc_K0_M_K1
=
decltype
(
MakeAGridDescriptor_K0_M_K1
(
1
,
1
,
1
));
using
AGridDesc_K0_M_K1
=
decltype
(
MakeAGridDescriptor_K0_M_K1
(
1
,
1
,
1
));
...
@@ -149,22 +170,22 @@ struct DeviceGemmXdl
...
@@ -149,22 +170,22 @@ struct DeviceGemmXdl
Sequence
<
0
,
0
,
0
>
{}));
// 2-: K1
Sequence
<
0
,
0
,
0
>
{}));
// 2-: K1
static
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
static
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
static
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
static
constexpr
auto
a_k0_m_k1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
...
@@ -293,6 +314,10 @@ struct DeviceGemmXdl
...
@@ -293,6 +314,10 @@ struct DeviceGemmXdl
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
{
{
{
std
::
cout
<<
"MPerBlock = "
<<
MPerBlock
<<
" NPerBlock = "
<<
NPerBlock
<<
" MXdlPerWave = "
<<
MXdlPerWave
<<
" NXdlPerWave = "
<<
NXdlPerWave
<<
std
::
endl
;
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
...
...
host/driver_offline/src/gemm_driver_offline.cpp
View file @
97a5b74a
...
@@ -235,7 +235,7 @@ int main(int argc, char* argv[])
...
@@ -235,7 +235,7 @@ int main(int argc, char* argv[])
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
=
std
::
stoi
(
argv
[
10
]);
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
=
std
::
stoi
(
argv
[
10
]);
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
=
std
::
stoi
(
argv
[
11
]);
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
=
std
::
stoi
(
argv
[
11
]);
#if
0
#if
1
using
ab_data_t
=
float
;
using
ab_data_t
=
float
;
using
acc_data_t
=
float
;
using
acc_data_t
=
float
;
using
c_data_t
=
float
;
using
c_data_t
=
float
;
...
...
profiler/CMakeLists.txt
View file @
97a5b74a
...
@@ -15,13 +15,13 @@ include_directories(BEFORE
...
@@ -15,13 +15,13 @@ include_directories(BEFORE
# device_gemm_instance
# device_gemm_instance
set
(
DEVICE_GEMM_INSTANCE_SOURCE
set
(
DEVICE_GEMM_INSTANCE_SOURCE
${
PROJECT_SOURCE_DIR
}
/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp;
#
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp;
#
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp;
#
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp;
#
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp;
#
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp;
#
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp;
#
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp;
)
)
add_library
(
device_gemm_instance SHARED
${
DEVICE_GEMM_INSTANCE_SOURCE
}
)
add_library
(
device_gemm_instance SHARED
${
DEVICE_GEMM_INSTANCE_SOURCE
}
)
...
@@ -43,8 +43,10 @@ set_target_properties(device_conv_instance PROPERTIES POSITION_INDEPENDENT_CODE
...
@@ -43,8 +43,10 @@ set_target_properties(device_conv_instance PROPERTIES POSITION_INDEPENDENT_CODE
install
(
TARGETS device_conv_instance LIBRARY DESTINATION lib
)
install
(
TARGETS device_conv_instance LIBRARY DESTINATION lib
)
# ck_profiler
# ck_profiler
set
(
PROFILER_SOURCE profiler.cpp gemm_profiler.cpp conv_profiler.cpp
)
#set(PROFILER_SOURCE profiler.cpp gemm_profiler.cpp conv_profiler.cpp)
set
(
PROFILER_SOURCE profiler.cpp gemm_profiler.cpp
)
add_executable
(
ckProfiler
${
PROFILER_SOURCE
}
)
add_executable
(
ckProfiler
${
PROFILER_SOURCE
}
)
target_link_libraries
(
ckProfiler PRIVATE host_tensor
)
target_link_libraries
(
ckProfiler PRIVATE host_tensor
)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_instance device_conv_instance
)
#target_link_libraries(ckProfiler PRIVATE device_gemm_instance device_conv_instance)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_instance
)
profiler/gemm_profiler.cpp
View file @
97a5b74a
...
@@ -66,6 +66,7 @@ int gemm_profiler(int argc, char* argv[])
...
@@ -66,6 +66,7 @@ int gemm_profiler(int argc, char* argv[])
const
int
StrideB
=
std
::
stoi
(
argv
[
12
]);
const
int
StrideB
=
std
::
stoi
(
argv
[
12
]);
const
int
StrideC
=
std
::
stoi
(
argv
[
13
]);
const
int
StrideC
=
std
::
stoi
(
argv
[
13
]);
#if 0
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
{
ck::profiler::profile_gemm<ck::half_t,
ck::profiler::profile_gemm<ck::half_t,
...
@@ -210,6 +211,26 @@ int gemm_profiler(int argc, char* argv[])
...
@@ -210,6 +211,26 @@ int gemm_profiler(int argc, char* argv[])
(StrideB < 0) ? K : StrideB,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC);
(StrideC < 0) ? N : StrideC);
}
}
#endif
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
ck
::
profiler
::
profile_gemm
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
}
else
else
{
{
throw
std
::
runtime_error
(
"wrong! this GEMM data_type & layout is not implemented"
);
throw
std
::
runtime_error
(
"wrong! this GEMM data_type & layout is not implemented"
);
...
...
profiler/profiler.cpp
View file @
97a5b74a
...
@@ -14,10 +14,10 @@ int main(int argc, char* argv[])
...
@@ -14,10 +14,10 @@ int main(int argc, char* argv[])
{
{
return
gemm_profiler
(
argc
,
argv
);
return
gemm_profiler
(
argc
,
argv
);
}
}
else
if
(
strcmp
(
argv
[
1
],
"conv"
)
==
0
)
//
else if(strcmp(argv[1], "conv") == 0)
{
//
{
return
conv_profiler
(
argc
,
argv
);
//
return conv_profiler(argc, argv);
}
//
}
else
else
{
{
printf
(
"arg1: tensor operation (gemm=GEMM, conv=Convolution)
\n
"
);
printf
(
"arg1: tensor operation (gemm=GEMM, conv=Convolution)
\n
"
);
...
...
script/profile_gemm.sh
View file @
97a5b74a
...
@@ -24,22 +24,22 @@ REPEAT=$7
...
@@ -24,22 +24,22 @@ REPEAT=$7
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
960 1024 1024
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
960 1024 1024
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
1920 2048 2048
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
1920 2048 2048
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
3840 4096 4096
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
3840 4096 4096
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
7680 8192 8192
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
7680 8192 8192
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
1024 1024 1024 1024 1024 1024
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
1024 1024 1024 1024 1024 1024
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
2048 2048 2048 2048 2048 2048
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
2048 2048 2048 2048 2048 2048
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
4096 4096 4096 4096 4096 4096
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
4096 4096 4096 4096 4096 4096
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
8192 8192 8192 8192 8192 8192
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
8192 8192 8192 8192 8192 8192
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
1024 1024 1024 1056 1056 1056
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
1024 1024 1024 1056 1056 1056
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
2048 2048 2048 2080 2080 2080
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
2048 2048 2048 2080 2080 2080
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
4096 4096 4096 4128 4128 4128
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
4096 4096 4096 4128 4128 4128
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
8192 8192 8192 8224 8224 8224
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
8192 8192 8192 8224 8224 8224
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
1024 1024 1024 1088 1088 1088
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
1024 1024 1024 1088 1088 1088
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
2048 2048 2048 2112 2112 2112
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
2048 2048 2048 2112 2112 2112
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
4096 4096 4096 4160 4160 4160
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
4096 4096 4096 4160 4160 4160
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
8192 8192 8192 8256 8256 8256
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
8192 8192 8192 8256 8256 8256
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment