Commit 4ae9919e authored by Anthony Chang's avatar Anthony Chang
Browse files

strictly follow natural indexing for traversing P tile to avoid jumping accesses (no snake pattern)

parent b67a58c0
......@@ -15,7 +15,7 @@ Outputs:
*/
#pragma clang diagnostic ignored "-Wunused-variable"
#pragma clang diagnostic ignored "-Wunused-variable" // TODO ANT: remove
#define PRINT_HOST 0
......@@ -96,7 +96,7 @@ using DeviceGemmInstance =
TensorSpecY,
1,
256,
128, // MPerBlock
256, // MPerBlock
128, // NPerBlock
32, // KPerBlock
64, // Gemm1NPerBlock
......@@ -106,7 +106,7 @@ using DeviceGemmInstance =
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
2, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
......
......@@ -1042,7 +1042,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr auto sfc_p_m0_n0_m1_n1_m2_n2 =
SpaceFillingCurve<Sequence<P_M0, P_N0, P_M1, P_N1>,
Sequence<0, 1, 2, 3>,
decltype(p_block_slice_lengths_m0_n0_m1_n1)>{};
decltype(p_block_slice_lengths_m0_n0_m1_n1),
false>{};
constexpr auto ygrad_block_desc_m0_o_m1 =
VGradGemmTile_N_O_M::GetYGradBlockDescriptor_M0_O_M1();
......@@ -1369,6 +1370,21 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
p_nd_idx[I2], p_nd_idx[I2] + p_block_slice_lengths_m0_n0_m1_n1[I2]);
constexpr auto nwave_range = make_tuple(
p_nd_idx[I3], p_nd_idx[I3] + p_block_slice_lengths_m0_n0_m1_n1[I3]);
#if 0
if(hipThreadIdx_x % 64 == 0)
{
printf(
"VGrad P vgrad_gemm_loop_idx %d, wave_id = %d, mrepeat, nrepeat, mwave, "
"nwave = %d, %d, %d, %d, active %d\n",
vgrad_gemm_loop_idx.value,
(int)hipThreadIdx_x / 64,
p_nd_idx[I0].value,
p_nd_idx[I1].value,
p_nd_idx[I2].value,
p_nd_idx[I3].value,
p_thread_copy_subgroup.IsBelong(mwave_range, nwave_range));
}
#endif
if (p_thread_copy_subgroup.IsBelong(mwave_range, nwave_range))
{
p_thread_copy_vgpr_to_lds.Run(
......@@ -1406,7 +1422,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
block_sync_lds(); // sync before read
vgrad_blockwise_gemm.Run(p_block_buf, ygrad_block_buf, vgrad_acc_thread_buf);
#if 1
#if 0
if(hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
printf("outer %d inner %d tid %zd, dV[0:3] = %f, %f, %f, %f\n",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment