Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5b2eb1ab
"docs/source/en/vscode:/vscode.git/clone" did not exist on "6d32b29239fbbea5a00ca94e42642e87457250e3"
Commit
5b2eb1ab
authored
Mar 20, 2024
by
Adam Osewski
Browse files
Single flag per workgroup synchronization scheme.
parent
f8ca9048
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
49 additions
and
104 deletions
+49
-104
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
...grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
+9
-13
include/ck/utility/work_scheduling.hpp
include/ck/utility/work_scheduling.hpp
+29
-72
include/ck/utility/workgroup_barrier.hpp
include/ck/utility/workgroup_barrier.hpp
+4
-4
test/work_scheduling/test_strided_reduction_tile_loop.cpp
test/work_scheduling/test_strided_reduction_tile_loop.cpp
+7
-15
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
View file @
5b2eb1ab
...
@@ -156,32 +156,28 @@ __global__ void
...
@@ -156,32 +156,28 @@ __global__ void
}
while
(
work_scheduler
.
GetNextTile
()
&&
b2c_tile_map
.
GetNextKTileIdx
());
}
while
(
work_scheduler
.
GetNextTile
()
&&
b2c_tile_map
.
GetNextKTileIdx
());
const
index_t
output_tile_idx
=
__builtin_amdgcn_readfirstlane
(
b2c_tile_map
.
GetOutputTileIdx
());
const
index_t
output_tile_idx_offset
=
__builtin_amdgcn_readfirstlane
(
offset
/
k_batch
);
// if (changed group_id || next [M,N] tile)
// if (changed group_id || next [M,N] tile)
if
(
!
b2c_tile_map
.
IsFirstKSplitBlock
())
if
(
!
b2c_tile_map
.
IsFirstKSplitBlock
())
{
{
GridwiseGemm
::
StorePartials
(
p_workspace
,
results_buffer
);
GridwiseGemm
::
StorePartials
(
p_workspace
,
results_buffer
);
}
}
work_scheduler
.
FlagFinished
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
work_scheduler
.
FlagFinished
();
// The workgroup which processed first K tile accumulates results and stores to GMEM
// The workgroup which processed first K tile accumulates results and stores to GMEM
if
(
b2c_tile_map
.
IsFirstKSplitBlock
())
if
(
b2c_tile_map
.
IsFirstKSplitBlock
())
{
{
// Wait untill all other blocks for this [M,N] tile store their results.
// Wait untill all other blocks for this [M,N] tile store their results.
index_t
neighbour_count
=
work_scheduler
.
WaitForNeighbours
(
index_t
neighbour_count
=
k_batch
,
b2c_tile_map
.
GetTileKIdx
()
,
output_tile_idx
,
output_tile_idx_offset
);
work_scheduler
.
WaitForNeighbours
(
k_batch
,
b2c_tile_map
.
GetTileKIdx
());
// Accumulate only when there is at least two workgroups processing splitk data-tiles
// Accumulate only when there is at least two workgroups processing splitk data-tiles
// across same MN-output tile.
// across same MN-output tile.
if
(
neighbour_count
>
1
)
if
(
neighbour_count
>
0
)
GridwiseGemm
::
AccumulatePartials
(
p_workspace
,
results_buffer
,
neighbour_count
);
GridwiseGemm
::
AccumulatePartials
(
p_workspace
,
results_buffer
,
neighbour_count
+
1
);
// Signal waiting blocks that they can start use their workspace.
// Signal waiting blocks that they can start use their workspace.
work_scheduler
.
Reset
(
k_batch
,
output_tile_idx
,
output_tile_idx_offse
t
);
work_scheduler
.
Reset
(
neighbour_coun
t
);
const
auto
p_e_grid
=
reinterpret_cast
<
FloatC
*>
(
gemm_desc_ptr
[
group_id
].
p_e_grid
);
const
auto
p_e_grid
=
reinterpret_cast
<
FloatC
*>
(
gemm_desc_ptr
[
group_id
].
p_e_grid
);
const
auto
stride_e
=
gemm_desc_ptr
[
group_id
].
StrideE
;
const
auto
stride_e
=
gemm_desc_ptr
[
group_id
].
StrideE
;
...
@@ -210,7 +206,7 @@ __global__ void
...
@@ -210,7 +206,7 @@ __global__ void
}
}
else
if
(
work_scheduler
.
HasTile
())
else
if
(
work_scheduler
.
HasTile
())
{
{
work_scheduler
.
WaitForReduction
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
work_scheduler
.
WaitForReduction
();
}
}
}
while
(
work_scheduler
.
HasTile
());
}
while
(
work_scheduler
.
HasTile
());
#else
#else
...
@@ -752,7 +748,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
...
@@ -752,7 +748,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
void
*
p_flags
=
reinterpret_cast
<
char
*>
(
dev_gemm_workspace
)
+
void
*
p_flags
=
reinterpret_cast
<
char
*>
(
dev_gemm_workspace
)
+
Block2ETileMapKSplit
::
GetAccWorkspaceSize
(
Block2ETileMapKSplit
::
GetAccWorkspaceSize
(
sizeof
(
typename
GridwiseGemm
::
AccType
),
grid_size
);
sizeof
(
typename
GridwiseGemm
::
AccType
),
grid_size
);
std
::
size_t
flag_count
=
(
grid_size
*
tiles_per_block
+
arg
.
K_BATCH
-
1
)
/
arg
.
K_BATCH
;
std
::
size_t
flag_count
=
grid_size
;
if
(
stream_config
.
log_level_
>
0
)
if
(
stream_config
.
log_level_
>
0
)
{
{
...
@@ -993,7 +989,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
...
@@ -993,7 +989,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
{
{
grid_size
=
(
arg
.
tile_count_
+
tiles_per_block
-
1
)
/
tiles_per_block
;
grid_size
=
(
arg
.
tile_count_
+
tiles_per_block
-
1
)
/
tiles_per_block
;
}
}
int
flag_count
=
(
grid_size
*
tiles_per_block
+
arg
.
K_BATCH
-
1
)
/
arg
.
K_BATCH
;
int
flag_count
=
grid_size
;
// This would be the maximum needed workspace size. Since actual grid size, which determines
// This would be the maximum needed workspace size. Since actual grid size, which determines
// the amount of workspace bytes needed, may be less due to the number of available CUs in
// the amount of workspace bytes needed, may be less due to the number of available CUs in
...
...
include/ck/utility/work_scheduling.hpp
View file @
5b2eb1ab
...
@@ -32,8 +32,7 @@ enum struct WorkSchedulingPolicy
...
@@ -32,8 +32,7 @@ enum struct WorkSchedulingPolicy
class
StridedReductionTileLoop
class
StridedReductionTileLoop
{
{
public:
public:
__device__
StridedReductionTileLoop
(
index_t
tile_count
,
__device__
StridedReductionTileLoop
(
index_t
tile_count
,
uint32_t
*
const
__restrict__
p_flags
)
volatile
uint32_t
*
const
__restrict__
p_flags
)
:
tile_count_
{
tile_count
},
:
tile_count_
{
tile_count
},
tiles_per_block_
{(
tile_count_
+
get_grid_size
()
-
1
)
/
get_grid_size
()},
tiles_per_block_
{(
tile_count_
+
get_grid_size
()
-
1
)
/
get_grid_size
()},
tile_id_
{
get_block_1d_id
()
*
tiles_per_block_
},
tile_id_
{
get_block_1d_id
()
*
tiles_per_block_
},
...
@@ -54,62 +53,29 @@ class StridedReductionTileLoop
...
@@ -54,62 +53,29 @@ class StridedReductionTileLoop
return
HasTile
();
return
HasTile
();
}
}
__device__
index_t
GetFlagCount
(
index_t
k_tiles
)
const
__device__
index_t
GetFlagCount
()
const
{
return
get_grid_size
();
}
{
// This is the number of MN-output tiles which we cover with workgroups.
// We launch k_tiles (k_batch) / tiles_per_block workgroups for each output tile.
return
(
get_grid_size
()
*
tiles_per_block_
+
k_tiles
-
1
)
/
k_tiles
;
}
///
///
/// @brief Calculate this workgroup flag index.
/// @brief Get this workgroup flag index.
///
/// @note Note this scheduler intentionaly does not have flag index as its member, since
/// current workgroup may process tiles across different MN-output tiles or
/// acorss different GEMMs (grouped gemm).
///
/// @param[in] k_tiles The number of data tiles in the reduced dimension.
/// @param[in] output_tile_idx The output (MN) linear tile index (of current GEMM).
/// @param[in] output_tile_idx_offset The accumulated offset of output tiles from previous
/// GEMMs.
///
///
/// @return The workgroup flag index.
/// @return The workgroup flag index.
///
///
__device__
uint32_t
GetWorkgroupFlagIdx
(
index_t
k_tiles
,
__device__
uint32_t
GetWorkgroupFlagIdx
()
const
{
return
static_cast
<
uint32_t
>
(
blockIdx
.
x
);
}
index_t
output_tile_idx
,
index_t
output_tile_idx_offset
)
const
{
return
(
output_tile_idx
+
output_tile_idx_offset
)
%
GetFlagCount
(
k_tiles
);
}
///
///
/// @brief Flag each workgroup that has finished its work.
/// @brief Flag each workgroup that has finished its work.
///
///
/// @param[in] k_tiles The number of tiles in the reduced dimension.
__device__
void
FlagFinished
()
{
finished_block_flags_
.
inc
(
GetWorkgroupFlagIdx
());
}
/// @param[in] output_tile_idx The output (MN) tile index
/// @param[in] output_tile_idx_offset The output tile index offset
///
__device__
void
FlagFinished
(
index_t
k_tiles
,
index_t
output_tile_idx
,
index_t
output_tile_idx_offset
)
{
const
auto
fidx
=
GetWorkgroupFlagIdx
(
k_tiles
,
output_tile_idx
,
output_tile_idx_offset
);
finished_block_flags_
.
inc
(
fidx
);
}
///
///
/// @brief Wait until each workgroup has finished its work.
/// @brief Wait until each workgroup has finished its work.
///
///
/// @param[in] k_tiles The number of tiles in the reduced dimension.
/// @param[in] k_tiles The number of tiles in the reduced dimension.
/// @param[in] k_tile_idx The currently processed tile k index.
/// @param[in] k_tile_idx The currently processed tile k index.
/// @param[in] output_tile_idx The output (MN) tile index
/// @param[in] output_tile_idx_offset The output tile index offset
///
///
/// @return The number of neighbours.
/// @return The number of neighbours.
///
///
__device__
index_t
WaitForNeighbours
(
index_t
k_tiles
,
__device__
index_t
WaitForNeighbours
(
index_t
k_tiles
,
index_t
k_tile_idx
)
index_t
k_tile_idx
,
index_t
output_tile_idx
,
index_t
output_tile_idx_offset
)
{
{
// We have to wait for all workgroups to finish their partial results.
// We have to wait for all workgroups to finish their partial results.
// First count how many "neighbour" workgroups we have to check.
// First count how many "neighbour" workgroups we have to check.
...
@@ -139,57 +105,48 @@ class StridedReductionTileLoop
...
@@ -139,57 +105,48 @@ class StridedReductionTileLoop
if
(
neighbour_count
>
0
)
if
(
neighbour_count
>
0
)
{
{
// Also count this workgroup
index_t
flag_sum
=
0
;
neighbour_count
++
;
do
finished_block_flags_
.
wait_eq
(
{
GetWorkgroupFlagIdx
(
k_tiles
,
output_tile_idx
,
output_tile_idx_offset
),
flag_sum
=
0
;
neighbour_count
);
for
(
index_t
i
=
1
;
i
<=
neighbour_count
;
++
i
)
{
flag_sum
+=
finished_block_flags_
.
ld
(
GetWorkgroupFlagIdx
()
+
i
);
}
}
while
(
flag_sum
!=
neighbour_count
);
}
}
return
neighbour_count
;
return
neighbour_count
;
}
}
///
///
/// @brief Wait until each workgroup has finished its work.
/// @brief Wait until reduction workgroup has finished its work.
///
/// @param[in] k_tiles The number of tiles in the reduced dimension.
/// @param[in] output_tile_idx The output (MN) tile index
/// @param[in] output_tile_idx_offset The output tile index offset
///
///
__device__
void
__device__
void
WaitForReduction
()
WaitForReduction
(
index_t
k_tiles
,
index_t
output_tile_idx
,
index_t
output_tile_idx_offset
)
{
{
// Wait untill the counter has been reset.
// Wait untill my counter has been reset.
finished_block_flags_
.
wait_eq
(
finished_block_flags_
.
wait_eq
(
GetWorkgroupFlagIdx
(),
0
);
GetWorkgroupFlagIdx
(
k_tiles
,
output_tile_idx
,
output_tile_idx_offset
),
0
);
}
}
///
///
/// @brief Reset flag counter to zero.
/// @brief Reset flag counter to zero.
///
///
/// @param[in] k_tiles The number of tiles in the reduced dimension.
/// @param[in] neighbour_count The number of peer workgroups.
/// @param[in] output_tile_idx The output (MN) tile index.
/// @param[in] output_tile_idx_offset The output tile index offset.
///
///
__device__
void
Reset
(
index_t
k_tiles
,
index_t
output_tile_idx
,
index_t
output_tile_idx_offse
t
)
__device__
void
Reset
(
index_t
neighbour_coun
t
)
{
{
finished_block_flags_
.
reset
(
for
(
index_t
i
=
0
;
i
<=
neighbour_count
;
++
i
)
GetWorkgroupFlagIdx
(
k_tiles
,
output_tile_idx
,
output_tile_idx_offset
));
{
finished_block_flags_
.
reset
(
GetWorkgroupFlagIdx
()
+
i
);
}
}
}
///
///
/// @brief Gets the flag value.
/// @brief Gets the flag value.
///
///
/// @param[in] k_tiles The number of tiles in the reduced dimension.
__device__
uint32_t
GetFlagValue
()
const
/// @param[in] output_tile_idx The output (MN) tile index.
/// @param[in] output_tile_idx_offset The output tile index offset.
///
__device__
uint32_t
GetFlagValue
(
index_t
k_tiles
,
index_t
output_tile_idx
,
index_t
output_tile_idx_offset
)
const
{
{
return
finished_block_flags_
.
ld
(
return
finished_block_flags_
.
ld
(
GetWorkgroupFlagIdx
());
GetWorkgroupFlagIdx
(
k_tiles
,
output_tile_idx
,
output_tile_idx_offset
));
}
}
const
index_t
tile_count_
;
const
index_t
tile_count_
;
...
...
include/ck/utility/workgroup_barrier.hpp
View file @
5b2eb1ab
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
namespace
ck
{
namespace
ck
{
struct
workgroup_barrier
struct
workgroup_barrier
{
{
__device__
workgroup_barrier
(
volatile
uint32_t
*
ptr
)
:
base_ptr
(
ptr
)
{}
__device__
workgroup_barrier
(
uint32_t
*
ptr
)
:
base_ptr
(
ptr
)
{}
__device__
uint32_t
ld
(
uint32_t
offset
)
const
__device__
uint32_t
ld
(
uint32_t
offset
)
const
{
{
...
@@ -53,7 +53,7 @@ struct workgroup_barrier
...
@@ -53,7 +53,7 @@ struct workgroup_barrier
{
{
if
(
threadIdx
.
x
==
0
)
if
(
threadIdx
.
x
==
0
)
{
{
while
(
atomicCAS
(
const_cast
<
uint32_t
*>
(
base_ptr
+
offset
)
,
compare
,
value
)
!=
compare
)
{}
while
(
atomicCAS
(
base_ptr
+
offset
,
compare
,
value
)
!=
compare
)
{}
}
}
__syncthreads
();
__syncthreads
();
}
}
...
@@ -68,7 +68,7 @@ struct workgroup_barrier
...
@@ -68,7 +68,7 @@ struct workgroup_barrier
{
{
if
(
threadIdx
.
x
==
0
)
if
(
threadIdx
.
x
==
0
)
{
{
atomicAdd
(
const_cast
<
uint32_t
*>
(
base_ptr
+
offset
)
,
1
);
atomicAdd
(
base_ptr
+
offset
,
1
);
}
}
__syncthreads
();
__syncthreads
();
}
}
...
@@ -82,6 +82,6 @@ struct workgroup_barrier
...
@@ -82,6 +82,6 @@ struct workgroup_barrier
__syncthreads
();
__syncthreads
();
}
}
volatile
uint32_t
*
base_ptr
;
uint32_t
*
base_ptr
;
};
};
}
// namespace ck
}
// namespace ck
test/work_scheduling/test_strided_reduction_tile_loop.cpp
View file @
5b2eb1ab
...
@@ -166,22 +166,18 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
...
@@ -166,22 +166,18 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
// partial_result;
// partial_result;
}
}
const
index_t
output_tile_idx
=
work_scheduler
.
FlagFinished
();
__builtin_amdgcn_readfirstlane
(
b2c_tile_map
.
GetOutputTileIdx
());
const
index_t
output_tile_idx_offset
=
__builtin_amdgcn_readfirstlane
(
offset
/
k_batch
);
work_scheduler
.
FlagFinished
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
// The workgroup which processed first K tile accumulates results and stores to GMEM
// The workgroup which processed first K tile accumulates results and stores to GMEM
if
(
b2c_tile_map
.
IsFirstKSplitBlock
())
if
(
b2c_tile_map
.
IsFirstKSplitBlock
())
{
{
// Wait untill all other blocks for this [M,N] tile store their results.
// Wait untill all other blocks for this [M,N] tile store their results.
index_t
neighbour_count
=
work_scheduler
.
WaitForNeighbours
(
index_t
neighbour_count
=
k_batch
,
b2c_tile_map
.
GetTileKIdx
()
,
output_tile_idx
,
output_tile_idx_offset
);
work_scheduler
.
WaitForNeighbours
(
k_batch
,
b2c_tile_map
.
GetTileKIdx
());
// Accumulate partial results. We can have different # of workgroups to reduce, thus we
// Accumulate partial results. We can have different # of workgroups to reduce, thus we
// read actual flag value.
// read actual flag value.
for
(
index_t
i
=
1
;
i
<
neighbour_count
;
++
i
)
for
(
index_t
i
=
1
;
i
<
=
neighbour_count
;
++
i
)
{
{
// partial_result += p_workspace[(get_block_1d_id()) * MPerBlock * NPerBlock +
// partial_result += p_workspace[(get_block_1d_id()) * MPerBlock * NPerBlock +
// i * MPerBlock * NPerBlock +
// i * MPerBlock * NPerBlock +
...
@@ -199,7 +195,7 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
...
@@ -199,7 +195,7 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
}
}
// Signal waiting blocks that they can start use their workspace.
// Signal waiting blocks that they can start use their workspace.
work_scheduler
.
Reset
(
k_batch
,
output_tile_idx
,
output_tile_idx_offse
t
);
work_scheduler
.
Reset
(
neighbour_coun
t
);
// write result
// write result
const
index_t
C_m_tile_offset
=
block_m_id
*
MPerBlock
;
const
index_t
C_m_tile_offset
=
block_m_id
*
MPerBlock
;
...
@@ -221,7 +217,7 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
...
@@ -221,7 +217,7 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
}
}
else
if
(
work_scheduler
.
HasTile
())
else
if
(
work_scheduler
.
HasTile
())
{
{
work_scheduler
.
WaitForReduction
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
work_scheduler
.
WaitForReduction
();
}
}
}
while
(
work_scheduler
.
HasTile
());
}
while
(
work_scheduler
.
HasTile
());
...
@@ -328,11 +324,7 @@ struct GroupedGemmStridedTileLoopReduce
...
@@ -328,11 +324,7 @@ struct GroupedGemmStridedTileLoopReduce
gemm_descs_device_buf
.
ToDevice
(
gemm_descs
.
data
());
gemm_descs_device_buf
.
ToDevice
(
gemm_descs
.
data
());
DeviceMem
gemm_workspace
,
gemm_flags
;
DeviceMem
gemm_workspace
,
gemm_flags
;
const
index_t
flag_count
=
grid_size
;
const
index_t
tiles_per_block
=
(
tile_count
+
grid_size
-
1
)
/
grid_size
;
// This is the number of MN-output tiles which we cover with workgroups.
// We launch k_batch / tiles_per_block workgroups for each output tile.
const
index_t
flag_count
=
(
grid_size
*
tiles_per_block
+
k_batch
-
1
)
/
k_batch
;
gemm_workspace
.
Realloc
(
grid_size
*
MPerBlock
*
NPerBlock
*
sizeof
(
float
));
gemm_workspace
.
Realloc
(
grid_size
*
MPerBlock
*
NPerBlock
*
sizeof
(
float
));
gemm_flags
.
Realloc
(
flag_count
*
sizeof
(
uint32_t
));
gemm_flags
.
Realloc
(
flag_count
*
sizeof
(
uint32_t
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment