Commit 4388b767 authored by danyao12's avatar danyao12
Browse files

optimize fwd v2 dropout to eliminate scratch

parent 00cb7e41
...@@ -198,55 +198,41 @@ struct BlockwiseDropout ...@@ -198,55 +198,41 @@ struct BlockwiseDropout
}); });
} }
template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false> template <typename CThreadBuffer, typename ZThreadBuffer, typename Step, typename Offset>
__host__ __device__ void ApplyDropoutWithZ(CThreadBuffer& in_thread_buf, __host__ __device__ void ApplyDropoutWithZ(CThreadBuffer& in_thread_buf,
ZThreadBuffer& z_thread_buf) ZThreadBuffer& z_thread_buf)
{ {
auto execute_dropout = [&](bool keep, DataType val) { auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0); return keep ? val * p_dropout_rescale : float(0);
}; };
int tmp_index = 0; constexpr int tmp_size = MRepeat * KRepeat / Step{}.value;
static_for<0, MRepeat, 1>{}([&](auto iM) { static_for<0, tmp_size, 1>{}([&](auto i) {
static_for<0, KRepeat, 1>{}([&](auto iK) { in_thread_buf(i + Offset{}) =
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{}; execute_dropout(z_thread_buf(i) <= p_dropout_16bits, in_thread_buf(i + Offset{}));
in_thread_buf(offset) = execute_dropout(z_thread_buf(offset) <= p_dropout_16bits,
in_thread_buf(offset));
tmp_index = tmp_index + 1;
});
}); });
} }
// get raw z matrix with random number for shuffle // get raw z matrix with random number for shuffle
template <typename ZThreadBuffer> template <typename ZThreadBuffer,
typename Step,
typename Offset> // N3*N4=8
__host__ __device__ void GenerateZMatrixAttnFwd(ck::philox& ph, __host__ __device__ void GenerateZMatrixAttnFwd(ck::philox& ph,
index_t element_global_1d_id, index_t element_global_1d_id,
ZThreadBuffer& z_thread_buf) ZThreadBuffer& z_thread_buf)
{ {
constexpr int tmp_size = MRepeat * KRepeat; constexpr int tmp_size = MRepeat * KRepeat / Step{}.value;
int philox_calls = tmp_size / 4; int philox_calls = tmp_size / 4;
ushort tmp[tmp_size]; ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++) for(int i = 0; i < philox_calls; i++)
{ {
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8); ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * Offset{});
} }
block_sync_lds(); static_for<0, tmp_size, 1>{}([&](auto i) { z_thread_buf(i) = tmp[i.value]; });
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1;
});
});
} }
ushort p_dropout_16bits; ushort p_dropout_16bits;
......
...@@ -143,6 +143,23 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -143,6 +143,23 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{})); make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
} }
__host__ __device__ static constexpr auto GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto M0 = MXdlPerWave;
constexpr auto M1 = Gemm0MWaves;
constexpr auto N1 = Gemm0NWaves;
constexpr auto M2 = MPerXdl;
constexpr auto N2 = mfma.num_groups_per_blk;
constexpr auto N3 = mfma.num_input_blks;
constexpr auto N4 = mfma.group_size;
constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
make_naive_tensor_descriptor_packed(make_tuple(M0, I1, M1, N1, M2, N2, N3, N4));
return z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4;
}
__host__ __device__ static constexpr auto GetPaddedSize(const index_t size) __host__ __device__ static constexpr auto GetPaddedSize(const index_t size)
{ {
constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma; constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma;
...@@ -430,7 +447,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -430,7 +447,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
static constexpr auto c_block_space_size = static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
static constexpr auto z_shuffle_block_space_size = MPerBlock * NPerBlock; // LDS allocation for Z shuffle in LDS
static constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
static constexpr auto z_shuffle_block_space_size =
z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize();
}; };
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
...@@ -873,7 +894,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -873,7 +894,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// z matrix threadwise desc // z matrix threadwise desc
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = // for blockwise copy constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat
n0, // NRepeat I1, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl m2, // MPerXdl
...@@ -881,9 +902,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -881,9 +902,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
n3, // NInputNum n3, // NInputNum
n4)); // RegisterNum n4)); // RegisterNum
constexpr auto z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4 = // for blockwise copy constexpr auto z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat
n0, // NRepeat I1, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl m2, // MPerXdl
...@@ -896,7 +917,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -896,7 +917,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockId I1, // NBlockId
m0, // MRepeat m0, // MRepeat
n0, // NRepeat I1, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl m2, // MPerXdl
...@@ -904,21 +925,21 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -904,21 +925,21 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
n3, // NInputNum n3, // NInputNum
n4)); // RegisterNum n4)); // RegisterNum
constexpr auto z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto ZM0 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); constexpr auto ZM0 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto ZN0 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1); constexpr auto ZN0 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto ZM1 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2); constexpr auto ZM1 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto ZN1 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3); constexpr auto ZN1 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto ZM2 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4); constexpr auto ZM2 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto ZN2 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5); constexpr auto ZN2 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto ZN3 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6); constexpr auto ZN3 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto ZN4 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7); constexpr auto ZN4 = z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4 = constexpr auto z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4 =
transform_tensor_descriptor( transform_tensor_descriptor(
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(make_pass_through_transform(ZM0), make_tuple(make_pass_through_transform(ZM0),
make_pass_through_transform(ZN0), make_pass_through_transform(ZN0),
make_pass_through_transform(ZM1), make_pass_through_transform(ZM1),
...@@ -946,7 +967,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -946,7 +967,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
ushort, ushort,
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize(), z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize(),
true> true>
z_tensor_buffer; z_tensor_buffer;
z_tensor_buffer.Clear(); z_tensor_buffer.Clear();
...@@ -956,19 +977,19 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -956,19 +977,19 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
auto z_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto z_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ushort*>(p_shared), static_cast<ushort*>(p_shared),
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize()); z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize());
const auto wave_id = GetGemm0WaveIdx(); const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63 const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
auto z_tmp_thread_copy_vgpr_to_lds = auto z_tmp_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
ThreadwiseTensorSliceTransfer_v1r3<ushort, ushort,
ushort, ushort,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4), decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4), decltype(z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
Sequence<m0, // MRepeat Sequence<m0, // MRepeat
n0, // NRepeat I1, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl m2, // MPerXdl
...@@ -980,8 +1001,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -980,8 +1001,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
1, // DstScalarPerVector 1, // DstScalarPerVector
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>{ true>{z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_multi_index(0, // MRepeat make_multi_index(0, // MRepeat
0, // NRepeat 0, // NRepeat
wave_id[I0], // MWaveId wave_id[I0], // MWaveId
...@@ -995,14 +1015,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -995,14 +1015,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
auto z_shuffle_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< auto z_shuffle_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
ushort, ushort,
ushort, ushort,
decltype(z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4), decltype(z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4),
decltype(z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4), decltype(z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4),
Sequence<m0, n0, m1, n1, m2, n2, n3, n4, I1>, Sequence<m0, I1, m1, n1, m2, n2, n3, n4, I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>, Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>,
8, 8,
1, 1,
1, 1,
true>{z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4, true>{z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4,
make_multi_index(0, // MRepeat make_multi_index(0, // MRepeat
0, // NRepeat 0, // NRepeat
wave_id[I0], // MWaveId wave_id[I0], // MWaveId
...@@ -1022,7 +1042,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1022,7 +1042,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
Sequence<I1, // MBlockId Sequence<I1, // MBlockId
I1, // NBlockID I1, // NBlockID
m0, // MRepeat m0, // MRepeat
n0, // NRepeat I1, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl m2, // MPerXdl
...@@ -1082,6 +1102,19 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1082,6 +1102,19 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
constexpr auto c_thread_lengths = constexpr auto c_thread_lengths =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(); blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
// 8d block_desc in block scope
constexpr auto c_block_lengths =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
constexpr auto M0 = c_block_lengths[I0];
constexpr auto N0 = c_block_lengths[I1];
constexpr auto M1 = c_block_lengths[I2];
constexpr auto N1 = c_block_lengths[I3];
constexpr auto M2 = c_block_lengths[I4];
constexpr auto N2 = c_block_lengths[I5];
constexpr auto N3 = c_block_lengths[I6];
constexpr auto N4 = c_block_lengths[I7];
// works like multi-dimension static_for (static_ford), but provides both the linear // works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index // index as well as n-d index
using Acc0TileIterator = SpaceFillingCurve< using Acc0TileIterator = SpaceFillingCurve<
...@@ -1091,8 +1124,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1091,8 +1124,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
false>; // SnakeCurved false>; // SnakeCurved
constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor( constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(ZM0, ZM1, ZM2)), make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2)),
make_unmerge_transform(make_tuple(ZN0, ZN1, ZN2, ZN3, ZN4))), make_unmerge_transform(make_tuple(N0, N1, N2, N3, N4))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{})); make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
...@@ -1132,53 +1165,67 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1132,53 +1165,67 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
blockwise_softmax.Run(acc_thread_buf, workspace_buf); blockwise_softmax.Run(acc_thread_buf, workspace_buf);
constexpr auto position_offset = N3 * N4;
constexpr auto iterator_offset = n2 * n3 * n4;
if constexpr(IsDropout) // dropout if constexpr(IsDropout) // dropout
{ {
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin; static_for<0, Acc0TileIterator::GetNumOfAccess(), iterator_offset>{}([&](auto i) {
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0]; auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin;
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1]; auto m_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid; auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid; auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id = z_random_matrix_offset + m_global * raw_n_padded + auto global_elem_id = z_random_matrix_offset + m_global * raw_n_padded +
n_global; // unique element global 1d id n_global; // unique element global 1d id
blockwise_dropout.template GenerateZMatrixAttnFwd<decltype(z_tensor_buffer)>( blockwise_dropout.template GenerateZMatrixAttnFwd<decltype(z_tensor_buffer),
decltype(n0),
decltype(position_offset)>(
ph, global_elem_id, z_tensor_buffer); ph, global_elem_id, z_tensor_buffer);
z_tmp_thread_copy_vgpr_to_lds.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, z_tmp_thread_copy_vgpr_to_lds.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
z_tensor_buffer, z_tensor_buffer,
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
z_block_buf); z_block_buf);
z_shuffle_thread_copy_lds_to_vgpr.Run( z_shuffle_thread_copy_lds_to_vgpr.Run(
z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4, z_shuffle_block_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4,
z_block_buf, z_block_buf,
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4, z_shuffle_thread_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tensor_buffer); z_tensor_buffer);
blockwise_dropout.template ApplyDropoutWithZ<decltype(acc_thread_buf), blockwise_dropout.template ApplyDropoutWithZ<decltype(acc_thread_buf),
decltype(z_tensor_buffer), decltype(z_tensor_buffer),
false>(acc_thread_buf, decltype(n0),
decltype(i)>(acc_thread_buf,
z_tensor_buffer); z_tensor_buffer);
// save z to global // save z to global
if(p_z_grid) if(p_z_grid)
{ {
// ignore = z_tensor_buffer;
z_thread_copy_vgpr_to_global.Run( z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tensor_buffer, z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf); z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
} }
});
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, -(n0.value), 0, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
} }
// TODO: may convert to log domain // TODO: may convert to log domain
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment