Commit 711663bc authored by qin letao's avatar qin letao
Browse files

add P and ds dropout

parent e3a2651b
......@@ -16,7 +16,7 @@ struct BlockwiseDropout
static constexpr index_t MRepeat = ThreadSliceDesc_M_K{}.GetLength(I0);
static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1);
template <typename CThreadBuffer>
template <typename CThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf,
ck::philox ph,
const int repeat_index,
......@@ -24,7 +24,8 @@ struct BlockwiseDropout
{
auto execute_dropout = [&](bool keep, DataType val) {
return keep ? val * p_dropout_rescale : float(0);
return keep ? val * p_dropout_rescale
: (using_sign_bit ? -val * p_dropout_rescale : float(0));
};
constexpr int tmp_size = MRepeat * KRepeat;
......
......@@ -130,6 +130,7 @@ __global__ void
block_2_ctile_map,
c0_matrix_mask,
p_dropout_in_16bits,
p_dropout,
rp_dropout,
ph);
#else
......
......@@ -16,6 +16,7 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_dropout.hpp"
namespace ck {
......@@ -1121,6 +1122,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask,
const ushort p_dropout_in_16bits,
FloatGemmAcc p_dropout,
FloatGemmAcc rp_dropout,
ck::philox& ph)
{
......@@ -1357,6 +1359,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
decltype(thread_cluster_desc_m_n),
decltype(thread_slice_desc_m_n)>{};
auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{
p_dropout_in_16bits, rp_dropout};
auto lse_grid_desc_mblock_mrepeat_mwave_mperxdl =
MakeLSEGridDescriptor_MBlock_MRepeat_NWave_MPerXdl(lse_grid_desc_m);
......@@ -1600,7 +1605,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
static_for<0, YDotYGrad_M_O::ThreadSliceLength_M, 1>{}([&](auto iM) {
const auto idx_on_block = y_thread_data_on_block_idx[I1] + iM;
y_dot_ygrad_block_accum_buf.AtomicAdd(
idx_on_block, true, y_dot_ygrad_thread_accum_buf[iM]);
idx_on_block, true, y_dot_ygrad_thread_accum_buf[iM] * p_dropout); // p_dropoutD1
});
block_sync_lds();
......@@ -1717,6 +1722,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph, gemm1_k_block_outer_index, num_gemm1_k_block_outer_loop);
block_sync_lds(); // wait for gemm1 LDS read
SubThreadBlock<BlockSize> gemm2_a_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0],
......@@ -1807,8 +1816,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
constexpr auto m =
pgrad_thread_idx_to_m_n_adaptor.CalculateBottomIndex(pgrad_thread_idx)[I0];
// dS and P has same thread buf layout
sgrad_thread_buf(i) = s_slash_p_thread_buf[i] *
if(s_slash_p_thread_buf[i] >= 0)
{
sgrad_thread_buf(i) =
s_slash_p_thread_buf[i] *
(pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}]);
}
else
{
sgrad_thread_buf(i) =
s_slash_p_thread_buf[i] * y_dot_ygrad_thread_buf[Number<m>{}];
}
});
// gemm dQ
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment