Commit 11eed39f authored by guangzlu's avatar guangzlu
Browse files

added dropou scale into device_batched_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp

parent 2ac0eefd
......@@ -27,6 +27,7 @@ template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename FloatLSE,
typename GemmAccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
......@@ -68,6 +69,7 @@ __global__ void
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask,
const ushort p_dropout_in_16bits,
GemmAccDataType p_dropout_rescale,
const unsigned long long seed)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
......@@ -110,6 +112,7 @@ __global__ void
block_2_ctile_map,
c0_matrix_mask,
p_dropout_in_16bits,
p_dropout_rescale,
ph);
#else
ignore = p_a_grid;
......@@ -549,6 +552,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
is_dropout_ = p_dropout > 0.0; //
p_dropout_ = 1.f - p_dropout;
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0));
p_dropout_ = 1.f / p_dropout_;
p_dropout_rescale_ = type_convert<GemmAccDataType>(p_dropout_);
}
void Print() const
......@@ -612,6 +617,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
float p_dropout_;
ushort p_dropout_in_16bits_;
GemmAccDataType p_dropout_rescale_;
unsigned long long seed_;
bool is_dropout_;
};
......@@ -643,6 +649,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
ADataType, // TODO: distiguish A/B datatype
CDataType,
LSEDataType,
GemmAccDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
......@@ -684,6 +691,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_,
arg.p_dropout_in_16bits_,
arg.p_dropout_rescale_,
arg.seed_);
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment