Commit f196d88b authored by guangzlu's avatar guangzlu
Browse files

added dropout into batched_gemm_softmax_gemm

parent 9fe6407e
...@@ -177,7 +177,8 @@ int run(int argc, char* argv[]) ...@@ -177,7 +177,8 @@ int run(int argc, char* argv[])
b0_element_op, b0_element_op,
acc0_element_op, acc0_element_op,
b1_element_op, b1_element_op,
c_element_op); c_element_op,
0); // dropout ratio
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
......
...@@ -115,7 +115,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermuteTrain : public BaseOperator ...@@ -115,7 +115,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermuteTrain : public BaseOperator
B0ElementwiseOperation b0_element_op, B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op, Acc0ElementwiseOperation acc0_element_op,
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) = 0; CElementwiseOperation c_element_op,
float p_dropout,
const unsigned long long seed = 0) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <sstream> #include <sstream>
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
...@@ -39,7 +40,8 @@ template <typename GridwiseGemm, ...@@ -39,7 +40,8 @@ template <typename GridwiseGemm,
typename Block2CTileMap, typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch, typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask, typename C0MatrixMask,
bool HasMainKBlockLoop> bool HasMainKBlockLoop,
bool IsDropout>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -64,7 +66,9 @@ __global__ void ...@@ -64,7 +66,9 @@ __global__ void
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask) const C0MatrixMask c0_matrix_mask,
const ushort p_dropout_in_16bits,
const unsigned long long seed)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -83,24 +87,30 @@ __global__ void ...@@ -83,24 +87,30 @@ __global__ void
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset, const index_t block_id = get_block_1d_id();
p_b_grid + b_batch_offset, ck::philox ph(seed, 0, block_id * 4);
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset, GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
p_lse_grid + lse_batch_offset, p_a_grid + a_batch_offset,
p_shared, p_b_grid + b_batch_offset,
a_element_op, p_b1_grid + b1_batch_offset,
b_element_op, p_c_grid + c_batch_offset,
acc_element_op, p_lse_grid + lse_batch_offset,
b1_element_op, p_shared,
c_element_op, a_element_op,
a_grid_desc_ak0_m_ak1, b_element_op,
b_grid_desc_bk0_n_bk1, acc_element_op,
b1_grid_desc_bk0_n_bk1, b1_element_op,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_element_op,
lse_grid_desc_m, a_grid_desc_ak0_m_ak1,
block_2_ctile_map, b_grid_desc_bk0_n_bk1,
c0_matrix_mask); b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
block_2_ctile_map,
c0_matrix_mask,
p_dropout_in_16bits,
ph);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
...@@ -463,7 +473,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -463,7 +473,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op,
float p_dropout,
unsigned long long seed)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_b1_grid_{p_b1_grid}, p_b1_grid_{p_b1_grid},
...@@ -512,7 +524,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -512,7 +524,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b_grid_desc_g_n_k_, b_grid_desc_g_n_k_,
b1_grid_desc_g_n_k_, b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_, c_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())} type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())},
seed_(seed)
{ {
// TODO ANT: implement bias addition // TODO ANT: implement bias addition
ignore = p_acc0_biases; ignore = p_acc0_biases;
...@@ -532,6 +545,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -532,6 +545,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_); c_grid_desc_m_n_);
} }
is_dropout_ = p_dropout > 0.0; //
p_dropout_ = 1.f - p_dropout;
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0));
} }
void Print() const void Print() const
...@@ -592,6 +609,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -592,6 +609,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
index_t batch_count_; index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_dropout_;
ushort p_dropout_in_16bits_;
unsigned long long seed_;
bool is_dropout_;
}; };
// Invoker // Invoker
...@@ -615,7 +637,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -615,7 +637,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v2< const auto kernel = kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v2<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
...@@ -634,7 +656,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -634,7 +656,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
C0MatrixMask, C0MatrixMask,
has_main_k_block_loop_>; has_main_k_block_loop_,
is_dropout_>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -659,18 +682,38 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -659,18 +682,38 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
arg.block_2_ctile_map_, arg.block_2_ctile_map_,
arg.batch_count_, arg.batch_count_,
arg.compute_base_ptr_of_batch_, arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_); arg.c0_matrix_mask_,
arg.p_dropout_in_16bits_,
arg.seed_);
}; };
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop // to concern Gemm0's loop
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
ave_time = launch_kernel(integral_constant<bool, true>{}); if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
} }
else else
{ {
ave_time = launch_kernel(integral_constant<bool, false>{}); if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
} }
return ave_time; return ave_time;
...@@ -793,7 +836,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -793,7 +836,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op,
float p_dropout,
const unsigned long long seed = 0)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
...@@ -819,7 +864,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -819,7 +864,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b_element_op, b_element_op,
acc_element_op, acc_element_op,
b1_element_op, b1_element_op,
c_element_op}; c_element_op,
p_dropout,
seed};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -853,7 +900,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -853,7 +900,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) override CElementwiseOperation c_element_op,
float p_dropout,
const unsigned long long seed = 0) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
...@@ -879,7 +928,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -879,7 +928,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b_element_op, b_element_op,
acc_element_op, acc_element_op,
b1_element_op, b1_element_op,
c_element_op); c_element_op,
p_dropout,
seed);
} }
// polymorphic // polymorphic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment