Commit 610e0ab0 authored by guangzlu's avatar guangzlu
Browse files

updated dropout for bwd pt4 & pt5

parent 8edf2a72
...@@ -14,6 +14,7 @@ add_example_executable(example_batched_multihead_attention_backward_v2 batched_m ...@@ -14,6 +14,7 @@ add_example_executable(example_batched_multihead_attention_backward_v2 batched_m
add_example_executable(example_batched_multihead_attention_backward_v3 batched_multihead_attention_backward_v3.cpp) add_example_executable(example_batched_multihead_attention_backward_v3 batched_multihead_attention_backward_v3.cpp)
add_example_executable(example_grouped_multihead_attention_train grouped_multihead_attention_train.cpp) add_example_executable(example_grouped_multihead_attention_train grouped_multihead_attention_train.cpp)
add_example_executable(example_batched_multihead_attention_train batched_multihead_attention_train.cpp) add_example_executable(example_batched_multihead_attention_train batched_multihead_attention_train.cpp)
add_example_executable(example_batched_multihead_attention_train_v5 batched_multihead_attention_train_v5.cpp)
add_custom_target(example_gemm_scale_softmax_gemm) add_custom_target(example_gemm_scale_softmax_gemm)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
......
...@@ -122,6 +122,160 @@ struct BlockwiseDropout ...@@ -122,6 +122,160 @@ struct BlockwiseDropout
}); });
} }
template <typename CThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void ApplyDropoutAttnBwd(CThreadBuffer& in_thread_buf,
ck::philox& ph,
index_t element_global_1d_id,
index_t MRaw)
{
auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
};
constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 4;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8 * MRaw);
}
block_sync_lds();
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset));
tmp_index = tmp_index + 1;
});
});
}
template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void ApplyDropoutAttnBwdSaveZ(CThreadBuffer& in_thread_buf,
ck::philox& ph,
index_t element_global_1d_id,
ZThreadBuffer& z_thread_buf,
index_t MRaw)
{
auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
};
constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 4;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8 * MRaw);
}
// ushort tmp_id[tmp_size];
// for(int i = 0; i < philox_calls; i++)
//{
// for(int j = 0; j < 4; j++)
// {
// tmp_id[i * 4 + j] = element_global_1d_id + i * 8 * MRaw;
// }
//}
block_sync_lds();
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset));
z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1;
});
});
}
template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void ApplyDropoutWithZ(CThreadBuffer& in_thread_buf,
ZThreadBuffer& z_thread_buf)
{
auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
};
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = execute_dropout(z_thread_buf(offset) <= p_dropout_16bits,
in_thread_buf(offset));
tmp_index = tmp_index + 1;
// if(get_thread_global_1d_id()==0){
// printf("z at %d is %u \n", tmp_index, z_thread_buf(offset));
//}
});
});
}
// get raw z matrix with random number for shuffle
template <typename ZThreadBuffer>
__host__ __device__ void GenerateZMatrixAttnFwd(ck::philox& ph,
index_t element_global_1d_id,
ZThreadBuffer& z_thread_buf)
{
// if(get_thread_global_1d_id() == 0){
// printf("MRepeat & KRepeat is %d , %d . \n", MRepeat, KRepeat);
// }
constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 4;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8);
}
// ushort tmp_id[tmp_size];
// for(int i = 0; i < philox_calls; i++)
//{
// for(int j = 0; j < 4; j++)
// {
// tmp_id[i * 4 + j] = element_global_1d_id + i * 8;
// }
//}
block_sync_lds();
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1;
});
});
}
ushort p_dropout_16bits; ushort p_dropout_16bits;
DataType p_dropout_rescale; DataType p_dropout_rescale;
}; };
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp" #include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt6.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt4.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp" #include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -39,7 +39,7 @@ template <typename GridwiseGemm, ...@@ -39,7 +39,7 @@ template <typename GridwiseGemm,
typename CElementwiseOperation, typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename B1GridDesc_BK0_N_BK1, typename B1GridDesc_BK0_N_BK1,
typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
typename LSEGridDescriptor_M, typename LSEGridDescriptor_M,
...@@ -71,7 +71,7 @@ __global__ void ...@@ -71,7 +71,7 @@ __global__ void
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
...@@ -85,7 +85,9 @@ __global__ void ...@@ -85,7 +85,9 @@ __global__ void
const C0MatrixMask c0_matrix_mask, const C0MatrixMask c0_matrix_mask,
const float p_drop, const float p_drop,
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset) const unsigned long long offset,
const index_t MRaw,
const index_t NRaw)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -144,6 +146,9 @@ __global__ void ...@@ -144,6 +146,9 @@ __global__ void
c0_matrix_mask, c0_matrix_mask,
p_drop, p_drop,
ph, ph,
g_idx,
MRaw,
NRaw,
i); i);
} }
} }
...@@ -176,6 +181,9 @@ __global__ void ...@@ -176,6 +181,9 @@ __global__ void
c0_matrix_mask, c0_matrix_mask,
p_drop, p_drop,
ph, ph,
g_idx,
MRaw,
NRaw,
0); 0);
} }
#else #else
...@@ -278,6 +286,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -278,6 +286,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
#if 0
// TODO: use alias
static constexpr index_t NumDimGemm0M = NumDimM;
static constexpr index_t NumDimGemm0N = NumDimN;
static constexpr index_t NumDimGemm0K = NumDimK;
static constexpr index_t NumDimGemm1M = NumDimM;
static constexpr index_t NumDimGemm1N = NumDimO;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1; using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -808,8 +826,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -808,8 +826,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds); offset_ = std::get<1>(seeds);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(z_grid_desc_m_n_);
// Print(); // Print();
} }
...@@ -869,8 +887,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -869,8 +887,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_; y_grid_desc_mblock_mperblock_oblock_operblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_; c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_;
// block-to-c-tile map // block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
...@@ -933,7 +951,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -933,7 +951,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
CElementwiseOperation, CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
DeviceOp::B1GridDesc_BK0_N_BK1, DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::LSEGridDesc_M, DeviceOp::LSEGridDesc_M,
...@@ -967,7 +985,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -967,7 +985,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
arg.c_element_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.y_grid_desc_mblock_mperblock_oblock_operblock_, arg.y_grid_desc_mblock_mperblock_oblock_operblock_,
arg.lse_grid_desc_m_, arg.lse_grid_desc_m_,
...@@ -979,11 +997,24 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -979,11 +997,24 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
arg.c0_matrix_mask_, arg.c0_matrix_mask_,
arg.p_drop_, arg.p_drop_,
arg.seed_, arg.seed_,
arg.offset_); arg.offset_,
arg.raw_lengths_mz_nz_kz_gemm1nz_[0],
arg.raw_lengths_mz_nz_kz_gemm1nz_[1]);
}; };
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop
#if 1
// if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
// {
// ave_time = launch_kernel(integral_constant<bool, true>{});
// }
// else
// {
// ave_time = launch_kernel(integral_constant<bool, false>{});
// }
ave_time = launch_kernel(integral_constant<bool, false>{}); ave_time = launch_kernel(integral_constant<bool, false>{});
#endif
return ave_time; return ave_time;
} }
...@@ -1003,6 +1034,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1003,6 +1034,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
#if 0
arg.Print();
#endif
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
{ {
return false; return false;
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp" #include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt7.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt5.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp" #include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -38,7 +38,7 @@ template <typename GridwiseGemm, ...@@ -38,7 +38,7 @@ template <typename GridwiseGemm,
typename CElementwiseOperation, typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename B1GridDesc_BK0_N_BK1, typename B1GridDesc_BK0_N_BK1,
typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
typename LSEGridDescriptor_M, typename LSEGridDescriptor_M,
...@@ -70,8 +70,8 @@ __global__ void ...@@ -70,8 +70,8 @@ __global__ void
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -84,7 +84,9 @@ __global__ void ...@@ -84,7 +84,9 @@ __global__ void
const C0MatrixMask c0_matrix_mask, const C0MatrixMask c0_matrix_mask,
const float p_drop, const float p_drop,
const unsigned long long seed, const unsigned long long seed,
const unsigned long long offset) const unsigned long long offset,
const index_t MRaw,
const index_t NRaw)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -134,7 +136,7 @@ __global__ void ...@@ -134,7 +136,7 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m, lse_grid_desc_m,
...@@ -143,6 +145,9 @@ __global__ void ...@@ -143,6 +145,9 @@ __global__ void
c0_matrix_mask, c0_matrix_mask,
p_drop, p_drop,
ph, ph,
g_idx,
MRaw,
NRaw,
i); i);
} }
} }
...@@ -166,7 +171,7 @@ __global__ void ...@@ -166,7 +171,7 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m, lse_grid_desc_m,
...@@ -175,6 +180,9 @@ __global__ void ...@@ -175,6 +180,9 @@ __global__ void
c0_matrix_mask, c0_matrix_mask,
p_drop, p_drop,
ph, ph,
g_idx,
MRaw,
NRaw,
0); 0);
} }
#else #else
...@@ -284,6 +292,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -284,6 +292,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
#if 0
// TODO: use alias
static constexpr index_t NumDimGemm0M = NumDimM;
static constexpr index_t NumDimGemm0N = NumDimN;
static constexpr index_t NumDimGemm0K = NumDimK;
static constexpr index_t NumDimGemm1M = NumDimM;
static constexpr index_t NumDimGemm1N = NumDimO;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2; using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -821,8 +839,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -821,8 +839,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds); offset_ = std::get<1>(seeds);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(z_grid_desc_m_n_);
// Print(); // Print();
} }
...@@ -882,8 +900,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -882,8 +900,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_; y_grid_desc_mblock_mperblock_oblock_operblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_; c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_;
// block-to-c-tile map // block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
...@@ -950,7 +968,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -950,7 +968,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
CElementwiseOperation, CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
DeviceOp::B1GridDesc_BK0_N_BK1, DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::LSEGridDesc_M, DeviceOp::LSEGridDesc_M,
...@@ -984,7 +1002,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -984,7 +1002,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
arg.c_element_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.y_grid_desc_mblock_mperblock_oblock_operblock_, arg.y_grid_desc_mblock_mperblock_oblock_operblock_,
arg.lse_grid_desc_m_, arg.lse_grid_desc_m_,
...@@ -996,11 +1014,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -996,11 +1014,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
arg.c0_matrix_mask_, arg.c0_matrix_mask_,
arg.p_drop_, arg.p_drop_,
arg.seed_, arg.seed_,
arg.offset_); arg.offset_,
arg.raw_lengths_mz_nz_kz_gemm1nz_[0],
arg.raw_lengths_mz_nz_kz_gemm1nz_[1]);
}; };
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop // to concern Gemm0's loop
#if 1
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
ave_time = launch_kernel(integral_constant<bool, true>{}); ave_time = launch_kernel(integral_constant<bool, true>{});
...@@ -1009,7 +1030,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1009,7 +1030,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
{ {
ave_time = launch_kernel(integral_constant<bool, false>{}); ave_time = launch_kernel(integral_constant<bool, false>{});
} }
#endif
return ave_time; return ave_time;
} }
...@@ -1029,6 +1050,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1029,6 +1050,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
#if 0
arg.Print();
#endif
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
{ {
return false; return false;
......
...@@ -127,24 +127,26 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -127,24 +127,26 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// C desc for source in blockwise copy // C desc for source in blockwise copy
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const ZGridDesc_M_N& z_grid_desc_m_n) MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(const ZGridDesc_M_N& z_grid_desc_m_n)
{ {
const auto M = z_grid_desc_m_n.GetLength(I0); const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1); const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma; constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk; constexpr auto M3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks; constexpr auto M4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size; constexpr auto M5 = mfma.group_size;
return transform_tensor_descriptor( return transform_tensor_descriptor(
z_grid_desc_m_n, z_grid_desc_m_n,
make_tuple(make_unmerge_transform( make_tuple(make_unmerge_transform(
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, MPerXdl)), make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, M3, M4, M5)),
make_unmerge_transform( make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, N3, N4, N5))), make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, NPerXdl))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{})); make_tuple(Sequence<0, 2, 4, 6, 7, 8>{}, Sequence<1, 3, 5, 9>{}));
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const index_t M, const index_t N) MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const index_t M, const index_t N)
{ {
...@@ -467,8 +469,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -467,8 +469,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(KGridDesc_N_K{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(KGridDesc_N_K{}))>;
using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype( using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 = remove_cvref_t<decltype(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(ZGridDesc_M_N{}))>; MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(ZGridDesc_M_N{}))>;
// S / dP Gemm (type 1 rcc) // S / dP Gemm (type 1 rcc)
struct Gemm0 struct Gemm0
...@@ -1183,8 +1185,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1183,8 +1185,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1, const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1, const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5& const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const VGridDesc_N0_O_N1& v_grid_desc_n0_o_n1, const VGridDesc_N0_O_N1& v_grid_desc_n0_o_n1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock& const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock&
y_grid_desc_mblock_mperblock_oblock_operblock, y_grid_desc_mblock_mperblock_oblock_operblock,
...@@ -1194,6 +1196,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1194,6 +1196,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const C0MatrixMask& c0_matrix_mask, const C0MatrixMask& c0_matrix_mask,
const float p_drop, const float p_drop,
ck::philox& ph, ck::philox& ph,
const index_t g_idx,
const index_t MRaw,
const index_t NRaw,
const index_t block_idx_n) const index_t block_idx_n)
{ {
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop); const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
...@@ -1558,47 +1563,45 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1558,47 +1563,45 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// z vgpr copy to global // z vgpr copy to global
// //
// z matrix threadwise desc // z matrix threadwise desc
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID I1, // NBlockID
m0, // MRepeat m0, // MRepeat
I1, // NRepeat n0, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl m2, // MGroupNum
m3, // NGroupNum m3, // MInputNum
m4, // NInputNum m4, // registerNum
n2)); // registerNum n2)); // NPerXdl
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
unsigned short, unsigned short,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(), z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true> true>
z_tenor_buffer; z_tenor_buffer;
z_tenor_buffer.Clear(); z_tenor_buffer.Clear();
// z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = // z matrix global desc
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(M, N);*/ // ignore = p_z_tmp_grid;
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( // tmp buffer for shuffle
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( p_z_grid,
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize()); z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
const auto wave_id = GetGemm0WaveIdx(); const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63 const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto z_thread_copy_vgpr_to_global =
ushort, ThreadwiseTensorSliceTransfer_v1r3<ushort,
ZDataType, ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), decltype(
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
Sequence<I1, // MBlockId Sequence<I1, // MBlockId
I1, // NBlockID I1, // NBlockID
m0, // MRepeat m0, // MRepeat
I1, // NRepeat n0, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl m2, // MPerXdl
...@@ -1610,17 +1613,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1610,17 +1613,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
1, // DstScalarPerVector 1, // DstScalarPerVector
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(block_work_idx_n, // MBlockId make_multi_index(0, // MBlockId
0, // NBlockId block_work_idx_n, // NBlockId
0, // mrepeat 0, // mrepeat
0, // nrepeat 0, // nrepeat
wave_id[I0], // MWaveId wave_id[I0], // MWaveId
wave_id[I1], // NWaveId wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl 0, // MPerXdl
0, // group wave_m_n_id[I0], // group
wave_m_n_id[I0], // NInputIndex 0, // NInputIndex
0), wave_m_n_id[I1]),
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
// //
...@@ -1743,8 +1746,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1743,8 +1746,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
y_grid_desc_mblock_mperblock_oblock_operblock, make_multi_index(1, 0, 0, 0)); y_grid_desc_mblock_mperblock_oblock_operblock, make_multi_index(1, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
continue; continue;
} }
...@@ -1891,35 +1894,144 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1891,35 +1894,144 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// save z to global // save z to global
if(p_z_grid) if(p_z_grid)
{ {
// P_dropped // 8d thread_desc in thread scope
static_for<0, n0, 1>{}([&](auto i) { constexpr auto c_thread_lengths =
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
// 8d block_desc in block scope
constexpr auto c_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
constexpr auto M0 = c_block_lengths[I0];
constexpr auto N0 = c_block_lengths[I1];
constexpr auto M1 = c_block_lengths[I2];
constexpr auto N1 = c_block_lengths[I3];
constexpr auto M2 = c_block_lengths[I4];
constexpr auto M3 = c_block_lengths[I5];
constexpr auto M4 = c_block_lengths[I6];
constexpr auto N2 = c_block_lengths[I7];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using Acc0TileIterator = SpaceFillingCurve<
decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>; // SnakeCurved
constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
//}
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
// }
// if(get_thread_global_1d_id()==32){
// printf("tid 32 m_global & n_global is %d & %d \n", m_global , n_global);
// }
auto global_elem_id_raw =
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % 4) * MRaw + int(global_elem_id_raw / 4) * 4;
// if(get_block_1d_id() == 0 && get_thread_local_1d_id()==64){
// printf("global_elem_id is %d \n", global_elem_id);
//}
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
// if(get_thread_global_1d_id() == 0){
// printf("Acc0TileIterator::GetNumOfAccess() is %d \n",
// Acc0TileIterator::GetNumOfAccess()); printf("n0.value is %d \n", n0.value);
// printf("id_step is %d \n", id_step);
//}
// dropout
// z_tenor_buffer_tmp -> z_grid_buf_tmp -> shuffle -> z_tenor_buffer -> z_grid_buf
blockwise_dropout.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer), decltype(z_tenor_buffer),
true, true>(
decltype(n0), s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, MRaw);
decltype(i)>(
s_slash_p_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run( z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer, z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf); z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
});
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, -n0.value, 0, 0, 0, 0, 0, 0));
} }
else else
{ {
ignore = z_grid_buf; ignore = z_grid_buf;
// 8d thread_desc in thread scope
constexpr auto c_thread_lengths =
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
// 8d block_desc in block scope
constexpr auto c_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
constexpr auto M0 = c_block_lengths[I0];
constexpr auto N0 = c_block_lengths[I1];
constexpr auto M1 = c_block_lengths[I2];
constexpr auto N1 = c_block_lengths[I3];
constexpr auto M2 = c_block_lengths[I4];
constexpr auto M3 = c_block_lengths[I5];
constexpr auto M4 = c_block_lengths[I6];
constexpr auto N2 = c_block_lengths[I7];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using Acc0TileIterator = SpaceFillingCurve<
decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>; // SnakeCurved
constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
//}
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
// }
// if(get_thread_global_1d_id()==32){
// printf("tid 32 m_global & n_global is %d & %d \n", m_global , n_global);
// }
auto global_elem_id_raw =
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % 4) * MRaw + int(global_elem_id_raw / 4) * 4;
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>( blockwise_dropout
s_slash_p_thread_buf, ph); .template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph, global_elem_id, MRaw);
} }
block_sync_lds(); // wait for gemm1 LDS read block_sync_lds(); // wait for gemm1 LDS read
...@@ -2183,8 +2295,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -2183,8 +2295,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
make_multi_index(1, 0, 0, 0)); make_multi_index(1, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
} while(++gemm0_m_block_outer_index < num_gemm0_m_block_outer_loop); // end j loop } while(++gemm0_m_block_outer_index < num_gemm0_m_block_outer_loop); // end j loop
......
...@@ -84,6 +84,17 @@ class philox ...@@ -84,6 +84,17 @@ class philox
out_tmp[3] = tmp_ph.w; out_tmp[3] = tmp_ph.w;
} }
__device__ void get_random_4x16(ushort* out, const unsigned long long subsequence)
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
out[0] = static_cast<ushort>(tmp_ph.x);
out[1] = static_cast<ushort>(tmp_ph.y);
out[2] = static_cast<ushort>(tmp_ph.z);
out[3] = static_cast<ushort>(tmp_ph.w);
}
private: private:
struct ull2 struct ull2
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment