Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5f4bfa4a
Commit
5f4bfa4a
authored
Jan 26, 2025
by
Qianfeng Zhang
Browse files
Tune the prefetching of V in qr_ks_vs_async pipeline
parent
45398bf4
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
34 additions
and
28 deletions
+34
-28
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+34
-28
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
5f4bfa4a
...
...
@@ -161,13 +161,18 @@ struct BlockFmhaPipelineQRKSVSAsync
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
constexpr
index_t
k0_loops
=
kQKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
2
<=
k0_loops
);
static_assert
(
1
<=
k1_loops
);
static_assert
(
2
<=
k1_loops
);
constexpr
auto
NumVLdsBuffers
=
Policy
::
template
GetNumVLdsBuffers
<
Problem
>();
static_assert
(
NumVLdsBuffers
>=
2
);
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_origin
(),
...
...
@@ -366,7 +371,14 @@ struct BlockFmhaPipelineQRKSVSAsync
__builtin_amdgcn_sched_barrier
(
0
);
auto
v_buf
=
load_tile
(
v_dram_window
);
// prefetch load v tile
using
v_tile_type
=
decltype
(
load_tile
(
v_dram_window
));
statically_indexed_array
<
v_tile_type
,
NumVLdsBuffers
>
v_tiles
;
static_for
<
0
,
NumVLdsBuffers
,
1
>
{}([
&
](
auto
i_k1
)
{
v_tiles
[
i_k1
]
=
load_tile
(
v_dram_window
);
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
});
// STAGE 2, scale_s, add bias, mask, softmax
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
...
...
@@ -446,12 +458,12 @@ struct BlockFmhaPipelineQRKSVSAsync
s
.
get_tile_distribution
());
// Pcompute{j}
__builtin_amdgcn_sched_barrier
(
0
);
// store & prefetch next v, after the max reduction
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
v_shuffle_tmp
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v_
buf
);
shuffle_tile
(
v_shuffle_tmp
,
v_
tiles
[
I0
]
);
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
0
,
0
>
{},
sequence
<
kN1
,
kK1
>
{});
...
...
@@ -465,14 +477,7 @@ struct BlockFmhaPipelineQRKSVSAsync
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
0
,
0
>
{},
sequence
<
kN1
,
kK1
>
{});
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_buf
));
// store the prefetch
}
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
if
constexpr
(
NumVLdsBuffers
>
1
)
{
v_buf
=
load_tile
(
v_dram_window
);
// load next v_buf
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
tile_elementwise_in
(
v_element_func
,
v_tiles
[
I0
]));
// store the prefetch
}
__builtin_amdgcn_sched_barrier
(
0
);
...
...
@@ -569,7 +574,8 @@ struct BlockFmhaPipelineQRKSVSAsync
if
constexpr
(
NumVLdsBuffers
==
1
)
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
v_buf
=
load_tile
(
v_dram_window
);
// load next v_buf
v_tiles
[
I0
]
=
load_tile
(
v_dram_window
);
block_sync_lds
();
gemm_1
(
o_acc
,
...
...
@@ -584,15 +590,14 @@ struct BlockFmhaPipelineQRKSVSAsync
{
auto
v_shuffle_tmp
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v_
buf
);
shuffle_tile
(
v_shuffle_tmp
,
v_
tiles
[
I0
]
);
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
((
i_k1
+
1
)
%
NumVLdsBuffers
)
*
kN1
,
0
>
{},
sequence
<
(((
i_k1
+
1
)
%
NumVLdsBuffers
)
+
1
)
*
kN1
,
kK1
>
{});
block_sync_lds
();
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_shuffle_tmp
));
// store the prefetch
tile_elementwise_in
(
v_element_func
,
v_shuffle_tmp
));
}
else
{
...
...
@@ -601,18 +606,18 @@ struct BlockFmhaPipelineQRKSVSAsync
sequence
<
((
i_k1
+
1
)
%
NumVLdsBuffers
)
*
kN1
,
0
>
{},
sequence
<
(((
i_k1
+
1
)
%
NumVLdsBuffers
)
+
1
)
*
kN1
,
kK1
>
{});
block_sync_lds
();
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_buf
));
// store next v_buf
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_tiles
[
I0
]));
}
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
});
}
else
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
if
constexpr
(
i_k1
>
0
&&
i_k1
<
k1_loops
-
1
)
v_
buf
=
load_tile
(
v_dram_window
);
// load next v_buf
if
constexpr
(
i_k1
<
k1_loops
-
NumVLdsBuffers
)
v_
tiles
[
number
<
i_k1
%
NumVLdsBuffers
>
{}]
=
load_tile
(
v_dram_window
);
block_sync_lds
();
gemm_1
(
...
...
@@ -628,14 +633,14 @@ struct BlockFmhaPipelineQRKSVSAsync
{
auto
v_shuffle_tmp
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v_buf
);
shuffle_tile
(
v_shuffle_tmp
,
v_tiles
[
number
<
(
i_k1
+
1
)
%
NumVLdsBuffers
>
{}]);
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
((
i_k1
+
1
)
%
NumVLdsBuffers
)
*
kN1
,
0
>
{},
sequence
<
(((
i_k1
+
1
)
%
NumVLdsBuffers
)
+
1
)
*
kN1
,
kK1
>
{});
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_shuffle_tmp
));
// store the prefetch
tile_elementwise_in
(
v_element_func
,
v_shuffle_tmp
));
}
else
{
...
...
@@ -643,12 +648,13 @@ struct BlockFmhaPipelineQRKSVSAsync
v_lds_window
,
sequence
<
((
i_k1
+
1
)
%
NumVLdsBuffers
)
*
kN1
,
0
>
{},
sequence
<
(((
i_k1
+
1
)
%
NumVLdsBuffers
)
+
1
)
*
kN1
,
kK1
>
{});
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_buf
));
// store next v_buf
store_tile
(
v_lds_window_tmp
,
tile_elementwise_in
(
v_element_func
,
v_tiles
[
number
<
(
i_k1
+
1
)
%
NumVLdsBuffers
>
{}]));
}
if
constexpr
(
i_k1
>
0
&&
i_k1
<
k1_loops
-
1
)
if
constexpr
(
i_k1
<
k1_loops
-
NumVLdsBuffers
)
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
});
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment