Commit 7af1b43a authored by ltqin's avatar ltqin
Browse files

one block dv pass

parent b281dac5
......@@ -12,7 +12,7 @@ add_example_executable(example_batched_multihead_attention_backward batched_mult
add_example_executable(example_grouped_multihead_attention_train grouped_multihead_attention_train.cpp)
add_example_executable(example_batched_multihead_attention_train batched_multihead_attention_train.cpp)
add_example_executable(example_batched_multihead_attention_backward_V1R2 batched_multihead_attention_backward_V2R2.cpp)
add_example_executable(example_batched_multihead_attention_backward_V2R2 batched_multihead_attention_backward_V2R2.cpp)
add_custom_target(example_gemm_scale_softmax_gemm)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
......
......@@ -350,9 +350,9 @@ using DeviceGemmInstance =
2, // B1K1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
1, // NXdlPerWave
1, // Gemm1NXdlPerWave
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
2, // Gemm2NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
......@@ -375,8 +375,8 @@ using DeviceGemmInstance =
4,
2,
false,
4, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
1, // CShuffleMXdlPerWavePerShuffle
4, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec, // MaskingSpecialization
......@@ -501,17 +501,17 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 512;
ck::index_t N = 512;
ck::index_t M = 128;
ck::index_t N = 128;
ck::index_t K = DIM;
ck::index_t O = DIM;
ck::index_t G0 = 4;
ck::index_t G1 = 6;
ck::index_t G0 = 1;
ck::index_t G1 = 1;
bool input_permute = false;
bool output_permute = false;
float p_drop = 0.2;
float p_drop = 0;
const unsigned long long seed = 1;
const unsigned long long offset = 0;
......@@ -1040,6 +1040,7 @@ int run(int argc, char* argv[])
"error",
1e-2,
1e-2);
//std::cout << vgrad_gs_os_ns_device_result << std::endl;
}
return pass ? ((void)(std::cout << "pass\n"), 0) : ((void)(std::cout << "fail\n"), 1);
......
......@@ -903,8 +903,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, // AccessOrder
7, // VectorDim
2, // ScalarPerVector
InMemoryDataOperationEnum::AtomicAdd, // GlobalMemoryDataOperation
1, // DstScalarStrideInVector
InMemoryDataOperationEnum::Set, // GlobalMemoryDataOperation
1, // DstScalarStrideInVector
true>;
};
......@@ -1177,7 +1177,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const C0MatrixMask& c0_matrix_mask,
const float p_drop,
ck::philox& ph,
const index_t block_idx_m)
const index_t block_idx_n)
{
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
......@@ -1217,11 +1217,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
return;
}
const index_t block_work_idx_m = Deterministic ? block_idx_m : block_work_idx[I0];
const index_t block_work_idx_n = Deterministic ? block_idx_n : block_work_idx[I0];
// HACK: this force m/o_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx_m * NPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx_n * NPerBlock);
const index_t o_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
......@@ -1254,7 +1254,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto s_gemm_tile_q_blockwise_copy =
typename Gemm0::template ABlockwiseCopy<decltype(q_grid_desc_k0_m_k1)>(
q_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
make_multi_index(0, 0, 0), // will loop over GemmM dimension
a_element_op,
Gemm0::a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
......@@ -1264,7 +1264,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto s_gemm_tile_k_blockwise_copy =
typename Gemm0::template BBlockwiseCopy<decltype(k_grid_desc_k0_n_k1)>(
k_grid_desc_k0_n_k1,
make_multi_index(0, 0, 0), // will loop over GemmN dimension
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
Gemm0::b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
......@@ -1276,9 +1276,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto s_slash_p_thread_buf = s_blockwise_gemm.GetCThreadBuffer();
const auto s_gemm_tile_a_block_reset_copy_step =
make_multi_index(-q_grid_desc_k0_m_k1.GetLength(I0), 0, 0);
make_multi_index(-q_grid_desc_k0_m_k1.GetLength(I0), MPerBlock, 0);
const auto s_gemm_tile_b_block_reset_copy_step =
make_multi_index(-k_grid_desc_k0_n_k1.GetLength(I0), NPerBlock, 0);
make_multi_index(-k_grid_desc_k0_n_k1.GetLength(I0), 0, 0);
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2)) / KPerBlock);
......@@ -1293,7 +1293,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto pgrad_gemm_tile_ygrad_blockwise_copy =
typename Gemm0::template ABlockwiseCopy<decltype(ygrad_grid_desc_o0_m_o1)>(
ygrad_grid_desc_o0_m_o1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
make_multi_index(0, 0, 0), // will loop over GemmM dimension
tensor_operation::element_wise::PassThrough{},
Gemm0::a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
......@@ -1303,7 +1303,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto pgrad_gemm_tile_v_blockwise_copy =
typename Gemm0::template BBlockwiseCopy<decltype(v_grid_desc_o0_n_o1)>(
v_grid_desc_o0_n_o1,
make_multi_index(0, 0, 0), // will loop over GemmN dimension
make_multi_index(0, n_block_data_idx_on_grid, 0),
tensor_operation::element_wise::PassThrough{},
Gemm0::b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
......@@ -1448,7 +1448,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
1,
false>{
lse_grid_desc_mblock_mrepeat_mwave_mperxdl,
make_multi_index(block_work_idx_m, // mblock
make_multi_index(0, // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4])}; // mperxdl
......@@ -1510,15 +1510,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx_m, // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
make_multi_index(0, // MBlockId
block_work_idx_n, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0),
tensor_operation::element_wise::PassThrough{}};
......@@ -1551,7 +1551,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto vgrad_gemm_tile_ygrad_blockwise_copy =
typename Gemm2::template BBlockwiseCopy<decltype(ygrad_grid_desc_m0_o_m1)>(
ygrad_grid_desc_m0_o_m1,
make_multi_index(m_block_data_idx_on_grid / Gemm2Params_N_O_M::B_M1,
make_multi_index(0, // QLT
o_block_data_idx_on_grid,
0),
tensor_operation::element_wise::PassThrough{},
......@@ -1563,6 +1563,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto v_slash_k_grad_blockwise_gemm = typename Gemm2::BlockwiseGemm{};
auto v_slash_k_grad_thread_buf = v_slash_k_grad_blockwise_gemm.GetCThreadBuffer();
v_slash_k_grad_thread_buf.Clear();
// dV: C VGPR-to-global copy
const auto vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4 =
......@@ -1597,7 +1598,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto kgrad_gemm_tile_q_blockwise_copy =
typename Gemm2::template BBlockwiseCopy<decltype(q_grid_desc_m0_k_m1)>(
q_grid_desc_m0_k_m1,
make_multi_index(m_block_data_idx_on_grid / Gemm2Params_N_O_M::B_M1,
make_multi_index(0, // QLT
o_block_data_idx_on_grid,
0),
tensor_operation::element_wise::PassThrough{},
......@@ -1645,10 +1646,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const auto y_thread_data_on_block_idx =
y_thread_cluster_idx * y_thread_desc_m0_m1_o0_o1.GetLengths();
const auto y_thread_data_on_grid_idx =
make_multi_index(
block_work_idx_m, I0, I0 /* all WGs start from o_block_idx = 0 */, I0) +
y_thread_data_on_block_idx;
const auto y_thread_data_on_grid_idx = y_thread_data_on_block_idx; // QLT
// performs double duty for both y and ygrad
auto yygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
......@@ -1773,8 +1771,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
index_t gemm1_m_block_outer_index = 0;
do
{
auto n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(gemm1_m_block_outer_index * NPerBlock);
auto m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(gemm1_m_block_outer_index * MPerBlock);
if(c0_matrix_mask.IsTileSkippable(
m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
{
......@@ -1912,62 +1910,54 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
ignore = gemm2_b_block_buf;
ignore = v_slash_k_grad_thread_buf;
// SubThreadBlock<BlockSize> gemm2_a_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0],
// s_blockwise_gemm.GetWaveIdx()[I1]);
// constexpr index_t num_gemm2_loop = MPerBlock / Gemm2Params_N_O_M::Sum_M;
// static_assert(Gemm2::ASrcBlockSliceWindowIterator::GetNumOfAccess() == num_gemm2_loop,
// "");
// // TODO: tune gemm2 pipeline
// // dV = P_drop^T * dY
// v_slash_k_grad_thread_buf.Clear();
// static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dV
// // load VGrad Gemm B
// vgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1,
// ygrad_grid_buf);
// // load VGrad Gemm A
// const auto p_slice_idx =
// Gemm2::ASrcBlockSliceWindowIterator::GetIndexTupleOfNumber(gemm2_loop_idx);
// constexpr auto mwave_range = make_tuple(
// p_slice_idx[I2],
// p_slice_idx[I2] + Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I2));
// constexpr auto nwave_range = make_tuple(
// p_slice_idx[I3],
// p_slice_idx[I3] + Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3));
// if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
// {
// vgrad_gemm_tile_p_thread_copy_vgpr_to_lds.Run(
// Gemm2::a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
// make_tuple(p_slice_idx[I0], p_slice_idx[I1], I0, I0, I0, I0, I0, I0),
// s_slash_p_thread_buf,
// Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
// gemm2_a_block_buf);
// }
// // ygrad slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer
// // p slice window is moved by loop index
// vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
// ygrad_grid_desc_m0_o_m1, Gemm2::b_block_slice_copy_step);
// block_sync_lds(); // sync before write
// vgrad_gemm_tile_ygrad_blockwise_copy.RunWrite(Gemm2::b_block_desc_m0_o_m1,
// gemm2_b_block_buf);
// block_sync_lds(); // sync before read
// v_slash_k_grad_blockwise_gemm.Run(
// gemm2_a_block_buf, gemm2_b_block_buf, v_slash_k_grad_thread_buf);
// }); // end gemm dV
// // atomic_add dV
// vgrad_thread_copy_vgpr_to_global.Run(Gemm2::c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
// v_slash_k_grad_thread_buf,
// vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
// vgrad_grid_buf);
SubThreadBlock<BlockSize> gemm2_a_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0],
s_blockwise_gemm.GetWaveIdx()[I1]);
constexpr index_t num_gemm2_loop = MPerBlock / Gemm2Params_N_O_M::Sum_M;
static_assert(Gemm2::ASrcBlockSliceWindowIterator::GetNumOfAccess() == num_gemm2_loop,
"");
// TODO: tune gemm2 pipeline
// dV = P_drop^T * dY
static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dV
// load VGrad Gemm B
vgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1,
ygrad_grid_buf);
// load VGrad Gemm A
const auto p_slice_idx =
Gemm2::ASrcBlockSliceWindowIterator::GetIndexTupleOfNumber(gemm2_loop_idx);
constexpr auto mwave_range = make_tuple(
p_slice_idx[I2],
p_slice_idx[I2] + Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I2));
constexpr auto nwave_range = make_tuple(
p_slice_idx[I3],
p_slice_idx[I3] + Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3));
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{
vgrad_gemm_tile_p_thread_copy_vgpr_to_lds.Run(
Gemm2::a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(p_slice_idx[I0], p_slice_idx[I1], I0, I0, I0, I0, I0, I0),
s_slash_p_thread_buf,
Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
gemm2_a_block_buf);
}
// ygrad slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer
// p slice window is moved by loop index
vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_m0_o_m1, Gemm2::b_block_slice_copy_step);
block_sync_lds(); // sync before write
vgrad_gemm_tile_ygrad_blockwise_copy.RunWrite(Gemm2::b_block_desc_m0_o_m1,
gemm2_b_block_buf);
block_sync_lds(); // sync before read
v_slash_k_grad_blockwise_gemm.Run(
gemm2_a_block_buf, gemm2_b_block_buf, v_slash_k_grad_thread_buf);
}); // end gemm dV
// // gemm dP
// block_sync_lds();
......@@ -2019,9 +2009,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// {
// // TODO: explore using dynamic buffer for a1 thread buffer
// // For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// // RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// // RunWrite(), and MoveSliceWindow(). But it is impossible to implement given
// that
// // the A1 source buffer is static buffer holding the output of first GEMM and
// // requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// // requires constexpr offset by design. Therefore, we pass tensor coordinate
// offset
// // explicitly in Run() below.
// // preload data into LDS
......@@ -2040,12 +2032,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// {
// static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) {
// qgrad_gemm_tile_sgrad_blockwise_copy.Run(Gemm1::a_src_thread_desc_k0_m_k1,
// Gemm1::a_block_slice_copy_step * i,
// sgrad_thread_buf,
// Gemm1::a_block_slice_copy_step *
// i, sgrad_thread_buf,
// Gemm1::a_thread_desc_k0_m_k1,
// make_tuple(I0, I0, I0),
// gemm1_a_thread_buf);
// qgrad_gemm_tile_k_blockwise_copy.RunRead(k_grid_desc_n0_k_n1, k_grid_buf);
// qgrad_gemm_tile_k_blockwise_copy.RunRead(k_grid_desc_n0_k_n1,
// k_grid_buf);
// block_sync_lds();
......@@ -2065,11 +2058,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// {
// qgrad_gemm_tile_sgrad_blockwise_copy.Run(
// Gemm1::a_src_thread_desc_k0_m_k1,
// Gemm1::a_block_slice_copy_step * Number<num_gemm1_k_block_inner_loop - 1>{},
// sgrad_thread_buf,
// Gemm1::a_thread_desc_k0_m_k1,
// make_tuple(I0, I0, I0),
// gemm1_a_thread_buf);
// Gemm1::a_block_slice_copy_step * Number<num_gemm1_k_block_inner_loop -
// 1>{}, sgrad_thread_buf, Gemm1::a_thread_desc_k0_m_k1, make_tuple(I0, I0,
// I0), gemm1_a_thread_buf);
// block_sync_lds();
......@@ -2107,7 +2098,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// gemm2_a_block_buf);
// }
// // kgrad slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer
// // kgrad slice window is moved with MoveSrcSliceWindow() since it is dynamic
// buffer
// // sgrad slice window is moved by loop index
// kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(q_grid_desc_m0_k_m1,
// Gemm2::b_block_slice_copy_step);
......@@ -2135,11 +2127,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
s_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(
k_grid_desc_k0_n_k1,
s_gemm_tile_b_block_reset_copy_step); // rewind K and step N
vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_m0_o_m1,
Gemm2::b_block_reset_copy_step); // rewind M
vgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step N
// vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
// ygrad_grid_desc_m0_o_m1,
// Gemm2::b_block_reset_copy_step); // rewind M
// vgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step N
pgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_o0_m_o1, pgrad_gemm_tile_ygrad_block_reset_copy_step); // rewind O
pgrad_gemm_tile_v_blockwise_copy.MoveSrcSliceWindow(
......@@ -2153,10 +2145,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
make_multi_index(1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
} while(++gemm1_m_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
// atomic_add dV
vgrad_thread_copy_vgpr_to_global.Run(Gemm2::c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
v_slash_k_grad_thread_buf,
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
vgrad_grid_buf);
ignore = c_element_op;
ignore = qgrad_grid_buf;
ignore = qgrad_grid_desc_mblock_mperblock_kblock_kperblock;
......
......@@ -917,7 +917,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
{
block_sync_lds();
}
do
{
auto n_block_data_idx_on_grid =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment