"...composable_kernel.git" did not exist on "904cbe2a8fe1caca4635b2c12b818b93fa9edc5d"
Commit a5bad9f2 authored by danyao12's avatar danyao12
Browse files

aligned with prototype2 dropout

parent 36dc18e8
......@@ -53,7 +53,7 @@ __global__ void
// __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1)
#endif
kernel_batched_multiheadattention_backward_xdl_cshuffle_pt1(
kernel_batched_multihead_attention_backward_xdl_cshuffle_pt1(
const DataType* __restrict__ p_a_grid,
const DataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid,
......@@ -83,7 +83,7 @@ __global__ void
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask,
const float p_dropout,
const float p_drop,
const unsigned long long seed,
const unsigned long long offset)
{
......@@ -138,7 +138,7 @@ __global__ void
ygrad_grid_desc_o0_m_o1,
block_2_ctile_map,
c0_matrix_mask,
p_dropout,
p_drop,
ph);
#else
ignore = p_a_grid;
......@@ -158,6 +158,9 @@ __global__ void
ignore = batch_count;
ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask;
ignore = p_drop;
ignore = seed;
ignore = offset;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
......@@ -758,7 +761,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
z_grid_desc_g_m_n_,
b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())}
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())},
p_drop_{p_drop}
{
// TODO: implement bias addition
ignore = p_acc0_biases;
......@@ -779,10 +783,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
y_grid_desc_m_o_);
}
p_dropout_ = 1.f - p_drop;
float rp_dropout_ = 1.f / p_dropout_;
acc_element_op_.Append(rp_dropout_);
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
......@@ -873,7 +873,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_dropout_;
float p_drop_;
unsigned long long seed_;
unsigned long long offset_;
};
......@@ -896,7 +896,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_multiheadattention_backward_xdl_cshuffle_pt1<
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_pt1<
GridwiseGemm,
DataType,
ZDataType,
......@@ -951,7 +951,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
arg.batch_count_,
arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_,
arg.p_dropout_,
arg.p_drop_,
arg.seed_,
arg.offset_);
};
......
......@@ -1259,11 +1259,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
const YGradGridDesc_O0_M_O1& ygrad_grid_desc_o0_m_o1,
const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask,
FloatGemmAcc p_dropout,
const float p_drop,
ck::philox& ph)
{
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
const FloatGemmAcc rp_dropout = 1.0f / p_dropout;
const bool is_dropout = p_drop > 0.0f;
const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() *
rp_dropout);
const auto q_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_q_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
......@@ -1670,9 +1674,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
auto kgrad_thread_copy_vgpr_to_global = typename Gemm2::template CBlockwiseCopy<
decltype(kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4),
decltype(s_element_op)>(kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4,
s_element_op);
decltype(scale_rp_dropout)>(kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4,
scale_rp_dropout);
//
// set up Y dot dY
......@@ -1871,9 +1875,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
const index_t num_gemm1_k_block_outer_loop = k_grid_desc_k0_n_k1.GetLength(I1) / NPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock;
const index_t K = k_grid_desc_k0_n_k1.GetLength(I0) * k_grid_desc_k0_n_k1.GetLength(I2);
const float scalar = 1.0f / std::sqrt(K);
// Initialize dQ
qgrad_thread_buf.Clear();
......@@ -1966,14 +1967,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
}
else
{
s_slash_p_thread_buf(i) = scalar * s_slash_p_thread_buf[i];
s_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]);
}
});
}
else
{
static_for<0, s_slash_p_thread_buf.Size(), 1>{}(
[&](auto i) { s_slash_p_thread_buf(i) = scalar * s_slash_p_thread_buf[i]; });
static_for<0, s_slash_p_thread_buf.Size(), 1>{}([&](auto i) {
s_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]);
});
}
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
......@@ -1983,25 +1986,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
// save z to global
if(p_z_grid)
{
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
true>(
s_slash_p_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
}
else
if(is_dropout)
{
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph);
if(p_z_grid)
{
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
true>(
s_slash_p_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
}
else
{
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph);
}
}
block_sync_lds(); // wait for gemm1 LDS read
......@@ -2306,7 +2313,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
n_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I4]),
s_element_op};
scale_rp_dropout};
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment