Commit b753bbc5 authored by carlushuang's avatar carlushuang
Browse files

fix a bug in final reduce

parent d19487eb
...@@ -168,9 +168,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -168,9 +168,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
else if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy == else if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
StreamKReductionStrategy::Reduction) StreamKReductionStrategy::Reduction)
{ {
char* workspace_semaphore = reinterpret_cast<char*>(karg.p_workspace_); char* workspace_semaphore = reinterpret_cast<char*>(karg.p_workspace_) +
workspace_semaphore = karg.block_mapping.get_workspace_size_for_acc(
workspace_semaphore + karg.block_mapping.get_workspace_size_for_acc(
sizeof(typename GridwiseGemm::FloatAcc)); sizeof(typename GridwiseGemm::FloatAcc));
auto preprocess = [&]() { auto preprocess = [&]() {
hipGetErrorString( hipGetErrorString(
......
...@@ -562,7 +562,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -562,7 +562,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed( constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{})); make_tuple(I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
constexpr auto acc_thread_buf_store_desc = make_naive_tensor_descriptor_packed( constexpr auto acc_thread_buf_store_desc = make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{})); make_tuple(I1, I1, I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
...@@ -572,7 +571,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -572,7 +571,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
0, cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL); 0, cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
constexpr auto partial_acc_load_step_n_reverse = constexpr auto partial_acc_load_step_n_reverse =
make_multi_index(0, make_multi_index(0,
-1 * (MReduceIters - 1) * cluster_length_reduce.At(I1) * -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
CBlockTransferScalarPerVector_NWaveNPerXDL); CBlockTransferScalarPerVector_NWaveNPerXDL);
constexpr auto partial_acc_load_step_m = constexpr auto partial_acc_load_step_m =
make_multi_index(cluster_length_reduce.At(I0), 0); make_multi_index(cluster_length_reduce.At(I0), 0);
...@@ -653,17 +652,22 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -653,17 +652,22 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// block synchronization // block synchronization
wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start); wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start);
// if(threadIdx.x == 0) { #if 0
if(threadIdx.x == 0) {
// if(reduction_idx == 0){ // if(reduction_idx == 0){
// printf("(miter:%d, niter:%d, cluster red:%d,%d)bid:%d, rid:%d, os:%d-%d(%d), // printf("(cluster red:%d,%d)bid:%d, rid:%d, os:%d-%d(%d), spatial:%d-%d, tid:%d, %d, %d\n",
// spatial:%d-%d, tid:%d, %d, %d\n", // cluster_length_reduce.At(I0).value,
// MReduceIters(), NReduceIters(), cluster_length_reduce.At(I0).value,
// cluster_length_reduce.At(I1).value, static_cast<int>(blockIdx.x), // cluster_length_reduce.At(I1).value, static_cast<int>(blockIdx.x),
// reduction_idx, tile_acc_offset_start, tile_acc_offset_end, // reduction_idx, tile_acc_offset_start, tile_acc_offset_end,
// tile_acc_offset_end - tile_acc_offset_start, spatial_idx[I0], // tile_acc_offset_end - tile_acc_offset_start, spatial_idx[I0],
// spatial_idx[I1], static_cast<int>(threadIdx.x), thread_m_cluster_id, // spatial_idx[I1], static_cast<int>(threadIdx.x), thread_m_cluster_id,
// thread_n_cluster_id); // thread_n_cluster_id);
// } printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(blockIdx.x),
reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end),
__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
__builtin_amdgcn_readfirstlane(spatial_idx[I1]));
}
#endif
using Accumulation = ck::detail:: using Accumulation = ck::detail::
AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, FloatAcc>; AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, FloatAcc>;
...@@ -698,11 +702,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -698,11 +702,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
}); });
} }
if(thread_n_cluster_id * CBlockTransferScalarPerVector_NWaveNPerXDL <
NPerBlock)
{
acc_store.Run(acc_thread_buf_store_desc, acc_store.Run(acc_thread_buf_store_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
acc_buf, acc_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf); c_grid_buf);
}
if constexpr(NReduceIters != 1) if constexpr(NReduceIters != 1)
{ {
if constexpr(i_n_reduce != (NReduceIters - 1)) if constexpr(i_n_reduce != (NReduceIters - 1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment