"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "03aa52bc5c2d3fabfb7e24bc522dd726a4e0a984"
Commit 95a9bacf authored by guangzlu's avatar guangzlu
Browse files

bwd dropout refactor

parent 5e319fb8
...@@ -82,7 +82,9 @@ __global__ void ...@@ -82,7 +82,9 @@ __global__ void
const C0MatrixMask c0_matrix_mask, const C0MatrixMask c0_matrix_mask,
const float p_drop, const float p_drop,
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset) const unsigned long long offset,
const index_t MRaw,
const index_t NRaw)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -135,7 +137,10 @@ __global__ void ...@@ -135,7 +137,10 @@ __global__ void
block_2_ctile_map, block_2_ctile_map,
c0_matrix_mask, c0_matrix_mask,
p_drop, p_drop,
ph); ph,
g_idx,
MRaw,
NRaw);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
...@@ -951,7 +956,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -951,7 +956,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
arg.c0_matrix_mask_, arg.c0_matrix_mask_,
arg.p_drop_, arg.p_drop_,
arg.seed_, arg.seed_,
arg.offset_); arg.offset_,
arg.raw_lengths_mz_nz_kz_gemm1nz_[0],
arg.raw_lengths_mz_nz_kz_gemm1nz_[1]);
}; };
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
......
...@@ -1262,7 +1262,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1262,7 +1262,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask, const C0MatrixMask& c0_matrix_mask,
const float p_drop, const float p_drop,
ck::philox& ph) ck::philox& ph,
const index_t g_idx,
const index_t MRaw,
const index_t NRaw)
{ {
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop); const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout); const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
...@@ -1941,6 +1944,48 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1941,6 +1944,48 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// save z to global // save z to global
if(p_z_grid) if(p_z_grid)
{ {
// 8d thread_desc in thread scope
constexpr auto c_thread_lengths =
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
// 8d block_desc in block scope
constexpr auto c_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
constexpr auto M0 = c_block_lengths[I0];
constexpr auto N0 = c_block_lengths[I1];
constexpr auto M1 = c_block_lengths[I2];
constexpr auto N1 = c_block_lengths[I3];
constexpr auto M2 = c_block_lengths[I4];
constexpr auto N2 = c_block_lengths[I5];
constexpr auto N3 = c_block_lengths[I6];
constexpr auto N4 = c_block_lengths[I7];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using Acc0TileIterator = SpaceFillingCurve<
decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>; // SnakeCurved
constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2)),
make_unmerge_transform(make_tuple(N0, N1, N2, N3, N4))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id =
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
// P_dropped // P_dropped
static_for<0, n0, 1>{}([&](auto i) { static_for<0, n0, 1>{}([&](auto i) {
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
...@@ -1948,7 +1993,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1948,7 +1993,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
true, true,
decltype(n0), decltype(n0),
decltype(i)>( decltype(i)>(
s_slash_p_thread_buf, ph, z_tenor_buffer); s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run( z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
...@@ -1966,10 +2011,52 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1966,10 +2011,52 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
} }
else else
{ {
// 8d thread_desc in thread scope
constexpr auto c_thread_lengths =
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
// 8d block_desc in block scope
constexpr auto c_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
constexpr auto M0 = c_block_lengths[I0];
constexpr auto N0 = c_block_lengths[I1];
constexpr auto M1 = c_block_lengths[I2];
constexpr auto N1 = c_block_lengths[I3];
constexpr auto M2 = c_block_lengths[I4];
constexpr auto N2 = c_block_lengths[I5];
constexpr auto N3 = c_block_lengths[I6];
constexpr auto N4 = c_block_lengths[I7];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using Acc0TileIterator = SpaceFillingCurve<
decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>; // SnakeCurved
constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2)),
make_unmerge_transform(make_tuple(N0, N1, N2, N3, N4))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id =
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
ignore = z_grid_buf; ignore = z_grid_buf;
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>( blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph); s_slash_p_thread_buf, ph, global_elem_id);
} }
block_sync_lds(); // wait for gemm1 LDS read block_sync_lds(); // wait for gemm1 LDS read
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment