Commit 346b4e8d authored by ltqin's avatar ltqin
Browse files

change parameter:remove p_dropout_in_16bits

parent c2d566ff
......@@ -270,8 +270,8 @@ int run(int argc, char* argv[])
ck::index_t N = 512;
ck::index_t K = 128;
ck::index_t O = 128;
ck::index_t G0 = 1;
ck::index_t G1 = 1;
ck::index_t G0 = 3;
ck::index_t G1 = 2;
float alpha = 1.f / std::sqrt(K);
......@@ -285,7 +285,6 @@ int run(int argc, char* argv[])
const unsigned long long seed = 1;
const unsigned long long offset = 0;
if(argc == 1)
{
// use default case
......
......@@ -82,7 +82,6 @@ __global__ void
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask,
const ushort p_dropout_in_16bits,
const float p_dropout,
const unsigned long long seed,
const unsigned long long offset)
......@@ -137,7 +136,6 @@ __global__ void
ygrad_grid_desc_m0_o_m1,
block_2_ctile_map,
c0_matrix_mask,
p_dropout_in_16bits,
p_dropout,
ph);
#else
......@@ -778,10 +776,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
y_grid_desc_m_o_);
}
p_dropout_ = 1.f - p_drop;
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0));
rp_dropout_ = 1.f / p_dropout_;
p_dropout_ = 1.f - p_drop;
float rp_dropout_ = 1.f / p_dropout_;
acc_element_op_.Append(rp_dropout_);
seed_ = std::get<0>(seeds);
......@@ -875,8 +871,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_dropout_;
ushort p_dropout_in_16bits_;
GemmAccDataType rp_dropout_;
unsigned long long seed_;
unsigned long long offset_;
};
......@@ -958,7 +952,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
arg.batch_count_,
arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_,
arg.p_dropout_in_16bits_,
arg.p_dropout_,
arg.seed_,
arg.offset_);
......
......@@ -1169,11 +1169,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const YGradGridDesc_M0_O_M1& ygrad_grid_desc_m0_o_m1,
const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask,
const ushort p_dropout_in_16bits,
FloatGemmAcc p_dropout,
ck::philox& ph)
{
const FloatGemmAcc rp_dropout = 1.0f / p_dropout;
const ushort p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
const FloatGemmAcc rp_dropout = 1.0f / p_dropout;
const auto q_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_q_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment