Commit d13c92bd authored by Anthony Chang's avatar Anthony Chang
Browse files

remove printf's

parent 825f7f02
...@@ -1358,22 +1358,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1358,22 +1358,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
o_thread_data_nd_idx_on_grid[I4]), o_thread_data_nd_idx_on_grid[I4]),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
#if 0
if(hipThreadIdx_x % 32 < 4)
{
printf("wid %zd tid %zd _n0_o0_n1_o1_n2_o2_o3_o4 %d %d %d %d %d %d %d %d\n",
hipBlockIdx_x,
hipThreadIdx_x,
n_thread_data_nd_idx_on_grid[I0],
o_thread_data_nd_idx_on_grid[I0],
n_thread_data_nd_idx_on_grid[I1],
o_thread_data_nd_idx_on_grid[I1],
n_thread_data_nd_idx_on_grid[I2],
o_thread_data_nd_idx_on_grid[I2],
o_thread_data_nd_idx_on_grid[I3],
o_thread_data_nd_idx_on_grid[I4]);
}
#endif
// p_thread_slice_copy_step will be in for loop // p_thread_slice_copy_step will be in for loop
constexpr auto ygrad_block_slice_copy_step = constexpr auto ygrad_block_slice_copy_step =
make_multi_index(VGradGemmTile_N_O_M::YGrad_M0, 0, 0); make_multi_index(VGradGemmTile_N_O_M::YGrad_M0, 0, 0);
...@@ -1383,17 +1367,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1383,17 +1367,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// vgrad gemm output tile // vgrad gemm output tile
const auto vgrad_block_slice_copy_step = const auto vgrad_block_slice_copy_step =
make_multi_index(VGradGemmTile_N_O_M::GemmNRepeat, 0, 0, 0, 0, 0, 0, 0); make_multi_index(VGradGemmTile_N_O_M::GemmNRepeat, 0, 0, 0, 0, 0, 0, 0);
#if 0
if(hipThreadIdx_x == 0)
{
printf("bid %zd, n_grid = %d, o_grid = %d, step N0 = %d\n",
hipBlockIdx_x,
n_thread_data_idx_on_grid,
o_thread_data_idx_on_grid,
n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(NPerBlock))[I0]);
}
#endif
// //
// set up Y dot dY // set up Y dot dY
...@@ -1503,14 +1476,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1503,14 +1476,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
(ygrad_grid_desc_o0_m_o1.GetLength(I0) * ygrad_grid_desc_o0_m_o1.GetLength(I2)) / (ygrad_grid_desc_o0_m_o1.GetLength(I0) * ygrad_grid_desc_o0_m_o1.GetLength(I2)) /
KPerBlock); KPerBlock);
#if 0
if(hipThreadIdx_x == 0 && hipBlockIdx_x == 0) printf("lds before accum\n");
if(hipBlockIdx_x == 0)
{
debug::print_shared(y_dot_ygrad_block_accum_buf.p_data_, MPerBlock);
}
#endif
auto y_dot_ygrad_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatGemmAcc>( auto y_dot_ygrad_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatGemmAcc>(
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl.GetElementSpaceSize()); y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl.GetElementSpaceSize());
...@@ -1575,23 +1540,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1575,23 +1540,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}); });
}); });
#if 0
if (hipThreadIdx_x % 32 < 4 && hipBlockIdx_x == 0)
{
printf("bid %zd tid %zd, oblock_idx %d, y_thread_buf[0:3] = %f %f %f %f, ygrad_thread_buf[0:3] = %f %f %f %f\n",
hipBlockIdx_x,
hipThreadIdx_x,
oblock_idx,
(float)y_thread_buf[I0],
(float)y_thread_buf[I1],
(float)y_thread_buf[I2],
(float)y_thread_buf[I3],
(float)ygrad_thread_buf[I0],
(float)ygrad_thread_buf[I1],
(float)ygrad_thread_buf[I2],
(float)ygrad_thread_buf[I3]);
}
#endif
yygrad_threadwise_copy.MoveSrcSliceWindow(y_grid_desc_mblock_mperblock_oblock_operblock, yygrad_threadwise_copy.MoveSrcSliceWindow(y_grid_desc_mblock_mperblock_oblock_operblock,
make_multi_index(0, 0, 1, 0)); make_multi_index(0, 0, 1, 0));
...@@ -1607,14 +1555,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1607,14 +1555,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}); });
block_sync_lds(); block_sync_lds();
#if 1
if(hipThreadIdx_x == 0 && hipBlockIdx_x == 0) printf("lds after accum\n");
if(hipBlockIdx_x == 0)
{
debug::print_shared(y_dot_ygrad_block_accum_buf.p_data_, MPerBlock);
}
#endif
// distribute y_dot_ygrad to threads; LDS accum buffer can be safely reused after barrier // distribute y_dot_ygrad to threads; LDS accum buffer can be safely reused after barrier
y_dot_ygrad_thread_copy_lds_to_vgpr.Run( y_dot_ygrad_thread_copy_lds_to_vgpr.Run(
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl, y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl,
...@@ -1623,19 +1563,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1623,19 +1563,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
y_dot_ygrad_thread_buf); y_dot_ygrad_thread_buf);
#if 0
if(hipBlockIdx_x < 4 && hipThreadIdx_x % 32 < 4)
{
printf("bid %zd tid %zd, y_m0_m1_o0_o1 = %d, %d, %d, %d\n",
hipBlockIdx_x,
hipThreadIdx_x,
y_thread_data_on_grid_idx[I0],
y_thread_data_on_grid_idx[I1],
y_thread_data_on_grid_idx[I2],
y_thread_data_on_grid_idx[I3]);
}
#endif
lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mblock_mrepeat_mwave_mperxdl, lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mblock_mrepeat_mwave_mperxdl,
lse_grid_buf, lse_grid_buf,
lse_thread_desc_mblock_mrepeat_mwave_mperxdl, lse_thread_desc_mblock_mrepeat_mwave_mperxdl,
...@@ -1739,33 +1666,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1739,33 +1666,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
#if 0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
printf("tid %zd, S[0:3] = %f, %f, %f, %f\n",
hipThreadIdx_x,
s_slash_p_thread_buf[I0],
s_slash_p_thread_buf[I1],
s_slash_p_thread_buf[I2],
s_slash_p_thread_buf[I3]);
}
#endif
// P_i: = softmax(S_i:) // P_i: = softmax(S_i:)
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf); blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
#if 0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
printf("tid %zd, P[0:3] = %f, %f, %f, %f\n",
hipThreadIdx_x,
s_slash_p_thread_buf[I0],
s_slash_p_thread_buf[I1],
s_slash_p_thread_buf[I2],
s_slash_p_thread_buf[I3]);
}
#endif
block_sync_lds(); // wait for gemm1 LDS read block_sync_lds(); // wait for gemm1 LDS read
SubThreadBlock<BlockSize> p_thread_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0], SubThreadBlock<BlockSize> p_thread_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0],
...@@ -1788,21 +1691,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1788,21 +1691,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_tuple(p_nd_idx[I2], p_nd_idx[I2] + p_block_slice_lengths_m0_n0_m1_n1[I2]); make_tuple(p_nd_idx[I2], p_nd_idx[I2] + p_block_slice_lengths_m0_n0_m1_n1[I2]);
constexpr auto nwave_range = constexpr auto nwave_range =
make_tuple(p_nd_idx[I3], p_nd_idx[I3] + p_block_slice_lengths_m0_n0_m1_n1[I3]); make_tuple(p_nd_idx[I3], p_nd_idx[I3] + p_block_slice_lengths_m0_n0_m1_n1[I3]);
#if 0
if(hipThreadIdx_x % 64 == 0)
{
printf(
"VGrad P vgrad_gemm_loop_idx %d, wave_id = %d, mrepeat, nrepeat, mwave, "
"nwave = %d, %d, %d, %d, active %d\n",
vgrad_gemm_loop_idx.value,
(int)hipThreadIdx_x / 64,
p_nd_idx[I0].value,
p_nd_idx[I1].value,
p_nd_idx[I2].value,
p_nd_idx[I3].value,
p_thread_copy_subgroup.IsBelong(mwave_range, nwave_range));
}
#endif
if(p_thread_copy_subgroup.IsBelong(mwave_range, nwave_range)) if(p_thread_copy_subgroup.IsBelong(mwave_range, nwave_range))
{ {
p_thread_copy_vgpr_to_lds.Run( p_thread_copy_vgpr_to_lds.Run(
...@@ -1820,38 +1709,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1820,38 +1709,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
block_sync_lds(); // sync before write block_sync_lds(); // sync before write
ygrad_blockwise_copy.RunWrite(ygrad_block_desc_m0_o_m1, ygrad_block_buf); ygrad_blockwise_copy.RunWrite(ygrad_block_desc_m0_o_m1, ygrad_block_buf);
#if 0
if (hipBlockIdx_x == 0)
{
debug::print_shared(
p_block_buf.p_data_,
index_t(p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize()));
}
#endif
#if 0
if (hipBlockIdx_x == 0)
{
debug::print_shared(ygrad_block_buf.p_data_,
index_t(ygrad_block_desc_m0_o_m1.GetElementSpaceSize()));
}
#endif
block_sync_lds(); // sync before read block_sync_lds(); // sync before read
vgrad_blockwise_gemm.Run(p_block_buf, ygrad_block_buf, vgrad_thread_buf); vgrad_blockwise_gemm.Run(p_block_buf, ygrad_block_buf, vgrad_thread_buf);
#if 0
if(hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
printf("outer %d inner %d tid %zd, dV[0:3] = %f, %f, %f, %f\n",
gemm1_k_block_outer_index,
vgrad_gemm_loop_idx.value,
hipThreadIdx_x,
vgrad_thread_buf[I0],
vgrad_thread_buf[I1],
vgrad_thread_buf[I2],
vgrad_thread_buf[I3]);
}
#endif
}); // end gemm dV }); // end gemm dV
// atomic_add dV // atomic_add dV
...@@ -1880,18 +1741,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1880,18 +1741,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
pgrad_blockwise_gemm, pgrad_blockwise_gemm,
pgrad_thread_buf, pgrad_thread_buf,
num_o_block_main_loop); num_o_block_main_loop);
#if 0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
printf("outer j loop idx %d, tid %zd, dP[0:3] = %f, %f, %f, %f\n",
gemm1_k_block_outer_index,
hipThreadIdx_x,
pgrad_thread_buf[I0],
pgrad_thread_buf[I1],
pgrad_thread_buf[I2],
pgrad_thread_buf[I3]);
}
#endif
// calculate dS from dP // calculate dS from dP
auto& sgrad_thread_buf = pgrad_thread_buf; auto& sgrad_thread_buf = pgrad_thread_buf;
...@@ -1910,19 +1759,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1910,19 +1759,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
(pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}]); (pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}]);
}); });
#if 0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
printf("outer j loop idx %d, tid %zd, dS[0:3] = %f, %f, %f, %f\n",
gemm1_k_block_outer_index,
hipThreadIdx_x,
sgrad_thread_buf[I0],
sgrad_thread_buf[I1],
sgrad_thread_buf[I2],
sgrad_thread_buf[I3]);
}
#endif
// gemm dQ // gemm dQ
{ {
// TODO: explore using dynamic buffer for a1 thread buffer // TODO: explore using dynamic buffer for a1 thread buffer
...@@ -1941,14 +1777,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1941,14 +1777,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
block_sync_lds(); // wait for previous LDS read block_sync_lds(); // wait for previous LDS read
qgrad_gemm_tile_k_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf); qgrad_gemm_tile_k_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf);
#if 0
if(hipThreadIdx_x == 0 && hipBlockIdx_x == 0) printf("inner j loop idx %d, lds dQ gemm K matrix =", i.value);
if(hipBlockIdx_x == 0)
{
debug::print_shared(b1_block_buf.p_data_,
(index_t)b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
}
#endif
// main body // main body
if constexpr(num_gemm1_k_block_inner_loop > 1) if constexpr(num_gemm1_k_block_inner_loop > 1)
{ {
...@@ -1959,18 +1788,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1959,18 +1788,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
a1_thread_desc_k0_m_k1, a1_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
a1_thread_buf); a1_thread_buf);
#if 0
if(hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
printf("inner j loop idx %d, tid %zd, dS downcast[0:3] = %f, %f, %f, %f\n",
i.value,
hipThreadIdx_x,
(float)a1_thread_buf[I0],
(float)a1_thread_buf[I1],
(float)a1_thread_buf[I2],
(float)a1_thread_buf[I3]);
}
#endif
qgrad_gemm_tile_k_blockwise_copy.RunRead(k_grid_desc_n0_k_n1, k_grid_buf); qgrad_gemm_tile_k_blockwise_copy.RunRead(k_grid_desc_n0_k_n1, k_grid_buf);
block_sync_lds(); block_sync_lds();
...@@ -2021,19 +1838,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -2021,19 +1838,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// TODO ANT: // TODO ANT:
// shuffle dQ and write // shuffle dQ and write
#if 0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
printf("tid %zd, dQ[0:3] = %f, %f, %f, %f\n",
hipThreadIdx_x,
qgrad_thread_buf[I0],
qgrad_thread_buf[I1],
qgrad_thread_buf[I2],
qgrad_thread_buf[I3]);
}
#endif
{ {
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
...@@ -2218,14 +2022,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -2218,14 +2022,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
c_shuffle_block_buf, c_shuffle_block_buf,
qgrad_grid_desc_mblock_mperblock_kblock_kperblock, qgrad_grid_desc_mblock_mperblock_kblock_kperblock,
qgrad_grid_buf); qgrad_grid_buf);
#if 0
if(hipThreadIdx_x == 0 && hipBlockIdx_x == 0) printf("lds dQ shuffle loop %d\n", access_id.value);
if(hipBlockIdx_x == 1)
{
debug::print_shared(c_shuffle_block_buf.p_data_,
(index_t)c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
}
#endif
if constexpr(access_id < num_access - 1) if constexpr(access_id < num_access - 1)
{ {
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment