Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e5bcd2bb
Commit
e5bcd2bb
authored
Dec 03, 2021
by
Jing Zhang
Browse files
debug
parent
41cdd380
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
206 additions
and
101 deletions
+206
-101
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+5
-2
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
+28
-15
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
...ude/tensor_operation/threadwise_tensor_slice_transfer.hpp
+1
-0
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+2
-1
composable_kernel/include/utility/config.hpp
composable_kernel/include/utility/config.hpp
+1
-1
composable_kernel/include/utility/static_buffer_of_vector_type_v2.hpp
...ernel/include/utility/static_buffer_of_vector_type_v2.hpp
+5
-0
device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp
...eration/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp
+12
-8
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
...on_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
+11
-11
host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp
host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp
+37
-10
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
+28
-0
host/driver_offline/src/conv_fwd_driver_offline.cpp
host/driver_offline/src/conv_fwd_driver_offline.cpp
+4
-4
host/driver_offline/src/gemm_driver_offline.cpp
host/driver_offline/src/gemm_driver_offline.cpp
+5
-5
profiler/CMakeLists.txt
profiler/CMakeLists.txt
+18
-18
profiler/gemm_profiler.cpp
profiler/gemm_profiler.cpp
+22
-0
profiler/profiler.cpp
profiler/profiler.cpp
+5
-5
script/cmake-rocm.sh
script/cmake-rocm.sh
+1
-1
script/conv_driver.sh
script/conv_driver.sh
+1
-1
script/gemm_driver.sh
script/gemm_driver.sh
+2
-1
script/profile_gemm.sh
script/profile_gemm.sh
+18
-18
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
e5bcd2bb
...
...
@@ -157,10 +157,13 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
c_grid_desc_m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
Repeat
,
MWaves
,
MPerXDL
)),
make_unmerge_transform
(
make_tuple
(
N
Repeat
,
NWaves
,
NPerXDL
))),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerXDL
)
,
MWaves
,
MPerXDL
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerXDL
)
,
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
View file @
e5bcd2bb
...
...
@@ -288,13 +288,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0_m_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
//
if constexpr(ABlockLdsExtraM)
//
{
//
return make_naive_tensor_descriptor(
//
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
//
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
//
}
//
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
...
...
@@ -303,13 +303,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0_n_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
//
if constexpr(BBlockLdsExtraN)
//
{
//
return make_naive_tensor_descriptor(
//
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
//
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
//
}
//
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
...
...
@@ -619,6 +619,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const
index_t
n_thread_data_on_grid
=
n_block_data_idx_on_grid
+
c_thread_mtx_on_block
[
I1
];
printf
(
"%d %d %d
\n
"
,
get_thread_local_1d_id
(),
c_thread_mtx_on_block
[
I0
],
c_thread_mtx_on_block
[
I1
]);
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks
=
CGridStepHacks
{};
const
auto
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor
=
...
...
@@ -640,6 +645,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
n_thread_data_on_grid_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_grid
));
c_thread_buf
.
Fill
(
get_thread_local_1d_id
());
if
(
get_thread_local_1d_id
()
==
0
)
printf
(
"%d %d %d
\n
"
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I0
),
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I1
),
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I2
));
auto
c_thread_copy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
...
...
@@ -652,7 +665,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
CThreadTransferDstScalarPerVector
,
CGlobalMemoryDataOperation
,
1
,
tru
e
>
{
fals
e
>
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
m_thread_data_on_grid_idx
[
I0
],
n_thread_data_on_grid_idx
[
I0
],
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
View file @
e5bcd2bb
...
...
@@ -214,6 +214,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
dst_coord_
.
GetOffset
(),
is_dst_valid
,
dst_vector
.
template
AsType
<
dst_vector_t
>()[
Number
<
0
>
{}]);
printf
(
"copy: %d %d
\n
"
,
dst_coord_
.
GetOffset
(),
dst_coord_
.
GetIndex
()[
I0
]);
}
else
if
constexpr
(
DstInMemOp
==
InMemoryDataOperationEnum_t
::
AtomicAdd
)
{
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
e5bcd2bb
...
...
@@ -589,6 +589,7 @@ struct XdlopsGemm
const
auto
N0
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I1
);
const
auto
M1
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I2
);
const
auto
N1
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I3
);
const
auto
N2
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I5
);
return
transform_tensor_descriptor
(
c_desc_m0_n0_m1_n1_m2_n2
,
...
...
@@ -599,7 +600,7 @@ struct XdlopsGemm
make_unmerge_transform
(
make_tuple
(
mfma_instr
.
num_groups_per_blk
,
mfma_instr
.
num_input_blks
,
mfma_instr
.
group_size
)),
make_pass_through_transform
(
mfma_instr
.
num_threads_per_blk
)),
make_pass_through_transform
(
N2
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
...
...
composable_kernel/include/utility/config.hpp
View file @
e5bcd2bb
...
...
@@ -57,7 +57,7 @@
// AMD buffer addressing
#ifndef CK_USE_AMD_BUFFER_ADDRESSING
#define CK_USE_AMD_BUFFER_ADDRESSING
1
#define CK_USE_AMD_BUFFER_ADDRESSING
0
#endif
// only gfx908 support native floating point atomic add
...
...
composable_kernel/include/utility/static_buffer_of_vector_type_v2.hpp
View file @
e5bcd2bb
...
...
@@ -104,6 +104,11 @@ struct StaticBufferOfVectorTypeV2 : public StaticallyIndexedArray<T, N>
[
&
](
auto
i
)
{
GetElement
(
i
,
true
)
=
invalid_element_value_
;
});
}
__host__
__device__
void
Fill
(
VecBaseType
val
)
{
static_for
<
0
,
GetNumElements
(),
1
>
{}([
&
](
auto
i
)
{
GetElement
(
i
,
true
)
=
val
;
});
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
false
;
}
...
...
device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp
View file @
e5bcd2bb
...
...
@@ -27,14 +27,18 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = std::tuple<
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
256
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
256
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
64
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
2
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
1
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
7
,
1
,
true
,
true
>
,
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
64
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
1
,
1
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
2
,
4
>
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
7
,
1
,
true
,
true
>
//DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 96, 128, 4, 4, 16, 16, 3, 4, S<1, 3, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>
//DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 32, 128, 4, 4, 16, 16, 1, 4, S<1, 1, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>
DeviceGemmXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
96
,
128
,
4
,
4
,
32
,
32
,
3
,
2
,
S
<
1
,
3
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
S
<
1
,
4
,
4
>
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
7
,
1
,
true
,
true
>
//DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
//DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>,
//DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
//DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
//DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>,
//DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>,
//DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 2, 4>, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 4, 4>, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 7, 1, true, true>,
//DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 2, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 1, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 7, 1, true, true>,
//DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 1, 4>, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, S<1, 2, 4>, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 7, 1, true, true>
// clang-format on
>
;
...
...
host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp
View file @
e5bcd2bb
...
...
@@ -287,27 +287,27 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
25
6
;
constexpr
index_t
BlockSize
=
6
4
;
constexpr
index_t
GemmMPerBlock
=
12
8
;
constexpr
index_t
GemmNPerBlock
=
25
6
;
constexpr
index_t
GemmMPerBlock
=
4
8
;
constexpr
index_t
GemmNPerBlock
=
1
6
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmMPerXDL
=
16
;
constexpr
index_t
GemmNPerXDL
=
16
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
constexpr
index_t
MRepeat
=
3
;
constexpr
index_t
NRepeat
=
1
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
6
4
,
1
>
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
1
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
8
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
6
4
,
1
>
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
1
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
6
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
...
...
host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp
View file @
e5bcd2bb
...
...
@@ -162,23 +162,23 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k,
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
4
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif
1
#elif
0
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
MPerBlock
=
2
56
;
constexpr
index_t
MPerBlock
=
3
2
;
constexpr
index_t
NPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
32
;
constexpr
index_t
NPerXDL
=
32
;
constexpr
index_t
MPerXDL
=
16
;
constexpr
index_t
NPerXDL
=
16
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
4
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
4
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
64
,
1
>
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
1
,
1
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
8
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
8
;
...
...
@@ -189,6 +189,34 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k,
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
2
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
8
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
MPerBlock
=
48
;
constexpr
index_t
NPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
4
;
constexpr
index_t
MPerXDL
=
16
;
constexpr
index_t
NPerXDL
=
16
;
constexpr
index_t
K1
=
8
;
constexpr
index_t
MRepeat
=
3
;
constexpr
index_t
NRepeat
=
1
;
using
ABlockTransferThreadSliceLengths_K0_M_K1
=
Sequence
<
4
,
1
,
8
>
;
using
ABlockTransferThreadClusterLengths_K0_M_K1
=
Sequence
<
1
,
48
,
1
>
;
constexpr
index_t
ABlockTransferSrcScalarPerVector_K1
=
1
;
constexpr
index_t
ABlockTransferDstScalarPerVector_K1
=
1
;
using
BBlockTransferThreadSliceLengths_K0_N_K1
=
Sequence
<
4
,
1
,
8
>
;
using
BBlockTransferThreadClusterLengths_K0_N_K1
=
Sequence
<
1
,
16
,
1
>
;
constexpr
index_t
BBlockTransferSrcScalarPerVector_N
=
1
;
constexpr
index_t
BBlockTransferDstScalarPerVector_K1
=
1
;
constexpr
index_t
CThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
...
...
@@ -351,8 +379,7 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k,
b_k_n
.
mDesc
.
GetStrides
()[
1
],
b_k_n
.
mDesc
.
GetStrides
()[
0
]));
const
auto
c_m_n_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
c_m_n
.
mDesc
.
GetStrides
()[
0
],
c_m_n
.
mDesc
.
GetStrides
()[
1
]));
const
auto
c_m_n_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
N
));
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
// 0+: K0
...
...
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
View file @
e5bcd2bb
...
...
@@ -6,6 +6,15 @@
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp"
struct
OpPassThrough
{
template
<
typename
T
>
__host__
__device__
constexpr
T
operator
()(
T
v
)
const
{
return
v
;
}
};
template
<
ck
::
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
...
...
@@ -70,6 +79,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
using
ElementwiseOperation
=
OpPassThrough
;
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<
BlockSize
,
FloatAB
,
...
...
@@ -79,6 +90,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K
,
CMNGridDesc
,
ElementwiseOperation
,
ElementwiseOperation
,
ElementwiseOperation
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
...
...
@@ -152,6 +166,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
float
ave_time
=
0
;
auto
element_op_
=
OpPassThrough
{};
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
if
(
has_main_k0_block_loop
)
{
...
...
@@ -162,6 +178,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
remove_reference_t
<
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
BGridDesc_K0_N_K
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
ElementwiseOperation
,
ElementwiseOperation
,
ElementwiseOperation
,
remove_reference_t
<
Block2CTileMap
>
,
true
>
;
...
...
@@ -176,6 +195,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
element_op_
,
element_op_
,
element_op_
,
block_2_ctile_map
);
}
else
...
...
@@ -187,6 +209,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
remove_reference_t
<
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
BGridDesc_K0_N_K
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
ElementwiseOperation
,
ElementwiseOperation
,
ElementwiseOperation
,
remove_reference_t
<
Block2CTileMap
>
,
false
>
;
...
...
@@ -201,6 +226,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
element_op_
,
element_op_
,
element_op_
,
block_2_ctile_map
);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
...
...
host/driver_offline/src/conv_fwd_driver_offline.cpp
View file @
e5bcd2bb
...
...
@@ -12,10 +12,10 @@
#include "host_tensor_generator.hpp"
#include "conv_common.hpp"
#include "device_tensor.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
//
#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
//
#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
//
#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp"
//
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 1
...
...
host/driver_offline/src/gemm_driver_offline.cpp
View file @
e5bcd2bb
...
...
@@ -22,9 +22,9 @@
#include "device_gemm_xdlops_km_nk_nm.hpp"
#define USE_GEMM_XDL_MK_KN_MN 1
#define USE_GEMM_XDL_MK_NK_MN
1
#define USE_GEMM_XDL_KM_KN_MN
1
#define USE_GEMM_XDL_KM_NK_MN
1
#define USE_GEMM_XDL_MK_NK_MN
0
#define USE_GEMM_XDL_KM_KN_MN
0
#define USE_GEMM_XDL_KM_NK_MN
0
#define USE_GEMM_XDL_MK_KN_NM 0
#define USE_GEMM_XDL_MK_NK_NM 0
#define USE_GEMM_XDL_KM_KN_NM 0
...
...
@@ -445,8 +445,8 @@ int main(int argc, char* argv[])
if
(
do_log
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"a : "
,
a
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"b: "
,
b
.
mData
,
","
)
<<
std
::
endl
;
//
LogRangeAsType<float>(std::cout << "a : ", a.mData, ",") << std::endl;
//
LogRangeAsType<float>(std::cout << "b: ", b.mData, ",") << std::endl;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_host : "
,
c_host
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_device: "
,
c_device
.
mData
,
","
)
<<
std
::
endl
;
}
...
...
profiler/CMakeLists.txt
View file @
e5bcd2bb
...
...
@@ -15,13 +15,13 @@ include_directories(BEFORE
# device_gemm_instance
set
(
DEVICE_GEMM_INSTANCE_SOURCE
${
PROJECT_SOURCE_DIR
}
/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp;
#
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp;
#
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp;
#
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp;
#
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp;
#
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp;
#
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp;
#
${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp;
)
add_library
(
device_gemm_instance SHARED
${
DEVICE_GEMM_INSTANCE_SOURCE
}
)
...
...
@@ -31,20 +31,20 @@ set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE
install
(
TARGETS device_gemm_instance LIBRARY DESTINATION lib
)
# device_conv_instance
set
(
DEVICE_CONV_INSTANCE_SOURCE
${
PROJECT_SOURCE_DIR
}
/device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp;
)
#
set(DEVICE_CONV_INSTANCE_SOURCE
##
${PROJECT_SOURCE_DIR}/device_operation/device_conv_xdl_instance_f32_f32_f32_nhwc_kyxc_nhwk.cpp;
##
${PROJECT_SOURCE_DIR}/device_operation/device_conv_xdl_instance_f16_f16_f16_nhwc_kyxc_nhwk.cpp;
#
)
add_library
(
device_conv_instance SHARED
${
DEVICE_CONV_INSTANCE_SOURCE
}
)
target_include_directories
(
device_conv_instance SYSTEM PUBLIC $<BUILD_INTERFACE:
${
HALF_INCLUDE_DIR
}
>
)
target_compile_features
(
device_conv_instance PUBLIC
)
set_target_properties
(
device_conv_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
install
(
TARGETS device_conv_instance LIBRARY DESTINATION lib
)
#
add_library(device_conv_instance SHARED ${DEVICE_CONV_INSTANCE_SOURCE})
#
target_include_directories(device_conv_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
#
target_compile_features(device_conv_instance PUBLIC)
#
set_target_properties(device_conv_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
#
install(TARGETS device_conv_instance LIBRARY DESTINATION lib)
# ck_profiler
set
(
PROFILER_SOURCE profiler.cpp gemm_profiler.cpp
conv_profiler.cpp
)
set
(
PROFILER_SOURCE profiler.cpp gemm_profiler.cpp
)
add_executable
(
ckProfiler
${
PROFILER_SOURCE
}
)
target_link_libraries
(
ckProfiler PRIVATE host_tensor
)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_instance
device_conv_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_instance
)
profiler/gemm_profiler.cpp
View file @
e5bcd2bb
...
...
@@ -66,6 +66,7 @@ int gemm_profiler(int argc, char* argv[])
const
int
StrideB
=
std
::
stoi
(
argv
[
12
]);
const
int
StrideC
=
std
::
stoi
(
argv
[
13
]);
#if 0
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_gemm<ck::half_t,
...
...
@@ -210,6 +211,27 @@ int gemm_profiler(int argc, char* argv[])
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC);
}
#endif
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
ck
::
profiler
::
profile_gemm
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
nrepeat
,
M
,
N
,
K
,
(
StrideA
<
0
)
?
K
:
StrideA
,
(
StrideB
<
0
)
?
N
:
StrideB
,
(
StrideC
<
0
)
?
N
:
StrideC
);
}
else
{
throw
std
::
runtime_error
(
"wrong! this GEMM data_type & layout is not implemented"
);
...
...
profiler/profiler.cpp
View file @
e5bcd2bb
...
...
@@ -6,7 +6,7 @@
#include <half.hpp>
int
gemm_profiler
(
int
,
char
*
[]);
int
conv_profiler
(
int
,
char
*
[]);
//
int conv_profiler(int, char*[]);
int
main
(
int
argc
,
char
*
argv
[])
{
...
...
@@ -14,10 +14,10 @@ int main(int argc, char* argv[])
{
return
gemm_profiler
(
argc
,
argv
);
}
else
if
(
strcmp
(
argv
[
1
],
"conv"
)
==
0
)
{
return
conv_profiler
(
argc
,
argv
);
}
//
else if(strcmp(argv[1], "conv") == 0)
//
{
//
return conv_profiler(argc, argv);
//
}
else
{
printf
(
"arg1: tensor operation (gemm=GEMM, conv=Convolution)
\n
"
);
...
...
script/cmake-rocm.sh
View file @
e5bcd2bb
...
...
@@ -10,7 +10,7 @@ cmake
-D
CMAKE_INSTALL_PREFIX
=
${
MY_PROJECT_INSTALL
}
\
-D
BUILD_DEV
=
OFF
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
CMAKE_CXX_FLAGS
=
"-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O
3
-ftemplate-backtrace-limit=0 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=
$PWD
"
\
-D
CMAKE_CXX_FLAGS
=
"-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O
1
-ftemplate-backtrace-limit=0 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=
$PWD
"
\
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_PREFIX_PATH
=
/opt/rocm
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
...
...
script/conv_driver.sh
View file @
e5bcd2bb
...
...
@@ -22,7 +22,7 @@ REPEAT=$6
######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads Desired_grid_size__
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 $DESIRED_GRID_SIZE
$DRIVER
$LAYOUT
$ALGO
$VERIFY
$INIT
$LOG
$REPEAT
128 256 192 3 3 7
1
7
1
2 2
1 1
1 1
1 1
$DESIRED_GRID_SIZE
$DRIVER
$LAYOUT
$ALGO
$VERIFY
$INIT
$LOG
$REPEAT
1 16 32 1
1 1
48 1 1
1 1
0 0
0 0
$DESIRED_GRID_SIZE
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3 $DESIRED_GRID_SIZE
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE
...
...
script/gemm_driver.sh
View file @
e5bcd2bb
...
...
@@ -19,7 +19,8 @@ REPEAT=$6
######### layout algo verify init log repeat M___ N___ K___ M01_ N01_
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024 $M01 $N01
$DRIVER
$LAYOUT
$ALGO
$VERIFY
$INIT
$LOG
$REPEAT
48 16 32
$M01
$N01
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 $M01 $N01
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 $M01 $N01
$DRIVER
$LAYOUT
$ALGO
$VERIFY
$INIT
$LOG
$REPEAT
3840 4096 4096
$M01
$N01
#
$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 $M01 $N01
#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 $M01 $N01
script/profile_gemm.sh
View file @
e5bcd2bb
...
...
@@ -25,21 +25,21 @@ REPEAT=$7
#$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
960 1024 1024
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
1920 2048 2048
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
3840 4096 4096
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
7680 8192 8192
-1
-1
-1
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
1024 1024 1024 1024 1024 1024
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
2048 2048 2048 2048 2048 2048
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
4096 4096 4096 4096 4096 4096
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
8192 8192 8192 8192 8192 8192
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
1024 1024 1024 1056 1056 1056
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
2048 2048 2048 2080 2080 2080
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
4096 4096 4096 4128 4128 4128
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
8192 8192 8192 8224 8224 8224
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
1024 1024 1024 1088 1088 1088
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
2048 2048 2048 2112 2112 2112
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
4096 4096 4096 4160 4160 4160
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$REPEAT
8192 8192 8192 8256 8256 8256
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 -1 -1 -1
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 -1 -1 -1
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 -1 -1 -1
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1024 1024 1024
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2048 2048 2048
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4096 4096 4096
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8192 8192 8192
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1056 1056 1056
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2080 2080 2080
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4128 4128 4128
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8224 8224 8224
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1088 1088 1088
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2112 2112 2112
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4160 4160 4160
#
$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8256 8256 8256
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment