Commit 9cb25b86 authored by Adam Osewski's avatar Adam Osewski
Browse files

Use memory fence and volatile attribute for synchronization flags.

parent a74b2263
......@@ -33,12 +33,12 @@ class StridedReductionTileLoop
{
public:
__device__ StridedReductionTileLoop(index_t tile_count,
uint32_t* const __restrict__ p_flag_count)
volatile uint32_t* const __restrict__ p_flags)
: tile_count_{tile_count},
tiles_per_block_{(tile_count_ + get_grid_size() - 1) / get_grid_size()},
tile_id_{get_block_1d_id() * tiles_per_block_},
block_tile_idx_{0},
finished_block_flags_{p_flag_count}
finished_block_flags_{p_flags}
{
}
......
......@@ -5,7 +5,7 @@
namespace ck {
struct workgroup_barrier
{
__device__ workgroup_barrier(uint32_t* ptr) : base_ptr(ptr) {}
__device__ workgroup_barrier(volatile uint32_t* ptr) : base_ptr(ptr) {}
__device__ uint32_t ld(uint32_t offset) const
{
......@@ -53,7 +53,7 @@ struct workgroup_barrier
{
if(threadIdx.x == 0)
{
while(atomicCAS(base_ptr + offset, compare, value) != compare) {}
while(atomicCAS(const_cast<uint32_t*>(base_ptr + offset), compare, value) != compare) {}
}
__syncthreads();
}
......@@ -66,11 +66,11 @@ struct workgroup_barrier
__device__ void inc(uint32_t offset)
{
__syncthreads();
if(threadIdx.x == 0)
{
atomicAdd(base_ptr + offset, 1);
atomicAdd(const_cast<uint32_t*>(base_ptr + offset), 1);
}
__syncthreads();
}
__device__ void reset(uint32_t offset)
......@@ -82,6 +82,6 @@ struct workgroup_barrier
__syncthreads();
}
uint32_t* base_ptr;
volatile uint32_t* base_ptr;
};
} // namespace ck
......@@ -48,8 +48,8 @@ struct GemmArgDesc
template <index_t MPerBlock, index_t NPerBlock, index_t KPerBlock>
__global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p_gemm_descs,
float* p_workspace,
uint32_t* p_flags,
volatile float* p_workspace,
volatile uint32_t* p_flags,
index_t tile_count,
index_t k_batch)
{
......@@ -139,6 +139,7 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
p_workspace[get_block_1d_id() * MPerBlock * NPerBlock + get_thread_local_1d_id()] =
partial_result;
}
__threadfence();
const index_t output_tile_idx =
__builtin_amdgcn_readfirstlane(b2c_tile_map.GetOutputTileIdx());
......@@ -160,6 +161,7 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
partial_result += p_workspace[(get_block_1d_id()) * MPerBlock * NPerBlock +
i * MPerBlock * NPerBlock + get_thread_local_1d_id()];
}
__threadfence();
// Signal waiting blocks that they can start use their workspace.
work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment