Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6ed9a02d
Commit
6ed9a02d
authored
Sep 29, 2022
by
Qianfeng Zhang
Browse files
Add batchnorm-backward device op
parent
bbd19d89
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
988 additions
and
0 deletions
+988
-0
include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp
...tensor_operation/gpu/device/device_batchnorm_backward.hpp
+49
-0
include/ck/tensor_operation/gpu/device/device_batchnorm_backward_impl.hpp
...r_operation/gpu/device/device_batchnorm_backward_impl.hpp
+850
-0
include/ck/tensor_operation/gpu/device/welford_helper.hpp
include/ck/tensor_operation/gpu/device/welford_helper.hpp
+89
-0
No files found.
include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp
0 → 100644
View file @
6ed9a02d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
struct
DeviceBatchNormBwd
:
public
BaseOperator
{
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumBatchNormReduceDim
;
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
dyStrides
,
const
std
::
array
<
index_t
,
Rank
>
dxStrides
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
ck
::
index_t
,
NumInvariantDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
ck
::
index_t
,
NumInvariantDim
>
bnScaleStrides
,
const
std
::
array
<
ck
::
index_t
,
NumInvariantDim
>
bnBiasStrides
,
const
std
::
array
<
ck
::
index_t
,
NumInvariantDim
>
bnMeanVarStrides
,
const
void
*
p_x
,
const
void
*
p_dy
,
const
void
*
p_scale
,
const
void
*
p_savedMean
,
const
void
*
p_savedInvVar
,
double
epsilon
,
void
*
p_dx
,
void
*
p_scaleDiff
,
void
*
p_biasDiff
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
using
DeviceBatchNormBwdPtr
=
std
::
unique_ptr
<
DeviceBatchNormBwd
<
Rank
,
NumBatchNormReduceDim
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_batchnorm_backward_impl.hpp
0 → 100644
View file @
6ed9a02d
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/welford_helper.hpp
0 → 100644
View file @
6ed9a02d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
index_t
K_BlockTileSize
,
index_t
KThreadSliceSize
>
struct
GetReduceCountPerThreadForBlockwiseWelford
{
GetReduceCountPerThreadForBlockwiseWelford
(
index_t
numBlockTileIteration
,
long_index_t
reduce_length
)
:
numBlockTileIteration_
{
numBlockTileIteration
}
{
count_in_last_tile
=
reduce_length
%
K_BlockTileSize
;
};
__device__
index_t
operator
()(
index_t
thread_k_cluster_id
)
const
{
if
(
count_in_last_tile
==
0
)
return
(
KThreadSliceSize
*
numBlockTileIteration_
);
else
{
index_t
num_complete_slice
=
count_in_last_tile
/
KThreadSliceSize
;
index_t
count_in_last_slice
=
count_in_last_tile
%
KThreadSliceSize
;
if
(
thread_k_cluster_id
<
num_complete_slice
)
return
(
KThreadSliceSize
*
numBlockTileIteration_
);
else
if
(
thread_k_cluster_id
==
num_complete_slice
)
return
(
KThreadSliceSize
*
(
numBlockTileIteration_
-
1
)
+
count_in_last_slice
);
else
return
(
KThreadSliceSize
*
(
numBlockTileIteration_
-
1
));
};
};
index_t
numBlockTileIteration_
;
index_t
count_in_last_tile
;
};
template
<
index_t
K_BlockTileSize
,
index_t
KThreadSliceSize
>
struct
GetReduceCountPerThreadForMultiblockWelford
{
GetReduceCountPerThreadForMultiblockWelford
(
index_t
blkGroupSize
,
index_t
numBlockTileIteration
,
long_index_t
reduce_length
)
:
blkGroupSize_
(
blkGroupSize
),
numBlockTileIteration_
{
numBlockTileIteration
}
{
last_block_reduce_length
=
reduce_length
-
K_BlockTileSize
*
numBlockTileIteration_
*
(
blkGroupSize_
-
1
);
numBlockTileIterationByLastBlock
=
(
last_block_reduce_length
+
K_BlockTileSize
-
1
)
/
K_BlockTileSize
;
};
__device__
index_t
operator
()(
index_t
block_local_id
,
index_t
thread_k_cluster_id
)
const
{
if
(
last_block_reduce_length
==
K_BlockTileSize
*
numBlockTileIteration_
||
block_local_id
<
blkGroupSize_
-
1
)
return
(
KThreadSliceSize
*
numBlockTileIteration_
);
index_t
count_in_last_tile
=
last_block_reduce_length
%
K_BlockTileSize
;
if
(
count_in_last_tile
==
0
)
return
(
KThreadSliceSize
*
numBlockTileIterationByLastBlock
);
else
{
index_t
num_complete_slice
=
count_in_last_tile
/
KThreadSliceSize
;
if
(
thread_k_cluster_id
<
num_complete_slice
)
return
(
KThreadSliceSize
*
numBlockTileIterationByLastBlock
);
else
if
(
thread_k_cluster_id
==
num_complete_slice
)
return
(
KThreadSliceSize
*
(
numBlockTileIterationByLastBlock
-
1
)
+
count_in_last_tile
);
else
return
(
KThreadSliceSize
*
(
numBlockTileIterationByLastBlock
-
1
));
};
};
index_t
blkGroupSize_
;
index_t
numBlockTileIteration_
;
index_t
last_block_reduce_length
;
index_t
numBlockTileIterationByLastBlock
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment