Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e954c206
Commit
e954c206
authored
Jan 29, 2024
by
Adam Osewski
Browse files
Clean up and change how neighbours are counted.
parent
7e71ea99
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
18 deletions
+44
-18
include/ck/utility/work_scheduling.hpp
include/ck/utility/work_scheduling.hpp
+44
-18
No files found.
include/ck/utility/work_scheduling.hpp
View file @
e954c206
...
@@ -79,8 +79,8 @@ class StridedReductionTileLoop
...
@@ -79,8 +79,8 @@ class StridedReductionTileLoop
index_t
output_tile_idx
,
index_t
output_tile_idx
,
index_t
output_tile_idx_offset
)
const
index_t
output_tile_idx_offset
)
const
{
{
//
return (output_tile_idx + output_tile_idx_offset) % GetFlagCount(k_tiles);
return
(
output_tile_idx
+
output_tile_idx_offset
)
%
GetFlagCount
(
k_tiles
);
return
output_tile_idx
+
output_tile_idx_offset
;
//
return output_tile_idx + output_tile_idx_offset;
}
}
///
///
...
@@ -93,7 +93,7 @@ class StridedReductionTileLoop
...
@@ -93,7 +93,7 @@ class StridedReductionTileLoop
__device__
void
__device__
void
FlagFinished
(
index_t
k_tiles
,
index_t
output_tile_idx
,
index_t
output_tile_idx_offset
)
FlagFinished
(
index_t
k_tiles
,
index_t
output_tile_idx
,
index_t
output_tile_idx_offset
)
{
{
/* [[maybe_unused]] */
const
auto
fidx
=
GetWorkgroupFlagIdx
(
k_tiles
,
output_tile_idx
,
output_tile_idx_offset
);
const
auto
fidx
=
GetWorkgroupFlagIdx
(
k_tiles
,
output_tile_idx
,
output_tile_idx_offset
);
finished_block_flags_
.
inc
(
fidx
);
finished_block_flags_
.
inc
(
fidx
);
}
}
...
@@ -101,21 +101,51 @@ class StridedReductionTileLoop
...
@@ -101,21 +101,51 @@ class StridedReductionTileLoop
/// @brief Wait until each workgroup has finished its work.
/// @brief Wait until each workgroup has finished its work.
///
///
/// @param[in] k_tiles The number of tiles in the reduced dimension.
/// @param[in] k_tiles The number of tiles in the reduced dimension.
/// @param[in] k_tile_idx The currently processed tile k index.
/// @param[in] output_tile_idx The output (MN) tile index
/// @param[in] output_tile_idx The output (MN) tile index
/// @param[in] output_tile_idx_offset The output tile index offset
/// @param[in] output_tile_idx_offset The output tile index offset
///
///
__device__
void
/// @return The number of neighbours.
WaitForNeighbours
(
index_t
k_tiles
,
index_t
output_tile_idx
,
index_t
output_tile_idx_offset
)
///
__device__
index_t
WaitForNeighbours
(
index_t
k_tiles
,
index_t
k_tile_idx
,
index_t
output_tile_idx
,
index_t
output_tile_idx_offset
)
{
// We have to wait for all workgroups to finish their partial results.
// First count how many "neighbour" workgroups we have to check.
index_t
neighbour_count
=
0
;
if
(
tiles_per_block_
<
k_tiles
)
{
// Since we can have deviation (+1) in neighbours number
// we calculate how many workgroups are needed to process the k-tiles left.
neighbour_count
=
(
k_tiles
-
k_tile_idx
-
1
+
tiles_per_block_
-
1
)
/
tiles_per_block_
;
}
// If we have more tiles to process than the reduction dimension size,
// then the number of neighbours depends on first K-tile workgroup block tile idx.
else
{
if
(
block_tile_idx_
==
tiles_per_block_
)
{
// If we just finished work per workgroup then check at which k-idx we are.
neighbour_count
=
(
k_tile_idx
<
k_tiles
-
1
)
?
1
:
0
;
}
else
{
// If we have still tiles to process then it means that we already processed
// whole K-dim.
neighbour_count
=
0
;
}
}
if
(
neighbour_count
>
0
)
{
{
// Wait untill all workgroups finish
const
index_t
workgroups_per_dim
=
(
k_tiles
+
tiles_per_block_
-
1
)
/
tiles_per_block_
;
// We use < because for some cases we may have +1 more workgroups per dim.
// Ie when k_tiles = 5, tiles_per_block = 3.
finished_block_flags_
.
wait_lt
(
finished_block_flags_
.
wait_lt
(
GetWorkgroupFlagIdx
(
k_tiles
,
output_tile_idx
,
output_tile_idx_offset
),
GetWorkgroupFlagIdx
(
k_tiles
,
output_tile_idx
,
output_tile_idx_offset
),
workgroups_per_dim
);
neighbour_count
);
}
// [[maybe_unused]] const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset)
;
return
neighbour_count
;
}
}
///
///
...
@@ -131,8 +161,6 @@ class StridedReductionTileLoop
...
@@ -131,8 +161,6 @@ class StridedReductionTileLoop
// Wait untill the counter has been reset.
// Wait untill the counter has been reset.
finished_block_flags_
.
wait_eq
(
finished_block_flags_
.
wait_eq
(
GetWorkgroupFlagIdx
(
k_tiles
,
output_tile_idx
,
output_tile_idx_offset
),
0
);
GetWorkgroupFlagIdx
(
k_tiles
,
output_tile_idx
,
output_tile_idx_offset
),
0
);
// [[maybe_unused]] const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset);
}
}
///
///
...
@@ -146,8 +174,6 @@ class StridedReductionTileLoop
...
@@ -146,8 +174,6 @@ class StridedReductionTileLoop
{
{
finished_block_flags_
.
reset
(
finished_block_flags_
.
reset
(
GetWorkgroupFlagIdx
(
k_tiles
,
output_tile_idx
,
output_tile_idx_offset
));
GetWorkgroupFlagIdx
(
k_tiles
,
output_tile_idx
,
output_tile_idx_offset
));
// [[maybe_unused]] const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset);
}
}
///
///
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment