Commit bb06d009 authored by ltqin's avatar ltqin
Browse files

add drop parameter in device

parent 5012068b
...@@ -76,7 +76,12 @@ __global__ void ...@@ -76,7 +76,12 @@ __global__ void
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask) const C0MatrixMask c0_matrix_mask,
const ushort p_dropout_in_16bits,
const float p_dropout,
const float rp_dropout,
const unsigned long long seed,
const unsigned long long offset)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -97,13 +102,8 @@ __global__ void ...@@ -97,13 +102,8 @@ __global__ void
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
float p_dropout = 1 - 0.2; const index_t global_thread_id = get_thread_global_1d_id();
const ushort p_dropout_in_16bits = 65536 * p_dropout; ck::philox ph(seed, global_thread_id, offset);
float rp_dropout = 1.0 / p_dropout;
const unsigned long long seed = 0;
const index_t block_id = get_block_1d_id();
ck::philox ph(seed, 0, block_id * 4);
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
...@@ -665,8 +665,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle ...@@ -665,8 +665,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float, float p_drop,
std::tuple<unsigned long long, unsigned long long>) std::tuple<unsigned long long, unsigned long long> seeds)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_b1_grid_{p_b1_grid}, p_b1_grid_{p_b1_grid},
...@@ -743,6 +743,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle ...@@ -743,6 +743,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
y_grid_desc_m_o_); y_grid_desc_m_o_);
} }
p_dropout_ = 1.f - p_drop;
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0));
rp_dropout_ = 1.f / p_dropout_;
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
// Print(); // Print();
} }
...@@ -821,6 +828,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle ...@@ -821,6 +828,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
index_t batch_count_; index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_dropout_;
ushort p_dropout_in_16bits_;
GemmAccDataType rp_dropout_;
unsigned long long seed_;
unsigned long long offset_;
}; };
// Invoker // Invoker
...@@ -895,7 +908,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle ...@@ -895,7 +908,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
arg.block_2_ctile_map_, arg.block_2_ctile_map_,
arg.batch_count_, arg.batch_count_,
arg.compute_base_ptr_of_batch_, arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_); arg.c0_matrix_mask_,
arg.p_dropout_in_16bits_,
arg.p_dropout_,
arg.rp_dropout_,
arg.seed_,
arg.offset_);
}; };
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
...@@ -1036,7 +1054,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle ...@@ -1036,7 +1054,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_drop, float p_drop,
std::tuple<unsigned long long, unsigned long long> seed) std::tuple<unsigned long long, unsigned long long> seeds)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
...@@ -1068,7 +1086,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle ...@@ -1068,7 +1086,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_drop, p_drop,
seed}; seeds};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -1108,7 +1126,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle ...@@ -1108,7 +1126,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
float p_drop, float p_drop,
std::tuple<unsigned long long, unsigned long long> seed) // override std::tuple<unsigned long long, unsigned long long> seeds) // override
{ {
return std::make_unique<Argument>(static_cast<const DataType*>(p_a), return std::make_unique<Argument>(static_cast<const DataType*>(p_a),
static_cast<const DataType*>(p_b), static_cast<const DataType*>(p_b),
...@@ -1140,7 +1158,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle ...@@ -1140,7 +1158,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_drop, p_drop,
seed); seeds);
} }
// polymorphic // polymorphic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment