Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4885c38a
Commit
4885c38a
authored
Sep 03, 2024
by
aska-0096
Browse files
Merge branch 'transpose_opt' of
https://github.com/ROCm/composable_kernel
into rowwise_opt
parents
cbf14ee1
7c8e92fa
Changes
83
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1730 additions
and
55 deletions
+1730
-55
CMakeLists.txt
CMakeLists.txt
+5
-6
Jenkinsfile
Jenkinsfile
+37
-11
client_example/24_grouped_conv_activation/CMakeLists.txt
client_example/24_grouped_conv_activation/CMakeLists.txt
+20
-4
client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/common.hpp
...activation/grouped_convnd_fwd_convscale_reduce/common.hpp
+834
-0
client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/conv3d_fwd_convscale_amax_fp8.cpp
...nd_fwd_convscale_reduce/conv3d_fwd_convscale_amax_fp8.cpp
+58
-0
client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/conv3d_fwd_convscale_relu_amax_fp8.cpp
...d_convscale_reduce/conv3d_fwd_convscale_relu_amax_fp8.cpp
+58
-0
codegen/CMakeLists.txt
codegen/CMakeLists.txt
+2
-2
codegen/test/CMakeLists.txt
codegen/test/CMakeLists.txt
+16
-12
codegen/test/rtc/CMakeLists.txt
codegen/test/rtc/CMakeLists.txt
+0
-2
docs/sphinx/requirements.in
docs/sphinx/requirements.in
+1
-1
docs/sphinx/requirements.txt
docs/sphinx/requirements.txt
+1
-1
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+1
-0
example/01_gemm/gemm_xdl_bf16_v3.cpp
example/01_gemm/gemm_xdl_bf16_v3.cpp
+9
-9
example/01_gemm/gemm_xdl_fp8.cpp
example/01_gemm/gemm_xdl_fp8.cpp
+2
-2
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+5
-5
example/62_convnd_activ/CMakeLists.txt
example/62_convnd_activ/CMakeLists.txt
+1
-0
example/62_convnd_activ/convscale_reduce/CMakeLists.txt
example/62_convnd_activ/convscale_reduce/CMakeLists.txt
+14
-0
example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp
...v/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp
+502
-0
example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_amax_fp8.cpp
...iv/convscale_reduce/convnd_fwd_xdl_convscale_amax_fp8.cpp
+82
-0
example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_relu_amax_fp8.cpp
...nvscale_reduce/convnd_fwd_xdl_convscale_relu_amax_fp8.cpp
+82
-0
No files found.
CMakeLists.txt
View file @
4885c38a
...
...
@@ -553,12 +553,7 @@ if(NOT DEFINED INSTANCES_ONLY)
PACKAGE_NAME examples
)
add_subdirectory
(
example
)
if
(
GPU_TARGETS MATCHES
"gfx9"
AND NOT INSTANCES_ONLY
)
add_subdirectory
(
codegen
)
endif
()
if
(
BUILD_TESTING
)
add_subdirectory
(
test
)
endif
()
add_subdirectory
(
test
)
rocm_package_setup_component
(
profiler
LIBRARY_NAME composablekernel
...
...
@@ -575,6 +570,10 @@ if(NOT DEFINED INSTANCES_ONLY)
endif
()
endif
()
if
(
NOT DEFINED PROFILER_ONLY
AND
(
GPU_TARGETS MATCHES
"gfx9"
OR DEFINED INSTANCES_ONLY
))
add_subdirectory
(
codegen
)
endif
()
#Create an interface target for the include only files and call it "composablekernels"
include
(
CMakePackageConfigHelpers
)
...
...
Jenkinsfile
View file @
4885c38a
...
...
@@ -262,10 +262,19 @@ def cmake_build(Map conf=[:]){
// reduce parallelism when compiling, clang uses too much memory
def
nt
=
nthreads
()
def
cmd
def
setup_cmd
def
build_cmd
def
execute_cmd
=
conf
.
get
(
"execute_cmd"
,
""
)
if
(!
setup_args
.
contains
(
"NO_CK_BUILD"
)){
def
setup_cmd
=
conf
.
get
(
"setup_cmd"
,
"${cmake_envs} cmake ${setup_args} .. "
)
def
build_cmd
=
conf
.
get
(
"build_cmd"
,
"${build_envs} dumb-init make -j${nt} ${config_targets}"
)
if
(
setup_args
.
contains
(
"gfx90a"
)
&&
params
.
NINJA_BUILD_TRACE
){
echo
"running ninja build trace"
setup_cmd
=
conf
.
get
(
"setup_cmd"
,
"${cmake_envs} cmake -G Ninja ${setup_args} .. "
)
build_cmd
=
conf
.
get
(
"build_cmd"
,
"${build_envs} ninja -j${nt} ${config_targets}"
)
}
else
{
setup_cmd
=
conf
.
get
(
"setup_cmd"
,
"${cmake_envs} cmake ${setup_args} .. "
)
build_cmd
=
conf
.
get
(
"build_cmd"
,
"${build_envs} dumb-init make -j${nt} ${config_targets}"
)
}
cmd
=
conf
.
get
(
"cmd"
,
"""
${setup_cmd}
${build_cmd}
...
...
@@ -281,7 +290,19 @@ def cmake_build(Map conf=[:]){
echo
cmd
dir
(
"build"
){
//build CK
sh
cmd
//run tests
if
(!
setup_args
.
contains
(
"NO_CK_BUILD"
)){
if
(
setup_args
.
contains
(
"gfx90a"
)
&&
params
.
NINJA_BUILD_TRACE
){
sh
"/ninjatracing/ninjatracing .ninja_log > ck_build_trace.json"
archiveArtifacts
"ck_build_trace.json"
sh
"ninja test"
}
else
{
sh
"make check"
}
}
}
// Only archive from master or develop
...
...
@@ -543,7 +564,7 @@ def Build_CK(Map conf=[:]){
cmake_build
(
conf
)
dir
(
"build"
){
//run tests and examples
sh
'make -j check'
//
sh 'make -j check'
if
(
params
.
RUN_PERFORMANCE_TESTS
&&
do_perf_tests
==
0
){
//we only need the ckProfiler to run the performance tests, so we pack and stash it
//do not stash profiler on nodes where we don't need to run performance tests
...
...
@@ -684,8 +705,8 @@ def process_results(Map conf=[:]){
//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version
CRON_SETTINGS
=
BRANCH_NAME
==
"develop"
?
'''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.2; RUN_CK_TILE_TESTS=true
0 21 * * * % ROCMVERSION=6.2;hipTensor_test=true
0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false
0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false
0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false
;NINJA_BUILD_TRACE=true
0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false
;NINJA_BUILD_TRACE=true
0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_CODEGEN_TESTS=false;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false'''
:
""
pipeline
{
...
...
@@ -765,7 +786,10 @@ pipeline {
name:
"BUILD_GFX12"
,
defaultValue:
false
,
description:
"Build CK and run tests on gfx12 (default: OFF)"
)
booleanParam
(
name:
"NINJA_BUILD_TRACE"
,
defaultValue:
false
,
description:
"Generate a ninja build trace (default: OFF)"
)
}
environment
{
dbuser
=
"${dbuser}"
...
...
@@ -799,6 +823,7 @@ pipeline {
}
agent
{
label
rocmnode
(
"nogpu"
)
}
environment
{
setup_args
=
"NO_CK_BUILD"
execute_cmd
=
"find .. -not -path \'*.git*\' -iname \'*.h\' \
-o -not -path \'*.git*\' -iname \'*.hpp\' \
-o -not -path \'*.git*\' -iname \'*.cpp\' \
...
...
@@ -815,7 +840,7 @@ pipeline {
--file-filter=*.cpp --force --enable=all --output-file=ck_cppcheck.log"
}
steps
{
buildHipClangJobAndReboot
(
setup_cmd:
""
,
build_cmd:
""
,
execute_cmd:
execute_cmd
,
no_reboot:
true
)
buildHipClangJobAndReboot
(
setup_args:
setup_args
,
setup_cmd:
""
,
build_cmd:
""
,
execute_cmd:
execute_cmd
,
no_reboot:
true
)
archiveArtifacts
"build/ck_cppcheck.log"
cleanWs
()
}
...
...
@@ -827,6 +852,7 @@ pipeline {
}
agent
{
label
rocmnode
(
"nogpu"
)
}
environment
{
setup_args
=
"NO_CK_BUILD"
execute_cmd
=
"find .. -not -path \'*.git*\' -iname \'*.h\' \
-o -not -path \'*.git*\' -iname \'*.hpp\' \
-o -not -path \'*.git*\' -iname \'*.cpp\' \
...
...
@@ -838,7 +864,7 @@ pipeline {
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\'"
}
steps
{
buildHipClangJobAndReboot
(
setup_cmd:
""
,
build_cmd:
""
,
execute_cmd:
execute_cmd
,
no_reboot:
true
)
buildHipClangJobAndReboot
(
setup_args:
setup_args
,
setup_cmd:
""
,
build_cmd:
""
,
execute_cmd:
execute_cmd
,
no_reboot:
true
)
cleanWs
()
}
}
...
...
@@ -967,10 +993,10 @@ pipeline {
}
agent
{
label
rocmnode
(
"gfx90a"
)
}
environment
{
setup_args
=
""" -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="
gfx1100;
gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """
setup_args
=
""" -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """
execute_args
=
""" cd ../client_example && rm -rf build && mkdir build && cd build && \
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
-DGPU_TARGETS="
gfx1100;
gfx90a" \
-DGPU_TARGETS="gfx90a" \
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
}
...
...
@@ -1074,7 +1100,7 @@ pipeline {
options
{
retry
(
1
)
}
agent
{
label
rocmnode
(
"gfx90a"
)}
environment
{
setup_args
=
"
"" -DGPU_TARGETS="gfx90a" -DBUILD_DEV=On ""
"
setup_args
=
"
NO_CK_BUILD
"
}
steps
{
runPerfTest
(
setup_args:
setup_args
,
config_targets:
"ckProfiler"
,
no_reboot:
true
,
build_type:
'Release'
)
...
...
client_example/24_grouped_conv_activation/CMakeLists.txt
View file @
4885c38a
if
(
GPU_TARGETS MATCHES
"gfx9"
)
# Fwd scaleadd scaleadd relu
add_executable
(
client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32
add_executable
(
client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32
grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp
)
target_link_libraries
(
client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32 PRIVATE composable_kernel::device_conv_operations
)
...
...
@@ -36,7 +36,7 @@ add_executable(client_grouped_convnd_fwd_bilinear_residual_fp16
grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp
)
target_link_libraries
(
client_grouped_convnd_fwd_bilinear_residual_fp16 PRIVATE composable_kernel::device_conv_operations
)
# Fwd convinvscale
add_executable
(
client_conv3d_fwd_convinvscale_fp8
add_executable
(
client_conv3d_fwd_convinvscale_fp8
grouped_convnd_fwd_convinvscale/conv3d_fwd_convinvscale_fp8.cpp
)
target_link_libraries
(
client_conv3d_fwd_convinvscale_fp8 PRIVATE composable_kernel::device_conv_operations
)
# Fwd convscale + Bias
...
...
@@ -47,6 +47,22 @@ target_link_libraries(client_conv3d_fwd_convscale_add_fp8 PRIVATE composable_ker
add_executable
(
client_conv3d_fwd_convscale_relu_fp8
grouped_convnd_fwd_convscale_relu/conv3d_fwd_convscale_relu_fp8.cpp
)
target_link_libraries
(
client_conv3d_fwd_convscale_relu_fp8 PRIVATE composable_kernel::device_conv_operations
)
# Fwd convscale + ReLU + AMAX
add_executable
(
client_conv3d_fwd_convscale_relu_amax_fp8
grouped_convnd_fwd_convscale_reduce/conv3d_fwd_convscale_relu_amax_fp8.cpp
)
target_link_libraries
(
client_conv3d_fwd_convscale_relu_amax_fp8
PRIVATE composable_kernel::device_conv_operations
composable_kernel::device_other_operations
composable_kernel::device_reduction_operations
utility
)
# Fwd convscale + AMAX
add_executable
(
client_conv3d_fwd_convscale_amax_fp8
grouped_convnd_fwd_convscale_reduce/conv3d_fwd_convscale_amax_fp8.cpp
)
target_link_libraries
(
client_conv3d_fwd_convscale_amax_fp8
PRIVATE composable_kernel::device_conv_operations
composable_kernel::device_other_operations
composable_kernel::device_reduction_operations
utility
)
# Fwd convscale
add_executable
(
client_conv3d_fwd_convscale_fp8
grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8.cpp
)
...
...
@@ -56,11 +72,11 @@ add_executable(client_conv3d_fwd_convscale_bf8
grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8.cpp
)
target_link_libraries
(
client_conv3d_fwd_convscale_bf8 PRIVATE composable_kernel::device_conv_operations
)
add_executable
(
client_conv3d_fwd_convscale_fp8_bf8
add_executable
(
client_conv3d_fwd_convscale_fp8_bf8
grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8_bf8.cpp
)
target_link_libraries
(
client_conv3d_fwd_convscale_fp8_bf8 PRIVATE composable_kernel::device_conv_operations
)
add_executable
(
client_conv3d_fwd_convscale_bf8_fp8
add_executable
(
client_conv3d_fwd_convscale_bf8_fp8
grouped_convnd_fwd_convscale/conv3d_fwd_convscale_bf8_fp8.cpp
)
target_link_libraries
(
client_conv3d_fwd_convscale_bf8_fp8 PRIVATE composable_kernel::device_conv_operations
)
# Bwd data bilinear
...
...
client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/common.hpp
0 → 100644
View file @
4885c38a
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstdlib>
#include <iomanip>
#include <iostream>
#include <iterator>
#include <numeric>
#include <string>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/type.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale_relu.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/library/tensor_operation_instance/gpu/permute_scale.hpp"
#include "ck/library/tensor_operation_instance/gpu/reduce/reduce.hpp"
#include "ck/library/utility/host_tensor.hpp"
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ConvScaleRelu
=
ck
::
tensor_operation
::
element_wise
::
ScaleScaleRelu
;
using
ConvScale
=
ck
::
tensor_operation
::
element_wise
::
ScaleScalePass
;
struct
SimpleDeviceMem
{
SimpleDeviceMem
()
=
delete
;
SimpleDeviceMem
(
std
::
size_t
mem_size
)
:
p_mem_
{}
{
(
void
)
hipMalloc
(
static_cast
<
void
**>
(
&
p_mem_
),
mem_size
);
}
void
*
GetDeviceBuffer
()
{
return
p_mem_
;
}
~
SimpleDeviceMem
()
{
(
void
)
hipFree
(
p_mem_
);
}
void
*
p_mem_
;
};
template
<
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumNonSpatialDim
=
3
>
std
::
size_t
GetFlops
(
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
output_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
weights_lengths
,
const
std
::
size_t
&
ds_size
)
{
// 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product> +
// + ds_size * <output tensor size> =>
// => <output tensor size> * ( 2 * C * <filter spatial lengths product> + ds_size) =>
// => G * N * K * <output spatial lengths product> * (2 * C * <filter spatial lengths product> +
// ds_size)
ck
::
index_t
G
=
weights_lengths
[
0
];
ck
::
index_t
N
=
output_lengths
[
1
];
ck
::
index_t
K
=
weights_lengths
[
1
];
ck
::
index_t
C
=
weights_lengths
[
2
];
return
G
*
N
*
K
*
std
::
accumulate
(
std
::
next
(
std
::
begin
(
output_lengths
),
NumNonSpatialDim
),
std
::
end
(
output_lengths
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<>
())
*
(
ds_size
+
static_cast
<
std
::
size_t
>
(
2
)
*
C
*
std
::
accumulate
(
std
::
next
(
std
::
begin
(
weights_lengths
),
NumNonSpatialDim
),
std
::
end
(
weights_lengths
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<>
()));
}
template
<
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumNonSpatialDim
=
3
>
std
::
size_t
GetTensorSize
(
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
lengths
)
{
return
std
::
accumulate
(
std
::
begin
(
lengths
),
std
::
end
(
lengths
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<
std
::
size_t
>
());
}
template
<
typename
InDataType
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumNonSpatialDim
=
3
>
std
::
size_t
GetInputByte
(
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
input_lengths
)
{
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
return
sizeof
(
InDataType
)
*
GetTensorSize
<
NumDimSpatial
>
(
input_lengths
);
}
template
<
typename
WeiDataType
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumNonSpatialDim
=
3
>
std
::
size_t
GetWeightByte
(
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
weights_lengths
)
{
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
return
sizeof
(
WeiDataType
)
*
GetTensorSize
<
NumDimSpatial
>
(
weights_lengths
);
}
template
<
typename
OutDataType
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumNonSpatialDim
=
3
>
std
::
size_t
GetOutputByte
(
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
output_lengths
)
{
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
return
sizeof
(
OutDataType
)
*
GetTensorSize
<
NumDimSpatial
>
(
output_lengths
);
}
template
<
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
ConvElementOp
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumNonSpatialDim
=
3
,
typename
AComputeType
=
InDataType
,
typename
BComputeType
=
AComputeType
>
bool
ConvolutionScale
(
SimpleDeviceMem
&
in
,
SimpleDeviceMem
&
wei
,
SimpleDeviceMem
&
out
,
ConvElementOp
elementwise_op
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
in_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
in_strides
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
wei_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
wei_strides
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
out_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
out_strides
);
template
<
typename
InDataType
,
typename
OutDataType
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumNonSpatialDim
=
3
>
bool
TensorScaleConvert
(
SimpleDeviceMem
&
in
,
SimpleDeviceMem
&
out
,
float
scale_out
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
strides
);
template
<
typename
InDataType
,
typename
OutDataType
,
ck
::
ReduceTensorOp
ReduceOpId
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumNonSpatialDim
=
3
>
bool
TensorFullReduction
(
SimpleDeviceMem
&
tensor
,
SimpleDeviceMem
&
out_amax
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
strides
);
template
<
ck
::
index_t
NumDimSpatial
,
typename
InDataType
,
typename
WeiDataType
,
typename
ConvOutDataType
,
typename
OutDataType
,
typename
ConvElementOp
,
ck
::
ReduceTensorOp
ReduceOp
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
ck
::
index_t
NumNonSpatialDim
=
3
,
typename
AComputeType
=
InDataType
,
typename
BComputeType
=
AComputeType
>
bool
run_grouped_conv_fwd_convscale_reduce
(
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>
in_lengths
,
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>
wei_lengths
,
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>
out_lengths
)
{
namespace
ctc
=
ck
::
tensor_layout
::
convolution
;
static_assert
(
NumDimSpatial
==
3
&&
ck
::
is_same_v
<
InLayout
,
ctc
::
NDHWGC
>
&&
ck
::
is_same_v
<
WeiLayout
,
ctc
::
GKZYXC
>
&&
ck
::
is_same_v
<
OutLayout
,
ctc
::
NDHWGK
>
,
"Unsupported configuration"
);
const
ck
::
index_t
G
=
in_lengths
[
4
];
const
ck
::
index_t
N
=
in_lengths
[
0
];
const
ck
::
index_t
K
=
wei_lengths
[
1
];
const
ck
::
index_t
C
=
in_lengths
[
5
];
const
ck
::
index_t
Z
=
wei_lengths
[
2
];
const
ck
::
index_t
Y
=
wei_lengths
[
3
];
const
ck
::
index_t
X
=
wei_lengths
[
4
];
const
ck
::
index_t
Di
=
in_lengths
[
1
];
const
ck
::
index_t
Hi
=
in_lengths
[
2
];
const
ck
::
index_t
Wi
=
in_lengths
[
3
];
const
ck
::
index_t
Do
=
out_lengths
[
1
];
const
ck
::
index_t
Ho
=
out_lengths
[
2
];
const
ck
::
index_t
Wo
=
out_lengths
[
3
];
const
std
::
size_t
in_mem_size
=
sizeof
(
InDataType
)
*
N
*
Di
*
Hi
*
Wi
*
G
*
C
;
const
std
::
size_t
wei_mem_size
=
sizeof
(
WeiDataType
)
*
G
*
K
*
Z
*
Y
*
X
*
C
;
const
std
::
size_t
conv_out_mem_size
=
sizeof
(
ConvOutDataType
)
*
N
*
Do
*
Ho
*
Wo
*
G
*
K
;
const
std
::
size_t
out_mem_size
=
sizeof
(
OutDataType
)
*
N
*
Do
*
Ho
*
Wo
*
G
*
K
;
SimpleDeviceMem
in
(
in_mem_size
);
SimpleDeviceMem
wei
(
wei_mem_size
);
SimpleDeviceMem
conv_out
(
conv_out_mem_size
);
SimpleDeviceMem
out
(
out_mem_size
);
float
scale_in
=
float
(
std
::
rand
())
/
float
(
RAND_MAX
);
float
scale_wei
=
float
(
std
::
rand
())
/
float
(
RAND_MAX
);
float
scale_out
=
float
(
std
::
rand
())
/
float
(
RAND_MAX
);
// We have NDHWGC/GKZYXC/NDHWGK (x, weight, y) in memory space.
// However, CK's API only accepts lengths and strides with order of GNCDHW/GKCZYX/GNKDHW.
// Hence, we need to adjust the order of strides.
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_lengths
{
G
,
N
,
C
,
Di
,
Hi
,
Wi
};
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
input_strides
{
C
,
Di
*
Hi
*
Wi
*
G
*
C
,
1
,
Hi
*
Wi
*
G
*
C
,
Wi
*
G
*
C
,
G
*
C
};
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
weights_lengths
{
G
,
K
,
C
,
Z
,
Y
,
X
};
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
weights_strides
{
K
*
Z
*
Y
*
X
*
C
,
Z
*
Y
*
X
*
C
,
1
,
Y
*
X
*
C
,
X
*
C
,
C
};
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_lengths
{
G
,
N
,
K
,
Do
,
Ho
,
Wo
};
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
output_strides
{
K
,
Do
*
Ho
*
Wo
*
G
*
K
,
1
,
Ho
*
Wo
*
G
*
K
,
Wo
*
G
*
K
,
G
*
K
};
/*
* FP8 Convolution with Scaling
*/
std
::
cout
<<
"
\n\n
Convolution with scale Benchmarking:"
<<
std
::
endl
;
auto
elementwise_op
=
ConvElementOp
{
ck
::
tensor_operation
::
element_wise
::
Scale
{
scale_in
},
ck
::
tensor_operation
::
element_wise
::
Scale
{
scale_wei
},
{}};
auto
conv_ok
=
ConvolutionScale
<
InDataType
,
WeiDataType
,
ConvOutDataType
,
ConvElementOp
,
InLayout
,
WeiLayout
,
OutLayout
,
NumDimSpatial
>
(
in
,
wei
,
conv_out
,
elementwise_op
,
input_lengths
,
input_strides
,
weights_lengths
,
weights_strides
,
output_lengths
,
output_strides
);
if
(
!
conv_ok
)
return
false
;
/*
* Scale with output weight and convert to FP8
*/
std
::
cout
<<
"
\n\n
Element-wise scale + convert Benchmarking:"
<<
std
::
endl
;
auto
elem_wise_ok
=
TensorScaleConvert
<
ConvOutDataType
,
OutDataType
,
NumDimSpatial
>
(
conv_out
,
out
,
scale_out
,
output_lengths
,
output_strides
);
if
(
!
elem_wise_ok
)
return
false
;
/*
* Compute AMAX
*/
std
::
cout
<<
"
\n\n
AMAX Benchmarking:"
<<
std
::
endl
;
SimpleDeviceMem
amax_device
(
sizeof
(
ConvOutDataType
));
auto
reduction_ok
=
TensorFullReduction
<
ConvOutDataType
,
ConvOutDataType
,
ck
::
ReduceTensorOp
::
AMAX
,
NumDimSpatial
>
(
conv_out
,
amax_device
,
output_lengths
,
output_strides
);
if
(
!
reduction_ok
)
return
false
;
return
true
;
}
template
<
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
ConvElementOp
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumNonSpatialDim
,
typename
AComputeType
,
typename
BComputeType
>
bool
ConvolutionScale
(
SimpleDeviceMem
&
in
,
SimpleDeviceMem
&
wei
,
SimpleDeviceMem
&
out
,
ConvElementOp
elementwise_op
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
in_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
in_strides
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
wei_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
wei_strides
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
out_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
out_strides
)
{
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_strides
{
1
,
1
,
1
};
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_dilations
{
1
,
1
,
1
};
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_left_pads
{
1
,
1
,
1
};
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_right_pads
{
1
,
1
,
1
};
const
auto
in_mem_size
=
GetInputByte
<
InDataType
,
NumDimSpatial
>
(
in_lengths
);
const
auto
wei_mem_size
=
GetWeightByte
<
WeiDataType
,
NumDimSpatial
>
(
wei_lengths
);
const
auto
out_mem_size
=
GetOutputByte
<
OutDataType
,
NumDimSpatial
>
(
out_lengths
);
std
::
size_t
ds_size
=
2
;
// 2 element-wise scale multipliers
if
constexpr
(
ck
::
is_same_v
<
ConvElementOp
,
ConvScaleRelu
>
)
{
ds_size
+=
1
;
// +1 element-wise relu
}
std
::
size_t
flop
=
GetFlops
<
NumDimSpatial
>
(
out_lengths
,
wei_lengths
,
ds_size
);
std
::
size_t
num_bytes
=
in_mem_size
+
wei_mem_size
+
sizeof
(
float
)
+
sizeof
(
float
)
+
out_mem_size
;
using
ConvDeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
ck
::
Tuple
<>
,
OutLayout
,
InDataType
,
WeiDataType
,
ck
::
Tuple
<>
,
OutDataType
,
PassThrough
,
PassThrough
,
ConvElementOp
,
AComputeType
,
BComputeType
>
;
// get device op instances
const
auto
conv_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
ConvDeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
conv_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
std
::
string
conv_best_op_name
;
int
conv_best_op_id
=
-
1
;
float
conv_best_avg_time
=
std
::
numeric_limits
<
float
>::
max
();
float
conv_best_gb_per_sec
=
0
;
float
conv_best_tflops
=
0
;
// profile device operation instances
std
::
cout
<<
"Run all convolution instances and do timing"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
conv_ptrs
.
size
();
++
i
)
{
auto
&
op_ptr
=
conv_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
in
.
GetDeviceBuffer
(),
wei
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
0
>
{},
out
.
GetDeviceBuffer
(),
in_lengths
,
in_strides
,
wei_lengths
,
wei_strides
,
std
::
array
<
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>
,
0
>
{},
std
::
array
<
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>
,
0
>
{},
out_lengths
,
out_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
PassThrough
{},
PassThrough
{},
elementwise_op
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
avg_time
;
float
gb_per_sec
=
num_bytes
/
1.E6
/
avg_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
avg_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
if
(
tflops
>
conv_best_tflops
)
{
conv_best_op_id
=
i
;
conv_best_op_name
=
op_name
;
conv_best_avg_time
=
avg_time
;
conv_best_gb_per_sec
=
gb_per_sec
;
conv_best_tflops
=
tflops
;
}
}
else
{
std
::
cerr
<<
op_name
<<
" does not support this problem"
<<
std
::
endl
;
}
}
if
(
conv_best_op_id
<
0
)
{
std
::
cerr
<<
"no suitable instance"
<<
std
::
endl
;
return
false
;
}
std
::
cout
<<
"Best Perf: "
<<
std
::
setw
(
10
)
<<
conv_best_avg_time
<<
" ms, "
<<
conv_best_tflops
<<
" TFlops, "
<<
conv_best_gb_per_sec
<<
" GB/s, "
<<
conv_best_op_name
<<
std
::
endl
;
// run the best instance
{
auto
&
op_ptr
=
conv_ptrs
[
conv_best_op_id
];
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
in
.
GetDeviceBuffer
(),
wei
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
0
>
{},
out
.
GetDeviceBuffer
(),
in_lengths
,
in_strides
,
wei_lengths
,
wei_strides
,
std
::
array
<
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>
,
0
>
{},
std
::
array
<
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>
,
0
>
{},
out_lengths
,
out_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
PassThrough
{},
PassThrough
{},
elementwise_op
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
}
std
::
cout
<<
"Done"
<<
std
::
endl
;
}
return
true
;
}
template
<
typename
InDataType
,
typename
OutDataType
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumNonSpatialDim
>
bool
TensorScaleConvert
(
SimpleDeviceMem
&
in
,
SimpleDeviceMem
&
out
,
float
scale_out
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
strides
)
{
const
auto
tensor_size
=
GetTensorSize
<
NumDimSpatial
>
(
lengths
);
const
std
::
size_t
in_mem_size
=
sizeof
(
InDataType
)
*
tensor_size
;
const
std
::
size_t
out_mem_size
=
sizeof
(
OutDataType
)
*
tensor_size
;
std
::
size_t
flop
=
2
*
tensor_size
;
// element-wise scale + convert
std
::
size_t
bytes
=
in_mem_size
+
sizeof
(
float
)
+
out_mem_size
;
// read from in, scale, write to out
using
DeviceScaleConvert
=
ck
::
tensor_operation
::
device
::
DeviceElementwise
<
ck
::
Tuple
<
InDataType
>
,
ck
::
Tuple
<
OutDataType
>
,
ck
::
tensor_operation
::
element_wise
::
Scale
,
NumDimSpatial
+
NumNonSpatialDim
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceScaleConvert
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
std
::
string
best_op_name
;
int
best_op_id
=
-
1
;
float
best_avg_time
=
std
::
numeric_limits
<
float
>::
max
();
float
best_gb_per_sec
=
0
;
float
best_tflops
=
0
;
// profile device operation instances
std
::
cout
<<
"Run all DeviceScaleConvert instances and do timing"
<<
std
::
endl
;
auto
scale_convert
=
ck
::
tensor_operation
::
element_wise
::
Scale
{
scale_out
};
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
lengths
,
{
strides
},
{
strides
},
{
in
.
GetDeviceBuffer
()},
{
out
.
GetDeviceBuffer
()},
scale_convert
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
avg_time
;
float
gb_per_sec
=
bytes
/
1.E6
/
avg_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
avg_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
if
(
tflops
>
best_tflops
)
{
best_op_id
=
i
;
best_op_name
=
op_name
;
best_avg_time
=
avg_time
;
best_gb_per_sec
=
gb_per_sec
;
best_tflops
=
tflops
;
}
}
else
{
std
::
cerr
<<
op_name
<<
" does not support this problem"
<<
std
::
endl
;
}
}
if
(
best_op_id
<
0
)
{
std
::
cerr
<<
"no suitable instance found."
<<
std
::
endl
;
return
false
;
}
else
{
std
::
cout
<<
"Best Perf: "
<<
std
::
setw
(
10
)
<<
best_avg_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
// run the best intance
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
lengths
,
{
strides
},
{
strides
},
{
in
.
GetDeviceBuffer
()},
{
out
.
GetDeviceBuffer
()},
scale_convert
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
}
std
::
cout
<<
"Done"
<<
std
::
endl
;
}
return
true
;
}
template
<
typename
InDataType
,
typename
OutDataType
,
ck
::
ReduceTensorOp
ReduceOpId
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumNonSpatialDim
>
bool
TensorFullReduction
(
SimpleDeviceMem
&
tensor
,
SimpleDeviceMem
&
out_amax
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
strides
)
{
const
auto
spatial_dim_size
=
std
::
accumulate
(
std
::
next
(
std
::
begin
(
lengths
),
NumNonSpatialDim
),
std
::
end
(
lengths
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<>
());
const
auto
tensor_size
=
GetTensorSize
<
NumDimSpatial
>
(
lengths
);
auto
copy
=
[](
const
auto
&
x
,
auto
&
y
)
{
ck
::
ranges
::
copy
(
x
,
y
.
begin
());
};
// Get the reduction operation
using
ReduceOperation
=
typename
ck
::
reduce_binary_operator
<
ReduceOpId
>::
opType
;
using
InElementwiseOperation
=
typename
ck
::
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
;
using
AccElementwiseOperation
=
typename
ck
::
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
;
InElementwiseOperation
in_elementwise_op
;
AccElementwiseOperation
acc_elementwise_op
;
std
::
tie
(
in_elementwise_op
,
acc_elementwise_op
)
=
ck
::
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
GetElementwiseOperator
(
static_cast
<
int32_t
>
(
tensor_size
));
std
::
array
<
ck
::
index_t
,
1
>
reduce_out_lengths
{
1
};
std
::
array
<
ck
::
index_t
,
1
>
reduce_out_strides
{
1
};
SimpleDeviceMem
partial_reduce_tensor
(
sizeof
(
OutDataType
)
*
spatial_dim_size
);
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
reduce_part_lengths
;
std
::
copy
(
std
::
next
(
std
::
begin
(
lengths
),
NumNonSpatialDim
),
std
::
end
(
lengths
),
std
::
begin
(
reduce_part_lengths
));
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
reduce_part_strides
;
copy
(
HostTensorDescriptor
(
reduce_part_lengths
).
GetStrides
(),
reduce_part_strides
);
{
std
::
cout
<<
"
\n
Reduction of nonspatial dimensions:"
<<
std
::
endl
;
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceReduce
<
InDataType
,
OutDataType
,
OutDataType
,
NumDimSpatial
+
NumNonSpatialDim
,
NumNonSpatialDim
,
ReduceOperation
,
InElementwiseOperation
,
PassThrough
,
true
,
// PropagateNan
false
>
;
// OutputIndex
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
std
::
string
best_op_name
;
int
best_op_id
=
-
1
;
float
best_ave_time
=
std
::
numeric_limits
<
float
>::
max
();
float
best_gb_per_sec
=
0
;
std
::
array
<
int
,
NumNonSpatialDim
>
reduce_dims
;
std
::
iota
(
reduce_dims
.
begin
(),
reduce_dims
.
end
(),
0
);
// 0,..., NumNonSpatialDim-1
ck
::
index_t
num_in_elements
=
tensor_size
;
ck
::
index_t
num_out_elements
=
spatial_dim_size
;
// profile device operation instances
std
::
cout
<<
"Run partial reduction and do timing"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
lengths
,
strides
,
reduce_part_lengths
,
reduce_part_strides
,
reduce_dims
,
1.0
,
0.0
,
tensor
.
GetDeviceBuffer
(),
nullptr
,
partial_reduce_tensor
.
GetDeviceBuffer
(),
nullptr
,
in_elementwise_op
,
PassThrough
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
std
::
size_t
num_bytes
=
num_in_elements
*
sizeof
(
InDataType
)
+
num_out_elements
*
sizeof
(
OutDataType
);
float
gb_per_sec
=
num_bytes
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
if
(
ave_time
<
best_ave_time
)
{
best_op_id
=
i
;
best_op_name
=
op_name
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
}
else
{
std
::
cout
<<
op_name
<<
" does not support this problem"
<<
std
::
endl
;
}
}
if
(
best_op_id
<
0
)
{
std
::
cerr
<<
"no suitable instance found."
<<
std
::
endl
;
return
false
;
}
else
{
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
// run the best instance
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
lengths
,
strides
,
reduce_part_lengths
,
reduce_part_strides
,
reduce_dims
,
1.0
,
0.0
,
tensor
.
GetDeviceBuffer
(),
nullptr
,
partial_reduce_tensor
.
GetDeviceBuffer
(),
nullptr
,
in_elementwise_op
,
PassThrough
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
}
std
::
cout
<<
"Done"
<<
std
::
endl
;
}
}
{
std
::
cout
<<
"
\n
Reduction of spatial dimensions:"
<<
std
::
endl
;
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceReduce
<
OutDataType
,
OutDataType
,
OutDataType
,
NumDimSpatial
,
NumDimSpatial
,
ReduceOperation
,
PassThrough
,
AccElementwiseOperation
,
true
,
// PropagateNan
false
>
;
// OutputIndex
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
std
::
string
best_op_name
;
int
best_op_id
=
-
1
;
float
best_ave_time
=
std
::
numeric_limits
<
float
>::
max
();
float
best_gb_per_sec
=
0
;
std
::
array
<
int
,
NumDimSpatial
>
reduce_dims
;
std
::
iota
(
reduce_dims
.
begin
(),
reduce_dims
.
end
(),
0
);
// 0,..., NumDimSpatial-1
ck
::
index_t
num_in_elements
=
spatial_dim_size
;
ck
::
index_t
num_out_elements
=
1
;
// profile device operation instances
std
::
cout
<<
"Run final reduction and do timing"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
reduce_part_lengths
,
reduce_part_strides
,
reduce_out_lengths
,
reduce_out_strides
,
reduce_dims
,
1.0
,
0.0
,
partial_reduce_tensor
.
GetDeviceBuffer
(),
nullptr
,
out_amax
.
GetDeviceBuffer
(),
nullptr
,
PassThrough
{},
acc_elementwise_op
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
std
::
size_t
num_bytes
=
num_in_elements
*
sizeof
(
OutDataType
)
+
num_out_elements
*
sizeof
(
OutDataType
);
float
gb_per_sec
=
num_bytes
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
if
(
ave_time
<
best_ave_time
)
{
best_op_id
=
i
;
best_op_name
=
op_name
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
}
else
{
std
::
cout
<<
op_name
<<
" does not support this problem"
<<
std
::
endl
;
}
}
if
(
best_op_id
<
0
)
{
std
::
cerr
<<
"no suitable instance found."
<<
std
::
endl
;
return
false
;
}
else
{
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
// run the best instance
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
reduce_part_lengths
,
reduce_part_strides
,
reduce_out_lengths
,
reduce_out_strides
,
reduce_dims
,
1.0
,
0.0
,
partial_reduce_tensor
.
GetDeviceBuffer
(),
nullptr
,
out_amax
.
GetDeviceBuffer
(),
nullptr
,
PassThrough
{},
acc_elementwise_op
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
}
std
::
cout
<<
"Done"
<<
std
::
endl
;
}
}
return
true
;
}
client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/conv3d_fwd_convscale_amax_fp8.cpp
0 → 100644
View file @
4885c38a
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
using
InDataType
=
ck
::
f8_t
;
using
WeiDataType
=
ck
::
f8_t
;
using
CShuffleDataType
=
float
;
using
ConvOutDataType
=
float
;
// data type of convolution result
using
OutDataType
=
ck
::
f8_t
;
// data type of final result
using
AComputeDataType
=
ck
::
f8_t
;
using
BComputeDataType
=
ck
::
f8_t
;
using
ConvElementOp
=
ConvScale
;
using
InLayout
=
ck
::
tensor_layout
::
convolution
::
NDHWGC
;
using
WeiLayout
=
ck
::
tensor_layout
::
convolution
::
GKZYXC
;
using
OutLayout
=
ck
::
tensor_layout
::
convolution
::
NDHWGK
;
constexpr
auto
ReduceOpId
=
ck
::
ReduceTensorOp
::
AMAX
;
static
constexpr
ck
::
index_t
NumDimSpatial
=
3
;
static
constexpr
ck
::
index_t
G
=
1
;
static
constexpr
ck
::
index_t
N
=
64
;
static
constexpr
ck
::
index_t
K
=
128
;
static
constexpr
ck
::
index_t
C
=
64
;
static
constexpr
ck
::
index_t
Z
=
3
;
static
constexpr
ck
::
index_t
Y
=
3
;
static
constexpr
ck
::
index_t
X
=
3
;
static
constexpr
ck
::
index_t
Di
=
28
;
static
constexpr
ck
::
index_t
Hi
=
28
;
static
constexpr
ck
::
index_t
Wi
=
3
;
static
constexpr
ck
::
index_t
Do
=
28
;
static
constexpr
ck
::
index_t
Ho
=
28
;
static
constexpr
ck
::
index_t
Wo
=
3
;
int
main
()
{
return
run_grouped_conv_fwd_convscale_reduce
<
NumDimSpatial
,
InDataType
,
WeiDataType
,
ConvOutDataType
,
OutDataType
,
ConvElementOp
,
ReduceOpId
,
InLayout
,
WeiLayout
,
OutLayout
,
3
,
AComputeDataType
,
BComputeDataType
>
(
{
N
,
Di
,
Hi
,
Wi
,
G
,
C
},
{
G
,
K
,
Z
,
Y
,
X
,
C
},
{
N
,
Do
,
Ho
,
Wo
,
G
,
K
})
?
EXIT_SUCCESS
:
EXIT_FAILURE
;
}
client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/conv3d_fwd_convscale_relu_amax_fp8.cpp
0 → 100644
View file @
4885c38a
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
using
InDataType
=
ck
::
f8_t
;
using
WeiDataType
=
ck
::
f8_t
;
using
CShuffleDataType
=
float
;
using
ConvOutDataType
=
float
;
// data type of convolution result
using
OutDataType
=
ck
::
f8_t
;
// data type of final result
using
AComputeDataType
=
ck
::
f8_t
;
using
BComputeDataType
=
ck
::
f8_t
;
using
ConvElementOp
=
ConvScaleRelu
;
using
InLayout
=
ck
::
tensor_layout
::
convolution
::
NDHWGC
;
using
WeiLayout
=
ck
::
tensor_layout
::
convolution
::
GKZYXC
;
using
OutLayout
=
ck
::
tensor_layout
::
convolution
::
NDHWGK
;
constexpr
auto
ReduceOpId
=
ck
::
ReduceTensorOp
::
AMAX
;
static
constexpr
ck
::
index_t
NumDimSpatial
=
3
;
static
constexpr
ck
::
index_t
G
=
1
;
static
constexpr
ck
::
index_t
N
=
64
;
static
constexpr
ck
::
index_t
K
=
128
;
static
constexpr
ck
::
index_t
C
=
64
;
static
constexpr
ck
::
index_t
Z
=
3
;
static
constexpr
ck
::
index_t
Y
=
3
;
static
constexpr
ck
::
index_t
X
=
3
;
static
constexpr
ck
::
index_t
Di
=
28
;
static
constexpr
ck
::
index_t
Hi
=
28
;
static
constexpr
ck
::
index_t
Wi
=
3
;
static
constexpr
ck
::
index_t
Do
=
28
;
static
constexpr
ck
::
index_t
Ho
=
28
;
static
constexpr
ck
::
index_t
Wo
=
3
;
int
main
()
{
return
run_grouped_conv_fwd_convscale_reduce
<
NumDimSpatial
,
InDataType
,
WeiDataType
,
ConvOutDataType
,
OutDataType
,
ConvElementOp
,
ReduceOpId
,
InLayout
,
WeiLayout
,
OutLayout
,
3
,
AComputeDataType
,
BComputeDataType
>
(
{
N
,
Di
,
Hi
,
Wi
,
G
,
C
},
{
G
,
K
,
Z
,
Y
,
X
,
C
},
{
N
,
Do
,
Ho
,
Wo
,
G
,
K
})
?
EXIT_SUCCESS
:
EXIT_FAILURE
;
}
codegen/CMakeLists.txt
View file @
4885c38a
...
...
@@ -27,6 +27,8 @@ file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS
add_embed_library
(
ck_headers
${
KERNEL_FILES
}
RELATIVE
${
CK_ROOT
}
/include
)
file
(
GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp
)
##message(STATUS "SOURCE_FILES: ${SOURCES}")
# TODO: Use object library
add_library
(
ck_host STATIC
${
SOURCES
}
)
target_link_libraries
(
ck_host PRIVATE ck_headers
)
...
...
@@ -48,6 +50,4 @@ rocm_install(
)
rocm_install
(
DIRECTORY include/ck DESTINATION
${
CMAKE_INSTALL_INCLUDEDIR
}
)
if
(
BUILD_TESTING
)
add_subdirectory
(
test
)
endif
()
codegen/test/CMakeLists.txt
View file @
4885c38a
list
(
APPEND CMAKE_PREFIX_PATH /opt/rocm
)
add_subdirectory
(
rtc
)
file
(
GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp
)
foreach
(
TEST_SRC
${
TEST_SRCS
}
)
set_source_files_properties
(
${
TEST_SRC
}
PROPERTIES LANGUAGE HIP
)
get_filename_component
(
BASE_NAME
${
TEST_SRC
}
NAME_WE
)
add_executable
(
test_host_
${
BASE_NAME
}
${
TEST_SRC
}
)
add_dependencies
(
codegen test_host_
${
BASE_NAME
}
)
add_test
(
NAME codegen_test_
${
BASE_NAME
}
COMMAND test_host_
${
BASE_NAME
}
)
target_link_libraries
(
test_host_
${
BASE_NAME
}
ck_rtc ck_host
)
# target_link_libraries(test_host_${BASE_NAME} ${CK_ROOT}/build/lib/libutility.a)
target_include_directories
(
test_host_
${
BASE_NAME
}
PUBLIC
include
())
target_include_directories
(
test_host_
${
BASE_NAME
}
PUBLIC
${
CK_ROOT
}
/include
)
target_include_directories
(
test_host_
${
BASE_NAME
}
PUBLIC
${
CK_ROOT
}
/library/include
)
endforeach
()
if
(
NOT INSTANCES_ONLY
)
foreach
(
TEST_SRC
${
TEST_SRCS
}
)
set_source_files_properties
(
${
TEST_SRC
}
PROPERTIES LANGUAGE HIP
)
get_filename_component
(
BASE_NAME
${
TEST_SRC
}
NAME_WE
)
add_executable
(
codegen_test_
${
BASE_NAME
}
${
TEST_SRC
}
)
add_dependencies
(
codegen codegen_test_
${
BASE_NAME
}
)
add_dependencies
(
tests codegen_test_
${
BASE_NAME
}
)
add_dependencies
(
check codegen_test_
${
BASE_NAME
}
)
add_test
(
NAME codegen_test_
${
BASE_NAME
}
COMMAND codegen_test_
${
BASE_NAME
}
)
message
(
"adding test codegen_test_
${
BASE_NAME
}
"
)
target_link_libraries
(
codegen_test_
${
BASE_NAME
}
ck_rtc ck_host
)
target_include_directories
(
codegen_test_
${
BASE_NAME
}
PUBLIC
${
CK_ROOT
}
/codegen/test/include
)
target_include_directories
(
codegen_test_
${
BASE_NAME
}
PUBLIC
${
CK_ROOT
}
/include
)
target_include_directories
(
codegen_test_
${
BASE_NAME
}
PUBLIC
${
CK_ROOT
}
/library/include
)
endforeach
()
endif
()
codegen/test/rtc/CMakeLists.txt
View file @
4885c38a
find_package
(
hip
)
file
(
GLOB RTC_SOURCES CONFIGURE_DEPENDS src/*.cpp
)
add_library
(
ck_rtc
${
RTC_SOURCES
}
)
target_include_directories
(
ck_rtc PUBLIC include
)
...
...
docs/sphinx/requirements.in
View file @
4885c38a
rocm-docs-core==1.7.
0
rocm-docs-core==1.7.
2
sphinxcontrib-bibtex==2.6.2
docs/sphinx/requirements.txt
View file @
4885c38a
...
...
@@ -103,7 +103,7 @@ requests==2.32.3
# via
# pygithub
# sphinx
rocm-docs-core==1.7.
0
rocm-docs-core==1.7.
2
# via -r requirements.in
six==1.16.0
# via pybtex
...
...
example/01_gemm/CMakeLists.txt
View file @
4885c38a
...
...
@@ -34,6 +34,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_v3)
add_example_executable
(
example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp
)
target_compile_options
(
example_gemm_xdl_bf16_v3 PRIVATE -mllvm -greedy-reverse-local-assignment=1 -save-temps=$PWD -Wno-gnu-line-marker
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16_v3
)
target_compile_options
(
example_gemm_xdl_bf16_v3 PRIVATE -mllvm -greedy-reverse-local-assignment=1 -save-temps=$PWD -Wno-gnu-line-marker
)
add_example_executable
(
example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_wavelet_fp16
)
...
...
example/01_gemm/gemm_xdl_bf16_v3.cpp
View file @
4885c38a
...
...
@@ -12,7 +12,7 @@ using CShuffleDataType = ck::bhalf_t;
using
CDataType
=
ck
::
bhalf_t
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
BLayout
=
Row
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
...
...
@@ -28,15 +28,15 @@ using DeviceGemmV2Instance =
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
2
56
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
24
,
256
,
64
,
8
,
1
,
16
,
16
,
7
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
1
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v3
>
;
// clang-format on
...
...
example/01_gemm/gemm_xdl_fp8.cpp
View file @
4885c38a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
...
...
@@ -7,7 +7,7 @@
using
ADataType
=
ck
::
f8_t
;
using
BDataType
=
ck
::
f8_t
;
using
CDataType
=
ck
::
hal
f_t
;
using
CDataType
=
ck
::
f
8
_t
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
float
;
...
...
example/01_gemm/run_gemm_example.inc
View file @
4885c38a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -34,11 +34,11 @@ inline __host__ __device__ constexpr double get_rtol()
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
f8_t
>
)
{
return
1
e
-
1
;
// 240 and 224 are acceptable
return
2
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bf8_t
>
)
{
return
1.5
e-1
;
// 57344 and 49152 are acceptable
return
2
e
-
1
;
}
else
{
...
...
@@ -75,11 +75,11 @@ inline __host__ __device__ constexpr double get_atol()
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
f8_t
>
)
{
return
16.1
;
// 240 and 224 are acceptable
return
2
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bf8_t
>
)
{
return
8192.1
;
// 57344 and 49152 are acceptable
return
2
e
-
1
;
}
else
{
...
...
example/62_convnd_activ/CMakeLists.txt
View file @
4885c38a
...
...
@@ -3,6 +3,7 @@ add_subdirectory(convinvscale)
add_subdirectory
(
convscale
)
add_subdirectory
(
convscale_relu
)
add_subdirectory
(
convscale_add
)
add_subdirectory
(
convscale_reduce
)
add_subdirectory
(
multi_AB
)
add_subdirectory
(
unary
)
...
...
example/62_convnd_activ/convscale_reduce/CMakeLists.txt
0 → 100644
View file @
4885c38a
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
example_convnd_activ_xdl_convscale_reduce
)
add_example_executable
(
example_convnd_fwd_xdl_convscale_relu_amax_fp8 convnd_fwd_xdl_convscale_relu_amax_fp8.cpp
)
add_example_dependencies
(
example_convnd_activ_xdl_convscale_reduce example_convnd_fwd_xdl_convscale_relu_amax_fp8
)
add_example_executable
(
example_convnd_fwd_xdl_convscale_amax_fp8 convnd_fwd_xdl_convscale_amax_fp8.cpp
)
add_example_dependencies
(
example_convnd_activ_xdl_convscale_reduce example_convnd_fwd_xdl_convscale_amax_fp8
)
set
(
target 1
)
endif
()
endforeach
()
example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp
0 → 100644
View file @
4885c38a
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
#include "ck/ck.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_reduce.hpp"
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/utility/type.hpp"
namespace
ew
=
ck
::
tensor_operation
::
element_wise
;
using
PassThrough
=
ew
::
PassThrough
;
using
ConvScaleRelu
=
ew
::
UnaryCombinedOp
<
ew
::
Scale
,
ew
::
Scale
,
ew
::
Relu
>
;
using
ConvScale
=
ew
::
UnaryCombinedOp
<
ew
::
Scale
,
ew
::
Scale
,
PassThrough
>
;
using
UnaryScaleConvert
=
ew
::
Scale
;
void
print_helper_msg
()
{
std
::
cout
<<
"arg1: verification (0=no, 1=yes)
\n
"
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
<<
"arg3: time kernel (0=no, 1=yes)
\n
"
<<
ck
::
utils
::
conv
::
get_conv_param_parser_helper_msg
()
<<
std
::
endl
;
}
template
<
typename
DataType
>
inline
__host__
__device__
constexpr
double
get_rtol
()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
float
>
)
{
return
1e-3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
double
>
)
{
return
1e-6
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
half_t
>
)
{
return
1e-3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bhalf_t
>
)
{
return
5e-2
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int32_t
>
)
{
return
1e-1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int8_t
>
)
{
return
1e-1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
f8_t
>
)
{
return
1e-1
;
// 240 and 224 are acceptable
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bf8_t
>
)
{
return
1.5e-1
;
// 57344 and 49152 are acceptable
}
else
{
return
1e-3
;
}
}
template
<
typename
DataType
>
inline
__host__
__device__
constexpr
double
get_atol
()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
float
>
)
{
return
1e-3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
double
>
)
{
return
1e-6
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
half_t
>
)
{
return
1e-3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bhalf_t
>
)
{
return
5e-2
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int32_t
>
)
{
return
1e-1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int8_t
>
)
{
return
1e-1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
f8_t
>
)
{
return
16.1
;
// 240 and 224 are acceptable
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bf8_t
>
)
{
return
8192.1
;
// 57344 and 49152 are acceptable
}
else
{
return
1e-3
;
}
}
template
<
ck
::
index_t
NDimSpatial
,
typename
InDataType
,
typename
WeiDataType
,
typename
ConvOutDataType
,
typename
OutDataType
,
typename
InElementOp
,
typename
WeiElementOp
,
typename
ConvElementOp
,
typename
DeviceConvNDFwdInstance
>
bool
run_grouped_conv_fwd
(
bool
do_verification
,
int
init_method
,
bool
time_kernel
,
const
ck
::
utils
::
conv
::
ConvParam
&
conv_param
,
const
HostTensorDescriptor
&
in_g_n_c_wis_desc
,
const
HostTensorDescriptor
&
wei_g_k_c_xs_desc
,
const
HostTensorDescriptor
&
out_g_n_k_wos_desc
,
const
InElementOp
&
in_element_op
,
const
WeiElementOp
&
wei_element_op
)
{
Tensor
<
InDataType
>
in
(
in_g_n_c_wis_desc
);
Tensor
<
WeiDataType
>
wei
(
wei_g_k_c_xs_desc
);
Tensor
<
ConvOutDataType
>
host_conv
(
out_g_n_k_wos_desc
);
Tensor
<
ConvOutDataType
>
device_conv
(
out_g_n_k_wos_desc
);
Tensor
<
OutDataType
>
out_host
(
out_g_n_k_wos_desc
);
Tensor
<
OutDataType
>
out_device
(
out_g_n_k_wos_desc
);
std
::
cout
<<
"in: "
<<
in
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"wei: "
<<
wei
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"out: "
<<
out_host
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
in
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
});
wei
.
GenerateTensorValue
(
GeneratorTensor_2
<
WeiDataType
>
{
-
5
,
5
});
break
;
case
11
:
// used for debugging
in
.
GenerateTensorValue
(
GeneratorTensor_1
<
InDataType
>
{
1
});
wei
.
GenerateTensorValue
(
GeneratorTensor_1
<
WeiDataType
>
{
1
});
break
;
default:
in
.
GenerateTensorValue
(
GeneratorTensor_3
<
InDataType
>
{
-
1.0
,
1.0
});
wei
.
GenerateTensorValue
(
GeneratorTensor_3
<
WeiDataType
>
{
-
0.5
,
0.5
});
}
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
wei_device_buf
(
sizeof
(
WeiDataType
)
*
wei
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
conv_device_buf
(
conv_param
.
GetOutputByte
<
ConvOutDataType
>
());
DeviceMem
out_device_buf
(
conv_param
.
GetOutputByte
<
OutDataType
>
());
in_device_buf
.
ToDevice
(
in
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei
.
mData
.
data
());
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
{};
auto
copy
=
[](
const
auto
&
x
,
auto
&
y
)
{
ck
::
ranges
::
copy
(
x
,
y
.
begin
());
};
copy
(
in_g_n_c_wis_desc
.
GetLengths
(),
a_g_n_c_wis_lengths
);
copy
(
in_g_n_c_wis_desc
.
GetStrides
(),
a_g_n_c_wis_strides
);
copy
(
wei_g_k_c_xs_desc
.
GetLengths
(),
b_g_k_c_xs_lengths
);
copy
(
wei_g_k_c_xs_desc
.
GetStrides
(),
b_g_k_c_xs_strides
);
copy
(
out_g_n_k_wos_desc
.
GetLengths
(),
e_g_n_k_wos_lengths
);
copy
(
out_g_n_k_wos_desc
.
GetStrides
(),
e_g_n_k_wos_strides
);
copy
(
conv_param
.
conv_filter_strides_
,
conv_filter_strides
);
copy
(
conv_param
.
conv_filter_dilations_
,
conv_filter_dilations
);
copy
(
conv_param
.
input_left_pads_
,
input_left_pads
);
copy
(
conv_param
.
input_right_pads_
,
input_right_pads
);
// random scale values
float
scale_in
=
float
(
std
::
rand
())
/
float
(
RAND_MAX
);
float
scale_wei
=
float
(
std
::
rand
())
/
float
(
RAND_MAX
);
float
scale_out
=
float
(
std
::
rand
())
/
float
(
RAND_MAX
);
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
"scale_in: "
<<
scale_in
<<
std
::
endl
;
std
::
cout
<<
"scale_wei: "
<<
scale_wei
<<
std
::
endl
;
std
::
cout
<<
"scale_out: "
<<
scale_out
<<
std
::
endl
;
// convolution elementwise operation
auto
conv_element_op
=
ConvElementOp
{
ew
::
Scale
{
scale_in
},
ew
::
Scale
{
scale_wei
},
{}};
auto
scale_convert
=
UnaryScaleConvert
{
scale_out
};
// elementwise scale and type cast
// do Conv
auto
conv
=
DeviceConvNDFwdInstance
{};
auto
conv_invoker
=
conv
.
MakeInvoker
();
auto
conv_argument
=
conv
.
MakeArgument
(
in_device_buf
.
GetDeviceBuffer
(),
wei_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
0
>
{},
conv_device_buf
.
GetDeviceBuffer
(),
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
std
::
array
<
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
,
0
>
{},
std
::
array
<
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
,
0
>
{},
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
in_element_op
,
wei_element_op
,
conv_element_op
);
if
(
!
conv
.
IsSupportedArgument
(
conv_argument
))
{
throw
std
::
runtime_error
(
"wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem"
);
}
std
::
string
kernels
=
conv
.
GetTypeString
();
float
avg_time
=
conv_invoker
.
Run
(
conv_argument
,
StreamConfig
{
nullptr
,
time_kernel
});
using
DeviceElementwiseScale
=
ck
::
tensor_operation
::
device
::
DeviceElementwiseImpl
<
ck
::
Tuple
<
ConvOutDataType
>
,
// InDataTypeTuple
ck
::
Tuple
<
OutDataType
>
,
// OutDataTypeTuple
UnaryScaleConvert
,
// UnaryScaleConvert
NDimSpatial
+
3
,
// NumDim
256
,
// BlockSize
128
,
// M0PerBlock
128
,
// M1PerBlock
8
,
// M0PerThread
8
,
// M1PerThread
ck
::
Sequence
<
1
,
0
>
,
// ThreadClusterArrangeOrder
ck
::
Sequence
<
8
>
,
// InScalarPerVectorSeq
ck
::
Sequence
<
8
>>
;
// OutScalarPerVectorSeq
auto
device_ew_scale
=
DeviceElementwiseScale
{};
auto
scale_invoker
=
device_ew_scale
.
MakeInvoker
();
auto
scale_argument
=
device_ew_scale
.
MakeArgument
(
e_g_n_k_wos_lengths
,
{
e_g_n_k_wos_strides
},
{
e_g_n_k_wos_strides
},
{
conv_device_buf
.
GetDeviceBuffer
()},
{
out_device_buf
.
GetDeviceBuffer
()},
scale_convert
);
if
(
!
device_ew_scale
.
IsSupportedArgument
(
scale_argument
))
{
throw
std
::
runtime_error
(
"wrong! DeviceElementwiseScale with the specified compilation parameters does "
"not support this problem"
);
}
kernels
+=
std
::
string
(
"
\n\t\t
"
)
+
device_ew_scale
.
GetTypeString
();
avg_time
+=
scale_invoker
.
Run
(
scale_argument
,
StreamConfig
{
nullptr
,
time_kernel
});
constexpr
auto
ReduceOpId
=
ck
::
ReduceTensorOp
::
AMAX
;
using
ReduceOperation
=
typename
ck
::
reduce_binary_operator
<
ReduceOpId
>::
opType
;
using
InElementwiseOperation
=
typename
ck
::
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
;
using
AccElementwiseOperation
=
typename
ck
::
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
;
using
DeviceReduceInstance
=
ck
::
tensor_operation
::
device
::
DeviceReduceMultiBlock
<
ConvOutDataType
,
ConvOutDataType
,
ConvOutDataType
,
NDimSpatial
+
3
,
NDimSpatial
+
3
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
ck
::
InMemoryDataOperationEnum
::
Set
,
true
,
// PropagateNan
false
,
// OutputIndex
false
,
// HaveIndexInputIfOutputIndex
256
,
// BlockSize
4
,
// MThreadClusterSize
64
,
// KThreadClusterSize
1
,
// MThreadSliceSize
1
,
// KThreadSliceSize
1
,
// InSrcVectorDim
1
,
// InSrceVectorSize
1
>
;
// OutDstVectorSize
std
::
vector
<
size_t
>
outLengths
=
{
1
};
Tensor
<
ConvOutDataType
>
amax_host
(
outLengths
);
Tensor
<
ConvOutDataType
>
amax_from_device
(
outLengths
);
auto
amax_host_strides
=
amax_host
.
mDesc
.
GetStrides
();
std
::
array
<
int
,
NDimSpatial
+
3
>
reduce_dims
;
std
::
iota
(
reduce_dims
.
begin
(),
reduce_dims
.
end
(),
0
);
// 0,..., NDimSpatial+3-1
std
::
array
<
ck
::
index_t
,
1
>
reduce_out_lengths
{
1
};
std
::
array
<
ck
::
index_t
,
1
>
reduce_out_strides
{
static_cast
<
ck
::
index_t
>
(
amax_host_strides
[
0
])};
DeviceMem
amax_device
(
sizeof
(
ConvOutDataType
)
*
amax_host
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
index_device
;
InElementwiseOperation
in_elementwise_op
;
AccElementwiseOperation
acc_elementwise_op
;
std
::
tie
(
in_elementwise_op
,
acc_elementwise_op
)
=
ck
::
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
GetElementwiseOperator
(
static_cast
<
int32_t
>
(
host_conv
.
mDesc
.
GetElementSize
()));
// Hack convolution output strides for reduction as kernel expects stride 1 for the last
// dimension. It only works because the reduction is done on the whole tensor and result is
// independent of the order of elements.
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
reduction_strides
{};
copy
(
HostTensorDescriptor
(
e_g_n_k_wos_lengths
).
GetStrides
(),
reduction_strides
);
auto
device_reduce
=
DeviceReduceInstance
{};
auto
reduce_invoker
=
device_reduce
.
MakeInvokerPointer
();
auto
reduce_argument
=
device_reduce
.
MakeArgumentPointer
(
e_g_n_k_wos_lengths
,
reduction_strides
,
reduce_out_lengths
,
reduce_out_strides
,
reduce_dims
,
1.0
,
0.0
,
conv_device_buf
.
GetDeviceBuffer
(),
nullptr
,
amax_device
.
GetDeviceBuffer
(),
nullptr
,
in_elementwise_op
,
acc_elementwise_op
);
if
(
!
device_reduce
.
IsSupportedArgument
(
reduce_argument
.
get
()))
{
throw
std
::
runtime_error
(
"wrong! DeviceReduceInstance with the specified compilation parameters does "
"not support this runtime parameters!"
);
};
kernels
+=
std
::
string
(
"
\n\t\t
"
)
+
device_reduce
.
GetTypeString
();
float
reduce_time
=
reduce_invoker
->
Run
(
reduce_argument
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
if
(
time_kernel
)
std
::
cout
<<
"
\n
Reduce time: "
<<
reduce_time
<<
" ms"
<<
std
::
endl
;
avg_time
+=
reduce_time
;
std
::
size_t
flop
=
conv_param
.
GetFlops
();
// convolution FLOPs
auto
conv_out_elems
=
host_conv
.
GetElementSize
();
// number of elements in conv result tensor
// 3 element-wise scale multipliers + 1 AMAX
std
::
size_t
elementwise_ops
=
3
+
1
;
if
constexpr
(
ck
::
is_same_v
<
ConvElementOp
,
ConvScaleRelu
>
)
{
elementwise_ops
+=
1
;
// +1 element-wise relu
}
flop
+=
elementwise_ops
*
conv_out_elems
;
// convolution + elementwise scaling (in + wei + output byte count)
std
::
size_t
num_btype
=
conv_param
.
GetByte
<
InDataType
,
WeiDataType
,
ConvOutDataType
>
();
num_btype
+=
sizeof
(
float
)
+
sizeof
(
float
);
// + 2 scales
// elementwise scaling + F8 conversion
num_btype
+=
conv_param
.
GetOutputByte
<
ConvOutDataType
>
()
+
sizeof
(
float
)
+
conv_param
.
GetOutputByte
<
OutDataType
>
();
// AMAX
num_btype
+=
conv_param
.
GetOutputByte
<
ConvOutDataType
>
()
+
sizeof
(
float
);
if
(
time_kernel
)
{
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
avg_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
avg_time
;
std
::
cout
<<
"Perf: "
<<
avg_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
}
std
::
cout
<<
"
\n
Kernels: "
<<
kernels
<<
std
::
endl
;
if
(
do_verification
)
{
auto
ref_conv
=
ck
::
tensor_operation
::
host
::
ReferenceConvFwd
<
NDimSpatial
,
InDataType
,
WeiDataType
,
ConvOutDataType
,
InElementOp
,
WeiElementOp
,
ConvElementOp
>
();
auto
ref_invoker
=
ref_conv
.
MakeInvoker
();
auto
ref_argument
=
ref_conv
.
MakeArgument
(
in
,
wei
,
host_conv
,
conv_param
.
conv_filter_strides_
,
conv_param
.
conv_filter_dilations_
,
conv_param
.
input_left_pads_
,
conv_param
.
input_right_pads_
,
in_element_op
,
wei_element_op
,
conv_element_op
);
ref_invoker
.
Run
(
ref_argument
);
conv_device_buf
.
FromDevice
(
device_conv
.
mData
.
data
());
out_device_buf
.
FromDevice
(
out_device
.
mData
.
data
());
out_host
.
ForEach
([
&
](
auto
&
,
auto
idx
)
{
scale_convert
(
out_host
(
idx
),
host_conv
(
idx
));
});
std
::
cout
<<
"
\n
Comparing output to reference: "
<<
std
::
endl
;
auto
tight_tol_check
=
ck
::
utils
::
check_err
(
out_device
,
out_host
,
"Error: "
);
if
(
!
tight_tol_check
)
{
std
::
cout
<<
"
\n\t
Recompare applying tolerances...
\n
"
;
std
::
cout
<<
"
\t\t
rtol = "
<<
get_rtol
<
OutDataType
>
()
<<
std
::
endl
;
std
::
cout
<<
"
\t\t
atol = "
<<
get_atol
<
OutDataType
>
()
<<
std
::
endl
;
auto
loose_tol_check
=
ck
::
utils
::
check_err
(
out_device
,
out_host
,
"Error: incorrect convolution results!"
,
get_rtol
<
OutDataType
>
(),
get_atol
<
OutDataType
>
());
if
(
!
loose_tol_check
)
{
return
false
;
}
}
std
::
cout
<<
"Success!"
<<
std
::
endl
;
/// Verify AMAX
using
RefReduceInstance
=
ck
::
tensor_operation
::
host
::
ReferenceReduce
<
ConvOutDataType
,
ConvOutDataType
,
ConvOutDataType
,
NDimSpatial
+
3
,
NDimSpatial
+
3
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
true
,
false
>
;
auto
ref_reduce
=
RefReduceInstance
{};
auto
ref_reduce_invoker
=
ref_reduce
.
MakeInvokerPointer
();
auto
ref_reduce_argument
=
ref_reduce
.
MakeArgumentPointer
(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
reduce_out_lengths
,
reduce_out_strides
,
reduce_dims
,
1.0
,
0.0
,
host_conv
.
mData
.
data
(),
nullptr
,
amax_host
.
mData
.
data
(),
nullptr
,
in_elementwise_op
,
acc_elementwise_op
);
if
(
!
ref_reduce
.
IsSupportedArgument
(
ref_reduce_argument
.
get
()))
{
throw
std
::
runtime_error
(
"wrong! RefReduceInstance with the specified compilation parameters does "
"not support this runtime parameters!"
);
};
ref_reduce_invoker
->
Run
(
ref_reduce_argument
.
get
());
amax_device
.
FromDevice
(
amax_from_device
.
mData
.
data
());
std
::
cout
<<
"
\n
amax: "
<<
amax_from_device
.
mData
[
0
]
<<
std
::
endl
;
std
::
cout
<<
"amax_ref: "
<<
amax_host
.
mData
[
0
]
<<
std
::
endl
;
return
ck
::
utils
::
check_err
(
amax_from_device
,
amax_host
,
"Error: incorrect AMAX results!"
);
}
return
true
;
}
example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_amax_fp8.cpp
0 → 100644
View file @
4885c38a
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_convscale_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
using
InDataType
=
ck
::
f8_t
;
using
WeiDataType
=
ck
::
f8_t
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
float
;
using
ConvOutDataType
=
float
;
// data type of convolution result
using
OutDataType
=
ck
::
f8_t
;
// data type of final result
using
AComputeDataType
=
ck
::
f8_t
;
using
BComputeDataType
=
ck
::
f8_t
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
InElementOp
=
PassThrough
;
using
WeiElementOp
=
PassThrough
;
using
OutElementOp
=
ConvScale
;
static
constexpr
auto
ConvSpec
=
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
template
<
ck
::
index_t
NDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
using
DeviceGroupedConvNDFwdInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
InLayout
,
WeiLayout
,
ck
::
Tuple
<>
,
OutLayout
,
InDataType
,
WeiDataType
,
AccDataType
,
CShuffleDataType
,
ck
::
Tuple
<>
,
ConvOutDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
ConvSpec
,
// ConvForwardSpecialization
GemmSpec
,
// GemmSpecialization
1
,
//
256
,
// BlockSize
128
,
// MPerBlock
256
,
// NPerBlock
32
,
// KPerBlock
8
,
// AK1
8
,
// BK1
32
,
// MPerXdl
32
,
// NPerXdl
2
,
// MXdlPerWave
4
,
// NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransferThreadClusterLengths_AK0_M_AK1
S
<
1
,
0
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
2
,
// ABlockTransferSrcVectorDim
8
,
// ABlockTransferSrcScalarPerVector
8
,
// ABlockTransferDstScalarPerVector_AK1
1
,
// ABlockLdsExtraM
S
<
4
,
64
,
1
>
,
// BBlockTransferThreadClusterLengths_BK0_N_BK1
S
<
1
,
0
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
8
,
// BBlockTransferSrcScalarPerVector
8
,
// BBlockTransferDstScalarPerVector_BK1
1
,
// BBlockLdsExtraN
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
AComputeDataType
,
BComputeDataType
>
;
#include "run_convnd_fwd_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run_convnd_fwd_example
(
argc
,
argv
)
?
0
:
1
;
}
example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_relu_amax_fp8.cpp
0 → 100644
View file @
4885c38a
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_convscale_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
using
InDataType
=
ck
::
f8_t
;
using
WeiDataType
=
ck
::
f8_t
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
float
;
using
ConvOutDataType
=
float
;
// data type of convolution result
using
OutDataType
=
ck
::
f8_t
;
// data type of final result
using
AComputeDataType
=
ck
::
f8_t
;
using
BComputeDataType
=
ck
::
f8_t
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
InElementOp
=
PassThrough
;
using
WeiElementOp
=
PassThrough
;
using
OutElementOp
=
ConvScaleRelu
;
static
constexpr
auto
ConvSpec
=
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
template
<
ck
::
index_t
NDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
using
DeviceGroupedConvNDFwdInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
InLayout
,
WeiLayout
,
ck
::
Tuple
<>
,
OutLayout
,
InDataType
,
WeiDataType
,
AccDataType
,
CShuffleDataType
,
ck
::
Tuple
<>
,
ConvOutDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
ConvSpec
,
// ConvForwardSpecialization
GemmSpec
,
// GemmSpecialization
1
,
//
256
,
// BlockSize
128
,
// MPerBlock
256
,
// NPerBlock
32
,
// KPerBlock
8
,
// AK1
8
,
// BK1
32
,
// MPerXdl
32
,
// NPerXdl
2
,
// MXdlPerWave
4
,
// NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransferThreadClusterLengths_AK0_M_AK1
S
<
1
,
0
,
2
>
,
// ABlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// ABlockTransferSrcAccessOrder
2
,
// ABlockTransferSrcVectorDim
8
,
// ABlockTransferSrcScalarPerVector
8
,
// ABlockTransferDstScalarPerVector_AK1
1
,
// ABlockLdsExtraM
S
<
4
,
64
,
1
>
,
// BBlockTransferThreadClusterLengths_BK0_N_BK1
S
<
1
,
0
,
2
>
,
// BBlockTransferThreadClusterArrangeOrder
S
<
1
,
0
,
2
>
,
// BBlockTransferSrcAccessOrder
2
,
// BBlockTransferSrcVectorDim
8
,
// BBlockTransferSrcScalarPerVector
8
,
// BBlockTransferDstScalarPerVector_BK1
1
,
// BBlockLdsExtraN
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
AComputeDataType
,
BComputeDataType
>
;
#include "run_convnd_fwd_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run_convnd_fwd_example
(
argc
,
argv
)
?
0
:
1
;
}
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment