Commit 2ac0eefd authored by guangzlu's avatar guangzlu
Browse files

added dropout rescale into grouped_gemm_softmax_gemm

parent 6926effa
......@@ -8,7 +8,7 @@
namespace ck {
template <typename ThreadSliceDesc_M_K>
template <typename DataType, typename ThreadSliceDesc_M_K>
struct BlockwiseDropout
{
static constexpr auto I0 = Number<0>{};
......@@ -18,13 +18,14 @@ struct BlockwiseDropout
template <typename CThreadBuffer>
__host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf,
ushort p_dropout_16bits,
ck::philox ph,
const int repeat_index,
const int total_repeats)
{
auto if_dropout = [](bool keep, float val) { return keep ? val : float(0); };
auto execute_dropout = [&](bool keep, DataType val) {
return keep ? val * p_dropout_rescale : float(0);
};
constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 8;
......@@ -45,11 +46,14 @@ struct BlockwiseDropout
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) =
if_dropout(tmp[tmp_index] < p_dropout_16bits, in_thread_buf(offset));
execute_dropout(tmp[tmp_index] < p_dropout_16bits, in_thread_buf(offset));
tmp_index = tmp_index + 1;
});
});
}
ushort p_dropout_16bits;
DataType p_dropout_rescale;
};
} // namespace ck
......@@ -24,6 +24,7 @@ namespace tensor_operation {
namespace device {
template <typename GridwiseGemm,
typename GemmAccDataType,
typename GroupKernelArg,
typename AElementwiseOperation,
typename BElementwiseOperation,
......@@ -45,6 +46,7 @@ __global__ void
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op,
const ushort p_dropout_in_16bits,
GemmAccDataType p_dropout_rescale,
const unsigned long long seed)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
......@@ -52,7 +54,7 @@ __global__ void
const index_t block_id = get_block_1d_id();
ck::philox ph(seed, 0, block_id);
ck::philox ph(seed, 0, block_id * 4);
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
cast_pointer_to_generic_address_space(group_kernel_args));
......@@ -111,6 +113,7 @@ __global__ void
arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_,
p_dropout_in_16bits,
p_dropout_rescale,
ph);
#else
ignore = group_kernel_args;
......@@ -642,6 +645,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
is_dropout_ = p_dropout > 0.0; //
p_dropout_ = 1.f - p_dropout;
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0));
p_dropout_ = 1.f / p_dropout_;
p_dropout_rescale_ = type_convert<GemmAccDataType>(p_dropout_);
}
std::vector<GroupKernelArg> group_kernel_args_;
......@@ -659,6 +664,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
float p_dropout_;
ushort p_dropout_in_16bits_;
unsigned long long seed_;
GemmAccDataType p_dropout_rescale_;
bool is_dropout_;
};
......@@ -695,6 +701,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel =
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2<GridwiseGemm,
GemmAccDataType,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
......@@ -718,6 +725,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
arg.b1_element_op_,
arg.c_element_op_,
arg.p_dropout_in_16bits_,
arg.p_dropout_rescale_,
arg.seed_);
};
......
......@@ -383,6 +383,7 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask,
const ushort p_dropout_in_16bits,
FloatGemmAcc p_dropout_rescale,
ck::philox ph)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......@@ -728,7 +729,8 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
decltype(thread_cluster_desc_m_n),
decltype(thread_slice_desc_m_n)>{};
auto blockwise_dropout = BlockwiseDropout<decltype(thread_slice_desc_m_n)>{};
auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{
p_dropout_in_16bits, p_dropout_rescale};
const index_t num_gemm1_k_block_outer_loop =
b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock;
......@@ -873,11 +875,8 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
if constexpr(IsDropout) // dropout
{
blockwise_dropout.ApplyDropout(acc_thread_buf,
p_dropout_in_16bits,
ph,
gemm1_k_block_outer_index,
num_gemm1_k_block_outer_loop);
blockwise_dropout.ApplyDropout(
acc_thread_buf, ph, gemm1_k_block_outer_index, num_gemm1_k_block_outer_loop);
}
// TODO: may convert to log domain
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment