Commit bb06d009 authored by ltqin's avatar ltqin
Browse files

add drop parameter in device

parent 5012068b
......@@ -76,7 +76,12 @@ __global__ void
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask)
const C0MatrixMask c0_matrix_mask,
const ushort p_dropout_in_16bits,
const float p_dropout,
const float rp_dropout,
const unsigned long long seed,
const unsigned long long offset)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -97,13 +102,8 @@ __global__ void
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
float p_dropout = 1 - 0.2;
const ushort p_dropout_in_16bits = 65536 * p_dropout;
float rp_dropout = 1.0 / p_dropout;
const unsigned long long seed = 0;
const index_t block_id = get_block_1d_id();
ck::philox ph(seed, 0, block_id * 4);
const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, global_thread_id, offset);
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
......@@ -665,8 +665,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float,
std::tuple<unsigned long long, unsigned long long>)
float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_b1_grid_{p_b1_grid},
......@@ -743,6 +743,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
y_grid_desc_m_o_);
}
p_dropout_ = 1.f - p_drop;
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0));
rp_dropout_ = 1.f / p_dropout_;
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
// Print();
}
......@@ -821,6 +828,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_dropout_;
ushort p_dropout_in_16bits_;
GemmAccDataType rp_dropout_;
unsigned long long seed_;
unsigned long long offset_;
};
// Invoker
......@@ -895,7 +908,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
arg.block_2_ctile_map_,
arg.batch_count_,
arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_);
arg.c0_matrix_mask_,
arg.p_dropout_in_16bits_,
arg.p_dropout_,
arg.rp_dropout_,
arg.seed_,
arg.offset_);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
......@@ -1036,7 +1054,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
std::tuple<unsigned long long, unsigned long long> seed)
std::tuple<unsigned long long, unsigned long long> seeds)
{
return Argument{p_a,
p_b,
......@@ -1068,7 +1086,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
b1_element_op,
c_element_op,
p_drop,
seed};
seeds};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -1108,7 +1126,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
std::tuple<unsigned long long, unsigned long long> seed) // override
std::tuple<unsigned long long, unsigned long long> seeds) // override
{
return std::make_unique<Argument>(static_cast<const DataType*>(p_a),
static_cast<const DataType*>(p_b),
......@@ -1140,7 +1158,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
b1_element_op,
c_element_op,
p_drop,
seed);
seeds);
}
// polymorphic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment