"test/vscode:/vscode.git/clone" did not exist on "d4e5a3ea93d1e0da847ff96420efbd40862164a9"
Commit d8dc850e authored by Adam Osewski's avatar Adam Osewski
Browse files

Use buffer loads and proper cache coherence.

parent b398481e
...@@ -48,8 +48,8 @@ struct GemmArgDesc ...@@ -48,8 +48,8 @@ struct GemmArgDesc
template <index_t MPerBlock, index_t NPerBlock, index_t KPerBlock> template <index_t MPerBlock, index_t NPerBlock, index_t KPerBlock>
__global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p_gemm_descs, __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p_gemm_descs,
volatile float* p_workspace, float* p_workspace,
volatile uint32_t* p_flags, uint32_t* p_flags,
index_t tile_count, index_t tile_count,
index_t k_batch) index_t k_batch)
{ {
...@@ -88,9 +88,9 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p ...@@ -88,9 +88,9 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
const index_t N = p_gemm_descs[group_id].N; const index_t N = p_gemm_descs[group_id].N;
const index_t K = p_gemm_descs[group_id].K; const index_t K = p_gemm_descs[group_id].K;
const auto p_A = p_gemm_descs[group_id].p_A; auto p_A = const_cast<float*>(p_gemm_descs[group_id].p_A);
const auto p_B = p_gemm_descs[group_id].p_B; auto p_B = const_cast<float*>(p_gemm_descs[group_id].p_B);
const auto p_C = p_gemm_descs[group_id].p_C; auto p_C = p_gemm_descs[group_id].p_C;
const auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); const auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N));
BlockToCTileMap_LinearKSplit<MPerBlock, NPerBlock> b2c_tile_map(c_grid_desc_m_n, k_batch); BlockToCTileMap_LinearKSplit<MPerBlock, NPerBlock> b2c_tile_map(c_grid_desc_m_n, k_batch);
...@@ -124,11 +124,29 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p ...@@ -124,11 +124,29 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
const index_t B_k_tile_offset = k_batch_id * KPerBlock; const index_t B_k_tile_offset = k_batch_id * KPerBlock;
const index_t B_thread_tile_n_idx = get_thread_local_1d_id() % NPerBlock; const index_t B_thread_tile_n_idx = get_thread_local_1d_id() % NPerBlock;
auto a_buffer_resource = make_wave_buffer_resource_with_default_range<float>(
p_A + A_m_tile_offset * stride_a + A_k_tile_offset);
auto b_buffer_resource = make_wave_buffer_resource_with_default_range<float>(
p_B + B_k_tile_offset * stride_b + B_n_tile_offset);
for(index_t k = 0; k < KPerBlock; ++k) for(index_t k = 0; k < KPerBlock; ++k)
{ {
partial_result += float a_val = llvm_amdgcn_raw_buffer_load_fp32(
p_A[(A_m_tile_offset + A_thread_tile_m_idx) * stride_a + A_k_tile_offset + k] * a_buffer_resource,
p_B[(B_k_tile_offset + k) * stride_b + B_n_tile_offset + B_thread_tile_n_idx]; (A_thread_tile_m_idx * stride_a + k) * sizeof(float),
0,
static_cast<index_t>(AmdBufferCoherenceEnum::DefaultCoherence));
float b_val = llvm_amdgcn_raw_buffer_load_fp32(
b_buffer_resource,
(k * stride_b + B_thread_tile_n_idx) * sizeof(float),
0,
static_cast<index_t>(AmdBufferCoherenceEnum::DefaultCoherence));
partial_result += a_val * b_val;
// partial_result +=
// p_A[(A_m_tile_offset + A_thread_tile_m_idx) * stride_a + A_k_tile_offset + k]
// * p_B[(B_k_tile_offset + k) * stride_b + B_n_tile_offset +
// B_thread_tile_n_idx];
} }
} while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx()); } while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx());
...@@ -136,10 +154,17 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p ...@@ -136,10 +154,17 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
if(!b2c_tile_map.IsFirstKSplitBlock()) if(!b2c_tile_map.IsFirstKSplitBlock())
{ {
// Assume we have MPerBlock x NPerBlock tile per each workgroup in contiguous memory. // Assume we have MPerBlock x NPerBlock tile per each workgroup in contiguous memory.
p_workspace[get_block_1d_id() * MPerBlock * NPerBlock + get_thread_local_1d_id()] = auto w_buffer_resource = make_wave_buffer_resource_with_default_range<float>(
partial_result; p_workspace + get_block_1d_id() * MPerBlock * NPerBlock);
llvm_amdgcn_raw_buffer_store_fp32(partial_result,
w_buffer_resource,
get_thread_local_1d_id() * sizeof(float),
0,
static_cast<index_t>(AmdBufferCoherenceEnum::GLC));
// p_workspace[get_block_1d_id() * MPerBlock * NPerBlock + get_thread_local_1d_id()] =
// partial_result;
} }
__threadfence();
const index_t output_tile_idx = const index_t output_tile_idx =
__builtin_amdgcn_readfirstlane(b2c_tile_map.GetOutputTileIdx()); __builtin_amdgcn_readfirstlane(b2c_tile_map.GetOutputTileIdx());
...@@ -158,10 +183,21 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p ...@@ -158,10 +183,21 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
// read actual flag value. // read actual flag value.
for(index_t i = 1; i < neighbour_count; ++i) for(index_t i = 1; i < neighbour_count; ++i)
{ {
partial_result += p_workspace[(get_block_1d_id()) * MPerBlock * NPerBlock + // partial_result += p_workspace[(get_block_1d_id()) * MPerBlock * NPerBlock +
i * MPerBlock * NPerBlock + get_thread_local_1d_id()]; // i * MPerBlock * NPerBlock +
// get_thread_local_1d_id()];
auto w_buffer_resource = make_wave_buffer_resource_with_default_range<float>(
p_workspace + get_block_1d_id() * MPerBlock * NPerBlock +
i * MPerBlock * NPerBlock);
float value = llvm_amdgcn_raw_buffer_load_fp32(
w_buffer_resource,
get_thread_local_1d_id() * sizeof(float),
0,
static_cast<index_t>(AmdBufferCoherenceEnum::GLC));
partial_result += value;
} }
__threadfence();
// Signal waiting blocks that they can start use their workspace. // Signal waiting blocks that they can start use their workspace.
work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset); work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset);
...@@ -171,8 +207,17 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p ...@@ -171,8 +207,17 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
const index_t C_n_tile_offset = block_n_id * NPerBlock; const index_t C_n_tile_offset = block_n_id * NPerBlock;
const index_t C_thread_tile_n_idx = get_thread_local_1d_id() % NPerBlock; const index_t C_thread_tile_n_idx = get_thread_local_1d_id() % NPerBlock;
p_C[(C_m_tile_offset + C_thread_tile_m_idx) * stride_c + C_n_tile_offset + auto c_buffer_resource = make_wave_buffer_resource_with_default_range<float>(
C_thread_tile_n_idx] = partial_result; p_C + C_m_tile_offset * stride_c + C_n_tile_offset);
llvm_amdgcn_raw_buffer_store_fp32(
partial_result,
c_buffer_resource,
(C_thread_tile_m_idx * stride_c + C_thread_tile_n_idx) * sizeof(float),
0,
static_cast<index_t>(AmdBufferCoherenceEnum::DefaultCoherence));
// p_C[(C_m_tile_offset + C_thread_tile_m_idx) * stride_c + C_n_tile_offset +
// C_thread_tile_n_idx] = partial_result;
} }
else if(work_scheduler.HasTile()) else if(work_scheduler.HasTile())
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment