"driver/src/driver.cpp" did not exist on "87d8740bf5d8030f9e4e54c9b7e64f353a6f944e"
Commit 6d7a5784 authored by ltqin's avatar ltqin
Browse files

add is_dropout parameter to gridwise

parent 63b37aa0
...@@ -83,6 +83,7 @@ __global__ void ...@@ -83,6 +83,7 @@ __global__ void
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask, const C0MatrixMask c0_matrix_mask,
const float p_dropout, const float p_dropout,
const bool is_dropout,
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset) const unsigned long long offset)
{ {
...@@ -138,6 +139,7 @@ __global__ void ...@@ -138,6 +139,7 @@ __global__ void
block_2_ctile_map, block_2_ctile_map,
c0_matrix_mask, c0_matrix_mask,
p_dropout, p_dropout,
is_dropout,
ph); ph);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
...@@ -157,6 +159,10 @@ __global__ void ...@@ -157,6 +159,10 @@ __global__ void
ignore = batch_count; ignore = batch_count;
ignore = compute_base_ptr_of_batch; ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask; ignore = c0_matrix_mask;
ignore = p_dropout;
ignore = is_dropout;
ignore = seed;
ignore = offset;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -778,6 +784,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -778,6 +784,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
} }
p_dropout_ = 1.f - p_drop; p_dropout_ = 1.f - p_drop;
is_dropout_ = p_drop > 0.0f;
float rp_dropout_ = 1.f / p_dropout_; float rp_dropout_ = 1.f / p_dropout_;
acc_element_op_.Append(rp_dropout_); acc_element_op_.Append(rp_dropout_);
...@@ -872,6 +879,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -872,6 +879,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_dropout_; float p_dropout_;
bool is_dropout_;
unsigned long long seed_; unsigned long long seed_;
unsigned long long offset_; unsigned long long offset_;
}; };
...@@ -954,6 +962,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -954,6 +962,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
arg.compute_base_ptr_of_batch_, arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_, arg.c0_matrix_mask_,
arg.p_dropout_, arg.p_dropout_,
arg.is_dropout_,
arg.seed_, arg.seed_,
arg.offset_); arg.offset_);
}; };
......
...@@ -1170,6 +1170,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1170,6 +1170,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask, const C0MatrixMask& c0_matrix_mask,
FloatGemmAcc p_dropout, FloatGemmAcc p_dropout,
const bool is_dropout,
ck::philox& ph) ck::philox& ph)
{ {
const ushort p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0)); const ushort p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
...@@ -1847,6 +1848,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1847,6 +1848,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf); blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
// save z to global // save z to global
if(is_dropout)
{
if(p_z_grid) if(p_z_grid)
{ {
// P_dropped // P_dropped
...@@ -1855,7 +1858,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1855,7 +1858,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
true>( true>(
s_slash_p_thread_buf, ph, z_tenor_buffer); s_slash_p_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer, z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
...@@ -1867,6 +1871,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1867,6 +1871,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>( blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph); s_slash_p_thread_buf, ph);
} }
}
block_sync_lds(); // wait for gemm1 LDS read block_sync_lds(); // wait for gemm1 LDS read
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment