Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b1da29ba
"sgl-kernel/vscode:/vscode.git/clone" did not exist on "bcda0c9ee6a6e687e53ac933f3541dd5c5a1fe9b"
Commit
b1da29ba
authored
Jan 27, 2025
by
Qianfeng Zhang
Browse files
Switch to separate code blocks according to iteration index
parent
90e99a95
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
47 deletions
+73
-47
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+72
-42
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+1
-5
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
b1da29ba
...
...
@@ -169,8 +169,10 @@ struct BlockFmhaPipelineQRKSVSAsync
static_assert
(
2
<=
k0_loops
);
static_assert
(
2
<=
k1_loops
);
constexpr
auto
NumKLdsBuffers
=
Policy
::
template
GetNumKLdsBuffers
<
Problem
>();
constexpr
auto
NumVLdsBuffers
=
Policy
::
template
GetNumVLdsBuffers
<
Problem
>();
static_assert
(
NumKLdsBuffers
>=
2
);
static_assert
(
NumVLdsBuffers
>=
2
);
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
...
...
@@ -190,8 +192,8 @@ struct BlockFmhaPipelineQRKSVSAsync
KDataType
*
k_lds_ptr
=
static_cast
<
KDataType
*>
(
smem_ptr
);
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsBlockDescriptor
<
Problem
>());
auto
k_lds_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kSubQKHeaddim
>
{}
),
{
0
,
0
});
auto
k_lds_window
=
make_tile_window
(
k_lds
,
Policy
::
template
MakeKLdsBlockDescriptor
<
Problem
>().
get_lengths
(
),
{
0
,
0
});
// V tile in LDS
auto
v_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
...
...
@@ -261,23 +263,18 @@ struct BlockFmhaPipelineQRKSVSAsync
k_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_k_start
,
0
});
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window
.
get_bottom_tensor_view
(),
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window
.
get_bottom_tensor_view
(),
k_dram_block_window
.
get_window_lengths
(),
k_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
using
k_tile_type
=
decltype
(
load_tile
(
k_dram_window
));
statically_indexed_array
<
k_tile_type
,
k0_loops
>
k_tiles
;
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
k_tiles
[
i_k0
]
=
load_tile
(
k_dram_window
);
k_tiles
[
I0
]
=
load_tile
(
k_dram_window
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
});
move_tile_window
(
k_dram_window
,
{
0
,
-
k0_loops
*
kK0
});
__builtin_amdgcn_sched_barrier
(
0
);
...
...
@@ -318,7 +315,6 @@ struct BlockFmhaPipelineQRKSVSAsync
__builtin_amdgcn_sched_barrier
(
0
);
// prefetch K tile
index_t
i_total_loops
=
0
;
// ensure loading of Q from LDS completely done
...
...
@@ -326,51 +322,85 @@ struct BlockFmhaPipelineQRKSVSAsync
do
{
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
auto
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
sequence
<
i_k0
*
kN0
,
0
>
{},
sequence
<
(
i_k0
+
1
)
*
kN0
,
kK0
>
{});
if
(
i_total_loops
==
0
)
// executed by fist iteration
{
auto
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
0
>
{},
sequence
<
kN0
,
kK0
>
{});
store_tile
(
k_lds_window_tmp
,
k_tiles
[
I0
]);
store_tile
(
k_lds_window_tmp
,
k_tiles
[
i_k0
]);
});
clear_tile
(
s_acc
);
// initialize C
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 1, QK gemm
clear_tile
(
s_acc
);
// initialize C
if
(
i_total_loops
<
num_total_loop
-
1
)
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
if
constexpr
(
i_k0
<
k0_loops
-
1
)
{
move_tile_window
(
k_dram_window
,
{
kN0
,
0
});
k_tiles
[
number
<
i_k0
+
1
>
{}]
=
load_tile
(
k_dram_window
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
};
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
k_tiles
[
i_k0
]
=
load_tile
(
k_dram_window
);
__builtin_amdgcn_sched_barrier
(
0
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
block_sync_lds
();
// execute current unroll of gemm_0
gemm_0
(
s_acc
,
get_slice_tile
(
q
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
k_lds_window_tmp
);
if
constexpr
(
i_k0
<
k0_loops
-
1
)
{
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
sequence
<
((
i_k0
+
1
)
%
NumKLdsBuffers
)
*
kN0
,
0
>
{},
sequence
<
(((
i_k0
+
1
)
%
NumKLdsBuffers
)
+
1
)
*
kN0
,
kK0
>
{});
store_tile
(
k_lds_window_tmp
,
k_tiles
[
number
<
i_k0
+
1
>
{}]);
};
});
move_tile_window
(
k_dram_window
,
{
0
,
-
k0_loops
*
kK0
});
}
else
// executed by intermediate and last iteration
{
clear_tile
(
s_acc
);
// initialize C
__builtin_amdgcn_sched_barrier
(
0
);
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
auto
k_lds_window_tmp
=
get_slice_tile
(
k_lds_window
,
sequence
<
(
i_k0
%
NumKLdsBuffers
)
*
kN0
,
0
>
{},
sequence
<
((
i_k0
%
NumKLdsBuffers
)
+
1
)
*
kN0
,
kK0
>
{});
store_tile
(
k_lds_window_tmp
,
k_tiles
[
number
<
i_k0
>
{}]);
// ensure k is completely updated on LDS
block_sync_lds
();
// execute last unroll of gemm_0
gemm_0
(
s_acc
,
get_slice_tile
(
q
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
k_lds_window_tmp
);
});
};
__builtin_amdgcn_sched_barrier
(
0
);
// executed by first and intermediate iteration
if
(
i_total_loops
<
num_total_loop
-
1
)
{
move_tile_window
(
k_dram_window
,
{
kN0
,
0
});
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
gemm_0
(
s_acc
,
get_slice_tile
(
q
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
get_slice_tile
(
k_lds_window
,
sequence
<
i_k0
*
kN0
,
0
>
{},
sequence
<
(
i_k0
+
1
)
*
kN0
,
kK0
>
{}));
});
k_tiles
[
number
<
i_k0
>
{}]
=
load_tile
(
k_dram_window
);
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
if
constexpr
(
i_k0
<
k0_loops
-
1
)
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
});
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
move_tile_window
(
k_dram_window
,
{
0
,
-
(
k0_loops
-
1
)
*
kK0
});
}
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
using
v_tile_type
=
decltype
(
load_tile
(
v_dram_window
));
statically_indexed_array
<
v_tile_type
,
NumVLdsBuffers
>
v_tiles
;
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
b1da29ba
...
...
@@ -296,11 +296,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
{
if
constexpr
(
KLoadOnce
)
{
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
constexpr
index_t
k0_loops
=
BlockFmhaShape
::
kQKHeaddim
/
BlockFmhaShape
::
kK0
;
return
k0_loops
;
return
2
;
}
else
return
1
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment