Commit 4388b767 authored by danyao12's avatar danyao12
Browse files

optimize fwd v2 dropout to eliminate scratch

parent 00cb7e41
......@@ -198,55 +198,41 @@ struct BlockwiseDropout
});
}
template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false>
template <typename CThreadBuffer, typename ZThreadBuffer, typename Step, typename Offset>
__host__ __device__ void ApplyDropoutWithZ(CThreadBuffer& in_thread_buf,
ZThreadBuffer& z_thread_buf)
{
auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
};
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = execute_dropout(z_thread_buf(offset) <= p_dropout_16bits,
in_thread_buf(offset));
tmp_index = tmp_index + 1;
});
constexpr int tmp_size = MRepeat * KRepeat / Step{}.value;
static_for<0, tmp_size, 1>{}([&](auto i) {
in_thread_buf(i + Offset{}) =
execute_dropout(z_thread_buf(i) <= p_dropout_16bits, in_thread_buf(i + Offset{}));
});
}
// get raw z matrix with random number for shuffle
template <typename ZThreadBuffer>
template <typename ZThreadBuffer,
typename Step,
typename Offset> // N3*N4=8
__host__ __device__ void GenerateZMatrixAttnFwd(ck::philox& ph,
index_t element_global_1d_id,
ZThreadBuffer& z_thread_buf)
{
constexpr int tmp_size = MRepeat * KRepeat;
constexpr int tmp_size = MRepeat * KRepeat / Step{}.value;
int philox_calls = tmp_size / 4;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8);
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * Offset{});
}
block_sync_lds();
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1;
});
});
static_for<0, tmp_size, 1>{}([&](auto i) { z_thread_buf(i) = tmp[i.value]; });
}
ushort p_dropout_16bits;
......
......@@ -143,6 +143,23 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
}
__host__ __device__ static constexpr auto GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto M0 = MXdlPerWave;
constexpr auto M1 = Gemm0MWaves;
constexpr auto N1 = Gemm0NWaves;
constexpr auto M2 = MPerXdl;
constexpr auto N2 = mfma.num_groups_per_blk;
constexpr auto N3 = mfma.num_input_blks;
constexpr auto N4 = mfma.group_size;
constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
make_naive_tensor_descriptor_packed(make_tuple(M0, I1, M1, N1, M2, N2, N3, N4));
return z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4;
}
__host__ __device__ static constexpr auto GetPaddedSize(const index_t size)
{
constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma;
......@@ -430,7 +447,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
static constexpr auto z_shuffle_block_space_size = MPerBlock * NPerBlock;
// LDS allocation for Z shuffle in LDS
static constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
static constexpr auto z_shuffle_block_space_size =
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize();
};
template <bool HasMainKBlockLoop,
......@@ -873,7 +894,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// z matrix threadwise desc
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat
n0, // NRepeat
I1, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
......@@ -881,9 +902,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
n3, // NInputNum
n4)); // RegisterNum
constexpr auto z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4 = // for blockwise copy
constexpr auto z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat
n0, // NRepeat
I1, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
......@@ -896,7 +917,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockId
m0, // MRepeat
n0, // NRepeat
I1, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
......@@ -904,21 +925,21 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
n3, // NInputNum
n4)); // RegisterNum
constexpr auto z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto ZM0 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto ZN0 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto ZM1 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto ZN1 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto ZM2 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto ZN2 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto ZN3 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto ZN4 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto ZM0 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto ZN0 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto ZM1 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto ZN1 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto ZM2 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto ZN2 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto ZN3 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto ZN4 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4 =
constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4 =
transform_tensor_descriptor(
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(make_pass_through_transform(ZM0),
make_pass_through_transform(ZN0),
make_pass_through_transform(ZM1),
......@@ -946,7 +967,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
StaticBuffer<AddressSpaceEnum::Vgpr,
ushort,
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize(),
z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize(),
true>
z_tensor_buffer;
z_tensor_buffer.Clear();
......@@ -956,19 +977,19 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
auto z_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ushort*>(p_shared),
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize());
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize());
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
auto z_tmp_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<ushort,
auto z_tmp_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ushort,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
tensor_operation::element_wise::PassThrough,
Sequence<m0, // MRepeat
n0, // NRepeat
I1, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
......@@ -980,8 +1001,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
true>{z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_multi_index(0, // MRepeat
0, // NRepeat
wave_id[I0], // MWaveId
......@@ -995,14 +1015,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
auto z_shuffle_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
ushort,
ushort,
decltype(z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4),
decltype(z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4),
Sequence<m0, n0, m1, n1, m2, n2, n3, n4, I1>,
decltype(z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4),
decltype(z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4),
Sequence<m0, I1, m1, n1, m2, n2, n3, n4, I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>,
8,
1,
1,
true>{z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4,
true>{z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4,
make_multi_index(0, // MRepeat
0, // NRepeat
wave_id[I0], // MWaveId
......@@ -1022,7 +1042,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
I1, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
......@@ -1082,6 +1102,19 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
constexpr auto c_thread_lengths =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
// 8d block_desc in block scope
constexpr auto c_block_lengths =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
constexpr auto M0 = c_block_lengths[I0];
constexpr auto N0 = c_block_lengths[I1];
constexpr auto M1 = c_block_lengths[I2];
constexpr auto N1 = c_block_lengths[I3];
constexpr auto M2 = c_block_lengths[I4];
constexpr auto N2 = c_block_lengths[I5];
constexpr auto N3 = c_block_lengths[I6];
constexpr auto N4 = c_block_lengths[I7];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using Acc0TileIterator = SpaceFillingCurve<
......@@ -1091,8 +1124,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
false>; // SnakeCurved
constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(ZM0, ZM1, ZM2)),
make_unmerge_transform(make_tuple(ZN0, ZN1, ZN2, ZN3, ZN4))),
make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2)),
make_unmerge_transform(make_tuple(N0, N1, N2, N3, N4))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
......@@ -1132,53 +1165,67 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
blockwise_softmax.Run(acc_thread_buf, workspace_buf);
constexpr auto position_offset = N3 * N4;
constexpr auto iterator_offset = n2 * n3 * n4;
if constexpr(IsDropout) // dropout
{
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
static_for<0, Acc0TileIterator::GetNumOfAccess(), iterator_offset>{}([&](auto i) {
auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin;
auto m_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id
blockwise_dropout.template GenerateZMatrixAttnFwd<decltype(z_tensor_buffer)>(
blockwise_dropout.template GenerateZMatrixAttnFwd<decltype(z_tensor_buffer),
decltype(n0),
decltype(position_offset)>(
ph, global_elem_id, z_tensor_buffer);
z_tmp_thread_copy_vgpr_to_lds.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
z_tensor_buffer,
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
z_block_buf);
z_shuffle_thread_copy_lds_to_vgpr.Run(
z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4,
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4,
z_block_buf,
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4,
z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tensor_buffer);
blockwise_dropout.template ApplyDropoutWithZ<decltype(acc_thread_buf),
decltype(z_tensor_buffer),
false>(acc_thread_buf,
decltype(n0),
decltype(i)>(acc_thread_buf,
z_tensor_buffer);
// save z to global
if(p_z_grid)
{
// ignore = z_tensor_buffer;
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
}
});
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, -(n0.value), 0, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
}
// TODO: may convert to log domain
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment