"vscode:/vscode.git/clone" did not exist on "325a5de3a9acc97534a4446ce9dd4147efcd61a0"
Commit d8dc850e authored by Adam Osewski's avatar Adam Osewski
Browse files

Use buffer loads and proper cache coherence.

parent b398481e
......@@ -48,8 +48,8 @@ struct GemmArgDesc
template <index_t MPerBlock, index_t NPerBlock, index_t KPerBlock>
__global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p_gemm_descs,
volatile float* p_workspace,
volatile uint32_t* p_flags,
float* p_workspace,
uint32_t* p_flags,
index_t tile_count,
index_t k_batch)
{
......@@ -88,9 +88,9 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
const index_t N = p_gemm_descs[group_id].N;
const index_t K = p_gemm_descs[group_id].K;
const auto p_A = p_gemm_descs[group_id].p_A;
const auto p_B = p_gemm_descs[group_id].p_B;
const auto p_C = p_gemm_descs[group_id].p_C;
auto p_A = const_cast<float*>(p_gemm_descs[group_id].p_A);
auto p_B = const_cast<float*>(p_gemm_descs[group_id].p_B);
auto p_C = p_gemm_descs[group_id].p_C;
const auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N));
BlockToCTileMap_LinearKSplit<MPerBlock, NPerBlock> b2c_tile_map(c_grid_desc_m_n, k_batch);
......@@ -124,11 +124,29 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
const index_t B_k_tile_offset = k_batch_id * KPerBlock;
const index_t B_thread_tile_n_idx = get_thread_local_1d_id() % NPerBlock;
auto a_buffer_resource = make_wave_buffer_resource_with_default_range<float>(
p_A + A_m_tile_offset * stride_a + A_k_tile_offset);
auto b_buffer_resource = make_wave_buffer_resource_with_default_range<float>(
p_B + B_k_tile_offset * stride_b + B_n_tile_offset);
for(index_t k = 0; k < KPerBlock; ++k)
{
partial_result +=
p_A[(A_m_tile_offset + A_thread_tile_m_idx) * stride_a + A_k_tile_offset + k] *
p_B[(B_k_tile_offset + k) * stride_b + B_n_tile_offset + B_thread_tile_n_idx];
float a_val = llvm_amdgcn_raw_buffer_load_fp32(
a_buffer_resource,
(A_thread_tile_m_idx * stride_a + k) * sizeof(float),
0,
static_cast<index_t>(AmdBufferCoherenceEnum::DefaultCoherence));
float b_val = llvm_amdgcn_raw_buffer_load_fp32(
b_buffer_resource,
(k * stride_b + B_thread_tile_n_idx) * sizeof(float),
0,
static_cast<index_t>(AmdBufferCoherenceEnum::DefaultCoherence));
partial_result += a_val * b_val;
// partial_result +=
// p_A[(A_m_tile_offset + A_thread_tile_m_idx) * stride_a + A_k_tile_offset + k]
// * p_B[(B_k_tile_offset + k) * stride_b + B_n_tile_offset +
// B_thread_tile_n_idx];
}
} while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx());
......@@ -136,10 +154,17 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
if(!b2c_tile_map.IsFirstKSplitBlock())
{
// Assume we have MPerBlock x NPerBlock tile per each workgroup in contiguous memory.
p_workspace[get_block_1d_id() * MPerBlock * NPerBlock + get_thread_local_1d_id()] =
partial_result;
auto w_buffer_resource = make_wave_buffer_resource_with_default_range<float>(
p_workspace + get_block_1d_id() * MPerBlock * NPerBlock);
llvm_amdgcn_raw_buffer_store_fp32(partial_result,
w_buffer_resource,
get_thread_local_1d_id() * sizeof(float),
0,
static_cast<index_t>(AmdBufferCoherenceEnum::GLC));
// p_workspace[get_block_1d_id() * MPerBlock * NPerBlock + get_thread_local_1d_id()] =
// partial_result;
}
__threadfence();
const index_t output_tile_idx =
__builtin_amdgcn_readfirstlane(b2c_tile_map.GetOutputTileIdx());
......@@ -158,10 +183,21 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
// read actual flag value.
for(index_t i = 1; i < neighbour_count; ++i)
{
partial_result += p_workspace[(get_block_1d_id()) * MPerBlock * NPerBlock +
i * MPerBlock * NPerBlock + get_thread_local_1d_id()];
// partial_result += p_workspace[(get_block_1d_id()) * MPerBlock * NPerBlock +
// i * MPerBlock * NPerBlock +
// get_thread_local_1d_id()];
auto w_buffer_resource = make_wave_buffer_resource_with_default_range<float>(
p_workspace + get_block_1d_id() * MPerBlock * NPerBlock +
i * MPerBlock * NPerBlock);
float value = llvm_amdgcn_raw_buffer_load_fp32(
w_buffer_resource,
get_thread_local_1d_id() * sizeof(float),
0,
static_cast<index_t>(AmdBufferCoherenceEnum::GLC));
partial_result += value;
}
__threadfence();
// Signal waiting blocks that they can start use their workspace.
work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset);
......@@ -171,8 +207,17 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
const index_t C_n_tile_offset = block_n_id * NPerBlock;
const index_t C_thread_tile_n_idx = get_thread_local_1d_id() % NPerBlock;
p_C[(C_m_tile_offset + C_thread_tile_m_idx) * stride_c + C_n_tile_offset +
C_thread_tile_n_idx] = partial_result;
auto c_buffer_resource = make_wave_buffer_resource_with_default_range<float>(
p_C + C_m_tile_offset * stride_c + C_n_tile_offset);
llvm_amdgcn_raw_buffer_store_fp32(
partial_result,
c_buffer_resource,
(C_thread_tile_m_idx * stride_c + C_thread_tile_n_idx) * sizeof(float),
0,
static_cast<index_t>(AmdBufferCoherenceEnum::DefaultCoherence));
// p_C[(C_m_tile_offset + C_thread_tile_m_idx) * stride_c + C_n_tile_offset +
// C_thread_tile_n_idx] = partial_result;
}
else if(work_scheduler.HasTile())
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment