Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
97efebdb
Commit
97efebdb
authored
Feb 02, 2025
by
Qianfeng Zhang
Browse files
Special treatment for hdim-96 to save vgprs in qr_ks_vs_async pipeline
parent
a94ac4bb
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
139 additions
and
95 deletions
+139
-95
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+139
-95
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
97efebdb
...
...
@@ -76,25 +76,27 @@ struct BlockFmhaPipelineQRKSVSAsync
return
Problem
::
kBlockPerCu
;
else
{
if
constexpr
(
kQKHeaddim
<
=
32
)
if
constexpr
(
kQKHeaddim
=
=
32
)
{
return
2
;
}
else
if
constexpr
(
kQKHeaddim
<
=
64
)
else
if
constexpr
(
kQKHeaddim
=
=
64
)
{
return
2
;
}
else
if
constexpr
(
kQKHeaddim
<
=
128
)
else
if
constexpr
(
kQKHeaddim
==
96
||
kQKHeaddim
=
=
128
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
else
return
1
;
return
2
;
}
else
if
constexpr
(
kQKHeaddim
<
=
256
)
else
if
constexpr
(
kQKHeaddim
=
=
256
)
{
return
1
;
}
else
return
1
;
}
}();
...
...
@@ -170,7 +172,6 @@ struct BlockFmhaPipelineQRKSVSAsync
constexpr
auto
NumVLdsBuffers
=
Policy
::
template
GetNumVLdsBuffers
<
Problem
>();
static_assert
(
NumKLdsBuffers
>=
2
);
static_assert
(
NumVLdsBuffers
>=
2
);
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
...
...
@@ -269,7 +270,13 @@ struct BlockFmhaPipelineQRKSVSAsync
using
k_tile_type
=
decltype
(
load_tile
(
k_dram_window
));
statically_indexed_array
<
k_tile_type
,
k0_loops
>
k_tiles
;
auto
k_tiles
=
[
&
]()
{
// for hdim-96 and hdim-160, try to save vgprs
if
constexpr
(
kQKHeaddim
<
kSubQKHeaddim
)
return
statically_indexed_array
<
k_tile_type
,
2
>
{};
else
return
statically_indexed_array
<
k_tile_type
,
k0_loops
>
{};
}();
k_tiles
[
I0
]
=
load_tile
(
k_dram_window
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
...
...
@@ -295,6 +302,8 @@ struct BlockFmhaPipelineQRKSVSAsync
index_t
i_total_loops
=
0
;
do
{
if
constexpr
(
kQKHeaddim
==
kSubQKHeaddim
)
{
if
(
i_total_loops
==
0
)
// executed by fist iteration
{
...
...
@@ -375,8 +384,8 @@ struct BlockFmhaPipelineQRKSVSAsync
move_tile_window
(
k_dram_window
,
{
kN0
,
0
});
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
auto
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
auto
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
sequence
<
(
i_k0
%
NumKLdsBuffers
)
*
kN0
,
0
>
{},
sequence
<
((
i_k0
%
NumKLdsBuffers
)
+
1
)
*
kN0
,
kK0
>
{});
store_tile
(
k_lds_window_tmp
,
k_tiles
[
number
<
i_k0
>
{}]);
...
...
@@ -389,7 +398,6 @@ struct BlockFmhaPipelineQRKSVSAsync
clear_tile
(
s_acc
);
block_sync_lds
();
// execute last unroll of gemm_0
gemm_0
(
s_acc
,
q_tiles
[
number
<
i_k0
>
{}],
k_lds_window_tmp
);
});
...
...
@@ -398,8 +406,8 @@ struct BlockFmhaPipelineQRKSVSAsync
else
// last iteration
{
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
auto
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
auto
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
sequence
<
(
i_k0
%
NumKLdsBuffers
)
*
kN0
,
0
>
{},
sequence
<
((
i_k0
%
NumKLdsBuffers
)
+
1
)
*
kN0
,
kK0
>
{});
store_tile
(
k_lds_window_tmp
,
k_tiles
[
number
<
i_k0
>
{}]);
...
...
@@ -408,11 +416,47 @@ struct BlockFmhaPipelineQRKSVSAsync
clear_tile
(
s_acc
);
block_sync_lds
();
// execute last unroll of gemm_0
gemm_0
(
s_acc
,
q_tiles
[
number
<
i_k0
>
{}],
k_lds_window_tmp
);
});
};
};
}
else
{
auto
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
0
>
{},
sequence
<
kN0
,
kK0
>
{});
store_tile
(
k_lds_window_tmp
,
k_tiles
[
I0
]);
clear_tile
(
s_acc
);
// initialize C
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
if
constexpr
(
i_k0
<
k0_loops
-
1
)
{
k_tiles
[
number
<
(
i_k0
+
1
)
%
2
>
{}]
=
load_tile
(
k_dram_window
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
};
block_sync_lds
();
// execute current unroll of gemm_0
gemm_0
(
s_acc
,
q_tiles
[
number
<
i_k0
>
{}],
k_lds_window_tmp
);
if
constexpr
(
i_k0
<
k0_loops
-
1
)
{
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
sequence
<
((
i_k0
+
1
)
%
NumKLdsBuffers
)
*
kN0
,
0
>
{},
sequence
<
(((
i_k0
+
1
)
%
NumKLdsBuffers
)
+
1
)
*
kN0
,
kK0
>
{});
store_tile
(
k_lds_window_tmp
,
k_tiles
[
number
<
(
i_k0
+
1
)
%
2
>
{}]);
};
});
if
(
i_total_loops
<
num_total_loop
-
1
)
{
move_tile_window
(
k_dram_window
,
{
kN0
,
-
k0_loops
*
kK0
});
k_tiles
[
I0
]
=
load_tile
(
k_dram_window
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
};
};
__builtin_amdgcn_sched_barrier
(
0
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment