Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
87b206fb
Commit
87b206fb
authored
Feb 03, 2025
by
Qianfeng Zhang
Browse files
Define statically indexed array v_lds_windows[] to reduce using of get_slice_tile()
parent
cde3b677
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
47 deletions
+27
-47
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+27
-47
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
87b206fb
...
...
@@ -313,6 +313,16 @@ struct BlockFmhaPipelineQRKSVSAsync
statically_indexed_array
<
v_tile_type
,
NumVLdsBuffers
>
v_tiles
;
using
v_lds_window_type
=
decltype
(
get_slice_tile
(
v_lds_window
,
sequence
<
0
,
0
>
{},
sequence
<
kN1
,
kK1
>
{}));
statically_indexed_array
<
v_lds_window_type
,
NumVLdsBuffers
>
v_lds_windows
;
static_for
<
0
,
NumVLdsBuffers
,
1
>
{}([
&
](
auto
i_buf
)
{
v_lds_windows
[
i_buf
]
=
get_slice_tile
(
v_lds_window
,
sequence
<
i_buf
*
kN1
,
0
>
{},
sequence
<
(
i_buf
+
1
)
*
kN1
,
kK1
>
{});
});
index_t
i_total_loops
=
0
;
do
...
...
@@ -643,18 +653,13 @@ struct BlockFmhaPipelineQRKSVSAsync
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v_tiles
[
I0
]);
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
0
,
0
>
{},
sequence
<
kN1
,
kK1
>
{});
store_tile
(
v_lds_window
_tmp
,
v_lds_window
s
[
I0
]
,
tile_elementwise_in
(
v_element_func
,
v_shuffle_tmp
));
// store the prefetch
}
else
{
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
0
,
0
>
{},
sequence
<
kN1
,
kK1
>
{});
store_tile
(
v_lds_window_tmp
,
store_tile
(
v_lds_windows
[
I0
],
tile_elementwise_in
(
v_element_func
,
v_tiles
[
I0
]));
// store the prefetch
}
...
...
@@ -672,13 +677,10 @@ struct BlockFmhaPipelineQRKSVSAsync
v_tiles
[
I0
]
=
load_tile
(
v_dram_window
);
block_sync_lds
();
gemm_1
(
o_acc
,
gemm_1
(
o_acc
,
get_slice_tile
(
p
,
sequence
<
0
,
i_k1
*
kK1
>
{},
sequence
<
kM0
,
(
i_k1
+
1
)
*
kK1
>
{}),
get_slice_tile
(
v_lds_window
,
sequence
<
(
i_k1
%
NumVLdsBuffers
)
*
kN1
,
0
>
{},
sequence
<
((
i_k1
%
NumVLdsBuffers
)
+
1
)
*
kN1
,
kK1
>
{}));
v_lds_windows
[
number
<
i_k1
%
NumVLdsBuffers
>
{}]);
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
...
...
@@ -686,22 +688,14 @@ struct BlockFmhaPipelineQRKSVSAsync
auto
v_shuffle_tmp
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v_tiles
[
I0
]);
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
((
i_k1
+
1
)
%
NumVLdsBuffers
)
*
kN1
,
0
>
{},
sequence
<
(((
i_k1
+
1
)
%
NumVLdsBuffers
)
+
1
)
*
kN1
,
kK1
>
{});
block_sync_lds
();
store_tile
(
v_lds_window
_tmp
,
store_tile
(
v_lds_window
s
[
number
<
(
i_k1
+
1
)
%
NumVLdsBuffers
>
{}]
,
tile_elementwise_in
(
v_element_func
,
v_shuffle_tmp
));
}
else
{
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
((
i_k1
+
1
)
%
NumVLdsBuffers
)
*
kN1
,
0
>
{},
sequence
<
(((
i_k1
+
1
)
%
NumVLdsBuffers
)
+
1
)
*
kN1
,
kK1
>
{});
block_sync_lds
();
store_tile
(
v_lds_window
_tmp
,
store_tile
(
v_lds_window
s
[
number
<
(
i_k1
+
1
)
%
NumVLdsBuffers
>
{}]
,
tile_elementwise_in
(
v_element_func
,
v_tiles
[
I0
]));
}
...
...
@@ -715,13 +709,10 @@ struct BlockFmhaPipelineQRKSVSAsync
v_tiles
[
number
<
i_k1
%
NumVLdsBuffers
>
{}]
=
load_tile
(
v_dram_window
);
block_sync_lds
();
gemm_1
(
o_acc
,
gemm_1
(
o_acc
,
get_slice_tile
(
p
,
sequence
<
0
,
i_k1
*
kK1
>
{},
sequence
<
kM0
,
(
i_k1
+
1
)
*
kK1
>
{}),
get_slice_tile
(
v_lds_window
,
sequence
<
(
i_k1
%
NumVLdsBuffers
)
*
kN1
,
0
>
{},
sequence
<
((
i_k1
%
NumVLdsBuffers
)
+
1
)
*
kN1
,
kK1
>
{}));
v_lds_windows
[
number
<
i_k1
%
NumVLdsBuffers
>
{}]);
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
...
...
@@ -730,20 +721,12 @@ struct BlockFmhaPipelineQRKSVSAsync
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v_tiles
[
number
<
(
i_k1
+
1
)
%
NumVLdsBuffers
>
{}]);
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
((
i_k1
+
1
)
%
NumVLdsBuffers
)
*
kN1
,
0
>
{},
sequence
<
(((
i_k1
+
1
)
%
NumVLdsBuffers
)
+
1
)
*
kN1
,
kK1
>
{});
store_tile
(
v_lds_window_tmp
,
store_tile
(
v_lds_windows
[
number
<
(
i_k1
+
1
)
%
NumVLdsBuffers
>
{}],
tile_elementwise_in
(
v_element_func
,
v_shuffle_tmp
));
}
else
{
auto
v_lds_window_tmp
=
get_slice_tile
(
v_lds_window
,
sequence
<
((
i_k1
+
1
)
%
NumVLdsBuffers
)
*
kN1
,
0
>
{},
sequence
<
(((
i_k1
+
1
)
%
NumVLdsBuffers
)
+
1
)
*
kN1
,
kK1
>
{});
store_tile
(
v_lds_window_tmp
,
store_tile
(
v_lds_windows
[
number
<
(
i_k1
+
1
)
%
NumVLdsBuffers
>
{}],
tile_elementwise_in
(
v_element_func
,
v_tiles
[
number
<
(
i_k1
+
1
)
%
NumVLdsBuffers
>
{}]));
...
...
@@ -759,12 +742,9 @@ struct BlockFmhaPipelineQRKSVSAsync
// tail
{
block_sync_lds
();
gemm_1
(
o_acc
,
gemm_1
(
o_acc
,
get_slice_tile
(
p
,
sequence
<
0
,
(
k1_loops
-
1
)
*
kK1
>
{},
sequence
<
kM0
,
kN0
>
{}),
get_slice_tile
(
v_lds_window
,
sequence
<
((
k1_loops
-
1
)
%
NumVLdsBuffers
)
*
kN1
,
0
>
{},
sequence
<
(((
k1_loops
-
1
)
%
NumVLdsBuffers
)
+
1
)
*
kN1
,
kK1
>
{}));
v_lds_windows
[
number
<
(
k1_loops
-
1
)
%
NumVLdsBuffers
>
{}]);
}
}
while
(
++
i_total_loops
<
num_total_loop
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment