Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4ee40bcc
Commit
4ee40bcc
authored
Oct 12, 2024
by
letaoqin
Browse files
change warp_welford.hpp
parent
63214d01
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
50 deletions
+10
-50
include/ck_tile/ops/welford/warp/warp_welford.hpp
include/ck_tile/ops/welford/warp/warp_welford.hpp
+10
-50
No files found.
include/ck_tile/ops/welford/warp/warp_welford.hpp
View file @
4ee40bcc
...
@@ -44,9 +44,9 @@ struct WarpMergeWelford
...
@@ -44,9 +44,9 @@ struct WarpMergeWelford
constexpr
index_t
idim_p_lane
=
NDimP
-
1
;
constexpr
index_t
idim_p_lane
=
NDimP
-
1
;
const
auto
ps_idx
=
make_array
<
index_t
>
(
get_warp_id
(),
get_lane_id
());
//
const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
const
auto
rs_idx
=
//
const auto rs_idx =
mean_tensor
.
get_tile_distribution
().
calculate_rs_index_from_ps_index
(
ps_idx
);
//
mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
constexpr
index_t
thread_buf_size
=
MeanDistributedTensor_
::
get_thread_buffer_size
();
constexpr
index_t
thread_buf_size
=
MeanDistributedTensor_
::
get_thread_buffer_size
();
static_assert
(
thread_buf_size
==
VarDistributedTensor_
::
get_thread_buffer_size
());
static_assert
(
thread_buf_size
==
VarDistributedTensor_
::
get_thread_buffer_size
());
...
@@ -78,13 +78,15 @@ struct WarpMergeWelford
...
@@ -78,13 +78,15 @@ struct WarpMergeWelford
// reduction sweep forward
// reduction sweep forward
static_for
<
0
,
nstage
,
1
>
{}([
&
](
auto
istage
)
{
static_for
<
0
,
nstage
,
1
>
{}([
&
](
auto
istage
)
{
constexpr
index_t
lid_delta
=
// xor
lid_over_rid_derivative
*
(
1
<<
(
nstage
-
istage
-
1
));
index_t
src_lane
=
(
__lane_id
())
^
(
number
<
lid_over_rid_derivative
<<
istage
.
value
>
{}.
value
);
// pull data from remote lane
// pull data from remote lane
const
auto
v_remote_mean
=
warp_shuffle
_down
(
v_local_mean
,
lid_delta
);
const
auto
v_remote_mean
=
warp_shuffle
(
v_local_mean
,
src_lane
);
const
auto
v_remote_var
=
warp_shuffle
_down
(
v_local_var
,
lid_delta
);
const
auto
v_remote_var
=
warp_shuffle
(
v_local_var
,
src_lane
);
const
auto
v_remote_count
=
warp_shuffle
_down
(
v_local_count
,
lid_delta
);
const
auto
v_remote_count
=
warp_shuffle
(
v_local_count
,
src_lane
);
// welford merge
// welford merge
Merge
(
v_local_mean
,
Merge
(
v_local_mean
,
...
@@ -97,48 +99,6 @@ struct WarpMergeWelford
...
@@ -97,48 +99,6 @@ struct WarpMergeWelford
}
}
});
});
// cross-lane broadcast for replication
// only broadcast on R dimension correspond to lane
// (lane id maps to this R dimension)
if
constexpr
(
BroadcastLane
)
{
static_for
<
0
,
NDimR
,
1
>
{}([
&
](
auto
idim_r
)
{
// FIXME: nasty to use does_p_own_r_
if
constexpr
(
DstrEncodeDetail
::
does_p_own_r_
[
idim_p_lane
][
idim_r
])
{
const
index_t
r_id
=
rs_idx
[
idim_r
];
constexpr
index_t
r_length
=
DstrEncode
::
rs_lengths_
[
idim_r
];
constexpr
index_t
lid_over_rid_derivative
=
DstrEncodeDetail
::
ps_over_rs_derivative_
[
NDimP
-
1
][
idim_r
];
static_assert
(
is_power_of_two_integer
(
r_length
),
"wrong! only support power of 2 reduction"
);
constexpr
index_t
nstage
=
integer_log2_floor
(
r_length
);
// broadcast sweep backward
static_for
<
0
,
nstage
,
1
>
{}([
&
](
auto
istage
)
{
// do I hold reduced data?
const
bool
do_i_hold_reduced_data
=
r_id
<
(
1
<<
istage
);
constexpr
index_t
lid_delta
=
lid_over_rid_derivative
*
(
1
<<
istage
);
// pull data from remote lane
const
auto
v_remote_mean
=
warp_shuffle_up
(
v_local_mean
,
lid_delta
);
const
auto
v_remote_var
=
warp_shuffle_up
(
v_local_var
,
lid_delta
);
const
auto
v_remote_count
=
warp_shuffle_up
(
v_local_count
,
lid_delta
);
// decide whether to update local data with remote data
v_local_mean
=
do_i_hold_reduced_data
?
v_local_mean
:
v_remote_mean
;
v_local_var
=
do_i_hold_reduced_data
?
v_local_var
:
v_remote_var
;
v_local_count
=
do_i_hold_reduced_data
?
v_local_count
:
v_remote_count
;
});
}
});
}
mean_tensor
.
get_thread_buffer
()(
i
)
=
v_local_mean
;
mean_tensor
.
get_thread_buffer
()(
i
)
=
v_local_mean
;
if
constexpr
(
GetActualVariance
)
if
constexpr
(
GetActualVariance
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment