Commit 6d7a5784 authored by ltqin's avatar ltqin
Browse files

add is_dropout parameter to gridwise

parent 63b37aa0
......@@ -83,6 +83,7 @@ __global__ void
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask,
const float p_dropout,
const bool is_dropout,
const unsigned long long seed,
const unsigned long long offset)
{
......@@ -138,6 +139,7 @@ __global__ void
block_2_ctile_map,
c0_matrix_mask,
p_dropout,
is_dropout,
ph);
#else
ignore = p_a_grid;
......@@ -157,6 +159,10 @@ __global__ void
ignore = batch_count;
ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask;
ignore = p_dropout;
ignore = is_dropout;
ignore = seed;
ignore = offset;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
......@@ -778,6 +784,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
}
p_dropout_ = 1.f - p_drop;
is_dropout_ = p_drop > 0.0f;
float rp_dropout_ = 1.f / p_dropout_;
acc_element_op_.Append(rp_dropout_);
......@@ -872,6 +879,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_dropout_;
bool is_dropout_;
unsigned long long seed_;
unsigned long long offset_;
};
......@@ -954,6 +962,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_,
arg.p_dropout_,
arg.is_dropout_,
arg.seed_,
arg.offset_);
};
......
......@@ -1170,6 +1170,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask,
FloatGemmAcc p_dropout,
const bool is_dropout,
ck::philox& ph)
{
const ushort p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
......@@ -1492,7 +1493,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
n3, // NInputNum
n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim
9, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
......@@ -1847,25 +1848,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
// save z to global
if(p_z_grid)
if(is_dropout)
{
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
true>(
s_slash_p_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
}
else
{
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph);
if(p_z_grid)
{
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
true>(
s_slash_p_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
}
else
{
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph);
}
}
block_sync_lds(); // wait for gemm1 LDS read
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment