Commit f196d88b authored by guangzlu's avatar guangzlu
Browse files

added dropout into batched_gemm_softmax_gemm

parent 9fe6407e
......@@ -177,7 +177,8 @@ int run(int argc, char* argv[])
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op);
c_element_op,
0); // dropout ratio
if(!gemm.IsSupportedArgument(argument))
{
......
......@@ -115,7 +115,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermuteTrain : public BaseOperator
B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) = 0;
CElementwiseOperation c_element_op,
float p_dropout,
const unsigned long long seed = 0) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
......
......@@ -7,6 +7,7 @@
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
......@@ -39,7 +40,8 @@ template <typename GridwiseGemm,
typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask,
bool HasMainKBlockLoop>
bool HasMainKBlockLoop,
bool IsDropout>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
......@@ -64,7 +66,9 @@ __global__ void
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask)
const C0MatrixMask c0_matrix_mask,
const ushort p_dropout_in_16bits,
const unsigned long long seed)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -83,7 +87,11 @@ __global__ void
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
const index_t block_id = get_block_1d_id();
ck::philox ph(seed, 0, block_id * 4);
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
......@@ -100,7 +108,9 @@ __global__ void
c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
block_2_ctile_map,
c0_matrix_mask);
c0_matrix_mask,
p_dropout_in_16bits,
ph);
#else
ignore = p_a_grid;
ignore = p_b_grid;
......@@ -463,7 +473,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op)
CElementwiseOperation c_element_op,
float p_dropout,
unsigned long long seed)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_b1_grid_{p_b1_grid},
......@@ -512,7 +524,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b_grid_desc_g_n_k_,
b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())}
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())},
seed_(seed)
{
// TODO ANT: implement bias addition
ignore = p_acc0_biases;
......@@ -532,6 +545,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_);
}
is_dropout_ = p_dropout > 0.0; //
p_dropout_ = 1.f - p_dropout;
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0));
}
void Print() const
......@@ -592,6 +609,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_dropout_;
ushort p_dropout_in_16bits_;
unsigned long long seed_;
bool is_dropout_;
};
// Invoker
......@@ -615,7 +637,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v2<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
......@@ -634,7 +656,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch,
C0MatrixMask,
has_main_k_block_loop_>;
has_main_k_block_loop_,
is_dropout_>;
return launch_and_time_kernel(stream_config,
kernel,
......@@ -659,18 +682,38 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
arg.block_2_ctile_map_,
arg.batch_count_,
arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_);
arg.c0_matrix_mask_,
arg.p_dropout_in_16bits_,
arg.seed_);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
ave_time = launch_kernel(integral_constant<bool, true>{});
if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{});
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
}
else
{
if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
return ave_time;
......@@ -793,7 +836,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op)
CElementwiseOperation c_element_op,
float p_dropout,
const unsigned long long seed = 0)
{
return Argument{p_a,
p_b,
......@@ -819,7 +864,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b_element_op,
acc_element_op,
b1_element_op,
c_element_op};
c_element_op,
p_dropout,
seed};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -853,7 +900,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) override
CElementwiseOperation c_element_op,
float p_dropout,
const unsigned long long seed = 0) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
......@@ -879,7 +928,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b_element_op,
acc_element_op,
b1_element_op,
c_element_op);
c_element_op,
p_dropout,
seed);
}
// polymorphic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment