Commit 6fd1490b authored by ltqin's avatar ltqin
Browse files

change dropout to one parameter

parent 6e676287
...@@ -81,8 +81,7 @@ __global__ void ...@@ -81,8 +81,7 @@ __global__ void
const index_t batch_count, const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask, const C0MatrixMask c0_matrix_mask,
const float p_dropout, const float p_drop,
const bool is_dropout,
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset) const unsigned long long offset)
{ {
...@@ -137,8 +136,7 @@ __global__ void ...@@ -137,8 +136,7 @@ __global__ void
ygrad_grid_desc_m0_o_m1, ygrad_grid_desc_m0_o_m1,
block_2_ctile_map, block_2_ctile_map,
c0_matrix_mask, c0_matrix_mask,
p_dropout, p_drop,
is_dropout,
ph); ph);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
...@@ -158,8 +156,7 @@ __global__ void ...@@ -158,8 +156,7 @@ __global__ void
ignore = batch_count; ignore = batch_count;
ignore = compute_base_ptr_of_batch; ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask; ignore = c0_matrix_mask;
ignore = p_dropout; ignore = p_drop;
ignore = is_dropout;
ignore = seed; ignore = seed;
ignore = offset; ignore = offset;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
...@@ -761,7 +758,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -761,7 +758,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
z_grid_desc_g_m_n_, z_grid_desc_g_m_n_,
b1_grid_desc_g_n_k_, b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_, c_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())} type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())},
p_drop_{p_drop}
{ {
// TODO: implement bias addition // TODO: implement bias addition
ignore = p_acc0_biases; ignore = p_acc0_biases;
...@@ -782,9 +780,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -782,9 +780,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
y_grid_desc_m_o_); y_grid_desc_m_o_);
} }
p_dropout_ = 1.f - p_drop;
is_dropout_ = p_drop > 0.0f;
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds); offset_ = std::get<1>(seeds);
...@@ -875,8 +870,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -875,8 +870,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
index_t batch_count_; index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_dropout_; float p_drop_;
bool is_dropout_;
unsigned long long seed_; unsigned long long seed_;
unsigned long long offset_; unsigned long long offset_;
}; };
...@@ -958,8 +952,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -958,8 +952,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
arg.batch_count_, arg.batch_count_,
arg.compute_base_ptr_of_batch_, arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_, arg.c0_matrix_mask_,
arg.p_dropout_, arg.p_drop_,
arg.is_dropout_,
arg.seed_, arg.seed_,
arg.offset_); arg.offset_);
}; };
......
...@@ -1169,12 +1169,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1169,12 +1169,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const YGradGridDesc_M0_O_M1& ygrad_grid_desc_m0_o_m1, const YGradGridDesc_M0_O_M1& ygrad_grid_desc_m0_o_m1,
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask, const C0MatrixMask& c0_matrix_mask,
FloatGemmAcc p_dropout, const float p_drop,
const bool is_dropout,
ck::philox& ph) ck::philox& ph)
{ {
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0)); const ushort p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
const FloatGemmAcc rp_dropout = 1.0f / p_dropout; const bool is_dropout = p_drop > 0.0f;
const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() * const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() *
rp_dropout); rp_dropout);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment