Commit 7af1b43a authored by ltqin's avatar ltqin
Browse files

one block dv pass

parent b281dac5
...@@ -12,7 +12,7 @@ add_example_executable(example_batched_multihead_attention_backward batched_mult ...@@ -12,7 +12,7 @@ add_example_executable(example_batched_multihead_attention_backward batched_mult
add_example_executable(example_grouped_multihead_attention_train grouped_multihead_attention_train.cpp) add_example_executable(example_grouped_multihead_attention_train grouped_multihead_attention_train.cpp)
add_example_executable(example_batched_multihead_attention_train batched_multihead_attention_train.cpp) add_example_executable(example_batched_multihead_attention_train batched_multihead_attention_train.cpp)
add_example_executable(example_batched_multihead_attention_backward_V1R2 batched_multihead_attention_backward_V2R2.cpp) add_example_executable(example_batched_multihead_attention_backward_V2R2 batched_multihead_attention_backward_V2R2.cpp)
add_custom_target(example_gemm_scale_softmax_gemm) add_custom_target(example_gemm_scale_softmax_gemm)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
......
...@@ -350,9 +350,9 @@ using DeviceGemmInstance = ...@@ -350,9 +350,9 @@ using DeviceGemmInstance =
2, // B1K1 2, // B1K1
32, // MPerXDL 32, // MPerXDL
32, // NPerXDL 32, // NPerXDL
4, // MXdlPerWave 1, // MXdlPerWave
1, // NXdlPerWave 4, // NXdlPerWave
1, // Gemm1NXdlPerWave 4, // Gemm1NXdlPerWave
2, // Gemm2NXdlPerWave 2, // Gemm2NXdlPerWave
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
...@@ -375,8 +375,8 @@ using DeviceGemmInstance = ...@@ -375,8 +375,8 @@ using DeviceGemmInstance =
4, 4,
2, 2,
false, false,
4, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle 4, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec, // MaskingSpecialization MaskingSpec, // MaskingSpecialization
...@@ -501,17 +501,17 @@ int run(int argc, char* argv[]) ...@@ -501,17 +501,17 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 512; ck::index_t M = 128;
ck::index_t N = 512; ck::index_t N = 128;
ck::index_t K = DIM; ck::index_t K = DIM;
ck::index_t O = DIM; ck::index_t O = DIM;
ck::index_t G0 = 4; ck::index_t G0 = 1;
ck::index_t G1 = 6; ck::index_t G1 = 1;
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
float p_drop = 0.2; float p_drop = 0;
const unsigned long long seed = 1; const unsigned long long seed = 1;
const unsigned long long offset = 0; const unsigned long long offset = 0;
...@@ -1040,6 +1040,7 @@ int run(int argc, char* argv[]) ...@@ -1040,6 +1040,7 @@ int run(int argc, char* argv[])
"error", "error",
1e-2, 1e-2,
1e-2); 1e-2);
//std::cout << vgrad_gs_os_ns_device_result << std::endl;
} }
return pass ? ((void)(std::cout << "pass\n"), 0) : ((void)(std::cout << "fail\n"), 1); return pass ? ((void)(std::cout << "pass\n"), 0) : ((void)(std::cout << "fail\n"), 1);
......
...@@ -903,8 +903,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -903,8 +903,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, // AccessOrder Sequence<0, 1, 2, 3, 4, 5, 6, 7>, // AccessOrder
7, // VectorDim 7, // VectorDim
2, // ScalarPerVector 2, // ScalarPerVector
InMemoryDataOperationEnum::AtomicAdd, // GlobalMemoryDataOperation InMemoryDataOperationEnum::Set, // GlobalMemoryDataOperation
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>; true>;
}; };
...@@ -1177,7 +1177,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1177,7 +1177,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const C0MatrixMask& c0_matrix_mask, const C0MatrixMask& c0_matrix_mask,
const float p_drop, const float p_drop,
ck::philox& ph, ck::philox& ph,
const index_t block_idx_m) const index_t block_idx_n)
{ {
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop); const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout); const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
...@@ -1217,11 +1217,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1217,11 +1217,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
return; return;
} }
const index_t block_work_idx_m = Deterministic ? block_idx_m : block_work_idx[I0]; const index_t block_work_idx_n = Deterministic ? block_idx_n : block_work_idx[I0];
// HACK: this force m/o_block_data_idx_on_grid into SGPR // HACK: this force m/o_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx_m * NPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx_n * NPerBlock);
const index_t o_block_data_idx_on_grid = const index_t o_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
...@@ -1254,7 +1254,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1254,7 +1254,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto s_gemm_tile_q_blockwise_copy = auto s_gemm_tile_q_blockwise_copy =
typename Gemm0::template ABlockwiseCopy<decltype(q_grid_desc_k0_m_k1)>( typename Gemm0::template ABlockwiseCopy<decltype(q_grid_desc_k0_m_k1)>(
q_grid_desc_k0_m_k1, q_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, 0, 0), // will loop over GemmM dimension
a_element_op, a_element_op,
Gemm0::a_block_desc_ak0_m_ak1, Gemm0::a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
...@@ -1264,7 +1264,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1264,7 +1264,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto s_gemm_tile_k_blockwise_copy = auto s_gemm_tile_k_blockwise_copy =
typename Gemm0::template BBlockwiseCopy<decltype(k_grid_desc_k0_n_k1)>( typename Gemm0::template BBlockwiseCopy<decltype(k_grid_desc_k0_n_k1)>(
k_grid_desc_k0_n_k1, k_grid_desc_k0_n_k1,
make_multi_index(0, 0, 0), // will loop over GemmN dimension make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op, b_element_op,
Gemm0::b_block_desc_bk0_n_bk1, Gemm0::b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
...@@ -1276,9 +1276,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1276,9 +1276,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto s_slash_p_thread_buf = s_blockwise_gemm.GetCThreadBuffer(); auto s_slash_p_thread_buf = s_blockwise_gemm.GetCThreadBuffer();
const auto s_gemm_tile_a_block_reset_copy_step = const auto s_gemm_tile_a_block_reset_copy_step =
make_multi_index(-q_grid_desc_k0_m_k1.GetLength(I0), 0, 0); make_multi_index(-q_grid_desc_k0_m_k1.GetLength(I0), MPerBlock, 0);
const auto s_gemm_tile_b_block_reset_copy_step = const auto s_gemm_tile_b_block_reset_copy_step =
make_multi_index(-k_grid_desc_k0_n_k1.GetLength(I0), NPerBlock, 0); make_multi_index(-k_grid_desc_k0_n_k1.GetLength(I0), 0, 0);
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2)) / KPerBlock); (q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2)) / KPerBlock);
...@@ -1293,7 +1293,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1293,7 +1293,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto pgrad_gemm_tile_ygrad_blockwise_copy = auto pgrad_gemm_tile_ygrad_blockwise_copy =
typename Gemm0::template ABlockwiseCopy<decltype(ygrad_grid_desc_o0_m_o1)>( typename Gemm0::template ABlockwiseCopy<decltype(ygrad_grid_desc_o0_m_o1)>(
ygrad_grid_desc_o0_m_o1, ygrad_grid_desc_o0_m_o1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, 0, 0), // will loop over GemmM dimension
tensor_operation::element_wise::PassThrough{}, tensor_operation::element_wise::PassThrough{},
Gemm0::a_block_desc_ak0_m_ak1, Gemm0::a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
...@@ -1303,7 +1303,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1303,7 +1303,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto pgrad_gemm_tile_v_blockwise_copy = auto pgrad_gemm_tile_v_blockwise_copy =
typename Gemm0::template BBlockwiseCopy<decltype(v_grid_desc_o0_n_o1)>( typename Gemm0::template BBlockwiseCopy<decltype(v_grid_desc_o0_n_o1)>(
v_grid_desc_o0_n_o1, v_grid_desc_o0_n_o1,
make_multi_index(0, 0, 0), // will loop over GemmN dimension make_multi_index(0, n_block_data_idx_on_grid, 0),
tensor_operation::element_wise::PassThrough{}, tensor_operation::element_wise::PassThrough{},
Gemm0::b_block_desc_bk0_n_bk1, Gemm0::b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
...@@ -1448,7 +1448,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1448,7 +1448,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
1, 1,
false>{ false>{
lse_grid_desc_mblock_mrepeat_mwave_mperxdl, lse_grid_desc_mblock_mrepeat_mwave_mperxdl,
make_multi_index(block_work_idx_m, // mblock make_multi_index(0, // mblock
acc0_thread_origin[I0], // mrepeat acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4])}; // mperxdl acc0_thread_origin[I4])}; // mperxdl
...@@ -1510,15 +1510,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1510,15 +1510,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx_m, // MBlockId make_multi_index(0, // MBlockId
0, // NBlockId block_work_idx_n, // NBlockId
0, // mrepeat 0, // mrepeat
0, // nrepeat 0, // nrepeat
wave_id[I0], // MWaveId wave_id[I0], // MWaveId
wave_id[I1], // NWaveId wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl wave_m_n_id[I1], // MPerXdl
0, // group 0, // group
wave_m_n_id[I0], // NInputIndex wave_m_n_id[I0], // NInputIndex
0), 0),
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
...@@ -1551,7 +1551,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1551,7 +1551,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto vgrad_gemm_tile_ygrad_blockwise_copy = auto vgrad_gemm_tile_ygrad_blockwise_copy =
typename Gemm2::template BBlockwiseCopy<decltype(ygrad_grid_desc_m0_o_m1)>( typename Gemm2::template BBlockwiseCopy<decltype(ygrad_grid_desc_m0_o_m1)>(
ygrad_grid_desc_m0_o_m1, ygrad_grid_desc_m0_o_m1,
make_multi_index(m_block_data_idx_on_grid / Gemm2Params_N_O_M::B_M1, make_multi_index(0, // QLT
o_block_data_idx_on_grid, o_block_data_idx_on_grid,
0), 0),
tensor_operation::element_wise::PassThrough{}, tensor_operation::element_wise::PassThrough{},
...@@ -1563,6 +1563,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1563,6 +1563,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto v_slash_k_grad_blockwise_gemm = typename Gemm2::BlockwiseGemm{}; auto v_slash_k_grad_blockwise_gemm = typename Gemm2::BlockwiseGemm{};
auto v_slash_k_grad_thread_buf = v_slash_k_grad_blockwise_gemm.GetCThreadBuffer(); auto v_slash_k_grad_thread_buf = v_slash_k_grad_blockwise_gemm.GetCThreadBuffer();
v_slash_k_grad_thread_buf.Clear();
// dV: C VGPR-to-global copy // dV: C VGPR-to-global copy
const auto vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4 = const auto vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4 =
...@@ -1597,7 +1598,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1597,7 +1598,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto kgrad_gemm_tile_q_blockwise_copy = auto kgrad_gemm_tile_q_blockwise_copy =
typename Gemm2::template BBlockwiseCopy<decltype(q_grid_desc_m0_k_m1)>( typename Gemm2::template BBlockwiseCopy<decltype(q_grid_desc_m0_k_m1)>(
q_grid_desc_m0_k_m1, q_grid_desc_m0_k_m1,
make_multi_index(m_block_data_idx_on_grid / Gemm2Params_N_O_M::B_M1, make_multi_index(0, // QLT
o_block_data_idx_on_grid, o_block_data_idx_on_grid,
0), 0),
tensor_operation::element_wise::PassThrough{}, tensor_operation::element_wise::PassThrough{},
...@@ -1645,10 +1646,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1645,10 +1646,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const auto y_thread_data_on_block_idx = const auto y_thread_data_on_block_idx =
y_thread_cluster_idx * y_thread_desc_m0_m1_o0_o1.GetLengths(); y_thread_cluster_idx * y_thread_desc_m0_m1_o0_o1.GetLengths();
const auto y_thread_data_on_grid_idx = const auto y_thread_data_on_grid_idx = y_thread_data_on_block_idx; // QLT
make_multi_index(
block_work_idx_m, I0, I0 /* all WGs start from o_block_idx = 0 */, I0) +
y_thread_data_on_block_idx;
// performs double duty for both y and ygrad // performs double duty for both y and ygrad
auto yygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2< auto yygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
...@@ -1773,8 +1771,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1773,8 +1771,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
index_t gemm1_m_block_outer_index = 0; index_t gemm1_m_block_outer_index = 0;
do do
{ {
auto n_block_data_idx_on_grid = auto m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(gemm1_m_block_outer_index * NPerBlock); __builtin_amdgcn_readfirstlane(gemm1_m_block_outer_index * MPerBlock);
if(c0_matrix_mask.IsTileSkippable( if(c0_matrix_mask.IsTileSkippable(
m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock)) m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
{ {
...@@ -1912,62 +1910,54 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1912,62 +1910,54 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
ignore = gemm2_b_block_buf; ignore = gemm2_b_block_buf;
ignore = v_slash_k_grad_thread_buf; ignore = v_slash_k_grad_thread_buf;
SubThreadBlock<BlockSize> gemm2_a_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0],
// SubThreadBlock<BlockSize> gemm2_a_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0], s_blockwise_gemm.GetWaveIdx()[I1]);
// s_blockwise_gemm.GetWaveIdx()[I1]);
constexpr index_t num_gemm2_loop = MPerBlock / Gemm2Params_N_O_M::Sum_M;
// constexpr index_t num_gemm2_loop = MPerBlock / Gemm2Params_N_O_M::Sum_M; static_assert(Gemm2::ASrcBlockSliceWindowIterator::GetNumOfAccess() == num_gemm2_loop,
// static_assert(Gemm2::ASrcBlockSliceWindowIterator::GetNumOfAccess() == num_gemm2_loop, "");
// "");
// TODO: tune gemm2 pipeline
// // TODO: tune gemm2 pipeline // dV = P_drop^T * dY
// // dV = P_drop^T * dY static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dV
// v_slash_k_grad_thread_buf.Clear(); // load VGrad Gemm B
// static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dV vgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1,
// // load VGrad Gemm B ygrad_grid_buf);
// vgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1,
// ygrad_grid_buf); // load VGrad Gemm A
const auto p_slice_idx =
// // load VGrad Gemm A Gemm2::ASrcBlockSliceWindowIterator::GetIndexTupleOfNumber(gemm2_loop_idx);
// const auto p_slice_idx = constexpr auto mwave_range = make_tuple(
// Gemm2::ASrcBlockSliceWindowIterator::GetIndexTupleOfNumber(gemm2_loop_idx); p_slice_idx[I2],
// constexpr auto mwave_range = make_tuple( p_slice_idx[I2] + Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I2));
// p_slice_idx[I2], constexpr auto nwave_range = make_tuple(
// p_slice_idx[I2] + Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I2)); p_slice_idx[I3],
// constexpr auto nwave_range = make_tuple( p_slice_idx[I3] + Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3));
// p_slice_idx[I3],
// p_slice_idx[I3] + Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3)); if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{
// if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range)) vgrad_gemm_tile_p_thread_copy_vgpr_to_lds.Run(
// { Gemm2::a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
// vgrad_gemm_tile_p_thread_copy_vgpr_to_lds.Run( make_tuple(p_slice_idx[I0], p_slice_idx[I1], I0, I0, I0, I0, I0, I0),
// Gemm2::a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, s_slash_p_thread_buf,
// make_tuple(p_slice_idx[I0], p_slice_idx[I1], I0, I0, I0, I0, I0, I0), Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
// s_slash_p_thread_buf, gemm2_a_block_buf);
// Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, }
// gemm2_a_block_buf);
// } // ygrad slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer
// p slice window is moved by loop index
// // ygrad slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
// // p slice window is moved by loop index ygrad_grid_desc_m0_o_m1, Gemm2::b_block_slice_copy_step);
// vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
// ygrad_grid_desc_m0_o_m1, Gemm2::b_block_slice_copy_step); block_sync_lds(); // sync before write
vgrad_gemm_tile_ygrad_blockwise_copy.RunWrite(Gemm2::b_block_desc_m0_o_m1,
// block_sync_lds(); // sync before write gemm2_b_block_buf);
// vgrad_gemm_tile_ygrad_blockwise_copy.RunWrite(Gemm2::b_block_desc_m0_o_m1,
// gemm2_b_block_buf); block_sync_lds(); // sync before read
v_slash_k_grad_blockwise_gemm.Run(
// block_sync_lds(); // sync before read gemm2_a_block_buf, gemm2_b_block_buf, v_slash_k_grad_thread_buf);
// v_slash_k_grad_blockwise_gemm.Run(
// gemm2_a_block_buf, gemm2_b_block_buf, v_slash_k_grad_thread_buf); }); // end gemm dV
// }); // end gemm dV
// // atomic_add dV
// vgrad_thread_copy_vgpr_to_global.Run(Gemm2::c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
// v_slash_k_grad_thread_buf,
// vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
// vgrad_grid_buf);
// // gemm dP // // gemm dP
// block_sync_lds(); // block_sync_lds();
...@@ -2019,9 +2009,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -2019,9 +2009,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// { // {
// // TODO: explore using dynamic buffer for a1 thread buffer // // TODO: explore using dynamic buffer for a1 thread buffer
// // For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(), // // For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// // RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that // // RunWrite(), and MoveSliceWindow(). But it is impossible to implement given
// that
// // the A1 source buffer is static buffer holding the output of first GEMM and // // the A1 source buffer is static buffer holding the output of first GEMM and
// // requires constexpr offset by design. Therefore, we pass tensor coordinate offset // // requires constexpr offset by design. Therefore, we pass tensor coordinate
// offset
// // explicitly in Run() below. // // explicitly in Run() below.
// // preload data into LDS // // preload data into LDS
...@@ -2040,12 +2032,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -2040,12 +2032,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// { // {
// static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) { // static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) {
// qgrad_gemm_tile_sgrad_blockwise_copy.Run(Gemm1::a_src_thread_desc_k0_m_k1, // qgrad_gemm_tile_sgrad_blockwise_copy.Run(Gemm1::a_src_thread_desc_k0_m_k1,
// Gemm1::a_block_slice_copy_step * i, // Gemm1::a_block_slice_copy_step *
// sgrad_thread_buf, // i, sgrad_thread_buf,
// Gemm1::a_thread_desc_k0_m_k1, // Gemm1::a_thread_desc_k0_m_k1,
// make_tuple(I0, I0, I0), // make_tuple(I0, I0, I0),
// gemm1_a_thread_buf); // gemm1_a_thread_buf);
// qgrad_gemm_tile_k_blockwise_copy.RunRead(k_grid_desc_n0_k_n1, k_grid_buf); // qgrad_gemm_tile_k_blockwise_copy.RunRead(k_grid_desc_n0_k_n1,
// k_grid_buf);
// block_sync_lds(); // block_sync_lds();
...@@ -2065,11 +2058,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -2065,11 +2058,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// { // {
// qgrad_gemm_tile_sgrad_blockwise_copy.Run( // qgrad_gemm_tile_sgrad_blockwise_copy.Run(
// Gemm1::a_src_thread_desc_k0_m_k1, // Gemm1::a_src_thread_desc_k0_m_k1,
// Gemm1::a_block_slice_copy_step * Number<num_gemm1_k_block_inner_loop - 1>{}, // Gemm1::a_block_slice_copy_step * Number<num_gemm1_k_block_inner_loop -
// sgrad_thread_buf, // 1>{}, sgrad_thread_buf, Gemm1::a_thread_desc_k0_m_k1, make_tuple(I0, I0,
// Gemm1::a_thread_desc_k0_m_k1, // I0), gemm1_a_thread_buf);
// make_tuple(I0, I0, I0),
// gemm1_a_thread_buf);
// block_sync_lds(); // block_sync_lds();
...@@ -2107,7 +2098,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -2107,7 +2098,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// gemm2_a_block_buf); // gemm2_a_block_buf);
// } // }
// // kgrad slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer // // kgrad slice window is moved with MoveSrcSliceWindow() since it is dynamic
// buffer
// // sgrad slice window is moved by loop index // // sgrad slice window is moved by loop index
// kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(q_grid_desc_m0_k_m1, // kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(q_grid_desc_m0_k_m1,
// Gemm2::b_block_slice_copy_step); // Gemm2::b_block_slice_copy_step);
...@@ -2135,11 +2127,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -2135,11 +2127,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
s_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow( s_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(
k_grid_desc_k0_n_k1, k_grid_desc_k0_n_k1,
s_gemm_tile_b_block_reset_copy_step); // rewind K and step N s_gemm_tile_b_block_reset_copy_step); // rewind K and step N
vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow( // vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_m0_o_m1, // ygrad_grid_desc_m0_o_m1,
Gemm2::b_block_reset_copy_step); // rewind M // Gemm2::b_block_reset_copy_step); // rewind M
vgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow( // vgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step N // vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step N
pgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow( pgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_o0_m_o1, pgrad_gemm_tile_ygrad_block_reset_copy_step); // rewind O ygrad_grid_desc_o0_m_o1, pgrad_gemm_tile_ygrad_block_reset_copy_step); // rewind O
pgrad_gemm_tile_v_blockwise_copy.MoveSrcSliceWindow( pgrad_gemm_tile_v_blockwise_copy.MoveSrcSliceWindow(
...@@ -2153,10 +2145,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -2153,10 +2145,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
z_thread_copy_vgpr_to_global.MoveDstSliceWindow( z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)); make_multi_index(1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
} while(++gemm1_m_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop } while(++gemm1_m_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
// atomic_add dV
vgrad_thread_copy_vgpr_to_global.Run(Gemm2::c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
v_slash_k_grad_thread_buf,
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
vgrad_grid_buf);
ignore = c_element_op; ignore = c_element_op;
ignore = qgrad_grid_buf; ignore = qgrad_grid_buf;
ignore = qgrad_grid_desc_mblock_mperblock_kblock_kperblock; ignore = qgrad_grid_desc_mblock_mperblock_kblock_kperblock;
......
...@@ -917,7 +917,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -917,7 +917,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
{ {
block_sync_lds(); block_sync_lds();
} }
do do
{ {
auto n_block_data_idx_on_grid = auto n_block_data_idx_on_grid =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment