Commit b753bbc5 authored by carlushuang's avatar carlushuang
Browse files

fix a bug in final reduce

parent d19487eb
......@@ -168,9 +168,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
else if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
char* workspace_semaphore = reinterpret_cast<char*>(karg.p_workspace_);
workspace_semaphore =
workspace_semaphore + karg.block_mapping.get_workspace_size_for_acc(
char* workspace_semaphore = reinterpret_cast<char*>(karg.p_workspace_) +
karg.block_mapping.get_workspace_size_for_acc(
sizeof(typename GridwiseGemm::FloatAcc));
auto preprocess = [&]() {
hipGetErrorString(
......
......@@ -562,7 +562,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
constexpr auto acc_thread_buf_store_desc = make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
......@@ -572,7 +571,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
0, cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
constexpr auto partial_acc_load_step_n_reverse =
make_multi_index(0,
-1 * (MReduceIters - 1) * cluster_length_reduce.At(I1) *
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
CBlockTransferScalarPerVector_NWaveNPerXDL);
constexpr auto partial_acc_load_step_m =
make_multi_index(cluster_length_reduce.At(I0), 0);
......@@ -653,17 +652,22 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// block synchronization
wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start);
// if(threadIdx.x == 0) {
#if 0
if(threadIdx.x == 0) {
// if(reduction_idx == 0){
// printf("(miter:%d, niter:%d, cluster red:%d,%d)bid:%d, rid:%d, os:%d-%d(%d),
// spatial:%d-%d, tid:%d, %d, %d\n",
// MReduceIters(), NReduceIters(), cluster_length_reduce.At(I0).value,
// printf("(cluster red:%d,%d)bid:%d, rid:%d, os:%d-%d(%d), spatial:%d-%d, tid:%d, %d, %d\n",
// cluster_length_reduce.At(I0).value,
// cluster_length_reduce.At(I1).value, static_cast<int>(blockIdx.x),
// reduction_idx, tile_acc_offset_start, tile_acc_offset_end,
// tile_acc_offset_end - tile_acc_offset_start, spatial_idx[I0],
// spatial_idx[I1], static_cast<int>(threadIdx.x), thread_m_cluster_id,
// thread_n_cluster_id);
// }
printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(blockIdx.x),
reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end),
__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
__builtin_amdgcn_readfirstlane(spatial_idx[I1]));
}
#endif
using Accumulation = ck::detail::
AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, FloatAcc>;
......@@ -698,11 +702,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
});
}
if(thread_n_cluster_id * CBlockTransferScalarPerVector_NWaveNPerXDL <
NPerBlock)
{
acc_store.Run(acc_thread_buf_store_desc,
make_tuple(I0, I0, I0, I0),
acc_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
if constexpr(NReduceIters != 1)
{
if constexpr(i_n_reduce != (NReduceIters - 1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment