Commit 644a45e9 authored by guangzlu's avatar guangzlu
Browse files

fixed bug of global pos in gridwise

parent 907413ac
......@@ -1559,7 +1559,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough,
Sequence<I1, // MBlockId
I1, // NBlockID
I1, // NBlockId
m0, // MRepeat
I1, // NRepeat
m1, // MWaveId
......@@ -1976,24 +1976,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id =
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
// P_dropped
static_for<0, n0, 1>{}([&](auto i) {
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
true,
decltype(n0),
decltype(i)>(
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer);
decltype(i)>(s_slash_p_thread_buf,
ph,
global_elem_id +
i.value * id_step,
z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
......@@ -2043,16 +2046,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id =
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
ignore = z_grid_buf;
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
......
......@@ -1067,6 +1067,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// printf("at 1 global_elem_id is %d \n", global_elem_id);
// }
index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
// save z to global
if(p_z_grid)
{
......@@ -1076,7 +1078,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
false,
decltype(n0),
decltype(i)>(
acc_thread_buf, ph, global_elem_id, z_tenor_buffer);
acc_thread_buf, ph, global_elem_id + i.value * id_step, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment