Commit 2ac0eefd authored by guangzlu's avatar guangzlu
Browse files

added dropout rescale into grouped_gemm_softmax_gemm

parent 6926effa
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
namespace ck { namespace ck {
template <typename ThreadSliceDesc_M_K> template <typename DataType, typename ThreadSliceDesc_M_K>
struct BlockwiseDropout struct BlockwiseDropout
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -18,13 +18,14 @@ struct BlockwiseDropout ...@@ -18,13 +18,14 @@ struct BlockwiseDropout
template <typename CThreadBuffer> template <typename CThreadBuffer>
__host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf, __host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf,
ushort p_dropout_16bits,
ck::philox ph, ck::philox ph,
const int repeat_index, const int repeat_index,
const int total_repeats) const int total_repeats)
{ {
auto if_dropout = [](bool keep, float val) { return keep ? val : float(0); }; auto execute_dropout = [&](bool keep, DataType val) {
return keep ? val * p_dropout_rescale : float(0);
};
constexpr int tmp_size = MRepeat * KRepeat; constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 8; int philox_calls = tmp_size / 8;
...@@ -45,11 +46,14 @@ struct BlockwiseDropout ...@@ -45,11 +46,14 @@ struct BlockwiseDropout
static_for<0, KRepeat, 1>{}([&](auto iK) { static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{}; auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = in_thread_buf(offset) =
if_dropout(tmp[tmp_index] < p_dropout_16bits, in_thread_buf(offset)); execute_dropout(tmp[tmp_index] < p_dropout_16bits, in_thread_buf(offset));
tmp_index = tmp_index + 1; tmp_index = tmp_index + 1;
}); });
}); });
} }
ushort p_dropout_16bits;
DataType p_dropout_rescale;
}; };
} // namespace ck } // namespace ck
...@@ -24,6 +24,7 @@ namespace tensor_operation { ...@@ -24,6 +24,7 @@ namespace tensor_operation {
namespace device { namespace device {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename GemmAccDataType,
typename GroupKernelArg, typename GroupKernelArg,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
...@@ -45,6 +46,7 @@ __global__ void ...@@ -45,6 +46,7 @@ __global__ void
const B1ElementwiseOperation b1_element_op, const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const ushort p_dropout_in_16bits, const ushort p_dropout_in_16bits,
GemmAccDataType p_dropout_rescale,
const unsigned long long seed) const unsigned long long seed)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
...@@ -52,7 +54,7 @@ __global__ void ...@@ -52,7 +54,7 @@ __global__ void
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
ck::philox ph(seed, 0, block_id); ck::philox ph(seed, 0, block_id * 4);
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>( const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
cast_pointer_to_generic_address_space(group_kernel_args)); cast_pointer_to_generic_address_space(group_kernel_args));
...@@ -111,6 +113,7 @@ __global__ void ...@@ -111,6 +113,7 @@ __global__ void
arg_ptr[group_id].block_2_ctile_map_, arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_, arg_ptr[group_id].c0_matrix_mask_,
p_dropout_in_16bits, p_dropout_in_16bits,
p_dropout_rescale,
ph); ph);
#else #else
ignore = group_kernel_args; ignore = group_kernel_args;
...@@ -642,6 +645,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -642,6 +645,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
is_dropout_ = p_dropout > 0.0; // is_dropout_ = p_dropout > 0.0; //
p_dropout_ = 1.f - p_dropout; p_dropout_ = 1.f - p_dropout;
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0)); p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0));
p_dropout_ = 1.f / p_dropout_;
p_dropout_rescale_ = type_convert<GemmAccDataType>(p_dropout_);
} }
std::vector<GroupKernelArg> group_kernel_args_; std::vector<GroupKernelArg> group_kernel_args_;
...@@ -659,6 +664,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -659,6 +664,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
float p_dropout_; float p_dropout_;
ushort p_dropout_in_16bits_; ushort p_dropout_in_16bits_;
unsigned long long seed_; unsigned long long seed_;
GemmAccDataType p_dropout_rescale_;
bool is_dropout_; bool is_dropout_;
}; };
...@@ -695,6 +701,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -695,6 +701,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) { auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = const auto kernel =
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2<GridwiseGemm, kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2<GridwiseGemm,
GemmAccDataType,
GroupKernelArg, GroupKernelArg,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
...@@ -718,6 +725,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -718,6 +725,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
arg.b1_element_op_, arg.b1_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.p_dropout_in_16bits_, arg.p_dropout_in_16bits_,
arg.p_dropout_rescale_,
arg.seed_); arg.seed_);
}; };
......
...@@ -383,6 +383,7 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle ...@@ -383,6 +383,7 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask, const C0MatrixMask& c0_matrix_mask,
const ushort p_dropout_in_16bits, const ushort p_dropout_in_16bits,
FloatGemmAcc p_dropout_rescale,
ck::philox ph) ck::philox ph)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -728,7 +729,8 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle ...@@ -728,7 +729,8 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
decltype(thread_cluster_desc_m_n), decltype(thread_cluster_desc_m_n),
decltype(thread_slice_desc_m_n)>{}; decltype(thread_slice_desc_m_n)>{};
auto blockwise_dropout = BlockwiseDropout<decltype(thread_slice_desc_m_n)>{}; auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{
p_dropout_in_16bits, p_dropout_rescale};
const index_t num_gemm1_k_block_outer_loop = const index_t num_gemm1_k_block_outer_loop =
b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock; b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock;
...@@ -873,11 +875,8 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle ...@@ -873,11 +875,8 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
if constexpr(IsDropout) // dropout if constexpr(IsDropout) // dropout
{ {
blockwise_dropout.ApplyDropout(acc_thread_buf, blockwise_dropout.ApplyDropout(
p_dropout_in_16bits, acc_thread_buf, ph, gemm1_k_block_outer_index, num_gemm1_k_block_outer_loop);
ph,
gemm1_k_block_outer_index,
num_gemm1_k_block_outer_loop);
} }
// TODO: may convert to log domain // TODO: may convert to log domain
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment