Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b097be17
Commit
b097be17
authored
Jun 23, 2022
by
root
Browse files
merge changes for upstream/latest update
parents
8a891bbd
a49115b9
Changes
140
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
606 additions
and
250 deletions
+606
-250
Jenkinsfile
Jenkinsfile
+5
-3
LICENSE
LICENSE
+28
-0
README.md
README.md
+1
-1
example/01_gemm/gemm_xdl_fp16.cpp
example/01_gemm/gemm_xdl_fp16.cpp
+17
-13
example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp
example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp
+154
-115
example/04_gemm_add_add_fastgelu/CMakeLists.txt
example/04_gemm_add_add_fastgelu/CMakeLists.txt
+1
-0
example/04_gemm_add_add_fastgelu/README.md
example/04_gemm_add_add_fastgelu/README.md
+23
-0
example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp
..._gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp
+245
-0
example/04_gemm_bias_relu_add/CMakeLists.txt
example/04_gemm_bias_relu_add/CMakeLists.txt
+0
-1
example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp
example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp
+2
-2
example/12_reduce/README.md
example/12_reduce/README.md
+6
-7
example/12_reduce/reduce_blockwise.cpp
example/12_reduce/reduce_blockwise.cpp
+30
-19
example/12_reduce/reduce_blockwise_two_call.cpp
example/12_reduce/reduce_blockwise_two_call.cpp
+44
-33
example/13_pool2d_fwd/pool2d_fwd_common.hpp
example/13_pool2d_fwd/pool2d_fwd_common.hpp
+9
-10
example/16_gemm_reduce/gemm_reduce_xdl_max_fp16.cpp
example/16_gemm_reduce/gemm_reduce_xdl_max_fp16.cpp
+8
-6
example/16_gemm_reduce/gemm_reduce_xdl_mean_squaremean_fp16.cpp
...e/16_gemm_reduce/gemm_reduce_xdl_mean_squaremean_fp16.cpp
+15
-18
example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp
...e/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp
+15
-16
example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp
example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp
+1
-2
example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp
example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp
+1
-2
example/19_binary_elementwise/elementwise_add_1d.cpp
example/19_binary_elementwise/elementwise_add_1d.cpp
+1
-2
No files found.
Jenkinsfile
View file @
b097be17
...
@@ -7,7 +7,6 @@ def show_node_info() {
...
@@ -7,7 +7,6 @@ def show_node_info() {
echo "NODE_NAME = \$NODE_NAME"
echo "NODE_NAME = \$NODE_NAME"
lsb_release -sd
lsb_release -sd
uname -r
uname -r
cat /sys/module/amdgpu/version
ls /opt/ -la
ls /opt/ -la
"""
"""
}
}
...
@@ -101,7 +100,8 @@ def buildHipClangJob(Map conf=[:]){
...
@@ -101,7 +100,8 @@ def buildHipClangJob(Map conf=[:]){
def
variant
=
env
.
STAGE_NAME
def
variant
=
env
.
STAGE_NAME
def
retimage
def
retimage
gitStatusWrapper
(
credentialsId:
'7126e5fe-eb51-4576-b52b-9aaf1de8f0fd'
,
gitHubContext:
"Jenkins - ${variant}"
,
account:
'ROCmSoftwarePlatform'
,
repo:
'composable_kernel'
)
{
gitStatusWrapper
(
credentialsId:
"${status_wrapper_creds}"
,
gitHubContext:
"Jenkins - ${variant}"
,
account:
'ROCmSoftwarePlatform'
,
repo:
'composable_kernel'
)
{
if
(
params
.
USE_DOCKERFILE
){
if
(
params
.
USE_DOCKERFILE
){
try
{
try
{
retimage
=
docker
.
build
(
"${image}"
,
dockerArgs
+
'.'
)
retimage
=
docker
.
build
(
"${image}"
,
dockerArgs
+
'.'
)
...
@@ -191,7 +191,8 @@ def runCKProfiler(Map conf=[:]){
...
@@ -191,7 +191,8 @@ def runCKProfiler(Map conf=[:]){
def
variant
=
env
.
STAGE_NAME
def
variant
=
env
.
STAGE_NAME
def
retimage
def
retimage
gitStatusWrapper
(
credentialsId:
'7126e5fe-eb51-4576-b52b-9aaf1de8f0fd'
,
gitHubContext:
"Jenkins - ${variant}"
,
account:
'ROCmSoftwarePlatform'
,
repo:
'composable_kernel'
)
{
gitStatusWrapper
(
credentialsId:
"${status_wrapper_creds}"
,
gitHubContext:
"Jenkins - ${variant}"
,
account:
'ROCmSoftwarePlatform'
,
repo:
'composable_kernel'
)
{
if
(
params
.
USE_DOCKERFILE
){
if
(
params
.
USE_DOCKERFILE
){
try
{
try
{
retimage
=
docker
.
build
(
"${image}"
,
dockerArgs
+
'.'
)
retimage
=
docker
.
build
(
"${image}"
,
dockerArgs
+
'.'
)
...
@@ -317,6 +318,7 @@ pipeline {
...
@@ -317,6 +318,7 @@ pipeline {
dbsshport
=
"${dbsshport}"
dbsshport
=
"${dbsshport}"
dbsshuser
=
"${dbsshuser}"
dbsshuser
=
"${dbsshuser}"
dbsshpassword
=
"${dbsshpassword}"
dbsshpassword
=
"${dbsshpassword}"
status_wrapper_creds
=
"${status_wrapper_creds}"
}
}
stages
{
stages
{
stage
(
"Static checks"
)
{
stage
(
"Static checks"
)
{
...
...
LICENSE
0 → 100644
View file @
b097be17
Copyright (c) 2018- , Advanced Micro Devices, Inc. (Chao Liu, Jing Zhang)
Copyright (c) 2019- , Advanced Micro Devices, Inc. (Letao Qin, Qianfeng Zhang, Liang Huang, Shaojie Wang)
Copyright (c) 2022- , Advanced Micro Devices, Inc. (Anthony Chang, Chunyu Lai, Illia Silin, Adam Osewski, Poyen Chen, Jehandad Khan)
Copyright (c) 2019-2021, Advanced Micro Devices, Inc. (Hanwen Chang)
Copyright (c) 2019-2020, Advanced Micro Devices, Inc. (Tejash Shah)
Copyright (c) 2020 , Advanced Micro Devices, Inc. (Xiaoyan Zhou)
Copyright (c) 2021-2022, Advanced Micro Devices, Inc. (Jianfeng Yan)
SPDX-License-Identifier: MIT
Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
README.md
View file @
b097be17
...
@@ -6,7 +6,7 @@ docker run \
...
@@ -6,7 +6,7 @@ docker run \
--group-add
sudo
\
--group-add
sudo
\
-w
/root/workspace
\
-w
/root/workspace
\
-v
${
PATH_TO_LOCAL_WORKSPACE
}
:/root/workspace
\
-v
${
PATH_TO_LOCAL_WORKSPACE
}
:/root/workspace
\
rocm/tensorflow:rocm
4.3
.1-tf2.6-dev
\
rocm/tensorflow:rocm
5
.1-tf2.6-dev
\
/bin/bash
/bin/bash
```
```
...
...
example/01_gemm/gemm_xdl_fp16.cpp
View file @
b097be17
...
@@ -28,18 +28,19 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
...
@@ -28,18 +28,19 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
ck
::
half_t
;
using
ADataType
=
F16
;
using
BDataType
=
ck
::
half_t
;
using
BDataType
=
F16
;
using
CDataType
=
ck
::
half_t
;
using
AccDataType
=
F32
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
F32
;
using
CDataType
=
F16
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ALayout
=
Row
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
BLayout
=
Col
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
CLayout
=
Row
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
...
@@ -59,7 +60,6 @@ using DeviceGemmInstance_WaveletModel = ck::tensor_operation::device::DeviceGemm
...
@@ -59,7 +60,6 @@ using DeviceGemmInstance_WaveletModel = ck::tensor_operation::device::DeviceGemm
//######| | | | | | | | | | | | | | |
//######| | | | | | | | | | | | | | |
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
// clang-format on
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
...
@@ -79,7 +79,11 @@ int main(int argc, char* argv[])
...
@@ -79,7 +79,11 @@ int main(int argc, char* argv[])
ck
::
index_t
StrideB
=
4096
;
ck
::
index_t
StrideB
=
4096
;
ck
::
index_t
StrideC
=
4096
;
ck
::
index_t
StrideC
=
4096
;
if
(
argc
==
4
)
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
4
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
...
@@ -103,7 +107,7 @@ int main(int argc, char* argv[])
...
@@ -103,7 +107,7 @@ int main(int argc, char* argv[])
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=n
0
, 1=yes)
\n
"
);
printf
(
"arg3: time kernel (0=n
o
, 1=yes)
\n
"
);
printf
(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
\n
"
);
printf
(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
\n
"
);
exit
(
0
);
exit
(
0
);
}
}
...
...
example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp
View file @
b097be17
...
@@ -3,83 +3,103 @@
...
@@ -3,83 +3,103 @@
#include <initializer_list>
#include <initializer_list>
#include <cstdlib>
#include <cstdlib>
#include <stdlib.h>
#include <stdlib.h>
#include <half.hpp>
#include "check_err.hpp"
#include "check_err.hpp"
#include "config.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_tensor.hpp"
#include "element_wise_operation.hpp"
#include "element_wise_operation.hpp"
#include "device_gemm_xdl_c_shuffle_bias_activation.hpp"
#include "reference_gemm.hpp"
#include "reference_gemm_bias_activation.hpp"
#include "gemm_specialization.hpp"
#include "device_gemm_multiple_d_xdl_cshuffle.hpp"
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
ADataType
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
F32
=
float
;
using
CDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
// C = A * B
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// E = Relu(C + D);
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
struct
AddRelu
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
AddRelu
;
{
__host__
__device__
void
// clang-format off
operator
()(
ck
::
half_t
&
e
,
const
ck
::
half_t
&
c
,
const
ck
::
half_t
&
d
)
const
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdl_C_Shuffle_Bias_Activation
<
{
ADataType
,
// ADataType
const
ck
::
half_t
x
=
c
+
d
;
BDataType
,
// BDataType
CDataType
,
// CDataType
e
=
x
>
0
?
x
:
0
;
AccDataType
,
// AccDataType
}
ALayout
,
// ALayout
};
BLayout
,
// BLayout
CLayout
,
// CLayout
using
ADataType
=
F16
;
AElementOp
,
// AElementwiseOperation
using
BDataType
=
F16
;
BElementOp
,
// BElementwiseOperation
using
AccDataType
=
F32
;
CElementOp
,
// CElementwiseOperation
using
CShuffleDataType
=
F16
;
256
,
// BlockSize
using
DDataType
=
F16
;
256
,
// MPerBlock
using
DsDataType
=
ck
::
Tuple
<
DDataType
>
;
128
,
// NPerBlock
using
EDataType
=
F16
;
4
,
// K0PerBlock
8
,
// K1
using
ALayout
=
Row
;
32
,
// MPerXDL
using
BLayout
=
Col
;
32
,
// NPerXDL
using
ELayout
=
Row
;
4
,
// MXdlPerWave
2
,
// NXdlPerWave
using
AElementOp
=
PassThrough
;
S
<
4
,
64
,
1
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
using
BElementOp
=
PassThrough
;
S
<
1
,
0
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
using
CDEElementOp
=
AddRelu
;
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
2
,
// ABlockTransferSrcVectorDim
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
8
,
// ABlockTransferSrcScalarPerVector
8
,
// ABlockTransferDstScalarPerVector_K1
using
DeviceOpInstance
=
true
,
// ABlockLdsAddExtraM
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD_Xdl_CShuffle
<
ALayout
,
S
<
4
,
64
,
1
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
BLayout
,
S
<
1
,
0
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
ELayout
,
S
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder
ADataType
,
2
,
// BBlockTransferSrcVectorDim
BDataType
,
8
,
// BBlockTransferSrcScalarPerVector
AccDataType
,
8
,
// BBlockTransferDstScalarPerVector_K1
CShuffleDataType
,
true
,
// BBlockLdsAddExtraN
DsDataType
,
1
,
// CShuffleMXdlPerWavePerShuffle
EDataType
,
1
,
// CShuffleNXdlPerWavePerShuffle
AElementOp
,
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
// CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
BElementOp
,
8
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
CDEElementOp
,
// clang-format on
GemmDefault
,
1
,
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemmBiasActivation
<
ADataType
,
256
,
BDataType
,
256
,
CDataType
,
128
,
AElementOp
,
32
,
BElementOp
,
8
,
CElementOp
>
;
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
...
@@ -94,9 +114,13 @@ int main(int argc, char* argv[])
...
@@ -94,9 +114,13 @@ int main(int argc, char* argv[])
ck
::
index_t
StrideA
=
4096
;
ck
::
index_t
StrideA
=
4096
;
ck
::
index_t
StrideB
=
4096
;
ck
::
index_t
StrideB
=
4096
;
ck
::
index_t
Stride
C
=
4096
;
ck
::
index_t
Stride
E
=
4096
;
if
(
argc
==
4
)
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
4
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
...
@@ -114,14 +138,14 @@ int main(int argc, char* argv[])
...
@@ -114,14 +138,14 @@ int main(int argc, char* argv[])
StrideA
=
std
::
stoi
(
argv
[
7
]);
StrideA
=
std
::
stoi
(
argv
[
7
]);
StrideB
=
std
::
stoi
(
argv
[
8
]);
StrideB
=
std
::
stoi
(
argv
[
8
]);
Stride
C
=
std
::
stoi
(
argv
[
9
]);
Stride
E
=
std
::
stoi
(
argv
[
9
]);
}
}
else
else
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=n
0
, 1=yes)
\n
"
);
printf
(
"arg3: time kernel (0=n
o
, 1=yes)
\n
"
);
printf
(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, Stride
C
\n
"
);
printf
(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, Stride
E
\n
"
);
exit
(
0
);
exit
(
0
);
}
}
...
@@ -141,17 +165,14 @@ int main(int argc, char* argv[])
...
@@ -141,17 +165,14 @@ int main(int argc, char* argv[])
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
DDataType
>
d_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
0
,
ELayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
EDataType
>
e_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
ELayout
{}));
Tensor
<
EDataType
>
e_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
ELayout
{}));
// c0_n[n]
Tensor
<
CDataType
>
c0_n
(
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
N
)}),
std
::
vector
<
std
::
size_t
>
({
1
})));
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"
c
_m_n: "
<<
c
_m_n
_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"
d
_m_n: "
<<
d
_m_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"
c0
_n: "
<<
c0_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"
e_m
_n: "
<<
e_m_n_host_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
switch
(
init_method
)
{
{
...
@@ -159,59 +180,59 @@ int main(int argc, char* argv[])
...
@@ -159,59 +180,59 @@ int main(int argc, char* argv[])
case
1
:
case
1
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
c0
_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
C
DataType
>
{
-
5
,
5
});
d_m
_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D
DataType
>
{
-
5
,
5
});
break
;
break
;
default:
default:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
c0
_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
C
DataType
>
{
0.0
,
1.0
});
d_m
_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D
DataType
>
{
0.0
,
1.0
});
}
}
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c
_m_n_device_buf
(
sizeof
(
C
DataType
)
*
c
_m_n
_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
d
_m_n_device_buf
(
sizeof
(
D
DataType
)
*
d
_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c0
_n_device_buf
(
sizeof
(
C
DataType
)
*
c0_n
.
mDesc
.
GetElementSpace
());
DeviceMem
e_m
_n_device_buf
(
sizeof
(
E
DataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpace
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
c_m_n_device_buf
.
ToDevice
(
c_m_n_device_result
.
mData
.
data
());
d_m_n_device_buf
.
ToDevice
(
d_m_n
.
mData
.
data
());
c0_n_device_buf
.
ToDevice
(
c0_n
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
c
de
_element_op
=
C
DE
ElementOp
{};
// do GEMM
// do GEMM
auto
gemm
=
DeviceGemmInstance
{};
auto
device_op
=
DeviceOpInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
device_op
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
auto
argument
=
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
device_op
.
MakeArgument
(
a_m_k_device_buf
.
GetDeviceBuffer
(),
static_cast
<
CDataType
*>
(
c0_n_device_buf
.
GetDeviceBuffer
()),
b_k_n_device_buf
.
GetDeviceBuffer
(),
M
,
std
::
array
<
const
void
*
,
1
>
{
d_m_n_device_buf
.
GetDeviceBuffer
()},
N
,
e_m_n_device_buf
.
GetDeviceBuffer
(),
K
,
M
,
StrideA
,
N
,
StrideB
,
K
,
StrideC
,
StrideA
,
a_element_op
,
StrideB
,
b_element_op
,
std
::
array
<
ck
::
index_t
,
1
>
{
0
},
c_element_op
);
StrideE
,
a_element_op
,
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
b_element_op
,
cde_element_op
);
if
(
!
device_op
.
IsSupportedArgument
(
argument
))
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"wrong! this device_op instance does not support this problem"
);
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"
);
}
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
M
+
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
C
DataType
)
*
M
*
N
+
sizeof
(
C
DataType
)
*
N
;
sizeof
(
E
DataType
)
*
M
*
N
+
sizeof
(
E
DataType
)
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
...
@@ -220,19 +241,37 @@ int main(int argc, char* argv[])
...
@@ -220,19 +241,37 @@ int main(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
<<
std
::
endl
;
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
if
(
do_verification
)
if
(
do_verification
)
{
{
e_m_n_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
Tensor
<
AccDataType
>
c_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
ELayout
{}));
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
AccDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
PassThrough
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
auto
ref_argument
=
a_m_k
,
b_k_n
,
c_m_n
_host_result
,
c0_n
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n
,
a_element_op
,
b_element_op
,
PassThrough
{}
);
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
return
ck
::
utils
::
check_err
(
c_m_n_device_result
.
mData
,
c_m_n_host_result
.
mData
)
?
0
:
1
;
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
cde_element_op
(
e_m_n_host_result
(
m
,
n
),
c_m_n
(
m
,
n
),
d_m_n
(
m
,
n
));
}
}
return
ck
::
utils
::
check_err
(
e_m_n_device_result
.
mData
,
e_m_n_host_result
.
mData
)
?
0
:
1
;
}
}
return
0
;
return
0
;
...
...
example/04_gemm_add_add_fastgelu/CMakeLists.txt
0 → 100644
View file @
b097be17
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp
)
example/04_gemm_
bias_relu_add
/README.md
→
example/04_gemm_
add_add_fastgelu
/README.md
View file @
b097be17
# Instructions for ```example_gemm_
xdl_bias_relu_add
```
# Instructions for ```example_gemm_
add_add_fastgelu_xdl_fp16
```
## Run ```example_gemm_
xdl_bias_relu_add
```
## Run ```example_gemm_
add_add_fastgelu_xdl_fp16
```
```
bash
```
bash
#arg1: verification (0=no, 1=yes)
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3:
run
kernel
# of times (>1
)
#arg3:
time
kernel
(0=no, 1=yes
)
#arg4 to
9
: M (256x), N(128x), K(32x), StrideA, StrideB, Stride
C
#arg4 to
11
: M (256x), N(128x), K(32x), StrideA, StrideB, Stride
D0, StrideD1, StrideE"
./bin/example_gemm_
xdl_bias_relu_add 0 1 5 3840 4096 4096 4096 4096 4096
./bin/example_gemm_
add_add_fastgelu_xdl_fp16 1 1 1
```
```
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
```
```
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
d0_m_n: dim 2, lengths {3840, 4096}, strides {0, 1}
c0_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
d1_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
c1_m_n: dim 2, lengths {3840, 4096}, strides {1, 0}
e_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
arg.a_grid_desc_k0_m_k1_{512, 3840, 8}
arg.b_grid_desc_k0_n_k1_{512, 4096, 8}
arg.c_grid_desc_m_n_{ 3840, 4096}
arg.c0_grid_desc_m_n_{ 3840, 4096}
arg.c1_grid_desc_m_n_{ 3840, 4096}
launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1}
launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1}
Warm up
Warm up
1 time
Start running
5
times...
Start running
10
times...
Perf: 1.2
7583
ms, 10
0.992
TFlops,
73.9688 GB/s
Perf: 1.2
6914
ms, 10
1.525
TFlops,
100.804 GB/s, DeviceGemmMultipleD_Xdl_CShuffle<256, 256, 128, 32, 8, 8>
```
```
example/04_gemm_
bias_relu_add/gemm_xdl_bias_relu_add
.cpp
→
example/04_gemm_
add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16
.cpp
View file @
b097be17
...
@@ -3,84 +3,60 @@
...
@@ -3,84 +3,60 @@
#include <initializer_list>
#include <initializer_list>
#include <cstdlib>
#include <cstdlib>
#include <stdlib.h>
#include <stdlib.h>
#include <half.hpp>
#include "check_err.hpp"
#include "check_err.hpp"
#include "config.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_tensor.hpp"
#include "element_wise_operation.hpp"
#include "element_wise_operation.hpp"
#include "device_gemm_xdl_c_shuffle_bias_activation_add.hpp"
#include "reference_gemm.hpp"
#include "reference_gemm_bias_activation_add.hpp"
#include "gemm_specialization.hpp"
#include "device_gemm_multiple_d_xdl_cshuffle.hpp"
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
ADataType
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
F32
=
float
;
using
CDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AddAddFastGelu
=
ck
::
tensor_operation
::
element_wise
::
AddAddFastGelu
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
AddReluAdd
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
D0DataType
=
F16
;
using
D1DataType
=
F16
;
using
DsDataType
=
ck
::
Tuple
<
D0DataType
,
D1DataType
>
;
using
EDataType
=
F16
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
D0Layout
=
Row
;
using
D1Layout
=
Row
;
using
ELayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
AddAddFastGelu
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
<
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD_Xdl_CShuffle
ADataType
,
// ADataType
//######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
BDataType
,
// BDataType
//######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
CDataType
,
// CDataType
//######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
AccDataType
,
// AccDataType
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ALayout
,
// ALayout
<
ALayout
,
BLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
BLayout
,
// BLayout
CLayout
,
// CLayout
AElementOp
,
// AElementwiseOperation
BElementOp
,
// BElementwiseOperation
CElementOp
,
// CElementwiseOperation
256
,
// BlockSize
256
,
// MPerBlock
128
,
// NPerBlock
4
,
// K0PerBlock
8
,
// K1
32
,
// MPerXDL
32
,
// NPerXDL
4
,
// MXdlPerWave
2
,
// NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransferThreadClusterLengths_K0_M_K1
S
<
1
,
0
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
2
,
// ABlockTransferSrcVectorDim
8
,
// ABlockTransferSrcScalarPerVector
8
,
// ABlockTransferDstScalarPerVector_K1
true
,
// ABlockLdsAddExtraM
S
<
4
,
64
,
1
>
,
// BBlockTransferThreadClusterLengths_K0_N_K1
S
<
1
,
0
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
8
,
// BBlockTransferSrcScalarPerVector
8
,
// BBlockTransferDstScalarPerVector_K1
true
,
// BBlockLdsAddExtraN
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
1
,
32
,
1
,
1
,
8
>
,
// CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
8
>
;
// CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemmBiasActivationAdd
<
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
bool
do_verification
=
true
;
bool
do_verification
=
true
;
...
@@ -94,16 +70,21 @@ int main(int argc, char* argv[])
...
@@ -94,16 +70,21 @@ int main(int argc, char* argv[])
ck
::
index_t
StrideA
=
4096
;
ck
::
index_t
StrideA
=
4096
;
ck
::
index_t
StrideB
=
4096
;
ck
::
index_t
StrideB
=
4096
;
ck
::
index_t
StrideC
=
4096
;
ck
::
index_t
StrideD0
=
0
;
ck
::
index_t
StrideC1
=
4096
;
ck
::
index_t
StrideD1
=
4096
;
ck
::
index_t
StrideE
=
4096
;
if
(
argc
==
4
)
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
4
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
1
1
)
else
if
(
argc
==
1
2
)
{
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
...
@@ -115,15 +96,17 @@ int main(int argc, char* argv[])
...
@@ -115,15 +96,17 @@ int main(int argc, char* argv[])
StrideA
=
std
::
stoi
(
argv
[
7
]);
StrideA
=
std
::
stoi
(
argv
[
7
]);
StrideB
=
std
::
stoi
(
argv
[
8
]);
StrideB
=
std
::
stoi
(
argv
[
8
]);
StrideC
=
std
::
stoi
(
argv
[
9
]);
StrideD0
=
std
::
stoi
(
argv
[
9
]);
StrideC1
=
std
::
stoi
(
argv
[
10
]);
StrideD1
=
std
::
stoi
(
argv
[
10
]);
StrideE
=
std
::
stoi
(
argv
[
11
]);
}
}
else
else
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=n0, 1=yes)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, StrideC1
\n
"
);
printf
(
"arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, "
"StrideE
\n
"
);
exit
(
0
);
exit
(
0
);
}
}
...
@@ -143,21 +126,16 @@ int main(int argc, char* argv[])
...
@@ -143,21 +126,16 @@ int main(int argc, char* argv[])
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
D0DataType
>
d0_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD0
,
D0Layout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
D1DataType
>
d1_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD1
,
D1Layout
{}));
Tensor
<
EDataType
>
e_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
ELayout
{}));
// c0_n[n]
Tensor
<
EDataType
>
e_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
ELayout
{}));
Tensor
<
CDataType
>
c0_n
(
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
N
)}),
std
::
vector
<
std
::
size_t
>
({
1
})));
// c1_m_n[m ,n]
Tensor
<
CDataType
>
c1_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"
c
_m_n: "
<<
c
_m_n
_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"
d0
_m_n: "
<<
d0
_m_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"
c0
_n: "
<<
c0
_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"
d1_m
_n: "
<<
d1_m
_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"
c1
_m_n: "
<<
c1
_m_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"
e
_m_n: "
<<
e
_m_n
_host_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
switch
(
init_method
)
{
{
...
@@ -165,92 +143,102 @@ int main(int argc, char* argv[])
...
@@ -165,92 +143,102 @@ int main(int argc, char* argv[])
case
1
:
case
1
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
c0
_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
C
DataType
>
{
-
5
,
5
});
d0_m
_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D0
DataType
>
{
-
5
,
5
});
c
1_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
C
DataType
>
{
-
5
,
5
});
d
1_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D1
DataType
>
{
-
5
,
5
});
break
;
break
;
default:
default:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
c0
_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
C
DataType
>
{
0.0
,
1.0
});
d0_m
_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D0
DataType
>
{
0.0
,
1.0
});
c
1_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
C
DataType
>
{
0.0
,
1.0
});
d
1_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D1
DataType
>
{
0.0
,
1.0
});
}
}
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c
_m_n_device_buf
(
sizeof
(
C
DataType
)
*
c
_m_n
_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
d0
_m_n_device_buf
(
sizeof
(
D0
DataType
)
*
d0
_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c0
_n_device_buf
(
sizeof
(
C
DataType
)
*
c0
_n
.
mDesc
.
GetElementSpace
());
DeviceMem
d1_m
_n_device_buf
(
sizeof
(
D1
DataType
)
*
d1_m
_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c1
_m_n_device_buf
(
sizeof
(
C
DataType
)
*
c1
_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
e
_m_n_device_buf
(
sizeof
(
E
DataType
)
*
e
_m_n
_device_result
.
mDesc
.
GetElementSpace
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
c_m_n_device_buf
.
ToDevice
(
c_m_n_device_result
.
mData
.
data
());
d0_m_n_device_buf
.
ToDevice
(
d0_m_n
.
mData
.
data
());
c0_n_device_buf
.
ToDevice
(
c0_n
.
mData
.
data
());
d1_m_n_device_buf
.
ToDevice
(
d1_m_n
.
mData
.
data
());
c1_m_n_device_buf
.
ToDevice
(
c1_m_n
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
c
de
_element_op
=
C
DE
ElementOp
{};
// do GEMM
// do GEMM
auto
gemm
=
Device
Gemm
Instance
{};
auto
device_op
=
Device
Op
Instance
{};
auto
invoker
=
device_op
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()
)
,
device_op
.
MakeArgument
(
a_m_k_device_buf
.
GetDeviceBuffer
(),
static_cast
<
BDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()
)
,
b_k_n_device_buf
.
GetDeviceBuffer
(),
static_cast
<
CDataType
*>
(
c
_m_n_device_buf
.
GetDeviceBuffer
()
)
,
std
::
array
<
const
void
*
,
2
>
{
d0
_m_n_device_buf
.
GetDeviceBuffer
(),
static_cast
<
CDataType
*>
(
c0
_n_device_buf
.
GetDeviceBuffer
()
)
,
d1_m
_n_device_buf
.
GetDeviceBuffer
()
}
,
static_cast
<
CDataType
*>
(
c1
_m_n_device_buf
.
GetDeviceBuffer
()
)
,
e
_m_n_device_buf
.
GetDeviceBuffer
(),
M
,
M
,
N
,
N
,
K
,
K
,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
Stride
C
,
std
::
array
<
ck
::
index_t
,
2
>
{
StrideD0
,
Stride
D1
}
,
Stride
C1
,
Stride
E
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c
_element_op
);
cde
_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
if
(
!
device_op
.
IsSupportedArgument
(
argument
))
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"wrong! this device_op instance does not support this problem"
);
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"
);
}
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
M
+
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
C
DataType
)
*
M
*
N
+
sizeof
(
C
DataType
)
*
N
+
sizeof
(
D0
DataType
)
*
N
+
sizeof
(
D1
DataType
)
*
M
*
N
+
sizeof
(
C
DataType
)
*
M
*
N
;
sizeof
(
E
DataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
<<
device_op
.
GetTypeString
()
<<
std
::
endl
;
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
if
(
do_verification
)
if
(
do_verification
)
{
{
Tensor
<
AccDataType
>
c_m_n
(
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
M
),
static_cast
<
std
::
size_t
>
(
N
)}));
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
AccDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
PassThrough
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
auto
ref_argument
=
b_k_n
,
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n
,
a_element_op
,
b_element_op
,
PassThrough
{});
c_m_n_host_result
,
c0_n
,
c1_m_n
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
return
ck
::
utils
::
check_err
(
c_m_n_device_result
.
mData
,
c_m_n_host_result
.
mData
)
?
0
:
1
;
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
cde_element_op
(
e_m_n_host_result
(
m
,
n
),
c_m_n
(
m
,
n
),
d0_m_n
(
m
,
n
),
d1_m_n
(
m
,
n
));
}
}
e_m_n_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
return
ck
::
utils
::
check_err
(
e_m_n_device_result
.
mData
,
e_m_n_host_result
.
mData
)
?
0
:
1
;
}
}
return
0
;
return
0
;
...
...
example/04_gemm_bias_relu_add/CMakeLists.txt
deleted
100644 → 0
View file @
8a891bbd
add_example_executable
(
example_gemm_xdl_bias_relu_add gemm_xdl_bias_relu_add.cpp
)
example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp
View file @
b097be17
...
@@ -291,8 +291,8 @@ int main(int argc, char* argv[])
...
@@ -291,8 +291,8 @@ int main(int argc, char* argv[])
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
conv
->
GetTypeString
()
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
<<
conv
->
GetTypeString
()
<<
std
::
endl
;
if
(
do_verification
)
if
(
do_verification
)
{
{
...
...
example/12_reduce/README.md
View file @
b097be17
...
@@ -5,14 +5,14 @@
...
@@ -5,14 +5,14 @@
# -D <xxx> : input 4-d tensor lengths
# -D <xxx> : input 4-d tensor lengths
# -v <x> : verification (0=no, 1=yes)
# -v <x> : verification (0=no, 1=yes)
#arg1: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
#arg1: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
#arg2: time kernel (0=no, 1=yes)
#arg2: time kernel (0=no, 1=yes)
./bin/example_reduce_blockwise
-D
16,64,32,960
-v
1 1 1
./bin/example_reduce_blockwise
-D
16,64,32,960
-v
1 1 1
```
```
Result
Result
```
```
./bin/example_reduce_blockwise -D 16,64,32,960 -v 1 1 1
./bin/example_reduce_blockwise -D 16,64,32,960 -v 1 1 1
launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1}
launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Warm up 1 time
Start running 10 times...
Start running 10 times...
Perf: 0.282592 ms, 222.641 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSrcVectorDim_0_InSrcVectorSize_1_OutDstVectorSize_1>
Perf: 0.282592 ms, 222.641 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSrcVectorDim_0_InSrcVectorSize_1_OutDstVectorSize_1>
...
@@ -24,19 +24,18 @@ Perf: 0.282592 ms, 222.641 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSr
...
@@ -24,19 +24,18 @@ Perf: 0.282592 ms, 222.641 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSr
```
bash
```
bash
#arg1: verification (0=no, 1=yes(
#arg1: verification (0=no, 1=yes(
#arg2: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
#arg2: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
#arg3: time kernel (0=no, 1=yes)
#arg3: time kernel (0=no, 1=yes)
./bin/example_reduce_blockwise_two_call 1 2 1
./bin/example_reduce_blockwise_two_call 1 2 1
```
Result
Result
```
```
./bin/example_reduce_blockwise_two_call 1 2 1
./bin/example_reduce_blockwise_two_call 1 2 1
launch_and_time_kernel: grid_dim {204800, 1, 1}, block_dim {256, 1, 1}
launch_and_time_kernel: grid_dim {204800, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Warm up 1 time
Start running 10 times...
Start running 10 times...
launch_and_time_kernel: grid_dim {6400, 1, 1}, block_dim {256, 1, 1}
launch_and_time_kernel: grid_dim {6400, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Warm up 1 time
Start running 10 times...
Start running 10 times...
Perf: 2.1791 ms, 771.42 GB/s, DeviceReduceBlockWise<256,M_C32_S1,K_C8_S1,InSrcVectorDim_1_InSrcVectorSize_1_OutDstVectorSize_1> => DeviceReduceBlockWise<256,M_C256_S1,K_C1_S1,InSrcVectorDim_1_InSrcVectorSize_1_OutDstVectorSize_1>
Perf: 2.1791 ms, 771.42 GB/s, DeviceReduceBlockWise<256,M_C32_S1,K_C8_S1,InSrcVectorDim_1_InSrcVectorSize_1_OutDstVectorSize_1> => DeviceReduceBlockWise<256,M_C256_S1,K_C1_S1,InSrcVectorDim_1_InSrcVectorSize_1_OutDstVectorSize_1>
```
```
example/12_reduce/reduce_blockwise.cpp
View file @
b097be17
...
@@ -33,11 +33,11 @@ constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::NORM2;
...
@@ -33,11 +33,11 @@ constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::NORM2;
constexpr
bool
PropagateNan
=
true
;
constexpr
bool
PropagateNan
=
true
;
constexpr
bool
OutputIndex
=
false
;
constexpr
bool
OutputIndex
=
false
;
using
ReduceOperation
=
typename
reduce_binary_operator
<
AccDataType
,
ReduceOpId
>::
opType
;
using
ReduceOperation
=
typename
reduce_binary_operator
<
ReduceOpId
>::
opType
;
using
InElementwiseOperation
=
using
InElementwiseOperation
=
typename
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
;
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
;
using
AccElementwiseOperation
=
using
AccElementwiseOperation
=
typename
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
;
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
;
using
DeviceReduceInstance
=
DeviceReduceMultiBlock
<
InDataType
,
using
DeviceReduceInstance
=
DeviceReduceMultiBlock
<
InDataType
,
AccDataType
,
AccDataType
,
...
@@ -247,6 +247,13 @@ int main(int argc, char* argv[])
...
@@ -247,6 +247,13 @@ int main(int argc, char* argv[])
DeviceMem
out_index_dev
(
indicesSizeInBytes
);
DeviceMem
out_index_dev
(
indicesSizeInBytes
);
InElementwiseOperation
in_elementwise_op
;
AccElementwiseOperation
acc_elementwise_op
;
std
::
tie
(
in_elementwise_op
,
acc_elementwise_op
)
=
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
GetElementwiseOperator
(
static_cast
<
int32_t
>
(
reduce_total_length
));
if
(
args
.
do_verification
)
if
(
args
.
do_verification
)
{
{
ReductionHost
<
InDataType
,
ReductionHost
<
InDataType
,
...
@@ -261,8 +268,13 @@ int main(int argc, char* argv[])
...
@@ -261,8 +268,13 @@ int main(int argc, char* argv[])
OutputIndex
>
OutputIndex
>
hostReduce
(
in
.
mDesc
,
out_ref
.
mDesc
,
invariantDims
,
reduceDims
);
hostReduce
(
in
.
mDesc
,
out_ref
.
mDesc
,
invariantDims
,
reduceDims
);
hostReduce
.
Run
(
hostReduce
.
Run
(
alpha
,
alpha
,
in
.
mData
.
data
(),
beta
,
out_ref
.
mData
.
data
(),
out_indices_ref
.
mData
.
data
());
in
.
mData
.
data
(),
beta
,
out_ref
.
mData
.
data
(),
out_indices_ref
.
mData
.
data
(),
in_elementwise_op
,
acc_elementwise_op
);
};
};
std
::
vector
<
ck
::
index_t
>
i_inLengths
;
std
::
vector
<
ck
::
index_t
>
i_inLengths
;
...
@@ -277,20 +289,19 @@ int main(int argc, char* argv[])
...
@@ -277,20 +289,19 @@ int main(int argc, char* argv[])
auto
reduce
=
DeviceReduceInstance
{};
auto
reduce
=
DeviceReduceInstance
{};
auto
argument_ptr
=
reduce
.
MakeArgumentPointer
(
auto
argument_ptr
=
reduce
.
MakeArgumentPointer
(
i_inLengths
,
i_inLengths
,
i_inStrides
,
i_inStrides
,
i_outLengths
,
i_outLengths
,
i_outStrides
,
i_outStrides
,
reduceDims
,
reduceDims
,
alpha
,
alpha
,
beta
,
beta
,
in_dev
.
GetDeviceBuffer
(),
in_dev
.
GetDeviceBuffer
(),
nullptr
,
nullptr
,
out_dev
.
GetDeviceBuffer
(),
out_dev
.
GetDeviceBuffer
(),
out_index_dev
.
GetDeviceBuffer
(),
out_index_dev
.
GetDeviceBuffer
(),
in_elementwise_op
,
InElementwiseOperation
{
static_cast
<
int32_t
>
(
reduce_total_length
)},
acc_elementwise_op
);
AccElementwiseOperation
{
static_cast
<
int32_t
>
(
reduce_total_length
)});
if
(
!
reduce
.
IsSupportedArgument
(
argument_ptr
.
get
()))
if
(
!
reduce
.
IsSupportedArgument
(
argument_ptr
.
get
()))
{
{
...
...
example/12_reduce/reduce_blockwise_two_call.cpp
View file @
b097be17
...
@@ -31,13 +31,13 @@ constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::NORM2;
...
@@ -31,13 +31,13 @@ constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::NORM2;
constexpr
bool
PropagateNan
=
true
;
constexpr
bool
PropagateNan
=
true
;
constexpr
bool
OutputIndex
=
false
;
constexpr
bool
OutputIndex
=
false
;
using
ReduceOperation
=
typename
reduce_binary_operator
<
AccDataType
,
ReduceOpId
>::
opType
;
using
ReduceOperation
=
typename
reduce_binary_operator
<
ReduceOpId
>::
opType
;
using
InElementwiseOperation
=
using
InElementwiseOperation
=
typename
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
;
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
;
using
AccElementwiseOperation
=
using
AccElementwiseOperation
=
typename
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
;
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
UnaryIdentic
<
AccDataType
,
AccDataType
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
using
DeviceReduceInstance_1
=
DeviceReduceMultiBlock
<
InOutDataType
,
using
DeviceReduceInstance_1
=
DeviceReduceMultiBlock
<
InOutDataType
,
AccDataType
,
AccDataType
,
...
@@ -184,6 +184,13 @@ int main(int argc, char* argv[])
...
@@ -184,6 +184,13 @@ int main(int argc, char* argv[])
if
(
beta
!=
0.0
f
)
if
(
beta
!=
0.0
f
)
out_dev
.
ToDevice
(
out
.
mData
.
data
());
out_dev
.
ToDevice
(
out
.
mData
.
data
());
InElementwiseOperation
in_elementwise_op
;
AccElementwiseOperation
acc_elementwise_op
;
std
::
tie
(
in_elementwise_op
,
acc_elementwise_op
)
=
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
GetElementwiseOperator
(
static_cast
<
int32_t
>
(
reduce_total_length
));
if
(
do_verify
)
if
(
do_verify
)
{
{
ReductionHost
<
InOutDataType
,
ReductionHost
<
InOutDataType
,
...
@@ -198,7 +205,13 @@ int main(int argc, char* argv[])
...
@@ -198,7 +205,13 @@ int main(int argc, char* argv[])
OutputIndex
>
OutputIndex
>
hostReduce
(
in_1
.
mDesc
,
out_ref
.
mDesc
,
invariantDims
,
reduceDims
);
hostReduce
(
in_1
.
mDesc
,
out_ref
.
mDesc
,
invariantDims
,
reduceDims
);
hostReduce
.
Run
(
alpha
,
in_1
.
mData
.
data
(),
beta
,
out_ref
.
mData
.
data
(),
nullptr
);
hostReduce
.
Run
(
alpha
,
in_1
.
mData
.
data
(),
beta
,
out_ref
.
mData
.
data
(),
nullptr
,
in_elementwise_op
,
acc_elementwise_op
);
};
};
std
::
vector
<
ck
::
index_t
>
i_inLengths_1
;
std
::
vector
<
ck
::
index_t
>
i_inLengths_1
;
...
@@ -217,20 +230,19 @@ int main(int argc, char* argv[])
...
@@ -217,20 +230,19 @@ int main(int argc, char* argv[])
auto
reduce_1
=
DeviceReduceInstance_1
{};
auto
reduce_1
=
DeviceReduceInstance_1
{};
auto
argument_ptr_1
=
reduce_1
.
MakeArgumentPointer
(
auto
argument_ptr_1
=
reduce_1
.
MakeArgumentPointer
(
i_inLengths_1
,
i_inLengths_1
,
i_inStrides_1
,
i_inStrides_1
,
i_inLengths_2
,
i_inLengths_2
,
i_inStrides_2
,
i_inStrides_2
,
reduceDims_1
,
reduceDims_1
,
1.0
f
,
1.0
f
,
0.0
f
,
0.0
f
,
in_1_dev
.
GetDeviceBuffer
(),
in_1_dev
.
GetDeviceBuffer
(),
nullptr
,
nullptr
,
in_2_dev
.
GetDeviceBuffer
(),
in_2_dev
.
GetDeviceBuffer
(),
nullptr
,
nullptr
,
in_elementwise_op
,
InElementwiseOperation
{
static_cast
<
int32_t
>
(
reduce_total_length
)},
PassThroughOp
{});
PassThroughOp
{});
if
(
!
reduce_1
.
IsSupportedArgument
(
argument_ptr_1
.
get
()))
if
(
!
reduce_1
.
IsSupportedArgument
(
argument_ptr_1
.
get
()))
{
{
...
@@ -243,20 +255,19 @@ int main(int argc, char* argv[])
...
@@ -243,20 +255,19 @@ int main(int argc, char* argv[])
auto
reduce_2
=
DeviceReduceInstance_2
{};
auto
reduce_2
=
DeviceReduceInstance_2
{};
auto
argument_ptr_2
=
reduce_2
.
MakeArgumentPointer
(
auto
argument_ptr_2
=
reduce_2
.
MakeArgumentPointer
(
i_inLengths_2
,
i_inLengths_2
,
i_inStrides_2
,
i_inStrides_2
,
i_outLengths
,
i_outLengths
,
i_outStrides
,
i_outStrides
,
reduceDims_2
,
reduceDims_2
,
alpha
,
alpha
,
beta
,
beta
,
in_2_dev
.
GetDeviceBuffer
(),
in_2_dev
.
GetDeviceBuffer
(),
nullptr
,
nullptr
,
out_dev
.
GetDeviceBuffer
(),
out_dev
.
GetDeviceBuffer
(),
nullptr
,
nullptr
,
PassThroughOp
{},
PassThroughOp
{},
acc_elementwise_op
);
AccElementwiseOperation
{
static_cast
<
int32_t
>
(
reduce_total_length
)});
if
(
!
reduce_2
.
IsSupportedArgument
(
argument_ptr_2
.
get
()))
if
(
!
reduce_2
.
IsSupportedArgument
(
argument_ptr_2
.
get
()))
{
{
...
...
example/13_pool2d_fwd/pool2d_fwd_common.hpp
View file @
b097be17
...
@@ -31,16 +31,15 @@ static void pool_host_verify(const Tensor<InDataType>& in,
...
@@ -31,16 +31,15 @@ static void pool_host_verify(const Tensor<InDataType>& in,
const
std
::
array
<
ck
::
index_t
,
2
>&
in_left_pads
,
const
std
::
array
<
ck
::
index_t
,
2
>&
in_left_pads
,
const
std
::
array
<
ck
::
index_t
,
2
>&
/*in_right_pads*/
)
const
std
::
array
<
ck
::
index_t
,
2
>&
/*in_right_pads*/
)
{
{
const
int32_t
divider
=
window_spatial_lengths
[
0
]
*
window_spatial_lengths
[
1
];
const
int32_t
reduceLength
=
window_spatial_lengths
[
0
]
*
window_spatial_lengths
[
1
];
using
ReduceOperation
=
typename
ck
::
reduce_binary_operator
<
AccDataType
,
ReduceOpId
>::
opType
;
using
ReduceOperation
=
typename
ck
::
reduce_binary_operator
<
ReduceOpId
>::
opType
;
using
InElementwiseOperation
=
typename
ck
::
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
;
using
AccElementwiseOperation
=
typename
ck
::
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
;
const
InElementwiseOperation
in_elementwise_op
(
divider
);
auto
elementwise_ops
=
const
AccElementwiseOperation
acc_elementwise_op
(
divider
);
ck
::
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
GetElementwiseOperator
(
reduceLength
);
auto
in_elementwise_op
=
std
::
get
<
0
>
(
elementwise_ops
);
auto
acc_elementwise_op
=
std
::
get
<
1
>
(
elementwise_ops
);
if
constexpr
(
!
OutputIndex
)
if
constexpr
(
!
OutputIndex
)
{
{
...
@@ -48,7 +47,7 @@ static void pool_host_verify(const Tensor<InDataType>& in,
...
@@ -48,7 +47,7 @@ static void pool_host_verify(const Tensor<InDataType>& in,
ck
::
detail
::
AccumulateWithNanCheck
<
PropagateNan
,
ReduceOperation
,
AccDataType
>
;
ck
::
detail
::
AccumulateWithNanCheck
<
PropagateNan
,
ReduceOperation
,
AccDataType
>
;
auto
f_nchw
=
[
&
](
auto
n
,
auto
c
,
auto
ho
,
auto
wo
)
{
auto
f_nchw
=
[
&
](
auto
n
,
auto
c
,
auto
ho
,
auto
wo
)
{
auto
accuVal
=
ReduceOperation
::
GetIdentityValue
();
auto
accuVal
=
ReduceOperation
::
template
GetIdentityValue
<
AccDataType
>
();
for
(
ck
::
index_t
y
=
0
;
y
<
window_spatial_lengths
[
0
];
++
y
)
for
(
ck
::
index_t
y
=
0
;
y
<
window_spatial_lengths
[
0
];
++
y
)
{
{
...
@@ -86,7 +85,7 @@ static void pool_host_verify(const Tensor<InDataType>& in,
...
@@ -86,7 +85,7 @@ static void pool_host_verify(const Tensor<InDataType>& in,
AccDataType
,
AccDataType
,
IndexDataType
>
;
IndexDataType
>
;
auto
f_nchw
=
[
&
](
auto
n
,
auto
c
,
auto
ho
,
auto
wo
)
{
auto
f_nchw
=
[
&
](
auto
n
,
auto
c
,
auto
ho
,
auto
wo
)
{
auto
accuVal
=
ReduceOperation
::
GetIdentityValue
();
auto
accuVal
=
ReduceOperation
::
template
GetIdentityValue
<
AccDataType
>
();
IndexDataType
accuIndex
=
0
;
IndexDataType
accuIndex
=
0
;
for
(
ck
::
index_t
y
=
0
;
y
<
window_spatial_lengths
[
0
];
++
y
)
for
(
ck
::
index_t
y
=
0
;
y
<
window_spatial_lengths
[
0
];
++
y
)
...
...
example/16_gemm_reduce/gemm_reduce_xdl_max_fp16.cpp
View file @
b097be17
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
#include "element_wise_operation.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
#include "gemm_specialization.hpp"
#include "element_wise_reduce_operation.hpp"
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
@@ -41,9 +40,8 @@ using CLayout = ck::tensor_layout::gemm::RowMajor;
...
@@ -41,9 +40,8 @@ using CLayout = ck::tensor_layout::gemm::RowMajor;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DsReduceOp
=
ck
::
Tuple
<
ck
::
reduce
::
Max
<
ReduceAccDataType
>>
;
using
DsReduceOp
=
ck
::
Tuple
<
ck
::
reduce
::
Max
>
;
using
DsElementOp
=
ck
::
Tuple
<
using
DsElementOp
=
ck
::
Tuple
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
ReduceAccDataType
,
ReduceAccDataType
,
false
>>
;
using
DGlobalMemOp
=
using
DGlobalMemOp
=
ck
::
InMemoryDataOperationEnumSequence
<
ck
::
InMemoryDataOperationEnum
::
AtomicMax
>
;
ck
::
InMemoryDataOperationEnumSequence
<
ck
::
InMemoryDataOperationEnum
::
AtomicMax
>
;
...
@@ -236,10 +234,14 @@ int main(int argc, char* argv[])
...
@@ -236,10 +234,14 @@ int main(int argc, char* argv[])
for
(
int
m
=
0
;
m
<
M
;
++
m
)
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
{
ReduceAccDataType
d_acc
=
d_reduce_op
.
GetIdentityValue
();
ReduceAccDataType
d_acc
=
d_reduce_op
.
GetIdentityValue
<
ReduceAccDataType
>
();
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
int
n
=
0
;
n
<
N
;
++
n
)
d_reduce_op
(
d_acc
,
c_m_n_host_result
(
m
,
n
));
{
ReduceAccDataType
curr_val
=
ck
::
type_convert
<
ReduceAccDataType
>
(
c_m_n_host_result
(
m
,
n
));
d_reduce_op
(
d_acc
,
curr_val
);
};
d_m_host_result
(
m
)
=
d_acc
;
d_m_host_result
(
m
)
=
d_acc
;
}
}
...
...
example/16_gemm_reduce/gemm_reduce_xdl_mean_squaremean_fp16.cpp
View file @
b097be17
...
@@ -41,18 +41,15 @@ using CLayout = ck::tensor_layout::gemm::RowMajor;
...
@@ -41,18 +41,15 @@ using CLayout = ck::tensor_layout::gemm::RowMajor;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
D0ReduceOp
=
ck
::
reduce
::
Add
<
ReduceAccDataType
>
;
using
D0ReduceOp
=
ck
::
reduce
::
Add
;
using
D1ReduceOp
=
ck
::
reduce
::
Add
<
ReduceAccDataType
>
;
using
D1ReduceOp
=
ck
::
reduce
::
Add
;
using
DxsReduceOp
=
ck
::
Tuple
<
D0ReduceOp
,
D1ReduceOp
>
;
using
DxsReduceOp
=
ck
::
Tuple
<
D0ReduceOp
,
D1ReduceOp
>
;
using
UnaryIdenticElementOp
=
using
UnaryIdenticElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
ReduceAccDataType
,
ReduceAccDataType
,
false
>
;
using
UnaryDivElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnaryDivide
;
using
UnaryDivElementOp
=
using
UnarySquareElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
;
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
ReduceAccDataType
,
ReduceAccDataType
,
true
>
;
using
DxsInElementOps
=
ck
::
Tuple
<
UnaryIdenticElementOp
,
UnarySquareElementOp
>
;
using
UnarySquareElementOp
=
using
DxsOutElementOps
=
ck
::
Tuple
<
UnaryDivElementOp
,
UnaryDivElementOp
>
;
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
ReduceAccDataType
,
ReduceAccDataType
,
false
>
;
using
DxsInElementOp
=
ck
::
Tuple
<
UnaryIdenticElementOp
,
UnarySquareElementOp
>
;
using
DxsOutElementOp
=
ck
::
Tuple
<
UnaryDivElementOp
,
UnaryDivElementOp
>
;
using
DGlobalMemOp
=
using
DGlobalMemOp
=
ck
::
InMemoryDataOperationEnumSequence
<
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
,
ck
::
InMemoryDataOperationEnumSequence
<
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
,
...
@@ -67,7 +64,7 @@ using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_
...
@@ -67,7 +64,7 @@ using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
AElementOp
,
BElementOp
,
CElementOp
,
DxsReduceOp
,
DxsInElementOp
,
DxsOutElementOp
,
DGlobalMemOp
,
GemmSpecialization
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
;
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
AElementOp
,
BElementOp
,
CElementOp
,
DxsReduceOp
,
DxsInElementOp
s
,
DxsOutElementOp
s
,
DGlobalMemOp
,
GemmSpecialization
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
;
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
...
@@ -204,8 +201,8 @@ int main(int argc, char* argv[])
...
@@ -204,8 +201,8 @@ int main(int argc, char* argv[])
auto
dxs_global
=
ck
::
make_tuple
(
static_cast
<
DDataType
*>
(
d0_device_buf
.
GetDeviceBuffer
()),
auto
dxs_global
=
ck
::
make_tuple
(
static_cast
<
DDataType
*>
(
d0_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DDataType
*>
(
d1_device_buf
.
GetDeviceBuffer
()));
static_cast
<
DDataType
*>
(
d1_device_buf
.
GetDeviceBuffer
()));
auto
dxs_in_element_op
=
DxsInElementOp
{};
auto
dxs_in_element_op
=
DxsInElementOp
s
{};
auto
dxs_out_element_op
=
DxsOutElementOp
{
M
,
M
};
auto
dxs_out_element_op
=
DxsOutElementOp
s
{
N
,
N
};
// do GEMM
// do GEMM
auto
gemm
=
DeviceGemmReduceInstance
{};
auto
gemm
=
DeviceGemmReduceInstance
{};
...
@@ -261,14 +258,14 @@ int main(int argc, char* argv[])
...
@@ -261,14 +258,14 @@ int main(int argc, char* argv[])
for
(
int
m
=
0
;
m
<
M
;
++
m
)
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
{
float
d0_acc
=
d0_reduce_op
.
GetIdentityValue
();
auto
d0_acc
=
d0_reduce_op
.
GetIdentityValue
<
ReduceAccDataType
>
();
float
d1_acc
=
d1_reduce_op
.
GetIdentityValue
();
auto
d1_acc
=
d1_reduce_op
.
GetIdentityValue
<
ReduceAccDataType
>
();
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
{
float
c_val
=
ck
::
type_convert
<
float
>
(
c_m_n_host_result
(
m
,
n
));
auto
c_val
=
ck
::
type_convert
<
ReduceAccDataType
>
(
c_m_n_host_result
(
m
,
n
));
float
d0_val
=
0
;
ReduceAccDataType
d0_val
;
float
d1_val
=
0
;
ReduceAccDataType
d1_val
;
dxs_in_element_op
(
ck
::
Number
<
0
>
{})(
d0_val
,
c_val
);
dxs_in_element_op
(
ck
::
Number
<
0
>
{})(
d0_val
,
c_val
);
dxs_in_element_op
(
ck
::
Number
<
1
>
{})(
d1_val
,
c_val
);
dxs_in_element_op
(
ck
::
Number
<
1
>
{})(
d1_val
,
c_val
);
...
...
example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp
View file @
b097be17
...
@@ -39,16 +39,14 @@ using CLayout = ck::tensor_layout::gemm::RowMajor;
...
@@ -39,16 +39,14 @@ using CLayout = ck::tensor_layout::gemm::RowMajor;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
D0ReduceOp
=
ck
::
reduce
::
Add
<
ReduceAccDataType
>
;
using
D0ReduceOp
=
ck
::
reduce
::
Add
;
using
D1ReduceOp
=
ck
::
reduce
::
Add
<
ReduceAccDataType
>
;
using
D1ReduceOp
=
ck
::
reduce
::
Add
;
using
DxsReduceOp
=
ck
::
Tuple
<
D0ReduceOp
,
D1ReduceOp
>
;
using
DxsReduceOp
=
ck
::
Tuple
<
D0ReduceOp
,
D1ReduceOp
>
;
using
UnaryIdenticElementOp
=
using
UnaryIdenticElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
ReduceAccDataType
,
ReduceAccDataType
,
false
>
;
using
UnarySquareElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
;
using
UnarySquareElementOp
=
using
DxsInElementOps
=
ck
::
Tuple
<
UnaryIdenticElementOp
,
UnarySquareElementOp
>
;
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
ReduceAccDataType
,
ReduceAccDataType
,
false
>
;
using
DxsOutElementOps
=
ck
::
Tuple
<
UnaryIdenticElementOp
,
UnaryIdenticElementOp
>
;
using
DxsInElementOp
=
ck
::
Tuple
<
UnaryIdenticElementOp
,
UnarySquareElementOp
>
;
using
DxsOutElementOp
=
ck
::
Tuple
<
UnaryIdenticElementOp
,
UnaryIdenticElementOp
>
;
using
DGlobalMemOp
=
using
DGlobalMemOp
=
ck
::
InMemoryDataOperationEnumSequence
<
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
,
ck
::
InMemoryDataOperationEnumSequence
<
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
,
...
@@ -63,7 +61,7 @@ using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatc
...
@@ -63,7 +61,7 @@ using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatc
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
AElementOp
,
BElementOp
,
CElementOp
,
DxsReduceOp
,
DxsInElementOp
,
DxsOutElementOp
,
DGlobalMemOp
,
GemmSpecialization
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
;
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
AElementOp
,
BElementOp
,
CElementOp
,
DxsReduceOp
,
DxsInElementOp
s
,
DxsOutElementOp
s
,
DGlobalMemOp
,
GemmSpecialization
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
;
// clang-format on
// clang-format on
using
ReferenceBatchedGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceBatchedGemmInstance
=
ck
::
tensor_operation
::
host
::
...
@@ -206,8 +204,8 @@ int main(int argc, char* argv[])
...
@@ -206,8 +204,8 @@ int main(int argc, char* argv[])
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
DxsInElementOp
{},
DxsInElementOp
s
{},
DxsOutElementOp
{},
DxsOutElementOp
s
{},
BatchCount
);
BatchCount
);
if
(
!
batched_gemm
.
IsSupportedArgument
(
argument
))
if
(
!
batched_gemm
.
IsSupportedArgument
(
argument
))
...
@@ -259,14 +257,15 @@ int main(int argc, char* argv[])
...
@@ -259,14 +257,15 @@ int main(int argc, char* argv[])
{
{
for
(
int
m
=
0
;
m
<
M
;
++
m
)
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
{
float
d0_acc
=
d0_reduce_op
.
GetIdentityValue
();
auto
d0_acc
=
d0_reduce_op
.
GetIdentityValue
<
ReduceAccDataType
>
();
float
d1_acc
=
d1_reduce_op
.
GetIdentityValue
();
auto
d1_acc
=
d1_reduce_op
.
GetIdentityValue
<
ReduceAccDataType
>
();
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
{
float
c_val
=
ck
::
type_convert
<
float
>
(
c_g_m_n_host_result
(
batch
,
m
,
n
));
auto
c_val
=
float
d0_val
=
0
;
ck
::
type_convert
<
ReduceAccDataType
>
(
c_g_m_n_host_result
(
batch
,
m
,
n
));
float
d1_val
=
0
;
ReduceAccDataType
d0_val
;
ReduceAccDataType
d1_val
;
UnaryIdenticElementOp
{}(
d0_val
,
c_val
);
UnaryIdenticElementOp
{}(
d0_val
,
c_val
);
UnarySquareElementOp
{}(
d1_val
,
c_val
);
UnarySquareElementOp
{}(
d1_val
,
c_val
);
...
...
example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp
View file @
b097be17
...
@@ -42,8 +42,7 @@ using ABDataType = F16;
...
@@ -42,8 +42,7 @@ using ABDataType = F16;
using
CDataType
=
F16
;
using
CDataType
=
F16
;
using
EltwiseComputeDataType
=
F32
;
using
EltwiseComputeDataType
=
F32
;
using
Add
=
ck
::
tensor_operation
::
binary_element_wise
::
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
;
Add
<
EltwiseComputeDataType
,
EltwiseComputeDataType
,
EltwiseComputeDataType
>
;
using
DeviceElementwiseAddInstance
=
using
DeviceElementwiseAddInstance
=
ck
::
tensor_operation
::
device
::
DeviceBinaryElementwise
<
ABDataType
,
ck
::
tensor_operation
::
device
::
DeviceBinaryElementwise
<
ABDataType
,
...
...
example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp
View file @
b097be17
...
@@ -17,8 +17,7 @@ using ABDataType = F16;
...
@@ -17,8 +17,7 @@ using ABDataType = F16;
using
CDataType
=
F16
;
using
CDataType
=
F16
;
using
EltwiseComputeDataType
=
F32
;
using
EltwiseComputeDataType
=
F32
;
using
Add
=
ck
::
tensor_operation
::
binary_element_wise
::
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
;
Add
<
EltwiseComputeDataType
,
EltwiseComputeDataType
,
EltwiseComputeDataType
>
;
using
DeviceElementwiseAddInstance
=
using
DeviceElementwiseAddInstance
=
ck
::
tensor_operation
::
device
::
DeviceBinaryElementwise
<
ABDataType
,
ck
::
tensor_operation
::
device
::
DeviceBinaryElementwise
<
ABDataType
,
...
...
example/19_binary_elementwise/elementwise_add_1d.cpp
View file @
b097be17
...
@@ -42,8 +42,7 @@ using ABDataType = F16;
...
@@ -42,8 +42,7 @@ using ABDataType = F16;
using
CDataType
=
F16
;
using
CDataType
=
F16
;
using
EltwiseComputeDataType
=
F32
;
using
EltwiseComputeDataType
=
F32
;
using
Add
=
ck
::
tensor_operation
::
binary_element_wise
::
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
;
Add
<
EltwiseComputeDataType
,
EltwiseComputeDataType
,
EltwiseComputeDataType
>
;
using
DeviceElementwiseAddInstance
=
using
DeviceElementwiseAddInstance
=
ck
::
tensor_operation
::
device
::
DeviceBinaryElementwise
<
ABDataType
,
ck
::
tensor_operation
::
device
::
DeviceBinaryElementwise
<
ABDataType
,
...
...
Prev
1
2
3
4
5
…
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment