Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d3b01d2b
Commit
d3b01d2b
authored
Feb 04, 2025
by
Qianfeng Zhang
Browse files
Re-arrange the codes before the main-loop
parent
d55852bc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
65 additions
and
62 deletions
+65
-62
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+65
-62
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
d3b01d2b
...
...
@@ -188,6 +188,38 @@ struct BlockFmhaPipelineQRKSVSAsync
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
q_origin
=
q_dram_window
.
get_window_origin
();
const
auto
[
seqlen_k_start
,
seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
__builtin_amdgcn_sched_barrier
(
0
);
auto
k_dram_block_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_k_start
,
0
});
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window
.
get_bottom_tensor_view
(),
k_dram_block_window
.
get_window_lengths
(),
k_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
using
k_tile_type
=
decltype
(
load_tile
(
k_dram_window
));
auto
k_tiles
=
[
&
]()
{
// for hdim-96 and hdim-160, try to save vgprs
if
constexpr
(
kQKHeaddim
<
kSubQKHeaddim
)
return
statically_indexed_array
<
k_tile_type
,
2
>
{};
else
return
statically_indexed_array
<
k_tile_type
,
k0_loops
>
{};
}();
k_tiles
[
I0
]
=
load_tile
(
k_dram_window
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
// K tile in LDS
KDataType
*
k_lds_ptr
=
static_cast
<
KDataType
*>
(
smem_ptr
);
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
...
...
@@ -195,6 +227,23 @@ struct BlockFmhaPipelineQRKSVSAsync
auto
k_lds_window
=
make_tile_window
(
k_lds
,
Policy
::
template
MakeKLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
using
k_lds_window_type
=
decltype
(
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
0
>
{},
sequence
<
kN0
,
kK0
>
{}));
statically_indexed_array
<
k_lds_window_type
,
NumKLdsBuffers
>
k_lds_windows
;
static_for
<
0
,
NumKLdsBuffers
,
1
>
{}([
&
](
auto
i_buf
)
{
k_lds_windows
[
i_buf
]
=
get_slice_tile
(
k_lds_window
,
sequence
<
i_buf
*
kN0
,
0
>
{},
sequence
<
(
i_buf
+
1
)
*
kN0
,
kK0
>
{});
});
__builtin_amdgcn_sched_barrier
(
0
);
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_k_start
},
// TODO: hdim split?
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
// V tile in LDS
auto
v_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
VDataType
*>
(
smem_ptr
),
...
...
@@ -202,6 +251,22 @@ struct BlockFmhaPipelineQRKSVSAsync
auto
v_lds_window
=
make_tile_window
(
v_lds
,
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
using
v_tile_type
=
decltype
(
load_tile
(
v_dram_window
));
statically_indexed_array
<
v_tile_type
,
NumVLdsBuffers
>
v_tiles
;
using
v_lds_window_type
=
decltype
(
get_slice_tile
(
v_lds_window
,
sequence
<
0
,
0
>
{},
sequence
<
kN1
,
kK1
>
{}));
statically_indexed_array
<
v_lds_window_type
,
NumVLdsBuffers
>
v_lds_windows
;
static_for
<
0
,
NumVLdsBuffers
,
1
>
{}([
&
](
auto
i_buf
)
{
v_lds_windows
[
i_buf
]
=
get_slice_tile
(
v_lds_window
,
sequence
<
i_buf
*
kN1
,
0
>
{},
sequence
<
(
i_buf
+
1
)
*
kN1
,
kK1
>
{});
});
__builtin_amdgcn_sched_barrier
(
0
);
// Block GEMM
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetKVBlockGemm
<
Problem
>();
...
...
@@ -230,12 +295,6 @@ struct BlockFmhaPipelineQRKSVSAsync
set_tile
(
m
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
clear_tile
(
l
);
const
auto
q_origin
=
q_dram_window
.
get_window_origin
();
const
auto
[
seqlen_k_start
,
seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
// check early exit if no work to do
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
)
{
...
...
@@ -257,42 +316,6 @@ struct BlockFmhaPipelineQRKSVSAsync
}
}
auto
k_dram_block_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
{
seqlen_k_start
,
0
});
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window
.
get_bottom_tensor_view
(),
k_dram_block_window
.
get_window_lengths
(),
k_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
using
k_tile_type
=
decltype
(
load_tile
(
k_dram_window
));
auto
k_tiles
=
[
&
]()
{
// for hdim-96 and hdim-160, try to save vgprs
if
constexpr
(
kQKHeaddim
<
kSubQKHeaddim
)
return
statically_indexed_array
<
k_tile_type
,
2
>
{};
else
return
statically_indexed_array
<
k_tile_type
,
k0_loops
>
{};
}();
k_tiles
[
I0
]
=
load_tile
(
k_dram_window
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
using
k_lds_window_type
=
decltype
(
get_slice_tile
(
k_lds_window
,
sequence
<
0
,
0
>
{},
sequence
<
kN0
,
kK0
>
{}));
statically_indexed_array
<
k_lds_window_type
,
NumKLdsBuffers
>
k_lds_windows
;
static_for
<
0
,
NumKLdsBuffers
,
1
>
{}([
&
](
auto
i_buf
)
{
k_lds_windows
[
i_buf
]
=
get_slice_tile
(
k_lds_window
,
sequence
<
i_buf
*
kN0
,
0
>
{},
sequence
<
(
i_buf
+
1
)
*
kN0
,
kK0
>
{});
});
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
...
...
@@ -303,26 +326,6 @@ struct BlockFmhaPipelineQRKSVSAsync
auto
randval_dram_window
=
dropout
.
template
MakeRandvalDramWindow
<
decltype
(
gemm_0
)>(
randval_dram_block_window_tmp
,
seqlen_k_start
);
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
{
0
,
seqlen_k_start
},
// TODO: hdim split?
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
using
v_tile_type
=
decltype
(
load_tile
(
v_dram_window
));
statically_indexed_array
<
v_tile_type
,
NumVLdsBuffers
>
v_tiles
;
using
v_lds_window_type
=
decltype
(
get_slice_tile
(
v_lds_window
,
sequence
<
0
,
0
>
{},
sequence
<
kN1
,
kK1
>
{}));
statically_indexed_array
<
v_lds_window_type
,
NumVLdsBuffers
>
v_lds_windows
;
static_for
<
0
,
NumVLdsBuffers
,
1
>
{}([
&
](
auto
i_buf
)
{
v_lds_windows
[
i_buf
]
=
get_slice_tile
(
v_lds_window
,
sequence
<
i_buf
*
kN1
,
0
>
{},
sequence
<
(
i_buf
+
1
)
*
kN1
,
kK1
>
{});
});
index_t
i_total_loops
=
0
;
do
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment