Unverified Commit f404e93a authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

Merge pull request #1039 from ROCmSoftwarePlatform/mha-train-develop-outaccu

Use Flash-Attention-2 method to do output accumulation for infer/forward
parents 9a423017 fe9c974e
...@@ -88,9 +88,8 @@ struct BlockwiseSoftmax ...@@ -88,9 +88,8 @@ struct BlockwiseSoftmax
__host__ __device__ void Run(CThreadBuffer& in_thread_buf, WorkspaceBuffer& reduce_work_buf) __host__ __device__ void Run(CThreadBuffer& in_thread_buf, WorkspaceBuffer& reduce_work_buf)
{ {
// find max value // find max value
static_for<0, MRepeat, 1>{}([&](auto I) { static_for<0, MRepeat, 1>{}(
max_value_buf(I) = reduce::Max::template GetIdentityValue<AccDataType>(); [&](auto I) { max_value_buf(I) = ck::NumericLimits<AccDataType>::Lowest(); });
});
ThreadwiseMaxReduce::Reduce(in_thread_buf, max_value_buf); ThreadwiseMaxReduce::Reduce(in_thread_buf, max_value_buf);
static_for<0, MRepeat, 1>{}([&](auto I) { static_for<0, MRepeat, 1>{}([&](auto I) {
BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I)); BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I));
...@@ -118,6 +117,47 @@ struct BlockwiseSoftmax ...@@ -118,6 +117,47 @@ struct BlockwiseSoftmax
}); });
} }
template <typename CThreadBuffer, typename WorkspaceBuffer>
__host__ __device__ void CalculateRowMax(CThreadBuffer& in_thread_buf,
WorkspaceBuffer& reduce_work_buf)
{
// find max value
static_for<0, MRepeat, 1>{}(
[&](auto I) { max_value_buf(I) = ck::NumericLimits<AccDataType>::Lowest(); });
ThreadwiseMaxReduce::Reduce(in_thread_buf, max_value_buf);
static_for<0, MRepeat, 1>{}([&](auto I) {
BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I));
block_sync_lds();
});
};
template <typename CThreadBuffer, typename WorkspaceBuffer>
__host__ __device__ void CalculateRowExpSum(CThreadBuffer& in_thread_buf,
WorkspaceBuffer& reduce_work_buf,
BufferType& max_value_buf_new)
{
// calculate exp for elements, P=exp(s-max)
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) =
IgnoreNaN && ck::math::isnan(in_thread_buf[offset])
? 0
: math::exp(in_thread_buf[offset] - max_value_buf_new(iM));
});
});
// sum data
static_for<0, MRepeat, 1>{}([&](auto I) {
sum_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
});
ThreadwiseSumReduce::Reduce(in_thread_buf, sum_value_buf);
static_for<0, MRepeat, 1>{}([&](auto I) {
BlockwiseSumReduce::Reduce(reduce_work_buf, sum_value_buf(I));
block_sync_lds();
});
}
template <typename CThreadBuffer, typename LSEBuffer> template <typename CThreadBuffer, typename LSEBuffer>
__host__ __device__ void RunWithPreCalcStats(CThreadBuffer& in_thread_buf, __host__ __device__ void RunWithPreCalcStats(CThreadBuffer& in_thread_buf,
const LSEBuffer& lse_thread_buf) const LSEBuffer& lse_thread_buf)
......
...@@ -1147,6 +1147,21 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1147,6 +1147,21 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
0), 0),
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto cn0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto cm1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto cn1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto cm2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto cn2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto cn3 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto cn4 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(cm0 * cm1 * cm2, cn0 * cn1 * cn2 * cn3 * cn4));
constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0);
constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1);
do do
{ {
auto n_block_data_idx_on_grid = auto n_block_data_idx_on_grid =
...@@ -1262,11 +1277,23 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1262,11 +1277,23 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
} }
// softmax
// calculate current max
blockwise_softmax.CalculateRowMax(acc_thread_buf, workspace_buf);
// current max
SoftmaxBuf& max = blockwise_softmax.max_value_buf; SoftmaxBuf& max = blockwise_softmax.max_value_buf;
// accumulated max
running_max_new = mathext::max(max, running_max);
// calculate current exp_sum
blockwise_softmax.CalculateRowExpSum(acc_thread_buf, workspace_buf, running_max_new);
// current exp_sum
SoftmaxBuf& sum = blockwise_softmax.sum_value_buf; SoftmaxBuf& sum = blockwise_softmax.sum_value_buf;
blockwise_softmax.Run(acc_thread_buf, workspace_buf); // accumulated exp_sum
running_sum_new = mathext::exp(running_max - running_max_new) * running_sum + sum;
constexpr auto iterator_offset = Number<16 * DropoutStep>{}; constexpr auto iterator_offset = Number<16 * DropoutStep>{};
constexpr auto iterator_step = Number<m0 * n0 * n1 * n2 * n3 * n4 / 16 / DropoutStep>{}; constexpr auto iterator_step = Number<m0 * n0 * n1 * n2 * n3 * n4 / 16 / DropoutStep>{};
...@@ -1325,11 +1352,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1325,11 +1352,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
}); });
} }
// TODO: may convert to log domain
running_max_new = mathext::max(max, running_max);
running_sum_new = mathext::exp(running_max - running_max_new) * running_sum +
mathext::exp(max - running_max_new) * sum;
// gemm1 // gemm1
{ {
// TODO: explore using dynamic buffer for a1 thread buffer // TODO: explore using dynamic buffer for a1 thread buffer
...@@ -1394,31 +1416,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1394,31 +1416,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
} }
} // end gemm1 } // end gemm1
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto cn0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto cm1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto cn1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto cm2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto cn2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto cn3 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto cn4 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(cm0 * cm1 * cm2, cn0 * cn1 * cn2 * cn3 * cn4));
constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0);
constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1);
static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) { static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) {
static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) { static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) {
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{}; auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V
FloatGemmAcc c = c_thread_buf[I]; // O FloatGemmAcc c = c_thread_buf[I]; // O
FloatGemmAcc c_new = FloatGemmAcc c_new =
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c + math::exp(running_max[iM] - running_max_new[iM]) * c + acc1;
math::exp(max[iM] - running_max_new[iM]) * acc1) / // Formula by Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
running_sum_new[iM]; // Formula by Dao et al.,
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
c_thread_buf(I) = c_new; // O_new c_thread_buf(I) = c_new; // O_new
}); });
...@@ -1436,6 +1441,13 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1436,6 +1441,13 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
block_sync_lds(); // wait for gemm1 LDS read block_sync_lds(); // wait for gemm1 LDS read
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) {
static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) {
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
c_thread_buf(I) = c_thread_buf[I] / running_sum[iM];
});
});
// Calculate max + ln(sum) and write out // Calculate max + ln(sum) and write out
if constexpr(IsLseStoring) if constexpr(IsLseStoring)
......
...@@ -914,6 +914,21 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -914,6 +914,21 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadwiseCopyLdsToVgpr( auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadwiseCopyLdsToVgpr(
make_tuple(wave_id[I0], wave_m_n_id[I1], 0, 0, wave_m_n_id[I0], 0)); make_tuple(wave_id[I0], wave_m_n_id[I1], 0, 0, wave_m_n_id[I0], 0));
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto cn0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto cm1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto cn1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto cm2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto cn2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto cn3 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto cn4 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(cm0 * cm1 * cm2, cn0 * cn1 * cn2 * cn3 * cn4));
constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0);
constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1);
index_t gemm1_k_block_outer_index = 0; index_t gemm1_k_block_outer_index = 0;
do do
{ {
...@@ -1058,16 +1073,22 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -1058,16 +1073,22 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
} }
} }
// softmax // calculate current max
blockwise_softmax.CalculateRowMax(acc_thread_buf, workspace_buf);
// current max
SoftmaxBuf& max = blockwise_softmax.max_value_buf; SoftmaxBuf& max = blockwise_softmax.max_value_buf;
SoftmaxBuf& sum = blockwise_softmax.sum_value_buf; // accumulated max
running_max_new = mathext::max(max, running_max);
blockwise_softmax.Run(acc_thread_buf, workspace_buf); // calculate current exp_sum
blockwise_softmax.CalculateRowExpSum(acc_thread_buf, workspace_buf, running_max_new);
// TODO: may convert to log domain // current exp_sum
running_max_new = mathext::max(max, running_max); SoftmaxBuf& sum = blockwise_softmax.sum_value_buf;
running_sum_new = mathext::exp(running_max - running_max_new) * running_sum +
mathext::exp(max - running_max_new) * sum; // accumulated exp_sum
running_sum_new = mathext::exp(running_max - running_max_new) * running_sum + sum;
// gemm1 // gemm1
{ {
...@@ -1133,31 +1154,14 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -1133,31 +1154,14 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
} }
} // end gemm1 } // end gemm1
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto cn0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto cm1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto cn1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto cm2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto cn2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto cn3 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto cn4 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(cm0 * cm1 * cm2, cn0 * cn1 * cn2 * cn3 * cn4));
constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0);
constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1);
static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) { static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) {
static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) { static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) {
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{}; auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V
FloatGemmAcc c = c_thread_buf[I]; // O FloatGemmAcc c = c_thread_buf[I]; // O
FloatGemmAcc c_new = FloatGemmAcc c_new = math::exp(running_max[iM] - running_max_new[iM]) * c +
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c + acc1; // Formula by Dao et al.,
math::exp(max[iM] - running_max_new[iM]) * acc1) / // https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
running_sum_new[iM]; // Formula by Dao et al.,
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
c_thread_buf(I) = c_new; // O_new c_thread_buf(I) = c_new; // O_new
}); });
...@@ -1175,6 +1179,13 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle ...@@ -1175,6 +1179,13 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
block_sync_lds(); // wait for gemm1 LDS read block_sync_lds(); // wait for gemm1 LDS read
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) {
static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) {
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
c_thread_buf(I) = c_thread_buf[I] / running_sum[iM];
});
});
// shuffle C and write out // shuffle C and write out
{ {
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment