Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
73d26a88
Commit
73d26a88
authored
Feb 07, 2023
by
rocking
Browse files
Check the vector size and remove redundant var
parent
b53db56e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
17 deletions
+18
-17
include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp
...tion/gpu/grid/gridwise_normalization_welford_variance.hpp
+18
-17
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp
View file @
73d26a88
...
...
@@ -43,6 +43,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
(
YDstVectorDim
==
1
&&
KThreadSliceSize
%
YDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
(
XSrcVectorSize
==
YDstVectorSize
);
static_assert
(
XSrcVectorSize
==
GammaSrcVectorSize
);
static_assert
(
XSrcVectorSize
==
BetaSrcVectorSize
);
static
constexpr
bool
reorder_thread_cluster
=
(
XSrcVectorDim
==
0
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
...
...
@@ -77,10 +81,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileStepSize
=
KThreadClusterSize
*
XSrcVectorSize
;
static
constexpr
auto
XThreadBufferNumber
=
Number
<
KThreadSliceSize
/
XSrcVectorSize
>
{};
static
constexpr
auto
GammaThreadBufferNumber
=
Number
<
KThreadSliceSize
/
XSrcVectorSize
>
{};
static
constexpr
auto
BetaThreadBufferNumber
=
Number
<
KThreadSliceSize
/
XSrcVectorSize
>
{};
static
constexpr
auto
YThreadBufferNumber
=
Number
<
KThreadSliceSize
/
XSrcVectorSize
>
{};
static
constexpr
auto
ThreadBufferNumber
=
Number
<
KThreadSliceSize
/
XSrcVectorSize
>
{};
__device__
static
int
GetKPerThread
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
int
thread_k_cluster_id
)
...
...
@@ -93,7 +94,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
if
(
kPerBlockTail
>
0
)
{
static_for
<
0
,
X
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
int
thread_max_len
=
(
thread_k_cluster_id
+
1
)
*
XSrcVectorSize
+
K_BlockTileStepSize
*
i
;
int
delta
=
thread_max_len
-
kPerBlockTail
;
...
...
@@ -132,7 +133,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
MThreadSliceSize
*
XSrcVectorSize
,
true
>
{};
},
Number
<
X
ThreadBufferNumber
>
{});
Number
<
ThreadBufferNumber
>
{});
auto
gamma_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
...
...
@@ -141,7 +142,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
MThreadSliceSize
*
GammaSrcVectorSize
,
true
>
{};
},
Number
<
Gamma
ThreadBufferNumber
>
{});
Number
<
ThreadBufferNumber
>
{});
auto
beta_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
...
...
@@ -150,7 +151,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
MThreadSliceSize
*
BetaSrcVectorSize
,
true
>
{};
},
Number
<
Beta
ThreadBufferNumber
>
{});
Number
<
ThreadBufferNumber
>
{});
auto
y_thread_buf
=
generate_tuple
(
[
&
](
auto
)
{
...
...
@@ -159,7 +160,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
MThreadSliceSize
*
YDstVectorSize
,
true
>
{};
},
Number
<
Y
ThreadBufferNumber
>
{});
Number
<
ThreadBufferNumber
>
{});
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
var_thread_buf
;
...
...
@@ -266,7 +267,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
static_for
<
0
,
X
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
...
...
@@ -286,7 +287,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
});
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
X
ThreadBufferNumber
*
thread_copy_fwd_step_m_k
;
(
num_k_block_tile_iteration
-
1
)
*
ThreadBufferNumber
*
thread_copy_fwd_step_m_k
;
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
thread_copy_tail_m_k
);
...
...
@@ -297,7 +298,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
{
if
constexpr
(
!
SweepOnce
)
{
static_for
<
0
,
X
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
...
...
@@ -307,7 +308,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
});
}
static_for
<
0
,
Gamma
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_gamma_load
.
Run
(
gamma_grid_desc_m_k
,
gamma_global_val_buf
,
thread_buffer_desc_m_k
,
...
...
@@ -320,7 +321,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
auto
divisor
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
static_for
<
0
,
X
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK1
));
...
...
@@ -338,7 +339,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
});
});
static_for
<
0
,
Beta
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_beta_load
.
Run
(
beta_grid_desc_m_k
,
beta_global_val_buf
,
thread_buffer_desc_m_k
,
...
...
@@ -349,7 +350,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
X
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK1
));
...
...
@@ -362,7 +363,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
});
});
static_for
<
0
,
Y
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_y_store
.
Run
(
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
y_thread_buf
(
i
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment