Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
dcf15456
Commit
dcf15456
authored
Sep 14, 2023
by
Rostyslav Geyyer
Browse files
Merge branch 'develop' into lwpck-911
parents
1e3eb1c6
5fe687fa
Changes
85
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1027 additions
and
535 deletions
+1027
-535
.github/CODEOWNERS
.github/CODEOWNERS
+6
-0
CMakeLists.txt
CMakeLists.txt
+25
-1
Jenkinsfile
Jenkinsfile
+19
-11
client_example/20_splitk_gemm/CMakeLists.txt
client_example/20_splitk_gemm/CMakeLists.txt
+4
-2
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+4
-2
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp
...ouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp
+50
-39
example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
...d_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
+1
-13
example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp
example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp
+10
-6
include/ck/config.h.in
include/ck/config.h.in
+7
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
...n/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
+3
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
...ion/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
+376
-317
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+26
-5
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+4
-5
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
+223
-7
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+9
-1
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+16
-2
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+2
-0
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+123
-8
include/ck/utility/f8_utils.hpp
include/ck/utility/f8_utils.hpp
+107
-112
include/ck/utility/inner_product.hpp
include/ck/utility/inner_product.hpp
+12
-0
No files found.
.github/CODEOWNERS
0 → 100644
View file @
dcf15456
# Documentation files
docs/* @saadrahim @LisaDelaney
*.md @saadrahim @LisaDelaney
*.rst @saadrahim @LisaDelaney
# Header directory
library/include/* @saadrahim @LisaDelaney
CMakeLists.txt
View file @
dcf15456
cmake_minimum_required
(
VERSION 3.14
)
# This has to be initialized before the project() command appears
# Set the default of CMAKE_BUILD_TYPE to be release, unless user specifies with -D. MSVC_IDE does not use CMAKE_BUILD_TYPE
if
(
NOT MSVC_IDE AND NOT CMAKE_BUILD_TYPE
)
set
(
CMAKE_BUILD_TYPE Release CACHE STRING
"Choose the type of build, options are: None Debug Release RelWithDebInfo MinSizeRel."
)
endif
()
# Default installation path
if
(
WIN32
)
set
(
CMAKE_INSTALL_PREFIX
"/opt/rocm/x86_64-w64-mingw32"
CACHE PATH
""
)
else
()
set
(
CMAKE_INSTALL_PREFIX
"/opt/rocm"
CACHE PATH
""
)
endif
()
set
(
version 1.1.0
)
# Check support for CUDA/HIP in Cmake
project
(
composable_kernel VERSION
${
version
}
)
...
...
@@ -15,6 +28,12 @@ if (DTYPES)
if
(
DTYPES MATCHES
"fp8"
)
add_definitions
(
-DCK_ENABLE_FP8
)
set
(
CK_ENABLE_FP8
"ON"
)
add_compile_options
(
-Wno-bit-int-extension
)
endif
()
if
(
DTYPES MATCHES
"bf8"
)
add_definitions
(
-DCK_ENABLE_BF8
)
set
(
CK_ENABLE_BF8
"ON"
)
add_compile_options
(
-Wno-bit-int-extension
)
endif
()
if
(
DTYPES MATCHES
"fp16"
)
add_definitions
(
-DCK_ENABLE_FP16
)
...
...
@@ -34,8 +53,9 @@ if (DTYPES)
endif
()
message
(
"DTYPES macro set to
${
DTYPES
}
"
)
else
()
add_definitions
(
-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16
)
add_definitions
(
-DCK_ENABLE_INT8 -DCK_ENABLE_FP8
-DCK_ENABLE_BF8
-DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16
)
set
(
CK_ENABLE_ALL_DTYPES
"ON"
)
add_compile_options
(
-Wno-bit-int-extension
)
# enable fp8 and bf8
endif
()
if
(
DL_KERNELS
)
...
...
@@ -365,6 +385,10 @@ IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu
#message("fp8 instance found!")
set
(
add_inst 1
)
endif
()
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
bf8
\"
"
AND DTYPES MATCHES
"bf8"
)
#message("bf8 instance found!")
set
(
add_inst 1
)
endif
()
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
fp16
\"
"
AND DTYPES MATCHES
"fp16"
)
#message("fp16 instance found!")
set
(
add_inst 1
)
...
...
Jenkinsfile
View file @
dcf15456
...
...
@@ -210,6 +210,9 @@ def cmake_build(Map conf=[:]){
}
else
{
setup_args
=
' -DBUILD_DEV=On'
+
setup_args
}
if
(
params
.
DL_KERNELS
){
setup_args
=
setup_args
+
" -DDL_KERNELS=ON "
}
if
(
build_type_debug
){
setup_args
=
" -DCMAKE_BUILD_TYPE=debug -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'"
+
setup_args
...
...
@@ -367,8 +370,6 @@ def runCKProfiler(Map conf=[:]){
withDockerContainer
(
image:
image
,
args:
dockerOpts
+
' -v=/var/jenkins/:/var/jenkins'
)
{
timeout
(
time:
24
,
unit:
'HOURS'
)
{
//cmake_build(conf)
//instead of building, just unstash the ckProfiler and install it
sh
"""
rm -rf build
mkdir build
...
...
@@ -614,7 +615,7 @@ def process_results(Map conf=[:]){
//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version
CRON_SETTINGS
=
BRANCH_NAME
==
"develop"
?
'''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=5.7;COMPILER_VERSION=rc1
0 21 * * * % ROCMVERSION=5.6;COMPILER_VERSION=;COMPILER_COMMIT=
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT='''
:
""
0 19 * * * % BUILD_DOCKER=true;
DL_KERNELS=true;
COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT='''
:
""
pipeline
{
agent
none
...
...
@@ -649,6 +650,10 @@ pipeline {
name:
"RUN_FULL_QA"
,
defaultValue:
false
,
description:
"Select whether to run small set of performance tests (default) or full QA"
)
booleanParam
(
name:
"DL_KERNELS"
,
defaultValue:
false
,
description:
"Select whether to build DL kernels (default: OFF)"
)
}
environment
{
dbuser
=
"${dbuser}"
...
...
@@ -663,15 +668,12 @@ pipeline {
}
stages
{
stage
(
"Build Docker"
){
//when {
// beforeAgent true
// expression { params.BUILD_DOCKER.toBoolean() }
//}
parallel
{
stage
(
'Docker /opt/rocm'
){
agent
{
label
rocmnode
(
"nogpu"
)
}
steps
{
buildDocker
(
'/opt/rocm'
)
cleanWs
()
}
}
}
...
...
@@ -693,6 +695,7 @@ pipeline {
}
steps
{
buildHipClangJobAndReboot
(
setup_cmd:
""
,
build_cmd:
""
,
execute_cmd:
execute_cmd
,
no_reboot:
true
)
cleanWs
()
}
}
}
...
...
@@ -715,6 +718,7 @@ pipeline {
}
steps
{
Build_CK_and_Reboot
(
setup_args:
setup_args
,
config_targets:
"install"
,
no_reboot:
true
,
build_type:
'Release'
,
execute_cmd:
execute_args
,
prefixpath:
'/usr/local'
)
cleanWs
()
}
}
stage
(
"Build CK and run Tests on MI100/MI200"
)
...
...
@@ -730,6 +734,7 @@ pipeline {
}
steps
{
Build_CK_and_Reboot
(
setup_args:
setup_args
,
config_targets:
"install"
,
no_reboot:
true
,
build_type:
'Release'
,
execute_cmd:
execute_args
,
prefixpath:
'/usr/local'
)
cleanWs
()
}
}
stage
(
"Build CK and run Tests on Navi21"
)
...
...
@@ -742,10 +747,10 @@ pipeline {
environment
{
setup_args
=
""" -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" -DDL_KERNELS=ON """
execute_args
=
""" cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1030" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
}
steps
{
Build_CK_and_Reboot
(
setup_args:
setup_args
,
config_targets:
"install"
,
no_reboot:
true
,
build_type:
'Release'
,
execute_cmd:
execute_args
,
prefixpath:
'/usr/local'
)
cleanWs
()
}
}
stage
(
"Build CK and run Tests on Navi32"
)
...
...
@@ -756,12 +761,12 @@ pipeline {
}
agent
{
label
rocmnode
(
"navi32"
)
}
environment
{
setup_args
=
""" -DCMAKE_INSTALL_PREFIX=../install -DDTYPES="fp16;fp32;bf16" -DGPU_TARGETS="gfx1101" """
execute_args
=
""" cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1101" -DDTYPES="fp16;fp32;bf16" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
setup_args
=
""" -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1101" """
execute_args
=
""" cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1101" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
}
steps
{
Build_CK_and_Reboot
(
setup_args:
setup_args
,
config_targets:
"install"
,
no_reboot:
true
,
build_type:
'Release'
,
execute_cmd:
execute_args
,
prefixpath:
'/usr/local'
)
cleanWs
()
}
}
}
...
...
@@ -784,6 +789,7 @@ pipeline {
}
steps
{
runPerfTest
(
setup_args:
setup_args
,
config_targets:
"ckProfiler"
,
no_reboot:
true
,
build_type:
'Release'
)
cleanWs
()
}
}
stage
(
"Run ckProfiler: gfx90a"
)
...
...
@@ -799,6 +805,7 @@ pipeline {
}
steps
{
runPerfTest
(
setup_args:
setup_args
,
config_targets:
"ckProfiler"
,
no_reboot:
true
,
build_type:
'Release'
)
cleanWs
()
}
}
}
...
...
@@ -811,6 +818,7 @@ pipeline {
agent
{
label
'mici'
}
steps
{
process_results
()
cleanWs
()
}
}
}
...
...
client_example/20_splitk_gemm/CMakeLists.txt
View file @
dcf15456
add_executable
(
client_splitK_gemm splitK_gemm_fp16_f8.cpp
)
target_link_libraries
(
client_splitK_gemm PRIVATE composable_kernel::device_operations
)
if
((
DTYPES MATCHES
"fp8"
AND DTYPES MATCHES
"fp16"
)
OR NOT DEFINED DTYPES
)
add_executable
(
client_splitK_gemm splitK_gemm_fp16_f8.cpp
)
target_link_libraries
(
client_splitK_gemm PRIVATE composable_kernel::device_operations
)
endif
()
example/01_gemm/CMakeLists.txt
View file @
dcf15456
...
...
@@ -69,5 +69,7 @@ if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES)
endif
()
endif
()
add_example_executable
(
example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_f8
)
if
((
DTYPES MATCHES
"fp8"
AND DTYPES MATCHES
"fp16"
)
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_f8
)
endif
()
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp
View file @
dcf15456
...
...
@@ -3,7 +3,7 @@
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_
gnwc_gkxc_gnwk_
dl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp"
using
InDataType
=
F16
;
using
WeiDataType
=
F16
;
...
...
@@ -15,9 +15,20 @@ using WeiElementOp = PassThrough;
using
OutElementOp
=
PassThrough
;
template
<
ck
::
index_t
NDimSpatial
>
using
DeviceConvBwdWeightInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
<
using
DeviceConvBwdWeightInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeight_Dl
<
NDimSpatial
,
// NDimSpatial
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWC
,
ck
::
tensor_layout
::
convolution
::
GNHWC
,
ck
::
tensor_layout
::
convolution
::
GNDHWC
>>
,
// InLayout
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GKXC
,
ck
::
tensor_layout
::
convolution
::
GKYXC
,
ck
::
tensor_layout
::
convolution
::
GKZYXC
>>
,
// WeiLayout
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWK
,
ck
::
tensor_layout
::
convolution
::
GNHWK
,
ck
::
tensor_layout
::
convolution
::
GNDHWK
>>
,
// OutLayout
InDataType
,
// InDataType
WeiDataType
,
// WeiDataType
OutDataType
,
// OutDataType
...
...
example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
View file @
dcf15456
...
...
@@ -14,20 +14,8 @@ template <ck::index_t NDimSpatial>
bool
run_grouped_conv_bwd_weight
(
const
ExecutionConfig
&
config
,
const
ck
::
utils
::
conv
::
ConvParam
&
conv_param
)
{
ck
::
index_t
split_k
;
// Set split_k = 2 for xdl op, split_k = 1 for dl
// Dl op doesn't support split_k > 1
// TODO: Add Dl op split_k > 1 support
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx1030"
||
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
ck
::
get_device_name
()
==
"gfx1102"
))
{
split_k
=
2
;
}
else
{
split_k
=
1
;
}
constexpr
ck
::
index_t
split_k
=
1
;
const
auto
in_g_n_c_wis_desc
=
ck
::
utils
::
conv
::
make_input_host_tensor_descriptor_g_n_c_wis_packed
<
...
...
example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp
View file @
dcf15456
...
...
@@ -14,18 +14,22 @@ using ComputeDataType = float;
struct
YElementOp
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
static_assert
(
ck
::
is_same
<
T
,
float
>::
value
||
ck
::
is_same
<
T
,
double
>::
value
||
ck
::
is_same
<
T
,
ck
::
half_t
>::
value
,
static_assert
(
ck
::
is_same
<
X
,
float
>::
value
||
ck
::
is_same
<
X
,
double
>::
value
||
ck
::
is_same
<
X
,
ck
::
half_t
>::
value
,
"Data type is not supported by this operation!"
);
T
a
;
static_assert
(
ck
::
is_same
<
Y
,
float
>::
value
||
ck
::
is_same
<
Y
,
double
>::
value
||
ck
::
is_same
<
Y
,
ck
::
half_t
>::
value
,
"Data type is not supported by this operation!"
);
X
a
;
ck
::
tensor_operation
::
element_wise
::
Sigmoid
{}(
a
,
x
);
y
=
x
*
a
;
y
=
ck
::
type_convert
<
Y
>
(
x
*
a
)
;
};
};
...
...
include/ck/config.h.in
View file @
dcf15456
...
...
@@ -43,6 +43,9 @@
#ifndef CK_ENABLE_FP8
#define CK_ENABLE_FP8 "ON"
#endif
#ifndef CK_ENABLE_BF8
#define CK_ENABLE_BF8 "ON"
#endif
#ifndef CK_ENABLE_FP16
#define CK_ENABLE_FP16 "ON"
#endif
...
...
@@ -66,6 +69,10 @@
#cmakedefine CK_ENABLE_FP8 @CK_ENABLE_FP8@
#endif
#ifndef CK_ENABLE_BF8
#cmakedefine CK_ENABLE_BF8 @CK_ENABLE_BF8@
#endif
#ifndef CK_ENABLE_FP16
#cmakedefine CK_ENABLE_FP16 @CK_ENABLE_FP16@
#endif
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
View file @
dcf15456
...
...
@@ -144,7 +144,8 @@ template <typename ALayout,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
,
typename
ComputeDataType
=
EDataType
>
struct
DeviceGemmMultipleD_Xdl_CShuffle
:
public
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
DsLayout
,
...
...
@@ -243,11 +244,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
(
1
,
1
,
1
));
using
ComputeDataType
=
EDataType
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_
gnwc_gkxc_gnwk_
dl.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
View file @
dcf15456
...
...
@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -72,6 +73,9 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx940__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
...
...
@@ -96,9 +100,27 @@ __global__ void
block_2_ctile_map
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
batch_count
;
ignore
=
a_grid_desc_kbatch_k0_m0_m1_k1
;
ignore
=
b_grid_desc_kbatch_k0_n0_n1_k1
;
ignore
=
c_grid_desc_m0_m10_m11_n0_n10_n11
;
ignore
=
block_2_ctile_map
;
ignore
=
compute_ptr_offset_of_batch
;
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
0
);
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
0
);
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
0
);
#endif
}
template
<
ck
::
index_t
NDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
...
...
@@ -134,21 +156,10 @@ template <ck::index_t NDimSpatial,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
>
struct
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
:
public
DeviceGroupedConvBwdWeight
<
NDimSpatial
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWC
,
ck
::
tensor_layout
::
convolution
::
GNHWC
,
ck
::
tensor_layout
::
convolution
::
GNDHWC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GKXC
,
ck
::
tensor_layout
::
convolution
::
GKYXC
,
ck
::
tensor_layout
::
convolution
::
GKZYXC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWK
,
ck
::
tensor_layout
::
convolution
::
GNHWK
,
ck
::
tensor_layout
::
convolution
::
GNDHWK
>>
,
struct
DeviceGroupedConvBwdWeight_Dl
:
public
DeviceGroupedConvBwdWeight
<
NDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
...
...
@@ -156,7 +167,35 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
using
DeviceOp
=
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
;
// 1d
static
constexpr
bool
is_NWGK_GKXC_NWGC
=
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NWGC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
NWGK
>
;
static
constexpr
bool
is_GNWK_GKXC_GNWC
=
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
GNWC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
GNWK
>
;
// 2d
static
constexpr
bool
is_NHWGK_GKYXC_NHWGC
=
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NHWGC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKYXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
NHWGK
>
;
static
constexpr
bool
is_GNHWK_GKYXC_GNHWC
=
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
GNHWC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKYXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
GNHWK
>
;
// 3d
static
constexpr
bool
is_NDHWGK_GKZYXC_NDHWGC
=
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
NDHWGK
>
;
static
constexpr
bool
is_GNDHWK_GKZYXC_GNDHWC
=
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
GNDHWC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
GNDHWK
>
;
using
DeviceOp
=
DeviceGroupedConvBwdWeight_Dl
;
using
ADataType
=
OutDataType
;
using
BDataType
=
InDataType
;
...
...
@@ -176,6 +215,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
spatial_offset
=
I3
;
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
GemmK1Number
=
K1Number
;
...
...
@@ -195,12 +236,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
const
ck
::
index_t
N
,
const
ck
::
index_t
K
,
const
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_length
s
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_length
s
,
const
std
::
array
<
index_t
,
N
DimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_stride
s
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_stride
s
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
...
...
@@ -209,90 +250,102 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
{
using
namespace
ck
;
const
index_t
Wi
=
input_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
0
];
const
index_t
X
=
filter_spatial_lengths
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
0
];
const
index_t
N
=
a_g_n_c_wis_lengths
[
I1
];
const
index_t
K
=
b_g_k_c_xs_lengths
[
I1
];
const
index_t
C
=
a_g_n_c_wis_lengths
[
I2
];
const
index_t
Wi
=
a_g_n_c_wis_lengths
[
spatial_offset
];
const
index_t
Wo
=
e_g_n_k_wos_lengths
[
spatial_offset
];
const
index_t
X
=
b_g_k_c_xs_lengths
[
spatial_offset
];
const
index_t
InLeftPadW
=
input_left_pads
[
I0
];
const
index_t
InRightPadW
=
input_right_pads
[
I0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
I0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
I0
];
const
auto
InNStride
=
a_g_n_c_wis_strides
[
I1
];
const
auto
InCStride
=
a_g_n_c_wis_strides
[
I2
];
const
auto
InWStride
=
a_g_n_c_wis_strides
[
spatial_offset
];
const
auto
WeiKStride
=
b_g_k_c_xs_strides
[
I1
];
const
auto
WeiCStride
=
b_g_k_c_xs_strides
[
I2
];
const
auto
OutKStride
=
e_g_n_k_wos_strides
[
I2
];
const
auto
OutWStride
=
e_g_n_k_wos_strides
[
spatial_offset
];
const
index_t
GemmKTotal
=
N
*
Wo
;
const
index_t
GemmM
=
K
;
const
index_t
GemmN
=
C
*
X
;
const
index_t
GemmKBatch
=
batch_k
;
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1Number
*
K0PerBlock
*
GemmKBatch
)
*
K0PerBlock
;
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1Number
;
if
constexpr
(
ConvBackwardWeightSpecialization
==
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
// A: output tensor
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_
naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Wo
,
K
));
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor
(
make_
tuple
(
N
*
Wo
,
K
),
make_tuple
(
OutWStride
,
OutKStride
));
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_gemmkpad_gemmmpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
GemmK1Number
*
K0PerBlock
*
GemmKBatch
,
MPerBlock
),
Sequence
<
true
,
true
>
{});
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
out_gemmkpad_gemmmpad_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
out_gemmkpad_gemmmpad_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// B: input tensor
const
auto
in_gemmktotal_gemmn_grid_desc
=
make_
naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Wi
,
C
));
const
auto
in_gemmktotal_gemmn_grid_desc
=
make_naive_tensor_descriptor
(
make_
tuple
(
N
*
Wi
,
C
),
make_tuple
(
InWStride
,
InCStride
));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_gemmkpad_gemmnpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
GemmK1Number
*
K0PerBlock
*
GemmKBatch
,
NPerBlock
),
Sequence
<
true
,
true
>
{});
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
in_gemmkpad_gemmnpad_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
in_gemmkpad_gemmnpad_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// C: weights tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
X
*
C
));
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K
,
X
*
C
),
make_tuple
(
WeiKStride
,
WeiCStride
));
const
auto
wei_gemmmpad_gemmnpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmm_gemmn_grid_desc
,
make_tuple
(
MPerBlock
,
NPerBlock
),
Sequence
<
true
,
true
>
{});
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
wei_gemmm
pad
_gemmn
pad
_grid_desc
);
}
else
{
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_
naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Wo
,
K
));
const
auto
in_n_wi_c_grid_desc
=
make_
naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor
(
make_
tuple
(
N
*
Wo
,
K
),
make_tuple
(
OutWStride
,
OutKStride
));
const
auto
in_n_wi_c_grid_desc
=
make_naive_tensor_descriptor
(
make_
tuple
(
N
,
Wi
,
C
),
make_tuple
(
InNStride
,
InWStride
,
InCStride
));
// A: output tensor
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_gemmkpad_gemmmpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
GemmK1Number
*
K0PerBlock
*
GemmKBatch
,
MPerBlock
),
Sequence
<
true
,
true
>
{});
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
out_gemmkpad_gemmmpad_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
out_gemmkpad_gemmmpad_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
...
@@ -321,38 +374,43 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
make_tuple
(
Sequence
<
1
,
3
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_gemmkpad_gemmnpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
GemmK1Number
*
K0PerBlock
*
GemmKBatch
,
NPerBlock
),
Sequence
<
true
,
true
>
{});
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
in_gemmkpad_gemmnpad_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
in_gemmkpad_gemmnpad_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
X
*
C
));
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K
,
X
*
C
),
make_tuple
(
WeiKStride
,
WeiCStride
));
const
auto
wei_gemmmpad_gemmnpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmm_gemmn_grid_desc
,
make_tuple
(
MPerBlock
,
NPerBlock
),
Sequence
<
true
,
true
>
{});
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
wei_gemmm
pad
_gemmn
pad
_grid_desc
);
}
}
// function end
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
const
ck
::
index_t
N
,
const
ck
::
index_t
K
,
const
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_length
s
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_length
s
,
const
std
::
array
<
index_t
,
N
DimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_stride
s
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_stride
s
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
...
...
@@ -361,103 +419,111 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
{
using
namespace
ck
;
const
index_t
Hi
=
input_spatial_lengths
[
0
];
const
index_t
Wi
=
input_spatial_lengths
[
1
];
const
index_t
Ho
=
output_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
1
];
const
index_t
Y
=
filter_spatial_lengths
[
0
];
const
index_t
X
=
filter_spatial_lengths
[
1
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
index_t
N
=
a_g_n_c_wis_lengths
[
I1
];
const
index_t
K
=
b_g_k_c_xs_lengths
[
I1
];
const
index_t
C
=
a_g_n_c_wis_lengths
[
I2
];
const
index_t
Hi
=
a_g_n_c_wis_lengths
[
spatial_offset
];
const
index_t
Wi
=
a_g_n_c_wis_lengths
[
spatial_offset
+
I1
];
const
index_t
Ho
=
e_g_n_k_wos_lengths
[
spatial_offset
];
const
index_t
Wo
=
e_g_n_k_wos_lengths
[
spatial_offset
+
I1
];
const
index_t
Y
=
b_g_k_c_xs_lengths
[
spatial_offset
];
const
index_t
X
=
b_g_k_c_xs_lengths
[
spatial_offset
+
I1
];
const
index_t
InLeftPadH
=
input_left_pads
[
I0
];
const
index_t
InLeftPadW
=
input_left_pads
[
I1
];
const
index_t
InRightPadH
=
input_right_pads
[
I0
];
const
index_t
InRightPadW
=
input_right_pads
[
I1
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
I0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
I1
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
I0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
I1
];
const
auto
InNStride
=
a_g_n_c_wis_strides
[
I1
];
const
auto
InCStride
=
a_g_n_c_wis_strides
[
I2
];
const
auto
InHStride
=
a_g_n_c_wis_strides
[
spatial_offset
];
const
auto
InWStride
=
a_g_n_c_wis_strides
[
spatial_offset
+
I1
];
const
auto
WeiKStride
=
b_g_k_c_xs_strides
[
I1
];
const
auto
WeiCStride
=
b_g_k_c_xs_strides
[
I2
];
const
auto
OutKStride
=
e_g_n_k_wos_strides
[
I2
];
const
auto
OutWStride
=
e_g_n_k_wos_strides
[
spatial_offset
+
I1
];
const
index_t
GemmKTotal
=
N
*
Ho
*
Wo
;
const
index_t
GemmM
=
K
;
const
index_t
GemmN
=
C
*
X
*
Y
;
const
index_t
GemmKBatch
=
batch_k
;
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1Number
*
K0PerBlock
*
GemmKBatch
)
*
K0PerBlock
;
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1Number
;
if
constexpr
(
ConvBackwardWeightSpecialization
==
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
// A: output tensor
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_
naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
));
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor
(
make_
tuple
(
N
*
Ho
*
Wo
,
K
),
make_tuple
(
OutWStride
,
OutKStride
));
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_gemmkpad_gemmmpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
GemmK1Number
*
K0PerBlock
*
GemmKBatch
,
MPerBlock
),
Sequence
<
true
,
true
>
{});
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
out_gemmkpad_gemmmpad_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
out_gemmkpad_gemmmpad_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// B: input tensor
const
auto
in_gemmktotal_gemmn_grid_desc
=
make_
naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Hi
*
Wi
,
C
));
const
auto
in_gemmktotal_gemmn_grid_desc
=
make_naive_tensor_descriptor
(
make_
tuple
(
N
*
Hi
*
Wi
,
C
),
make_tuple
(
InWStride
,
InCStride
));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_gemmkpad_gemmnpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
GemmK1Number
*
K0PerBlock
*
GemmKBatch
,
NPerBlock
),
Sequence
<
true
,
true
>
{});
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
in_gemmkpad_gemmnpad_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
in_gemmkpad_gemmnpad_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
*
X
*
C
));
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K
,
Y
*
X
*
C
),
make_tuple
(
WeiKStride
,
WeiCStride
));
const
auto
wei_gemmmpad_gemmnpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmm_gemmn_grid_desc
,
make_tuple
(
MPerBlock
,
NPerBlock
),
Sequence
<
true
,
true
>
{});
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
wei_gemmm
pad
_gemmn
pad
_grid_desc
);
}
else
{
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_
naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
));
const
auto
in_n_hi_wi_c_grid_desc
=
make_
naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor
(
make_
tuple
(
N
*
Ho
*
Wo
,
K
),
make_tuple
(
OutWStride
,
OutKStride
));
const
auto
in_n_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor
(
make_
tuple
(
N
,
Hi
,
Wi
,
C
),
make_tuple
(
InNStride
,
InHStride
,
InWStride
,
InCStride
));
// A: output tensor
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_gemmkpad_gemmmpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
GemmK1Number
*
K0PerBlock
*
GemmKBatch
,
MPerBlock
),
Sequence
<
true
,
true
>
{});
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
out_gemmkpad_gemmmpad_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
out_gemmkpad_gemmmpad_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
...
@@ -488,39 +554,44 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_gemmkpad_gemmnpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
GemmK1Number
*
K0PerBlock
*
GemmKBatch
,
NPerBlock
),
Sequence
<
true
,
true
>
{});
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
in_gemmkpad_gemmnpad_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
in_gemmkpad_gemmnpad_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
*
X
*
C
));
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K
,
Y
*
X
*
C
),
make_tuple
(
WeiKStride
,
WeiCStride
));
const
auto
wei_gemmmpad_gemmnpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmm_gemmn_grid_desc
,
make_tuple
(
MPerBlock
,
NPerBlock
),
Sequence
<
true
,
true
>
{});
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
wei_gemmm
pad
_gemmn
pad
_grid_desc
);
}
}
// function end
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
const
ck
::
index_t
N
,
const
ck
::
index_t
K
,
const
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_length
s
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_length
s
,
const
std
::
array
<
index_t
,
N
DimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_stride
s
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_stride
s
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
...
...
@@ -529,110 +600,120 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
{
using
namespace
ck
;
const
index_t
Di
=
input_spatial_lengths
[
0
];
const
index_t
Hi
=
input_spatial_lengths
[
1
];
const
index_t
Wi
=
input_spatial_lengths
[
2
];
const
index_t
Do
=
output_spatial_lengths
[
0
];
const
index_t
Ho
=
output_spatial_lengths
[
1
];
const
index_t
Wo
=
output_spatial_lengths
[
2
];
const
index_t
Z
=
filter_spatial_lengths
[
0
];
const
index_t
Y
=
filter_spatial_lengths
[
1
];
const
index_t
X
=
filter_spatial_lengths
[
2
];
const
index_t
InLeftPadD
=
input_left_pads
[
0
];
const
index_t
InLeftPadH
=
input_left_pads
[
1
];
const
index_t
InLeftPadW
=
input_left_pads
[
2
];
const
index_t
InRightPadD
=
input_right_pads
[
0
];
const
index_t
InRightPadH
=
input_right_pads
[
1
];
const
index_t
InRightPadW
=
input_right_pads
[
2
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
1
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
2
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
2
];
const
index_t
N
=
a_g_n_c_wis_lengths
[
I1
];
const
index_t
K
=
b_g_k_c_xs_lengths
[
I1
];
const
index_t
C
=
a_g_n_c_wis_lengths
[
I2
];
const
index_t
Di
=
a_g_n_c_wis_lengths
[
spatial_offset
+
I0
];
const
index_t
Hi
=
a_g_n_c_wis_lengths
[
spatial_offset
+
I1
];
const
index_t
Wi
=
a_g_n_c_wis_lengths
[
spatial_offset
+
I2
];
const
index_t
Do
=
e_g_n_k_wos_lengths
[
spatial_offset
+
I0
];
const
index_t
Ho
=
e_g_n_k_wos_lengths
[
spatial_offset
+
I1
];
const
index_t
Wo
=
e_g_n_k_wos_lengths
[
spatial_offset
+
I2
];
const
index_t
Z
=
b_g_k_c_xs_lengths
[
spatial_offset
+
I0
];
const
index_t
Y
=
b_g_k_c_xs_lengths
[
spatial_offset
+
I1
];
const
index_t
X
=
b_g_k_c_xs_lengths
[
spatial_offset
+
I2
];
const
index_t
InLeftPadD
=
input_left_pads
[
I0
];
const
index_t
InLeftPadH
=
input_left_pads
[
I1
];
const
index_t
InLeftPadW
=
input_left_pads
[
I2
];
const
index_t
InRightPadD
=
input_right_pads
[
I0
];
const
index_t
InRightPadH
=
input_right_pads
[
I1
];
const
index_t
InRightPadW
=
input_right_pads
[
I2
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
I0
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
I1
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
I2
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
I0
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
I1
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
I2
];
const
auto
InNStride
=
a_g_n_c_wis_strides
[
I1
];
const
auto
InCStride
=
a_g_n_c_wis_strides
[
I2
];
const
auto
InDStride
=
a_g_n_c_wis_strides
[
spatial_offset
];
const
auto
InHStride
=
a_g_n_c_wis_strides
[
spatial_offset
+
I1
];
const
auto
InWStride
=
a_g_n_c_wis_strides
[
spatial_offset
+
I2
];
const
auto
WeiKStride
=
b_g_k_c_xs_strides
[
I1
];
const
auto
WeiCStride
=
b_g_k_c_xs_strides
[
I2
];
const
auto
OutKStride
=
e_g_n_k_wos_strides
[
I2
];
const
auto
OutWStride
=
e_g_n_k_wos_strides
[
spatial_offset
+
I2
];
const
index_t
GemmKTotal
=
N
*
Do
*
Ho
*
Wo
;
const
index_t
GemmM
=
K
;
const
index_t
GemmN
=
C
*
Z
*
X
*
Y
;
const
index_t
GemmKBatch
=
batch_k
;
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1Number
*
K0PerBlock
*
GemmKBatch
)
*
K0PerBlock
;
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1Number
;
if
constexpr
(
ConvBackwardWeightSpecialization
==
ConvolutionBackwardWeightSpecialization
::
Filter1x1Stride1Pad0
)
{
// A: output tensor
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_
naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Do
*
Ho
*
Wo
,
K
));
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor
(
make_
tuple
(
N
*
Do
*
Ho
*
Wo
,
K
),
make_tuple
(
OutWStride
,
OutKStride
));
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_gemmkpad_gemmmpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
GemmK1Number
*
K0PerBlock
*
GemmKBatch
,
MPerBlock
),
Sequence
<
true
,
true
>
{});
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
out_gemmkpad_gemmmpad_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
out_gemmkpad_gemmmpad_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// B: input tensor
const
auto
in_gemmktotal_gemmn_grid_desc
=
make_
naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Di
*
Hi
*
Wi
,
C
));
const
auto
in_gemmktotal_gemmn_grid_desc
=
make_naive_tensor_descriptor
(
make_
tuple
(
N
*
Di
*
Hi
*
Wi
,
C
),
make_tuple
(
InWStride
,
InCStride
));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_gemmkpad_gemmnpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
GemmK1Number
*
K0PerBlock
*
GemmKBatch
,
NPerBlock
),
Sequence
<
true
,
true
>
{});
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
in_gemmkpad_gemmnpad_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
in_gemmkpad_gemmnpad_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Z
*
Y
*
X
*
C
));
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K
,
Z
*
Y
*
X
*
C
),
make_tuple
(
WeiKStride
,
WeiCStride
));
const
auto
wei_gemmmpad_gemmnpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmm_gemmn_grid_desc
,
make_tuple
(
MPerBlock
,
NPerBlock
),
Sequence
<
true
,
true
>
{});
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
wei_gemmm
pad
_gemmn
pad
_grid_desc
);
}
else
{
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Do
*
Ho
*
Wo
,
K
));
const
auto
in_n_di_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
const
auto
out_gemmktotal_gemmm_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Do
*
Ho
*
Wo
,
K
),
make_tuple
(
OutWStride
,
OutKStride
));
const
auto
in_n_di_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
),
make_tuple
(
InNStride
,
InDStride
,
InHStride
,
InWStride
,
InCStride
));
// A: output tensor
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_gemmkpad_gemmmpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmktotal_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
GemmK1Number
*
K0PerBlock
*
GemmKBatch
,
MPerBlock
),
Sequence
<
true
,
true
>
{});
const
auto
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
out_gemmkpad_gemmmpad_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
out_gemmkpad_gemmmpad_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
...
...
@@ -672,27 +753,32 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
make_tuple
(
Sequence
<
1
,
3
,
5
,
7
>
{},
Sequence
<
0
,
2
,
4
,
6
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_gemmkpad_gemmnpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
GemmK1Number
*
K0PerBlock
*
GemmKBatch
,
NPerBlock
),
Sequence
<
true
,
true
>
{});
const
auto
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
in_gemmkpad_gemmnpad_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmKBatch
,
GemmK0
,
GemmK1Number
)),
make_pass_through_transform
(
in_gemmkpad_gemmnpad_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Z
*
Y
*
X
*
C
));
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K
,
Z
*
Y
*
X
*
C
),
make_tuple
(
WeiKStride
,
WeiCStride
));
const
auto
wei_gemmmpad_gemmnpad_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmm_gemmn_grid_desc
,
make_tuple
(
MPerBlock
,
NPerBlock
),
Sequence
<
true
,
true
>
{});
return
make_tuple
(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
wei_gemmm
pad
_gemmn
pad
_grid_desc
);
}
}
// function end
...
...
@@ -701,22 +787,22 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static
auto
GetABCGridDesc
()
{
return
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
1
>
(
1
,
1
,
1
,
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
1
);
{
1
},
{
1
},
{
1
}
,
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
1
);
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
GetABCGridDesc
()
{
return
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
2
>
(
1
,
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
1
);
{
1
,
1
}
,
{
1
,
1
},
{
1
,
1
}
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
1
);
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
GetABCGridDesc
()
{
return
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
3
>
(
1
,
1
,
1
,
return
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
3
>
(
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
...
...
@@ -785,11 +871,11 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/*
a_g_n_c_wis_strides
*/
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/*
b_g_k_c_xs_strides
*/
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/*
e_g_n_k_wos_strides
*/
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
...
...
@@ -809,38 +895,24 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
a_element_op_
{
out_element_op
},
b_element_op_
{
wei_element_op
},
c_element_op_
{
in_element_op
},
Conv_G_
{
a_g_n_c_wis_lengths
[
0
]},
Conv_N_
{
a_g_n_c_wis_lengths
[
1
]},
Conv_K_
{
b_g_k_c_xs_lengths
[
1
]},
Conv_C_
{
a_g_n_c_wis_lengths
[
2
]},
input_spatial_lengths_
{},
filter_spatial_lengths_
{},
output_spatial_lengths_
{},
Conv_G_
{
a_g_n_c_wis_lengths
[
I0
]},
Conv_K_
{
b_g_k_c_xs_lengths
[
I1
]},
Conv_C_
{
a_g_n_c_wis_lengths
[
I2
]},
filter_lengths_
{
b_g_k_c_xs_lengths
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
},
k_batch_
{
split_k
}
{
constexpr
index_t
spatial_offset
=
3
;
std
::
copy
(
begin
(
a_g_n_c_wis_lengths
)
+
spatial_offset
,
end
(
a_g_n_c_wis_lengths
),
begin
(
input_spatial_lengths_
));
std
::
copy
(
begin
(
b_g_k_c_xs_lengths
)
+
spatial_offset
,
end
(
b_g_k_c_xs_lengths
),
begin
(
filter_spatial_lengths_
));
std
::
copy
(
begin
(
e_g_n_k_wos_lengths
)
+
spatial_offset
,
end
(
e_g_n_k_wos_lengths
),
begin
(
output_spatial_lengths_
));
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>
(
Conv_N_
,
Conv_K_
,
Conv_C_
,
input_spatial_lengths_
,
filter_spatial_lengths_
,
output_spatial_lengths_
,
a_g_n_c_wis_lengths
,
// input
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
// weight
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
// output
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
...
...
@@ -863,24 +935,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_grid_desc_m_n_
,
M01
,
N01
,
k_batch_
);
// A/B/C Batch Stride
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
Conv_N_
*
Conv_K_
*
std
::
accumulate
(
begin
(
output_spatial_lengths_
),
end
(
output_spatial_lengths_
),
index_t
{
1
},
std
::
multiplies
<>
{});
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
Conv_N_
*
Conv_C_
*
std
::
accumulate
(
begin
(
input_spatial_lengths_
),
end
(
input_spatial_lengths_
),
index_t
{
1
},
std
::
multiplies
<>
{});
compute_ptr_offset_of_batch_
.
BatchStrideC_
=
Conv_K_
*
Conv_C_
*
std
::
accumulate
(
begin
(
filter_spatial_lengths_
),
end
(
filter_spatial_lengths_
),
index_t
{
1
},
std
::
multiplies
<>
{});
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
e_g_n_k_wos_strides
[
I0
];
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
a_g_n_c_wis_strides
[
I0
];
compute_ptr_offset_of_batch_
.
BatchStrideC_
=
b_g_k_c_xs_strides
[
I0
];
}
const
ADataType
*
p_a_grid_
;
...
...
@@ -908,13 +965,10 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
// for checking IsSupportedArgument()
const
index_t
Conv_G_
;
const
index_t
Conv_N_
;
const
index_t
Conv_K_
;
const
index_t
Conv_C_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
filter_lengths_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads_
;
...
...
@@ -1036,10 +1090,14 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
// check device
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx1030"
||
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
ck
::
get_device_name
()
==
"gfx1102"
))
// DL version only supports split_k equal to 1
if
(
arg
.
k_batch_
!=
1
)
return
false
;
if
constexpr
(
!
((
NDimSpatial
==
1
&&
(
is_NWGK_GKXC_NWGC
||
is_GNWK_GKXC_GNWC
))
||
(
NDimSpatial
==
2
&&
(
is_NHWGK_GKYXC_NHWGC
||
is_GNHWK_GKYXC_GNHWC
))
||
(
NDimSpatial
==
3
&&
(
is_NDHWGK_GKZYXC_NDHWGC
||
is_GNDHWK_GKZYXC_GNDHWC
))))
{
return
false
;
}
...
...
@@ -1050,8 +1108,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
// check if it's 1x1, stride=1 pad = 0 conv
for
(
int
i
=
0
;
i
<
NDimSpatial
;
i
++
)
{
if
(
!
(
arg
.
filter_spatial_lengths_
[
i
]
==
1
&&
arg
.
conv_filter_strides_
[
i
]
==
1
&&
arg
.
input_left_pads_
[
i
]
==
0
&&
arg
.
input_right_pads_
[
i
]
==
0
))
if
(
!
(
arg
.
filter_lengths_
[
spatial_offset
+
i
]
==
1
&&
arg
.
conv_filter_strides_
[
i
]
==
1
&&
arg
.
input_left_pads_
[
i
]
==
0
&&
arg
.
input_right_pads_
[
i
]
==
0
))
{
return
false
;
}
...
...
@@ -1206,7 +1265,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGroupedConvBwdWeight
GnwcGkxcGnwk
_Dl"
str
<<
"DeviceGroupedConvBwdWeight_Dl"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
dcf15456
...
...
@@ -27,6 +27,12 @@ struct PassThrough
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
double
>
(
float
&
y
,
const
double
&
x
)
const
{
y
=
type_convert
<
float
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
...
...
@@ -81,6 +87,12 @@ struct PassThrough
y
=
type_convert
<
int8_t
>
(
x
);
}
template
<
>
__host__
__device__
void
operator
()
<
int8_t
,
float
>
(
int8_t
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
int8_t
>
(
x
);
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
__host__
__device__
void
operator
()
<
int4_t
,
int4_t
>
(
int4_t
&
y
,
const
int4_t
&
x
)
const
...
...
@@ -89,6 +101,7 @@ struct PassThrough
}
#endif
#if defined CK_ENABLE_FP8
template
<
>
__host__
__device__
void
operator
()
<
f8_t
,
f8_t
>
(
f8_t
&
y
,
const
f8_t
&
x
)
const
{
...
...
@@ -118,6 +131,7 @@ struct PassThrough
{
y
=
type_convert
<
f8_t
>
(
x
);
}
#endif
};
struct
UnaryConvert
...
...
@@ -146,6 +160,7 @@ struct ConvertBF16RTN
}
};
#if defined CK_ENABLE_FP8
struct
ConvertF8SR
{
// convert to fp8 using stochastic rounding (SR)
...
...
@@ -162,6 +177,7 @@ struct ConvertF8SR
y
=
f8_convert_sr
<
Y
>
(
x
);
}
};
#endif
struct
Scale
{
...
...
@@ -412,14 +428,19 @@ struct Swish
{
Swish
(
float
beta
=
1.0
f
)
:
beta_
(
beta
)
{}
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
,
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
double
>::
value
||
is_same
<
X
,
ck
::
half_t
>::
value
,
"Data type is not supported by this operation!"
);
static_assert
(
is_same
<
Y
,
float
>::
value
||
is_same
<
Y
,
double
>::
value
||
is_same
<
Y
,
ck
::
half_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
x
/
(
ck
::
type_convert
<
T
>
(
1
)
+
ck
::
math
::
exp
(
-
beta_
*
x
));
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
Y
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
};
float
beta_
=
1.0
f
;
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
dcf15456
...
...
@@ -137,13 +137,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
Src
Data
v
;
Dst
Data
v
;
// apply element-wise operation
element_op_
(
v
,
src_buf
[
Number
<
src_offset
>
{}]);
// apply type convert
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
v
);
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
v
;
});
const
bool
is_dst_valid
=
...
...
@@ -1289,13 +1288,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
Src
Data
v
;
Dst
Data
v
;
// apply element-wise operation
element_op_
(
v
,
src_buf
[
Number
<
src_offset
>
{}]);
// apply type convert
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
v
)
;
dst_buf
(
Number
<
dst_offset
>
{})
=
v
;
});
});
}
...
...
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
View file @
dcf15456
...
...
@@ -11,8 +11,14 @@ namespace ck {
enum
struct
DppInstr
{
dpp8_f16_16x16x2
=
0
,
dpp8_f16_1x32x2
=
0
,
dpp8_f16_2x16x2
,
dpp8_f16_2x32x2
,
dpp8_f16_4x16x2
,
dpp8_f16_4x32x2
,
dpp8_f16_8x16x2
,
dpp8_f16_8x32x2
,
dpp8_f16_16x16x2
,
dpp8_f16_32x8x2
};
...
...
@@ -101,6 +107,36 @@ struct dpp_type<DppInstr::dpp8_f16_8x32x2>
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_8x16x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
8
;
static
constexpr
index_t
n_per_wave
=
16
;
static
constexpr
index_t
m_per_lanegroup
=
4
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
4
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_16x16x2
>
{
...
...
@@ -131,6 +167,156 @@ struct dpp_type<DppInstr::dpp8_f16_16x16x2>
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_4x32x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
4
;
static
constexpr
index_t
n_per_wave
=
32
;
static
constexpr
index_t
m_per_lanegroup
=
4
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
4
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_4x16x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
4
;
static
constexpr
index_t
n_per_wave
=
16
;
static
constexpr
index_t
m_per_lanegroup
=
2
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
2
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_1x32x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
1
;
static
constexpr
index_t
n_per_wave
=
32
;
static
constexpr
index_t
m_per_lanegroup
=
1
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
1
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_2x32x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
2
;
static
constexpr
index_t
n_per_wave
=
32
;
static
constexpr
index_t
m_per_lanegroup
=
2
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
2
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_2x16x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
2
;
static
constexpr
index_t
n_per_wave
=
16
;
static
constexpr
index_t
m_per_lanegroup
=
1
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
1
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
typename
BaseType
,
index_t
MPerDpp
,
index_t
NPerDpp
>
struct
DppSelector
{
...
...
@@ -143,6 +329,12 @@ struct DppSelector
return
DppInstr
::
dpp8_f16_8x32x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
8
,
16
>
()
{
return
DppInstr
::
dpp8_f16_8x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
16
,
16
>
()
{
...
...
@@ -155,6 +347,36 @@ struct DppSelector
return
DppInstr
::
dpp8_f16_32x8x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
1
,
32
>
()
{
return
DppInstr
::
dpp8_f16_1x32x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
2
,
32
>
()
{
return
DppInstr
::
dpp8_f16_2x32x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
2
,
16
>
()
{
return
DppInstr
::
dpp8_f16_2x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
4
,
16
>
()
{
return
DppInstr
::
dpp8_f16_4x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
4
,
32
>
()
{
return
DppInstr
::
dpp8_f16_4x32x2
;
}
static
constexpr
auto
selected_dpp
=
dpp_type
<
GetDpp
<
BaseType
,
MPerDpp
,
NPerDpp
>
()
>
{};
__host__
__device__
constexpr
DppSelector
()
...
...
@@ -191,7 +413,6 @@ struct DppSelector
// in the future when the implementation is more generalized.
static_assert
(
selected_dpp
.
share_a
);
static_assert
(
selected_dpp
.
n_per_thread
==
1
);
static_assert
(
selected_dpp
.
m_per_thread
==
selected_dpp
.
lanegroup_size
);
static_assert
(
selected_dpp
.
m_per_lanegroup
==
selected_dpp
.
m_per_thread
);
static_assert
(
selected_dpp
.
n_per_lanegroup
==
selected_dpp
.
n_per_thread
*
selected_dpp
.
lanegroup_size
);
...
...
@@ -215,11 +436,6 @@ struct DppGemm
__host__
__device__
constexpr
DppGemm
()
{
static_assert
(
MPerDpp
==
8
||
MPerDpp
==
16
||
MPerDpp
==
32
,
"MPerDpp must be either 8, 16 or 32."
);
static_assert
(
NPerDpp
==
8
||
NPerDpp
==
16
||
NPerDpp
==
32
,
"NPerDpp must be either 8, 16 or 32."
);
static_assert
(
KPack
%
dpp_instr
.
k_per_dpp
==
0
,
"KPack must be divisible by k_per_dpp."
);
}
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
dcf15456
...
...
@@ -456,6 +456,7 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
}
};
#if defined CK_ENABLE_FP8
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f8f8
>
{
...
...
@@ -499,6 +500,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
intrin_mfma_f32_16x16x32f8f8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
};
#endif
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
>
struct
MfmaSelector
...
...
@@ -640,6 +642,7 @@ struct MfmaSelector
}
#endif
#if defined CK_ENABLE_FP8
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
>
()
{
...
...
@@ -651,6 +654,7 @@ struct MfmaSelector
{
return
MfmaInstr
::
mfma_f32_16x16x32f8f8
;
}
#endif
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
>
()
>
{};
...
...
@@ -852,7 +856,11 @@ struct XdlopsGemm
{
static_assert
(
is_same
<
base_type
,
double
>::
value
||
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
||
is_same
<
base_type
,
int8_t
>::
value
||
is_same
<
base_type
,
f8_t
>::
value
,
is_same
<
base_type
,
int8_t
>::
value
#if defined CK_ENABLE_FP8
||
is_same
<
base_type
,
f8_t
>::
value
#endif
,
"base base_type must be double, float, half, bfloat16, and int8_t!"
);
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
dcf15456
...
...
@@ -1127,7 +1127,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
src_thread_element_valid
?
0
:
0x80000000
;
#if defined CK_ENABLE_FP8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
...
...
@@ -1136,10 +1136,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
}
else
{
#endif
return
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8
}
#endif
#else
#if defined CK_ENABLE_FP8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
...
...
@@ -1148,11 +1152,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
}
else
{
#endif
vector_t
tmp
=
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
return
src_thread_element_valid
?
tmp
:
vector_t
(
0
);
#if defined CK_ENABLE_FP8
}
#endif
#endif
}
// buffer_load requires:
...
...
@@ -1209,7 +1216,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x80000000
;
#if defined CK_ENABLE_FP8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
auto
tmp
=
...
...
@@ -1219,12 +1226,16 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
}
else
{
#endif
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8
}
#endif
#else
if
(
dst_thread_element_valid
)
{
#if defined CK_ENABLE_FP8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
auto
tmp
=
bit_cast
<
typename
vector_type_maker
<
int8_t
,
vector_size
>::
type
::
type
>
(
...
...
@@ -1234,9 +1245,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
}
else
{
#endif
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8
}
#endif
}
#endif
}
...
...
include/ck/utility/amd_xdlops.hpp
View file @
dcf15456
...
...
@@ -355,6 +355,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
}
};
#if defined CK_ENABLE_FP8
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16f8f8
;
...
...
@@ -417,5 +418,6 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
#endif
}
};
#endif
}
// namespace ck
#endif
include/ck/utility/data_type.hpp
View file @
dcf15456
...
...
@@ -12,7 +12,12 @@ using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using
int4_t
=
_BitInt
(
4
);
#endif
using
f8_t
=
uint8_t
;
#if defined CK_ENABLE_FP8
using
f8_t
=
_BitInt
(
8
);
#endif
#if defined CK_ENABLE_BF8
using
bf8_t
=
unsigned
_BitInt
(
8
);
#endif
// vector_type
template
<
typename
T
,
index_t
N
>
...
...
@@ -143,14 +148,24 @@ struct scalar_type<int4_t>
};
#endif
#if defined CK_ENABLE_FP8
template
<
>
struct
scalar_type
<
f8_t
>
{
using
type
=
f8_t
;
static
constexpr
index_t
vector_size
=
1
;
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
struct
scalar_type
<
bf8_t
>
{
using
type
=
bf8_t
;
static
constexpr
index_t
vector_size
=
1
;
};
#endif
//
template
<
typename
T
>
struct
vector_type
<
T
,
1
>
{
...
...
@@ -953,12 +968,24 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
// f8
#if defined CK_ENABLE_FP8
using
f8x2_t
=
typename
vector_type
<
f8_t
,
2
>::
type
;
using
f8x4_t
=
typename
vector_type
<
f8_t
,
4
>::
type
;
using
f8x8_t
=
typename
vector_type
<
f8_t
,
8
>::
type
;
using
f8x16_t
=
typename
vector_type
<
f8_t
,
16
>::
type
;
using
f8x32_t
=
typename
vector_type
<
f8_t
,
32
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
#endif
// bf8
#if defined CK_ENABLE_BF8
using
bf8x2_t
=
typename
vector_type
<
bf8_t
,
2
>::
type
;
using
bf8x4_t
=
typename
vector_type
<
bf8_t
,
4
>::
type
;
using
bf8x8_t
=
typename
vector_type
<
bf8_t
,
8
>::
type
;
using
bf8x16_t
=
typename
vector_type
<
bf8_t
,
16
>::
type
;
using
bf8x32_t
=
typename
vector_type
<
bf8_t
,
32
>::
type
;
using
bf8x64_t
=
typename
vector_type
<
bf8_t
,
64
>::
type
;
#endif
template
<
typename
T
>
struct
NumericLimits
...
...
@@ -1006,21 +1033,109 @@ struct NumericLimits<int4_t>
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#if defined CK_ENABLE_FP8
template
<
>
struct
NumericLimits
<
f8_t
>
{
// negative zero nan mode with exp bias = 8
static
constexpr
uint8_t
binary_min
=
0x08
;
// 0b00001000
static
constexpr
uint8_t
binary_max
=
0x77
;
// 0b01110111
static
constexpr
uint8_t
binary_lowest
=
0xF7
;
// 0b11110111
static
constexpr
uint8_t
binary_max
=
0x7F
;
// 0b01111111
static
constexpr
uint8_t
binary_lowest
=
0xFF
;
// 0b11111111
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
// ieee mode with exp bias = 7
// static constexpr uint8_t binary_min = 0x08; // 0b00001000
// static constexpr uint8_t binary_max = 0x77; // 0b01110111
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
__host__
__device__
static
constexpr
f8_t
Min
()
{
return
f8_t
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_t
Max
()
{
return
f8_t
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_t
Lowest
()
{
return
f8_t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
f8_t
(
binary_qnan
);
}
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
struct
NumericLimits
<
bf8_t
>
{
// negative zero nan mode with exp bias = 16
static
constexpr
uint8_t
binary_min
=
0x04
;
// 0b00000100
static
constexpr
uint8_t
binary_max
=
0x7F
;
// 0b01111111
static
constexpr
uint8_t
binary_lowest
=
0xFF
;
// 0b11111111
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
// ieee mode with exp bias = 15
// static constexpr uint8_t binary_min = 0x04; // 0b00000100
// static constexpr uint8_t binary_max = 0x7B; // 0b01111011
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
__host__
__device__
static
constexpr
f8_t
Min
()
{
return
b
it_cast
<
f8_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
b
f8_t
Min
()
{
return
bf8_t
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_t
Max
()
{
return
b
it_cast
<
f8_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
b
f8_t
Max
()
{
return
bf8_t
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_t
Lowest
()
{
return
b
it_cast
<
f8_t
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
b
f8_t
Lowest
()
{
return
bf8_t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
b
it_cast
<
f8_t
>
(
binary_qnan
);
}
__host__
__device__
static
constexpr
b
f8_t
QuietNaN
()
{
return
bf8_t
(
binary_qnan
);
}
};
#endif
template
<
typename
T
>
struct
NumericUtils
{
};
template
<
>
struct
NumericUtils
<
float
>
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
23
;
static
constexpr
uint32_t
nan_mask
=
0x7F800000
;
static
constexpr
uint32_t
head_mask
=
0xFF800000
;
static
constexpr
uint32_t
mant_mask
=
0x7FFFFF
;
static
constexpr
uint32_t
exp_mask
=
0xFF
;
static
constexpr
uint32_t
Inf
=
0x7F800000
;
static
constexpr
uint32_t
NegInf
=
0xFF800000
;
static
constexpr
uint32_t
NaN
=
0x7F800001
;
static
constexpr
uint32_t
Neg0
=
0x80000000
;
using
bitwise_type
=
uint32_t
;
};
template
<
>
struct
NumericUtils
<
half_t
>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
10
;
static
constexpr
uint16_t
nan_mask
=
0x7C00
;
static
constexpr
uint16_t
head_mask
=
0xFC00
;
static
constexpr
uint16_t
mant_mask
=
0x3FF
;
static
constexpr
uint16_t
exp_mask
=
0x1F
;
static
constexpr
uint32_t
Inf
=
0x7C00
;
static
constexpr
uint32_t
NegInf
=
0xFC00
;
static
constexpr
uint32_t
NaN
=
0x7C01
;
static
constexpr
uint32_t
Neg0
=
0x8000
;
using
bitwise_type
=
uint16_t
;
};
#if defined CK_ENABLE_FP8
template
<
>
struct
NumericUtils
<
f8_t
>
{
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
struct
NumericUtils
<
bf8_t
>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
};
#endif
}
// namespace ck
include/ck/utility/f8_utils.hpp
View file @
dcf15456
...
...
@@ -7,6 +7,7 @@
// these conversions are disabled if native conversions available
#if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
namespace
ck
{
// fp8 rounding modes
...
...
@@ -24,53 +25,38 @@ namespace ck::utils {
namespace
{
template
<
typename
T
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
__host__
__device__
f8_t
run_cast_to_f8
(
T
x
,
uint32_t
rng
)
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
__host__
__device__
Y
run_cast_to_f8
(
X
x
,
uint32_t
rng
)
{
//
check data type
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
//
fp8/bf8 exponent/mantissa layout
constexpr
int
out_exp
=
NumericUtils
<
Y
>::
exp
;
constexpr
int
out_mant
=
NumericUtils
<
Y
>::
mant
;
// fp8 exponent/mantissa layout
constexpr
int
f8_exp
=
4
;
constexpr
int
f8_mant
=
3
;
// resulting type exponent/mantissa layout
constexpr
int
type_exp
=
is_half
?
5
:
8
;
constexpr
int
type_mant
=
is_half
?
10
:
23
;
// original type exponent/mantissa layout
constexpr
int
in_exp
=
NumericUtils
<
X
>::
exp
;
constexpr
int
in_mant
=
NumericUtils
<
X
>::
mant
;
int
exponent
;
uint32_t
head
,
mantissa
,
sign
;
// nan code is same for float and half
constexpr
uint8_t
nan_code
=
0x80
;
constexpr
uint32_t
nan_mask
=
is_half
?
0x7C00
:
0x7F800000
;
constexpr
Y
nan_code
=
0x80
;
constexpr
uint32_t
nan_mask
=
NumericUtils
<
X
>::
nan_mask
;
// convert to bitwise
typedef
typename
std
::
conditional
<
std
::
is_same
<
T
,
half_t
>::
value
,
uint16_t
,
uint32_t
>::
type
T_bitwise
;
using
T_bitwise
=
typename
NumericUtils
<
X
>::
bitwise_type
;
T_bitwise
x_bitwise
=
*
(
reinterpret_cast
<
T_bitwise
*>
(
&
x
));
// unpack the input, depends on datatype
if
constexpr
(
is_float
)
{
head
=
x_bitwise
&
0xFF800000
;
mantissa
=
x_bitwise
&
0x7FFFFF
;
exponent
=
(
head
>>
type_mant
)
&
0xFF
;
sign
=
head
>>
(
type_exp
+
type_mant
);
}
else
if
constexpr
(
is_half
)
{
head
=
x_bitwise
&
0xFC00
;
mantissa
=
x_bitwise
&
0x3FF
;
exponent
=
(
head
>>
type_mant
)
&
0x1F
;
sign
=
head
>>
(
type_exp
+
type_mant
);
}
uint32_t
signed_inf
=
(
sign
<<
(
type_exp
+
type_mant
))
+
(((
1
<<
type_exp
)
-
1
)
<<
type_mant
);
uint32_t
drop_mask
=
(
1
<<
(
type_mant
-
f8_mant
))
-
1
;
constexpr
int
max_exp
=
(
1
<<
f8_exp
)
-
(
negative_zero_nan
?
1
:
2
);
head
=
x_bitwise
&
NumericUtils
<
X
>::
head_mask
;
mantissa
=
x_bitwise
&
NumericUtils
<
X
>::
mant_mask
;
exponent
=
(
head
>>
in_mant
)
&
NumericUtils
<
X
>::
exp_mask
;
sign
=
head
>>
(
in_exp
+
in_mant
);
uint32_t
signed_inf
=
(
sign
<<
(
in_exp
+
in_mant
))
+
(((
1
<<
in_exp
)
-
1
)
<<
in_mant
);
uint32_t
drop_mask
=
(
1
<<
(
in_mant
-
out_mant
))
-
1
;
constexpr
int
max_exp
=
(
1
<<
out_exp
)
-
(
negative_zero_nan
?
1
:
2
);
constexpr
int
exp_low_cutoff
=
(
1
<<
(
type
_exp
-
1
))
-
(
1
<<
(
f8
_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
(
1
<<
(
in
_exp
-
1
))
-
(
1
<<
(
out
_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
if
constexpr
(
negative_zero_nan
)
{
...
...
@@ -83,22 +69,35 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
}
// if input is half and output is bf8
if
((
NumericUtils
<
X
>::
mant
==
10
)
&&
(
NumericUtils
<
Y
>::
mant
==
2
)
&&
negative_zero_nan
&&
exponent
==
0
)
{
exponent
+=
1
;
while
(
mantissa
<
(
1
<<
in_mant
))
{
mantissa
<<=
1
;
exponent
-=
1
;
}
mantissa
&=
~
(
1
<<
in_mant
);
}
// check if x is 0.0
if
(
x_bitwise
==
0
)
return
0
;
exponent
-=
exp_low_cutoff
-
1
;
if
(
exponent
<=
0
)
drop_mask
=
(
1
<<
(
type
_mant
-
f8
_mant
+
1
-
exponent
))
-
1
;
mantissa
+=
1
<<
type
_mant
;
drop_mask
=
(
1
<<
(
in
_mant
-
out
_mant
+
1
-
exponent
))
-
1
;
mantissa
+=
1
<<
in
_mant
;
// apply random number if needed
mantissa
+=
(
stoch
?
rng
:
mantissa
)
&
drop_mask
;
if
(
mantissa
>=
(
2
<<
type
_mant
))
if
(
mantissa
>=
(
2
<<
in
_mant
))
{
mantissa
>>=
1
;
exponent
++
;
}
mantissa
>>=
(
type
_mant
-
f8
_mant
);
mantissa
>>=
(
in
_mant
-
out
_mant
);
// check negative exponent
if
(
exponent
<=
0
)
...
...
@@ -118,7 +117,7 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
{
if
(
clip
)
{
mantissa
=
(
1
<<
f8
_mant
)
-
1
;
mantissa
=
(
1
<<
out
_mant
)
-
1
;
exponent
=
max_exp
;
}
else
...
...
@@ -129,125 +128,121 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
// check if x is 0.0 or -0.0
if
(
exponent
==
0
&&
mantissa
==
0
)
return
negative_zero_nan
?
0
:
(
sign
<<
(
f8
_exp
+
f8
_mant
));
mantissa
&=
(
1
<<
f8
_mant
)
-
1
;
return
(
sign
<<
(
f8
_exp
+
f8
_mant
))
|
(
exponent
<<
f8
_mant
)
|
mantissa
;
return
negative_zero_nan
?
0
:
(
sign
<<
(
out
_exp
+
out
_mant
));
mantissa
&=
(
1
<<
out
_mant
)
-
1
;
return
(
sign
<<
(
out
_exp
+
out
_mant
))
|
(
exponent
<<
out
_mant
)
|
mantissa
;
}
template
<
typename
T
,
bool
negative_zero_nan
>
__host__
__device__
T
run_cast_from_f8
(
f8_t
x
)
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
__host__
__device__
Y
run_cast_from_f8
(
X
x
)
{
// check data type
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
// fp8 exponent/mantissa layout
constexpr
int
f8_exp
=
4
;
constexpr
int
f8_mant
=
3
;
// fp8/bf8 exponent/mantissa layout
constexpr
int
in_exp
=
NumericUtils
<
X
>::
exp
;
constexpr
int
in_mant
=
NumericUtils
<
X
>::
mant
;
// resulting type exponent/mantissa layout
constexpr
int
type
_exp
=
is_half
?
5
:
8
;
constexpr
int
type
_mant
=
is_half
?
10
:
23
;
constexpr
int
out
_exp
=
NumericUtils
<
Y
>::
exp
;
constexpr
int
out
_mant
=
NumericUtils
<
Y
>::
mant
;
// prepare the codes
constexpr
uint8_t
nan_code
=
0x80
;
T
fInf
,
fNegInf
,
fNaN
,
fNeg0
;
if
constexpr
(
is_half
)
{
constexpr
uint16_t
ihInf
=
0x7C00
;
constexpr
uint16_t
ihNegInf
=
0xFC00
;
constexpr
uint16_t
ihNaN
=
0x7C01
;
constexpr
uint16_t
ihNeg0
=
0x8000
;
fInf
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihInf
));
fNegInf
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihNegInf
));
fNaN
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihNaN
));
fNeg0
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihNeg0
));
}
else
if
constexpr
(
is_float
)
{
constexpr
uint32_t
ifInf
=
0x7F800000
;
constexpr
uint32_t
ifNegInf
=
0xFF800000
;
constexpr
uint32_t
ifNaN
=
0x7F800001
;
constexpr
uint32_t
ifNeg0
=
0x80000000
;
fInf
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifInf
));
fNegInf
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNegInf
));
fNaN
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNaN
));
fNeg0
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNeg0
));
}
constexpr
X
nan_code
=
0x80
;
Y
Inf
,
NegInf
,
NaN
,
Neg0
;
using
T_bitwise
=
typename
NumericUtils
<
Y
>::
bitwise_type
;
constexpr
T_bitwise
Inf_bitwise
=
NumericUtils
<
Y
>::
Inf
;
constexpr
T_bitwise
NegInf_bitwise
=
NumericUtils
<
Y
>::
NegInf
;
constexpr
T_bitwise
NaN_bitwise
=
NumericUtils
<
Y
>::
NaN
;
constexpr
T_bitwise
Neg0_bitwise
=
NumericUtils
<
Y
>::
Neg0
;
Inf
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
Inf_bitwise
));
NegInf
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
NegInf_bitwise
));
NaN
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
NaN_bitwise
));
Neg0
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
Neg0_bitwise
));
// check if x is 0.0
if
(
x
==
0
)
return
static_cast
<
Y
>
(
0
);
// unpack the input
uint32_t
sign
=
x
>>
(
f8
_exp
+
f8
_mant
);
uint32_t
mantissa
=
x
&
((
1
<<
f8
_mant
)
-
1
);
int
exponent
=
(
x
&
0x7F
)
>>
f8
_mant
;
uint32_t
sign
=
x
>>
(
in
_exp
+
in
_mant
);
uint32_t
mantissa
=
x
&
((
1
<<
in
_mant
)
-
1
);
int
exponent
=
(
x
&
0x7F
)
>>
in
_mant
;
constexpr
int
exp_low_cutoff
=
(
1
<<
(
type
_exp
-
1
))
-
(
1
<<
(
f8
_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
typename
std
::
conditional
<
std
::
is_same
<
T
,
half_t
>::
value
,
uint16_t
,
uint32_t
>::
typ
e
retval
;
(
1
<<
(
out
_exp
-
1
))
-
(
1
<<
(
in
_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
T_bitwis
e
retval
;
if
constexpr
(
negative_zero_nan
)
{
if
(
x
==
nan_code
)
return
f
NaN
;
return
NaN
;
}
else
{
if
(
x
==
nan_code
)
return
fNeg0
;
if
(
exponent
==
((
1
<<
f8_exp
)
-
1
))
return
(
mantissa
==
0
)
?
(
sign
?
fNegInf
:
fInf
)
:
fNaN
;
return
Neg0
;
if
(
exponent
==
((
1
<<
in_exp
)
-
1
))
return
(
mantissa
==
0
)
?
(
sign
?
NegInf
:
Inf
)
:
NaN
;
}
if
((
NumericUtils
<
Y
>::
mant
==
10
)
&&
(
NumericUtils
<
X
>::
mant
==
2
)
&&
!
negative_zero_nan
)
{
retval
=
x
;
retval
<<=
8
;
return
*
(
reinterpret_cast
<
const
Y
*>
(
&
retval
));
}
// subnormal input
if
(
exponent
==
0
)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int
sh
=
1
+
__builtin_clz
(
mantissa
)
-
((
1
+
type_exp
+
type_mant
)
-
f8_mant
);
mantissa
<<=
sh
;
mantissa
&=
((
1
<<
f8_mant
)
-
1
);
exponent
+=
1
-
sh
;
exponent
++
;
while
(
mantissa
<
(
1
<<
in_mant
))
{
mantissa
<<=
1
;
exponent
--
;
}
mantissa
&=
((
1
<<
in_mant
)
-
1
);
}
exponent
+=
exp_low_cutoff
-
1
;
mantissa
<<=
type
_mant
-
f8
_mant
;
mantissa
<<=
out
_mant
-
in
_mant
;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if
(
exponent
<=
0
)
{
mantissa
|=
1
<<
type
_mant
;
mantissa
|=
1
<<
out
_mant
;
mantissa
>>=
1
-
exponent
;
exponent
=
0
;
}
retval
=
(
sign
<<
(
type
_exp
+
type
_mant
))
|
(
exponent
<<
type
_mant
)
|
mantissa
;
return
*
(
reinterpret_cast
<
const
T
*>
(
&
retval
));
retval
=
(
sign
<<
(
out
_exp
+
out
_mant
))
|
(
exponent
<<
out
_mant
)
|
mantissa
;
return
*
(
reinterpret_cast
<
const
Y
*>
(
&
retval
));
}
}
// namespace
template
<
typename
T
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
__host__
__device__
f8_t
cast_to_f8
(
T
x
,
uint32_t
rng
)
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
__host__
__device__
Y
cast_to_f8
(
X
x
,
uint32_t
rng
)
{
// check datatype
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"Only half and float can be casted
to f8
."
);
// check datatype
s
constexpr
bool
is_half
=
std
::
is_same
<
X
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
X
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"Only half and float can be casted."
);
return
run_cast_to_f8
<
T
,
negative_zero_nan
,
clip
,
stoch
>
(
x
,
rng
);
return
run_cast_to_f8
<
X
,
Y
,
negative_zero_nan
,
clip
,
stoch
>
(
x
,
rng
);
}
template
<
typename
T
,
bool
negative_zero_nan
>
__host__
__device__
T
cast_from_f8
(
f8_t
x
)
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
__host__
__device__
Y
cast_from_f8
(
X
x
)
{
// check datatype
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
constexpr
bool
is_half
=
std
::
is_same
<
Y
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
Y
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"only half and float are supported."
);
// check if x is 0.0
if
(
x
==
0
)
return
static_cast
<
T
>
(
0
);
return
run_cast_from_f8
<
T
,
negative_zero_nan
>
(
x
);
return
run_cast_from_f8
<
X
,
Y
,
negative_zero_nan
>
(
x
);
}
}
// namespace ck::utils
#endif
#endif // #if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
#endif // #if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
include/ck/utility/inner_product.hpp
View file @
dcf15456
...
...
@@ -72,6 +72,18 @@ inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, f
c
);
}
template
<
>
__device__
void
inner_product
<
bhalf_t
,
bhalf_t
,
float
>
(
const
bhalf_t
&
a
,
const
bhalf_t
&
b
,
float
&
c
)
{
inner_product
(
type_convert
<
float
>
(
a
),
type_convert
<
float
>
(
b
),
c
);
}
template
<
>
__device__
void
inner_product
<
half_t
,
half_t
,
float
>
(
const
half_t
&
a
,
const
half_t
&
b
,
float
&
c
)
{
inner_product
(
type_convert
<
float
>
(
a
),
type_convert
<
float
>
(
b
),
c
);
}
template
<
>
__device__
void
inner_product
<
half2_t
,
half2_t
,
float
>
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
)
{
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment