Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
29deceb6
Unverified
Commit
29deceb6
authored
Nov 28, 2023
by
Illia Silin
Committed by
GitHub
Nov 28, 2023
Browse files
Merge pull request #18 from ROCmSoftwarePlatform/merge-from-public
Merge from public
parents
91c1d147
c997bbf6
Changes
422
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
309 additions
and
105 deletions
+309
-105
Config.cmake.in
Config.cmake.in
+1
-1
Jenkinsfile
Jenkinsfile
+3
-3
client_example/01_gemm/CMakeLists.txt
client_example/01_gemm/CMakeLists.txt
+1
-1
client_example/02_gemm_add_add_fastgelu/CMakeLists.txt
client_example/02_gemm_add_add_fastgelu/CMakeLists.txt
+6
-6
client_example/03_gemm_layernorm/CMakeLists.txt
client_example/03_gemm_layernorm/CMakeLists.txt
+2
-2
client_example/04_contraction/CMakeLists.txt
client_example/04_contraction/CMakeLists.txt
+5
-5
client_example/05_layernorm/CMakeLists.txt
client_example/05_layernorm/CMakeLists.txt
+5
-2
client_example/05_layernorm/layernorm2d_fwd.cpp
client_example/05_layernorm/layernorm2d_fwd.cpp
+10
-10
client_example/05_layernorm/layernorm4d_fwd.cpp
client_example/05_layernorm/layernorm4d_fwd.cpp
+201
-0
client_example/06_softmax/CMakeLists.txt
client_example/06_softmax/CMakeLists.txt
+1
-1
client_example/07_grouped_convnd_fwd/CMakeLists.txt
client_example/07_grouped_convnd_fwd/CMakeLists.txt
+2
-2
client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp
client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp
+12
-12
client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp
client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp
+12
-12
client_example/08_fused_attention/CMakeLists.txt
client_example/08_fused_attention/CMakeLists.txt
+2
-2
client_example/09_quantization/CMakeLists.txt
client_example/09_quantization/CMakeLists.txt
+7
-7
client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp
...tization/conv2d_fwd_bias_relu_perchannel_quantization.cpp
+1
-1
client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp
...antization/conv2d_fwd_bias_relu_perlayer_quantization.cpp
+12
-12
client_example/09_quantization/conv2d_fwd_bias_tanh_perchannel_quantization.cpp
...tization/conv2d_fwd_bias_tanh_perchannel_quantization.cpp
+1
-1
client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp
...antization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp
+12
-12
client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp
...le/09_quantization/conv2d_fwd_perchannel_quantization.cpp
+13
-13
No files found.
Config.cmake.in
View file @
29deceb6
@PACKAGE_INIT@
@PACKAGE_INIT@
set(_composable_kernel_supported_components device_operations utility)
set(_composable_kernel_supported_components device_
other_operations device_gemm_operations device_conv_operations device_mha_operations device_contraction_operations device_reduction_
operations utility)
foreach(_comp ${composable_kernel_FIND_COMPONENTS})
foreach(_comp ${composable_kernel_FIND_COMPONENTS})
if(NOT _comp IN_LIST _composable_kernel_supported_components)
if(NOT _comp IN_LIST _composable_kernel_supported_components)
...
...
Jenkinsfile
View file @
29deceb6
...
@@ -694,8 +694,8 @@ pipeline {
...
@@ -694,8 +694,8 @@ pipeline {
description:
"Use the CK build to verify hipTensor build and tests (default: ON)"
)
description:
"Use the CK build to verify hipTensor build and tests (default: ON)"
)
string
(
string
(
name:
'hipTensor_branch'
,
name:
'hipTensor_branch'
,
defaultValue:
'
mainline
'
,
defaultValue:
'
develop
'
,
description:
'Specify which branch of hipTensor to use (default:
mainline
)'
)
description:
'Specify which branch of hipTensor to use (default:
develop
)'
)
booleanParam
(
booleanParam
(
name:
"USE_SCCACHE"
,
name:
"USE_SCCACHE"
,
defaultValue:
true
,
defaultValue:
true
,
...
@@ -759,7 +759,7 @@ pipeline {
...
@@ -759,7 +759,7 @@ pipeline {
}
}
agent
{
label
rocmnode
(
"gfx908 || gfx90a"
)
}
agent
{
label
rocmnode
(
"gfx908 || gfx90a"
)
}
environment
{
environment
{
setup_args
=
""" -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" """
setup_args
=
""" -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942"
-DCMAKE_EXE_LINKER_FLAGS=" -L ${env.WORKSPACE}/script -T hip_fatbin_insert "
"""
execute_args
=
""" cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
execute_args
=
""" cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
}
}
steps
{
steps
{
...
...
client_example/01_gemm/CMakeLists.txt
View file @
29deceb6
add_executable
(
client_gemm gemm.cpp
)
add_executable
(
client_gemm gemm.cpp
)
target_link_libraries
(
client_gemm PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_gemm PRIVATE composable_kernel::device_
other_operations composable_kernel::device_gemm_
operations
)
client_example/02_gemm_add_add_fastgelu/CMakeLists.txt
View file @
29deceb6
add_custom_target
(
client_gemm_fastgelu_examples
)
add_custom_target
(
client_gemm_fastgelu_examples
)
add_executable
(
client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp
)
add_executable
(
client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp
)
target_link_libraries
(
client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_
gemm_
operations
)
add_executable
(
client_gemm_add_fastgelu gemm_add_fastgelu.cpp
)
add_executable
(
client_gemm_add_fastgelu gemm_add_fastgelu.cpp
)
target_link_libraries
(
client_gemm_add_fastgelu PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_gemm_add_fastgelu PRIVATE composable_kernel::device_
gemm_
operations
)
add_executable
(
client_gemm_fastgelu gemm_fastgelu.cpp
)
add_executable
(
client_gemm_fastgelu gemm_fastgelu.cpp
)
target_link_libraries
(
client_gemm_fastgelu PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_gemm_fastgelu PRIVATE composable_kernel::device_
gemm_
operations
)
add_dependencies
(
client_gemm_fastgelu_examples client_gemm_add_add_fastgelu client_gemm_add_fastgelu
add_dependencies
(
client_gemm_fastgelu_examples client_gemm_add_add_fastgelu client_gemm_add_fastgelu
client_gemm_fastgelu
)
client_gemm_fastgelu
)
...
@@ -15,13 +15,13 @@ add_dependencies(client_gemm_fastgelu_examples client_gemm_add_add_fastgelu clie
...
@@ -15,13 +15,13 @@ add_dependencies(client_gemm_fastgelu_examples client_gemm_add_add_fastgelu clie
add_custom_target
(
client_gemm_fastgelu_generic_examples
)
add_custom_target
(
client_gemm_fastgelu_generic_examples
)
add_executable
(
client_gemm_add_add_fastgelu_generic gemm_add_add_fastgelu_generic.cpp
)
add_executable
(
client_gemm_add_add_fastgelu_generic gemm_add_add_fastgelu_generic.cpp
)
target_link_libraries
(
client_gemm_add_add_fastgelu_generic
PRIVATE
composable_kernel::device_operations
)
target_link_libraries
(
client_gemm_add_add_fastgelu_generic composable_kernel::device_
gemm_
operations
)
add_executable
(
client_gemm_add_fastgelu_generic gemm_add_fastgelu_generic.cpp
)
add_executable
(
client_gemm_add_fastgelu_generic gemm_add_fastgelu_generic.cpp
)
target_link_libraries
(
client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_
gemm_
operations
)
add_executable
(
client_gemm_fastgelu_generic gemm_fastgelu_generic.cpp
)
add_executable
(
client_gemm_fastgelu_generic gemm_fastgelu_generic.cpp
)
target_link_libraries
(
client_gemm_fastgelu_generic PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_gemm_fastgelu_generic PRIVATE composable_kernel::device_
gemm_
operations
)
add_dependencies
(
client_gemm_fastgelu_generic_examples client_gemm_add_add_fastgelu_generic
add_dependencies
(
client_gemm_fastgelu_generic_examples client_gemm_add_add_fastgelu_generic
client_gemm_add_fastgelu_generic client_gemm_fastgelu_generic
)
client_gemm_add_fastgelu_generic client_gemm_fastgelu_generic
)
client_example/03_gemm_layernorm/CMakeLists.txt
View file @
29deceb6
add_executable
(
client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp
)
add_executable
(
client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp
)
target_link_libraries
(
client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_
gemm_operations composable_kernel::device_other_
operations
)
add_executable
(
client_gemm_add_relu_add_layernorm_welford gemm_add_relu_add_layernorm_welford.cpp
)
add_executable
(
client_gemm_add_relu_add_layernorm_welford gemm_add_relu_add_layernorm_welford.cpp
)
target_link_libraries
(
client_gemm_add_relu_add_layernorm_welford PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_gemm_add_relu_add_layernorm_welford PRIVATE composable_kernel::device_
gemm_operations composable_kernel::device_other_
operations
)
client_example/04_contraction/CMakeLists.txt
View file @
29deceb6
add_executable
(
client_contraction_scale_fp32 contraction_scale_fp32.cpp
)
add_executable
(
client_contraction_scale_fp32 contraction_scale_fp32.cpp
)
target_link_libraries
(
client_contraction_scale_fp32 PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_contraction_scale_fp32 PRIVATE composable_kernel::device_
other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_
operations
)
add_executable
(
client_contraction_bilinear_fp32 contraction_bilinear_fp32.cpp
)
add_executable
(
client_contraction_bilinear_fp32 contraction_bilinear_fp32.cpp
)
target_link_libraries
(
client_contraction_bilinear_fp32 PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_contraction_bilinear_fp32 PRIVATE composable_kernel::device_
other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_
operations
)
add_executable
(
client_contraction_scale_fp64 contraction_scale_fp64.cpp
)
add_executable
(
client_contraction_scale_fp64 contraction_scale_fp64.cpp
)
target_link_libraries
(
client_contraction_scale_fp64 PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_contraction_scale_fp64 PRIVATE composable_kernel::device_
other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_
operations
)
add_executable
(
client_contraction_bilinear_fp64 contraction_bilinear_fp64.cpp
)
add_executable
(
client_contraction_bilinear_fp64 contraction_bilinear_fp64.cpp
)
target_link_libraries
(
client_contraction_bilinear_fp64 PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_contraction_bilinear_fp64 PRIVATE composable_kernel::device_
other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_
operations
)
add_executable
(
contraction_g1m2n3k1_add_xdl_fp16 contraction_g1m2n3k1_add_xdl_fp16.cpp
)
add_executable
(
contraction_g1m2n3k1_add_xdl_fp16 contraction_g1m2n3k1_add_xdl_fp16.cpp
)
target_link_libraries
(
contraction_g1m2n3k1_add_xdl_fp16 PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
contraction_g1m2n3k1_add_xdl_fp16 PRIVATE composable_kernel::device_
other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_
operations
)
client_example/05_layernorm/CMakeLists.txt
View file @
29deceb6
add_executable
(
client_layernorm2d layernorm2d.cpp
)
add_executable
(
client_layernorm2d_fwd layernorm2d_fwd.cpp
)
target_link_libraries
(
client_layernorm2d PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_layernorm2d_fwd PRIVATE composable_kernel::device_other_operations
)
add_executable
(
client_layernorm4d_fwd layernorm4d_fwd.cpp
)
target_link_libraries
(
client_layernorm4d_fwd PRIVATE composable_kernel::device_other_operations
)
client_example/05_layernorm/layernorm2d.cpp
→
client_example/05_layernorm/layernorm2d
_fwd
.cpp
View file @
29deceb6
...
@@ -7,10 +7,10 @@
...
@@ -7,10 +7,10 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization
_fwd
.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/normalization.hpp"
#include "ck/library/tensor_operation_instance/gpu/normalization
_fwd
.hpp"
using
XDataType
=
ck
::
half_t
;
using
XDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
...
@@ -57,14 +57,14 @@ int main(int argc, char* argv[])
...
@@ -57,14 +57,14 @@ int main(int argc, char* argv[])
SimpleDeviceMem
save_inv_std_device_buf
(
sizeof
(
SaveMeanInvStdDataType
)
*
M
);
SimpleDeviceMem
save_inv_std_device_buf
(
sizeof
(
SaveMeanInvStdDataType
)
*
M
);
#endif
#endif
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceNormalization
<
XDataType
,
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceNormalization
Fwd
<
XDataType
,
GammaDataType
,
GammaDataType
,
BetaDataType
,
BetaDataType
,
YDataType
,
YDataType
,
SaveMeanInvStdDataType
,
SaveMeanInvStdDataType
,
PassThrough
,
PassThrough
,
Rank
,
Rank
,
NumReduceDim
>
;
NumReduceDim
>
;
// get device op instances
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
...
...
client_example/05_layernorm/layernorm4d_fwd.cpp
0 → 100644
View file @
29deceb6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <vector>
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/normalization_fwd.hpp"
using
XDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
SaveMeanInvStdDataType
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
#define SAVE_MEAN_INV_STD
constexpr
int
Rank
=
4
;
constexpr
int
NumReduceDim
=
3
;
struct
SimpleDeviceMem
{
SimpleDeviceMem
()
=
delete
;
SimpleDeviceMem
(
std
::
size_t
mem_size
)
:
p_mem_
{}
{
(
void
)
hipMalloc
(
static_cast
<
void
**>
(
&
p_mem_
),
mem_size
);
}
void
*
GetDeviceBuffer
()
{
return
p_mem_
;
}
~
SimpleDeviceMem
()
{
(
void
)
hipFree
(
p_mem_
);
}
void
*
p_mem_
;
};
int
main
(
int
argc
,
char
*
argv
[])
{
ck
::
index_t
N
=
256
;
ck
::
index_t
H
=
16
;
ck
::
index_t
W
=
16
;
ck
::
index_t
C
=
8
;
std
::
vector
<
ck
::
index_t
>
strideXY
=
{
H
*
W
*
C
,
W
*
C
,
C
,
1
};
std
::
vector
<
ck
::
index_t
>
strideGammaBeta
=
{
0
,
W
*
C
,
C
,
1
};
std
::
vector
<
ck
::
index_t
>
strideSaveMeanInvStd
=
{
1
};
SimpleDeviceMem
x_device_buf
(
sizeof
(
XDataType
)
*
N
*
H
*
W
*
C
);
SimpleDeviceMem
gamma_device_buf
(
sizeof
(
GammaDataType
)
*
H
*
W
*
C
);
SimpleDeviceMem
beta_device_buf
(
sizeof
(
BetaDataType
)
*
H
*
W
*
C
);
SimpleDeviceMem
y_device_buf
(
sizeof
(
YDataType
)
*
N
*
H
*
W
*
C
);
#ifdef SAVE_MEAN_INV_STD
SimpleDeviceMem
save_mean_device_buf
(
sizeof
(
SaveMeanInvStdDataType
)
*
N
);
SimpleDeviceMem
save_inv_std_device_buf
(
sizeof
(
SaveMeanInvStdDataType
)
*
N
);
#endif
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceNormalizationFwd
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
PassThrough
,
Rank
,
NumReduceDim
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
std
::
string
best_op_name
;
bool
found
=
false
;
int
best_op_id
=
-
1
;
float
best_ave_time
=
std
::
numeric_limits
<
float
>::
max
();
float
best_gb_per_sec
=
0
;
// profile device operation instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
({
N
,
H
,
W
,
C
},
// lengths
strideXY
,
// xStrides
strideGammaBeta
,
// gammaStrides
strideGammaBeta
,
// betaStrides
strideXY
,
// yStrides
strideSaveMeanInvStd
,
// save_mean Strides
strideSaveMeanInvStd
,
// save_inv_std Strides
{
1
,
2
,
3
},
// reduceDims
1e-4
,
x_device_buf
.
GetDeviceBuffer
(),
gamma_device_buf
.
GetDeviceBuffer
(),
beta_device_buf
.
GetDeviceBuffer
(),
y_device_buf
.
GetDeviceBuffer
(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf
.
GetDeviceBuffer
(),
save_inv_std_device_buf
.
GetDeviceBuffer
(),
#else
nullptr
,
nullptr
,
#endif
PassThrough
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
size_t
workspace_sz
=
op_ptr
->
GetWorkSpaceSize
(
argument_ptr
.
get
());
SimpleDeviceMem
workspace
(
workspace_sz
);
op_ptr
->
SetWorkSpacePointer
(
argument_ptr
.
get
(),
workspace
.
GetDeviceBuffer
());
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
std
::
size_t
num_byte
=
sizeof
(
XDataType
)
*
N
*
H
*
W
*
C
+
sizeof
(
GammaDataType
)
*
H
*
W
*
C
+
sizeof
(
BetaDataType
)
*
H
*
W
*
C
+
sizeof
(
YDataType
)
*
N
*
H
*
W
*
C
;
#ifdef SAVE_MEAN_INV_STD
num_byte
+=
sizeof
(
SaveMeanInvStdDataType
)
*
N
*
2
;
#endif
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
if
(
ave_time
<
best_ave_time
)
{
found
=
true
;
best_op_id
=
i
;
best_op_name
=
op_name
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
}
else
{
std
::
cout
<<
op_name
<<
" does not support this problem"
<<
std
::
endl
;
}
}
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
// run the best intance
{
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
({
N
,
H
,
W
,
C
},
// lengths
strideXY
,
// xStrides
strideGammaBeta
,
// gammaStrides
strideGammaBeta
,
// betaStrides
strideXY
,
// yStrides
strideSaveMeanInvStd
,
// save_mean Strides
strideSaveMeanInvStd
,
// save_inv_std Strides
{
1
,
2
,
3
},
// reduceDims
1e-4
,
x_device_buf
.
GetDeviceBuffer
(),
gamma_device_buf
.
GetDeviceBuffer
(),
beta_device_buf
.
GetDeviceBuffer
(),
y_device_buf
.
GetDeviceBuffer
(),
#ifdef SAVE_MEAN_INV_STD
save_mean_device_buf
.
GetDeviceBuffer
(),
save_inv_std_device_buf
.
GetDeviceBuffer
(),
#else
nullptr
,
nullptr
,
#endif
PassThrough
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
size_t
workspace_sz
=
op_ptr
->
GetWorkSpaceSize
(
argument_ptr
.
get
());
SimpleDeviceMem
workspace
(
workspace_sz
);
op_ptr
->
SetWorkSpacePointer
(
argument_ptr
.
get
(),
workspace
.
GetDeviceBuffer
());
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
}
std
::
cout
<<
"Done"
<<
std
::
endl
;
}
return
0
;
}
client_example/06_softmax/CMakeLists.txt
View file @
29deceb6
add_executable
(
client_softmax4d softmax4d.cpp
)
add_executable
(
client_softmax4d softmax4d.cpp
)
target_link_libraries
(
client_softmax4d PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_softmax4d PRIVATE composable_kernel::device_
other_operations composable_kernel::device_reduction_
operations
)
client_example/07_grouped_convnd_fwd/CMakeLists.txt
View file @
29deceb6
add_executable
(
client_grouped_conv2d_fwd grouped_conv2d_fwd.cpp
)
add_executable
(
client_grouped_conv2d_fwd grouped_conv2d_fwd.cpp
)
target_link_libraries
(
client_grouped_conv2d_fwd PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_grouped_conv2d_fwd PRIVATE composable_kernel::device_
conv_
operations
)
add_executable
(
client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp
)
add_executable
(
client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp
)
target_link_libraries
(
client_grouped_conv1d_fwd PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_grouped_conv1d_fwd PRIVATE composable_kernel::device_
conv_
operations
)
client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp
View file @
29deceb6
...
@@ -100,18 +100,18 @@ int main()
...
@@ -100,18 +100,18 @@ int main()
SimpleDeviceMem
wei
(
sizeof
(
WeiDataType
)
*
G
*
K
*
X
*
C
);
SimpleDeviceMem
wei
(
sizeof
(
WeiDataType
)
*
G
*
K
*
X
*
C
);
SimpleDeviceMem
out
(
sizeof
(
OutDataType
)
*
G
*
N
*
Wo
*
K
);
SimpleDeviceMem
out
(
sizeof
(
OutDataType
)
*
G
*
N
*
Wo
*
K
);
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleD
<
NumDimSpatial
,
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultiple
AB
D
<
NumDimSpatial
,
InLayout
,
InLayout
,
WeiLayout
,
WeiLayout
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
OutLayout
,
OutLayout
,
InDataType
,
InDataType
,
WeiDataType
,
WeiDataType
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
OutDataType
,
OutDataType
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>
;
PassThrough
>
;
// get device op instances
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
...
...
client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp
View file @
29deceb6
...
@@ -71,18 +71,18 @@ int main()
...
@@ -71,18 +71,18 @@ int main()
SimpleDeviceMem
wei
(
sizeof
(
WeiDataType
)
*
G
*
K
*
Y
*
X
*
C
);
SimpleDeviceMem
wei
(
sizeof
(
WeiDataType
)
*
G
*
K
*
Y
*
X
*
C
);
SimpleDeviceMem
out
(
sizeof
(
OutDataType
)
*
N
*
Ho
*
Wo
*
G
*
K
);
SimpleDeviceMem
out
(
sizeof
(
OutDataType
)
*
N
*
Ho
*
Wo
*
G
*
K
);
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleD
<
NumDimSpatial
,
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultiple
AB
D
<
NumDimSpatial
,
InLayout
,
InLayout
,
WeiLayout
,
WeiLayout
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
OutLayout
,
OutLayout
,
InDataType
,
InDataType
,
WeiDataType
,
WeiDataType
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
OutDataType
,
OutDataType
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>
;
PassThrough
>
;
// get device op instances
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
...
...
client_example/08_fused_attention/CMakeLists.txt
View file @
29deceb6
add_executable
(
client_fused_attention fused_attention.cpp
)
add_executable
(
client_fused_attention fused_attention.cpp
)
target_link_libraries
(
client_fused_attention PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_fused_attention PRIVATE composable_kernel::device_
other_operations composable_kernel::device_gemm_
operations
)
add_executable
(
client_fused_attention_bias fused_attention_bias.cpp
)
add_executable
(
client_fused_attention_bias fused_attention_bias.cpp
)
target_link_libraries
(
client_fused_attention_bias PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_fused_attention_bias PRIVATE composable_kernel::device_
other_operations composable_kernel::device_gemm_
operations
)
client_example/09_quantization/CMakeLists.txt
View file @
29deceb6
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_executable
(
client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp
)
add_executable
(
client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp
)
target_link_libraries
(
client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_
conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_
operations
)
add_executable
(
client_conv2d_fwd_bias_relu_perchannel_quantization conv2d_fwd_bias_relu_perchannel_quantization.cpp
)
add_executable
(
client_conv2d_fwd_bias_relu_perchannel_quantization conv2d_fwd_bias_relu_perchannel_quantization.cpp
)
target_link_libraries
(
client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_
conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_
operations
)
add_executable
(
client_conv2d_fwd_bias_tanh_perlayer_quantization conv2d_fwd_bias_tanh_perlayer_quantization.cpp
)
add_executable
(
client_conv2d_fwd_bias_tanh_perlayer_quantization conv2d_fwd_bias_tanh_perlayer_quantization.cpp
)
target_link_libraries
(
client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE composable_kernel::device_
conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_
operations
)
add_executable
(
client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp
)
add_executable
(
client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp
)
target_link_libraries
(
client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_
conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_
operations
)
add_executable
(
client_conv2d_fwd_perchannel_quantization conv2d_fwd_perchannel_quantization.cpp
)
add_executable
(
client_conv2d_fwd_perchannel_quantization conv2d_fwd_perchannel_quantization.cpp
)
target_link_libraries
(
client_conv2d_fwd_perchannel_quantization PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_conv2d_fwd_perchannel_quantization PRIVATE composable_kernel::device_
conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_
operations
)
add_executable
(
client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp
)
add_executable
(
client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp
)
target_link_libraries
(
client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_
conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_
operations
)
add_executable
(
client_gemm_quantization gemm_quantization.cpp
)
add_executable
(
client_gemm_quantization gemm_quantization.cpp
)
target_link_libraries
(
client_gemm_quantization PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_gemm_quantization PRIVATE composable_kernel::device_
conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_
operations
)
endif
()
endif
()
client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp
View file @
29deceb6
...
@@ -80,7 +80,7 @@ int main(int argc, char* argv[])
...
@@ -80,7 +80,7 @@ int main(int argc, char* argv[])
SimpleDeviceMem
requant_scale
(
sizeof
(
RequantScaleDataType
)
*
G
*
K
);
SimpleDeviceMem
requant_scale
(
sizeof
(
RequantScaleDataType
)
*
G
*
K
);
SimpleDeviceMem
out
(
sizeof
(
OutDataType
)
*
N
*
Ho
*
Wo
*
G
*
K
);
SimpleDeviceMem
out
(
sizeof
(
OutDataType
)
*
N
*
Ho
*
Wo
*
G
*
K
);
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleD
<
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultiple
AB
D
<
NumDimSpatial
,
NumDimSpatial
,
InLayout
,
InLayout
,
WeiLayout
,
WeiLayout
,
...
...
client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp
View file @
29deceb6
...
@@ -78,18 +78,18 @@ int main(int argc, char* argv[])
...
@@ -78,18 +78,18 @@ int main(int argc, char* argv[])
SimpleDeviceMem
out
(
sizeof
(
OutDataType
)
*
N
*
Ho
*
Wo
*
G
*
K
);
SimpleDeviceMem
out
(
sizeof
(
OutDataType
)
*
N
*
Ho
*
Wo
*
G
*
K
);
using
DeviceOp
=
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleD
<
NumDimSpatial
,
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultiple
AB
D
<
NumDimSpatial
,
InLayout
,
InLayout
,
WeiLayout
,
WeiLayout
,
ck
::
Tuple
<
BiasLayout
>
,
ck
::
Tuple
<
BiasLayout
>
,
OutLayout
,
OutLayout
,
InDataType
,
InDataType
,
WeiDataType
,
WeiDataType
,
ck
::
Tuple
<
BiasDataType
>
,
ck
::
Tuple
<
BiasDataType
>
,
OutDataType
,
OutDataType
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
OutElementOp
>
;
OutElementOp
>
;
// get device op instances
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
DeviceOp
>::
GetInstances
();
...
...
client_example/09_quantization/conv2d_fwd_bias_tanh_perchannel_quantization.cpp
View file @
29deceb6
...
@@ -83,7 +83,7 @@ int main(int argc, char* argv[])
...
@@ -83,7 +83,7 @@ int main(int argc, char* argv[])
SimpleDeviceMem
requant_scale
(
sizeof
(
RequantScaleDataType
)
*
G
*
K
);
SimpleDeviceMem
requant_scale
(
sizeof
(
RequantScaleDataType
)
*
G
*
K
);
SimpleDeviceMem
out
(
sizeof
(
OutDataType
)
*
N
*
Ho
*
Wo
*
G
*
K
);
SimpleDeviceMem
out
(
sizeof
(
OutDataType
)
*
N
*
Ho
*
Wo
*
G
*
K
);
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleD
<
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultiple
AB
D
<
NumDimSpatial
,
NumDimSpatial
,
InLayout
,
InLayout
,
WeiLayout
,
WeiLayout
,
...
...
client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp
View file @
29deceb6
...
@@ -79,18 +79,18 @@ int main(int argc, char* argv[])
...
@@ -79,18 +79,18 @@ int main(int argc, char* argv[])
SimpleDeviceMem
out
(
sizeof
(
OutDataType
)
*
N
*
Ho
*
Wo
*
G
*
K
);
SimpleDeviceMem
out
(
sizeof
(
OutDataType
)
*
N
*
Ho
*
Wo
*
G
*
K
);
using
DeviceOp
=
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleD
<
NumDimSpatial
,
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultiple
AB
D
<
NumDimSpatial
,
InLayout
,
InLayout
,
WeiLayout
,
WeiLayout
,
ck
::
Tuple
<
BiasLayout
>
,
ck
::
Tuple
<
BiasLayout
>
,
OutLayout
,
OutLayout
,
InDataType
,
InDataType
,
WeiDataType
,
WeiDataType
,
ck
::
Tuple
<
BiasDataType
>
,
ck
::
Tuple
<
BiasDataType
>
,
OutDataType
,
OutDataType
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
OutElementOp
>
;
OutElementOp
>
;
// get device op instances
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
DeviceOp
>::
GetInstances
();
...
...
client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp
View file @
29deceb6
...
@@ -76,19 +76,19 @@ int main(int argc, char* argv[])
...
@@ -76,19 +76,19 @@ int main(int argc, char* argv[])
SimpleDeviceMem
requant_scale
(
sizeof
(
RequantScaleDataType
)
*
G
*
K
);
SimpleDeviceMem
requant_scale
(
sizeof
(
RequantScaleDataType
)
*
G
*
K
);
SimpleDeviceMem
out
(
sizeof
(
OutDataType
)
*
N
*
Ho
*
Wo
*
G
*
K
);
SimpleDeviceMem
out
(
sizeof
(
OutDataType
)
*
N
*
Ho
*
Wo
*
G
*
K
);
using
DeviceOp
=
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleABD
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleD
<
NumDimSpatial
,
NumDimSpatial
,
InLayout
,
InLayout
,
WeiLayout
,
WeiLayout
,
ck
::
Tuple
<
RequantScaleLayout
>
,
ck
::
Tuple
<
RequantScaleLayout
>
,
OutLayout
,
OutLayout
,
InDataType
,
InDataType
,
WeiDataType
,
WeiDataType
,
ck
::
Tuple
<
RequantScaleDataType
>
,
ck
::
Tuple
<
RequantScaleDataType
>
,
OutDataType
,
OutDataType
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
OutElementOp
>
;
OutElementOp
>
;
// get device op instances
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
DeviceOp
>::
GetInstances
();
...
...
Prev
1
2
3
4
5
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment