Commit fe9c974e authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

[Performance] Use Flash-Attention-2 method to implement output iterative update in infer/forward

parent 9a423017
......@@ -88,9 +88,8 @@ struct BlockwiseSoftmax
__host__ __device__ void Run(CThreadBuffer& in_thread_buf, WorkspaceBuffer& reduce_work_buf)
{
// find max value
static_for<0, MRepeat, 1>{}([&](auto I) {
max_value_buf(I) = reduce::Max::template GetIdentityValue<AccDataType>();
});
static_for<0, MRepeat, 1>{}(
[&](auto I) { max_value_buf(I) = ck::NumericLimits<AccDataType>::Lowest(); });
ThreadwiseMaxReduce::Reduce(in_thread_buf, max_value_buf);
static_for<0, MRepeat, 1>{}([&](auto I) {
BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I));
......@@ -118,6 +117,47 @@ struct BlockwiseSoftmax
});
}
template <typename CThreadBuffer, typename WorkspaceBuffer>
__host__ __device__ void CalculateRowMax(CThreadBuffer& in_thread_buf,
WorkspaceBuffer& reduce_work_buf)
{
// find max value
static_for<0, MRepeat, 1>{}(
[&](auto I) { max_value_buf(I) = ck::NumericLimits<AccDataType>::Lowest(); });
ThreadwiseMaxReduce::Reduce(in_thread_buf, max_value_buf);
static_for<0, MRepeat, 1>{}([&](auto I) {
BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I));
block_sync_lds();
});
};
template <typename CThreadBuffer, typename WorkspaceBuffer>
__host__ __device__ void CalculateRowExpSum(CThreadBuffer& in_thread_buf,
WorkspaceBuffer& reduce_work_buf,
BufferType& max_value_buf_new)
{
// calculate exp for elements, P=exp(s-max)
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) =
IgnoreNaN && ck::math::isnan(in_thread_buf[offset])
? 0
: math::exp(in_thread_buf[offset] - max_value_buf_new(iM));
});
});
// sum data
static_for<0, MRepeat, 1>{}([&](auto I) {
sum_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
});
ThreadwiseSumReduce::Reduce(in_thread_buf, sum_value_buf);
static_for<0, MRepeat, 1>{}([&](auto I) {
BlockwiseSumReduce::Reduce(reduce_work_buf, sum_value_buf(I));
block_sync_lds();
});
}
template <typename CThreadBuffer, typename LSEBuffer>
__host__ __device__ void RunWithPreCalcStats(CThreadBuffer& in_thread_buf,
const LSEBuffer& lse_thread_buf)
......
......@@ -1147,6 +1147,21 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
0),
tensor_operation::element_wise::PassThrough{}};
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto cn0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto cm1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto cn1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto cm2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto cn2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto cn3 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto cn4 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(cm0 * cm1 * cm2, cn0 * cn1 * cn2 * cn3 * cn4));
constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0);
constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1);
do
{
auto n_block_data_idx_on_grid =
......@@ -1262,11 +1277,23 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
}
// softmax
// calculate current max
blockwise_softmax.CalculateRowMax(acc_thread_buf, workspace_buf);
// current max
SoftmaxBuf& max = blockwise_softmax.max_value_buf;
// accumulated max
running_max_new = mathext::max(max, running_max);
// calculate current exp_sum
blockwise_softmax.CalculateRowExpSum(acc_thread_buf, workspace_buf, running_max_new);
// current exp_sum
SoftmaxBuf& sum = blockwise_softmax.sum_value_buf;
blockwise_softmax.Run(acc_thread_buf, workspace_buf);
// accumulated exp_sum
running_sum_new = mathext::exp(running_max - running_max_new) * running_sum + sum;
constexpr auto iterator_offset = Number<16 * DropoutStep>{};
constexpr auto iterator_step = Number<m0 * n0 * n1 * n2 * n3 * n4 / 16 / DropoutStep>{};
......@@ -1325,11 +1352,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
});
}
// TODO: may convert to log domain
running_max_new = mathext::max(max, running_max);
running_sum_new = mathext::exp(running_max - running_max_new) * running_sum +
mathext::exp(max - running_max_new) * sum;
// gemm1
{
// TODO: explore using dynamic buffer for a1 thread buffer
......@@ -1394,31 +1416,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
}
} // end gemm1
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto cn0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto cm1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto cn1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto cm2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto cn2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto cn3 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto cn4 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(cm0 * cm1 * cm2, cn0 * cn1 * cn2 * cn3 * cn4));
constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0);
constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1);
static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) {
static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) {
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V
FloatGemmAcc c = c_thread_buf[I]; // O
FloatGemmAcc c_new =
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
math::exp(max[iM] - running_max_new[iM]) * acc1) /
running_sum_new[iM]; // Formula by Dao et al.,
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
math::exp(running_max[iM] - running_max_new[iM]) * c + acc1;
// Formula by Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
c_thread_buf(I) = c_new; // O_new
});
......@@ -1436,6 +1441,13 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
block_sync_lds(); // wait for gemm1 LDS read
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) {
static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) {
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
c_thread_buf(I) = c_thread_buf[I] / running_sum[iM];
});
});
// Calculate max + ln(sum) and write out
if constexpr(IsLseStoring)
......
......@@ -914,6 +914,21 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadwiseCopyLdsToVgpr(
make_tuple(wave_id[I0], wave_m_n_id[I1], 0, 0, wave_m_n_id[I0], 0));
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto cn0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto cm1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto cn1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto cm2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto cn2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto cn3 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto cn4 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(cm0 * cm1 * cm2, cn0 * cn1 * cn2 * cn3 * cn4));
constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0);
constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1);
index_t gemm1_k_block_outer_index = 0;
do
{
......@@ -1058,16 +1073,22 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
}
}
// softmax
// calculate current max
blockwise_softmax.CalculateRowMax(acc_thread_buf, workspace_buf);
// current max
SoftmaxBuf& max = blockwise_softmax.max_value_buf;
SoftmaxBuf& sum = blockwise_softmax.sum_value_buf;
// accumulated max
running_max_new = mathext::max(max, running_max);
blockwise_softmax.Run(acc_thread_buf, workspace_buf);
// calculate current exp_sum
blockwise_softmax.CalculateRowExpSum(acc_thread_buf, workspace_buf, running_max_new);
// TODO: may convert to log domain
running_max_new = mathext::max(max, running_max);
running_sum_new = mathext::exp(running_max - running_max_new) * running_sum +
mathext::exp(max - running_max_new) * sum;
// current exp_sum
SoftmaxBuf& sum = blockwise_softmax.sum_value_buf;
// accumulated exp_sum
running_sum_new = mathext::exp(running_max - running_max_new) * running_sum + sum;
// gemm1
{
......@@ -1133,31 +1154,14 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
}
} // end gemm1
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto cn0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto cm1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto cn1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto cm2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto cn2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto cn3 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto cn4 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(cm0 * cm1 * cm2, cn0 * cn1 * cn2 * cn3 * cn4));
constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0);
constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1);
static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) {
static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) {
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V
FloatGemmAcc c = c_thread_buf[I]; // O
FloatGemmAcc c_new =
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
math::exp(max[iM] - running_max_new[iM]) * acc1) /
running_sum_new[iM]; // Formula by Dao et al.,
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V
FloatGemmAcc c = c_thread_buf[I]; // O
FloatGemmAcc c_new = math::exp(running_max[iM] - running_max_new[iM]) * c +
acc1; // Formula by Dao et al.,
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
c_thread_buf(I) = c_new; // O_new
});
......@@ -1175,6 +1179,13 @@ struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
block_sync_lds(); // wait for gemm1 LDS read
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) {
static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) {
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
c_thread_buf(I) = c_thread_buf[I] / running_sum[iM];
});
});
// shuffle C and write out
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment