Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8745c0ca
Commit
8745c0ca
authored
Feb 10, 2023
by
rocking
Browse files
Update naive variance kernel
parent
3b6f9c16
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
291 additions
and
155 deletions
+291
-155
include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp
...ration/gpu/grid/gridwise_normalization_naive_variance.hpp
+291
-155
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp
View file @
8745c0ca
...
@@ -4,9 +4,8 @@
...
@@ -4,9 +4,8 @@
#pragma once
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
...
@@ -19,8 +18,8 @@ template <typename XDataType,
...
@@ -19,8 +18,8 @@ template <typename XDataType,
typename
GammaDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
YDataType
,
typename
Acc
DataType
,
typename
Compute
DataType
,
typename
Acc
ElementwiseOperation
,
typename
Y
ElementwiseOperation
,
typename
GridDesc_M_K
,
typename
GridDesc_M_K
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
MThreadClusterSize
,
...
@@ -46,6 +45,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -46,6 +45,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
(
YDstVectorDim
==
1
&&
KThreadSliceSize
%
YDstVectorSize
==
0
),
(
YDstVectorDim
==
1
&&
KThreadSliceSize
%
YDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
(
XSrcVectorSize
==
YDstVectorSize
);
static_assert
(
XSrcVectorSize
==
GammaSrcVectorSize
);
static_assert
(
XSrcVectorSize
==
BetaSrcVectorSize
);
static
constexpr
bool
reorder_thread_cluster
=
(
XSrcVectorDim
==
0
);
static
constexpr
bool
reorder_thread_cluster
=
(
XSrcVectorDim
==
0
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
...
@@ -59,19 +62,23 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -59,19 +62,23 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
static
constexpr
auto
thread_cluster_desc
=
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
XSrcVectorSize
>
;
static
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVectorSize
>
{}));
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSlice
Size
>
{})));
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
XSrcVector
Size
>
{})));
using
ThreadReduceDstDesc_M
=
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
BlockwiseSumReduce
=
PartitionedBlockwiseReduction
<
Acc
DataType
,
using
BlockwiseSumReduce
=
PartitionedBlockwiseReduction
<
Compute
DataType
,
BlockSize
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
ThreadClusterArrangeOrder
,
reduce
::
Add
,
reduce
::
Add
,
true
>
;
true
>
;
using
ThreadwiseSumReduce
=
ThreadwiseReduction
<
Acc
DataType
,
using
ThreadwiseSumReduce
=
ThreadwiseReduction
<
Compute
DataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
ThreadReduceDstDesc_M
,
reduce
::
Add
,
reduce
::
Add
,
...
@@ -81,64 +88,77 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -81,64 +88,77 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileStepSize
=
KThreadClusterSize
*
XSrcVectorSize
;
static
constexpr
auto
ThreadBufferNumber
=
Number
<
KThreadSliceSize
/
XSrcVectorSize
>
{};
__device__
static
void
Run
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
__device__
static
void
Run
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
const
GridDesc_M_K
&
gamma_grid_desc_m_k
,
const
GridDesc_M_K
&
gamma_grid_desc_m_k
,
const
GridDesc_M_K
&
beta_grid_desc_m_k
,
const
GridDesc_M_K
&
beta_grid_desc_m_k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
index_t
num_k_block_tile_iteration
,
Acc
DataType
epsilon
,
Compute
DataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
YDataType
*
const
__restrict__
p_y_global
,
const
Acc
ElementwiseOperation
acc
_elementwise_op
)
const
Y
ElementwiseOperation
y
_elementwise_op
)
{
{
if
constexpr
(
SweepOnce
)
{
num_k_block_tile_iteration
=
1
;
}
// LDS
// LDS
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
__shared__
ComputeDataType
p_reduce_work_buffer
[
BlockSize
];
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_global
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
auto
reduce_work_buf
=
auto
reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_buffer
,
BlockSize
);
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_buffer
,
BlockSize
);
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
auto
y_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
x_thread_buf
;
p_y_global
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>&
beta_thread_buf
=
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
y_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>&
x_square_thread_buf
=
y_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
auto
x_thread_buf
=
generate_tuple
(
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
XSrcVectorSize
,
true
>
{};
},
Number
<
ThreadBufferNumber
>
{});
auto
gamma_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
GammaSrcVectorSize
,
true
>
{};
},
Number
<
ThreadBufferNumber
>
{});
auto
beta_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
BetaSrcVectorSize
,
true
>
{};
},
Number
<
ThreadBufferNumber
>
{});
auto
y_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
YDstVectorSize
,
true
>
{};
},
Number
<
ThreadBufferNumber
>
{});
auto
&
x_square_thread_buf
=
y_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
mean_square_thread_buf
;
mean_square_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>&
var_thread_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>&
mean_square_thread_buf
;
var_thread_buf
=
mean_square_thread_buf
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
mean_thread_buf
(
I
)
=
reduce
::
Add
::
template
GetIdentityValue
<
AccDataType
>();
mean_square_thread_buf
(
I
)
=
reduce
::
Add
::
template
GetIdentityValue
<
AccDataType
>();
});
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
...
@@ -149,12 +169,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -149,12 +169,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
Acc
DataType
,
Compute
DataType
,
GridDesc_M_K
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferLengths_M_K
,
...
@@ -166,11 +182,11 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -166,11 +182,11 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
x_grid_desc_m_k
,
x_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSlice
Size
));
thread_k_cluster_id
*
XSrcVector
Size
));
auto
threadwise_gamma_load
=
auto
threadwise_gamma_load
=
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
Acc
DataType
,
Compute
DataType
,
GridDesc_M_K
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferLengths_M_K
,
...
@@ -182,11 +198,11 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -182,11 +198,11 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
gamma_grid_desc_m_k
,
gamma_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSlice
Size
));
thread_k_cluster_id
*
GammaSrcVector
Size
));
auto
threadwise_beta_load
=
auto
threadwise_beta_load
=
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
Acc
DataType
,
Compute
DataType
,
GridDesc_M_K
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc_m_k
),
decltype
(
thread_buffer_desc_m_k
),
ThreadBufferLengths_M_K
,
ThreadBufferLengths_M_K
,
...
@@ -198,14 +214,14 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -198,14 +214,14 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
beta_grid_desc_m_k
,
beta_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSlice
Size
));
thread_k_cluster_id
*
BetaSrcVector
Size
));
auto
threadwise_y_store
=
auto
threadwise_y_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
Acc
DataType
,
ThreadwiseTensorSliceTransfer_v1r3
<
Compute
DataType
,
YDataType
,
YDataType
,
decltype
(
thread_buffer_desc_m_k
),
decltype
(
thread_buffer_desc_m_k
),
GridDesc_M_K
,
GridDesc_M_K
,
Acc
ElementwiseOperation
,
Y
ElementwiseOperation
,
ThreadBufferLengths_M_K
,
ThreadBufferLengths_M_K
,
ThreadBufferDimAccessOrder
,
ThreadBufferDimAccessOrder
,
YDstVectorDim
,
YDstVectorDim
,
...
@@ -216,13 +232,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -216,13 +232,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
y_grid_desc_m_k
,
y_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSlice
Size
),
thread_k_cluster_id
*
YDstVector
Size
),
acc
_elementwise_op
);
y
_elementwise_op
);
// Copy x from Cache
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
K_BlockTileStepSize
);
// one pass: fwd, second pass: bwd
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
SweepOnce
?
0
:
K_BlockTileSize
);
constexpr
auto
thread_copy_bwd_step_m_k
=
constexpr
auto
thread_copy_bwd_step_m_k
=
make_multi_index
(
0
,
SweepOnce
?
0
:
-
K_BlockTileSize
);
make_multi_index
(
0
,
SweepOnce
?
0
:
-
K_BlockTileSize
);
...
@@ -239,121 +252,244 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
...
@@ -239,121 +252,244 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// FIXME: Should not hack the transform from deviceOP
// FIXME: Should not hack the transform from deviceOP
int
reduce_length
=
x_grid_desc_m_k
.
GetTransforms
()[
I2
].
GetUpperLengths
()[
I0
];
int
reduce_length
=
x_grid_desc_m_k
.
GetTransforms
()[
I2
].
GetUpperLengths
()[
I0
];
index_t
reducedTiles
=
0
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
do
mean_thread_buf
(
I
)
=
reduce
::
Add
::
template
GetIdentityValue
<
ComputeDataType
>();
{
mean_square_thread_buf
(
I
)
=
reduce
::
Add
::
template
GetIdentityValue
<
ComputeDataType
>();
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
});
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
// Separate sweep once and sweep twice pipeline
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
if
constexpr
(
SweepOnce
)
constexpr
auto
offset_m_k
=
{
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
x_square_thread_buf
(
Number
<
offset_m_k
>
{})
=
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_thread_buf
(
Number
<
offset_m_k
>
{})
*
x_thread_buf
(
Number
<
offset_m_k
>
{});
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
(
i
));
threadwise_gamma_load
.
Run
(
gamma_grid_desc_m_k
,
gamma_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
gamma_thread_buf
(
i
));
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
x_square_thread_buf
(
i
)(
Number
<
offset_m_k
>
{})
=
x_thread_buf
(
i
)(
Number
<
offset_m_k
>
{})
*
x_thread_buf
(
i
)(
Number
<
offset_m_k
>
{});
});
});
});
});
ThreadwiseSumReduce
::
Reduce
(
x_thread_buf
,
mean_thread_buf
);
ThreadwiseSumReduce
::
Reduce
(
x_thread_buf
[
i
]
,
mean_thread_buf
);
ThreadwiseSumReduce
::
Reduce
(
x_square_thread_buf
,
mean_square_thread_buf
);
ThreadwiseSumReduce
::
Reduce
(
x_square_thread_buf
[
i
]
,
mean_square_thread_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
if
constexpr
(
i
!=
ThreadBufferNumber
-
1
)
{
++
reducedTiles
;
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
}
while
(
reducedTiles
<
num_k_block_tile_iteration
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
}
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
if
constexpr
(
I
>
0
)
block_sync_lds
();
block_sync_lds
();
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
mean_thread_buf
(
I
));
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
mean_thread_buf
(
I
));
mean_thread_buf
(
I
)
=
mean_thread_buf
(
I
)
/
reduce_length
;
mean_thread_buf
(
I
)
=
mean_thread_buf
(
I
)
/
reduce_length
;
block_sync_lds
();
block_sync_lds
();
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
mean_square_thread_buf
(
I
));
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
mean_square_thread_buf
(
I
));
mean_square_thread_buf
(
I
)
=
mean_square_thread_buf
(
I
)
/
reduce_length
;
mean_square_thread_buf
(
I
)
=
mean_square_thread_buf
(
I
)
/
reduce_length
;
// var(x) = E[x^2] - E[x]^2
// var(x) = E[x^2] - E[x]^2
var_thread_buf
(
I
)
=
var_thread_buf
(
I
)
=
mean_square_thread_buf
(
I
)
-
(
mean_thread_buf
(
I
)
*
mean_thread_buf
(
I
));
mean_square_thread_buf
(
I
)
-
(
mean_thread_buf
(
I
)
*
mean_thread_buf
(
I
));
});
});
// y = (x - E[x]) / sqrt(var[x] + epsilon)
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_m_k
;
auto
divisor
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
threadwise_beta_load
.
Run
(
beta_grid_desc_m_k
,
beta_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
beta_thread_buf
(
iK0
));
if
constexpr
(
iK0
!=
ThreadBufferNumber
-
1
)
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK1
));
// normalize
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
divisor
;
// gamma & beta
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
+
beta_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{});
});
});
});
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
thread_copy_tail_m_k
);
threadwise_y_store
.
Run
(
thread_buffer_desc_m_k
,
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_m_k
,
thread_copy_tail_m_k
);
make_tuple
(
I0
,
I0
),
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_tail_m_k
);
y_thread_buf
(
i
),
y_grid_desc_m_k
,
y_global_val_buf
);
reducedTiles
=
0
;
if
constexpr
(
i
!=
ThreadBufferNumber
-
1
)
do
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
}
// end of sweep once
else
{
{
if
constexpr
(
!
SweepOnce
)
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
x_global_val_buf
,
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
thread_buffer_desc_m_k
,
x_global_val_buf
,
make_tuple
(
I0
,
I0
),
thread_buffer_desc_m_k
,
x_thread_buf
);
make_tuple
(
I0
,
I0
),
x_thread_buf
(
i
));
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
x_square_thread_buf
(
i
)(
Number
<
offset_m_k
>
{})
=
x_thread_buf
(
i
)(
Number
<
offset_m_k
>
{})
*
x_thread_buf
(
i
)(
Number
<
offset_m_k
>
{});
});
});
ThreadwiseSumReduce
::
Reduce
(
x_thread_buf
[
i
],
mean_thread_buf
);
ThreadwiseSumReduce
::
Reduce
(
x_square_thread_buf
[
i
],
mean_square_thread_buf
);
});
}
}
threadwise_gamma_load
.
Run
(
gamma_grid_desc_m_k
,
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
gamma_global_val_buf
,
if
constexpr
(
I
>
0
)
thread_buffer_desc_m_k
,
block_sync_lds
();
make_tuple
(
I0
,
I0
),
gamma_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
mean_thread_buf
(
I
));
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
mean_thread_buf
(
I
)
=
mean_thread_buf
(
I
)
/
reduce_length
;
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
// normalize
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
/
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
// gamma
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
Number
<
offset_m_k
>
{});
});
});
threadwise_beta_load
.
Run
(
beta_grid_desc_m_k
,
block_sync_lds
();
beta_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
beta_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
mean_square_thread_buf
(
I
));
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
mean_square_thread_buf
(
I
)
=
mean_square_thread_buf
(
I
)
/
reduce_length
;
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
// beta
// var(x) = E[x^2] - E[x]^2
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
var_thread_buf
(
I
)
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
+
beta_thread_buf
(
Number
<
offset_m_k
>
{});
mean_square_thread_buf
(
I
)
-
(
mean_thread_buf
(
I
)
*
mean_thread_buf
(
I
));
});
});
});
threadwise_y_store
.
Run
(
thread_buffer_desc_m_k
,
auto
thread_copy_tail_m_k
=
make_tuple
(
I0
,
I0
),
(
num_k_block_tile_iteration
-
1
)
*
ThreadBufferNumber
*
thread_copy_fwd_step_m_k
;
y_thread_buf
,
y_grid_desc_m_k
,
y_global_val_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
thread_copy_tail_m_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_m_k
,
thread_copy_tail_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_tail_m_k
);
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
(
i
));
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
++
reducedTiles
;
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
}
while
(
reducedTiles
<
num_k_block_tile_iteration
);
threadwise_gamma_load
.
Run
(
gamma_grid_desc_m_k
,
gamma_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
gamma_thread_buf
(
i
));
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
auto
divisor
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK1
));
// normalize
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
divisor
;
// gamma
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{});
});
});
});
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_beta_load
.
Run
(
beta_grid_desc_m_k
,
beta_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
beta_thread_buf
(
i
));
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK1
));
// beta
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
+
beta_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{});
});
});
});
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_y_store
.
Run
(
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
y_thread_buf
(
i
),
y_grid_desc_m_k
,
y_global_val_buf
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
2
*
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
2
*
thread_copy_bwd_step_m_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_m_k
,
2
*
thread_copy_bwd_step_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
2
*
thread_copy_bwd_step_m_k
);
}
}
// end of sweep twice
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment