"...resnet50_tensorflow.git" did not exist on "4701bfbc924539860f610fa4ceae484a7bf194c6"
Commit 9cb25b86 authored by Adam Osewski's avatar Adam Osewski
Browse files

Use memory fence and volatile attribute for synchronization flags.

parent a74b2263
...@@ -33,12 +33,12 @@ class StridedReductionTileLoop ...@@ -33,12 +33,12 @@ class StridedReductionTileLoop
{ {
public: public:
__device__ StridedReductionTileLoop(index_t tile_count, __device__ StridedReductionTileLoop(index_t tile_count,
uint32_t* const __restrict__ p_flag_count) volatile uint32_t* const __restrict__ p_flags)
: tile_count_{tile_count}, : tile_count_{tile_count},
tiles_per_block_{(tile_count_ + get_grid_size() - 1) / get_grid_size()}, tiles_per_block_{(tile_count_ + get_grid_size() - 1) / get_grid_size()},
tile_id_{get_block_1d_id() * tiles_per_block_}, tile_id_{get_block_1d_id() * tiles_per_block_},
block_tile_idx_{0}, block_tile_idx_{0},
finished_block_flags_{p_flag_count} finished_block_flags_{p_flags}
{ {
} }
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
namespace ck { namespace ck {
struct workgroup_barrier struct workgroup_barrier
{ {
__device__ workgroup_barrier(uint32_t* ptr) : base_ptr(ptr) {} __device__ workgroup_barrier(volatile uint32_t* ptr) : base_ptr(ptr) {}
__device__ uint32_t ld(uint32_t offset) const __device__ uint32_t ld(uint32_t offset) const
{ {
...@@ -53,7 +53,7 @@ struct workgroup_barrier ...@@ -53,7 +53,7 @@ struct workgroup_barrier
{ {
if(threadIdx.x == 0) if(threadIdx.x == 0)
{ {
while(atomicCAS(base_ptr + offset, compare, value) != compare) {} while(atomicCAS(const_cast<uint32_t*>(base_ptr + offset), compare, value) != compare) {}
} }
__syncthreads(); __syncthreads();
} }
...@@ -66,11 +66,11 @@ struct workgroup_barrier ...@@ -66,11 +66,11 @@ struct workgroup_barrier
__device__ void inc(uint32_t offset) __device__ void inc(uint32_t offset)
{ {
__syncthreads();
if(threadIdx.x == 0) if(threadIdx.x == 0)
{ {
atomicAdd(base_ptr + offset, 1); atomicAdd(const_cast<uint32_t*>(base_ptr + offset), 1);
} }
__syncthreads();
} }
__device__ void reset(uint32_t offset) __device__ void reset(uint32_t offset)
...@@ -82,6 +82,6 @@ struct workgroup_barrier ...@@ -82,6 +82,6 @@ struct workgroup_barrier
__syncthreads(); __syncthreads();
} }
uint32_t* base_ptr; volatile uint32_t* base_ptr;
}; };
} // namespace ck } // namespace ck
...@@ -48,8 +48,8 @@ struct GemmArgDesc ...@@ -48,8 +48,8 @@ struct GemmArgDesc
template <index_t MPerBlock, index_t NPerBlock, index_t KPerBlock> template <index_t MPerBlock, index_t NPerBlock, index_t KPerBlock>
__global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p_gemm_descs, __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p_gemm_descs,
float* p_workspace, volatile float* p_workspace,
uint32_t* p_flags, volatile uint32_t* p_flags,
index_t tile_count, index_t tile_count,
index_t k_batch) index_t k_batch)
{ {
...@@ -139,6 +139,7 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p ...@@ -139,6 +139,7 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
p_workspace[get_block_1d_id() * MPerBlock * NPerBlock + get_thread_local_1d_id()] = p_workspace[get_block_1d_id() * MPerBlock * NPerBlock + get_thread_local_1d_id()] =
partial_result; partial_result;
} }
__threadfence();
const index_t output_tile_idx = const index_t output_tile_idx =
__builtin_amdgcn_readfirstlane(b2c_tile_map.GetOutputTileIdx()); __builtin_amdgcn_readfirstlane(b2c_tile_map.GetOutputTileIdx());
...@@ -160,6 +161,7 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p ...@@ -160,6 +161,7 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
partial_result += p_workspace[(get_block_1d_id()) * MPerBlock * NPerBlock + partial_result += p_workspace[(get_block_1d_id()) * MPerBlock * NPerBlock +
i * MPerBlock * NPerBlock + get_thread_local_1d_id()]; i * MPerBlock * NPerBlock + get_thread_local_1d_id()];
} }
__threadfence();
// Signal waiting blocks that they can start use their workspace. // Signal waiting blocks that they can start use their workspace.
work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset); work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment