Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9e16e38e
Commit
9e16e38e
authored
Jun 08, 2023
by
guangzlu
Browse files
added dropout shuffle in lds for fwd
parent
cd6e9903
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
288 additions
and
143 deletions
+288
-143
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
+5
-42
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
...wise_batched_multihead_attention_forward_xdl_cshuffle.hpp
+283
-101
No files found.
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
View file @
9e16e38e
...
...
@@ -255,11 +255,16 @@ struct BlockwiseDropout
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
};
int
tmp_index
=
0
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
execute_dropout
(
z_thread_buf
(
offset
)
<=
p_dropout_16bits
,
in_thread_buf
(
offset
));
tmp_index
=
tmp_index
+
1
;
// if(get_thread_global_1d_id()==0){
// printf("z at %d is %u \n", tmp_index, z_thread_buf(offset));
//}
});
});
}
...
...
@@ -306,48 +311,6 @@ struct BlockwiseDropout
});
}
template
<
typename
ZThreadBuffer
>
__host__
__device__
void
GenerateZMatrix
(
ck
::
philox
&
ph
,
index_t
element_global_1d_id
,
ZThreadBuffer
&
z_thread_buf
,
index_t
MRaw
)
{
// if(get_thread_global_1d_id() == 0){
// printf("MRepeat & KRepeat is %d , %d . \n", MRepeat, KRepeat);
// }
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
philox_calls
=
tmp_size
/
4
;
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
ph
.
get_random_4x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
8
*
MRaw
);
}
// ushort tmp_id[tmp_size];
// for(int i = 0; i < philox_calls; i++)
//{
// for(int j = 0; j < 4; j++)
// {
// tmp_id[i * 4 + j] = element_global_1d_id + i * 8 * MRaw;
// }
//}
block_sync_lds
();
int
tmp_index
=
0
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
z_thread_buf
(
offset
)
=
tmp
[
tmp_index
];
tmp_index
=
tmp_index
+
1
;
});
});
}
ushort
p_dropout_16bits
;
DataType
p_dropout_rescale
;
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
9e16e38e
...
...
@@ -143,6 +143,23 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
}
//// Z desc for source in blockwise copy
//__host__ __device__ static constexpr auto GetZBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() ////=>
//for z use
//{
// //const auto M = z_grid_desc_m_n.GetLength(I0);
// //const auto N = z_grid_desc_m_n.GetLength(I1);
//
// constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma;
// constexpr auto N3 = mfma.num_groups_per_blk;
// constexpr auto N4 = mfma.num_input_blks;
// constexpr auto N5 = mfma.group_size;
// return make_naive_tensor_descriptor_packed(
// make_tuple(Number<MXdlPerWave>{}, Number<NXdlPerWave>{}, Number<Gemm0MWaves>{},
// Number<Gemm0NWaves>{},
// Number<MPerXdl>{}, Number<N3>{}, Number<N4>{}, Number<N5>{}));
//}
// C shuffle desc for source in gridwise copy
__host__
__device__
static
constexpr
auto
MakeCShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5
(
...
...
@@ -156,18 +173,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
// printf("M / MPerBlock %d, ", M / MPerBlock);
// printf("MXdlPerWave %d, " , MXdlPerWave);
// printf("Gemm0MWaves %d, " , Gemm0MWaves);
// printf("MPerXdl / N5 %d, " , MPerXdl / N5);
// printf("N5 %d \n" , N5);
// printf("N / NPerBlock %d, " , N / NPerBlock);
// printf("NXdlPerWave %d, " , NXdlPerWave);
// printf("Gemm0NWaves %d, " , Gemm0NWaves);
// printf("N3 %d, " , N3);
// printf("N4 %d, " , N4);
// printf("N5 %d, " , N5);
return
transform_tensor_descriptor
(
z_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
...
...
@@ -175,9 +180,50 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_unmerge_transform
(
make_tuple
(
N
/
NPerBlock
,
NXdlPerWave
,
Gemm0NWaves
,
N3
,
N4
,
N5
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
,
9
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
10
>
{}));
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
,
9
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
10
>
{}));
// 0247,13568
}
// using ZBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4 = remove_cvref_t<decltype(
// GetZBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4())>;
// Z shuffle desc for source in blockwise copy
//__host__ __device__ static constexpr auto
// GetZShuffleBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4(const
// ZBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4& z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4) ////=> for z
// use to shuffle
//{
// //const auto M = z_grid_desc_m_n.GetLength(I0);
// //const auto N = z_grid_desc_m_n.GetLength(I1);
//
// constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma;
// constexpr auto N3 = mfma.num_groups_per_blk;
// constexpr auto N4 = mfma.num_input_blks;
// constexpr auto N5 = mfma.group_size;
//
// constexpr auto z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4 =
// transform_tensor_descriptor(
// z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
// make_tuple(
// make_freeze_transform(Number<MXdlPerWave>{}),
// make_freeze_transform(Number<NXdlPerWave>{}),
// make_freeze_transform(Number<Gemm0MWaves>{}),
// make_freeze_transform(Number<Gemm0NWaves>{}),
// make_unmerge_transform(make_tuple(Number<MPerXdl / N5>{}, Number<N5>{})),
// make_freeze_transform(Number<N3>{}),
// make_freeze_transform(Number<N4>{}),
// make_freeze_transform(Number<N5>{})),
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{},Sequence<4>{},
// Sequence<5>{}, Sequence<6>{}, Sequence<7>{}), make_tuple(Sequence<0>{}, Sequence<1>{},
// Sequence<2>{}, Sequence<3>{},Sequence<4,7>{}, Sequence<5>{}, Sequence<6>{},
// Sequence<8>{}));
// return z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4;
// //return make_naive_tensor_descriptor_packed(
// // make_tuple(Number<MXdlPerWave>{}, Number<NXdlPerWave>{}, Number<Gemm0MWaves>{},
// Number<Gemm0NWaves>{},
// // Number<MPerXdl / N5>{}, Number<N3>{}, Number<N4>{}, Number<N5>{},
// Number<N5>{}));
//}
__device__
static
auto
GetGemm0WaveIdx
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
...
...
@@ -904,6 +950,41 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// printf("n4 is %d \n",n4.value);
//}
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
// for blockwise copy
make_naive_tensor_descriptor_packed
(
make_tuple
(
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
n4
));
// registerNum
constexpr
auto
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4
=
// for blockwise copy
make_naive_tensor_descriptor_packed
(
make_tuple
(
m0
,
//
n0
,
//
m1
,
//
n1
,
//
m2
,
// m0 1
n2
,
// n0 4
n3
,
// n1 1
n4
,
// m1 4
I1
));
// n2 1
// constexpr auto z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5 = //for gridwise
// copy
// make_naive_tensor_descriptor_packed(make_tuple(I1,
// I1,
// m0, //
// n0, //
// m1, //
// n1, //
// m2, // m0 1
// n2, // n0 4
// n3, // n1 1
// n4, // m1 4
// I1)); // n2 1
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockID
...
...
@@ -916,19 +997,50 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
n3
,
// NInputNum
n4
));
// registerNum
constexpr
auto
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
//
I1
,
//
m0
,
//
n0
,
//
m1
,
//
n1
,
//
m2
,
// m0
n2
,
// m1
n3
,
// n0
n4
,
// n1
I1
));
// n2
// ignore = z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5;
constexpr
auto
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
// constexpr auto z_block_lengths = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLengths();
constexpr
auto
zM0
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
constexpr
auto
zN0
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
constexpr
auto
zM1
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I2
);
constexpr
auto
zN1
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I3
);
constexpr
auto
zM2
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I4
);
constexpr
auto
zN2
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I5
);
constexpr
auto
zN3
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
constexpr
auto
zN4
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
constexpr
auto
z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4
=
transform_tensor_descriptor
(
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
make_pass_through_transform
(
zM0
),
make_pass_through_transform
(
zN0
),
make_pass_through_transform
(
zM1
),
make_pass_through_transform
(
zN1
),
make_unmerge_transform
(
make_tuple
(
Number
<
zM2
.
value
/
zN4
.
value
>
{},
zN4
)),
make_pass_through_transform
(
zN2
),
make_pass_through_transform
(
zN3
),
make_pass_through_transform
(
zN4
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
,
7
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
8
>
{}));
// ignore = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4;
// ignore = z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
unsigned
short
,
...
...
@@ -939,7 +1051,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
unsigned
short
,
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_
m4_n4_n5
.
GetElementSpaceSize
(),
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_
n4
.
GetElementSpaceSize
(),
true
>
z_tenor_buffer
;
// z buffer after shuffle
z_tenor_buffer
.
Clear
();
...
...
@@ -948,10 +1060,17 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
auto
z_grid_tmp_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_z_grid
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
// auto z_grid_tmp_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_z_grid,
// z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5.GetElementSpaceSize());
ignore
=
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5
;
// ignore = z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5;
auto
z_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ZDataType
*>
(
p_shared
),
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetElementSpaceSize
());
// ignore = z_block_buf;
// if(get_thread_global_1d_id()==0){
// printf("z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize() is %ld \n",
...
...
@@ -990,7 +1109,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
//}
/*
auto z_tmp_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ZDataType,
...
...
@@ -1035,10 +1154,58 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
10,
1,
1,
true
/* ResetCoordAfterRun */
>
{
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5
,
make_multi_index
(
block_work_idx_m
,
// MBlockId
0
,
// NBlockId
0
,
// mrepeat
true >{z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
make_multi_index(block_work_idx_m, //
MBlockId 0, // NBlockId 0, // mrepeat 0, //
nrepeat wave_id[I0], // MWaveId wave_id[I1], // NWaveId
int(wave_m_n_id[I1] / 4), //
MPerXdl 0, // group wave_m_n_id[I0], // NInputIndex 0,
wave_m_n_id[I1] % 4)};
*/
auto
z_tmp_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ZDataType
,
decltype
(
z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
decltype
(
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
m0
,
// MRepeat
n0
,
// NRepeat
m1
,
// MWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
n2
,
// NGroupNum
n3
,
// NInputNum
n4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
// DstVectorDim
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
{
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_multi_index
(
0
,
// mrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
],
// MPerXdl
0
,
// group
wave_m_n_id
[
I0
],
// NInputIndex
0
),
tensor_operation
::
element_wise
::
PassThrough
{}};
auto
z_shuffle_thread_copy_lds_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
ZDataType
,
ushort
,
decltype
(
z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4
),
decltype
(
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4
),
Sequence
<
m0
,
n0
,
m1
,
n1
,
m2
,
n2
,
n3
,
n4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
>
,
8
,
1
,
1
,
true
>
{
z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4
,
make_multi_index
(
0
,
// mrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
...
...
@@ -1229,27 +1396,42 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
blockwise_dropout
.
template
GenerateZMatrixAttnFwd
<
decltype
(
z_tenor_tmp_buffer
)>(
ph
,
global_elem_id
,
z_tenor_tmp_buffer
);
z_tmp_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tmp_thread_copy_vgpr_to_lds
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_tmp_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_tmp_buf
);
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
z_block_buf
);
// z_tmp_thread_copy_vgpr_to_global.Run(
// z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
// z_tenor_tmp_buffer,
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// z_grid_tmp_buf);
block_sync_lds
();
z_shuffle_thread_copy_global_to_vgpr
.
Run
(
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5
,
z_grid_tmp_buf
,
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
// ignore = z_shuffle_thread_copy_lds_to_vgpr;
z_shuffle_thread_copy_lds_to_vgpr
.
Run
(
z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4
,
z_block_buf
,
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
);
// z_shuffle_thread_copy_global_to_vgpr.Run(
// z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
// z_grid_tmp_buf,
// z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
// z_tenor_buffer);
blockwise_dropout
.
template
ApplyDropoutWithZ
<
decltype
(
acc_thread_buf
),
decltype
(
z_tenor_buffer
),
false
>(
acc_thread_buf
,
z_tenor_buffer
);
// ignore = z_tenor_buffer;
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
...
...
@@ -1259,13 +1441,13 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
block_sync_lds
();
z_tmp_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
//
z_tmp_thread_copy_vgpr_to_global.MoveDstSliceWindow(
//
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
//
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
z_shuffle_thread_copy_global_to_vgpr
.
MoveSrcSliceWindow
(
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
//
z_shuffle_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
//
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
//
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment