Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
05ee41c3
Commit
05ee41c3
authored
Nov 30, 2022
by
Rosty Geyyer
Browse files
Merge branch 'develop' into lwpck-471
parents
37116c98
ad541ad6
Changes
436
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4248 additions
and
252 deletions
+4248
-252
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
...ation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp
...sor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp
+21
-3
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp
...ration/gpu/device/impl/device_batchnorm_backward_impl.hpp
+874
-0
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp
...eration/gpu/device/impl/device_batchnorm_forward_impl.hpp
+9
-2
include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+40
-38
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
...gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
+1583
-0
include/ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp
...gpu/device/impl/device_elementwise_normalization_impl.hpp
+592
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
...n/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
+15
-3
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
...e/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
+19
-3
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
...or_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
+15
-3
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp
...e_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp
+229
-90
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
...e_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
+7
-8
include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp
...r_operation/gpu/device/impl/device_normalization_impl.hpp
+80
-72
include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp
...or_operation/gpu/device/impl/device_reduce_multiblock.hpp
+24
-0
include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp
.../tensor_operation/gpu/device/impl/device_softmax_impl.hpp
+87
-19
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+1
-0
include/ck/tensor_operation/gpu/element/quantization_operation.hpp
...k/tensor_operation/gpu/element/quantization_operation.hpp
+124
-0
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+26
-10
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp
...ultiblock_reduce_second_half_batchnorm_backward_final.hpp
+498
-0
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp
...orm_multiblock/gridwise_multiblock_welford_first_half.hpp
+3
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
View file @
05ee41c3
...
...
@@ -700,7 +700,7 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp
View file @
05ee41c3
...
...
@@ -150,7 +150,10 @@ template <typename ADataType,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsAddExtraN
,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferDstScalarPerVector
>
ck
::
index_t
CThreadTransferDstScalarPerVector
,
ck
::
index_t
NumGemmKPrefetchStage
=
1
,
ck
::
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
ck
::
PipelineVersion
PipelineVer
=
ck
::
PipelineVersion
::
v1
>
struct
DeviceBatchedGemmXdl
:
public
DeviceBatchedGemm
<
ALayout
,
BLayout
,
CLayout
,
...
...
@@ -323,7 +326,10 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
BBlockLdsAddExtraN
,
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
>
;
CThreadTransferDstScalarPerVector
,
NumGemmKPrefetchStage
,
LoopSched
,
PipelineVer
>
;
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
...
...
@@ -622,6 +628,12 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
{
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
}};
// clang-format off
str
<<
"DeviceBatchedGemmXdl"
<<
"<"
...
...
@@ -629,7 +641,13 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
">"
;
<<
">"
<<
" NumGemmKPrefetchStage: "
<<
NumGemmKPrefetchStage
<<
", "
<<
"LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
"PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
];
// clang-format on
return
str
.
str
();
...
...
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp
0 → 100644
View file @
05ee41c3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/welford_helper.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
XDataType
,
typename
DxDataType
,
typename
DyDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
DscaleDbiasDataType
,
typename
MeanVarDataType
,
typename
DyElementwiseOp
,
index_t
Rank
,
index_t
NumBatchNormReduceDim
,
bool
UseMultiblockInK
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XDyDxVectorDim
,
index_t
XSrcVectorSize
,
index_t
DySrcVectorSize
,
index_t
DxDstVectorSize
,
index_t
ScaleSrcVectorSize
,
index_t
DscaleDbiasDstVectorSize
,
index_t
MeanVarSrcVectorSize
>
struct
DeviceBatchNormBwdImpl
:
public
DeviceBatchNormBwd
<
XDataType
,
DxDataType
,
DyDataType
,
AccDataType
,
ScaleDataType
,
DscaleDbiasDataType
,
MeanVarDataType
,
DyElementwiseOp
,
Rank
,
NumBatchNormReduceDim
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
"Invalid thread cluster size assignments!"
);
static_assert
((
XDyDxVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
&&
MThreadSliceSize
%
DySrcVectorSize
==
0
&&
MThreadSliceSize
%
DxDstVectorSize
==
0
)
||
(
XDyDxVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
&&
KThreadSliceSize
%
DySrcVectorSize
==
0
&&
KThreadSliceSize
%
DxDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumBatchNormReduceDim
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
auto
MakeXY2dDescriptor
(
const
std
::
array
<
index_t
,
Rank
>&
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>&
xyStrides
,
int
blkGroupSize
,
int
numBlockTileIteration
)
{
const
auto
tupleXYLengths
=
generate_tuple
([
&
](
auto
I
)
{
return
xyLengths
[
I
];
},
Number
<
Rank
>
{});
const
auto
tupleXYStrides
=
generate_tuple
([
&
](
auto
I
)
{
return
xyStrides
[
I
];
},
Number
<
Rank
>
{});
const
auto
raw_grid_desc
=
make_naive_tensor_descriptor
(
tupleXYLengths
,
tupleXYStrides
);
const
auto
grid_desc_m_k
=
[
&
]()
{
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
using
ReduceDims
=
typename
arithmetic_sequence_gen
<
NumInvariantDim
,
Rank
,
1
>::
type
;
const
auto
reduceDimLengths
=
generate_tuple
([
&
](
auto
I
)
{
return
xyLengths
[
NumInvariantDim
+
I
];
},
Number
<
NumBatchNormReduceDim
>
{});
const
auto
invariantDimLengths
=
generate_tuple
([
&
](
auto
I
)
{
return
xyLengths
[
I
];
},
Number
<
NumInvariantDim
>
{});
return
transform_tensor_descriptor
(
raw_grid_desc
,
make_tuple
(
make_merge_transform
(
invariantDimLengths
),
make_merge_transform
(
reduceDimLengths
)),
make_tuple
(
InvariantDims
{},
ReduceDims
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}();
const
auto
invariantLength
=
grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
reduceLength
=
grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
int
workSizePerBlock
=
K_BlockTileSize
*
numBlockTileIteration
;
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
kPad
=
workSizePerBlock
*
blkGroupSize
-
reduceLength
;
auto
grid_desc_m_k_padded
=
transform_tensor_descriptor
(
grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
mPad
),
make_right_pad_transform
(
reduceLength
,
kPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
grid_desc_m_k_padded
);
};
static
auto
MakeMultiblockFirstReduceOutputMG2dDescriptor
(
int
invariantLength
,
int
blkGroupSize
)
{
const
auto
grid_desc_m_g
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
invariantLength
,
blkGroupSize
));
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
grid_desc_m_g_padded
=
transform_tensor_descriptor
(
grid_desc_m_g
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
mPad
),
make_pass_through_transform
(
blkGroupSize
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
grid_desc_m_g_padded
);
};
static
auto
MakeMultiblockFinalReduceInputMK2dDescriptor
(
int
invariantLength
,
int
blkGroupSize
)
{
const
auto
reduceLength
=
blkGroupSize
;
const
auto
grid_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
invariantLength
,
reduceLength
));
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
kPad
=
math
::
integer_least_multiple
(
reduceLength
,
KThreadClusterSize
)
-
reduceLength
;
auto
grid_desc_m_k_padded
=
transform_tensor_descriptor
(
grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
mPad
),
make_right_pad_transform
(
reduceLength
,
kPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
grid_desc_m_k_padded
);
};
static
auto
MakeScaleBiasMeanVar1dDescriptor
(
const
std
::
array
<
index_t
,
NumInvariantDim
>&
lengths
,
const
std
::
array
<
index_t
,
NumInvariantDim
>&
strides
)
{
const
auto
tupleLengths
=
generate_tuple
([
&
](
auto
I
)
{
return
lengths
[
I
];
},
Number
<
NumInvariantDim
>
{});
const
auto
tupleStrides
=
generate_tuple
([
&
](
auto
I
)
{
return
strides
[
I
];
},
Number
<
NumInvariantDim
>
{});
auto
raw_grid_desc
=
make_naive_tensor_descriptor
(
tupleLengths
,
tupleStrides
);
auto
grid_desc_m
=
transform_tensor_descriptor
(
raw_grid_desc
,
make_tuple
(
make_merge_transform
(
tupleLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
invariantLength
=
grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
grid_desc_m_padded
=
transform_tensor_descriptor
(
grid_desc_m
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
mPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
(
grid_desc_m_padded
);
};
using
XYGridDesc_M_K
=
decltype
(
MakeXY2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
ScaleBiasGridDesc_M
=
decltype
(
MakeScaleBiasMeanVar1dDescriptor
({
1
},
{
1
}));
using
MeanVarGridDesc_M
=
ScaleBiasGridDesc_M
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
dyStrides
,
const
std
::
array
<
index_t
,
Rank
>
dxStrides
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
ck
::
index_t
,
NumInvariantDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
ck
::
index_t
,
NumInvariantDim
>
bnScaleStrides
,
const
std
::
array
<
ck
::
index_t
,
NumInvariantDim
>
bnDscaleDbiasStrides
,
const
std
::
array
<
ck
::
index_t
,
NumInvariantDim
>
bnMeanVarStrides
,
const
XDataType
*
p_x
,
const
DyDataType
*
p_dy
,
const
ScaleDataType
*
p_scale
,
const
MeanVarDataType
*
p_savedMean
,
const
MeanVarDataType
*
p_savedInvVar
,
const
DyElementwiseOp
dy_elementwise_op
,
double
epsilon
,
DxDataType
*
p_dx
,
DscaleDbiasDataType
*
p_dscale
,
DscaleDbiasDataType
*
p_dbias
)
:
bnScaleBiasMeanVarLengths_
(
bnScaleBiasMeanVarLengths
),
bnScaleStrides_
(
bnScaleStrides
),
bnDscaleDbiasStrides_
(
bnDscaleDbiasStrides
),
bnMeanVarStrides_
(
bnMeanVarStrides
),
p_x_
(
p_x
),
p_dy_
(
p_dy
),
p_scale_
(
p_scale
),
p_savedMean_
(
p_savedMean
),
p_savedInvVar_
(
p_savedInvVar
),
dy_elementwise_op_
(
dy_elementwise_op
),
p_dx_
(
p_dx
),
p_dscale_
(
p_dscale
),
p_dbias_
(
p_dbias
)
{
xyLengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumBatchNormReduceDim
>
(
xyLengths
,
reduceDims
);
xStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumBatchNormReduceDim
>
(
xStrides
,
reduceDims
);
dyStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumBatchNormReduceDim
>
(
dyStrides
,
reduceDims
);
dxStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumBatchNormReduceDim
>
(
dxStrides
,
reduceDims
);
std
::
tie
(
invariant_length
,
reduce_length
)
=
get_2d_lengths
<
Rank
,
NumBatchNormReduceDim
>
(
xyLengths_
);
epsilon_
=
type_convert
<
AccDataType
>
(
epsilon
);
haveSavedMeanInvVar_
=
(
p_savedMean_
!=
nullptr
&&
p_savedInvVar_
!=
nullptr
);
if
(
UseMultiblockInK
)
{
int
iterations
=
1
;
while
(
true
)
{
int
testBlkGroupSize
=
(
reduce_length
+
(
K_BlockTileSize
*
iterations
)
-
1
)
/
(
K_BlockTileSize
*
iterations
);
// we want the blkGroupSize be not more than 128
if
(
testBlkGroupSize
<=
128
)
break
;
iterations
++
;
};
blkGroupSize
=
(
reduce_length
+
(
K_BlockTileSize
*
iterations
)
-
1
)
/
(
K_BlockTileSize
*
iterations
);
numBlockTileIteration
=
iterations
;
}
else
{
blkGroupSize
=
1
;
numBlockTileIteration
=
(
reduce_length
+
K_BlockTileSize
-
1
)
/
K_BlockTileSize
;
};
gridSize
=
(
invariant_length
+
M_BlockTileSize
-
1
)
/
M_BlockTileSize
*
blkGroupSize
;
x_grid_desc_m_k
=
MakeXY2dDescriptor
(
xyLengths_
,
xStrides_
,
blkGroupSize
,
numBlockTileIteration
);
dy_grid_desc_m_k
=
MakeXY2dDescriptor
(
xyLengths_
,
dyStrides_
,
blkGroupSize
,
numBlockTileIteration
);
dx_grid_desc_m_k
=
MakeXY2dDescriptor
(
xyLengths_
,
dxStrides_
,
blkGroupSize
,
numBlockTileIteration
);
scale_grid_desc_m
=
MakeScaleBiasMeanVar1dDescriptor
(
bnScaleBiasMeanVarLengths
,
bnScaleStrides
);
dscale_dbias_grid_desc_m
=
MakeScaleBiasMeanVar1dDescriptor
(
bnScaleBiasMeanVarLengths
,
bnDscaleDbiasStrides
);
mean_var_grid_desc_m
=
MakeScaleBiasMeanVar1dDescriptor
(
bnScaleBiasMeanVarLengths
,
bnMeanVarStrides
);
}
AccDataType
epsilon_
;
bool
haveSavedMeanInvVar_
;
std
::
array
<
index_t
,
Rank
>
xyLengths_
;
std
::
array
<
index_t
,
Rank
>
xStrides_
;
std
::
array
<
index_t
,
Rank
>
dyStrides_
;
std
::
array
<
index_t
,
Rank
>
dxStrides_
;
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarLengths_
;
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleStrides_
;
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnDscaleDbiasStrides_
;
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnMeanVarStrides_
;
const
XDataType
*
p_x_
;
const
DyDataType
*
p_dy_
;
const
ScaleDataType
*
p_scale_
;
const
MeanVarDataType
*
p_savedMean_
;
const
MeanVarDataType
*
p_savedInvVar_
;
const
DyElementwiseOp
dy_elementwise_op_
;
DxDataType
*
p_dx_
;
DscaleDbiasDataType
*
p_dscale_
;
DscaleDbiasDataType
*
p_dbias_
;
long_index_t
invariant_length
;
long_index_t
reduce_length
;
int
blkGroupSize
;
int
numBlockTileIteration
;
size_t
gridSize
;
XYGridDesc_M_K
x_grid_desc_m_k
;
XYGridDesc_M_K
dy_grid_desc_m_k
;
XYGridDesc_M_K
dx_grid_desc_m_k
;
ScaleBiasGridDesc_M
scale_grid_desc_m
;
ScaleBiasGridDesc_M
dscale_dbias_grid_desc_m
;
MeanVarGridDesc_M
mean_var_grid_desc_m
;
void
*
workspace_mean
;
void
*
workspace_variance
;
void
*
workspace_count
;
void
*
workspace_savedMean
;
void
*
workspace_savedInvVar
;
void
*
workspace_reduce_dscale
;
void
*
workspace_reduce_dbias
;
};
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
pArg
)
const
override
{
const
Argument
*
pArg_
=
dynamic_cast
<
const
Argument
*>
(
pArg
);
size_t
workspace_size
=
0
;
if
(
UseMultiblockInK
&&
pArg_
->
blkGroupSize
>
1
)
{
// workspace for the partial reduced result for dscale
workspace_size
+=
pArg_
->
invariant_length
*
pArg_
->
blkGroupSize
*
sizeof
(
DscaleDbiasDataType
)
+
64
;
// workspace for the partial reduced result for dbias
workspace_size
+=
pArg_
->
invariant_length
*
pArg_
->
blkGroupSize
*
sizeof
(
DscaleDbiasDataType
)
+
64
;
if
(
!
pArg_
->
haveSavedMeanInvVar_
)
{
// workspace for welford intermediate mean
workspace_size
+=
pArg_
->
invariant_length
*
pArg_
->
blkGroupSize
*
sizeof
(
MeanVarDataType
)
+
64
;
// workspace for welford intermediate variance
workspace_size
+=
pArg_
->
invariant_length
*
pArg_
->
blkGroupSize
*
sizeof
(
MeanVarDataType
)
+
64
;
// workspace for welford intermediate count
workspace_size
+=
pArg_
->
invariant_length
*
pArg_
->
blkGroupSize
*
sizeof
(
int32_t
)
+
64
;
// workspace for welford result mean
workspace_size
+=
pArg_
->
invariant_length
*
sizeof
(
MeanVarDataType
)
+
64
;
// workspace for welford result inv_variance
workspace_size
+=
pArg_
->
invariant_length
*
sizeof
(
MeanVarDataType
)
+
64
;
};
}
return
(
workspace_size
);
};
void
SetWorkSpacePointer
(
BaseArgument
*
pArg
,
void
*
p_workspace
)
const
override
{
Argument
*
pArg_
=
dynamic_cast
<
Argument
*>
(
pArg
);
pArg_
->
p_workspace_
=
p_workspace
;
index_t
space_sz
;
// setup buffer for the partial reduced result for dscale
pArg_
->
workspace_reduce_dscale
=
pArg_
->
p_workspace_
;
space_sz
=
pArg_
->
invariant_length
*
pArg_
->
blkGroupSize
*
sizeof
(
DscaleDbiasDataType
);
space_sz
=
math
::
integer_least_multiple
(
space_sz
,
64
);
// setup buffer for the partial reduced result for dbias
pArg_
->
workspace_reduce_dbias
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_reduce_dscale
)
+
space_sz
;
if
(
UseMultiblockInK
&&
pArg_
->
blkGroupSize
>
1
)
{
space_sz
=
pArg_
->
invariant_length
*
pArg_
->
blkGroupSize
*
sizeof
(
DscaleDbiasDataType
);
space_sz
=
math
::
integer_least_multiple
(
space_sz
,
64
);
// setup buffer for welford intermediate mean
pArg_
->
workspace_mean
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_reduce_dbias
)
+
space_sz
;
space_sz
=
pArg_
->
invariant_length
*
pArg_
->
blkGroupSize
*
sizeof
(
MeanVarDataType
);
space_sz
=
math
::
integer_least_multiple
(
space_sz
,
64
);
// setup buffer for welford intermediate varirance
pArg_
->
workspace_variance
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_mean
)
+
space_sz
;
space_sz
=
pArg_
->
invariant_length
*
pArg_
->
blkGroupSize
*
sizeof
(
MeanVarDataType
);
space_sz
=
math
::
integer_least_multiple
(
space_sz
,
64
);
// setup buffer for welford intermediate count
pArg_
->
workspace_count
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_variance
)
+
space_sz
;
space_sz
=
pArg_
->
invariant_length
*
pArg_
->
blkGroupSize
*
sizeof
(
int32_t
);
space_sz
=
math
::
integer_least_multiple
(
space_sz
,
64
);
// setup buffer for welford result mean
pArg_
->
workspace_savedMean
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_count
)
+
space_sz
;
space_sz
=
pArg_
->
invariant_length
*
sizeof
(
MeanVarDataType
);
space_sz
=
math
::
integer_least_multiple
(
space_sz
,
64
);
// setup buffer for welford result inv_variance
pArg_
->
workspace_savedInvVar
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_savedMean
)
+
space_sz
;
};
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
float
avg_time
=
0
;
const
auto
mean_var_count_grid_desc_m_g
=
DeviceBatchNormBwdImpl
::
MakeMultiblockFirstReduceOutputMG2dDescriptor
(
arg
.
invariant_length
,
arg
.
blkGroupSize
);
const
auto
dscale_dbias_grid_desc_m_g
=
DeviceBatchNormBwdImpl
::
MakeMultiblockFirstReduceOutputMG2dDescriptor
(
arg
.
invariant_length
,
arg
.
blkGroupSize
);
const
auto
mean_var_count_grid_desc_m_k
=
DeviceBatchNormBwdImpl
::
MakeMultiblockFinalReduceInputMK2dDescriptor
(
arg
.
invariant_length
,
arg
.
blkGroupSize
);
const
auto
dscale_dbias_grid_desc_m_k
=
DeviceBatchNormBwdImpl
::
MakeMultiblockFinalReduceInputMK2dDescriptor
(
arg
.
invariant_length
,
arg
.
blkGroupSize
);
using
MeanVarCountGridDesc_M_G
=
decltype
(
mean_var_count_grid_desc_m_g
);
using
MeanVarCountGridDesc_M_K
=
decltype
(
mean_var_count_grid_desc_m_k
);
using
DscaleDbiasGridDesc_M_G
=
decltype
(
dscale_dbias_grid_desc_m_g
);
using
DscaleDbiasGridDesc_M_K
=
decltype
(
dscale_dbias_grid_desc_m_k
);
using
GridwiseWelfordSecondHalfReduceFirstHalf_
=
GridwiseWelfordSecondHalfReduceFirstHalf
<
XDataType
,
DyDataType
,
AccDataType
,
ScaleDataType
,
DscaleDbiasDataType
,
MeanVarDataType
,
DyElementwiseOp
,
XYGridDesc_M_K
,
MeanVarGridDesc_M
,
MeanVarCountGridDesc_M_K
,
DscaleDbiasGridDesc_M_G
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XDyDxVectorDim
,
XSrcVectorSize
,
DySrcVectorSize
,
MeanVarSrcVectorSize
>
;
using
GridwiseReduceSecondHalfBatchNormBwdFinal_
=
GridwiseReduceSecondHalfBatchNormBackwardFinal
<
XDataType
,
DyDataType
,
DxDataType
,
AccDataType
,
ScaleDataType
,
DscaleDbiasDataType
,
MeanVarDataType
,
DyElementwiseOp
,
XYGridDesc_M_K
,
DscaleDbiasGridDesc_M_K
,
MeanVarGridDesc_M
,
ScaleBiasGridDesc_M
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XDyDxVectorDim
,
XSrcVectorSize
,
DySrcVectorSize
,
DxDstVectorSize
,
ScaleSrcVectorSize
,
DscaleDbiasDstVectorSize
,
MeanVarSrcVectorSize
>
;
if
(
UseMultiblockInK
&&
arg
.
blkGroupSize
>
1
)
{
using
GetReduceCountPerThreadFunctor
=
GetReduceCountPerThreadForMultiblockWelford
<
K_BlockTileSize
,
KThreadSliceSize
>
;
GetReduceCountPerThreadFunctor
get_reduce_count_per_thread
(
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
,
arg
.
reduce_length
);
if
(
!
arg
.
haveSavedMeanInvVar_
)
{
using
GridwiseMultiblockWelfordFirstHalf_
=
GridwiseMultiblockWelfordFirstHalf
<
XDataType
,
AccDataType
,
MeanVarDataType
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_G
,
GetReduceCountPerThreadFunctor
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XDyDxVectorDim
,
XSrcVectorSize
>
;
const
auto
kern_multiblock_welford_first_half
=
kernel_multiblock_welford_first_half
<
GridwiseMultiblockWelfordFirstHalf_
,
XDataType
,
MeanVarDataType
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_G
,
GetReduceCountPerThreadFunctor
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_multiblock_welford_first_half
,
dim3
(
arg
.
gridSize
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k
,
mean_var_count_grid_desc_m_g
,
get_reduce_count_per_thread
,
arg
.
numBlockTileIteration
,
arg
.
p_x_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean
),
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance
),
static_cast
<
int32_t
*>
(
arg
.
workspace_count
));
};
const
auto
kern_welford_second_half_reduce_first_half
=
kernel_welford_second_half_reduce_first_half
<
GridwiseWelfordSecondHalfReduceFirstHalf_
,
XDataType
,
DyDataType
,
AccDataType
,
ScaleDataType
,
DscaleDbiasDataType
,
MeanVarDataType
,
DyElementwiseOp
,
XYGridDesc_M_K
,
MeanVarGridDesc_M
,
MeanVarCountGridDesc_M_K
,
DscaleDbiasGridDesc_M_G
>
;
const
auto
kern_reduce_second_half_batchnorm_backward_final
=
kernel_reduce_second_half_batchnorm_backward_final
<
GridwiseReduceSecondHalfBatchNormBwdFinal_
,
XDataType
,
DyDataType
,
DxDataType
,
ScaleDataType
,
DscaleDbiasDataType
,
MeanVarDataType
,
DyElementwiseOp
,
XYGridDesc_M_K
,
DscaleDbiasGridDesc_M_K
,
MeanVarGridDesc_M
,
ScaleBiasGridDesc_M
>
;
index_t
numDscaleDbiasBlockTileIteration
=
(
arg
.
blkGroupSize
+
KThreadClusterSize
-
1
)
/
KThreadClusterSize
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_welford_second_half_reduce_first_half
,
dim3
(
arg
.
gridSize
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k
,
arg
.
dy_grid_desc_m_k
,
arg
.
mean_var_grid_desc_m
,
mean_var_count_grid_desc_m_k
,
dscale_dbias_grid_desc_m_g
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
,
numDscaleDbiasBlockTileIteration
,
arg
.
epsilon_
,
arg
.
haveSavedMeanInvVar_
,
arg
.
haveSavedMeanInvVar_
?
arg
.
p_savedMean_
:
nullptr
,
arg
.
haveSavedMeanInvVar_
?
arg
.
p_savedInvVar_
:
nullptr
,
arg
.
haveSavedMeanInvVar_
?
nullptr
:
static_cast
<
const
MeanVarDataType
*>
(
arg
.
workspace_mean
),
arg
.
haveSavedMeanInvVar_
?
nullptr
:
static_cast
<
const
MeanVarDataType
*>
(
arg
.
workspace_variance
),
arg
.
haveSavedMeanInvVar_
?
nullptr
:
static_cast
<
const
int32_t
*>
(
arg
.
workspace_count
),
arg
.
dy_elementwise_op_
,
arg
.
haveSavedMeanInvVar_
?
nullptr
:
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_savedMean
),
arg
.
haveSavedMeanInvVar_
?
nullptr
:
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_savedInvVar
),
arg
.
p_x_
,
arg
.
p_dy_
,
static_cast
<
DscaleDbiasDataType
*>
(
arg
.
workspace_reduce_dscale
),
static_cast
<
DscaleDbiasDataType
*>
(
arg
.
workspace_reduce_dbias
));
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_reduce_second_half_batchnorm_backward_final
,
dim3
(
arg
.
gridSize
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k
,
arg
.
dy_grid_desc_m_k
,
arg
.
dx_grid_desc_m_k
,
dscale_dbias_grid_desc_m_k
,
arg
.
mean_var_grid_desc_m
,
arg
.
scale_grid_desc_m
,
arg
.
dscale_dbias_grid_desc_m
,
arg
.
blkGroupSize
,
arg
.
reduce_length
,
arg
.
numBlockTileIteration
,
numDscaleDbiasBlockTileIteration
,
static_cast
<
const
DscaleDbiasDataType
*>
(
arg
.
workspace_reduce_dscale
),
static_cast
<
const
DscaleDbiasDataType
*>
(
arg
.
workspace_reduce_dbias
),
arg
.
haveSavedMeanInvVar_
?
arg
.
p_savedMean_
:
static_cast
<
const
MeanVarDataType
*>
(
arg
.
workspace_savedMean
),
arg
.
haveSavedMeanInvVar_
?
arg
.
p_savedInvVar_
:
static_cast
<
const
MeanVarDataType
*>
(
arg
.
workspace_savedInvVar
),
arg
.
p_x_
,
arg
.
p_dy_
,
arg
.
p_scale_
,
arg
.
dy_elementwise_op_
,
arg
.
p_dx_
,
arg
.
p_dscale_
,
arg
.
p_dbias_
);
}
else
{
using
GetReduceCountPerThreadFunctor
=
GetReduceCountPerThreadForBlockwiseWelford
<
K_BlockTileSize
,
KThreadSliceSize
>
;
GetReduceCountPerThreadFunctor
get_reduce_count_per_thread
(
arg
.
numBlockTileIteration
,
arg
.
reduce_length
);
using
GridwiseBatchNormBackwardWithBlockwiseWelford_
=
GridwiseBatchNormBackwardWithBlockwiseWelford
<
XDataType
,
DyDataType
,
DxDataType
,
AccDataType
,
ScaleDataType
,
DscaleDbiasDataType
,
MeanVarDataType
,
DyElementwiseOp
,
XYGridDesc_M_K
,
ScaleBiasGridDesc_M
,
MeanVarGridDesc_M
,
GetReduceCountPerThreadFunctor
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XDyDxVectorDim
,
XSrcVectorSize
,
DySrcVectorSize
,
DxDstVectorSize
,
ScaleSrcVectorSize
,
DscaleDbiasDstVectorSize
,
MeanVarSrcVectorSize
>
;
const
auto
kern_batchnorm_bwd
=
kernel_batchnorm_backward_with_blockwise_welford
<
GridwiseBatchNormBackwardWithBlockwiseWelford_
,
XDataType
,
DyDataType
,
DxDataType
,
AccDataType
,
ScaleDataType
,
DscaleDbiasDataType
,
MeanVarDataType
,
DyElementwiseOp
,
XYGridDesc_M_K
,
ScaleBiasGridDesc_M
,
MeanVarGridDesc_M
,
GetReduceCountPerThreadFunctor
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_batchnorm_bwd
,
dim3
(
arg
.
gridSize
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k
,
arg
.
dy_grid_desc_m_k
,
arg
.
dx_grid_desc_m_k
,
arg
.
scale_grid_desc_m
,
arg
.
dscale_dbias_grid_desc_m
,
arg
.
mean_var_grid_desc_m
,
get_reduce_count_per_thread
,
arg
.
reduce_length
,
arg
.
numBlockTileIteration
,
arg
.
epsilon_
,
arg
.
p_x_
,
arg
.
p_dy_
,
arg
.
p_scale_
,
arg
.
haveSavedMeanInvVar_
,
arg
.
p_savedMean_
,
arg
.
p_savedInvVar_
,
arg
.
dy_elementwise_op_
,
arg
.
p_dx_
,
arg
.
p_dscale_
,
arg
.
p_dbias_
);
};
return
(
avg_time
);
};
float
Run
(
const
BaseArgument
*
pArg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
pArg
),
stream_config
);
};
};
bool
IsSupportedArgument
(
const
BaseArgument
*
pArg
)
override
{
const
Argument
*
pArg_
=
dynamic_cast
<
const
Argument
*>
(
pArg
);
if
constexpr
(
XDyDxVectorDim
==
0
)
{
if
(
pArg_
->
xStrides_
[
NumInvariantDim
-
1
]
!=
1
||
pArg_
->
dyStrides_
[
NumInvariantDim
-
1
]
!=
1
||
pArg_
->
dxStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
if
(
pArg_
->
xyLengths_
[
NumInvariantDim
-
1
]
%
XSrcVectorSize
!=
0
||
pArg_
->
xyLengths_
[
NumInvariantDim
-
1
]
%
DySrcVectorSize
!=
0
||
pArg_
->
xyLengths_
[
NumInvariantDim
-
1
]
%
DxDstVectorSize
!=
0
)
return
false
;
}
else
{
if
(
pArg_
->
xStrides_
[
Rank
-
1
]
!=
1
||
pArg_
->
dyStrides_
[
Rank
-
1
]
!=
1
||
pArg_
->
dxStrides_
[
Rank
-
1
]
!=
1
)
return
false
;
if
(
pArg_
->
xyLengths_
[
Rank
-
1
]
%
XSrcVectorSize
!=
0
||
pArg_
->
xyLengths_
[
Rank
-
1
]
%
DySrcVectorSize
!=
0
||
pArg_
->
xyLengths_
[
Rank
-
1
]
%
DxDstVectorSize
!=
0
)
return
false
;
};
if
(
pArg_
->
bnScaleStrides_
[
NumInvariantDim
-
1
]
!=
1
&&
ScaleSrcVectorSize
!=
1
)
return
false
;
if
(
pArg_
->
bnDscaleDbiasStrides_
[
NumInvariantDim
-
1
]
!=
1
&&
DscaleDbiasDstVectorSize
!=
1
)
return
false
;
if
(
pArg_
->
bnScaleBiasMeanVarLengths_
[
NumInvariantDim
-
1
]
%
ScaleSrcVectorSize
!=
0
)
return
false
;
if
(
pArg_
->
bnScaleBiasMeanVarLengths_
[
NumInvariantDim
-
1
]
%
DscaleDbiasDstVectorSize
!=
0
)
return
false
;
if
(
pArg_
->
haveSavedMeanInvVar_
)
{
if
(
pArg_
->
bnMeanVarStrides_
[
NumInvariantDim
-
1
]
!=
1
&&
MeanVarSrcVectorSize
!=
1
)
return
false
;
if
(
pArg_
->
bnScaleBiasMeanVarLengths_
[
NumInvariantDim
-
1
]
%
MeanVarSrcVectorSize
!=
0
)
return
false
;
};
bool
is_valid
=
true
;
static_for
<
0
,
NumInvariantDim
,
1
>
{}([
&
](
auto
I
)
{
if
(
pArg_
->
xyLengths_
[
I
]
!=
pArg_
->
bnScaleBiasMeanVarLengths_
[
I
])
is_valid
=
false
;
});
if
(
!
is_valid
)
return
false
;
return
true
;
};
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
dyStrides
,
const
std
::
array
<
index_t
,
Rank
>
dxStrides
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
ck
::
index_t
,
NumInvariantDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
ck
::
index_t
,
NumInvariantDim
>
bnScaleStrides
,
const
std
::
array
<
ck
::
index_t
,
NumInvariantDim
>
bnDscaleDbiasStrides
,
const
std
::
array
<
ck
::
index_t
,
NumInvariantDim
>
bnMeanVarStrides
,
const
void
*
p_x
,
const
void
*
p_dy
,
const
void
*
p_scale
,
const
void
*
p_savedMean
,
const
void
*
p_savedInvVar
,
double
epsilon
,
const
DyElementwiseOp
dy_elementwise_op
,
void
*
p_dx
,
void
*
p_dscale
,
void
*
p_dbias
)
override
{
return
std
::
make_unique
<
Argument
>
(
xyLengths
,
xStrides
,
dyStrides
,
dxStrides
,
reduceDims
,
bnScaleBiasMeanVarLengths
,
bnScaleStrides
,
bnDscaleDbiasStrides
,
bnMeanVarStrides
,
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
DyDataType
*>
(
p_dy
),
static_cast
<
const
ScaleDataType
*>
(
p_scale
),
static_cast
<
const
MeanVarDataType
*>
(
p_savedMean
),
static_cast
<
const
MeanVarDataType
*>
(
p_savedInvVar
),
dy_elementwise_op
,
epsilon
,
static_cast
<
DxDataType
*>
(
p_dx
),
static_cast
<
DscaleDbiasDataType
*>
(
p_dscale
),
static_cast
<
DscaleDbiasDataType
*>
(
p_dbias
));
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceBatchNormBwdImpl<"
<<
BlockSize
<<
","
;
str
<<
"M_C"
<<
MThreadClusterSize
<<
"_S"
<<
MThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
KThreadClusterSize
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"XDyDxVectorDim_"
<<
XDyDxVectorDim
<<
","
;
str
<<
"VectorSize_X"
<<
XSrcVectorSize
<<
"_scale_"
<<
ScaleSrcVectorSize
<<
"_bias_"
<<
DscaleDbiasDstVectorSize
<<
"_mean_var_"
<<
MeanVarSrcVectorSize
<<
"_Dx_"
<<
DxDstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp
View file @
05ee41c3
...
...
@@ -42,8 +42,15 @@ template <typename XDataType,
index_t
ScaleSrcVectorSize
,
index_t
BiasSrcVectorSize
,
index_t
MeanVarSrcDstVectorSize
>
struct
DeviceBatchNormFwdImpl
:
public
DeviceBatchNormFwd
<
Rank
,
NumBatchNormReduceDim
,
YElementwiseOp
>
struct
DeviceBatchNormFwdImpl
:
public
DeviceBatchNormFwd
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
Rank
,
NumBatchNormReduceDim
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
05ee41c3
...
...
@@ -67,6 +67,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
static
constexpr
ck
::
index_t
NDimSpatial
=
2
;
using
DeviceOp
=
DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
;
...
...
@@ -107,17 +109,17 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
static
constexpr
auto
BBlockLdsN0PerBlock
=
NPerBlock
/
BBlockLdsN1PerBlock
;
static
constexpr
auto
BBlockLdsN1Padding
=
4
;
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
ck
::
index_t
batch_k
)
{
using
namespace
ck
;
...
...
@@ -390,13 +392,13 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
ck
::
index_t
M01
,
ck
::
index_t
N01
,
InElementwiseOperation
in_element_op
,
...
...
@@ -473,11 +475,11 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
index_t
Conv_N_
;
index_t
Conv_K_
;
index_t
Conv_C_
;
std
::
vector
<
index_t
>
output_spatial_lengths_
;
std
::
vector
<
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
index_t
>
conv_filter_strides_
;
std
::
vector
<
index_t
>
input_left_pads_
;
std
::
vector
<
index_t
>
input_right_pads_
;
std
::
array
<
index_t
,
NDimSpatial
>
output_spatial_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
>
filter_spatial_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_
;
index_t
k_batch_
;
};
...
...
@@ -682,13 +684,13 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
@@ -724,13 +726,13 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
0 → 100644
View file @
05ee41c3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template
<
ck
::
index_t
NDimSpatial
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionBackwardDataSpecialization
ConvBackwardDataSpecialization
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
K1
,
index_t
M1PerThread
,
index_t
N1PerThread
,
index_t
KPerThread
,
typename
M1N1ThreadClusterM1Xs
,
typename
M1N1ThreadClusterN1Xs
,
typename
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
typename
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
typename
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
,
typename
ABlockTransferSrcVectorTensorContiguousDimOrder
,
typename
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
,
typename
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
typename
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
typename
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
,
typename
BBlockTransferSrcVectorTensorContiguousDimOrder
,
typename
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
>
struct
DeviceConvNdBwdDataNwcKxcNwk_Dl
:
public
DeviceConvBwdData
<
NDimSpatial
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
NWC
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
NDHWC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
KXC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
NWK
,
ck
::
tensor_layout
::
convolution
::
NHWK
,
ck
::
tensor_layout
::
convolution
::
NDHWK
>>
,
InDataType
,
WeiDataType
,
OutDataType
,
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
using
DeviceOp
=
DeviceConvNdBwdDataNwcKxcNwk_Dl
;
using
ADataType
=
OutDataType
;
using
BDataType
=
WeiDataType
;
using
CDataType
=
InDataType
;
// TODO make A/B datatype different
using
ABDataType
=
InDataType
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
index_t
>
tildes
)
{
using
namespace
ck
;
index_t
i_xtilde
=
tildes
[
0
];
const
index_t
Wi
=
input_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
0
];
const
index_t
X
=
filter_spatial_lengths
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
0
];
const
auto
K0
=
K
/
K1
;
const
auto
in_n_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
if
constexpr
(
ConvBackwardDataSpecialization
==
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
// A: output tensor
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Wo
,
K
)),
make_tuple
(
make_pass_through_transform
(
N
*
Wo
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
// B: weight tensor
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// C: input tensor
const
auto
in_n_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
I1
,
Wo
),
make_tuple
(
I1
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_x_wo_c_grid_desc
,
make_tuple
(
make_freeze_transform
(
I0
),
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
);
}
else
{
const
auto
out_n_wo_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wo
,
K
));
const
auto
wei_k_x_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
X
,
C
));
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilde
);
const
auto
WTilde
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
I1
),
ConvStrideW
);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const
auto
IWTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadW
-
ConvDilationW
*
(
XTilde
-
I1
)),
ConvStrideW
);
const
auto
IWTildeSliceEnd
=
math
::
min
(
WTilde
,
math
::
integer_divide_ceil
(
InLeftPadW
+
Wi
-
I1
,
ConvStrideW
)
+
I1
);
const
auto
WTildeSlice
=
IWTildeSliceEnd
-
IWTildeSliceBegin
;
// GemmK is different for each GEMM
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
// A: output tensor
const
auto
out_n_wop_k_grid_desc
=
transform_tensor_descriptor
(
out_n_wo_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Wo
,
I0
,
I0
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
out_n_xdot_wtilde_k_grid_desc
=
transform_tensor_descriptor
(
out_n_wop_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
XDot
,
WTilde
),
make_tuple
(
-
ConvDilationW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
const
auto
out_n_xdotslice_wtildeslice_k0_k1_grid_desc
=
transform_tensor_descriptor
(
out_n_xdot_wtilde_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
,
4
>
{}));
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_n_xdotslice_wtildeslice_k0_k1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
XDotSlice
,
K0
)),
make_merge_transform
(
make_tuple
(
N
,
WTildeSlice
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
1
,
3
>
{},
Sequence
<
0
,
2
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// B weight tensor
const
auto
wei_k_xdot_xtilde_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_x_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_embed_transform
(
make_tuple
(
XDot
,
XTilde
),
make_tuple
(
ConvStrideW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
const
auto
wei_k0_k1_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_xdot_xtilde_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
i_xtilde
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<>
{},
Sequence
<
3
>
{}));
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
wei_k0_k1_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
XDotSlice
,
K0
)),
make_pass_through_transform
(
C
),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
2
,
0
>
{},
Sequence
<
3
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// C: input tensor
const
auto
in_n_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_n_xtilde_wtilde_c_grid_desc
=
transform_tensor_descriptor
(
in_n_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
XTilde
,
WTilde
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_wtildeslice_c_grid_desc
=
transform_tensor_descriptor
(
in_n_xtilde_wtilde_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_freeze_transform
(
i_xtilde
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_wtildeslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
WTildeSlice
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
);
}
}
// function end
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
index_t
>
tildes
)
{
using
namespace
ck
;
index_t
i_ytilde
=
tildes
[
0
];
index_t
i_xtilde
=
tildes
[
1
];
const
index_t
Hi
=
input_spatial_lengths
[
0
];
const
index_t
Wi
=
input_spatial_lengths
[
1
];
const
index_t
Ho
=
output_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
1
];
const
index_t
Y
=
filter_spatial_lengths
[
0
];
const
index_t
X
=
filter_spatial_lengths
[
1
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
auto
K0
=
K
/
K1
;
const
auto
out_n_ho_wo_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Ho
,
Wo
,
K
));
const
auto
wei_k_y_x_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
,
X
,
C
));
const
auto
in_n_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
if
constexpr
(
ConvBackwardDataSpecialization
==
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
// A: output tensor
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
// B: weight tensor
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// C: input tensor
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
I1
,
Ho
),
make_tuple
(
I1
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
I1
,
Wo
),
make_tuple
(
I1
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_freeze_transform
(
I0
),
make_freeze_transform
(
I0
),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
);
}
else
{
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilde
);
const
auto
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilde
);
const
auto
HTilde
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
I1
),
ConvStrideH
);
const
auto
WTilde
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
I1
),
ConvStrideW
);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const
auto
IHTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadH
-
ConvDilationH
*
(
YTilde
-
I1
)),
ConvStrideH
);
const
auto
IWTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadW
-
ConvDilationW
*
(
XTilde
-
I1
)),
ConvStrideW
);
const
auto
IHTildeSliceEnd
=
math
::
min
(
HTilde
,
math
::
integer_divide_ceil
(
InLeftPadH
+
Hi
-
I1
,
ConvStrideH
)
+
I1
);
const
auto
IWTildeSliceEnd
=
math
::
min
(
WTilde
,
math
::
integer_divide_ceil
(
InLeftPadW
+
Wi
-
I1
,
ConvStrideW
)
+
I1
);
const
auto
HTildeSlice
=
IHTildeSliceEnd
-
IHTildeSliceBegin
;
const
auto
WTildeSlice
=
IWTildeSliceEnd
-
IWTildeSliceBegin
;
// GemmK is different for each GEMM
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
// A: output tensor
const
auto
out_n_hop_wop_k_grid_desc
=
transform_tensor_descriptor
(
out_n_ho_wo_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Ho
,
I0
,
I0
),
make_pad_transform
(
Wo
,
I0
,
I0
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
out_n_ydot_htilde_xdot_wtilde_k_grid_desc
=
transform_tensor_descriptor
(
out_n_hop_wop_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YDot
,
HTilde
),
make_tuple
(
-
ConvDilationH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
WTilde
),
make_tuple
(
-
ConvDilationW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydot_htilde_xdot_wtilde_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
>
{}));
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K0
)),
make_merge_transform
(
make_tuple
(
N
,
HTildeSlice
,
WTildeSlice
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// B weight tensor
const
auto
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_y_x_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_embed_transform
(
make_tuple
(
YDot
,
YTilde
),
make_tuple
(
ConvStrideH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
XTilde
),
make_tuple
(
ConvStrideW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
i_ytilde
),
make_freeze_transform
(
i_xtilde
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
4
>
{}));
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K0
)),
make_pass_through_transform
(
C
),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
2
,
3
,
0
>
{},
Sequence
<
4
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// C: input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YTilde
,
HTilde
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
XTilde
,
WTilde
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_n_htildeslice_wtildeslice_c_grid_desc
=
transform_tensor_descriptor
(
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_freeze_transform
(
i_ytilde
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_freeze_transform
(
i_xtilde
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<>
{},
Sequence
<
1
>
{},
Sequence
<>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_htildeslice_wtildeslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
HTildeSlice
,
WTildeSlice
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
);
}
}
// function end
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
index_t
>
tildes
)
{
using
namespace
ck
;
const
index_t
i_ztilde
=
tildes
[
0
];
const
index_t
i_ytilde
=
tildes
[
1
];
const
index_t
i_xtilde
=
tildes
[
2
];
const
index_t
Di
=
input_spatial_lengths
[
0
];
const
index_t
Hi
=
input_spatial_lengths
[
1
];
const
index_t
Wi
=
input_spatial_lengths
[
2
];
const
index_t
Do
=
output_spatial_lengths
[
0
];
const
index_t
Ho
=
output_spatial_lengths
[
1
];
const
index_t
Wo
=
output_spatial_lengths
[
2
];
const
index_t
Z
=
filter_spatial_lengths
[
0
];
const
index_t
Y
=
filter_spatial_lengths
[
1
];
const
index_t
X
=
filter_spatial_lengths
[
2
];
const
index_t
InLeftPadD
=
input_left_pads
[
0
];
const
index_t
InLeftPadH
=
input_left_pads
[
1
];
const
index_t
InLeftPadW
=
input_left_pads
[
2
];
const
index_t
InRightPadD
=
input_right_pads
[
0
];
const
index_t
InRightPadH
=
input_right_pads
[
1
];
const
index_t
InRightPadW
=
input_right_pads
[
2
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
1
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
2
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
2
];
const
auto
K0
=
K
/
K1
;
const
auto
out_n_do_ho_wo_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
,
K
));
const
auto
wei_k_z_y_x_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Z
,
Y
,
X
,
C
));
const
auto
in_n_di_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
if
constexpr
(
ConvBackwardDataSpecialization
==
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
// A: output tensor
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Do
*
Ho
*
Wo
,
K
)),
make_tuple
(
make_pass_through_transform
(
N
*
Do
*
Ho
*
Wo
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
// B: weight tensor
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// C: input tensor
const
auto
in_n_z_do_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
I1
,
Do
),
make_tuple
(
I1
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
I1
,
Ho
),
make_tuple
(
I1
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
I1
,
Wo
),
make_tuple
(
I1
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_freeze_transform
(
I0
),
make_freeze_transform
(
I0
),
make_freeze_transform
(
I0
),
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
);
}
else
{
const
auto
GcdStrideDilationD
=
math
::
gcd
(
ConvStrideD
,
ConvDilationD
);
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
ZTilde
=
ConvStrideD
/
GcdStrideDilationD
;
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
ZDot
=
math
::
integer_divide_ceil
(
Z
,
ZTilde
);
const
auto
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilde
);
const
auto
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilde
);
const
auto
DTilde
=
Do
+
math
::
integer_divide_ceil
(
ConvDilationD
*
(
Z
-
I1
),
ConvStrideD
);
const
auto
HTilde
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
I1
),
ConvStrideH
);
const
auto
WTilde
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
I1
),
ConvStrideW
);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const
auto
IDTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadD
-
ConvDilationD
*
(
ZTilde
-
I1
)),
ConvStrideD
);
const
auto
IHTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadH
-
ConvDilationH
*
(
YTilde
-
I1
)),
ConvStrideH
);
const
auto
IWTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadW
-
ConvDilationW
*
(
XTilde
-
I1
)),
ConvStrideW
);
const
auto
IDTildeSliceEnd
=
math
::
min
(
DTilde
,
math
::
integer_divide_ceil
(
InLeftPadD
+
Di
-
I1
,
ConvStrideD
)
+
I1
);
const
auto
IHTildeSliceEnd
=
math
::
min
(
HTilde
,
math
::
integer_divide_ceil
(
InLeftPadH
+
Hi
-
I1
,
ConvStrideH
)
+
I1
);
const
auto
IWTildeSliceEnd
=
math
::
min
(
WTilde
,
math
::
integer_divide_ceil
(
InLeftPadW
+
Wi
-
I1
,
ConvStrideW
)
+
I1
);
const
auto
DTildeSlice
=
IDTildeSliceEnd
-
IDTildeSliceBegin
;
const
auto
HTildeSlice
=
IHTildeSliceEnd
-
IHTildeSliceBegin
;
const
auto
WTildeSlice
=
IWTildeSliceEnd
-
IWTildeSliceBegin
;
// GemmK is different for each GEMM
const
auto
ZDotSlice
=
math
::
integer_divide_ceil
(
Z
-
i_ztilde
,
ZTilde
);
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
// A: output tensor
const
auto
out_n_dop_hop_wop_k_grid_desc
=
transform_tensor_descriptor
(
out_n_do_ho_wo_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Do
,
I0
,
I0
),
make_pad_transform
(
Ho
,
I0
,
I0
),
make_pad_transform
(
Wo
,
I0
,
I0
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc
=
transform_tensor_descriptor
(
out_n_dop_hop_wop_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
ZDot
,
DTilde
),
make_tuple
(
-
ConvDilationD
/
GcdStrideDilationD
,
I1
)),
make_embed_transform
(
make_tuple
(
YDot
,
HTilde
),
make_tuple
(
-
ConvDilationH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
WTilde
),
make_tuple
(
-
ConvDilationW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc
=
transform_tensor_descriptor
(
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_slice_transform
(
ZDot
,
I0
,
ZDotSlice
),
make_slice_transform
(
DTilde
,
IDTildeSliceBegin
,
DTildeSlice
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
,
8
>
{}));
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
ZDotSlice
,
YDotSlice
,
XDotSlice
,
K0
)),
make_merge_transform
(
make_tuple
(
N
,
DTildeSlice
,
HTildeSlice
,
WTildeSlice
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
1
,
3
,
5
,
7
>
{},
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
8
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// B weight tensor
const
auto
wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_z_y_x_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_embed_transform
(
make_tuple
(
ZDot
,
ZTilde
),
make_tuple
(
ConvStrideD
/
GcdStrideDilationD
,
I1
)),
make_embed_transform
(
make_tuple
(
YDot
,
YTilde
),
make_tuple
(
ConvStrideH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
XTilde
),
make_tuple
(
ConvStrideW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_slice_transform
(
ZDot
,
I0
,
ZDotSlice
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
i_ztilde
),
make_freeze_transform
(
i_ytilde
),
make_freeze_transform
(
i_xtilde
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
5
>
{}));
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
ZDotSlice
,
YDotSlice
,
XDotSlice
,
K0
)),
make_pass_through_transform
(
C
),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
2
,
3
,
4
,
0
>
{},
Sequence
<
5
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// C: input tensor
const
auto
in_n_dip_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc
=
transform_tensor_descriptor
(
in_n_dip_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
ZTilde
,
DTilde
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
YTilde
,
HTilde
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
XTilde
,
WTilde
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc
=
transform_tensor_descriptor
(
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_freeze_transform
(
i_ztilde
),
make_slice_transform
(
DTilde
,
IDTildeSliceBegin
,
DTildeSlice
),
make_freeze_transform
(
i_ytilde
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_freeze_transform
(
i_xtilde
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<>
{},
Sequence
<
1
>
{},
Sequence
<>
{},
Sequence
<
2
>
{},
Sequence
<>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
DTildeSlice
,
HTildeSlice
,
WTildeSlice
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
);
}
}
// function end
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
GetABCGridDesc
()
{
return
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
1
>
(
1
,
1
,
1
,
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
0
});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
GetABCGridDesc
()
{
return
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
2
>
(
1
,
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
GetABCGridDesc
()
{
return
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
3
>
(
1
,
1
,
1
,
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
});
}
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NDimSpatial
>
());
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
using
BGridDesc_K0_N_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmDl_km_kn_mn_v1r3
<
BlockSize
,
ADataType
,
AccDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
K1
,
M1PerThread
,
N1PerThread
,
KPerThread
,
M1N1ThreadClusterM1Xs
,
M1N1ThreadClusterN1Xs
,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
,
ABlockTransferSrcVectorTensorContiguousDimOrder
,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
,
BBlockTransferSrcVectorTensorContiguousDimOrder
,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
>
;
using
AGridDesc_K0_M0_M1_K1
=
decltype
(
GridwiseGemm
::
MakeAGridDescriptor_K0_M0_M1_K1
(
AGridDesc_K0_M_K1
{}));
using
BGridDesc_K0_N0_N1_K1
=
decltype
(
GridwiseGemm
::
MakeBGridDescriptor_K0_N0_N1_K1
(
BGridDesc_K0_N_K1
{}));
using
CGridDesc_M0_M10_M11_N0_N10_N11
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
CGridDesc_M_N
{}));
using
DefaultBlock2CTileMap
=
decltype
(
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}));
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
InDataType
*
p_in_grid
,
const
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
:
p_a_grid_
{
p_out_grid
},
p_b_grid_
{
p_wei_grid
},
p_c_grid_
{
p_in_grid
},
a_element_op_
{
out_element_op
},
b_element_op_
{
wei_element_op
},
c_element_op_
{
in_element_op
},
Conv_N_
{
N
},
Conv_K_
{
K
},
Conv_C_
{
C
},
input_spatial_lengths_
{
input_spatial_lengths
},
filter_spatial_lengths_
{
filter_spatial_lengths
},
output_spatial_lengths_
{
output_spatial_lengths
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
{
CreateABCDesc
<
NDimSpatial
>
();
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
void
CreateABCDesc
()
{
const
index_t
ConvStrideW
=
conv_filter_strides_
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations_
[
0
];
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
index_t
X
=
filter_spatial_lengths_
[
0
];
for
(
index_t
i_xtilde
=
0
;
i_xtilde
<
XTilde
;
++
i_xtilde
)
{
// check slice is valid
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
if
(
XDotSlice
<=
0
)
{
continue
;
}
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>
(
Conv_N_
,
Conv_K_
,
Conv_C_
,
input_spatial_lengths_
,
filter_spatial_lengths_
,
output_spatial_lengths_
,
conv_filter_strides_
,
conv_filter_dilations_
,
input_left_pads_
,
input_right_pads_
,
{
i_xtilde
});
a_grid_desc_k0_m_k1_container_
.
push_back
(
descs
[
I0
]);
b_grid_desc_k0_n_k1_container_
.
push_back
(
descs
[
I1
]);
c_grid_desc_m_n_container_
.
push_back
(
descs
[
I2
]);
if
(
GridwiseGemm
::
CheckValidity
(
descs
[
I0
],
descs
[
I1
],
descs
[
I2
]))
{
a_grid_desc_k0_m0_m1_k1_container_
.
push_back
(
GridwiseGemm
::
MakeAGridDescriptor_K0_M0_M1_K1
(
descs
[
I0
]));
b_grid_desc_k0_n0_n1_k1_container_
.
push_back
(
GridwiseGemm
::
MakeBGridDescriptor_K0_N0_N1_K1
(
descs
[
I1
]));
c_grid_desc_m0_m10_m11_n0_n10_n11_container_
.
push_back
(
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
descs
[
I2
]));
block_2_ctile_map_container_
.
push_back
(
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
descs
[
I2
]));
}
}
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
void
CreateABCDesc
()
{
const
index_t
ConvStrideH
=
conv_filter_strides_
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides_
[
1
];
const
index_t
ConvDilationH
=
conv_filter_dilations_
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations_
[
1
];
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
index_t
Y
=
filter_spatial_lengths_
[
0
];
const
index_t
X
=
filter_spatial_lengths_
[
1
];
for
(
index_t
i_ytilde
=
0
;
i_ytilde
<
YTilde
;
++
i_ytilde
)
{
for
(
index_t
i_xtilde
=
0
;
i_xtilde
<
XTilde
;
++
i_xtilde
)
{
// check slice is valid
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
if
(
YDotSlice
*
XDotSlice
<=
0
)
{
continue
;
}
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>
(
Conv_N_
,
Conv_K_
,
Conv_C_
,
input_spatial_lengths_
,
filter_spatial_lengths_
,
output_spatial_lengths_
,
conv_filter_strides_
,
conv_filter_dilations_
,
input_left_pads_
,
input_right_pads_
,
{
i_ytilde
,
i_xtilde
});
a_grid_desc_k0_m_k1_container_
.
push_back
(
descs
[
I0
]);
b_grid_desc_k0_n_k1_container_
.
push_back
(
descs
[
I1
]);
c_grid_desc_m_n_container_
.
push_back
(
descs
[
I2
]);
if
(
GridwiseGemm
::
CheckValidity
(
descs
[
I0
],
descs
[
I1
],
descs
[
I2
]))
{
a_grid_desc_k0_m0_m1_k1_container_
.
push_back
(
GridwiseGemm
::
MakeAGridDescriptor_K0_M0_M1_K1
(
descs
[
I0
]));
b_grid_desc_k0_n0_n1_k1_container_
.
push_back
(
GridwiseGemm
::
MakeBGridDescriptor_K0_N0_N1_K1
(
descs
[
I1
]));
c_grid_desc_m0_m10_m11_n0_n10_n11_container_
.
push_back
(
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
descs
[
I2
]));
block_2_ctile_map_container_
.
push_back
(
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
descs
[
I2
]));
}
}
}
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
void
CreateABCDesc
()
{
const
index_t
ConvStrideD
=
conv_filter_strides_
[
0
];
const
index_t
ConvStrideH
=
conv_filter_strides_
[
1
];
const
index_t
ConvStrideW
=
conv_filter_strides_
[
2
];
const
index_t
ConvDilationD
=
conv_filter_dilations_
[
0
];
const
index_t
ConvDilationH
=
conv_filter_dilations_
[
1
];
const
index_t
ConvDilationW
=
conv_filter_dilations_
[
2
];
const
auto
GcdStrideDilationD
=
math
::
gcd
(
ConvStrideD
,
ConvDilationD
);
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
ZTilde
=
ConvStrideD
/
GcdStrideDilationD
;
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
index_t
Z
=
filter_spatial_lengths_
[
0
];
const
index_t
Y
=
filter_spatial_lengths_
[
1
];
const
index_t
X
=
filter_spatial_lengths_
[
2
];
for
(
index_t
i_ztilde
=
0
;
i_ztilde
<
ZTilde
;
++
i_ztilde
)
{
for
(
index_t
i_ytilde
=
0
;
i_ytilde
<
YTilde
;
++
i_ytilde
)
{
for
(
index_t
i_xtilde
=
0
;
i_xtilde
<
XTilde
;
++
i_xtilde
)
{
// check slice is valid
const
auto
ZDotSlice
=
math
::
integer_divide_ceil
(
Z
-
i_ztilde
,
ZTilde
);
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
if
(
ZDotSlice
*
YDotSlice
*
XDotSlice
<=
0
)
{
continue
;
}
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>
(
Conv_N_
,
Conv_K_
,
Conv_C_
,
input_spatial_lengths_
,
filter_spatial_lengths_
,
output_spatial_lengths_
,
conv_filter_strides_
,
conv_filter_dilations_
,
input_left_pads_
,
input_right_pads_
,
{
i_ztilde
,
i_ytilde
,
i_xtilde
});
a_grid_desc_k0_m_k1_container_
.
push_back
(
descs
[
I0
]);
b_grid_desc_k0_n_k1_container_
.
push_back
(
descs
[
I1
]);
c_grid_desc_m_n_container_
.
push_back
(
descs
[
I2
]);
if
(
GridwiseGemm
::
CheckValidity
(
descs
[
I0
],
descs
[
I1
],
descs
[
I2
]))
{
a_grid_desc_k0_m0_m1_k1_container_
.
push_back
(
GridwiseGemm
::
MakeAGridDescriptor_K0_M0_M1_K1
(
descs
[
I0
]));
b_grid_desc_k0_n0_n1_k1_container_
.
push_back
(
GridwiseGemm
::
MakeBGridDescriptor_K0_N0_N1_K1
(
descs
[
I1
]));
c_grid_desc_m0_m10_m11_n0_n10_n11_container_
.
push_back
(
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
descs
[
I2
]));
block_2_ctile_map_container_
.
push_back
(
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
descs
[
I2
]));
}
}
}
}
}
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
std
::
vector
<
AGridDesc_K0_M_K1
>
a_grid_desc_k0_m_k1_container_
;
std
::
vector
<
BGridDesc_K0_N_K1
>
b_grid_desc_k0_n_k1_container_
;
std
::
vector
<
CGridDesc_M_N
>
c_grid_desc_m_n_container_
;
std
::
vector
<
AGridDesc_K0_M0_M1_K1
>
a_grid_desc_k0_m0_m1_k1_container_
;
std
::
vector
<
BGridDesc_K0_N0_N1_K1
>
b_grid_desc_k0_n0_n1_k1_container_
;
std
::
vector
<
CGridDesc_M0_M10_M11_N0_N10_N11
>
c_grid_desc_m0_m10_m11_n0_n10_n11_container_
;
std
::
vector
<
DefaultBlock2CTileMap
>
block_2_ctile_map_container_
;
// element-wise op
OutElementwiseOperation
a_element_op_
;
WeiElementwiseOperation
b_element_op_
;
InElementwiseOperation
c_element_op_
;
// for checking IsSupportedArgument()
index_t
Conv_N_
;
index_t
Conv_K_
;
index_t
Conv_C_
;
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths_
;
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths_
;
std
::
vector
<
ck
::
index_t
>
conv_filter_strides_
;
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations_
;
std
::
vector
<
ck
::
index_t
>
input_left_pads_
;
std
::
vector
<
ck
::
index_t
>
input_right_pads_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
float
ave_time
=
0
;
for
(
size_t
i
=
0
;
i
<
arg
.
a_grid_desc_k0_m_k1_container_
.
size
();
i
++
)
{
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_container_{"
<<
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_k0_n_k1_container_{"
<<
arg
.
b_grid_desc_k0_n_k1_container_
[
i
].
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_container_
[
i
].
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_container_
[
i
].
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_container_{ "
<<
arg
.
c_grid_desc_m_n_container_
[
i
].
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_container_
[
i
].
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_( "
<<
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_container_
[
i
].
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_container_
[
i
].
GetLength
(
I1
)
<<
", "
<<
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_container_
[
i
].
GetLength
(
I2
)
<<
", "
<<
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_container_
[
i
].
GetLength
(
I3
)
<<
", "
<<
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_container_
[
i
].
GetLength
(
I4
)
<<
", "
<<
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_container_
[
i
].
GetLength
(
I5
)
<<
" ) "
<<
std
::
endl
;
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_container_
[
i
],
arg
.
b_grid_desc_k0_n_k1_container_
[
i
],
arg
.
c_grid_desc_m_n_container_
[
i
]))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_ctile_map_container_
[
i
].
CalculateGridSize
(
arg
.
c_grid_desc_m_n_container_
[
i
]);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
,
auto
has_double_tail_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
constexpr
bool
has_double_loop
=
has_double_tail_k_block_loop
;
const
auto
kernel
=
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M0_M1_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N0_N1_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_M0_M10_M11_N0_N10_N11
>
,
remove_reference_t
<
DeviceOp
::
DefaultBlock2CTileMap
>
,
has_main_loop
,
has_double_loop
>
;
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m0_m1_k1_container_
[
i
],
arg
.
b_grid_desc_k0_n0_n1_k1_container_
[
i
],
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_container_
[
i
],
arg
.
block_2_ctile_map_container_
[
i
]);
};
const
auto
K0
=
arg
.
a_grid_desc_k0_m0_m1_k1_container_
[
i
].
GetLength
(
I0
);
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K0
);
const
bool
has_double_tail_k_block_loop
=
GridwiseGemm
::
CalculateHasDoubleTailKBlockLoop
(
K0
);
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
return
ave_time
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
// check device
if
(
!
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx1030"
))
{
return
false
;
}
if
constexpr
(
ConvBackwardDataSpecialization
==
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
// check if it's 1x1, stride=1 pad = 0 conv
for
(
int
i
=
0
;
i
<
NDimSpatial
;
i
++
)
{
if
(
!
(
arg
.
filter_spatial_lengths_
[
i
]
==
1
&&
arg
.
conv_filter_strides_
[
i
]
==
1
&&
arg
.
input_left_pads_
[
i
]
==
0
&&
arg
.
input_right_pads_
[
i
]
==
0
))
{
return
false
;
}
}
}
// matrix A
{
auto
srcVectorLengths
=
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
{};
if
(
srcVectorLengths
[
I1
]
!=
1
||
srcVectorLengths
[
I2
]
!=
1
)
{
return
false
;
}
if
(
K1
%
srcVectorLengths
[
I3
]
!=
0
||
K0PerBlock
%
srcVectorLengths
[
I0
]
!=
0
)
{
return
false
;
}
const
index_t
K
=
arg
.
Conv_K_
;
if
(
K
%
(
srcVectorLengths
[
I0
]
*
srcVectorLengths
[
I3
])
!=
0
)
{
return
false
;
}
}
// matrix B
{
auto
srcLoadLenghts
=
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
{};
auto
srcVectorLengths
=
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
{};
if
(
srcVectorLengths
[
I0
]
!=
1
||
srcVectorLengths
[
I3
]
!=
1
)
{
return
false
;
}
if
(
srcLoadLenghts
[
I1
]
%
srcVectorLengths
[
I1
]
!=
0
||
srcLoadLenghts
[
I2
]
%
srcVectorLengths
[
I2
]
!=
0
)
{
return
false
;
}
const
index_t
C
=
arg
.
Conv_K_
;
if
(
C
%
(
srcVectorLengths
[
I1
]
*
srcVectorLengths
[
I2
])
!=
0
)
{
return
false
;
}
}
// vector store C matrix into global memory
if
(
!
(
arg
.
Conv_C_
%
CThreadTransferDstScalarPerVector
==
0
))
{
std
::
cout
<<
"Not surpport,because: arg.Conv_C_ % CThreadTransferDstScalarPerVector = "
<<
arg
.
Conv_C_
%
CThreadTransferDstScalarPerVector
<<
std
::
endl
;
return
false
;
}
// Gridwise GEMM size
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
a_grid_desc_k0_m_k1_container_
.
size
();
i
++
)
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_container_
[
i
],
arg
.
b_grid_desc_k0_n_k1_container_
[
i
],
arg
.
c_grid_desc_m_n_container_
[
i
]))
{
return
false
;
}
}
return
true
;
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
InDataType
*
p_in_grid
,
const
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
{
return
Argument
{
p_in_grid
,
p_wei_grid
,
p_out_grid
,
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
in_element_op
,
wei_element_op
,
out_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
void
*
p_in_grid
,
const
void
*
p_wei_grid
,
const
void
*
p_out_grid
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
InDataType
*>
(
p_in_grid
),
static_cast
<
const
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
const
OutDataType
*>
(
p_out_grid
),
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
in_element_op
,
wei_element_op
,
out_element_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceConvNdBwdDataNwcKxcNwk_Dl"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
">"
;
if
constexpr
(
ConvBackwardDataSpecialization
==
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
){
str
<<
" Filter1x1Stride1Pad0"
;
}
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp
0 → 100644
View file @
05ee41c3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
// X = Elementwise(input1, input2, input3, ...)
// Y = Normalization(X, beta, gamma)
namespace
ck
{
template
<
typename
GridwiseElementwiseReduction
,
typename
InDataTypePointerTuple
,
// Datatype tuple of inputs
typename
XDataType
,
// Datatype of X
typename
GammaDataType
,
// Datatype of Gamma
typename
BetaDataType
,
// Datatype of Beta
typename
YDataType
,
// Datatype of Y
typename
AccDataType
,
// AccDatatype
typename
XElementwiseOperation
,
// Operation of input
typename
YElementwiseOperation
,
// Operation of output of normalization
typename
InGrid2dDescTuple
,
// Descriptor tuple of inputs
typename
GridDesc_M_K
>
// Descriptor of inputs, Gamma, Beta
__global__
void
kernel_elementwise_layernorm
(
const
InGrid2dDescTuple
in_grid_2d_desc_tuple
,
// Descriptor tuple of inputs
const
GridDesc_M_K
x_grid_desc_m_k
,
// Descriptor of X
const
GridDesc_M_K
gamma_grid_desc_m_k
,
// Descriptor of gamma
const
GridDesc_M_K
beta_grid_desc_m_k
,
// Descriptor of beta
const
GridDesc_M_K
y_grid_desc_m_k
,
// Descriptor of Y
index_t
num_k_block_tile_iteration
,
//
AccDataType
epsilon
,
// Datatype of epsilon
const
InDataTypePointerTuple
p_in_global_tuple
,
// Ptr tuple of input matrixs
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
// Ptr of gamma
const
BetaDataType
*
const
__restrict__
p_beta_global
,
// Ptr of beta
YDataType
*
const
__restrict__
p_y_global
,
// Ptr of y
const
XElementwiseOperation
x_elementwise_op
,
// Operation of input
const
YElementwiseOperation
y_elementwise_op
)
// Operation of output of normalization
{
extern
__shared__
XDataType
p_x_lds
[];
GridwiseElementwiseReduction
::
Run
(
in_grid_2d_desc_tuple
,
// Descriptor tuple of inputs
x_grid_desc_m_k
,
// Descriptor of X
gamma_grid_desc_m_k
,
// Descriptor of Gamma
beta_grid_desc_m_k
,
// Descriptor of Beta
y_grid_desc_m_k
,
// Descriptor of Y
num_k_block_tile_iteration
,
//
epsilon
,
// epsilon
p_in_global_tuple
,
// Ptr tuple of inputs
p_x_lds
,
// Ptr of X
p_gamma_global
,
// Ptr of gamma
p_beta_global
,
// Ptr of beta
p_y_global
,
// Ptr of Y
x_elementwise_op
,
// Operation of input
y_elementwise_op
);
// Operation of output of normalization
};
}
// namespace ck
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// Y = LayerNorm(A + B, Beta, Gamma)
template
<
typename
InDataTypeTuple
,
// Datatype of inputs
typename
GammaDataType
,
// Datatype of gamma
typename
BetaDataType
,
// Datatype of beta
typename
AccDataType
,
//
typename
YDataType
,
//
typename
XElementwiseOperation
,
//
typename
YElementwiseOperation
,
//
index_t
Rank
,
//
index_t
NumReduceDim
,
//
index_t
BlockSize
,
//
index_t
MThreadClusterSize
,
// Num of threads in a block on M direction
index_t
KThreadClusterSize
,
// Num of threads in a block on N direction
index_t
MThreadSliceSize
,
// Each thread calculate rows
index_t
KThreadSliceSize
,
// Each thread calculate columns
index_t
XYSrcVectorDim
,
// Dimension to do reduce
index_t
XSrcVectorSize
,
// Size to fetch source x
index_t
GammaSrcVectorDim
,
// Dimension for gamma to do reduce
index_t
GammaSrcVectorSize
,
// Size to fetch source gamma
index_t
BetaSrcVectorDim
,
// Dimension for beta to do reduce
index_t
BetaSrcVectorSize
,
// Size to fetch source beta
index_t
YDstVectorSize
>
// Size to write destination Y
struct
DeviceElementwiseNormalizationImpl
:
public
DeviceElementwiseNormalization
<
InDataTypeTuple
,
GammaDataType
,
BetaDataType
,
AccDataType
,
YDataType
,
XElementwiseOperation
,
YElementwiseOperation
,
Rank
,
NumReduceDim
>
{
static
constexpr
int
NumInput
=
InDataTypeTuple
::
Size
();
using
XDataType
=
YDataType
;
static_assert
(
(
KThreadSliceSize
%
GammaSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"
);
static_assert
(
(
KThreadSliceSize
%
BetaSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"
);
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
// num of rows calculated in a block
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
// num of columns calculated in a block
static
auto
GenerateInDataTypePointerTuple
()
{
return
generate_tuple
(
[
&
](
auto
I
)
{
using
DataType
=
remove_cvref_t
<
decltype
(
InDataTypeTuple
{}[
I
])
>
;
return
static_cast
<
const
DataType
*>
(
nullptr
);
},
Number
<
NumInput
>
{});
};
using
InDataTypePointerTuple
=
decltype
(
GenerateInDataTypePointerTuple
());
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
index_t
>&
inLengths
,
const
std
::
vector
<
index_t
>&
inStrides
,
int
blkGroupSize
,
int
numBlockTileIteration
)
{
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
numSrcDim
=
Rank
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
numSrcDim
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
numSrcDim
>
{});
const
auto
inDesc
=
make_naive_tensor_descriptor
(
tupleSrcLengths
,
tupleSrcStrides
);
const
auto
in_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
reduceAllDim
)
{
const
auto
one_dim_inDesc
=
transform_tensor_descriptor
(
inDesc
,
make_tuple
(
make_merge_transform
(
tupleSrcLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
numSrcDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
transform_tensor_descriptor
(
one_dim_inDesc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
1
,
one_dim_inDesc
.
GetLength
(
Number
<
0
>
{})))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{}));
}
else
{
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
using
ReduceDims
=
typename
arithmetic_sequence_gen
<
NumInvariantDim
,
Rank
,
1
>::
type
;
const
auto
reduceDimLengths
=
make_tuple_from_array_and_index_seq
(
inLengths
,
ReduceDims
{});
const
auto
invariantDimLengths
=
make_tuple_from_array_and_index_seq
(
inLengths
,
InvariantDims
{});
return
transform_tensor_descriptor
(
inDesc
,
make_tuple
(
make_merge_transform
(
invariantDimLengths
),
make_merge_transform
(
reduceDimLengths
)),
make_tuple
(
InvariantDims
{},
ReduceDims
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}();
const
auto
invariantLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
reduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
int
reduceSizePerBlock
=
K_BlockTileSize
*
numBlockTileIteration
;
const
auto
inPad_M
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
inPad_K
=
reduceSizePerBlock
*
blkGroupSize
-
reduceLength
;
auto
in_grid_desc_m_k_padded
=
transform_tensor_descriptor
(
in_grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
inPad_M
),
make_right_pad_transform
(
reduceLength
,
inPad_K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
in_grid_desc_m_k_padded
);
};
template
<
index_t
TupleSize
>
static
auto
GenerateSrcGrid2dDescTuple
(
Number
<
TupleSize
>
)
{
return
generate_tuple
([
&
](
auto
)
{
return
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
);
},
Number
<
TupleSize
>
{});
};
using
InGrid2dDescTuple
=
decltype
(
GenerateSrcGrid2dDescTuple
(
Number
<
NumInput
>
{}));
using
GridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
GridwiseReduceLayernormGeneric
=
GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
<
InDataTypePointerTuple
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
XElementwiseOperation
,
YElementwiseOperation
,
InGrid2dDescTuple
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XYSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
XYSrcVectorDim
,
YDstVectorSize
,
false
>
;
using
GridwiseReduceLayernormSweepOnce
=
GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
<
InDataTypePointerTuple
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
XElementwiseOperation
,
YElementwiseOperation
,
InGrid2dDescTuple
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XYSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
XYSrcVectorDim
,
YDstVectorSize
,
true
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
vector
<
index_t
>
lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumInput
>
inStridesArray
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
XElementwiseOperation
x_elementwise_op
,
YElementwiseOperation
y_elementwise_op
,
AccDataType
epsilon
,
const
std
::
array
<
const
void
*
,
NumInput
>
in_dev_buffers
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
YDataType
*
p_y
)
:
epsilon_
(
epsilon
),
p_gamma_
(
p_gamma
),
p_beta_
(
p_beta
),
p_y_
(
p_y
),
x_elementwise_op_
(
x_elementwise_op
),
y_elementwise_op_
(
y_elementwise_op
)
{
Lengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
lengths
,
reduceDims
);
for
(
int
i
=
0
;
i
<
NumInput
;
i
++
)
{
inStridesArray_
[
i
]
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inStridesArray
[
i
],
reduceDims
);
}
yStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
yStrides
,
reduceDims
);
xStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
yStrides
,
reduceDims
);
gammaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
gammaStrides
,
reduceDims
);
betaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
betaStrides
,
reduceDims
);
in_dev_buffers_
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataType
=
remove_cvref_t
<
decltype
(
InDataTypeTuple
{}[
I
])
>
;
return
static_cast
<
const
DataType
*>
(
in_dev_buffers
[
I
.
value
]);
},
Number
<
NumInput
>
{});
long_index_t
invariant_total_length
;
long_index_t
reduce_total_length
;
std
::
tie
(
invariant_total_length
,
reduce_total_length
)
=
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
Lengths_
);
blkGroupSize_
=
1
;
numBlockTileIteration_
=
(
reduce_total_length
+
K_BlockTileSize
-
1
)
/
K_BlockTileSize
;
gridSize_
=
math
::
integer_least_multiple
(
invariant_total_length
,
M_BlockTileSize
)
/
M_BlockTileSize
*
blkGroupSize_
;
in_grid_2d_desc_tuple_
=
generate_tuple
(
[
&
](
auto
I
)
{
return
MakeSrc2dDescriptor
(
Lengths_
,
inStridesArray_
[
I
.
value
],
blkGroupSize_
,
numBlockTileIteration_
);
},
Number
<
NumInput
>
{});
x_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
xStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
gamma_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
gammaStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
beta_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
betaStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
y_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
yStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
sweep_once_
=
x_grid_desc_m_k_
.
GetLength
(
Number
<
1
>
{})
<=
KThreadClusterSize
*
KThreadSliceSize
;
if
(
!
sweep_once_
)
// if not sweep once, compute memory size for matrix X in lds for
// store Intermediate results
{
int
block_TileSize
=
M_BlockTileSize
*
reduce_total_length
;
x_lds_size_
=
block_TileSize
*
sizeof
(
XDataType
);
}
else
x_lds_size_
=
0
;
}
AccDataType
epsilon_
;
InDataTypePointerTuple
in_dev_buffers_
;
const
GammaDataType
*
p_gamma_
;
const
BetaDataType
*
p_beta_
;
YDataType
*
p_y_
;
std
::
vector
<
index_t
>
Lengths_
;
std
::
array
<
std
::
vector
<
index_t
>
,
NumInput
>
inStridesArray_
;
std
::
vector
<
index_t
>
xStrides_
;
std
::
vector
<
index_t
>
gammaStrides_
;
std
::
vector
<
index_t
>
betaStrides_
;
std
::
vector
<
index_t
>
yStrides_
;
XElementwiseOperation
x_elementwise_op_
;
YElementwiseOperation
y_elementwise_op_
;
int
blkGroupSize_
;
int
numBlockTileIteration_
;
size_t
gridSize_
;
InGrid2dDescTuple
in_grid_2d_desc_tuple_
;
GridDesc_M_K
x_grid_desc_m_k_
;
GridDesc_M_K
gamma_grid_desc_m_k_
;
GridDesc_M_K
beta_grid_desc_m_k_
;
GridDesc_M_K
y_grid_desc_m_k_
;
bool
sweep_once_
;
int
x_lds_size_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
auto
kernel_main
=
arg
.
sweep_once_
?
kernel_elementwise_layernorm
<
GridwiseReduceLayernormSweepOnce
,
InDataTypePointerTuple
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
XElementwiseOperation
,
YElementwiseOperation
,
InGrid2dDescTuple
,
GridDesc_M_K
>
:
kernel_elementwise_layernorm
<
GridwiseReduceLayernormGeneric
,
InDataTypePointerTuple
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
XElementwiseOperation
,
YElementwiseOperation
,
InGrid2dDescTuple
,
GridDesc_M_K
>
;
float
avg_time
=
0
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_main
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
arg
.
x_lds_size_
,
arg
.
in_grid_2d_desc_tuple_
,
arg
.
x_grid_desc_m_k_
,
arg
.
gamma_grid_desc_m_k_
,
arg
.
beta_grid_desc_m_k_
,
arg
.
y_grid_desc_m_k_
,
arg
.
numBlockTileIteration_
,
arg
.
epsilon_
,
arg
.
in_dev_buffers_
,
arg
.
p_gamma_
,
arg
.
p_beta_
,
arg
.
p_y_
,
arg
.
x_elementwise_op_
,
arg
.
y_elementwise_op_
);
return
(
avg_time
);
};
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
};
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
if
constexpr
(
XYSrcVectorDim
==
0
)
{
if
constexpr
(
NumInvariantDim
==
0
)
{
return
false
;
}
else
{
for
(
int
i
=
0
;
i
<
NumInput
;
i
++
)
{
if
(
p_arg_
->
inStridesArray_
[
i
][
NumInvariantDim
-
1
]
!=
1
)
return
false
;
}
if
(
p_arg_
->
inStridesArray_
[
0
][
NumInvariantDim
-
1
]
!=
1
&&
p_arg_
->
inStridesArray_
[
1
][
NumInvariantDim
-
1
]
!=
1
)
return
false
;
if
(
p_arg_
->
invariant_lowest_length
%
XSrcVectorSize
!=
0
)
return
false
;
};
}
else
{
for
(
int
i
=
0
;
i
<
NumInput
;
i
++
)
{
if
(
p_arg_
->
inStridesArray_
[
i
][
Rank
-
1
]
!=
1
)
return
false
;
}
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
XSrcVectorSize
!=
0
)
return
false
;
};
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
YDstVectorSize
!=
0
)
{
return
false
;
}
auto
IsScalarPerVectorValid
=
[](
bool
isLastDimensionCoalesced
,
int
scalarPerVector
)
{
bool
ret
=
true
;
if
(
!
isLastDimensionCoalesced
)
ret
=
scalarPerVector
==
1
;
else
ret
=
KThreadSliceSize
%
scalarPerVector
==
0
;
return
ret
;
};
if
(
!
IsScalarPerVectorValid
(
p_arg_
->
gammaStrides_
.
back
()
==
1
,
GammaSrcVectorSize
))
return
false
;
if
(
!
IsScalarPerVectorValid
(
p_arg_
->
betaStrides_
.
back
()
==
1
,
BetaSrcVectorSize
))
return
false
;
// if fastest dim is not reduced
if
constexpr
(
XYSrcVectorDim
==
0
)
//
{
if
(
p_arg_
->
gammaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
(
false
);
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
GammaSrcVectorSize
!=
0
)
return
(
false
);
}
else
// if fastest dim is reduced
{
if
(
p_arg_
->
gammaStrides_
[
Rank
-
1
]
!=
1
)
return
(
false
);
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
GammaSrcVectorSize
!=
0
)
return
(
false
);
}
// if fastest dim is not reduced
if
constexpr
(
XYSrcVectorDim
==
0
)
{
if
(
p_arg_
->
betaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
(
false
);
if
(
p_arg_
->
invariant_lowest_length
%
BetaSrcVectorSize
!=
0
)
return
(
false
);
}
else
// if fastest dim is reduced
{
if
(
p_arg_
->
betaStrides_
[
Rank
-
1
]
!=
1
)
return
(
false
);
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
BetaSrcVectorSize
!=
0
)
return
(
false
);
}
return
true
;
};
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumInput
>
inStridesArray
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
AccDataType
epsilon
,
const
std
::
array
<
const
void
*
,
NumInput
>
in_dev_buffers
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
void
*
p_y
,
XElementwiseOperation
x_elementwise_op
,
YElementwiseOperation
y_elementwise_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
lengths
,
inStridesArray
,
gammaStrides
,
betaStrides
,
yStrides
,
reduceDims
,
x_elementwise_op
,
y_elementwise_op
,
epsilon
,
in_dev_buffers
,
static_cast
<
const
GammaDataType
*>
(
p_gamma
),
static_cast
<
const
BetaDataType
*>
(
p_beta
),
static_cast
<
YDataType
*>
(
p_y
));
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceElementwiseNormalizationImpl<"
<<
BlockSize
<<
","
;
str
<<
"M_C"
<<
MThreadClusterSize
<<
"_S"
<<
MThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
KThreadClusterSize
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"XYSrcVectorDim_"
<<
XYSrcVectorDim
<<
","
;
str
<<
"VectorSize_X"
<<
XSrcVectorSize
<<
"_Gamma"
<<
GammaSrcVectorSize
<<
"_Beta"
<<
BetaSrcVectorSize
<<
"_Y"
<<
YDstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
View file @
05ee41c3
...
...
@@ -141,7 +141,8 @@ template <typename ALayout,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
DeviceGemmMultipleD_Xdl_CShuffle
:
public
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
DsLayout
,
...
...
@@ -282,7 +283,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
LoopSched
,
PipelineVer
>
;
// desc for blockwise copy
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
...
...
@@ -664,6 +666,12 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
{
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
}};
// clang-format off
str
<<
"DeviceGemmMultipleD_Xdl_CShuffle"
<<
"<"
...
...
@@ -674,7 +682,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
">"
;
<<
">"
<<
" LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
"PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
];
// clang-format on
return
str
.
str
();
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
View file @
05ee41c3
...
...
@@ -56,7 +56,9 @@ template <typename ADataType,
bool
BBlockLdsAddExtraN
,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferDstScalarPerVector
,
ck
::
index_t
NumPrefetch
=
1
>
ck
::
index_t
NumPrefetch
=
1
,
ck
::
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
ck
::
PipelineVersion
PipelineVer
=
ck
::
PipelineVersion
::
v1
>
struct
DeviceGemmXdl
:
public
DeviceGemm
<
ALayout
,
BLayout
,
CLayout
,
...
...
@@ -230,7 +232,9 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
Sequence
<
0
,
2
,
4
,
5
,
6
,
1
,
3
,
7
>
,
// CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
NumPrefetch
>
;
NumPrefetch
,
LoopSched
,
PipelineVer
>
;
// Argument
struct
Argument
:
public
BaseArgument
...
...
@@ -523,6 +527,12 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
{
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
}};
// clang-format off
str
<<
"DeviceGemmXdl"
<<
"<"
...
...
@@ -535,7 +545,13 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
">"
;
<<
">"
<<
" NumPrefetch: "
<<
NumPrefetch
<<
", "
<<
"LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
"PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
];
// clang-format on
return
str
.
str
();
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
View file @
05ee41c3
...
...
@@ -64,7 +64,8 @@ template <typename ALayout,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
DeviceGemm_Xdl_CShuffle
:
public
DeviceGemm
<
ALayout
,
BLayout
,
CLayout
,
...
...
@@ -393,7 +394,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
LoopSched
,
PipelineVer
>
;
// Argument
struct
Argument
:
public
BaseArgument
...
...
@@ -656,6 +658,12 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
{
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
}};
// clang-format off
str
<<
"DeviceGemm_Xdl_CShuffle"
<<
"<"
...
...
@@ -665,7 +673,11 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
">"
;
<<
">"
<<
" LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
"PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
];;
// clang-format on
return
str
.
str
();
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv
nd
_bwd_weight_nwc_kxc_nwk_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_
grouped_
conv_bwd_weight_
g
nwc_
g
kxc_
g
nwk_xdl_cshuffle.hpp
View file @
05ee41c3
...
...
@@ -4,13 +4,14 @@
#pragma once
#include <iostream>
#include <numeric>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/device_
grouped_
conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/host_utility/device_prop.hpp"
...
...
@@ -20,6 +21,108 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
namespace
{
struct
ComputePtrOffsetOfStridedBatch
{
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideC_
);
}
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
index_t
BatchStrideC_
;
};
}
// namespace
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
AGridDesc_B_K0_M_K1
,
typename
BGridDesc_B_K0_N_K1
,
typename
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2CTileMap
,
typename
ComputePtrOffsetOfBatch
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_batched_gemm_xdlops_bwd_weight
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
index_t
batch_count
,
const
AGridDesc_B_K0_M_K1
a_b_k0_m_k1_grid_desc
,
const
BGridDesc_B_K0_N_K1
b_b_k0_n_k1_grid_desc
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2CTileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
)));
__shared__
FloatAB
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
)];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_shared
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
a_element_op
,
b_element_op
,
c_element_op
,
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
a_b_k0_m_k1_grid_desc
;
ignore
=
b_b_k0_n_k1_grid_desc
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
batch_count
;
ignore
=
block_2_ctile_map
;
ignore
=
compute_ptr_offset_of_batch
;
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
0
);
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
0
);
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
0
);
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template
<
ck
::
index_t
NDimSpatial
,
typename
InDataType
,
...
...
@@ -57,21 +160,21 @@ template <ck::index_t NDimSpatial,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
>
struct
DeviceConv
Nd
BwdWeight
NwcKxcN
wk_Xdl_CShuffle
:
public
DeviceConvBwdWeight
<
struct
Device
Grouped
ConvBwdWeight
GnwcGkxcGn
wk_Xdl_CShuffle
:
public
Device
Grouped
ConvBwdWeight
<
NDimSpatial
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
NWC
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
NDHWC
>>
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
G
NWC
,
ck
::
tensor_layout
::
convolution
::
G
NHWC
,
ck
::
tensor_layout
::
convolution
::
G
NDHWC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
KXC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
>>
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
G
KXC
,
ck
::
tensor_layout
::
convolution
::
G
KYXC
,
ck
::
tensor_layout
::
convolution
::
G
KZYXC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
NWK
,
ck
::
tensor_layout
::
convolution
::
NHWK
,
ck
::
tensor_layout
::
convolution
::
NDHWK
>>
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
G
NWK
,
ck
::
tensor_layout
::
convolution
::
G
NHWK
,
ck
::
tensor_layout
::
convolution
::
G
NDHWK
>>
,
InDataType
,
WeiDataType
,
OutDataType
,
...
...
@@ -79,7 +182,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
using
DeviceOp
=
DeviceConv
Nd
BwdWeight
NwcKxcN
wk_Xdl_CShuffle
;
using
DeviceOp
=
Device
Grouped
ConvBwdWeight
GnwcGkxcGn
wk_Xdl_CShuffle
;
using
ADataType
=
OutDataType
;
using
BDataType
=
InDataType
;
...
...
@@ -117,17 +220,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
static
constexpr
auto
BBlockLdsN1Padding
=
4
;
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
ck
::
index_t
batch_k
)
{
using
namespace
ck
;
...
...
@@ -269,17 +372,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
ck
::
index_t
batch_k
)
{
using
namespace
ck
;
...
...
@@ -436,17 +539,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
ck
::
index_t
batch_k
)
{
using
namespace
ck
;
...
...
@@ -664,8 +767,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
}
template
<
index_t
Dim
>
static
auto
MakeDescriptor_M0
(
const
std
::
vector
<
index_t
>&
shape
,
const
std
::
vector
<
index_t
>&
stride
,
static
auto
MakeDescriptor_M0
(
const
std
::
array
<
index_t
,
Dim
>&
shape
,
const
std
::
array
<
index_t
,
Dim
>&
stride
,
index_t
gridSize
,
index_t
blockSize
)
{
...
...
@@ -759,16 +862,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
Argument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
ck
::
index_t
M01
,
ck
::
index_t
N01
,
InElementwiseOperation
in_element_op
,
...
...
@@ -783,11 +887,13 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
c_grid_desc_m_n_
{},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_ctile_map_
{},
compute_ptr_offset_of_batch_
{},
M01_
{
M01
},
N01_
{
N01
},
a_element_op_
{
out_element_op
},
b_element_op_
{
in_element_op
},
c_element_op_
{
wei_element_op
},
Conv_G_
{
G
},
Conv_N_
{
N
},
Conv_K_
{
K
},
Conv_C_
{
C
},
...
...
@@ -819,6 +925,26 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
block_2_ctile_map_
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_grid_desc_m_n_
,
M01
,
N01
,
k_batch_
);
// A/B/C Batch Stride
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
N
*
K
*
std
::
accumulate
(
begin
(
output_spatial_lengths
),
end
(
output_spatial_lengths
),
index_t
{
1
},
std
::
multiplies
<>
{});
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
N
*
C
*
std
::
accumulate
(
begin
(
input_spatial_lengths
),
end
(
input_spatial_lengths
),
index_t
{
1
},
std
::
multiplies
<>
{});
compute_ptr_offset_of_batch_
.
BatchStrideC_
=
K
*
C
*
std
::
accumulate
(
begin
(
filter_spatial_lengths
),
end
(
filter_spatial_lengths
),
index_t
{
1
},
std
::
multiplies
<>
{});
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_kbatch_k0_m_k1_
,
b_grid_desc_kbatch_k0_n_k1_
,
c_grid_desc_m_n_
,
...
...
@@ -836,21 +962,29 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
BGridDesc_K0_N_K1
b_grid_desc_kbatch_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
Block2CTileMap
block_2_ctile_map_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
index_t
M01_
;
index_t
N01_
;
InElementwiseOperation
a_element_op_
;
OutElementwiseOperation
b_element_op_
;
WeiElementwiseOperation
c_element_op_
;
// for checking IsSupportedArgument()
index_t
Conv_G_
;
index_t
Conv_N_
;
index_t
Conv_K_
;
index_t
Conv_C_
;
std
::
vector
<
index_t
>
output_spatial_lengths_
;
std
::
vector
<
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
index_t
>
conv_filter_strides_
;
std
::
vector
<
index_t
>
input_left_pads_
;
std
::
vector
<
index_t
>
input_right_pads_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads_
;
index_t
k_batch_
;
};
...
...
@@ -873,14 +1007,12 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I2
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{
"
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"arg.c_grid_desc_m_n_{"
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
ShowInfo
(
arg
);
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
...
...
@@ -891,7 +1023,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
}
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
Conv_G_
;
const
auto
K0
=
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
);
...
...
@@ -900,17 +1032,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_gemm_xdlops_bwd_weight
<
const
auto
kernel
=
kernel_
batched_
gemm_xdlops_bwd_weight
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
OutElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
...
...
@@ -921,13 +1054,15 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
arg
.
Conv_G_
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
,
arg
.
compute_ptr_offset_of_batch_
);
};
if
(
has_main_k0_block_loop
)
...
...
@@ -998,16 +1133,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
@@ -1016,6 +1152,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
return
Argument
{
p_in_grid
,
p_wei_grid
,
p_out_grid
,
G
,
N
,
K
,
C
,
...
...
@@ -1040,16 +1177,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
MakeArgumentPointer
(
const
void
*
p_in_grid
,
void
*
p_wei_grid
,
const
void
*
p_out_grid
,
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
@@ -1058,6 +1196,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
static_cast
<
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
const
OutDataType
*>
(
p_out_grid
),
G
,
N
,
K
,
C
,
...
...
@@ -1086,7 +1225,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceConv
Nd
BwdWeight
NwcKxcN
wk_Xdl_CShuffle"
str
<<
"Device
Grouped
ConvBwdWeight
GnwcGkxcGn
wk_Xdl_CShuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
View file @
05ee41c3
...
...
@@ -22,6 +22,7 @@
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
#include "ck/library/utility/numeric.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -410,10 +411,9 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
{
const
index_t
N
=
r_g_n_wos_lengths
[
1
];
const
index_t
NHoWo
=
N
*
std
::
accumulate
(
r_g_n_wos_lengths
.
begin
()
+
2
,
r_g_n_wos_lengths
.
begin
()
+
2
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
r_g_n_wos_lengths
.
begin
()
+
2
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
auto
r_grid_desc_mraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
NHoWo
));
...
...
@@ -435,10 +435,9 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
const
index_t
WoStride
=
r_g_n_wos_strides
[
NDimSpatial
+
2
];
const
index_t
NHoWo
=
N
*
std
::
accumulate
(
r_g_n_wos_lengths
.
begin
()
+
2
,
r_g_n_wos_lengths
.
begin
()
+
2
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
r_g_n_wos_lengths
.
begin
()
+
2
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
auto
r_grid_desc_mraw
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
),
make_tuple
(
WoStride
));
...
...
include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp
View file @
05ee41c3
...
...
@@ -10,7 +10,7 @@
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_
layernorm
_welford_variance.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_
normalization
_welford_variance.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -24,7 +24,7 @@ template <typename GridwiseReduction,
typename
AccDataType
,
typename
AccElementwiseOperation
,
typename
GridDesc_M_K
>
__global__
void
kernel_
layernorm
(
const
GridDesc_M_K
x_grid_desc_m_k
,
__global__
void
kernel_
normalization
(
const
GridDesc_M_K
x_grid_desc_m_k
,
const
GridDesc_M_K
gamma_grid_desc_m_k
,
const
GridDesc_M_K
beta_grid_desc_m_k
,
const
GridDesc_M_K
y_grid_desc_m_k
,
...
...
@@ -54,7 +54,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
// Y =
LayerNorm
(X, Beta, Gamma)
// Y =
Normalization
(X, Beta, Gamma)
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
...
...
@@ -168,7 +168,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
using
GridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
GridwiseReduceLayernormGeneric
=
Gridwise
Layernorm
WelfordVariance_mk_to_mk
<
XDataType
,
Gridwise
Normalization
WelfordVariance_mk_to_mk
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
...
...
@@ -189,8 +189,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
XYSrcVectorDim
,
YDstVectorSize
,
false
>
;
using
Gridwise
ReduceLayernorm
SweepOnce
=
Gridwise
Layernorm
WelfordVariance_mk_to_mk
<
XDataType
,
using
Gridwise
Normalization
SweepOnce
=
Gridwise
Normalization
WelfordVariance_mk_to_mk
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
...
...
@@ -295,7 +295,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
auto
kernel_main
=
arg
.
isSweeponce_
?
kernel_
layernorm
<
GridwiseReduceLayernorm
SweepOnce
,
?
kernel_
normalization
<
GridwiseNormalization
SweepOnce
,
XDataType
,
GammaDataType
,
BetaDataType
,
...
...
@@ -303,7 +303,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
AccDataType
,
AccElementwiseOperation
,
GridDesc_M_K
>
:
kernel_
layernorm
<
GridwiseReduceLayernormGeneric
,
:
kernel_
normalization
<
GridwiseReduceLayernormGeneric
,
XDataType
,
GammaDataType
,
BetaDataType
,
...
...
@@ -426,8 +426,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const
void
*
p_gamma
,
const
void
*
p_beta
,
void
*
p_y
,
void
*
p_saveMean
,
void
*
p_saveInvVar
,
AccElementwiseOperation
acc_elementwise_op
)
override
{
// TODO
// Optional cache of the intermediate results (mean and InvVariance) during the
// forward pass could speedup in the backward
ignore
=
p_saveMean
;
ignore
=
p_saveInvVar
;
return
std
::
make_unique
<
Argument
>
(
lengths
,
xStrides
,
gammaStrides
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp
View file @
05ee41c3
...
...
@@ -226,6 +226,30 @@ struct DeviceReduceMultiBlock
in_elementwise_op_
{
in_elementwise_op
},
acc_elementwise_op_
{
acc_elementwise_op
}
{
if
(
Rank
!=
inLengths
.
size
()
||
Rank
!=
inStrides
.
size
()
||
NumReduceDim
!=
reduceDims
.
size
())
{
throw
std
::
runtime_error
(
"One of inLengths/inStrides/reduceDims has invalid size!"
"
\n
Expected size inLengths: "
+
std
::
to_string
(
Rank
)
+
", inStrides: "
+
std
::
to_string
(
Rank
)
+
", reduceDims: "
+
std
::
to_string
(
NumReduceDim
)
+
"
\n
But have inLengths: "
+
std
::
to_string
(
inLengths
.
size
())
+
", inStrides: "
+
std
::
to_string
(
inStrides
.
size
())
+
", reduceDims: "
+
std
::
to_string
(
reduceDims
.
size
()));
}
for
(
std
::
size_t
i
=
0
;
i
<
reduceDims
.
size
();
++
i
)
{
if
(
reduceDims
[
i
]
<
0
||
reduceDims
[
i
]
>=
Rank
)
{
throw
std
::
runtime_error
(
"Provided reduce dimension exceed input tensor Rank!"
"
\n
Have reduceDims["
+
std
::
to_string
(
i
)
+
"]: "
+
std
::
to_string
(
reduceDims
[
i
]));
}
}
inLengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inLengths
,
reduceDims
);
inStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inStrides
,
reduceDims
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp
View file @
05ee41c3
...
...
@@ -42,6 +42,7 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
{
static
constexpr
index_t
kRank
=
Rank
;
static
constexpr
index_t
kNumReduceDim
=
NumReduceDim
;
static
constexpr
index_t
kNumInvariantDim
=
Rank
-
NumReduceDim
;
virtual
index_t
GetRank
()
const
override
{
return
kRank
;
}
...
...
@@ -168,6 +169,30 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
in_elementwise_op_
{
in_elementwise_op
},
acc_elementwise_op_
{
acc_elementwise_op
}
{
if
(
Rank
!=
inLengths
.
size
()
||
Rank
!=
inStrides
.
size
()
||
NumReduceDim
!=
reduceDims
.
size
())
{
throw
std
::
runtime_error
(
"One of inLengths/inStrides/reduceDims has invalid size!"
"
\n
Expected size inLengths: "
+
std
::
to_string
(
Rank
)
+
", inStrides: "
+
std
::
to_string
(
Rank
)
+
", reduceDims: "
+
std
::
to_string
(
NumReduceDim
)
+
"
\n
But have inLengths: "
+
std
::
to_string
(
inLengths
.
size
())
+
", inStrides: "
+
std
::
to_string
(
inStrides
.
size
())
+
", reduceDims: "
+
std
::
to_string
(
reduceDims
.
size
()));
}
for
(
std
::
size_t
i
=
0
;
i
<
reduceDims
.
size
();
++
i
)
{
if
(
reduceDims
[
i
]
<
0
||
reduceDims
[
i
]
>=
Rank
)
{
throw
std
::
runtime_error
(
"Provided reduce dimension exceed input tensor Rank!"
"
\n
Have reduceDims["
+
std
::
to_string
(
i
)
+
"]: "
+
std
::
to_string
(
reduceDims
[
i
]));
}
}
inLengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inLengths
,
reduceDims
);
inStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
inStrides
,
reduceDims
);
...
...
@@ -257,40 +282,78 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
};
};
bool
IsSupportedArgument
(
const
Base
Argument
*
p_arg
)
override
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
const
Argument
*
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
constexpr
(
InSrcVectorDim
==
0
)
{
if
constexpr
(
NumInvariantDim
==
0
)
if
constexpr
(
k
NumInvariantDim
==
0
)
{
return
false
;
}
else
{
if
(
p_arg_
->
inStrides_
[
NumInvariantDim
-
1
]
!=
1
)
if
(
arg
.
inStrides_
[
kNumInvariantDim
-
1
]
!=
1
&&
InSrcVectorSize
!=
1
)
{
return
false
;
if
(
p_arg_
->
invariant_lowest_length_
%
InSrcVectorSize
!=
0
)
}
if
(
arg
.
invariant_lowest_length_
%
InSrcVectorSize
!=
0
)
{
return
false
;
};
}
}
}
else
{
if
(
p_arg_
->
inStrides_
[
Rank
-
1
]
!=
1
)
if
(
arg
.
inStrides_
[
Rank
-
1
]
!=
1
&&
InSrcVectorSize
!=
1
)
{
return
false
;
}
if
(
arg
.
inLengths_
[
Rank
-
1
]
%
InSrcVectorSize
!=
0
)
{
return
false
;
}
}
if
(
p_arg_
->
inLengths_
[
Rank
-
1
]
%
InSrcVectorSize
!=
0
)
// To improve
if
(
kNumInvariantDim
>
0
&&
arg
.
invariant_lowest_length_
%
OutDstVectorSize
!=
0
)
{
return
false
;
}
;
}
if
(
p_arg_
->
invariant_lowest_length_
%
OutDstVectorSize
!=
0
)
if
(
arg
.
inLengths_
[
Rank
-
1
]
%
OutDstVectorSize
!=
0
)
{
return
false
;
}
return
true
;
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
int
>
reduceDims
,
const
AccDataType
alpha
,
const
AccDataType
beta
,
const
InDataType
*
in_dev
,
OutDataType
*
out_dev
,
InElementwiseOp
in_elementwise_op
,
AccElementwiseOp
acc_elementwise_op
)
{
return
Argument
{
inLengths
,
inStrides
,
reduceDims
,
alpha
,
beta
,
in_dev
,
out_dev
,
in_elementwise_op
,
acc_elementwise_op
};
};
//
// @brief Makes a pointer to Argument class.
//
...
...
@@ -330,6 +393,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
acc_elementwise_op
);
};
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
...
...
@@ -340,10 +405,13 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceReduceSoftmax<"
<<
BlockSize
<<
","
;
str
<<
"M_C"
<<
MThreadClusterSize
<<
"_S"
<<
MThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
KThreadClusterSize
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"InSrcVectorDim_"
<<
InSrcVectorDim
<<
"_InSrcVectorSize_"
<<
InSrcVectorSize
<<
"_OutDstVectorSize_"
<<
OutDstVectorSize
<<
">"
;
str
<<
"DeviceReduceSoftmax<"
<<
Rank
<<
","
<<
NumReduceDim
<<
","
<<
BlockSize
<<
","
<<
"M_C"
<<
MThreadClusterSize
<<
"_S"
<<
MThreadSliceSize
<<
","
<<
"K_C"
<<
KThreadClusterSize
<<
"_S"
<<
KThreadSliceSize
<<
","
<<
"InSrcVectorDim_"
<<
InSrcVectorDim
<<
"_InSrcVectorSize_"
<<
InSrcVectorSize
<<
"_OutDstVectorSize_"
<<
OutDstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
...
...
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
05ee41c3
...
...
@@ -7,6 +7,7 @@
#include "ck/utility/math_v2.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/quantization_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
include/ck/tensor_operation/gpu/element/quantization_operation.hpp
0 → 100644
View file @
05ee41c3
#pragma once
#include "ck/utility/data_type.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
element_wise
{
// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc
template
<
typename
Activation
>
struct
Activation_Mul_Clamp
{
Activation_Mul_Clamp
(
float
requantScale
,
Activation
activationOp
)
:
requantScale_
(
requantScale
),
activationOp_
(
activationOp
)
{
}
__host__
__device__
constexpr
void
operator
()(
int8_t
&
y
,
const
int32_t
&
x
)
const
{
float
x_fp32
=
ck
::
type_convert
<
float
>
(
x
);
activationOp_
(
x_fp32
,
x_fp32
);
float
y_fp32
=
math
::
clamp
(
requantScale_
*
x_fp32
,
-
128.
f
,
127.
f
);
y
=
ck
::
type_convert
<
int8_t
>
(
y_fp32
);
}
__host__
__device__
constexpr
void
operator
()(
float
&
y
,
const
int32_t
&
x
)
const
{
// We might type_convert to int8 after lambda in someplace
float
x_fp32
=
ck
::
type_convert
<
float
>
(
x
);
activationOp_
(
x_fp32
,
x_fp32
);
y
=
math
::
clamp
(
requantScale_
*
x_fp32
,
-
128.
f
,
127.
f
);
}
float
requantScale_
;
Activation
activationOp_
;
};
// Conv Perchannel quantization + Activation function which is piecewise linear function, such as
// relu, leaky relu ...etc
template
<
typename
Activation
>
struct
Activation_Mul2_Clamp
{
Activation_Mul2_Clamp
(
Activation
activationOp
)
:
activationOp_
(
activationOp
)
{}
__host__
__device__
constexpr
void
operator
()(
int8_t
&
y
,
const
int32_t
&
x
,
const
float
&
requantScale
)
const
{
float
y_fp32
=
ck
::
type_convert
<
float
>
(
x
);
activationOp_
(
y_fp32
,
y_fp32
);
y_fp32
=
math
::
clamp
(
requantScale
*
y_fp32
,
-
128.
f
,
127.
f
);
y
=
ck
::
type_convert
<
int8_t
>
(
y_fp32
);
}
Activation
activationOp_
;
};
// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc
template
<
typename
Activation
>
struct
Add_Activation_Mul_Clamp
{
Add_Activation_Mul_Clamp
(
float
requantScale
,
Activation
activationOp
)
:
requantScale_
(
requantScale
),
activationOp_
(
activationOp
)
{
}
__host__
__device__
constexpr
void
operator
()(
int8_t
&
y
,
const
int32_t
&
x
,
const
int32_t
&
bias
)
const
{
float
y_fp32
=
ck
::
type_convert
<
float
>
(
x
+
bias
);
activationOp_
(
y_fp32
,
y_fp32
);
y_fp32
=
math
::
clamp
(
requantScale_
*
y_fp32
,
-
128.
f
,
127.
f
);
y
=
ck
::
type_convert
<
int8_t
>
(
y_fp32
);
}
float
requantScale_
;
Activation
activationOp_
;
};
// Conv Perchannel quantization + Activation function which is piecewise linear function, such as
// relu, leaky relu ...etc
template
<
typename
Activation
>
struct
Add_Activation_Mul2_Clamp
{
Add_Activation_Mul2_Clamp
(
Activation
activationOp
)
:
activationOp_
(
activationOp
)
{}
__host__
__device__
constexpr
void
operator
()(
int8_t
&
y
,
const
int32_t
&
x
,
const
int32_t
&
bias
,
const
float
&
requantScale
)
const
{
float
y_fp32
=
ck
::
type_convert
<
float
>
(
x
+
bias
);
activationOp_
(
y_fp32
,
y_fp32
);
y_fp32
=
math
::
clamp
(
requantScale
*
y_fp32
,
-
128.
f
,
127.
f
);
y
=
ck
::
type_convert
<
int8_t
>
(
y_fp32
);
}
Activation
activationOp_
;
};
// For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc
template
<
typename
Activation
>
struct
Add_Mul_Activation_Mul_Clamp
{
Add_Mul_Activation_Mul_Clamp
(
float
requantScale1
,
float
requantScale2
,
Activation
activationOp
)
:
requantScale1_
(
requantScale1
),
requantScale2_
(
requantScale2
),
activationOp_
(
activationOp
)
{
}
__host__
__device__
constexpr
void
operator
()(
int8_t
&
y
,
const
int32_t
&
x
,
const
int32_t
&
bias
)
const
{
float
y_fp32
=
ck
::
type_convert
<
float
>
(
x
+
bias
);
y_fp32
=
requantScale1_
*
y_fp32
;
activationOp_
(
y_fp32
,
y_fp32
);
y_fp32
=
math
::
clamp
(
requantScale2_
*
y_fp32
,
-
128.
f
,
127.
f
);
y
=
ck
::
type_convert
<
int8_t
>
(
y_fp32
);
}
float
requantScale1_
;
float
requantScale2_
;
Activation
activationOp_
;
};
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
05ee41c3
...
...
@@ -4,6 +4,7 @@
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/math_v2.hpp"
namespace
ck
{
...
...
@@ -193,21 +194,36 @@ struct Relu
}
};
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
// Y = FastGelu(X)
struct
FastGelu
{
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__
__device__
static
constexpr
float
GetFastGeLU
(
float
x
)
{
const
float
u
=
float
(
2
)
*
x
*
(
float
(
0.035677
)
*
x
*
x
+
float
(
0.797885
)
);
const
float
u
=
2.
f
*
x
*
(
0.035677
f
*
x
*
x
+
0.797885
f
);
const
float
emu
=
exp
(
-
u
);
const
float
cdf
=
float
(
0.5
)
+
float
(
0.5
)
*
(
float
(
2
)
/
(
float
(
1
)
+
emu
)
-
float
(
1
));
const
float
cdf
=
0.5
f
+
0.5
f
*
(
2.
f
/
(
1.
f
+
emu
)
-
1.
f
);
return
x
*
cdf
;
}
template
<
typename
T
>
static
inline
constexpr
bool
is_valid_param_type_v
=
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
half_t
>
||
std
::
is_same_v
<
T
,
bhalf_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
||
std
::
is_same_v
<
T
,
ck
::
int4_t
>
#endif
;
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
static_assert
(
is_valid_param_type_v
<
Y
>
&&
is_valid_param_type_v
<
X
>
);
y
=
x
*
cdf
;
const
float
tmp_y
=
GetFastGeLU
(
type_convert
<
float
>
(
x
));
y
=
type_convert
<
Y
>
(
tmp_y
);
}
};
...
...
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp
0 → 100644
View file @
05ee41c3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseReduceSecondHalfBatchNormBackwardFinal_
,
typename
XDataType
,
typename
DyDataType
,
typename
DxDataType
,
typename
ScaleDataType
,
typename
DscaleDbiasDataType
,
typename
MeanVarDataType
,
typename
DyElementwiseOp
,
typename
XYGridDesc_M_K
,
typename
DscaleDbiasGridDesc_M_K
,
typename
MeanVarGridDesc_M
,
typename
ScaleBiasGridDesc_M
>
__global__
void
kernel_reduce_second_half_batchnorm_backward_final
(
const
XYGridDesc_M_K
x_grid_desc_m_k
,
const
XYGridDesc_M_K
dy_grid_desc_m_k
,
const
XYGridDesc_M_K
dx_grid_desc_m_k
,
const
DscaleDbiasGridDesc_M_K
dscale_dbias_grid_desc_m_k
,
const
MeanVarGridDesc_M
mean_var_grid_desc_m
,
const
ScaleBiasGridDesc_M
scale_grid_desc_m
,
const
ScaleBiasGridDesc_M
bias_grid_desc_m
,
index_t
blkgroup_size
,
long_index_t
reduce_size
,
index_t
num_xy_k_block_tile_iteration
,
index_t
num_dscale_dbias_k_block_tile_iteration
,
const
DscaleDbiasDataType
*
const
__restrict__
p_reduce_dscale
,
const
DscaleDbiasDataType
*
const
__restrict__
p_reduce_dbias
,
const
MeanVarDataType
*
const
__restrict__
p_mean
,
const
MeanVarDataType
*
const
__restrict__
p_inv_var
,
const
XDataType
*
const
__restrict__
p_x
,
const
DyDataType
*
const
__restrict__
p_dy
,
const
ScaleDataType
*
const
__restrict__
p_scale
,
const
DyElementwiseOp
dy_elementwise_op
,
DxDataType
*
const
__restrict__
p_dx
,
DscaleDbiasDataType
*
const
__restrict__
p_dscale
,
DscaleDbiasDataType
*
const
__restrict__
p_dbias
)
{
GridwiseReduceSecondHalfBatchNormBackwardFinal_
::
Run
(
x_grid_desc_m_k
,
dy_grid_desc_m_k
,
dx_grid_desc_m_k
,
dscale_dbias_grid_desc_m_k
,
mean_var_grid_desc_m
,
scale_grid_desc_m
,
bias_grid_desc_m
,
blkgroup_size
,
reduce_size
,
num_xy_k_block_tile_iteration
,
num_dscale_dbias_k_block_tile_iteration
,
p_reduce_dscale
,
p_reduce_dbias
,
p_mean
,
p_inv_var
,
p_x
,
p_dy
,
p_scale
,
dy_elementwise_op
,
p_dx
,
p_dscale
,
p_dbias
);
};
template
<
typename
XDataType
,
typename
DyDataType
,
typename
DxDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
DscaleDbiasDataType
,
typename
MeanVarDataType
,
typename
DyElementwiseOp
,
typename
XYGridDesc_M_K
,
typename
DscaleDbiasGridDesc_M_K
,
typename
MeanVarGridDesc_M
,
typename
ScaleBiasGridDesc_M
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XDyDxVectorDim
,
index_t
XSrcVectorSize
,
index_t
DySrcVectorSize
,
index_t
DxDstVectorSize
,
index_t
ScaleSrcVectorSize
,
index_t
DscaleDbiasDstVectorSize
,
index_t
MeanVarSrcVectorSize
>
struct
GridwiseReduceSecondHalfBatchNormBackwardFinal
{
static_assert
((
XDyDxVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
&&
MThreadSliceSize
%
DySrcVectorSize
==
0
&&
MThreadSliceSize
%
DxDstVectorSize
==
0
)
||
(
XDyDxVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
&&
KThreadSliceSize
%
DySrcVectorSize
==
0
&&
KThreadSliceSize
%
DxDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
XDyDxVectorDim
==
0
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadReduceSrcDesc_M_1
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
1
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
BlockwiseReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
ck
::
reduce
::
Add
,
false
>
;
using
ThreadwiseReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_1
,
ThreadReduceDstDesc_M
,
ck
::
reduce
::
Add
,
false
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
// clang-format off
// Two of the steps of Multiblock BatchNorm Backward
// Step 1: Second half of Reduction: dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
// Step 2: calculating dx = 1/reduce_size * inv-variance * scale * (reduce_size * dy - dbias - dscale * (x - mean) * inv-variance)) elementwise-ly
// clang-format on
__device__
static
void
Run
(
const
XYGridDesc_M_K
&
x_grid_desc_m_k
,
const
XYGridDesc_M_K
&
dy_grid_desc_m_k
,
const
XYGridDesc_M_K
&
dx_grid_desc_m_k
,
const
DscaleDbiasGridDesc_M_K
&
dscale_dbias_grid_desc_m_k
,
const
MeanVarGridDesc_M
&
mean_var_grid_desc_m
,
const
ScaleBiasGridDesc_M
&
scale_grid_desc_m
,
const
ScaleBiasGridDesc_M
&
dscale_dbias_grid_desc_m
,
index_t
blkgroup_size
,
long_index_t
reduce_size
,
index_t
num_xy_k_block_tile_iteration
,
index_t
num_dscale_dbias_k_block_tile_iteration
,
const
DscaleDbiasDataType
*
const
__restrict__
p_reduce_dscale
,
const
DscaleDbiasDataType
*
const
__restrict__
p_reduce_dbias
,
const
MeanVarDataType
*
const
__restrict__
p_mean
,
const
MeanVarDataType
*
const
__restrict__
p_inv_var
,
const
XDataType
*
const
__restrict__
p_x
,
const
DyDataType
*
const
__restrict__
p_dy
,
const
ScaleDataType
*
const
__restrict__
p_scale
,
const
DyElementwiseOp
dy_elementwise_op
,
DxDataType
*
const
__restrict__
p_dx
,
DscaleDbiasDataType
*
const
__restrict__
p_dscale
,
DscaleDbiasDataType
*
const
__restrict__
p_dbias
)
{
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
auto
reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_buffer
,
BlockSize
);
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
1
,
true
>
reduce_dscale_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
1
,
true
>
reduce_dbias_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
dscale_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
dbias_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
x_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
dy_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
dx_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
inv_var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
scale_thread_buf
;
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
blkgroup_id
=
block_global_id
/
blkgroup_size
;
const
index_t
block_local_id
=
block_global_id
%
blkgroup_size
;
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_M
=
Sequence
<
MThreadSliceSize
>
;
using
ThreadBufferLengths_M_1
=
Sequence
<
MThreadSliceSize
,
1
>
;
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_m_1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
1
>
{}));
// clang-format off
// Step 1: do final reduction of dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
// clang-format on
auto
threadwise_dscale_dbias_load_m_k
=
ThreadwiseTensorSliceTransfer_v2
<
DscaleDbiasDataType
,
AccDataType
,
DscaleDbiasGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
true
>
(
dscale_dbias_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
1
));
auto
threadwise_dscale_dbias_store_m
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
DscaleDbiasDataType
,
decltype
(
thread_buffer_desc_m
),
ScaleBiasGridDesc_M
,
PassThroughOp
,
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
DscaleDbiasDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
dscale_dbias_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
),
PassThroughOp
{});
const
auto
reduce_dscale_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_reduce_dscale
,
dscale_dbias_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
reduce_dbias_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_reduce_dbias
,
dscale_dbias_grid_desc_m_k
.
GetElementSpaceSize
());
auto
dscale_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dscale
,
dscale_dbias_grid_desc_m
.
GetElementSpaceSize
());
auto
dbias_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dbias
,
dscale_dbias_grid_desc_m
.
GetElementSpaceSize
());
constexpr
auto
dscale_dbias_thread_copy_step_m_k
=
make_multi_index
(
0
,
KThreadClusterSize
*
1
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
dscale_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
dbias_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
});
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_dscale_dbias_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_dscale_dbias_load_m_k
.
Run
(
dscale_dbias_grid_desc_m_k
,
reduce_dscale_global_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
reduce_dscale_thread_buf
);
threadwise_dscale_dbias_load_m_k
.
Run
(
dscale_dbias_grid_desc_m_k
,
reduce_dbias_global_buf
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
reduce_dbias_thread_buf
);
ThreadwiseReduce
::
Reduce
(
reduce_dscale_thread_buf
,
dscale_thread_buf
);
ThreadwiseReduce
::
Reduce
(
reduce_dbias_thread_buf
,
dbias_thread_buf
);
threadwise_dscale_dbias_load_m_k
.
MoveSrcSliceWindow
(
dscale_dbias_grid_desc_m_k
,
dscale_dbias_thread_copy_step_m_k
);
}
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
dscale_thread_buf
(
I
));
block_sync_lds
();
BlockwiseReduce
::
Reduce
(
reduce_work_buf
,
dbias_thread_buf
(
I
));
});
threadwise_dscale_dbias_store_m
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
dscale_thread_buf
,
dscale_dbias_grid_desc_m
,
dscale_global_buf
);
threadwise_dscale_dbias_store_m
.
Run
(
thread_buffer_desc_m
,
make_tuple
(
I0
),
dbias_thread_buf
,
dscale_dbias_grid_desc_m
,
dbias_global_buf
);
// clang-format off
// Step 2: calculate dx = 1/N * inv-variance * scale * (N * dy - dbias - dscale * (x - mean) * inv-variance)
// clang-format on
const
index_t
workSizePerBlock
=
K_BlockTileSize
*
num_xy_k_block_tile_iteration
;
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
AccDataType
,
XYGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XDyDxVectorDim
,
XSrcVectorSize
,
1
,
true
>
(
x_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
workSizePerBlock
*
block_local_id
+
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_dy_load
=
ThreadwiseTensorSliceTransfer_v2
<
DyDataType
,
AccDataType
,
XYGridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XDyDxVectorDim
,
DySrcVectorSize
,
1
,
true
>
(
dy_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
workSizePerBlock
*
block_local_id
+
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_dx_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
DxDataType
,
decltype
(
thread_buffer_desc_m_k
),
XYGridDesc_M_K
,
PassThroughOp
,
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
XDyDxVectorDim
,
DxDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
dx_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
workSizePerBlock
*
block_local_id
+
thread_k_cluster_id
*
KThreadSliceSize
),
PassThroughOp
{});
auto
threadwise_scale_load
=
ThreadwiseTensorSliceTransfer_v2
<
ScaleDataType
,
AccDataType
,
ScaleBiasGridDesc_M
,
decltype
(
thread_buffer_desc_m
),
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
ScaleSrcVectorSize
,
1
,
true
>
(
scale_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
auto
threadwise_mean_var_load
=
ThreadwiseTensorSliceTransfer_v2
<
MeanVarDataType
,
AccDataType
,
MeanVarGridDesc_M
,
decltype
(
thread_buffer_desc_m
),
ThreadBufferLengths_M
,
Sequence
<
0
>
,
0
,
MeanVarSrcVectorSize
,
1
,
true
>
(
mean_var_grid_desc_m
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
));
const
auto
x_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
dy_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dy
,
dy_grid_desc_m_k
.
GetElementSpaceSize
());
auto
dx_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_dx
,
dx_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
scale_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_scale
,
scale_grid_desc_m
.
GetElementSpaceSize
());
const
auto
mean_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_mean
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
const
auto
inv_var_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_inv_var
,
mean_var_grid_desc_m
.
GetElementSpaceSize
());
threadwise_scale_load
.
Run
(
scale_grid_desc_m
,
scale_global_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
scale_thread_buf
);
threadwise_mean_var_load
.
Run
(
mean_var_grid_desc_m
,
mean_global_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
mean_thread_buf
);
threadwise_mean_var_load
.
Run
(
mean_var_grid_desc_m
,
inv_var_global_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
inv_var_thread_buf
);
constexpr
auto
xy_thread_copy_step_m_k
=
make_multi_index
(
0
,
K_BlockTileSize
);
AccDataType
inv_reduce_size
=
type_convert
<
AccDataType
>
(
1.0
)
/
type_convert
<
AccDataType
>
(
reduce_size
);
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_xy_k_block_tile_iteration
;
++
reducedTiles
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
threadwise_dy_load
.
Run
(
dy_grid_desc_m_k
,
dy_global_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dy_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
AccDataType
multiplier
=
inv_reduce_size
*
inv_var_thread_buf
[
iM
]
*
scale_thread_buf
[
iM
];
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
dy_elementwise_op
(
dy_thread_buf
(
Number
<
offset
>
{}),
dy_thread_buf
[
Number
<
offset
>
{}]);
AccDataType
norm_x
=
(
x_thread_buf
[
Number
<
offset
>
{}]
-
mean_thread_buf
[
iM
])
*
inv_var_thread_buf
[
iM
];
AccDataType
tmpVal
=
norm_x
*
dscale_thread_buf
[
iM
];
dx_thread_buf
(
Number
<
offset
>
{})
=
multiplier
*
(
type_convert
<
AccDataType
>
(
reduce_size
)
*
dy_thread_buf
[
Number
<
offset
>
{}]
-
dbias_thread_buf
[
iM
]
-
tmpVal
);
});
});
threadwise_dx_store
.
Run
(
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
dx_thread_buf
,
dx_grid_desc_m_k
,
dx_global_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
xy_thread_copy_step_m_k
);
threadwise_dy_load
.
MoveSrcSliceWindow
(
dy_grid_desc_m_k
,
xy_thread_copy_step_m_k
);
threadwise_dx_store
.
MoveDstSliceWindow
(
dx_grid_desc_m_k
,
xy_thread_copy_step_m_k
);
}
};
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp
View file @
05ee41c3
...
...
@@ -93,6 +93,9 @@ struct GridwiseMultiblockWelfordFirstHalf
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
// clang-format off
// First half of the Multiblock Welford method to calculate mean and variance, used by both batchnorm-forward and batchnorm-backward.
// clang-format on
__device__
static
void
Run
(
const
XGridDesc_M_K
&
x_grid_desc_m_k
,
const
MeanVarCountGridDesc_M_G
&
mean_var_count_grid_desc_m_g
,
const
GetReduceCountPerThreadFunctor
&
get_reduce_count_per_thread
,
...
...
Prev
1
…
5
6
7
8
9
10
11
12
13
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment