"git@developer.sourcefind.cn:OpenDAS/opencompass.git" did not exist on "3871188c89e7841e54f40363a5bb7dc62afa2510"
Commit 9fe6407e authored by guangzlu's avatar guangzlu
Browse files

added dropout into fla training

parent 393470f5
...@@ -140,7 +140,8 @@ int run(int argc, char* argv[]) ...@@ -140,7 +140,8 @@ int run(int argc, char* argv[])
<< "b0_gs_ns_ks[" << i << "]: " << b0_gs_ns_ks.mDesc << ", " << "b0_gs_ns_ks[" << i << "]: " << b0_gs_ns_ks.mDesc << ", "
<< "b1_gs_os_ns[" << i << "]: " << b1_gs_os_ns.mDesc << ", " << "b1_gs_os_ns[" << i << "]: " << b1_gs_os_ns.mDesc << ", "
<< "c_gs_ms_os[" << i << "]: " << c_gs_ms_os_device_result.mDesc << ", " << "c_gs_ms_os[" << i << "]: " << c_gs_ms_os_device_result.mDesc << ", "
<< "lse_gs_ms_os[" << i << "]: " << lse_gs_ms_device_result.mDesc << std::endl; << "lse_gs_ms_os[" << i << "]: " << lse_gs_ms_device_result.mDesc
<< std::endl;
} }
switch(init_method) switch(init_method)
...@@ -173,7 +174,6 @@ int run(int argc, char* argv[]) ...@@ -173,7 +174,6 @@ int run(int argc, char* argv[])
c_tensors.push_back(c_gs_ms_os_device_result); c_tensors.push_back(c_gs_ms_os_device_result);
lse_tensors.push_back(lse_gs_ms_device_result); lse_tensors.push_back(lse_gs_ms_device_result);
a_tensors_device.emplace_back(std::make_unique<DeviceMem>( a_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize())); sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()));
b0_tensors_device.emplace_back(std::make_unique<DeviceMem>( b0_tensors_device.emplace_back(std::make_unique<DeviceMem>(
...@@ -217,7 +217,8 @@ int run(int argc, char* argv[]) ...@@ -217,7 +217,8 @@ int run(int argc, char* argv[])
b0_element_op, b0_element_op,
acc0_element_op, acc0_element_op,
b1_element_op, b1_element_op,
c_element_op); c_element_op,
0); // dropout ratio
// specify workspace for problem_desc // specify workspace for problem_desc
DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument)); DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
...@@ -254,8 +255,8 @@ int run(int argc, char* argv[]) ...@@ -254,8 +255,8 @@ int run(int argc, char* argv[])
const auto& c_gs_ms_os_lengths = problem_descs[i].c_gs_ms_os_lengths; const auto& c_gs_ms_os_lengths = problem_descs[i].c_gs_ms_os_lengths;
const auto& c_gs_ms_os_strides = problem_descs[i].c_gs_ms_os_strides; const auto& c_gs_ms_os_strides = problem_descs[i].c_gs_ms_os_strides;
const auto& lse_gs_ms_lengths = problem_descs[i].lse_gs_ms_lengths; const auto& lse_gs_ms_lengths = problem_descs[i].lse_gs_ms_lengths;
const auto& lse_gs_ms_strides = problem_descs[i].lse_gs_ms_strides; const auto& lse_gs_ms_strides = problem_descs[i].lse_gs_ms_strides;
const auto& a_gs_ms_ks = a_tensors[i]; const auto& a_gs_ms_ks = a_tensors[i];
const auto& b0_gs_ns_ks = b0_tensors[i]; const auto& b0_gs_ns_ks = b0_tensors[i];
...@@ -305,9 +306,10 @@ int run(int argc, char* argv[]) ...@@ -305,9 +306,10 @@ int run(int argc, char* argv[])
}); });
// softmax // softmax
auto ref_softmax = ReferenceSoftmaxInstance{}; auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker(); auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}, &lse_g_m_host_result); auto ref_softmax_argument =
ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}, &lse_g_m_host_result);
ref_softmax_invoker.Run(ref_softmax_argument); ref_softmax_invoker.Run(ref_softmax_argument);
...@@ -347,25 +349,25 @@ int run(int argc, char* argv[]) ...@@ -347,25 +349,25 @@ int run(int argc, char* argv[])
// when BF16 is taken, set absolute error and relative error to 0.01 // when BF16 is taken, set absolute error and relative error to 0.01
if(std::is_same_v<ADataType, ck::bhalf_t> && std::is_same_v<B0DataType, ck::bhalf_t> && if(std::is_same_v<ADataType, ck::bhalf_t> && std::is_same_v<B0DataType, ck::bhalf_t> &&
std::is_same_v<B1DataType, ck::bhalf_t> && std::is_same_v<CDataType, ck::bhalf_t>) std::is_same_v<B1DataType, ck::bhalf_t> && std::is_same_v<CDataType, ck::bhalf_t>)
{ {
rtol = 1e-2; rtol = 1e-2;
atol = 1e-2; atol = 1e-2;
} }
// bool pass_ = // bool pass_ =
// ck::utils::check_err(c_gs_ms_os_device_result.mData, c_gs_ms_os_host_result.mData); // ck::utils::check_err(c_gs_ms_os_device_result.mData,
bool pass_ = // c_gs_ms_os_host_result.mData);
ck::utils::check_err(c_gs_ms_os_device_result.mData, bool pass_ = ck::utils::check_err(c_gs_ms_os_device_result.mData,
c_gs_ms_os_host_result.mData, c_gs_ms_os_host_result.mData,
"Error: Incorrect results c!", "Error: Incorrect results c!",
rtol, rtol,
atol) && atol) &&
ck::utils::check_err(lse_gs_ms_device_result.mData, ck::utils::check_err(lse_gs_ms_device_result.mData,
lse_gs_ms_host_result.mData, lse_gs_ms_host_result.mData,
"Error: Incorrect results lse!", "Error: Incorrect results lse!",
rtol, rtol,
atol); atol);
pass &= pass_; pass &= pass_;
} }
} }
......
...@@ -127,7 +127,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermuteTrain : public BaseOperator ...@@ -127,7 +127,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermuteTrain : public BaseOperator
B0ElementwiseOperation b0_element_op, B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op, Acc0ElementwiseOperation acc0_element_op,
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) = 0; CElementwiseOperation c_element_op,
float p_dropout,
const unsigned long long seed = 0) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <sstream> #include <sstream>
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
...@@ -29,7 +30,8 @@ template <typename GridwiseGemm, ...@@ -29,7 +30,8 @@ template <typename GridwiseGemm,
typename AccElementwiseOperation, typename AccElementwiseOperation,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
bool HasMainKBlockLoop> bool HasMainKBlockLoop,
bool IsDropout>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -41,13 +43,17 @@ __global__ void ...@@ -41,13 +43,17 @@ __global__ void
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op, const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op, const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op) const CElementwiseOperation c_element_op,
const ushort p_dropout_in_16bits,
const unsigned long long seed)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
ck::philox ph(seed, 0, block_id);
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>( const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
cast_pointer_to_generic_address_space(group_kernel_args)); cast_pointer_to_generic_address_space(group_kernel_args));
...@@ -82,10 +88,10 @@ __global__ void ...@@ -82,10 +88,10 @@ __global__ void
arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx)));
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset, arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
...@@ -103,7 +109,9 @@ __global__ void ...@@ -103,7 +109,9 @@ __global__ void
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_, arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_, arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_); arg_ptr[group_id].c0_matrix_mask_,
p_dropout_in_16bits,
ph);
#else #else
ignore = group_kernel_args; ignore = group_kernel_args;
ignore = group_count; ignore = group_count;
...@@ -506,12 +514,15 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -506,12 +514,15 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op,
float p_dropout,
unsigned long long seed)
: a_element_op_{a_element_op}, : a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
acc_element_op_{acc_element_op}, acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op},
seed_(seed)
{ {
// TODO ANT: implement bias addition // TODO ANT: implement bias addition
group_count_ = problem_desc_vec.size(); group_count_ = problem_desc_vec.size();
...@@ -531,11 +542,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -531,11 +542,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
for(std::size_t i = 0; i < group_count_; i++) for(std::size_t i = 0; i < group_count_; i++)
{ {
const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]); const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]);
const auto p_b_grid = static_cast<const BDataType*>(p_b_vec[i]); const auto p_b_grid = static_cast<const BDataType*>(p_b_vec[i]);
const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]); const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]);
const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]); const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]);
const auto p_lse_grid = static_cast<LSEDataType*>(p_lse_vec[i]); const auto p_lse_grid = static_cast<LSEDataType*>(p_lse_vec[i]);
const auto& problem_desc = problem_desc_vec[i]; const auto& problem_desc = problem_desc_vec[i];
...@@ -547,7 +558,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -547,7 +558,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides); problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
const auto c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N( const auto c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(
problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides); problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
const auto lse_grid_desc_m = DeviceOp::MakeLSEGridDescriptor_M(problem_desc.lse_gs_ms_lengths[NumDimG]); const auto lse_grid_desc_m =
DeviceOp::MakeLSEGridDescriptor_M(problem_desc.lse_gs_ms_lengths[NumDimG]);
const auto a_grid_desc_g_m_k = Transform::MakeAGridDescriptor_G_M_K( const auto a_grid_desc_g_m_k = Transform::MakeAGridDescriptor_G_M_K(
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides); problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
...@@ -571,7 +583,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -571,7 +583,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
// batch stride // batch stride
const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch( const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch(
a_grid_desc_g_m_k, b_grid_desc_g_n_k, b1_grid_desc_g_n_k, c_grid_desc_g_m_n, type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); a_grid_desc_g_m_k,
b_grid_desc_g_n_k,
b1_grid_desc_g_n_k,
c_grid_desc_g_m_n,
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize()));
// C0 mask // C0 mask
const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1)); const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1));
...@@ -622,6 +638,10 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -622,6 +638,10 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]}, problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]},
c_grid_desc_m_n}); c_grid_desc_m_n});
} }
is_dropout_ = p_dropout > 0.0; //
p_dropout_ = 1.f - p_dropout;
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0));
} }
std::vector<GroupKernelArg> group_kernel_args_; std::vector<GroupKernelArg> group_kernel_args_;
...@@ -635,6 +655,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -635,6 +655,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
AccElementwiseOperation acc_element_op_; AccElementwiseOperation acc_element_op_;
B1ElementwiseOperation b1_element_op_; B1ElementwiseOperation b1_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
float p_dropout_;
ushort p_dropout_in_16bits_;
unsigned long long seed_;
bool is_dropout_;
}; };
// Invoker // Invoker
...@@ -667,7 +692,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -667,7 +692,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel = const auto kernel =
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2<GridwiseGemm, kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2<GridwiseGemm,
GroupKernelArg, GroupKernelArg,
...@@ -676,7 +701,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -676,7 +701,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
AccElementwiseOperation, AccElementwiseOperation,
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
has_main_k_block_loop_>; has_main_k_block_loop_,
is_dropout_>;
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
...@@ -690,18 +716,38 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -690,18 +716,38 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
arg.b_element_op_, arg.b_element_op_,
arg.acc_element_op_, arg.acc_element_op_,
arg.b1_element_op_, arg.b1_element_op_,
arg.c_element_op_); arg.c_element_op_,
arg.p_dropout_in_16bits_,
arg.seed_);
}; };
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop // to concern Gemm0's loop
if(all_has_main_k_block_loop) if(all_has_main_k_block_loop)
{ {
ave_time = launch_kernel(integral_constant<bool, true>{}); if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
} }
else if(!some_has_main_k_block_loop) else if(!some_has_main_k_block_loop)
{ {
ave_time = launch_kernel(integral_constant<bool, false>{}); if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
} }
else else
{ {
...@@ -839,7 +885,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -839,7 +885,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op,
float p_dropout,
const unsigned long long seed = 0)
{ {
return Argument{p_a_vec, return Argument{p_a_vec,
p_b_vec, p_b_vec,
...@@ -853,7 +901,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -853,7 +901,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b_element_op, b_element_op,
acc_element_op, acc_element_op,
b1_element_op, b1_element_op,
c_element_op}; c_element_op,
p_dropout,
seed};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -872,7 +922,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -872,7 +922,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) override CElementwiseOperation c_element_op,
float p_dropout,
const unsigned long long seed = 0) override
{ {
return std::make_unique<Argument>(p_a_vec, return std::make_unique<Argument>(p_a_vec,
p_b_vec, p_b_vec,
...@@ -886,7 +938,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -886,7 +938,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b_element_op, b_element_op,
acc_element_op, acc_element_op,
b1_element_op, b1_element_op,
c_element_op); c_element_op,
p_dropout,
seed);
} }
// polymorphic // polymorphic
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
...@@ -15,6 +16,7 @@ ...@@ -15,6 +16,7 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp" #include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_dropout.hpp"
namespace ck { namespace ck {
...@@ -357,7 +359,10 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle ...@@ -357,7 +359,10 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
}; };
template <bool HasMainKBlockLoop, typename Block2CTileMap, typename C0MatrixMask> template <bool HasMainKBlockLoop,
bool IsDropout,
typename Block2CTileMap,
typename C0MatrixMask>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid, const FloatAB* __restrict__ p_b1_grid,
...@@ -376,7 +381,9 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle ...@@ -376,7 +381,9 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const LSEGridDesc_M& lse_grid_desc_m, const LSEGridDesc_M& lse_grid_desc_m,
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask) const C0MatrixMask& c0_matrix_mask,
const ushort p_dropout_in_16bits,
ck::philox ph)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
...@@ -721,6 +728,8 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle ...@@ -721,6 +728,8 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
decltype(thread_cluster_desc_m_n), decltype(thread_cluster_desc_m_n),
decltype(thread_slice_desc_m_n)>{}; decltype(thread_slice_desc_m_n)>{};
auto blockwise_dropout = BlockwiseDropout<decltype(thread_slice_desc_m_n)>{};
const index_t num_gemm1_k_block_outer_loop = const index_t num_gemm1_k_block_outer_loop =
b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock; b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock; constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock;
...@@ -862,6 +871,15 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle ...@@ -862,6 +871,15 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
blockwise_softmax.Run(acc_thread_buf, workspace_buf); blockwise_softmax.Run(acc_thread_buf, workspace_buf);
if constexpr(IsDropout) // dropout
{
blockwise_dropout.ApplyDropout(acc_thread_buf,
p_dropout_in_16bits,
ph,
gemm1_k_block_outer_index,
num_gemm1_k_block_outer_loop);
}
// TODO: may convert to log domain // TODO: may convert to log domain
running_max_new = mathext::max(max, running_max); running_max_new = mathext::max(max, running_max);
running_sum_new = mathext::exp(running_max - running_max_new) * running_sum + running_sum_new = mathext::exp(running_max - running_max_new) * running_sum +
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment