Commit 9fe6407e authored by guangzlu's avatar guangzlu
Browse files

added dropout into fla training

parent 393470f5
......@@ -140,7 +140,8 @@ int run(int argc, char* argv[])
<< "b0_gs_ns_ks[" << i << "]: " << b0_gs_ns_ks.mDesc << ", "
<< "b1_gs_os_ns[" << i << "]: " << b1_gs_os_ns.mDesc << ", "
<< "c_gs_ms_os[" << i << "]: " << c_gs_ms_os_device_result.mDesc << ", "
<< "lse_gs_ms_os[" << i << "]: " << lse_gs_ms_device_result.mDesc << std::endl;
<< "lse_gs_ms_os[" << i << "]: " << lse_gs_ms_device_result.mDesc
<< std::endl;
}
switch(init_method)
......@@ -173,7 +174,6 @@ int run(int argc, char* argv[])
c_tensors.push_back(c_gs_ms_os_device_result);
lse_tensors.push_back(lse_gs_ms_device_result);
a_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()));
b0_tensors_device.emplace_back(std::make_unique<DeviceMem>(
......@@ -217,7 +217,8 @@ int run(int argc, char* argv[])
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op);
c_element_op,
0); // dropout ratio
// specify workspace for problem_desc
DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
......@@ -254,8 +255,8 @@ int run(int argc, char* argv[])
const auto& c_gs_ms_os_lengths = problem_descs[i].c_gs_ms_os_lengths;
const auto& c_gs_ms_os_strides = problem_descs[i].c_gs_ms_os_strides;
const auto& lse_gs_ms_lengths = problem_descs[i].lse_gs_ms_lengths;
const auto& lse_gs_ms_strides = problem_descs[i].lse_gs_ms_strides;
const auto& lse_gs_ms_lengths = problem_descs[i].lse_gs_ms_lengths;
const auto& lse_gs_ms_strides = problem_descs[i].lse_gs_ms_strides;
const auto& a_gs_ms_ks = a_tensors[i];
const auto& b0_gs_ns_ks = b0_tensors[i];
......@@ -305,9 +306,10 @@ int run(int argc, char* argv[])
});
// softmax
auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}, &lse_g_m_host_result);
auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument =
ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}, &lse_g_m_host_result);
ref_softmax_invoker.Run(ref_softmax_argument);
......@@ -347,25 +349,25 @@ int run(int argc, char* argv[])
// when BF16 is taken, set absolute error and relative error to 0.01
if(std::is_same_v<ADataType, ck::bhalf_t> && std::is_same_v<B0DataType, ck::bhalf_t> &&
std::is_same_v<B1DataType, ck::bhalf_t> && std::is_same_v<CDataType, ck::bhalf_t>)
std::is_same_v<B1DataType, ck::bhalf_t> && std::is_same_v<CDataType, ck::bhalf_t>)
{
rtol = 1e-2;
atol = 1e-2;
}
// bool pass_ =
// ck::utils::check_err(c_gs_ms_os_device_result.mData, c_gs_ms_os_host_result.mData);
bool pass_ =
ck::utils::check_err(c_gs_ms_os_device_result.mData,
c_gs_ms_os_host_result.mData,
"Error: Incorrect results c!",
rtol,
atol) &&
ck::utils::check_err(lse_gs_ms_device_result.mData,
lse_gs_ms_host_result.mData,
"Error: Incorrect results lse!",
rtol,
atol);
// ck::utils::check_err(c_gs_ms_os_device_result.mData,
// c_gs_ms_os_host_result.mData);
bool pass_ = ck::utils::check_err(c_gs_ms_os_device_result.mData,
c_gs_ms_os_host_result.mData,
"Error: Incorrect results c!",
rtol,
atol) &&
ck::utils::check_err(lse_gs_ms_device_result.mData,
lse_gs_ms_host_result.mData,
"Error: Incorrect results lse!",
rtol,
atol);
pass &= pass_;
}
}
......
......@@ -127,7 +127,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermuteTrain : public BaseOperator
B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) = 0;
CElementwiseOperation c_element_op,
float p_dropout,
const unsigned long long seed = 0) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
......
......@@ -7,6 +7,7 @@
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
......@@ -29,7 +30,8 @@ template <typename GridwiseGemm,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
bool HasMainKBlockLoop>
bool HasMainKBlockLoop,
bool IsDropout>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
......@@ -41,13 +43,17 @@ __global__ void
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op)
const CElementwiseOperation c_element_op,
const ushort p_dropout_in_16bits,
const unsigned long long seed)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
ck::philox ph(seed, 0, block_id);
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
cast_pointer_to_generic_address_space(group_kernel_args));
......@@ -82,10 +88,10 @@ __global__ void
arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx)));
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx)));
GridwiseGemm::template Run<HasMainKBlockLoop>(
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
......@@ -103,7 +109,9 @@ __global__ void
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_);
arg_ptr[group_id].c0_matrix_mask_,
p_dropout_in_16bits,
ph);
#else
ignore = group_kernel_args;
ignore = group_count;
......@@ -506,12 +514,15 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op)
CElementwiseOperation c_element_op,
float p_dropout,
unsigned long long seed)
: a_element_op_{a_element_op},
b_element_op_{b_element_op},
acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op},
c_element_op_{c_element_op}
c_element_op_{c_element_op},
seed_(seed)
{
// TODO ANT: implement bias addition
group_count_ = problem_desc_vec.size();
......@@ -531,11 +542,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
for(std::size_t i = 0; i < group_count_; i++)
{
const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]);
const auto p_b_grid = static_cast<const BDataType*>(p_b_vec[i]);
const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]);
const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]);
const auto p_lse_grid = static_cast<LSEDataType*>(p_lse_vec[i]);
const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]);
const auto p_b_grid = static_cast<const BDataType*>(p_b_vec[i]);
const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]);
const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]);
const auto p_lse_grid = static_cast<LSEDataType*>(p_lse_vec[i]);
const auto& problem_desc = problem_desc_vec[i];
......@@ -547,7 +558,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
const auto c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(
problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
const auto lse_grid_desc_m = DeviceOp::MakeLSEGridDescriptor_M(problem_desc.lse_gs_ms_lengths[NumDimG]);
const auto lse_grid_desc_m =
DeviceOp::MakeLSEGridDescriptor_M(problem_desc.lse_gs_ms_lengths[NumDimG]);
const auto a_grid_desc_g_m_k = Transform::MakeAGridDescriptor_G_M_K(
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
......@@ -571,7 +583,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
// batch stride
const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch(
a_grid_desc_g_m_k, b_grid_desc_g_n_k, b1_grid_desc_g_n_k, c_grid_desc_g_m_n, type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize()));
a_grid_desc_g_m_k,
b_grid_desc_g_n_k,
b1_grid_desc_g_n_k,
c_grid_desc_g_m_n,
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize()));
// C0 mask
const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1));
......@@ -622,6 +638,10 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]},
c_grid_desc_m_n});
}
is_dropout_ = p_dropout > 0.0; //
p_dropout_ = 1.f - p_dropout;
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0));
}
std::vector<GroupKernelArg> group_kernel_args_;
......@@ -635,6 +655,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
AccElementwiseOperation acc_element_op_;
B1ElementwiseOperation b1_element_op_;
CElementwiseOperation c_element_op_;
float p_dropout_;
ushort p_dropout_in_16bits_;
unsigned long long seed_;
bool is_dropout_;
};
// Invoker
......@@ -667,7 +692,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel =
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2<GridwiseGemm,
GroupKernelArg,
......@@ -676,7 +701,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_>;
has_main_k_block_loop_,
is_dropout_>;
return launch_and_time_kernel(
stream_config,
......@@ -690,18 +716,38 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
arg.b_element_op_,
arg.acc_element_op_,
arg.b1_element_op_,
arg.c_element_op_);
arg.c_element_op_,
arg.p_dropout_in_16bits_,
arg.seed_);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop
if(all_has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, true>{});
if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
}
else if(!some_has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, false>{});
if(arg.is_dropout_)
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
else
{
......@@ -839,7 +885,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op)
CElementwiseOperation c_element_op,
float p_dropout,
const unsigned long long seed = 0)
{
return Argument{p_a_vec,
p_b_vec,
......@@ -853,7 +901,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b_element_op,
acc_element_op,
b1_element_op,
c_element_op};
c_element_op,
p_dropout,
seed};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -872,7 +922,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) override
CElementwiseOperation c_element_op,
float p_dropout,
const unsigned long long seed = 0) override
{
return std::make_unique<Argument>(p_a_vec,
p_b_vec,
......@@ -886,7 +938,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b_element_op,
acc_element_op,
b1_element_op,
c_element_op);
c_element_op,
p_dropout,
seed);
}
// polymorphic
......
......@@ -4,6 +4,7 @@
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
......@@ -15,6 +16,7 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_dropout.hpp"
namespace ck {
......@@ -357,7 +359,10 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
template <bool HasMainKBlockLoop, typename Block2CTileMap, typename C0MatrixMask>
template <bool HasMainKBlockLoop,
bool IsDropout,
typename Block2CTileMap,
typename C0MatrixMask>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid,
......@@ -376,7 +381,9 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
c_grid_desc_mblock_mperblock_nblock_nperblock,
const LSEGridDesc_M& lse_grid_desc_m,
const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask)
const C0MatrixMask& c0_matrix_mask,
const ushort p_dropout_in_16bits,
ck::philox ph)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
......@@ -721,6 +728,8 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
decltype(thread_cluster_desc_m_n),
decltype(thread_slice_desc_m_n)>{};
auto blockwise_dropout = BlockwiseDropout<decltype(thread_slice_desc_m_n)>{};
const index_t num_gemm1_k_block_outer_loop =
b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock;
......@@ -862,6 +871,15 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
blockwise_softmax.Run(acc_thread_buf, workspace_buf);
if constexpr(IsDropout) // dropout
{
blockwise_dropout.ApplyDropout(acc_thread_buf,
p_dropout_in_16bits,
ph,
gemm1_k_block_outer_index,
num_gemm1_k_block_outer_loop);
}
// TODO: may convert to log domain
running_max_new = mathext::max(max, running_max);
running_sum_new = mathext::exp(running_max - running_max_new) * running_sum +
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment