Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
0a65fc55
Commit
0a65fc55
authored
May 23, 2023
by
Adam Osewski
Browse files
Add example for new kernel.
parent
9e696586
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
115 additions
and
19 deletions
+115
-19
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+2
-0
example/01_gemm/common.hpp
example/01_gemm/common.hpp
+17
-13
example/01_gemm/gemm_xdl_direct_c_write_out_fp16.cpp
example/01_gemm/gemm_xdl_direct_c_write_out_fp16.cpp
+57
-0
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+39
-6
No files found.
example/01_gemm/CMakeLists.txt
View file @
0a65fc55
...
@@ -20,11 +20,13 @@ add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp)
...
@@ -20,11 +20,13 @@ add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp)
add_example_executable
(
example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp
)
add_example_executable
(
example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp
)
add_example_executable
(
example_gemm_xdl_bf16 gemm_xdl_bf16.cpp
)
add_example_executable
(
example_gemm_xdl_bf16 gemm_xdl_bf16.cpp
)
add_example_executable
(
example_gemm_xdl_int8 gemm_xdl_int8.cpp
)
add_example_executable
(
example_gemm_xdl_int8 gemm_xdl_int8.cpp
)
add_example_executable
(
example_gemm_xdl_direct_c_write_out_fp16 gemm_xdl_direct_c_write_out_fp16.cpp
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_int8
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_int8
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_wavelet_fp16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_wavelet_fp16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_direct_c_write_out_fp16
)
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_xdl_int4 gemm_xdl_int4.cpp
)
add_example_executable
(
example_gemm_xdl_int4 gemm_xdl_int4.cpp
)
...
...
example/01_gemm/common.hpp
View file @
0a65fc55
...
@@ -38,6 +38,7 @@ struct ExecutionConfig final
...
@@ -38,6 +38,7 @@ struct ExecutionConfig final
bool
do_verification
=
true
;
bool
do_verification
=
true
;
int
init_method
=
1
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
bool
time_kernel
=
false
;
bool
do_log
=
false
;
};
};
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
...
@@ -55,33 +56,36 @@ parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfi
...
@@ -55,33 +56,36 @@ parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfi
{
{
// use default case
// use default case
}
}
else
if
(
argc
==
4
)
else
if
(
argc
==
5
)
{
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
config
.
do_log
=
std
::
stoi
(
argv
[
4
]);
}
}
else
if
(
argc
==
1
0
)
else
if
(
argc
==
1
1
)
{
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
config
.
do_log
=
std
::
stoi
(
argv
[
4
]);
problem_size
.
M
=
std
::
stoi
(
argv
[
4
]);
problem_size
.
M
=
std
::
stoi
(
argv
[
5
]);
problem_size
.
N
=
std
::
stoi
(
argv
[
5
]);
problem_size
.
N
=
std
::
stoi
(
argv
[
6
]);
problem_size
.
K
=
std
::
stoi
(
argv
[
6
]);
problem_size
.
K
=
std
::
stoi
(
argv
[
7
]);
problem_size
.
StrideA
=
std
::
stoi
(
argv
[
7
]);
problem_size
.
StrideA
=
std
::
stoi
(
argv
[
8
]);
problem_size
.
StrideB
=
std
::
stoi
(
argv
[
8
]);
problem_size
.
StrideB
=
std
::
stoi
(
argv
[
9
]);
problem_size
.
StrideC
=
std
::
stoi
(
argv
[
9
]);
problem_size
.
StrideC
=
std
::
stoi
(
argv
[
10
]);
}
}
else
else
{
{
std
::
cerr
<<
"arg1: verification (0=no, 1=yes)"
<<
std
::
endl
std
::
cerr
<<
"arg1: verification (0=no, 1=yes)
\n
"
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
<<
std
::
endl
<<
"arg3: time kernel (0=no, 1=yes)
\n
"
<<
"arg3: time kernel (0=no, 1=yes)"
<<
std
::
endl
<<
"arg4: print tensor (0=no, 1=yes)
\n
"
<<
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC"
<<
std
::
endl
;
<<
"arg5 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC"
<<
std
::
endl
;
return
false
;
return
false
;
}
}
...
...
example/01_gemm/gemm_xdl_direct_c_write_out_fp16.cpp
0 → 100644
View file @
0a65fc55
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_direct_c_write_out.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
float
;
using
CDataType
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
LoopSchedDefault
=
ck
::
LoopScheduler
::
Default
;
static
constexpr
auto
GemmPipeline
=
ck
::
PipelineVersion
::
v1
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_DirectCWriteOut
// clang-format off
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| LoopScheduler| PipelineVersion|
// ######| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| | |
// ######| | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, LoopSchedDefault, GemmPipeline>;
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
64
,
32
,
32
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
LoopSchedDefault
,
GemmPipeline
>
;
// clang-format on
// clang-format off
using
DeviceGemmInstance1
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| PipelineVersion|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | |
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
LoopSchedDefault
,
GemmPipeline
>
;
// clang-format on
using
DeviceGemmInstance
=
DeviceGemmInstance
;
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
#include "run_gemm_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
example/01_gemm/run_gemm_example.inc
View file @
0a65fc55
...
@@ -30,10 +30,31 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -30,10 +30,31 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
switch
(
config
.
init_method
)
switch
(
config
.
init_method
)
{
{
case
0
:
break
;
case
0
:
ck
::
utils
::
FillConstant
<
ADataType
>
{
1.
f
}(
a_m_k
);
ck
::
utils
::
FillConstant
<
BDataType
>
{
0.
f
}(
b_k_n
);
// for (ck::index_t m = 0; m < M; ++m)
// {
// for (ck::index_t k = 0; k < K; ++k)
// {
// a_m_k(m, k) = (m * M + k) % 5;
// }
// }
for
(
ck
::
index_t
n
=
0
;
n
<
N
;
++
n
)
{
for
(
ck
::
index_t
k
=
0
;
k
<
K
;
++
k
)
{
if
(
n
==
k
)
b_k_n
(
k
,
n
)
=
n
*
2
;
}
}
break
;
case
1
:
case
1
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5
.
f
,
5
.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
1
.
f
,
3
.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5
.
f
,
5
.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
1
.
f
,
3
.
f
}(
b_k_n
);
break
;
break
;
default
:
default
:
ck
::
utils
::
FillUniformDistribution
<
ADataType
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistribution
<
ADataType
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
...
@@ -65,6 +86,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -65,6 +86,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
c_m_n_device_buf
.
SetZero
();
#endif
#endif
auto
a_element_op
=
AElementOp
{};
auto
a_element_op
=
AElementOp
{};
...
@@ -114,6 +136,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -114,6 +136,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
bool
result
=
true
;
if
(
config
.
do_verification
)
if
(
config
.
do_verification
)
{
{
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
...
@@ -131,15 +154,25 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -131,15 +154,25 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
c_m_n_device_result
=
c_m_n_device_result_converted
.
CopyAsType
<
CDataType
>
();
c_m_n_device_result
=
c_m_n_device_result_converted
.
CopyAsType
<
CDataType
>
();
re
turn
ck
::
utils
::
check_err
(
c_m_n_device_result_converted
,
c_m_n_host_result
);
re
sult
=
result
&&
ck
::
utils
::
check_err
(
c_m_n_device_result_converted
,
c_m_n_host_result
);
#else
#else
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
re
turn
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
re
sult
=
result
&&
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
#endif
#endif
}
}
return
true
;
if
(
config
.
do_log
)
{
LogRangeAsType
<
float
>
(
std
::
cout
<<
"a:
\n
"
,
a_m_k
.
mData
,
","
,
32
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"b:
\n
"
,
b_k_n
.
mData
,
","
,
32
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_host:
\n
"
,
c_m_n_host_result
.
mData
,
","
,
32
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_device:
\n
"
,
c_m_n_device_result
.
mData
,
","
,
32
)
<<
std
::
endl
;
}
return
result
;
}
}
bool
run_gemm_example
(
int
argc
,
char
*
argv
[])
bool
run_gemm_example
(
int
argc
,
char
*
argv
[])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment