Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
db952741
"docs/git@developer.sourcefind.cn:SIYIXNI/vllm.git" did not exist on "2cf1a333b63a303fd4b65dd499f2e9b606e6525a"
Commit
db952741
authored
Jan 01, 2025
by
Po Yen Chen
Browse files
Only check incomplete split in first&last iterations
parent
32ef8a18
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
46 additions
and
33 deletions
+46
-33
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
...mha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
+46
-33
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
View file @
db952741
...
...
@@ -272,10 +272,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
aligned_physical_seqlen_k_start
)},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
decltype
(
gemm_0
)>());
auto
[
i_page_block_v
,
v_dram_window
]
=
v_page_block_navigator
.
make_tile_window
(
v_dram_block_window_lengths
,
{
0
,
aligned_physical_seqlen_k_start
},
// TODO: hdim split?
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
auto
[
i_page_block_v
,
v_dram_block_window
]
=
v_page_block_navigator
.
make_tile_window
(
v_dram_block_window_lengths
,
{
0
,
aligned_physical_seqlen_k_start
});
auto
q_tile
=
tile_elementwise_in
(
q_element_func
,
q
);
...
...
@@ -289,10 +287,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
do
{
// STAGE 1, QK gemm
// K DRAM tile window for load
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window
,
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
k_dram_block_window
,
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
auto
k_block_tile
=
load_tile
(
k_dram_window
);
{
...
...
@@ -334,6 +331,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
k_block_tile
=
load_tile
(
k_dram_window
);
// global read i + 2
});
}
// V DRAM tile window for load
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window
,
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
const
auto
v_prefetch
=
load_tile
(
v_dram_window
);
// prefetch load v tile
{
// tail
...
...
@@ -402,27 +402,31 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
move_tile_window
(
bias_dram_window
,
{
0
,
kN0
});
//
/ TODO:
only check in first/last iteration
without increasing code size
// only check in first/last iteration
s
if
constexpr
(
kHasUnevenSplits
)
{
const
auto
k_origin
=
k_page_block_navigator
.
to_global_window_origin
(
i_page_block_k
,
k_dram_block_window
.
get_window_origin
());
set_tile_if
(
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
,
physical_seqlen_k_start_
=
physical_seqlen_k_start
,
physical_seqlen_k_end_
=
physical_seqlen_k_end
](
auto
tile_idx
)
{
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
if
constexpr
(
kIsPagedKV
)
{
return
col
<
physical_seqlen_k_start_
||
physical_seqlen_k_end_
<=
col
;
}
else
{
return
physical_seqlen_k_end_
<=
col
;
}
});
if
(
1
<
num_splits
&&
(
i_total_loops
==
0
||
i_total_loops
==
num_total_loop
-
1
))
{
const
auto
k_origin
=
k_page_block_navigator
.
to_global_window_origin
(
i_page_block_k
,
k_dram_block_window
.
get_window_origin
());
set_tile_if
(
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
,
physical_seqlen_k_start_
=
physical_seqlen_k_start
,
physical_seqlen_k_end_
=
physical_seqlen_k_end
](
auto
tile_idx
)
{
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
if
constexpr
(
kIsPagedKV
)
{
return
col
<
physical_seqlen_k_start_
||
physical_seqlen_k_end_
<=
col
;
}
else
{
return
physical_seqlen_k_end_
<=
col
;
}
});
}
}
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
...
...
@@ -444,6 +448,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
});
}
}
__builtin_amdgcn_sched_barrier
(
0
);
// move K tile window
i_page_block_k
=
k_page_block_navigator
.
move_tile_window
(
i_page_block_k
,
k_dram_block_window
,
{
kN0
,
0
});
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
s
=
cast_tile
<
SMPLComputeDataType
>
(
s_acc
);
// S{j}
auto
m_local
=
block_tile_reduce
<
SMPLComputeDataType
>
(
...
...
@@ -549,12 +558,19 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v_prefetch
));
// store the prefetch
}
i_page_block_v
=
v_page_block_navigator
.
move_tile_window
(
i_page_block_v
,
v_dram_window
,
{
0
,
kK1
});
// moving v_dram_window is an in-page-block operation, so there is
// no need to invoke v_page_block_navigator.move_tile_window() here.
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
const
auto
p
=
cast_tile
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
__builtin_amdgcn_sched_barrier
(
0
);
// move V tile window
i_page_block_v
=
v_page_block_navigator
.
move_tile_window
(
i_page_block_v
,
v_dram_block_window
,
{
0
,
kN0
});
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 3, KV gemm
if
constexpr
(
k1_loops
>
1
)
{
...
...
@@ -582,13 +598,10 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v
));
// store next v
}
i_page_block_v_
=
v_page_block_navigator
.
move_tile_window
(
i_page_block_v_
,
v_dram_window_
,
{
0
,
kK1
});
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
});
}
// move K tile windows
i_page_block_k
=
k_page_block_navigator
.
move_tile_window
(
i_page_block_k
,
k_dram_block_window
,
{
kN0
,
0
});
// tail
{
block_sync_lds
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment