Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1a38e362
Commit
1a38e362
authored
Feb 08, 2023
by
rocking
Browse files
Separate sweeponce flow and optimize the flow
parent
e12a6be2
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
149 additions
and
71 deletions
+149
-71
include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp
...tion/gpu/grid/gridwise_normalization_welford_variance.hpp
+149
-71
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp
View file @
1a38e362
...
@@ -265,6 +265,84 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
...
@@ -265,6 +265,84 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
var_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
var_thread_buf
(
I
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
});
});
// Separate sweep once and sweep twice pipeline
if
constexpr
(
SweepOnce
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
x_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
x_thread_buf
(
i
));
threadwise_gamma_load
.
Run
(
gamma_grid_desc_m_k
,
gamma_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
gamma_thread_buf
(
i
));
threadwise_welford
.
Run
(
x_thread_buf
[
i
],
mean_thread_buf
,
var_thread_buf
);
if
constexpr
(
i
!=
ThreadBufferNumber
-
1
)
{
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
}
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
int
count
=
threadwise_welford
.
cur_count_
;
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
count
);
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
auto
divisor
=
1
/
ck
::
math
::
sqrt
(
var_thread_buf
(
iM
)
+
epsilon
);
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
iK0
)
{
threadwise_beta_load
.
Run
(
beta_grid_desc_m_k
,
beta_global_val_buf
,
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
beta_thread_buf
(
iK0
));
if
constexpr
(
iK0
!=
ThreadBufferNumber
-
1
)
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
static_for
<
0
,
XSrcVectorSize
,
1
>
{}([
&
](
auto
iK1
)
{
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK1
));
// normalize
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
*
divisor
;
// gamma & beta
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{})
+
beta_thread_buf
(
iK0
)(
Number
<
offset_m_k
>
{});
});
});
});
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_y_store
.
Run
(
thread_buffer_desc_m_k
,
make_tuple
(
I0
,
I0
),
y_thread_buf
(
i
),
y_grid_desc_m_k
,
y_global_val_buf
);
if
constexpr
(
i
!=
ThreadBufferNumber
-
1
)
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
}
// end of sweep once
else
{
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
...
@@ -295,8 +373,6 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
...
@@ -295,8 +373,6 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_tail_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_tail_m_k
);
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
{
if
constexpr
(
!
SweepOnce
)
{
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
threadwise_x_load
.
Run
(
x_grid_desc_m_k
,
...
@@ -306,7 +382,6 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
...
@@ -306,7 +382,6 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
x_thread_buf
(
i
));
x_thread_buf
(
i
));
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
});
}
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
ThreadBufferNumber
,
1
>
{}([
&
](
auto
i
)
{
threadwise_gamma_load
.
Run
(
gamma_grid_desc_m_k
,
threadwise_gamma_load
.
Run
(
gamma_grid_desc_m_k
,
...
@@ -369,7 +444,8 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
...
@@ -369,7 +444,8 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
y_thread_buf
(
i
),
y_thread_buf
(
i
),
y_grid_desc_m_k
,
y_grid_desc_m_k
,
y_global_val_buf
);
y_global_val_buf
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_fwd_step_m_k
);
});
});
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
2
*
thread_copy_bwd_step_m_k
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
2
*
thread_copy_bwd_step_m_k
);
...
@@ -377,8 +453,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
...
@@ -377,8 +453,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
2
*
thread_copy_bwd_step_m_k
);
2
*
thread_copy_bwd_step_m_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_m_k
,
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_m_k
,
2
*
thread_copy_bwd_step_m_k
);
2
*
thread_copy_bwd_step_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
2
*
thread_copy_bwd_step_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
2
*
thread_copy_bwd_step_m_k
);
}
}
}
// end of sweep twice
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment