Commit 9f3bc9a3 authored by danyao12's avatar danyao12
Browse files

sync attn-bwd-dropout

parents f43ed837 e9e6081a
......@@ -163,5 +163,4 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
#include "run_grouped_multihead_attention_forward.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
......@@ -10,6 +10,13 @@ int run(int argc, char* argv[])
bool input_permute = false;
bool output_permute = true;
float p_drop = 0.2;
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
const unsigned long long seed = 1;
const unsigned long long offset = 0;
if(argc == 1)
{
// use default case
......@@ -104,7 +111,7 @@ int run(int argc, char* argv[])
output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides =
output_permute
......@@ -152,12 +159,12 @@ int run(int argc, char* argv[])
<< "b0_gs_ns_ks[" << i << "]: " << b0_gs_ns_ks.mDesc << ", "
<< "b1_gs_os_ns[" << i << "]: " << b1_gs_os_ns.mDesc << ", "
<< "c_gs_ms_os[" << i << "]: " << c_gs_ms_os_device_result.mDesc << ", "
<< "c_gs_ms_os[" << i << "]: " << c_gs_ms_os_device_result.mDesc << ", "
<< "z_gs_ms_ns[" << i << "]: " << z_gs_ms_ns.mDesc << ", "
<< "lse_gs_ms_os[" << i << "]: " << lse_gs_ms_device_result.mDesc
<< std::endl;
}
z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DataType>{0});
z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<ZDataType>{0});
switch(init_method)
{
......@@ -238,7 +245,7 @@ int run(int argc, char* argv[])
acc0_element_op,
b1_element_op,
c_element_op,
0, // dropout ratio
p_drop, // dropout ratio
{0, 448}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread
......
......@@ -37,7 +37,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_multiheadattention_forward_xdl_cshuffle(
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count,
const AElementwiseOperation a_element_op,
......@@ -92,14 +92,22 @@ __global__ void
arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetZBasePtr(g_idx)));
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx)));
// unsigned short* p_z_grid_in = //
// (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
// : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset,
arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
p_shared,
a_element_op,
......@@ -111,6 +119,7 @@ __global__ void
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, ////////
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_,
......@@ -140,6 +149,7 @@ template <index_t NumDimG,
typename BDataType,
typename B1DataType,
typename CDataType,
typename ZDataType,
typename LSEDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
......@@ -207,6 +217,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
BDataType,
B1DataType,
CDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
......@@ -246,6 +257,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
BDataType,
B1DataType,
CDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
......@@ -295,6 +307,12 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
Number<B1K1>{});
}
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths_vec,
const std::vector<index_t>& z_gs_ms_ns_strides_vec)
{
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths_vec, z_gs_ms_ns_strides_vec);
}
static auto MakeLSEGridDescriptor_M(index_t MRaw)
{
const auto lse_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
......@@ -325,10 +343,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
constexpr static auto make_MaskOutPredicate()
{
......@@ -349,11 +370,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
const BGridDesc_G_N_K& b_grid_desc_g_n_k,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n,
const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
index_t BatchStrideLSE)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
BatchStrideLSE_(BatchStrideLSE)
{
}
......@@ -378,6 +401,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const
{
return z_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetLSEBasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideLSE_);
......@@ -388,6 +416,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
BGridDesc_G_N_K b_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_;
index_t BatchStrideLSE_;
};
......@@ -408,6 +437,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
BGridDesc_BK0_N_BK1,
B1GridDesc_BK0_N_BK1,
CGridDesc_M_N,
ZGridDesc_M_N,
LSEGridDesc_M,
NumGemmKPrefetchStage,
BlockSize,
......@@ -465,6 +495,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
const BDataType* p_b_grid_;
const B1DataType* p_b1_grid_;
CDataType* p_c_grid_;
ZDataType* p_z_grid_;
LSEDataType* p_lse_grid_;
// tensor descriptors for block/thread-wise copy
......@@ -473,6 +504,9 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
ZGridDesc_M_N z_grid_desc_m_n_;
LSEGridDesc_M lse_grid_desc_m_;
// batch & stride
......@@ -511,6 +545,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
std::vector<const void*> p_b_vec,
std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec,
std::vector<void*> p_z_vec,
std::vector<void*> p_lse_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec,
......@@ -550,6 +585,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
const auto p_b_grid = static_cast<const BDataType*>(p_b_vec[i]);
const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]);
const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]);
const auto p_z_grid = static_cast<ZDataType*>(p_z_vec[i]);
const auto p_lse_grid = static_cast<LSEDataType*>(p_lse_vec[i]);
const auto& problem_desc = problem_desc_vec[i];
......@@ -562,6 +598,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
const auto c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(
problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
const auto z_grid_desc_m_n = MakeZGridDescriptor_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
const auto lse_grid_desc_m =
DeviceOp::MakeLSEGridDescriptor_M(problem_desc.lse_gs_ms_lengths[NumDimG]);
......@@ -573,11 +611,20 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
const auto z_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n);
// typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
z_grid_desc_m_n);
const index_t BlockStart = grid_size_;
const auto block_2_ctile_map = Block2CTileMap(c_grid_desc_m_n, BlockStart);
const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0);
......@@ -591,6 +638,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
b_grid_desc_g_n_k,
b1_grid_desc_g_n_k,
c_grid_desc_g_m_n,
z_grid_desc_g_m_n,
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize()));
// C0 mask
......@@ -614,11 +662,14 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
p_b_grid,
p_b1_grid,
p_c_grid,
p_z_grid,
p_lse_grid,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_desc_m_n,
lse_grid_desc_m,
block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n),
compute_base_ptr_of_batch,
......@@ -705,16 +756,16 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
const auto kernel =
kernel_grouped_multiheadattention_forward_xdl_cshuffle<GridwiseGemm,
GemmAccDataType,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_,
is_dropout_>;
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2<GridwiseGemm,
GemmAccDataType,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_,
is_dropout_>;
return launch_and_time_kernel(
stream_config,
......@@ -891,6 +942,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
std::vector<const void*> p_b_vec,
std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec,
std::vector<void*> p_z_vec,
std::vector<void*> p_lse_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec,
......@@ -907,6 +959,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
p_b_vec,
p_b1_vec,
p_c_vec,
p_z_vec,
p_lse_vec,
p_acc0_biases_vec,
p_acc1_biases_vec,
......@@ -928,6 +981,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
std::vector<const void*> p_b_vec,
std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec,
std::vector<void*> p_z_vec,
std::vector<void*> p_lse_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec,
......@@ -944,6 +998,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
p_b_vec,
p_b1_vec,
p_c_vec,
p_z_vec,
p_lse_vec,
p_acc0_biases_vec,
p_acc1_biases_vec,
......
......@@ -120,8 +120,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
// C desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const ZGridDesc_M_N& z_grid_desc_m_n) ////=> for z use
__host__ __device__ static constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
const ZGridDesc_M_N& z_grid_desc_m_n) ////=> for z use
{
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
......@@ -140,7 +140,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
}
__host__ __device__ static constexpr auto
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const index_t M, const index_t N) ////=> for z use
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const index_t M,
const index_t N) ////=> for z use
{
constexpr auto mfma = MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
......@@ -1018,7 +1019,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
if constexpr(IsDropout) // dropout
{
// save z to global
// save z to global
if(p_z_grid)
{
// P_dropped
......@@ -1027,11 +1028,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
true>(
acc_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
}
else
{
......@@ -1041,7 +1043,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
}
}
//if constexpr(IsDropout) // dropout
// if constexpr(IsDropout) // dropout
//{
// blockwise_dropout.ApplyDropout(acc_thread_buf, ph);
//}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment