Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
73adb83d
Commit
73adb83d
authored
Jun 27, 2024
by
Adam Osewski
Browse files
Do not synchronize when it's not necessary.
parent
c5c95578
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
36 additions
and
12 deletions
+36
-12
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
...grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
+2
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
+1
-2
include/ck/utility/work_scheduling.hpp
include/ck/utility/work_scheduling.hpp
+33
-10
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
View file @
73adb83d
...
...
@@ -133,6 +133,7 @@ __global__ void
// Iterate over K dimension for this [M,N] tile
// still in the same GEMM && the same [M,N] tile
auto
k_tiles
=
work_scheduler
.
GetNextKTiles
(
k_batch
,
b2c_tile_map
.
GetTileKIdx
());
work_scheduler
.
SetIsSyncNeeded
(
k_tiles
,
k_batch
);
// just accumulate results in registers!
GridwiseGemm
::
template
RunGEMM
(
p_a_grid
,
...
...
@@ -874,6 +875,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
<<
std
::
string
(
ALayout
::
name
)[
0
]
<<
","
<<
std
::
string
(
BLayout
::
name
)[
0
]
<<
","
<<
std
::
string
(
ELayout
::
name
)[
0
]
<<
","
<<
NumGemmKPrefetchStage
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
View file @
73adb83d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
...
...
include/ck/utility/work_scheduling.hpp
View file @
73adb83d
...
...
@@ -33,11 +33,13 @@ class StridedReductionTileLoop
{
public:
__device__
StridedReductionTileLoop
(
index_t
tile_count
,
uint32_t
*
const
__restrict__
p_flags
)
:
tile_count_
{
tile_count
},
tiles_per_block_
{(
tile_count_
+
get_grid_size
()
-
1
)
/
get_grid_size
()},
tile_id_
{
get_block_1d_id
()
*
tiles_per_block_
},
block_tile_idx_
{
0
},
finished_block_flags_
{
p_flags
}
:
tile_count_
{
__builtin_amdgcn_readfirstlane
(
tile_count
)},
tiles_per_block_
{
__builtin_amdgcn_readfirstlane
((
tile_count_
+
get_grid_size
()
-
1
)
/
get_grid_size
())},
tile_id_
{
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
*
tiles_per_block_
)},
block_tile_idx_
{
__builtin_amdgcn_readfirstlane
(
0
)},
finished_block_flags_
{
p_flags
},
is_sync_needed_
{
1
}
{
}
...
...
@@ -80,11 +82,18 @@ class StridedReductionTileLoop
///
/// @brief Flag each workgroup that has finished its work.
///
__device__
void
FlagFinished
()
{
finished_block_flags_
.
inc
(
GetWorkgroupFlagIdx
());
}
__device__
void
FlagFinished
()
{
if
(
is_sync_needed_
)
finished_block_flags_
.
inc
(
GetWorkgroupFlagIdx
());
}
///
/// @brief Wait until each workgroup has finished its work.
///
/// @note This function assumes it's called by the WGP which processes the first
/// k-tile.
///
/// @param[in] k_tiles The number of tiles in the reduced dimension.
/// @param[in] k_tile_idx The currently processed tile k index.
///
...
...
@@ -95,6 +104,10 @@ class StridedReductionTileLoop
// We have to wait for all workgroups to finish their partial results.
// First count how many "neighbour" workgroups we have to check.
index_t
neighbour_count
=
0
;
if
(
!
is_sync_needed_
)
return
neighbour_count
;
if
(
tiles_per_block_
<
k_tiles
)
{
// Since we can have deviation (+/-1) in neighbours number
...
...
@@ -135,8 +148,9 @@ class StridedReductionTileLoop
///
__device__
void
WaitForReduction
()
{
// Wait untill my counter has been reset.
finished_block_flags_
.
wait_eq
(
GetWorkgroupFlagIdx
(),
0
);
if
(
is_sync_needed_
)
// Wait untill my counter has been reset.
finished_block_flags_
.
wait_eq
(
GetWorkgroupFlagIdx
(),
0
);
}
///
...
...
@@ -146,9 +160,12 @@ class StridedReductionTileLoop
///
__device__
void
Reset
(
index_t
neighbour_count
)
{
for
(
index_t
i
=
0
;
i
<=
neighbour_count
;
++
i
)
if
(
is_sync_needed_
)
{
finished_block_flags_
.
reset
(
GetWorkgroupFlagIdx
()
+
i
);
for
(
index_t
i
=
0
;
i
<=
neighbour_count
;
++
i
)
{
finished_block_flags_
.
reset
(
GetWorkgroupFlagIdx
()
+
i
);
}
}
}
...
...
@@ -160,11 +177,17 @@ class StridedReductionTileLoop
return
finished_block_flags_
.
ld
(
GetWorkgroupFlagIdx
());
}
__device__
void
SetIsSyncNeeded
(
index_t
next_k_tiles
,
index_t
k_tiles
)
{
is_sync_needed_
=
__builtin_amdgcn_readfirstlane
(
next_k_tiles
==
k_tiles
?
0
:
1
);
}
const
index_t
tile_count_
;
const
index_t
tiles_per_block_
;
index_t
tile_id_
;
index_t
block_tile_idx_
;
workgroup_barrier
finished_block_flags_
;
index_t
is_sync_needed_
;
};
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment