Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
af1059a3
Commit
af1059a3
authored
Jun 12, 2023
by
guangzlu
Browse files
fixed bugs for lds shuffle
parent
957ab734
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
51 additions
and
50 deletions
+51
-50
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
...wise_batched_multihead_attention_forward_xdl_cshuffle.hpp
+51
-50
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
af1059a3
...
@@ -255,7 +255,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -255,7 +255,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const
index_t
c_block_bytes_end
=
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
softmax_bytes_end
,
c_block_bytes_end
);
const
index_t
z_block_bytes_end
=
SharedMemTrait
::
z_shuffle_block_space_size
*
sizeof
(
ZDataType
);
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
softmax_bytes_end
,
c_block_bytes_end
,
z_block_bytes_end
);
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...
@@ -415,6 +422,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -415,6 +422,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
static
constexpr
auto
c_block_space_size
=
static
constexpr
auto
c_block_space_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
static
constexpr
auto
z_shuffle_block_space_size
=
MPerBlock
*
NPerBlock
;
};
};
template
<
bool
HasMainKBlockLoop
,
template
<
bool
HasMainKBlockLoop
,
...
@@ -874,7 +883,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -874,7 +883,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
n3
,
// NInputNum
n3
,
// NInputNum
n4
));
// registerNum
n4
));
// registerNum
constexpr
auto
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_
m
3_
n
3_n4
=
// for blockwise copy
constexpr
auto
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_
n
3_
m
3_n4
=
// for blockwise copy
make_naive_tensor_descriptor_packed
(
make_tuple
(
m0
,
//
make_naive_tensor_descriptor_packed
(
make_tuple
(
m0
,
//
n0
,
//
n0
,
//
m1
,
//
m1
,
//
...
@@ -911,7 +920,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -911,7 +920,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
constexpr
auto
zN3
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
constexpr
auto
zN3
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
constexpr
auto
zN4
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
constexpr
auto
zN4
=
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
constexpr
auto
z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_
m
3_
n
3_n4
=
constexpr
auto
z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_
n
3_
m
3_n4
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
make_pass_through_transform
(
zM0
),
make_tuple
(
make_pass_through_transform
(
zM0
),
...
@@ -940,21 +949,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -940,21 +949,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
Sequence
<
8
>
{}));
Sequence
<
8
>
{}));
// ignore = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4;
// ignore = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4;
// ignore = z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4;
// ignore = z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
unsigned
short
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
(),
true
>
z_tenor_tmp_buffer
;
z_tenor_tmp_buffer
.
Clear
();
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
unsigned
short
,
unsigned
short
,
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_
m
3_
n
3_n4
.
GetElementSpaceSize
(),
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_
n
3_
m
3_n4
.
GetElementSpaceSize
(),
true
>
true
>
z_tenor_buffer
;
// z buffer after shuffle
z_ten
s
or_buffer
;
// z buffer after shuffle
z_tenor_buffer
.
Clear
();
z_ten
s
or_buffer
.
Clear
();
// z matrix global desc
// z matrix global desc
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
z_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
@@ -964,6 +966,17 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -964,6 +966,17 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
static_cast
<
ZDataType
*>
(
p_shared
),
static_cast
<
ZDataType
*>
(
p_shared
),
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetElementSpaceSize
());
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetElementSpaceSize
());
// if(get_thread_global_1d_id()==0){
// printf("z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize() is %ld \n",
// z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize().value);
// printf("z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize() is %ld
// \n", z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize().value);
// printf("z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize() is %ld
// \n",z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize().value);
// printf("z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize() is %ld
// \n",z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize().value);
// }
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
...
@@ -1028,14 +1041,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1028,14 +1041,14 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
auto
z_shuffle_thread_copy_lds_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
auto
z_shuffle_thread_copy_lds_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
ZDataType
,
ZDataType
,
ushort
,
ushort
,
decltype
(
z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_
m
3_
n
3_n4
),
decltype
(
z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_
n
3_
m
3_n4
),
decltype
(
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_
m
3_
n
3_n4
),
decltype
(
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_
n
3_
m
3_n4
),
Sequence
<
m0
,
n0
,
m1
,
n1
,
m2
,
n2
,
n3
,
n4
,
I1
>
,
Sequence
<
m0
,
n0
,
m1
,
n1
,
m2
,
n2
,
n3
,
n4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
>
,
8
,
8
,
1
,
1
,
1
,
1
,
true
>
{
z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_
m
3_
n
3_n4
,
true
>
{
z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_
n
3_
m
3_n4
,
make_multi_index
(
0
,
// mrepeat
make_multi_index
(
0
,
// mrepeat
0
,
// nrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I0
],
// MWaveId
...
@@ -1222,54 +1235,42 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1222,54 +1235,42 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
// save z to global
// save z to global
if
(
p_z_grid
)
{
blockwise_dropout
.
template
GenerateZMatrixAttnFwd
<
decltype
(
z_tenor_tmp_buffer
)>(
ph
,
global_elem_id
,
z_tenor_tmp_buffer
);
z_tmp_thread_copy_vgpr_to_lds
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_tmp_buffer
,
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
z_block_buf
);
block_sync_lds
();
blockwise_dropout
.
template
GenerateZMatrixAttnFwd
<
decltype
(
z_tensor_buffer
)>(
ph
,
global_elem_id
,
z_tensor_buffer
);
// ignore = z_shuffle_thread_copy_lds_to_vgpr;
z_tmp_thread_copy_vgpr_to_lds
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tensor_buffer
,
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
z_block_buf
);
z_shuffle_thread_copy_lds_to_vgpr
.
Run
(
z_shuffle_thread_copy_lds_to_vgpr
.
Run
(
z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_
m
3_
n
3_n4
,
z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_
n
3_
m
3_n4
,
z_block_buf
,
z_block_buf
,
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_
m
3_
n
3_n4
,
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_
n
3_
m
3_n4
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
);
z_ten
s
or_buffer
);
blockwise_dropout
.
template
ApplyDropoutWithZ
<
decltype
(
acc_thread_buf
),
blockwise_dropout
.
template
ApplyDropoutWithZ
<
decltype
(
acc_thread_buf
),
decltype
(
z_tenor_buffer
),
decltype
(
z_ten
s
or_buffer
),
false
>(
acc_thread_buf
,
false
>(
acc_thread_buf
,
z_tenor_buffer
);
z_ten
s
or_buffer
);
// ignore = z_tenor_buffer;
if
(
p_z_grid
)
{
// ignore = z_tensor_buffer;
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_ten
s
or_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
z_grid_buf
);
block_sync_lds
();
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
}
}
else
{
// ignore = z_grid_buf;
// P_dropped
blockwise_dropout
.
template
ApplyDropoutAttnFwd
<
decltype
(
acc_thread_buf
),
false
>(
acc_thread_buf
,
ph
,
global_elem_id
);
}
}
}
// TODO: may convert to log domain
// TODO: may convert to log domain
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment