"tests/vscode:/vscode.git/clone" did not exist on "2c03fe9952cfd3419fb4325a830a6958ed455c3d"
Commit dc8e0148 authored by guangzlu's avatar guangzlu
Browse files

added dropout shuffle for attn fwd, bwd v4 can pass now

parent 78c1482a
...@@ -78,7 +78,7 @@ using GemmDataType = F16; ...@@ -78,7 +78,7 @@ using GemmDataType = F16;
using AccDataType = F32; using AccDataType = F32;
using ShuffleDataType = F32; using ShuffleDataType = F32;
using LSEDataType = F32; using LSEDataType = F32;
using ZDataType = INT32; // INT32 using ZDataType = U16; // INT32
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = ck::Tuple<>;
...@@ -104,7 +104,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecia ...@@ -104,7 +104,7 @@ static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecia
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = true; static constexpr bool Deterministic = false;
// DIM should be a multiple of 8. // DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template. // If DIM <= 32 , ues prototype1 1st template.
...@@ -716,7 +716,7 @@ int run(int argc, char* argv[]) ...@@ -716,7 +716,7 @@ int run(int argc, char* argv[])
ck::index_t K = DIM; ck::index_t K = DIM;
ck::index_t O = DIM; ck::index_t O = DIM;
ck::index_t G0 = 1; // 54 ck::index_t G0 = 1; // 54
ck::index_t G1 = 2; // 16 ck::index_t G1 = 1; // 16
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
......
...@@ -123,8 +123,9 @@ struct BlockwiseDropout ...@@ -123,8 +123,9 @@ struct BlockwiseDropout
} }
template <typename CThreadBuffer, bool using_sign_bit = false> template <typename CThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void __host__ __device__ void ApplyDropoutAttnFwd(CThreadBuffer& in_thread_buf,
ApplyDropout_v1r1(CThreadBuffer& in_thread_buf, ck::philox& ph, index_t element_global_1d_id) // ck::philox& ph,
index_t element_global_1d_id) //
{ {
auto execute_dropout = [&](bool keep, DataType val) { auto execute_dropout = [&](bool keep, DataType val) {
...@@ -161,7 +162,7 @@ struct BlockwiseDropout ...@@ -161,7 +162,7 @@ struct BlockwiseDropout
__host__ __device__ void ApplyDropoutAttnBwd(CThreadBuffer& in_thread_buf, __host__ __device__ void ApplyDropoutAttnBwd(CThreadBuffer& in_thread_buf,
ck::philox& ph, ck::philox& ph,
index_t element_global_1d_id, index_t element_global_1d_id,
index_t MRaw) // index_t MRaw)
{ {
auto execute_dropout = [&](bool keep, DataType val) { auto execute_dropout = [&](bool keep, DataType val) {
...@@ -195,10 +196,11 @@ struct BlockwiseDropout ...@@ -195,10 +196,11 @@ struct BlockwiseDropout
} }
template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false> template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void ApplyDropout_v1r2(CThreadBuffer& in_thread_buf, __host__ __device__ void ApplyDropoutAttnBwdSaveZ(CThreadBuffer& in_thread_buf,
ck::philox& ph, ck::philox& ph,
index_t element_global_1d_id, index_t element_global_1d_id,
ZThreadBuffer& z_thread_buf) ZThreadBuffer& z_thread_buf,
index_t MRaw)
{ {
auto execute_dropout = [&](bool keep, DataType val) { auto execute_dropout = [&](bool keep, DataType val) {
...@@ -215,17 +217,17 @@ struct BlockwiseDropout ...@@ -215,17 +217,17 @@ struct BlockwiseDropout
ushort tmp[tmp_size]; ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++) for(int i = 0; i < philox_calls; i++)
{ {
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8); ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8 * MRaw);
} }
ushort tmp_id[tmp_size]; // ushort tmp_id[tmp_size];
for(int i = 0; i < philox_calls; i++) // for(int i = 0; i < philox_calls; i++)
{ //{
for(int j = 0; j < 4; j++) // for(int j = 0; j < 4; j++)
{ // {
tmp_id[i * 4 + j] = element_global_1d_id + i * 8; // tmp_id[i * 4 + j] = element_global_1d_id + i * 8 * MRaw;
} // }
} //}
block_sync_lds(); block_sync_lds();
...@@ -235,18 +237,15 @@ struct BlockwiseDropout ...@@ -235,18 +237,15 @@ struct BlockwiseDropout
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{}; auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset)); execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset));
z_thread_buf(offset) = tmp_id[tmp_index]; z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1; tmp_index = tmp_index + 1;
}); });
}); });
} }
template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false> template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void ApplyDropoutAttnBwdSaveZ(CThreadBuffer& in_thread_buf, __host__ __device__ void ApplyDropoutWithZ(CThreadBuffer& in_thread_buf,
ck::philox& ph, ZThreadBuffer& z_thread_buf)
index_t element_global_1d_id,
ZThreadBuffer& z_thread_buf,
index_t MRaw)
{ {
auto execute_dropout = [&](bool keep, DataType val) { auto execute_dropout = [&](bool keep, DataType val) {
...@@ -256,6 +255,26 @@ struct BlockwiseDropout ...@@ -256,6 +255,26 @@ struct BlockwiseDropout
return keep ? val * p_dropout_rescale : float(0); return keep ? val * p_dropout_rescale : float(0);
}; };
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = execute_dropout(z_thread_buf(offset) <= p_dropout_16bits,
in_thread_buf(offset));
});
});
}
// get raw z matrix with random number for shuffle
template <typename ZThreadBuffer>
__host__ __device__ void GenerateZMatrixAttnFwd(ck::philox& ph,
index_t element_global_1d_id,
ZThreadBuffer& z_thread_buf)
{
// if(get_thread_global_1d_id() == 0){
// printf("MRepeat & KRepeat is %d , %d . \n", MRepeat, KRepeat);
// }
constexpr int tmp_size = MRepeat * KRepeat; constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 4; int philox_calls = tmp_size / 4;
...@@ -263,17 +282,17 @@ struct BlockwiseDropout ...@@ -263,17 +282,17 @@ struct BlockwiseDropout
ushort tmp[tmp_size]; ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++) for(int i = 0; i < philox_calls; i++)
{ {
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8 * MRaw); ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8);
} }
ushort tmp_id[tmp_size]; // ushort tmp_id[tmp_size];
for(int i = 0; i < philox_calls; i++) // for(int i = 0; i < philox_calls; i++)
{ //{
for(int j = 0; j < 4; j++) // for(int j = 0; j < 4; j++)
{ // {
tmp_id[i * 4 + j] = element_global_1d_id + i * 8 * MRaw; // tmp_id[i * 4 + j] = element_global_1d_id + i * 8;
} // }
} //}
block_sync_lds(); block_sync_lds();
...@@ -281,36 +300,12 @@ struct BlockwiseDropout ...@@ -281,36 +300,12 @@ struct BlockwiseDropout
static_for<0, MRepeat, 1>{}([&](auto iM) { static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) { static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{}; auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = z_thread_buf(offset) = tmp[tmp_index];
execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset));
z_thread_buf(offset) = tmp_id[tmp_index];
tmp_index = tmp_index + 1; tmp_index = tmp_index + 1;
}); });
}); });
} }
template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void ApplyDropout_v2(CThreadBuffer& in_thread_buf,
ZThreadBuffer& z_thread_buf)
{
auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
};
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = execute_dropout(z_thread_buf(offset) <= p_dropout_16bits,
in_thread_buf(offset));
});
});
}
// get raw z matrix with random number for shuffle
template <typename ZThreadBuffer> template <typename ZThreadBuffer>
__host__ __device__ void GenerateZMatrix(ck::philox& ph, __host__ __device__ void GenerateZMatrix(ck::philox& ph,
index_t element_global_1d_id, index_t element_global_1d_id,
...@@ -332,14 +327,14 @@ struct BlockwiseDropout ...@@ -332,14 +327,14 @@ struct BlockwiseDropout
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8 * MRaw); ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8 * MRaw);
} }
ushort tmp_id[tmp_size]; // ushort tmp_id[tmp_size];
for(int i = 0; i < philox_calls; i++) // for(int i = 0; i < philox_calls; i++)
{ //{
for(int j = 0; j < 4; j++) // for(int j = 0; j < 4; j++)
{ // {
tmp_id[i * 4 + j] = element_global_1d_id + i * 8 * MRaw; // tmp_id[i * 4 + j] = element_global_1d_id + i * 8 * MRaw;
} // }
} //}
block_sync_lds(); block_sync_lds();
...@@ -347,7 +342,7 @@ struct BlockwiseDropout ...@@ -347,7 +342,7 @@ struct BlockwiseDropout
static_for<0, MRepeat, 1>{}([&](auto iM) { static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) { static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{}; auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
z_thread_buf(offset) = tmp_id[tmp_index]; z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1; tmp_index = tmp_index + 1;
}); });
}); });
......
...@@ -39,6 +39,7 @@ template <typename GridwiseGemm, ...@@ -39,6 +39,7 @@ template <typename GridwiseGemm,
typename B1GridDesc_BK0_N_BK1, typename B1GridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename ZShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5,
typename LSEGridDescriptor_M, typename LSEGridDescriptor_M,
typename Block2CTileMap, typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch, typename ComputeBasePtrOfStridedBatch,
...@@ -70,6 +71,8 @@ __global__ void ...@@ -70,6 +71,8 @@ __global__ void
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const ZShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
const LSEGridDescriptor_M lse_grid_desc_m, const LSEGridDescriptor_M lse_grid_desc_m,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
...@@ -127,6 +130,7 @@ __global__ void ...@@ -127,6 +130,7 @@ __global__ void
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
lse_grid_desc_m, lse_grid_desc_m,
block_2_ctile_map, block_2_ctile_map,
c0_matrix_mask, c0_matrix_mask,
...@@ -159,6 +163,7 @@ __global__ void ...@@ -159,6 +163,7 @@ __global__ void
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
lse_grid_desc_m, lse_grid_desc_m,
block_2_ctile_map, block_2_ctile_map,
c0_matrix_mask, c0_matrix_mask,
...@@ -648,6 +653,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -648,6 +653,10 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ = z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_);
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_ =
GridwiseGemm::MakeCShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5(
z_grid_desc_m_n_);
if(p_lse_grid == nullptr) if(p_lse_grid == nullptr)
{ {
is_lse_storing_ = false; is_lse_storing_ = false;
...@@ -693,9 +702,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -693,9 +702,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_; z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
typename GridwiseGemm::ZShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_;
// block-to-c-tile map // block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
...@@ -752,69 +765,72 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -752,69 +765,72 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
float ave_time = 0; float ave_time = 0;
auto launch_kernel = auto launch_kernel = [&](auto has_main_k_block_loop_,
[&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) { auto is_dropout_,
const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle< auto is_lse_storing_) {
GridwiseGemm, const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle<
ADataType, // TODO: distiguish A/B datatype GridwiseGemm,
CDataType, ADataType, // TODO: distiguish A/B datatype
ZDataType, CDataType,
LSEDataType, ZDataType,
GemmAccDataType, LSEDataType,
AElementwiseOperation, GemmAccDataType,
BElementwiseOperation, AElementwiseOperation,
AccElementwiseOperation, BElementwiseOperation,
B1ElementwiseOperation, AccElementwiseOperation,
CElementwiseOperation, B1ElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, CElementwiseOperation,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::B1GridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::LSEGridDesc_M, typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::ZShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5,
ComputeBasePtrOfStridedBatch, DeviceOp::LSEGridDesc_M,
C0MatrixMask, typename GridwiseGemm::DefaultBlock2CTileMap,
has_main_k_block_loop_, ComputeBasePtrOfStridedBatch,
is_dropout_, C0MatrixMask,
is_lse_storing_, has_main_k_block_loop_,
Deterministic>; is_dropout_,
is_lse_storing_,
return launch_and_time_kernel( Deterministic>;
stream_config,
kernel, return launch_and_time_kernel(
dim3(grid_size), stream_config,
dim3(BlockSize), kernel,
0, dim3(grid_size),
arg.p_a_grid_, dim3(BlockSize),
arg.p_b_grid_, 0,
arg.p_b1_grid_, arg.p_a_grid_,
arg.p_c_grid_, arg.p_b_grid_,
arg.p_z_grid_, arg.p_b1_grid_,
arg.p_lse_grid_, arg.p_c_grid_,
arg.a_element_op_, arg.p_z_grid_,
arg.b_element_op_, arg.p_lse_grid_,
arg.acc_element_op_, arg.a_element_op_,
arg.b1_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.acc_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.b1_element_op_,
arg.b_grid_desc_bk0_n_bk1_, arg.c_element_op_,
arg.b1_grid_desc_bk0_n_bk1_, arg.a_grid_desc_ak0_m_ak1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.b_grid_desc_bk0_n_bk1_,
arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg.b1_grid_desc_bk0_n_bk1_,
arg.lse_grid_desc_m_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_, arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.batch_count_, arg.z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_,
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_), arg.lse_grid_desc_m_,
arg.compute_base_ptr_of_batch_, arg.block_2_ctile_map_,
arg.c0_matrix_mask_, arg.batch_count_,
arg.p_dropout_in_16bits_, arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_),
arg.p_dropout_rescale_, arg.compute_base_ptr_of_batch_,
arg.seed_, arg.c0_matrix_mask_,
arg.offset_, arg.p_dropout_in_16bits_,
arg.raw_lengths_mz_nz_kz_gemm1nz_[0], arg.p_dropout_rescale_,
arg.raw_lengths_mz_nz_kz_gemm1nz_[1]); arg.seed_,
}; arg.offset_,
arg.raw_lengths_mz_nz_kz_gemm1nz_[0],
arg.raw_lengths_mz_nz_kz_gemm1nz_[1]);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop // to concern Gemm0's loop
......
...@@ -122,7 +122,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -122,7 +122,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
// C desc for source in blockwise copy // C desc for source in gridwise copy
__host__ __device__ static constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5( __host__ __device__ static constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
const ZGridDesc_M_N& z_grid_desc_m_n) ////=> for z use const ZGridDesc_M_N& z_grid_desc_m_n) ////=> for z use
{ {
...@@ -143,6 +143,41 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -143,6 +143,41 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{})); make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
} }
// C shuffle desc for source in gridwise copy
__host__ __device__ static constexpr auto
MakeCShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5(
const ZGridDesc_M_N& z_grid_desc_m_n) ////=> for z use to shuffle
{
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
// printf("M / MPerBlock %d, ", M / MPerBlock);
// printf("MXdlPerWave %d, " , MXdlPerWave);
// printf("Gemm0MWaves %d, " , Gemm0MWaves);
// printf("MPerXdl / N5 %d, " , MPerXdl / N5);
// printf("N5 %d \n" , N5);
// printf("N / NPerBlock %d, " , N / NPerBlock);
// printf("NXdlPerWave %d, " , NXdlPerWave);
// printf("Gemm0NWaves %d, " , Gemm0NWaves);
// printf("N3 %d, " , N3);
// printf("N4 %d, " , N4);
// printf("N5 %d, " , N5);
return transform_tensor_descriptor(
z_grid_desc_m_n,
make_tuple(make_unmerge_transform(
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, MPerXdl / N5, N5)),
make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, N3, N4, N5))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6, 9>{}, Sequence<1, 3, 5, 7, 8, 10>{}));
}
__device__ static auto GetGemm0WaveIdx() __device__ static auto GetGemm0WaveIdx()
{ {
const index_t thread_id = get_thread_local_1d_id(); const index_t thread_id = get_thread_local_1d_id();
...@@ -381,6 +416,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -381,6 +416,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype( using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(ZGridDesc_M_N{}))>; MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(ZGridDesc_M_N{}))>;
using ZShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5 = remove_cvref_t<decltype(
MakeCShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5(ZGridDesc_M_N{}))>;
struct SharedMemTrait struct SharedMemTrait
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
...@@ -441,6 +479,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -441,6 +479,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5& const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const ZShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5&
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
const LSEGridDesc_M& lse_grid_desc_m, const LSEGridDesc_M& lse_grid_desc_m,
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask, const C0MatrixMask& c0_matrix_mask,
...@@ -856,6 +896,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -856,6 +896,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// z vgpr copy to global // z vgpr copy to global
// //
// z matrix threadwise desc // z matrix threadwise desc
// if(get_thread_global_1d_id()==0){
// printf("m2 is %d \n",m2.value);
// printf("n2 is %d \n",n2.value);
// printf("n3 is %d \n",n3.value);
// printf("n4 is %d \n",n4.value);
//}
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID I1, // NBlockID
...@@ -868,20 +916,138 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -868,20 +916,138 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
n3, // NInputNum n3, // NInputNum
n4)); // registerNum n4)); // registerNum
constexpr auto z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, //
I1, //
m0, //
n0, //
m1, //
n1, //
m2, // m0
n2, // m1
n3, // n0
n4, // n1
I1)); // n2
// ignore = z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5;
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
unsigned short, unsigned short,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(), z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
true> true>
z_tenor_buffer; z_tenor_tmp_buffer;
z_tenor_tmp_buffer.Clear();
StaticBuffer<AddressSpaceEnum::Vgpr,
unsigned short,
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5.GetElementSpaceSize(),
true>
z_tenor_buffer; // z buffer after shuffle
z_tenor_buffer.Clear(); z_tenor_buffer.Clear();
// z matrix global desc
// z matrix global desc
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize()); p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
auto z_grid_tmp_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
// ignore = z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5;
// if(get_thread_global_1d_id()==0){
// printf("z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize() is %ld \n",
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
// printf("z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5.GetElementSpaceSize() is
// %ld \n", z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5.GetElementSpaceSize());
//
//}
const auto wave_id = GetGemm0WaveIdx(); const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63 const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
// if(get_block_1d_id()==0){
// if(get_thread_local_1d_id()==0){
// printf("tid = 0 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
// if(get_thread_local_1d_id()==1){
// printf("tid = 1 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
// if(get_thread_local_1d_id()==2){
// printf("tid = 2 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
// if(get_thread_local_1d_id()==3){
// printf("tid = 3 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
// if(get_thread_local_1d_id()==32){
// printf("tid = 32 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
// if(get_thread_local_1d_id()==64){
// printf("tid = 32 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
//}
auto z_tmp_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough,
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx_m, // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0),
tensor_operation::element_wise::PassThrough{}};
auto z_shuffle_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
ZDataType,
ushort,
decltype(z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5),
decltype(z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5),
Sequence<I1, I1, m0, n0, m1, n1, m2, n2, n3, n4, I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10>,
10,
1,
1,
true /* ResetCoordAfterRun */>{z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
make_multi_index(block_work_idx_m, // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
int(wave_m_n_id[I1] / 4), // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0,
wave_m_n_id[I1] % 4)};
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort, ushort,
ZDataType, ZDataType,
...@@ -1060,10 +1226,29 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1060,10 +1226,29 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// save z to global // save z to global
if(p_z_grid) if(p_z_grid)
{ {
blockwise_dropout.template ApplyDropout_v1r2<decltype(acc_thread_buf), blockwise_dropout.template GenerateZMatrixAttnFwd<decltype(z_tenor_tmp_buffer)>(
ph, global_elem_id, z_tenor_tmp_buffer);
z_tmp_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_tmp_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_tmp_buf);
block_sync_lds();
z_shuffle_thread_copy_global_to_vgpr.Run(
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
z_grid_tmp_buf,
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer);
blockwise_dropout.template ApplyDropoutWithZ<decltype(acc_thread_buf),
decltype(z_tenor_buffer), decltype(z_tenor_buffer),
false>( false>(acc_thread_buf,
acc_thread_buf, ph, global_elem_id, z_tenor_buffer); z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run( z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
...@@ -1071,28 +1256,17 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1071,28 +1256,17 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
z_tenor_buffer, z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf); z_grid_buf);
// static_for<0, n0, 1>{}([&](auto i) {
// blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf), block_sync_lds();
// decltype(z_tenor_buffer),
// false, z_tmp_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// decltype(n0), z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// decltype(i)>( make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
// acc_thread_buf, ph, global_elem_id + id_step * i.value,
// z_tenor_buffer); z_shuffle_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
// z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
// z_thread_copy_vgpr_to_global.Run( make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
// z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
// z_tenor_buffer,
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// z_grid_buf);
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
//});
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 0, 0, -(n0.value), 0, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
...@@ -1101,7 +1275,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1101,7 +1275,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
{ {
// ignore = z_grid_buf; // ignore = z_grid_buf;
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropout_v1r1<decltype(acc_thread_buf), false>( blockwise_dropout.template ApplyDropoutAttnFwd<decltype(acc_thread_buf), false>(
acc_thread_buf, ph, global_elem_id); acc_thread_buf, ph, global_elem_id);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment