Commit 346b4e8d authored by ltqin's avatar ltqin
Browse files

change parameter:remove p_dropout_in_16bits

parent c2d566ff
...@@ -270,8 +270,8 @@ int run(int argc, char* argv[]) ...@@ -270,8 +270,8 @@ int run(int argc, char* argv[])
ck::index_t N = 512; ck::index_t N = 512;
ck::index_t K = 128; ck::index_t K = 128;
ck::index_t O = 128; ck::index_t O = 128;
ck::index_t G0 = 1; ck::index_t G0 = 3;
ck::index_t G1 = 1; ck::index_t G1 = 2;
float alpha = 1.f / std::sqrt(K); float alpha = 1.f / std::sqrt(K);
...@@ -285,7 +285,6 @@ int run(int argc, char* argv[]) ...@@ -285,7 +285,6 @@ int run(int argc, char* argv[])
const unsigned long long seed = 1; const unsigned long long seed = 1;
const unsigned long long offset = 0; const unsigned long long offset = 0;
if(argc == 1) if(argc == 1)
{ {
// use default case // use default case
......
...@@ -82,7 +82,6 @@ __global__ void ...@@ -82,7 +82,6 @@ __global__ void
const index_t batch_count, const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask, const C0MatrixMask c0_matrix_mask,
const ushort p_dropout_in_16bits,
const float p_dropout, const float p_dropout,
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset) const unsigned long long offset)
...@@ -137,7 +136,6 @@ __global__ void ...@@ -137,7 +136,6 @@ __global__ void
ygrad_grid_desc_m0_o_m1, ygrad_grid_desc_m0_o_m1,
block_2_ctile_map, block_2_ctile_map,
c0_matrix_mask, c0_matrix_mask,
p_dropout_in_16bits,
p_dropout, p_dropout,
ph); ph);
#else #else
...@@ -779,9 +777,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle ...@@ -779,9 +777,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
} }
p_dropout_ = 1.f - p_drop; p_dropout_ = 1.f - p_drop;
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0)); float rp_dropout_ = 1.f / p_dropout_;
rp_dropout_ = 1.f / p_dropout_;
acc_element_op_.Append(rp_dropout_); acc_element_op_.Append(rp_dropout_);
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
...@@ -875,8 +871,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle ...@@ -875,8 +871,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_dropout_; float p_dropout_;
ushort p_dropout_in_16bits_;
GemmAccDataType rp_dropout_;
unsigned long long seed_; unsigned long long seed_;
unsigned long long offset_; unsigned long long offset_;
}; };
...@@ -958,7 +952,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle ...@@ -958,7 +952,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
arg.batch_count_, arg.batch_count_,
arg.compute_base_ptr_of_batch_, arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_, arg.c0_matrix_mask_,
arg.p_dropout_in_16bits_,
arg.p_dropout_, arg.p_dropout_,
arg.seed_, arg.seed_,
arg.offset_); arg.offset_);
......
...@@ -1169,10 +1169,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1169,10 +1169,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const YGradGridDesc_M0_O_M1& ygrad_grid_desc_m0_o_m1, const YGradGridDesc_M0_O_M1& ygrad_grid_desc_m0_o_m1,
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask, const C0MatrixMask& c0_matrix_mask,
const ushort p_dropout_in_16bits,
FloatGemmAcc p_dropout, FloatGemmAcc p_dropout,
ck::philox& ph) ck::philox& ph)
{ {
const ushort p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
const FloatGemmAcc rp_dropout = 1.0f / p_dropout; const FloatGemmAcc rp_dropout = 1.0f / p_dropout;
const auto q_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto q_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment