Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
28d87372
Commit
28d87372
authored
Nov 15, 2023
by
rocking
Browse files
Add kernel function, prepare to implement
parent
fba188c6
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
194 additions
and
4 deletions
+194
-4
include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_x_impl.hpp
...ation/gpu/device/impl/device_normalization_bwd_x_impl.hpp
+125
-4
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_x.hpp
...n/gpu/grid/normalization/gridwise_normalization_bwd_x.hpp
+69
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_x_impl.hpp
View file @
28d87372
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#include <vector>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_normalization_bwd_x.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization_bwd_x.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_x.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
...
@@ -17,6 +18,42 @@
...
@@ -17,6 +18,42 @@
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
template
<
typename
GridwiseNormalizationBwd
,
typename
DYDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
MeanInvStdDataType
,
typename
DXDataType
,
typename
GridDesc_M_K
>
__global__
void
kernel_normalization_bwd_x
(
const
GridDesc_M_K
dy_grid_desc_m_k
,
const
GridDesc_M_K
x_grid_desc_m_k
,
const
GridDesc_M_K
gamma_grid_desc_m_k
,
const
GridDesc_M_K
mean_grid_desc_m_k
,
const
GridDesc_M_K
inv_std_grid_desc_m_k
,
const
GridDesc_M_K
dx_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
const
DYDataType
*
const
__restrict__
p_dy_global
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
MeanInvStdDataType
*
const
__restrict__
p_mean_global
,
const
MeanInvStdDataType
*
const
__restrict__
p_inv_std_global
,
DXDataType
*
const
__restrict__
p_dx_global
)
{
GridwiseNormalizationBwd
::
Run
(
dy_grid_desc_m_k
,
x_grid_desc_m_k
,
gamma_grid_desc_m_k
,
mean_grid_desc_m_k
,
inv_std_grid_desc_m_k
,
dx_grid_desc_m_k
,
num_k_block_tile_iteration
,
p_dy_global
,
p_x_global
,
p_gamma_global
,
p_mean_global
,
p_inv_std_global
,
p_dx_global
);
};
template
<
typename
DYDataType
,
template
<
typename
DYDataType
,
typename
XDataType
,
typename
XDataType
,
...
@@ -131,6 +168,56 @@ struct DeviceNormalizationBwdXImpl : public DeviceNormalizationBwdX<DYDataType,
...
@@ -131,6 +168,56 @@ struct DeviceNormalizationBwdXImpl : public DeviceNormalizationBwdX<DYDataType,
using
GridDesc_M_K
=
decltype
(
Make2dDescriptor
({
1
},
{
1
},
1
));
using
GridDesc_M_K
=
decltype
(
Make2dDescriptor
({
1
},
{
1
},
1
));
using
GridwiseNormalizationBwdXGeneric
=
GridwiseNormalizationBwdX_mk_to_mk
<
DYDataType
,
XDataType
,
GammaDataType
,
MeanInvStdDataType
,
ComputeDataType
,
DXDataType
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
DYSrcVectorDim
,
DYSrcVectorSize
,
XSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
MeanInvStdSrcVectorDim
,
MeanInvStdSrcVectorSize
,
DXDstVectorDim
,
DXDstVectorSize
,
false
>
;
using
GridwiseNormalizationBwdXSweepOnce
=
GridwiseNormalizationBwdX_mk_to_mk
<
DYDataType
,
XDataType
,
GammaDataType
,
MeanInvStdDataType
,
ComputeDataType
,
DXDataType
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
DYSrcVectorDim
,
DYSrcVectorSize
,
XSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
MeanInvStdSrcVectorDim
,
MeanInvStdSrcVectorSize
,
DXDstVectorDim
,
DXDstVectorSize
,
true
>
;
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
Argument
(
const
std
::
vector
<
index_t
>
lengths
,
Argument
(
const
std
::
vector
<
index_t
>
lengths
,
...
@@ -214,12 +301,46 @@ struct DeviceNormalizationBwdXImpl : public DeviceNormalizationBwdX<DYDataType,
...
@@ -214,12 +301,46 @@ struct DeviceNormalizationBwdXImpl : public DeviceNormalizationBwdX<DYDataType,
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
{
{
auto
KernelSelector
(
bool
isSweepOnce
)
{
return
isSweepOnce
?
kernel_normalization_bwd_x
<
GridwiseNormalizationBwdXSweepOnce
,
DYDataType
,
XDataType
,
GammaDataType
,
MeanInvStdDataType
,
DXDataType
,
GridDesc_M_K
>
:
kernel_normalization_bwd_x
<
GridwiseNormalizationBwdXGeneric
,
DYDataType
,
XDataType
,
GammaDataType
,
MeanInvStdDataType
,
DXDataType
,
GridDesc_M_K
>
;
}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
// TODO
const
auto
kernel_main
=
KernelSelector
(
arg
.
isSweeponce_
);
ignore
=
arg
;
ignore
=
stream_config
;
return
launch_and_time_kernel
(
stream_config
,
return
0
;
kernel_main
,
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
0
,
arg
.
dy_grid_desc_m_k_
,
arg
.
x_grid_desc_m_k_
,
arg
.
gamma_grid_desc_m_k_
,
arg
.
mean_grid_desc_m_k_
,
arg
.
inv_std_grid_desc_m_k_
,
arg
.
dx_grid_desc_m_k_
,
arg
.
numBlockTileIteration_
,
arg
.
p_dy_
,
arg
.
p_x_
,
arg
.
p_gamma_
,
arg
.
p_mean_
,
arg
.
p_invStd_
,
arg
.
p_dx_
);
}
}
float
Run
(
const
BaseArgument
*
p_arg
,
float
Run
(
const
BaseArgument
*
p_arg
,
...
...
include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_x.hpp
0 → 100644
View file @
28d87372
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
namespace
ck
{
// Tensor Shape
// dy, x = [M, K], gamma = [1, K], x_mean, inv_std = [M, 1]
// Flow:
// def normalization_backward_x(dy, x, gamma, x_mean, inv_std, reduce_axis, reduce_size):
// ds = np.sum(dy * gamma * x, axis=reduce_axis, keepdims=True)
// db = np.sum(dy * gamma, axis=reduce_axis, keepdims=True)
// b = (db * x_mean - ds) * inv_std ** (3) / reduce_size
// c = -b * x_mean - db * inv_std / reduce_size
// dx = inv_std * dy * gamma + b * x + c
// return dx
template
<
typename
DYDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
MeanInvStdDataType
,
typename
ComputeDataType
,
typename
DXDataType
,
typename
GridDesc_M_K
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
DYSrcVectorDim
,
index_t
DYSrcVectorSize
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
MeanInvStdSrcVectorDim
,
index_t
MeanInvStdSrcVectorSize
,
index_t
DXDstVectorDim
,
index_t
DXDstVectorSize
,
bool
SweepOnce
>
struct
GridwiseNormalizationBwdX_mk_to_mk
{
__device__
static
void
Run
(
const
GridDesc_M_K
&
dy_grid_desc_m_k
,
const
GridDesc_M_K
&
x_grid_desc_m_k
,
const
GridDesc_M_K
&
gamma_grid_desc_m_k
,
const
GridDesc_M_K
&
mean_grid_desc_m_k
,
const
GridDesc_M_K
&
inv_std_grid_desc_m_k
,
const
GridDesc_M_K
&
dx_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
const
DYDataType
*
const
__restrict__
p_dy_global
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
MeanInvStdDataType
*
const
__restrict__
p_mean_global
,
const
MeanInvStdDataType
*
const
__restrict__
p_inv_std_global
,
DXDataType
*
const
__restrict__
p_dx_global
)
{
// TODO
}
};
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment