"...composable_kernel_rocm.git" did not exist on "b2888adfbe103ae3d9006af87d5871b69cbf00ba"
Commit 9dc3e49b authored by letaoqin's avatar letaoqin
Browse files

recover code for bwd v2

parent fd107062
...@@ -2251,6 +2251,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -2251,6 +2251,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// add bias // add bias
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{
if(p_d0_grid != nullptr)
{ {
static constexpr auto& c_thread_desc = s_blockwise_gemm.GetCThreadDesc(); static constexpr auto& c_thread_desc = s_blockwise_gemm.GetCThreadDesc();
...@@ -2276,7 +2278,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -2276,7 +2278,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3, d0_block_buf); D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3, d0_block_buf);
block_sync_lds(); block_sync_lds();
// read data form lds // read data form lds
d0_thread_copy_lds_to_vgpr.Run(D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2, d0_thread_copy_lds_to_vgpr.Run(
D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
d0_block_buf, d0_block_buf,
D0Operator::d0_thread_desc_, D0Operator::d0_thread_desc_,
...@@ -2294,7 +2297,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -2294,7 +2297,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
}); });
d0_block_copy_global_to_lds.MoveSrcSliceWindow( d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(-1, 0, -D0M0.value, 0, 0, 0)); d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
}
} }
// P_i: = softmax(scalar * S_i:) // P_i: = softmax(scalar * S_i:)
......
...@@ -2326,9 +2326,32 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2326,9 +2326,32 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{})); make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
// scale // do MNK padding or upper triangular masking
static_for<0, s_slash_p_thread_buf.Size(), 1>{}( if constexpr(MaskOutUpperTriangle || PadN)
[&](auto i) { s_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]); }); {
static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) {
auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin;
auto m_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
bool masked_flag = c0_matrix_mask.IsMaskedElement(m_global, n_global);
s_element_op(s_slash_p_thread_buf(i),
masked_flag ? -ck::NumericLimits<float>::Infinity()
: s_slash_p_thread_buf[i]);
});
}
else
{
static_for<0, s_slash_p_thread_buf.Size(), 1>{}([&](auto i) {
s_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]);
});
}
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
// add bias // add bias
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
...@@ -2383,26 +2406,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2383,26 +2406,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
} }
} }
// do MNK padding or upper triangular masking
if constexpr(MaskOutUpperTriangle || PadN)
{
static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) {
auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin;
auto m_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
bool masked_flag = c0_matrix_mask.IsMaskedElement(m_global, n_global);
s_slash_p_thread_buf(i) = masked_flag ? -ck::NumericLimits<float>::Infinity()
: s_slash_p_thread_buf[i];
});
}
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
// P_i: = softmax(scalar * S_i:) // P_i: = softmax(scalar * S_i:)
// scaling is already performed in the preceding statements with s_element_op // scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf); blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment