Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
dc0bae32
Commit
dc0bae32
authored
Feb 01, 2023
by
Adam Osewski
Browse files
Merge branch 'develop' into aosewski/wavelet_omniperf
parents
68474822
ba40c2ce
Changes
474
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
950 additions
and
36 deletions
+950
-36
include/ck/tensor_operation/gpu/device/device_multiple_reduce.hpp
...ck/tensor_operation/gpu/device/device_multiple_reduce.hpp
+2
-2
include/ck/tensor_operation/gpu/device/device_normalization.hpp
...e/ck/tensor_operation/gpu/device/device_normalization.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_permute.hpp
include/ck/tensor_operation/gpu/device/device_permute.hpp
+0
-1
include/ck/tensor_operation/gpu/device/device_reduce.hpp
include/ck/tensor_operation/gpu/device/device_reduce.hpp
+28
-8
include/ck/tensor_operation/gpu/device/device_softmax.hpp
include/ck/tensor_operation/gpu/device/device_softmax.hpp
+4
-6
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
...ation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
..._batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
+2
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp
...u/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp
+3
-3
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+1
-5
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp
...sor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp
+10
-1
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp
...ration/gpu/device/impl/device_batchnorm_backward_impl.hpp
+874
-0
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp
...eration/gpu/device/impl/device_batchnorm_forward_impl.hpp
+9
-2
include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+5
-2
include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
...device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
+2
-0
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
..._fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp
...nv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
.../gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
...u/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
+2
-0
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
...gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
+2
-0
No files found.
include/ck/tensor_operation/gpu/device/device_multiple_reduce.hpp
View file @
dc0bae32
...
...
@@ -32,8 +32,8 @@ struct DeviceMultipleReduce : public BaseOperator
const
std
::
array
<
index_t
,
NumOutputDim
>
outLengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumOutputDim
>
,
NumReduction
>
outStrides
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
const
std
::
array
<
const
void
*
,
NumReduction
>
alphas
,
const
std
::
array
<
const
void
*
,
NumReduction
>
betas
,
const
std
::
array
<
double
,
NumReduction
>
alphas
,
const
std
::
array
<
double
,
NumReduction
>
betas
,
const
void
*
in_dev
,
const
std
::
array
<
void
*
,
NumReduction
>
out_dev_buffers
,
const
InElementwiseOperationTuple
in_elementwise_op_tuple
,
...
...
include/ck/tensor_operation/gpu/device/device_normalization.hpp
View file @
dc0bae32
...
...
@@ -28,7 +28,7 @@ struct DeviceNormalization : public BaseOperator
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
AccDataTyp
e
epsilon
,
doubl
e
epsilon
,
const
void
*
p_x
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
...
...
include/ck/tensor_operation/gpu/device/device_permute.hpp
View file @
dc0bae32
...
...
@@ -4,7 +4,6 @@
#pragma once
#include <array>
#include <cmath>
#include <memory>
#include <type_traits>
...
...
include/ck/tensor_operation/gpu/device/device_reduce.hpp
View file @
dc0bae32
...
...
@@ -13,10 +13,16 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
template
<
index_t
Rank
,
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
,
index_t
Rank
,
index_t
NumReduceDim
,
typename
ReduceOperation
,
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
>
typename
AccElementwiseOperation
,
bool
PropagateNan
,
bool
OutputIndex
>
struct
DeviceReduce
:
public
BaseOperator
{
static
constexpr
index_t
NumOutDim
=
(
Rank
-
NumReduceDim
==
0
)
?
1
:
Rank
-
NumReduceDim
;
...
...
@@ -27,8 +33,8 @@ struct DeviceReduce : public BaseOperator
const
std
::
array
<
index_t
,
NumOutDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumOutDim
>
outStrides
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
float
alpha
,
float
beta
,
double
alpha
,
double
beta
,
const
void
*
in_dev
,
const
void
*
in_index_dev
,
void
*
out_dev
,
...
...
@@ -39,12 +45,26 @@ struct DeviceReduce : public BaseOperator
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
index_t
Rank
,
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
,
index_t
Rank
,
index_t
NumReduceDim
,
typename
ReduceOperation
,
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
>
using
DeviceReducePtr
=
std
::
unique_ptr
<
DeviceReduce
<
Rank
,
NumReduceDim
,
InElementwiseOperation
,
AccElementwiseOperation
>>
;
typename
AccElementwiseOperation
,
bool
PropagateNan
,
bool
OutputIndex
>
using
DeviceReducePtr
=
std
::
unique_ptr
<
DeviceReduce
<
InDataType
,
AccDataType
,
OutDataType
,
Rank
,
NumReduceDim
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
PropagateNan
,
OutputIndex
>>
;
}
// namespace device
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_softmax.hpp
View file @
dc0bae32
...
...
@@ -27,10 +27,8 @@ struct DeviceSoftmax : public BaseOperator
// @param[in] inLengths Input tensor extent(s) from high to low dimension
// @param[in] inStrides Input tensor stride(s) from high to low dimension
// @param[in] reduceDims The dimension(s) the normalization operation is applied
// @param[in] alpha Typeless pointer in host memory storing the alpha scaling
// value as type AccDataType
// @param[in] beta Typeless pointer in host memory storing the beta scaling
// value as type AccDataType
// @param[in] alpha double type value
// @param[in] beta double type value
// @param[in] in_dev Typeless const pointer in device memory storing the input
// tensor
// @param out_dev Typeless pointer in device memory storing the output tensor
...
...
@@ -43,8 +41,8 @@ struct DeviceSoftmax : public BaseOperator
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
int
>
reduceDims
,
const
void
*
alpha
,
const
void
*
beta
,
double
alpha
,
double
beta
,
const
void
*
in_dev
,
void
*
out_dev
,
InElementwiseOp
in_elementwise_op
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
View file @
dc0bae32
...
...
@@ -700,7 +700,7 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
View file @
dc0bae32
...
...
@@ -579,6 +579,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
BatchStrideD1s
,
BatchStrideE1
}
{
#if DEBUG_LOG
std
::
cout
<<
"a0_grid_desc_m_k_{"
<<
a0_grid_desc_m_k_
.
GetLength
(
I0
)
<<
", "
<<
a0_grid_desc_m_k_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"b0_grid_desc_n_k_{"
<<
b0_grid_desc_n_k_
.
GetLength
(
I0
)
<<
", "
...
...
@@ -601,6 +602,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<<
std
::
endl
;
std
::
cout
<<
"e1_grid_desc_m_n_{"
<<
e1_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
e1_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
#endif
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
using
D0Layout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
D0sLayout
>>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp
View file @
dc0bae32
...
...
@@ -657,7 +657,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
#if
0
#if
DEBUG_LOG
{
std
::
cout
<<
"arg.Batch_ = "
<<
arg
.
Batch_
<<
std
::
endl
;
...
...
@@ -674,8 +674,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0)
<< "}"
<< std::endl;
std
::
cout
<<
"arg.reduce_grid_desc_m_{ "
<<
arg
.
reduce_grid_desc_m_
.
GetLength
(
I0
)
<<
"}"
<<
std
::
endl
;
}
#endif
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
dc0bae32
...
...
@@ -485,19 +485,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
std
::
cout
<<
"a_grid_desc_g_m_k_: "
<<
a_grid_desc_g_m_k_
.
GetLength
(
I0
)
<<
", "
<<
a_grid_desc_g_m_k_
.
GetLength
(
I1
)
<<
", "
<<
a_grid_desc_g_m_k_
.
GetLength
(
I2
)
<<
'\n'
;
// a_grid_desc_g_m_k_.Print();
std
::
cout
<<
"b_grid_desc_g_n_k_: "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// b_grid_desc_g_n_k_.Print();
std
::
cout
<<
"b1_grid_desc_g_n_k_: "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
// b1_grid_desc_g_n_k_.Print();
std
::
cout
<<
"c_grid_desc_g_m_n_: "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I0
)
<<
", "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I1
)
<<
", "
<<
c_grid_desc_g_m_n_
.
GetLength
(
I2
)
<<
'\n'
;
// c_grid_desc_g_m_n_.Print();
}
// pointers
...
...
@@ -636,7 +632,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
#if
0
#if
DEBUG_LOG
arg
.
Print
();
#endif
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp
View file @
dc0bae32
...
...
@@ -373,7 +373,8 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
N01_
{
N01
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
c_element_op_
{
c_element_op
},
kraw_
{
K
}
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
...
...
@@ -401,6 +402,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
index_t
kraw_
;
};
// Invoker
...
...
@@ -410,6 +412,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
#if DEBUG_LOG
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
...
...
@@ -422,6 +425,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
std
::
cout
<<
"arg.c_grid_desc_m_n_{"
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
...
...
@@ -528,6 +532,11 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
arg
.
kraw_
%
K1
!=
0
)
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp
0 → 100644
View file @
dc0bae32
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp"
#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/welford_helper.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
XDataType
,
typename
DxDataType
,
typename
DyDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
DscaleDbiasDataType
,
typename
MeanVarDataType
,
typename
DyElementwiseOp
,
index_t
Rank
,
index_t
NumBatchNormReduceDim
,
bool
UseMultiblockInK
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XDyDxVectorDim
,
index_t
XSrcVectorSize
,
index_t
DySrcVectorSize
,
index_t
DxDstVectorSize
,
index_t
ScaleSrcVectorSize
,
index_t
DscaleDbiasDstVectorSize
,
index_t
MeanVarSrcVectorSize
>
struct
DeviceBatchNormBwdImpl
:
public
DeviceBatchNormBwd
<
XDataType
,
DxDataType
,
DyDataType
,
AccDataType
,
ScaleDataType
,
DscaleDbiasDataType
,
MeanVarDataType
,
DyElementwiseOp
,
Rank
,
NumBatchNormReduceDim
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
"Invalid thread cluster size assignments!"
);
static_assert
((
XDyDxVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
&&
MThreadSliceSize
%
DySrcVectorSize
==
0
&&
MThreadSliceSize
%
DxDstVectorSize
==
0
)
||
(
XDyDxVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
&&
KThreadSliceSize
%
DySrcVectorSize
==
0
&&
KThreadSliceSize
%
DxDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumBatchNormReduceDim
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
auto
MakeXY2dDescriptor
(
const
std
::
array
<
index_t
,
Rank
>&
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>&
xyStrides
,
int
blkGroupSize
,
int
numBlockTileIteration
)
{
const
auto
tupleXYLengths
=
generate_tuple
([
&
](
auto
I
)
{
return
xyLengths
[
I
];
},
Number
<
Rank
>
{});
const
auto
tupleXYStrides
=
generate_tuple
([
&
](
auto
I
)
{
return
xyStrides
[
I
];
},
Number
<
Rank
>
{});
const
auto
raw_grid_desc
=
make_naive_tensor_descriptor
(
tupleXYLengths
,
tupleXYStrides
);
const
auto
grid_desc_m_k
=
[
&
]()
{
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
using
ReduceDims
=
typename
arithmetic_sequence_gen
<
NumInvariantDim
,
Rank
,
1
>::
type
;
const
auto
reduceDimLengths
=
generate_tuple
([
&
](
auto
I
)
{
return
xyLengths
[
NumInvariantDim
+
I
];
},
Number
<
NumBatchNormReduceDim
>
{});
const
auto
invariantDimLengths
=
generate_tuple
([
&
](
auto
I
)
{
return
xyLengths
[
I
];
},
Number
<
NumInvariantDim
>
{});
return
transform_tensor_descriptor
(
raw_grid_desc
,
make_tuple
(
make_merge_transform
(
invariantDimLengths
),
make_merge_transform
(
reduceDimLengths
)),
make_tuple
(
InvariantDims
{},
ReduceDims
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}();
const
auto
invariantLength
=
grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
reduceLength
=
grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
int
workSizePerBlock
=
K_BlockTileSize
*
numBlockTileIteration
;
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
kPad
=
workSizePerBlock
*
blkGroupSize
-
reduceLength
;
auto
grid_desc_m_k_padded
=
transform_tensor_descriptor
(
grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
mPad
),
make_right_pad_transform
(
reduceLength
,
kPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
grid_desc_m_k_padded
);
};
static
auto
MakeMultiblockFirstReduceOutputMG2dDescriptor
(
int
invariantLength
,
int
blkGroupSize
)
{
const
auto
grid_desc_m_g
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
invariantLength
,
blkGroupSize
));
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
grid_desc_m_g_padded
=
transform_tensor_descriptor
(
grid_desc_m_g
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
mPad
),
make_pass_through_transform
(
blkGroupSize
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
grid_desc_m_g_padded
);
};
static
auto
MakeMultiblockFinalReduceInputMK2dDescriptor
(
int
invariantLength
,
int
blkGroupSize
)
{
const
auto
reduceLength
=
blkGroupSize
;
const
auto
grid_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
invariantLength
,
reduceLength
));
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
kPad
=
math
::
integer_least_multiple
(
reduceLength
,
KThreadClusterSize
)
-
reduceLength
;
auto
grid_desc_m_k_padded
=
transform_tensor_descriptor
(
grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
mPad
),
make_right_pad_transform
(
reduceLength
,
kPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
grid_desc_m_k_padded
);
};
static
auto
MakeScaleBiasMeanVar1dDescriptor
(
const
std
::
array
<
index_t
,
NumInvariantDim
>&
lengths
,
const
std
::
array
<
index_t
,
NumInvariantDim
>&
strides
)
{
const
auto
tupleLengths
=
generate_tuple
([
&
](
auto
I
)
{
return
lengths
[
I
];
},
Number
<
NumInvariantDim
>
{});
const
auto
tupleStrides
=
generate_tuple
([
&
](
auto
I
)
{
return
strides
[
I
];
},
Number
<
NumInvariantDim
>
{});
auto
raw_grid_desc
=
make_naive_tensor_descriptor
(
tupleLengths
,
tupleStrides
);
auto
grid_desc_m
=
transform_tensor_descriptor
(
raw_grid_desc
,
make_tuple
(
make_merge_transform
(
tupleLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
invariantLength
=
grid_desc_m
.
GetLength
(
Number
<
0
>
{});
const
auto
mPad
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
auto
grid_desc_m_padded
=
transform_tensor_descriptor
(
grid_desc_m
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
mPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
(
grid_desc_m_padded
);
};
using
XYGridDesc_M_K
=
decltype
(
MakeXY2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
ScaleBiasGridDesc_M
=
decltype
(
MakeScaleBiasMeanVar1dDescriptor
({
1
},
{
1
}));
using
MeanVarGridDesc_M
=
ScaleBiasGridDesc_M
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
dyStrides
,
const
std
::
array
<
index_t
,
Rank
>
dxStrides
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
ck
::
index_t
,
NumInvariantDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
ck
::
index_t
,
NumInvariantDim
>
bnScaleStrides
,
const
std
::
array
<
ck
::
index_t
,
NumInvariantDim
>
bnDscaleDbiasStrides
,
const
std
::
array
<
ck
::
index_t
,
NumInvariantDim
>
bnMeanVarStrides
,
const
XDataType
*
p_x
,
const
DyDataType
*
p_dy
,
const
ScaleDataType
*
p_scale
,
const
MeanVarDataType
*
p_savedMean
,
const
MeanVarDataType
*
p_savedInvVar
,
const
DyElementwiseOp
dy_elementwise_op
,
double
epsilon
,
DxDataType
*
p_dx
,
DscaleDbiasDataType
*
p_dscale
,
DscaleDbiasDataType
*
p_dbias
)
:
bnScaleBiasMeanVarLengths_
(
bnScaleBiasMeanVarLengths
),
bnScaleStrides_
(
bnScaleStrides
),
bnDscaleDbiasStrides_
(
bnDscaleDbiasStrides
),
bnMeanVarStrides_
(
bnMeanVarStrides
),
p_x_
(
p_x
),
p_dy_
(
p_dy
),
p_scale_
(
p_scale
),
p_savedMean_
(
p_savedMean
),
p_savedInvVar_
(
p_savedInvVar
),
dy_elementwise_op_
(
dy_elementwise_op
),
p_dx_
(
p_dx
),
p_dscale_
(
p_dscale
),
p_dbias_
(
p_dbias
)
{
xyLengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumBatchNormReduceDim
>
(
xyLengths
,
reduceDims
);
xStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumBatchNormReduceDim
>
(
xStrides
,
reduceDims
);
dyStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumBatchNormReduceDim
>
(
dyStrides
,
reduceDims
);
dxStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumBatchNormReduceDim
>
(
dxStrides
,
reduceDims
);
std
::
tie
(
invariant_length
,
reduce_length
)
=
get_2d_lengths
<
Rank
,
NumBatchNormReduceDim
>
(
xyLengths_
);
epsilon_
=
type_convert
<
AccDataType
>
(
epsilon
);
haveSavedMeanInvVar_
=
(
p_savedMean_
!=
nullptr
&&
p_savedInvVar_
!=
nullptr
);
if
(
UseMultiblockInK
)
{
int
iterations
=
1
;
while
(
true
)
{
int
testBlkGroupSize
=
(
reduce_length
+
(
K_BlockTileSize
*
iterations
)
-
1
)
/
(
K_BlockTileSize
*
iterations
);
// we want the blkGroupSize be not more than 128
if
(
testBlkGroupSize
<=
128
)
break
;
iterations
++
;
};
blkGroupSize
=
(
reduce_length
+
(
K_BlockTileSize
*
iterations
)
-
1
)
/
(
K_BlockTileSize
*
iterations
);
numBlockTileIteration
=
iterations
;
}
else
{
blkGroupSize
=
1
;
numBlockTileIteration
=
(
reduce_length
+
K_BlockTileSize
-
1
)
/
K_BlockTileSize
;
};
gridSize
=
(
invariant_length
+
M_BlockTileSize
-
1
)
/
M_BlockTileSize
*
blkGroupSize
;
x_grid_desc_m_k
=
MakeXY2dDescriptor
(
xyLengths_
,
xStrides_
,
blkGroupSize
,
numBlockTileIteration
);
dy_grid_desc_m_k
=
MakeXY2dDescriptor
(
xyLengths_
,
dyStrides_
,
blkGroupSize
,
numBlockTileIteration
);
dx_grid_desc_m_k
=
MakeXY2dDescriptor
(
xyLengths_
,
dxStrides_
,
blkGroupSize
,
numBlockTileIteration
);
scale_grid_desc_m
=
MakeScaleBiasMeanVar1dDescriptor
(
bnScaleBiasMeanVarLengths
,
bnScaleStrides
);
dscale_dbias_grid_desc_m
=
MakeScaleBiasMeanVar1dDescriptor
(
bnScaleBiasMeanVarLengths
,
bnDscaleDbiasStrides
);
mean_var_grid_desc_m
=
MakeScaleBiasMeanVar1dDescriptor
(
bnScaleBiasMeanVarLengths
,
bnMeanVarStrides
);
}
AccDataType
epsilon_
;
bool
haveSavedMeanInvVar_
;
std
::
array
<
index_t
,
Rank
>
xyLengths_
;
std
::
array
<
index_t
,
Rank
>
xStrides_
;
std
::
array
<
index_t
,
Rank
>
dyStrides_
;
std
::
array
<
index_t
,
Rank
>
dxStrides_
;
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarLengths_
;
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleStrides_
;
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnDscaleDbiasStrides_
;
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnMeanVarStrides_
;
const
XDataType
*
p_x_
;
const
DyDataType
*
p_dy_
;
const
ScaleDataType
*
p_scale_
;
const
MeanVarDataType
*
p_savedMean_
;
const
MeanVarDataType
*
p_savedInvVar_
;
const
DyElementwiseOp
dy_elementwise_op_
;
DxDataType
*
p_dx_
;
DscaleDbiasDataType
*
p_dscale_
;
DscaleDbiasDataType
*
p_dbias_
;
long_index_t
invariant_length
;
long_index_t
reduce_length
;
int
blkGroupSize
;
int
numBlockTileIteration
;
size_t
gridSize
;
XYGridDesc_M_K
x_grid_desc_m_k
;
XYGridDesc_M_K
dy_grid_desc_m_k
;
XYGridDesc_M_K
dx_grid_desc_m_k
;
ScaleBiasGridDesc_M
scale_grid_desc_m
;
ScaleBiasGridDesc_M
dscale_dbias_grid_desc_m
;
MeanVarGridDesc_M
mean_var_grid_desc_m
;
void
*
workspace_mean
;
void
*
workspace_variance
;
void
*
workspace_count
;
void
*
workspace_savedMean
;
void
*
workspace_savedInvVar
;
void
*
workspace_reduce_dscale
;
void
*
workspace_reduce_dbias
;
};
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
pArg
)
const
override
{
const
Argument
*
pArg_
=
dynamic_cast
<
const
Argument
*>
(
pArg
);
size_t
workspace_size
=
0
;
if
(
UseMultiblockInK
&&
pArg_
->
blkGroupSize
>
1
)
{
// workspace for the partial reduced result for dscale
workspace_size
+=
pArg_
->
invariant_length
*
pArg_
->
blkGroupSize
*
sizeof
(
DscaleDbiasDataType
)
+
64
;
// workspace for the partial reduced result for dbias
workspace_size
+=
pArg_
->
invariant_length
*
pArg_
->
blkGroupSize
*
sizeof
(
DscaleDbiasDataType
)
+
64
;
if
(
!
pArg_
->
haveSavedMeanInvVar_
)
{
// workspace for welford intermediate mean
workspace_size
+=
pArg_
->
invariant_length
*
pArg_
->
blkGroupSize
*
sizeof
(
MeanVarDataType
)
+
64
;
// workspace for welford intermediate variance
workspace_size
+=
pArg_
->
invariant_length
*
pArg_
->
blkGroupSize
*
sizeof
(
MeanVarDataType
)
+
64
;
// workspace for welford intermediate count
workspace_size
+=
pArg_
->
invariant_length
*
pArg_
->
blkGroupSize
*
sizeof
(
int32_t
)
+
64
;
// workspace for welford result mean
workspace_size
+=
pArg_
->
invariant_length
*
sizeof
(
MeanVarDataType
)
+
64
;
// workspace for welford result inv_variance
workspace_size
+=
pArg_
->
invariant_length
*
sizeof
(
MeanVarDataType
)
+
64
;
};
}
return
(
workspace_size
);
};
void
SetWorkSpacePointer
(
BaseArgument
*
pArg
,
void
*
p_workspace
)
const
override
{
Argument
*
pArg_
=
dynamic_cast
<
Argument
*>
(
pArg
);
pArg_
->
p_workspace_
=
p_workspace
;
index_t
space_sz
;
// setup buffer for the partial reduced result for dscale
pArg_
->
workspace_reduce_dscale
=
pArg_
->
p_workspace_
;
space_sz
=
pArg_
->
invariant_length
*
pArg_
->
blkGroupSize
*
sizeof
(
DscaleDbiasDataType
);
space_sz
=
math
::
integer_least_multiple
(
space_sz
,
64
);
// setup buffer for the partial reduced result for dbias
pArg_
->
workspace_reduce_dbias
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_reduce_dscale
)
+
space_sz
;
if
(
UseMultiblockInK
&&
pArg_
->
blkGroupSize
>
1
)
{
space_sz
=
pArg_
->
invariant_length
*
pArg_
->
blkGroupSize
*
sizeof
(
DscaleDbiasDataType
);
space_sz
=
math
::
integer_least_multiple
(
space_sz
,
64
);
// setup buffer for welford intermediate mean
pArg_
->
workspace_mean
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_reduce_dbias
)
+
space_sz
;
space_sz
=
pArg_
->
invariant_length
*
pArg_
->
blkGroupSize
*
sizeof
(
MeanVarDataType
);
space_sz
=
math
::
integer_least_multiple
(
space_sz
,
64
);
// setup buffer for welford intermediate varirance
pArg_
->
workspace_variance
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_mean
)
+
space_sz
;
space_sz
=
pArg_
->
invariant_length
*
pArg_
->
blkGroupSize
*
sizeof
(
MeanVarDataType
);
space_sz
=
math
::
integer_least_multiple
(
space_sz
,
64
);
// setup buffer for welford intermediate count
pArg_
->
workspace_count
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_variance
)
+
space_sz
;
space_sz
=
pArg_
->
invariant_length
*
pArg_
->
blkGroupSize
*
sizeof
(
int32_t
);
space_sz
=
math
::
integer_least_multiple
(
space_sz
,
64
);
// setup buffer for welford result mean
pArg_
->
workspace_savedMean
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_count
)
+
space_sz
;
space_sz
=
pArg_
->
invariant_length
*
sizeof
(
MeanVarDataType
);
space_sz
=
math
::
integer_least_multiple
(
space_sz
,
64
);
// setup buffer for welford result inv_variance
pArg_
->
workspace_savedInvVar
=
reinterpret_cast
<
char
*>
(
pArg_
->
workspace_savedMean
)
+
space_sz
;
};
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
float
avg_time
=
0
;
const
auto
mean_var_count_grid_desc_m_g
=
DeviceBatchNormBwdImpl
::
MakeMultiblockFirstReduceOutputMG2dDescriptor
(
arg
.
invariant_length
,
arg
.
blkGroupSize
);
const
auto
dscale_dbias_grid_desc_m_g
=
DeviceBatchNormBwdImpl
::
MakeMultiblockFirstReduceOutputMG2dDescriptor
(
arg
.
invariant_length
,
arg
.
blkGroupSize
);
const
auto
mean_var_count_grid_desc_m_k
=
DeviceBatchNormBwdImpl
::
MakeMultiblockFinalReduceInputMK2dDescriptor
(
arg
.
invariant_length
,
arg
.
blkGroupSize
);
const
auto
dscale_dbias_grid_desc_m_k
=
DeviceBatchNormBwdImpl
::
MakeMultiblockFinalReduceInputMK2dDescriptor
(
arg
.
invariant_length
,
arg
.
blkGroupSize
);
using
MeanVarCountGridDesc_M_G
=
decltype
(
mean_var_count_grid_desc_m_g
);
using
MeanVarCountGridDesc_M_K
=
decltype
(
mean_var_count_grid_desc_m_k
);
using
DscaleDbiasGridDesc_M_G
=
decltype
(
dscale_dbias_grid_desc_m_g
);
using
DscaleDbiasGridDesc_M_K
=
decltype
(
dscale_dbias_grid_desc_m_k
);
using
GridwiseWelfordSecondHalfReduceFirstHalf_
=
GridwiseWelfordSecondHalfReduceFirstHalf
<
XDataType
,
DyDataType
,
AccDataType
,
ScaleDataType
,
DscaleDbiasDataType
,
MeanVarDataType
,
DyElementwiseOp
,
XYGridDesc_M_K
,
MeanVarGridDesc_M
,
MeanVarCountGridDesc_M_K
,
DscaleDbiasGridDesc_M_G
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XDyDxVectorDim
,
XSrcVectorSize
,
DySrcVectorSize
,
MeanVarSrcVectorSize
>
;
using
GridwiseReduceSecondHalfBatchNormBwdFinal_
=
GridwiseReduceSecondHalfBatchNormBackwardFinal
<
XDataType
,
DyDataType
,
DxDataType
,
AccDataType
,
ScaleDataType
,
DscaleDbiasDataType
,
MeanVarDataType
,
DyElementwiseOp
,
XYGridDesc_M_K
,
DscaleDbiasGridDesc_M_K
,
MeanVarGridDesc_M
,
ScaleBiasGridDesc_M
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XDyDxVectorDim
,
XSrcVectorSize
,
DySrcVectorSize
,
DxDstVectorSize
,
ScaleSrcVectorSize
,
DscaleDbiasDstVectorSize
,
MeanVarSrcVectorSize
>
;
if
(
UseMultiblockInK
&&
arg
.
blkGroupSize
>
1
)
{
using
GetReduceCountPerThreadFunctor
=
GetReduceCountPerThreadForMultiblockWelford
<
K_BlockTileSize
,
KThreadSliceSize
>
;
GetReduceCountPerThreadFunctor
get_reduce_count_per_thread
(
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
,
arg
.
reduce_length
);
if
(
!
arg
.
haveSavedMeanInvVar_
)
{
using
GridwiseMultiblockWelfordFirstHalf_
=
GridwiseMultiblockWelfordFirstHalf
<
XDataType
,
AccDataType
,
MeanVarDataType
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_G
,
GetReduceCountPerThreadFunctor
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XDyDxVectorDim
,
XSrcVectorSize
>
;
const
auto
kern_multiblock_welford_first_half
=
kernel_multiblock_welford_first_half
<
GridwiseMultiblockWelfordFirstHalf_
,
XDataType
,
MeanVarDataType
,
XYGridDesc_M_K
,
MeanVarCountGridDesc_M_G
,
GetReduceCountPerThreadFunctor
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_multiblock_welford_first_half
,
dim3
(
arg
.
gridSize
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k
,
mean_var_count_grid_desc_m_g
,
get_reduce_count_per_thread
,
arg
.
numBlockTileIteration
,
arg
.
p_x_
,
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_mean
),
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_variance
),
static_cast
<
int32_t
*>
(
arg
.
workspace_count
));
};
const
auto
kern_welford_second_half_reduce_first_half
=
kernel_welford_second_half_reduce_first_half
<
GridwiseWelfordSecondHalfReduceFirstHalf_
,
XDataType
,
DyDataType
,
AccDataType
,
ScaleDataType
,
DscaleDbiasDataType
,
MeanVarDataType
,
DyElementwiseOp
,
XYGridDesc_M_K
,
MeanVarGridDesc_M
,
MeanVarCountGridDesc_M_K
,
DscaleDbiasGridDesc_M_G
>
;
const
auto
kern_reduce_second_half_batchnorm_backward_final
=
kernel_reduce_second_half_batchnorm_backward_final
<
GridwiseReduceSecondHalfBatchNormBwdFinal_
,
XDataType
,
DyDataType
,
DxDataType
,
ScaleDataType
,
DscaleDbiasDataType
,
MeanVarDataType
,
DyElementwiseOp
,
XYGridDesc_M_K
,
DscaleDbiasGridDesc_M_K
,
MeanVarGridDesc_M
,
ScaleBiasGridDesc_M
>
;
index_t
numDscaleDbiasBlockTileIteration
=
(
arg
.
blkGroupSize
+
KThreadClusterSize
-
1
)
/
KThreadClusterSize
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_welford_second_half_reduce_first_half
,
dim3
(
arg
.
gridSize
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k
,
arg
.
dy_grid_desc_m_k
,
arg
.
mean_var_grid_desc_m
,
mean_var_count_grid_desc_m_k
,
dscale_dbias_grid_desc_m_g
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
,
numDscaleDbiasBlockTileIteration
,
arg
.
epsilon_
,
arg
.
haveSavedMeanInvVar_
,
arg
.
haveSavedMeanInvVar_
?
arg
.
p_savedMean_
:
nullptr
,
arg
.
haveSavedMeanInvVar_
?
arg
.
p_savedInvVar_
:
nullptr
,
arg
.
haveSavedMeanInvVar_
?
nullptr
:
static_cast
<
const
MeanVarDataType
*>
(
arg
.
workspace_mean
),
arg
.
haveSavedMeanInvVar_
?
nullptr
:
static_cast
<
const
MeanVarDataType
*>
(
arg
.
workspace_variance
),
arg
.
haveSavedMeanInvVar_
?
nullptr
:
static_cast
<
const
int32_t
*>
(
arg
.
workspace_count
),
arg
.
dy_elementwise_op_
,
arg
.
haveSavedMeanInvVar_
?
nullptr
:
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_savedMean
),
arg
.
haveSavedMeanInvVar_
?
nullptr
:
static_cast
<
MeanVarDataType
*>
(
arg
.
workspace_savedInvVar
),
arg
.
p_x_
,
arg
.
p_dy_
,
static_cast
<
DscaleDbiasDataType
*>
(
arg
.
workspace_reduce_dscale
),
static_cast
<
DscaleDbiasDataType
*>
(
arg
.
workspace_reduce_dbias
));
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_reduce_second_half_batchnorm_backward_final
,
dim3
(
arg
.
gridSize
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k
,
arg
.
dy_grid_desc_m_k
,
arg
.
dx_grid_desc_m_k
,
dscale_dbias_grid_desc_m_k
,
arg
.
mean_var_grid_desc_m
,
arg
.
scale_grid_desc_m
,
arg
.
dscale_dbias_grid_desc_m
,
arg
.
blkGroupSize
,
arg
.
reduce_length
,
arg
.
numBlockTileIteration
,
numDscaleDbiasBlockTileIteration
,
static_cast
<
const
DscaleDbiasDataType
*>
(
arg
.
workspace_reduce_dscale
),
static_cast
<
const
DscaleDbiasDataType
*>
(
arg
.
workspace_reduce_dbias
),
arg
.
haveSavedMeanInvVar_
?
arg
.
p_savedMean_
:
static_cast
<
const
MeanVarDataType
*>
(
arg
.
workspace_savedMean
),
arg
.
haveSavedMeanInvVar_
?
arg
.
p_savedInvVar_
:
static_cast
<
const
MeanVarDataType
*>
(
arg
.
workspace_savedInvVar
),
arg
.
p_x_
,
arg
.
p_dy_
,
arg
.
p_scale_
,
arg
.
dy_elementwise_op_
,
arg
.
p_dx_
,
arg
.
p_dscale_
,
arg
.
p_dbias_
);
}
else
{
using
GetReduceCountPerThreadFunctor
=
GetReduceCountPerThreadForBlockwiseWelford
<
K_BlockTileSize
,
KThreadSliceSize
>
;
GetReduceCountPerThreadFunctor
get_reduce_count_per_thread
(
arg
.
numBlockTileIteration
,
arg
.
reduce_length
);
using
GridwiseBatchNormBackwardWithBlockwiseWelford_
=
GridwiseBatchNormBackwardWithBlockwiseWelford
<
XDataType
,
DyDataType
,
DxDataType
,
AccDataType
,
ScaleDataType
,
DscaleDbiasDataType
,
MeanVarDataType
,
DyElementwiseOp
,
XYGridDesc_M_K
,
ScaleBiasGridDesc_M
,
MeanVarGridDesc_M
,
GetReduceCountPerThreadFunctor
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XDyDxVectorDim
,
XSrcVectorSize
,
DySrcVectorSize
,
DxDstVectorSize
,
ScaleSrcVectorSize
,
DscaleDbiasDstVectorSize
,
MeanVarSrcVectorSize
>
;
const
auto
kern_batchnorm_bwd
=
kernel_batchnorm_backward_with_blockwise_welford
<
GridwiseBatchNormBackwardWithBlockwiseWelford_
,
XDataType
,
DyDataType
,
DxDataType
,
AccDataType
,
ScaleDataType
,
DscaleDbiasDataType
,
MeanVarDataType
,
DyElementwiseOp
,
XYGridDesc_M_K
,
ScaleBiasGridDesc_M
,
MeanVarGridDesc_M
,
GetReduceCountPerThreadFunctor
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kern_batchnorm_bwd
,
dim3
(
arg
.
gridSize
),
dim3
(
BlockSize
),
0
,
arg
.
x_grid_desc_m_k
,
arg
.
dy_grid_desc_m_k
,
arg
.
dx_grid_desc_m_k
,
arg
.
scale_grid_desc_m
,
arg
.
dscale_dbias_grid_desc_m
,
arg
.
mean_var_grid_desc_m
,
get_reduce_count_per_thread
,
arg
.
reduce_length
,
arg
.
numBlockTileIteration
,
arg
.
epsilon_
,
arg
.
p_x_
,
arg
.
p_dy_
,
arg
.
p_scale_
,
arg
.
haveSavedMeanInvVar_
,
arg
.
p_savedMean_
,
arg
.
p_savedInvVar_
,
arg
.
dy_elementwise_op_
,
arg
.
p_dx_
,
arg
.
p_dscale_
,
arg
.
p_dbias_
);
};
return
(
avg_time
);
};
float
Run
(
const
BaseArgument
*
pArg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
pArg
),
stream_config
);
};
};
bool
IsSupportedArgument
(
const
BaseArgument
*
pArg
)
override
{
const
Argument
*
pArg_
=
dynamic_cast
<
const
Argument
*>
(
pArg
);
if
constexpr
(
XDyDxVectorDim
==
0
)
{
if
(
pArg_
->
xStrides_
[
NumInvariantDim
-
1
]
!=
1
||
pArg_
->
dyStrides_
[
NumInvariantDim
-
1
]
!=
1
||
pArg_
->
dxStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
if
(
pArg_
->
xyLengths_
[
NumInvariantDim
-
1
]
%
XSrcVectorSize
!=
0
||
pArg_
->
xyLengths_
[
NumInvariantDim
-
1
]
%
DySrcVectorSize
!=
0
||
pArg_
->
xyLengths_
[
NumInvariantDim
-
1
]
%
DxDstVectorSize
!=
0
)
return
false
;
}
else
{
if
(
pArg_
->
xStrides_
[
Rank
-
1
]
!=
1
||
pArg_
->
dyStrides_
[
Rank
-
1
]
!=
1
||
pArg_
->
dxStrides_
[
Rank
-
1
]
!=
1
)
return
false
;
if
(
pArg_
->
xyLengths_
[
Rank
-
1
]
%
XSrcVectorSize
!=
0
||
pArg_
->
xyLengths_
[
Rank
-
1
]
%
DySrcVectorSize
!=
0
||
pArg_
->
xyLengths_
[
Rank
-
1
]
%
DxDstVectorSize
!=
0
)
return
false
;
};
if
(
pArg_
->
bnScaleStrides_
[
NumInvariantDim
-
1
]
!=
1
&&
ScaleSrcVectorSize
!=
1
)
return
false
;
if
(
pArg_
->
bnDscaleDbiasStrides_
[
NumInvariantDim
-
1
]
!=
1
&&
DscaleDbiasDstVectorSize
!=
1
)
return
false
;
if
(
pArg_
->
bnScaleBiasMeanVarLengths_
[
NumInvariantDim
-
1
]
%
ScaleSrcVectorSize
!=
0
)
return
false
;
if
(
pArg_
->
bnScaleBiasMeanVarLengths_
[
NumInvariantDim
-
1
]
%
DscaleDbiasDstVectorSize
!=
0
)
return
false
;
if
(
pArg_
->
haveSavedMeanInvVar_
)
{
if
(
pArg_
->
bnMeanVarStrides_
[
NumInvariantDim
-
1
]
!=
1
&&
MeanVarSrcVectorSize
!=
1
)
return
false
;
if
(
pArg_
->
bnScaleBiasMeanVarLengths_
[
NumInvariantDim
-
1
]
%
MeanVarSrcVectorSize
!=
0
)
return
false
;
};
bool
is_valid
=
true
;
static_for
<
0
,
NumInvariantDim
,
1
>
{}([
&
](
auto
I
)
{
if
(
pArg_
->
xyLengths_
[
I
]
!=
pArg_
->
bnScaleBiasMeanVarLengths_
[
I
])
is_valid
=
false
;
});
if
(
!
is_valid
)
return
false
;
return
true
;
};
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
dyStrides
,
const
std
::
array
<
index_t
,
Rank
>
dxStrides
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
ck
::
index_t
,
NumInvariantDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
ck
::
index_t
,
NumInvariantDim
>
bnScaleStrides
,
const
std
::
array
<
ck
::
index_t
,
NumInvariantDim
>
bnDscaleDbiasStrides
,
const
std
::
array
<
ck
::
index_t
,
NumInvariantDim
>
bnMeanVarStrides
,
const
void
*
p_x
,
const
void
*
p_dy
,
const
void
*
p_scale
,
const
void
*
p_savedMean
,
const
void
*
p_savedInvVar
,
double
epsilon
,
const
DyElementwiseOp
dy_elementwise_op
,
void
*
p_dx
,
void
*
p_dscale
,
void
*
p_dbias
)
override
{
return
std
::
make_unique
<
Argument
>
(
xyLengths
,
xStrides
,
dyStrides
,
dxStrides
,
reduceDims
,
bnScaleBiasMeanVarLengths
,
bnScaleStrides
,
bnDscaleDbiasStrides
,
bnMeanVarStrides
,
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
DyDataType
*>
(
p_dy
),
static_cast
<
const
ScaleDataType
*>
(
p_scale
),
static_cast
<
const
MeanVarDataType
*>
(
p_savedMean
),
static_cast
<
const
MeanVarDataType
*>
(
p_savedInvVar
),
dy_elementwise_op
,
epsilon
,
static_cast
<
DxDataType
*>
(
p_dx
),
static_cast
<
DscaleDbiasDataType
*>
(
p_dscale
),
static_cast
<
DscaleDbiasDataType
*>
(
p_dbias
));
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceBatchNormBwdImpl<"
<<
BlockSize
<<
","
;
str
<<
"M_C"
<<
MThreadClusterSize
<<
"_S"
<<
MThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
KThreadClusterSize
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"XDyDxVectorDim_"
<<
XDyDxVectorDim
<<
","
;
str
<<
"VectorSize_X"
<<
XSrcVectorSize
<<
"_scale_"
<<
ScaleSrcVectorSize
<<
"_bias_"
<<
DscaleDbiasDstVectorSize
<<
"_mean_var_"
<<
MeanVarSrcVectorSize
<<
"_Dx_"
<<
DxDstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp
View file @
dc0bae32
...
...
@@ -42,8 +42,15 @@ template <typename XDataType,
index_t
ScaleSrcVectorSize
,
index_t
BiasSrcVectorSize
,
index_t
MeanVarSrcDstVectorSize
>
struct
DeviceBatchNormFwdImpl
:
public
DeviceBatchNormFwd
<
Rank
,
NumBatchNormReduceDim
,
YElementwiseOp
>
struct
DeviceBatchNormFwdImpl
:
public
DeviceBatchNormFwd
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
Rank
,
NumBatchNormReduceDim
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
dc0bae32
...
...
@@ -488,7 +488,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
{
using
Argument
=
DeviceOp
::
Argument
;
void
ShowInfo
(
const
Argument
&
arg
)
void
Print
(
const
Argument
&
arg
)
{
std
::
cout
<<
"arg.a_grid_desc_kbatch_k0_m_k1_{"
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
...
...
@@ -508,7 +508,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
ShowInfo
(
arg
);
if
(
stream_config
.
log_level_
>
0
)
{
Print
(
arg
);
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
View file @
dc0bae32
...
...
@@ -549,6 +549,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
float
ave_time
=
0
;
for
(
size_t
i
=
0
;
i
<
arg
.
a_grid_desc_k0_m_k1_container_
.
size
();
i
++
)
{
#if DEBUG_LOG
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_container_{"
<<
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I0
)
<<
", "
...
...
@@ -581,6 +582,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<<
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
[
i
].
GetLength
(
I5
)
<<
" ) "
<<
std
::
endl
;
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_container_
[
i
],
arg
.
b_grid_desc_k0_n_k1_container_
[
i
],
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
View file @
dc0bae32
...
...
@@ -644,7 +644,7 @@ struct
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
#if
0
#if
DEBUG_LOG
{
std
::
cout
<<
DeviceOp
{}.
GetTypeString
()
<<
std
::
endl
;
std
::
cout
<<
"N "
<<
arg
.
Conv_N_
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp
View file @
dc0bae32
...
...
@@ -614,7 +614,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
#if
0
#if
DEBUG_LOG
{
std
::
cout
<<
DeviceOp
{}.
GetTypeString
()
<<
std
::
endl
;
std
::
cout
<<
"N "
<<
arg
.
Conv_N_
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
dc0bae32
...
...
@@ -579,7 +579,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
#if
0
#if
DEBUG_LOG
{
std
::
cout
<<
DeviceOp
{}.
GetTypeString
()
<<
std
::
endl
;
std
::
cout
<<
"N "
<<
arg
.
Conv_N_
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
View file @
dc0bae32
...
...
@@ -465,7 +465,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
#if
0
#if
DEBUG_LOG
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
View file @
dc0bae32
...
...
@@ -400,6 +400,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
#if DEBUG_LOG
{
std
::
cout
<<
"num_batches_of_GEMM = "
<<
arg
.
num_subbatches_
<<
std
::
endl
;
std
::
cout
<<
"a_grid_desc_k0_m_k1{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
...
...
@@ -413,6 +414,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
std
::
cout
<<
"c_grid_desc_m_n{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
View file @
dc0bae32
...
...
@@ -1272,6 +1272,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
float
ave_time
=
0
;
for
(
size_t
i
=
0
;
i
<
arg
.
a_grid_desc_k0_m_k1_container_
.
size
();
i
++
)
{
#if DEBUG_LOG
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_container_{"
<<
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I0
)
<<
", "
...
...
@@ -1304,6 +1305,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
<<
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_container_
[
i
].
GetLength
(
I5
)
<<
" ) "
<<
std
::
endl
;
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_container_
[
i
],
arg
.
b_grid_desc_k0_n_k1_container_
[
i
],
...
...
Prev
1
2
3
4
5
6
7
8
9
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment