Commit 9dc3e49b authored by letaoqin's avatar letaoqin
Browse files

recover code for bwd v2

parent fd107062
...@@ -2252,49 +2252,54 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -2252,49 +2252,54 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// add bias // add bias
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
static constexpr auto& c_thread_desc = s_blockwise_gemm.GetCThreadDesc(); if(p_d0_grid != nullptr)
{
static constexpr auto& c_thread_desc = s_blockwise_gemm.GetCThreadDesc();
const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset, static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>( auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Operator::d0_thread_desc_.GetElementSpaceSize()); D0Operator::d0_thread_desc_.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) { static_for<0, D0M0, 1>{}([&](auto mr) {
// load data to lds // load data to lds
d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3,
d0_grid_buf); d0_grid_buf);
d0_block_copy_global_to_lds.MoveSrcSliceWindow( d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0)); d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
d0_block_copy_global_to_lds.RunWrite( d0_block_copy_global_to_lds.RunWrite(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3, d0_block_buf); D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3, d0_block_buf);
block_sync_lds(); block_sync_lds();
// read data form lds // read data form lds
d0_thread_copy_lds_to_vgpr.Run(D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2, d0_thread_copy_lds_to_vgpr.Run(
make_tuple(I0, I0, I0, I0, I0), D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2,
d0_block_buf, make_tuple(I0, I0, I0, I0, I0),
D0Operator::d0_thread_desc_, d0_block_buf,
make_tuple(I0, I0, I0, I0, I0), D0Operator::d0_thread_desc_,
d0_thread_buf); make_tuple(I0, I0, I0, I0, I0),
d0_thread_buf);
// bias add
static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) { // bias add
constexpr index_t c_offset = static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) {
c_thread_desc.CalculateOffset(make_tuple(mr, I0, i)); constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(mr, I0, i));
s_slash_p_thread_buf(Number<c_offset>{}) +=
ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]); s_slash_p_thread_buf(Number<c_offset>{}) +=
ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
});
}); });
});
d0_block_copy_global_to_lds.MoveSrcSliceWindow( d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(-1, 0, -D0M0.value, 0, 0, 0)); d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
}
} }
// P_i: = softmax(scalar * S_i:) // P_i: = softmax(scalar * S_i:)
......
...@@ -2326,9 +2326,32 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2326,9 +2326,32 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{})); make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
// scale // do MNK padding or upper triangular masking
static_for<0, s_slash_p_thread_buf.Size(), 1>{}( if constexpr(MaskOutUpperTriangle || PadN)
[&](auto i) { s_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]); }); {
static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) {
auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin;
auto m_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
bool masked_flag = c0_matrix_mask.IsMaskedElement(m_global, n_global);
s_element_op(s_slash_p_thread_buf(i),
masked_flag ? -ck::NumericLimits<float>::Infinity()
: s_slash_p_thread_buf[i]);
});
}
else
{
static_for<0, s_slash_p_thread_buf.Size(), 1>{}([&](auto i) {
s_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]);
});
}
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
// add bias // add bias
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
...@@ -2383,26 +2406,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2383,26 +2406,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
} }
} }
// do MNK padding or upper triangular masking
if constexpr(MaskOutUpperTriangle || PadN)
{
static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) {
auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin;
auto m_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
bool masked_flag = c0_matrix_mask.IsMaskedElement(m_global, n_global);
s_slash_p_thread_buf(i) = masked_flag ? -ck::NumericLimits<float>::Infinity()
: s_slash_p_thread_buf[i];
});
}
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
// P_i: = softmax(scalar * S_i:) // P_i: = softmax(scalar * S_i:)
// scaling is already performed in the preceding statements with s_element_op // scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf); blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment