Commit 39f5d9d4 authored by Anthony Chang's avatar Anthony Chang
Browse files

post-refactor fix for output coord in dV/dK

parent ad93b411
...@@ -1416,7 +1416,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1416,7 +1416,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const auto vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4 = const auto vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4 =
Gemm2::GetCThreadOriginOnBlock_N0_O0_N1_O1_N2_O2_O3_O4() + Gemm2::GetCThreadOriginOnBlock_N0_O0_N1_O1_N2_O2_O3_O4() +
make_multi_index(I0, block_work_idx[I1], I0, I0, I0, I0, I0, I0); make_multi_index(
I0, block_work_idx[I1] * Gemm2Params_N_O_M::GemmORepeat, I0, I0, I0, I0, I0, I0);
auto vgrad_thread_copy_vgpr_to_global = typename Gemm2::template CBlockwiseCopy<decltype( auto vgrad_thread_copy_vgpr_to_global = typename Gemm2::template CBlockwiseCopy<decltype(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4)>( vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4)>(
...@@ -1457,7 +1458,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1457,7 +1458,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const auto kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4 = const auto kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4 =
Gemm2::GetCThreadOriginOnBlock_N0_O0_N1_O1_N2_O2_O3_O4() + Gemm2::GetCThreadOriginOnBlock_N0_O0_N1_O1_N2_O2_O3_O4() +
make_multi_index(I0, block_work_idx[I1], I0, I0, I0, I0, I0, I0); make_multi_index(
I0, block_work_idx[I1] * Gemm2Params_N_O_M::GemmORepeat, I0, I0, I0, I0, I0, I0);
auto kgrad_thread_copy_vgpr_to_global = typename Gemm2::template CBlockwiseCopy<decltype( auto kgrad_thread_copy_vgpr_to_global = typename Gemm2::template CBlockwiseCopy<decltype(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4)>( kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4)>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment