Commit 11eed39f authored by guangzlu's avatar guangzlu
Browse files

added dropou scale into device_batched_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp

parent 2ac0eefd
...@@ -27,6 +27,7 @@ template <typename GridwiseGemm, ...@@ -27,6 +27,7 @@ template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename FloatLSE, typename FloatLSE,
typename GemmAccDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename AccElementwiseOperation, typename AccElementwiseOperation,
...@@ -68,6 +69,7 @@ __global__ void ...@@ -68,6 +69,7 @@ __global__ void
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask, const C0MatrixMask c0_matrix_mask,
const ushort p_dropout_in_16bits, const ushort p_dropout_in_16bits,
GemmAccDataType p_dropout_rescale,
const unsigned long long seed) const unsigned long long seed)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
...@@ -110,6 +112,7 @@ __global__ void ...@@ -110,6 +112,7 @@ __global__ void
block_2_ctile_map, block_2_ctile_map,
c0_matrix_mask, c0_matrix_mask,
p_dropout_in_16bits, p_dropout_in_16bits,
p_dropout_rescale,
ph); ph);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
...@@ -549,6 +552,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -549,6 +552,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
is_dropout_ = p_dropout > 0.0; // is_dropout_ = p_dropout > 0.0; //
p_dropout_ = 1.f - p_dropout; p_dropout_ = 1.f - p_dropout;
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0)); p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0));
p_dropout_ = 1.f / p_dropout_;
p_dropout_rescale_ = type_convert<GemmAccDataType>(p_dropout_);
} }
void Print() const void Print() const
...@@ -612,6 +617,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -612,6 +617,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
float p_dropout_; float p_dropout_;
ushort p_dropout_in_16bits_; ushort p_dropout_in_16bits_;
GemmAccDataType p_dropout_rescale_;
unsigned long long seed_; unsigned long long seed_;
bool is_dropout_; bool is_dropout_;
}; };
...@@ -643,6 +649,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -643,6 +649,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
LSEDataType, LSEDataType,
GemmAccDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
...@@ -684,6 +691,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -684,6 +691,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
arg.compute_base_ptr_of_batch_, arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_, arg.c0_matrix_mask_,
arg.p_dropout_in_16bits_, arg.p_dropout_in_16bits_,
arg.p_dropout_rescale_,
arg.seed_); arg.seed_);
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment