Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
2e612c02
"...resnet50_tensorflow.git" did not exist on "6a47721eafea2bcbcce124fa9ec38d907824a4d6"
Commit
2e612c02
authored
Feb 09, 2025
by
Qianfeng Zhang
Browse files
Adjust the pipeline codes
parent
a72e100e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
19 deletions
+11
-19
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+11
-19
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
2e612c02
...
@@ -182,23 +182,10 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -182,23 +182,10 @@ struct BlockFmhaPipelineQRKSVSAsync
Policy
::
template
MakeQRegTileDistribution
<
Problem
>());
Policy
::
template
MakeQRegTileDistribution
<
Problem
>());
using
q_tile_type
=
decltype
(
load_tile
(
q_dram_window
));
using
q_tile_type
=
decltype
(
load_tile
(
q_dram_window
));
statically_indexed_array
<
q_tile_type
,
k0_loops
>
q_tiles
;
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
q_tiles
[
number
<
i_k0
>
{}]
=
load_tile
(
q_dram_window
);
move_tile_window
(
q_dram_window
,
{
0
,
kK0
});
});
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
q_origin
=
q_dram_window
.
get_window_origin
();
const
auto
q_origin
=
q_dram_window
.
get_window_origin
();
const
auto
[
seqlen_k_start
,
seqlen_k_end
]
=
const
auto
[
seqlen_k_start
,
seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
mask
.
GetTileRangeAlongX
(
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
__builtin_amdgcn_sched_barrier
(
0
);
auto
k_dram_block_window
=
auto
k_dram_block_window
=
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tile_window
(
k_dram_block_window_tmp
.
get_bottom_tensor_view
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
k_dram_block_window_tmp
.
get_window_lengths
(),
...
@@ -222,6 +209,15 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -222,6 +209,15 @@ struct BlockFmhaPipelineQRKSVSAsync
k_tiles
[
I0
]
=
load_tile
(
k_dram_window
);
k_tiles
[
I0
]
=
load_tile
(
k_dram_window
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
statically_indexed_array
<
q_tile_type
,
k0_loops
>
q_tiles
;
static_for
<
0
,
k0_loops
,
1
>
{}([
&
](
auto
i_k0
)
{
q_tiles
[
number
<
i_k0
>
{}]
=
load_tile
(
q_dram_window
);
move_tile_window
(
q_dram_window
,
{
0
,
kK0
});
});
__builtin_amdgcn_sched_barrier
(
0
);
// K tile in LDS
// K tile in LDS
KDataType
*
k_lds_ptr
=
static_cast
<
KDataType
*>
(
smem_ptr
);
KDataType
*
k_lds_ptr
=
static_cast
<
KDataType
*>
(
smem_ptr
);
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
...
@@ -239,8 +235,6 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -239,8 +235,6 @@ struct BlockFmhaPipelineQRKSVSAsync
k_lds_window
,
sequence
<
i_buf
*
kN0
,
0
>
{},
sequence
<
(
i_buf
+
1
)
*
kN0
,
kK0
>
{});
k_lds_window
,
sequence
<
i_buf
*
kN0
,
0
>
{},
sequence
<
(
i_buf
+
1
)
*
kN0
,
kK0
>
{});
});
});
__builtin_amdgcn_sched_barrier
(
0
);
auto
v_dram_window
=
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
v_dram_block_window_tmp
.
get_window_lengths
(),
...
@@ -268,8 +262,6 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -268,8 +262,6 @@ struct BlockFmhaPipelineQRKSVSAsync
v_lds_window
,
sequence
<
i_buf
*
kN1
,
0
>
{},
sequence
<
(
i_buf
+
1
)
*
kN1
,
kK1
>
{});
v_lds_window
,
sequence
<
i_buf
*
kN1
,
0
>
{},
sequence
<
(
i_buf
+
1
)
*
kN1
,
kK1
>
{});
});
});
__builtin_amdgcn_sched_barrier
(
0
);
// Block GEMM
// Block GEMM
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetKVBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetKVBlockGemm
<
Problem
>();
...
@@ -298,6 +290,8 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -298,6 +290,8 @@ struct BlockFmhaPipelineQRKSVSAsync
set_tile
(
m
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
set_tile
(
m
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
clear_tile
(
l
);
clear_tile
(
l
);
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
// check early exit if no work to do
// check early exit if no work to do
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
)
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
)
{
{
...
@@ -661,8 +655,6 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -661,8 +655,6 @@ struct BlockFmhaPipelineQRKSVSAsync
const
auto
p
=
const
auto
p
=
cast_tile
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
cast_tile
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
__builtin_amdgcn_sched_barrier
(
0
);
if
constexpr
(
!
kPreloadWholeNextIterationK
)
if
constexpr
(
!
kPreloadWholeNextIterationK
)
{
{
if
(
i_total_loops
<
num_total_loop
-
1
)
if
(
i_total_loops
<
num_total_loop
-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment