"docs/source/en/api/pipelines/latent_diffusion_uncond.mdx" did not exist on "a6e2c1fe5c02cae8a9f077f5d4e11b73d5791723"
Commit 9e16e38e authored by guangzlu's avatar guangzlu
Browse files

added dropout shuffle in lds for fwd

parent cd6e9903
......@@ -255,11 +255,16 @@ struct BlockwiseDropout
return keep ? val * p_dropout_rescale : float(0);
};
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = execute_dropout(z_thread_buf(offset) <= p_dropout_16bits,
in_thread_buf(offset));
tmp_index = tmp_index + 1;
// if(get_thread_global_1d_id()==0){
// printf("z at %d is %u \n", tmp_index, z_thread_buf(offset));
//}
});
});
}
......@@ -306,48 +311,6 @@ struct BlockwiseDropout
});
}
template <typename ZThreadBuffer>
__host__ __device__ void GenerateZMatrix(ck::philox& ph,
index_t element_global_1d_id,
ZThreadBuffer& z_thread_buf,
index_t MRaw)
{
// if(get_thread_global_1d_id() == 0){
// printf("MRepeat & KRepeat is %d , %d . \n", MRepeat, KRepeat);
// }
constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 4;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8 * MRaw);
}
// ushort tmp_id[tmp_size];
// for(int i = 0; i < philox_calls; i++)
//{
// for(int j = 0; j < 4; j++)
// {
// tmp_id[i * 4 + j] = element_global_1d_id + i * 8 * MRaw;
// }
//}
block_sync_lds();
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1;
});
});
}
ushort p_dropout_16bits;
DataType p_dropout_rescale;
};
......
......@@ -143,6 +143,23 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
}
//// Z desc for source in blockwise copy
//__host__ __device__ static constexpr auto GetZBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() ////=>
//for z use
//{
// //const auto M = z_grid_desc_m_n.GetLength(I0);
// //const auto N = z_grid_desc_m_n.GetLength(I1);
//
// constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma;
// constexpr auto N3 = mfma.num_groups_per_blk;
// constexpr auto N4 = mfma.num_input_blks;
// constexpr auto N5 = mfma.group_size;
// return make_naive_tensor_descriptor_packed(
// make_tuple(Number<MXdlPerWave>{}, Number<NXdlPerWave>{}, Number<Gemm0MWaves>{},
// Number<Gemm0NWaves>{},
// Number<MPerXdl>{}, Number<N3>{}, Number<N4>{}, Number<N5>{}));
//}
// C shuffle desc for source in gridwise copy
__host__ __device__ static constexpr auto
MakeCShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5(
......@@ -156,18 +173,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
// printf("M / MPerBlock %d, ", M / MPerBlock);
// printf("MXdlPerWave %d, " , MXdlPerWave);
// printf("Gemm0MWaves %d, " , Gemm0MWaves);
// printf("MPerXdl / N5 %d, " , MPerXdl / N5);
// printf("N5 %d \n" , N5);
// printf("N / NPerBlock %d, " , N / NPerBlock);
// printf("NXdlPerWave %d, " , NXdlPerWave);
// printf("Gemm0NWaves %d, " , Gemm0NWaves);
// printf("N3 %d, " , N3);
// printf("N4 %d, " , N4);
// printf("N5 %d, " , N5);
return transform_tensor_descriptor(
z_grid_desc_m_n,
make_tuple(make_unmerge_transform(
......@@ -175,9 +180,50 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, N3, N4, N5))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6, 9>{}, Sequence<1, 3, 5, 7, 8, 10>{}));
make_tuple(Sequence<0, 2, 4, 6, 9>{}, Sequence<1, 3, 5, 7, 8, 10>{})); // 0247,13568
}
// using ZBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4 = remove_cvref_t<decltype(
// GetZBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4())>;
// Z shuffle desc for source in blockwise copy
//__host__ __device__ static constexpr auto
// GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4(const
// ZBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4& z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4) ////=> for z
// use to shuffle
//{
// //const auto M = z_grid_desc_m_n.GetLength(I0);
// //const auto N = z_grid_desc_m_n.GetLength(I1);
//
// constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma;
// constexpr auto N3 = mfma.num_groups_per_blk;
// constexpr auto N4 = mfma.num_input_blks;
// constexpr auto N5 = mfma.group_size;
//
// constexpr auto z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4 =
// transform_tensor_descriptor(
// z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
// make_tuple(
// make_freeze_transform(Number<MXdlPerWave>{}),
// make_freeze_transform(Number<NXdlPerWave>{}),
// make_freeze_transform(Number<Gemm0MWaves>{}),
// make_freeze_transform(Number<Gemm0NWaves>{}),
// make_unmerge_transform(make_tuple(Number<MPerXdl / N5>{}, Number<N5>{})),
// make_freeze_transform(Number<N3>{}),
// make_freeze_transform(Number<N4>{}),
// make_freeze_transform(Number<N5>{})),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{},Sequence<4>{},
// Sequence<5>{}, Sequence<6>{}, Sequence<7>{}), make_tuple(Sequence<0>{}, Sequence<1>{},
// Sequence<2>{}, Sequence<3>{},Sequence<4,7>{}, Sequence<5>{}, Sequence<6>{},
// Sequence<8>{}));
// return z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4;
// //return make_naive_tensor_descriptor_packed(
// // make_tuple(Number<MXdlPerWave>{}, Number<NXdlPerWave>{}, Number<Gemm0MWaves>{},
// Number<Gemm0NWaves>{},
// // Number<MPerXdl / N5>{}, Number<N3>{}, Number<N4>{}, Number<N5>{},
// Number<N5>{}));
//}
__device__ static auto GetGemm0WaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
......@@ -904,6 +950,41 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// printf("n4 is %d \n",n4.value);
//}
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4)); // registerNum
constexpr auto z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, //
n0, //
m1, //
n1, //
m2, // m0 1
n2, // n0 4
n3, // n1 1
n4, // m1 4
I1)); // n2 1
// constexpr auto z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5 = //for gridwise
// copy
// make_naive_tensor_descriptor_packed(make_tuple(I1,
// I1,
// m0, //
// n0, //
// m1, //
// n1, //
// m2, // m0 1
// n2, // n0 4
// n3, // n1 1
// n4, // m1 4
// I1)); // n2 1
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID
......@@ -916,19 +997,50 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
n3, // NInputNum
n4)); // registerNum
constexpr auto z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, //
I1, //
m0, //
n0, //
m1, //
n1, //
m2, // m0
n2, // m1
n3, // n0
n4, // n1
I1)); // n2
// ignore = z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5;
constexpr auto z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
// constexpr auto z_block_lengths = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLengths();
constexpr auto zM0 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto zN0 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto zM1 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto zN1 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto zM2 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto zN2 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto zN3 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto zN4 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4 =
transform_tensor_descriptor(
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(make_pass_through_transform(zM0),
make_pass_through_transform(zN0),
make_pass_through_transform(zM1),
make_pass_through_transform(zN1),
make_unmerge_transform(make_tuple(Number<zM2.value / zN4.value>{}, zN4)),
make_pass_through_transform(zN2),
make_pass_through_transform(zN3),
make_pass_through_transform(zN4)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4, 7>{},
Sequence<5>{},
Sequence<6>{},
Sequence<8>{}));
// ignore = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4;
// ignore = z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4;
StaticBuffer<AddressSpaceEnum::Vgpr,
unsigned short,
......@@ -939,7 +1051,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
StaticBuffer<AddressSpaceEnum::Vgpr,
unsigned short,
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5.GetElementSpaceSize(),
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4.GetElementSpaceSize(),
true>
z_tenor_buffer; // z buffer after shuffle
z_tenor_buffer.Clear();
......@@ -948,10 +1060,17 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
auto z_grid_tmp_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
// auto z_grid_tmp_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_z_grid,
// z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5.GetElementSpaceSize());
ignore = z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5;
auto z_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ZDataType*>(p_shared),
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize());
// ignore = z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5;
// ignore = z_block_buf;
// if(get_thread_global_1d_id()==0){
// printf("z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize() is %ld \n",
......@@ -990,63 +1109,111 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
//}
auto z_tmp_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough,
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx_m, // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0),
tensor_operation::element_wise::PassThrough{}};
auto z_shuffle_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
/*
auto z_tmp_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough,
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx_m, // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0),
tensor_operation::element_wise::PassThrough{}};
auto z_shuffle_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
ZDataType,
ushort,
decltype(z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5),
decltype(z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5),
Sequence<I1, I1, m0, n0, m1, n1, m2, n2, n3, n4, I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10>,
10,
1,
1,
true >{z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
make_multi_index(block_work_idx_m, //
MBlockId 0, // NBlockId 0, // mrepeat 0, //
nrepeat wave_id[I0], // MWaveId wave_id[I1], // NWaveId
int(wave_m_n_id[I1] / 4), //
MPerXdl 0, // group wave_m_n_id[I0], // NInputIndex 0,
wave_m_n_id[I1] % 4)};
*/
auto z_tmp_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<ushort,
ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
tensor_operation::element_wise::PassThrough,
Sequence<m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_multi_index(0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0),
tensor_operation::element_wise::PassThrough{}};
auto z_shuffle_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
ZDataType,
ushort,
decltype(z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5),
decltype(z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5),
Sequence<I1, I1, m0, n0, m1, n1, m2, n2, n3, n4, I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10>,
10,
decltype(z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4),
decltype(z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4),
Sequence<m0, n0, m1, n1, m2, n2, n3, n4, I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>,
8,
1,
1,
true /* ResetCoordAfterRun */>{z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
make_multi_index(block_work_idx_m, // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
int(wave_m_n_id[I1] / 4), // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0,
wave_m_n_id[I1] % 4)};
true>{z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4,
make_multi_index(0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
int(wave_m_n_id[I1] / 4), // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0,
wave_m_n_id[I1] % 4)};
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
......@@ -1229,27 +1396,42 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
blockwise_dropout.template GenerateZMatrixAttnFwd<decltype(z_tenor_tmp_buffer)>(
ph, global_elem_id, z_tenor_tmp_buffer);
z_tmp_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_tmp_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_tmp_buf);
z_tmp_thread_copy_vgpr_to_lds.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_tmp_buffer,
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
z_block_buf);
// z_tmp_thread_copy_vgpr_to_global.Run(
// z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
// z_tenor_tmp_buffer,
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// z_grid_tmp_buf);
block_sync_lds();
z_shuffle_thread_copy_global_to_vgpr.Run(
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
z_grid_tmp_buf,
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
// ignore = z_shuffle_thread_copy_lds_to_vgpr;
z_shuffle_thread_copy_lds_to_vgpr.Run(
z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4,
z_block_buf,
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer);
// z_shuffle_thread_copy_global_to_vgpr.Run(
// z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
// z_grid_tmp_buf,
// z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
// z_tenor_buffer);
blockwise_dropout.template ApplyDropoutWithZ<decltype(acc_thread_buf),
decltype(z_tenor_buffer),
false>(acc_thread_buf,
z_tenor_buffer);
// ignore = z_tenor_buffer;
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
......@@ -1259,13 +1441,13 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
block_sync_lds();
z_tmp_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
// z_tmp_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
z_shuffle_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
// z_shuffle_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
// z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
// make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment