Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
475c0d2c
Commit
475c0d2c
authored
Feb 02, 2025
by
Qianfeng Zhang
Browse files
Use array of tiles to represent Q in vgprs
parent
119dd2ac
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
22 deletions
+22
-22
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+13
-22
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp
...ine/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp
+9
-0
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
475c0d2c
...
@@ -179,7 +179,14 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -179,7 +179,14 @@ struct BlockFmhaPipelineQRKSVSAsync
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_origin
(),
q_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeQRegTileDistribution
<
Problem
>());
Policy
::
template
MakeQRegTileDistribution
<
Problem
>());
auto
q
=
load_tile
(
q_dram_window
);
using
q_tile_type
=
decltype
(
load_tile
(
q_dram_window
));
statically_indexed_array
<
q_tile_type
,
k0_loops
>
q_tiles
;
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
q_tiles
[
number
<
i_k0
>
{}]
=
load_tile
(
q_dram_window
);
move_tile_window
(
q_dram_window
,
{
0
,
kK0
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
...
@@ -308,10 +315,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -308,10 +315,7 @@ struct BlockFmhaPipelineQRKSVSAsync
block_sync_lds
();
block_sync_lds
();
// execute current unroll of gemm_0
// execute current unroll of gemm_0
gemm_0
(
s_acc
,
gemm_0
(
s_acc
,
q_tiles
[
number
<
i_k0
>
{}],
k_lds_window_tmp
);
get_slice_tile
(
q
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
k_lds_window_tmp
);
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
k_lds_window
,
...
@@ -333,11 +337,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -333,11 +337,7 @@ struct BlockFmhaPipelineQRKSVSAsync
block_sync_lds
();
block_sync_lds
();
// execute last unroll of gemm_0
// execute last unroll of gemm_0
gemm_0
(
s_acc
,
gemm_0
(
s_acc
,
q_tiles
[
number
<
k0_loops
-
1
>
{}],
k_lds_window_tmp
);
get_slice_tile
(
q
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kM0
,
k0_loops
*
kK0
>
{}),
k_lds_window_tmp
);
}
}
else
// there is only single iteration
else
// there is only single iteration
{
{
...
@@ -356,10 +356,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -356,10 +356,7 @@ struct BlockFmhaPipelineQRKSVSAsync
block_sync_lds
();
block_sync_lds
();
// execute current unroll of gemm_0
// execute current unroll of gemm_0
gemm_0
(
s_acc
,
gemm_0
(
s_acc
,
q_tiles
[
number
<
i_k0
>
{}],
k_lds_window_tmp
);
get_slice_tile
(
q
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
k_lds_window_tmp
);
if
constexpr
(
i_k0
<
k0_loops
-
1
)
if
constexpr
(
i_k0
<
k0_loops
-
1
)
{
{
...
@@ -396,10 +393,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -396,10 +393,7 @@ struct BlockFmhaPipelineQRKSVSAsync
block_sync_lds
();
block_sync_lds
();
// execute last unroll of gemm_0
// execute last unroll of gemm_0
gemm_0
(
s_acc
,
gemm_0
(
s_acc
,
q_tiles
[
number
<
i_k0
>
{}],
k_lds_window_tmp
);
get_slice_tile
(
q
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
k_lds_window_tmp
);
});
});
move_tile_window
(
k_dram_window
,
{
0
,
-
(
k0_loops
-
1
)
*
kK0
});
move_tile_window
(
k_dram_window
,
{
0
,
-
(
k0_loops
-
1
)
*
kK0
});
...
@@ -418,10 +412,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -418,10 +412,7 @@ struct BlockFmhaPipelineQRKSVSAsync
block_sync_lds
();
block_sync_lds
();
// execute last unroll of gemm_0
// execute last unroll of gemm_0
gemm_0
(
s_acc
,
gemm_0
(
s_acc
,
q_tiles
[
number
<
i_k0
>
{}],
k_lds_window_tmp
);
get_slice_tile
(
q
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
k_lds_window_tmp
);
});
});
};
};
};
};
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp
View file @
475c0d2c
...
@@ -13,6 +13,15 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
...
@@ -13,6 +13,15 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
/* AsyncCopy = */
true
,
/* AsyncCopy = */
true
,
/* NumPrefetchV = */
2
>
/* NumPrefetchV = */
2
>
{
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQRegTileDistribution
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetQKBlockGemm
<
Problem
>
())
>
;
return
BlockGemm
::
template
MakeABlockTileDistribution
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kK0
>();
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment