Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9cb25b86
Commit
9cb25b86
authored
Feb 07, 2024
by
Adam Osewski
Browse files
Use memory fence and volatile attribute for synchronization flags.
parent
a74b2263
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
9 deletions
+11
-9
include/ck/utility/work_scheduling.hpp
include/ck/utility/work_scheduling.hpp
+2
-2
include/ck/utility/workgroup_barrier.hpp
include/ck/utility/workgroup_barrier.hpp
+5
-5
test/work_scheduling/test_strided_reduction_tile_loop.cpp
test/work_scheduling/test_strided_reduction_tile_loop.cpp
+4
-2
No files found.
include/ck/utility/work_scheduling.hpp
View file @
9cb25b86
...
...
@@ -33,12 +33,12 @@ class StridedReductionTileLoop
{
public:
__device__
StridedReductionTileLoop
(
index_t
tile_count
,
uint32_t
*
const
__restrict__
p_flag
_count
)
volatile
uint32_t
*
const
__restrict__
p_flag
s
)
:
tile_count_
{
tile_count
},
tiles_per_block_
{(
tile_count_
+
get_grid_size
()
-
1
)
/
get_grid_size
()},
tile_id_
{
get_block_1d_id
()
*
tiles_per_block_
},
block_tile_idx_
{
0
},
finished_block_flags_
{
p_flag
_count
}
finished_block_flags_
{
p_flag
s
}
{
}
...
...
include/ck/utility/workgroup_barrier.hpp
View file @
9cb25b86
...
...
@@ -5,7 +5,7 @@
namespace
ck
{
struct
workgroup_barrier
{
__device__
workgroup_barrier
(
uint32_t
*
ptr
)
:
base_ptr
(
ptr
)
{}
__device__
workgroup_barrier
(
volatile
uint32_t
*
ptr
)
:
base_ptr
(
ptr
)
{}
__device__
uint32_t
ld
(
uint32_t
offset
)
const
{
...
...
@@ -53,7 +53,7 @@ struct workgroup_barrier
{
if
(
threadIdx
.
x
==
0
)
{
while
(
atomicCAS
(
base_ptr
+
offset
,
compare
,
value
)
!=
compare
)
{}
while
(
atomicCAS
(
const_cast
<
uint32_t
*>
(
base_ptr
+
offset
)
,
compare
,
value
)
!=
compare
)
{}
}
__syncthreads
();
}
...
...
@@ -66,11 +66,11 @@ struct workgroup_barrier
__device__
void
inc
(
uint32_t
offset
)
{
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
atomicAdd
(
base_ptr
+
offset
,
1
);
atomicAdd
(
const_cast
<
uint32_t
*>
(
base_ptr
+
offset
)
,
1
);
}
__syncthreads
();
}
__device__
void
reset
(
uint32_t
offset
)
...
...
@@ -82,6 +82,6 @@ struct workgroup_barrier
__syncthreads
();
}
uint32_t
*
base_ptr
;
volatile
uint32_t
*
base_ptr
;
};
}
// namespace ck
test/work_scheduling/test_strided_reduction_tile_loop.cpp
View file @
9cb25b86
...
...
@@ -48,8 +48,8 @@ struct GemmArgDesc
template
<
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
>
__global__
void
grouped_gemm_naive_strided_tile_loop_reduce
(
const
GemmArgDesc
*
p_gemm_descs
,
float
*
p_workspace
,
uint32_t
*
p_flags
,
volatile
float
*
p_workspace
,
volatile
uint32_t
*
p_flags
,
index_t
tile_count
,
index_t
k_batch
)
{
...
...
@@ -139,6 +139,7 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
p_workspace
[
get_block_1d_id
()
*
MPerBlock
*
NPerBlock
+
get_thread_local_1d_id
()]
=
partial_result
;
}
__threadfence
();
const
index_t
output_tile_idx
=
__builtin_amdgcn_readfirstlane
(
b2c_tile_map
.
GetOutputTileIdx
());
...
...
@@ -160,6 +161,7 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
partial_result
+=
p_workspace
[(
get_block_1d_id
())
*
MPerBlock
*
NPerBlock
+
i
*
MPerBlock
*
NPerBlock
+
get_thread_local_1d_id
()];
}
__threadfence
();
// Signal waiting blocks that they can start use their workspace.
work_scheduler
.
Reset
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment