Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
65bbe6ea
Commit
65bbe6ea
authored
Dec 24, 2024
by
Po Yen Chen
Browse files
Use vector load if paged-vcache is in column major
parent
1fef9106
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
61 additions
and
8 deletions
+61
-8
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+14
-4
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
...mha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
+47
-4
No files found.
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
View file @
65bbe6ea
...
@@ -717,10 +717,15 @@ struct FmhaFwdSplitKVKernel
...
@@ -717,10 +717,15 @@ struct FmhaFwdSplitKVKernel
number
<
FmhaPipeline
::
kAlignmentV
>
{},
number
<
FmhaPipeline
::
kAlignmentV
>
{},
number
<
1
>
{});
number
<
1
>
{});
// We assume that page-block size is always divisible by vector size. So we can use
// vector load on seqlen_k direction. However, the seqlen_k may not be divisible by
// vector size as well. So we will have to override data points which are located
// outside [0, seqlen_k) to 0.0 in pipeline.
return
pad_tensor_view
(
return
pad_tensor_view
(
v_dram_naive
,
v_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
sequence
<
false
,
kPadSeqLenK
>
{});
sequence
<
false
,
!
kIsPagedKV
&&
kPadSeqLenK
>
{});
}
}
};
};
const
auto
v_dram
=
[
&
]()
{
const
auto
v_dram
=
[
&
]()
{
...
@@ -786,9 +791,14 @@ struct FmhaFwdSplitKVKernel
...
@@ -786,9 +791,14 @@ struct FmhaFwdSplitKVKernel
num_blocks
,
num_blocks
,
kargs
.
page_block_size
,
kargs
.
page_block_size
,
v_dram
,
v_dram
,
make_v_dram
(
nullptr
,
make_v_dram
(
nullptr
,
[
&
]
{
(
kv_l2p_offset
+
kargs
.
seqlen_k
)
-
if
constexpr
(
std
::
is_same_v
<
VLayout
,
(
num_blocks
-
1
)
*
kargs
.
page_block_size
));
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
return
(
kv_l2p_offset
+
kargs
.
seqlen_k
)
-
(
num_blocks
-
1
)
*
kargs
.
page_block_size
;
else
return
kargs
.
page_block_size
;
}()));
}
}
else
else
{
{
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
View file @
65bbe6ea
...
@@ -63,7 +63,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -63,7 +63,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
return
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
return
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
else
else
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
// We assume that page-block size is always divisible by vector size. So we can use
// vector load on seqlen_k direction. However, the seqlen_k may not be divisible by
// vector size as well. So we will have to override data points which are located
// outside [0, seqlen_k) to 0.0.
return
kIsPagedKV
?
Policy
::
template
GetAlignmentV
<
Problem
>()
:
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
}();
}();
static
constexpr
index_t
kAlignmentOacc
=
static
constexpr
index_t
kAlignmentOacc
=
...
@@ -335,8 +341,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -335,8 +341,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
});
});
}
}
const
auto
v_prefetch
=
load_tile
(
v_dram_window
);
// prefetch load v tile
auto
v_prefetch
=
load_tile
(
v_dram_window
);
// prefetch load v tile
{
// tail
{
// tail
block_sync_lds
();
block_sync_lds
();
gemm_0
(
s_acc
,
gemm_0
(
s_acc
,
get_slice_tile
(
q_tile
,
get_slice_tile
(
q_tile
,
...
@@ -550,6 +556,24 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -550,6 +556,24 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
}
else
else
{
{
// Override data points which are located outside [0, seqlen_k) to 0.0
if
constexpr
(
kIsPagedKV
&&
kPadSeqLenK
)
{
if
(
v_page_block_navigator
.
is_last_block
(
i_page_block_v
))
{
const
auto
v_origin
=
v_page_block_navigator
.
to_global_window_origin
(
i_page_block_v
,
v_dram_window
.
get_window_origin
());
set_tile_if
(
v_prefetch
,
type_convert
<
VDataType
>
(
0.0
),
[
&
,
physical_seqlen_k_end_
=
physical_seqlen_k_end
](
auto
tile_idx
)
{
const
auto
col
=
v_origin
.
at
(
number
<
1
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
physical_seqlen_k_end_
<=
col
;
});
}
}
store_tile
(
v_lds_window
,
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v_prefetch
));
// store the prefetch
tile_elementwise_in
(
v_element_func
,
v_prefetch
));
// store the prefetch
}
}
...
@@ -565,7 +589,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -565,7 +589,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
,
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
,
&
i_page_block_v_
=
i_page_block_v
,
&
i_page_block_v_
=
i_page_block_v
,
&
v_dram_window_
=
v_dram_window
](
auto
i_k1
)
{
&
v_dram_window_
=
v_dram_window
](
auto
i_k1
)
{
const
auto
v
=
load_tile
(
v_dram_window_
);
// load next v
auto
v
=
load_tile
(
v_dram_window_
);
// load next v
block_sync_lds
();
block_sync_lds
();
gemm_1
(
o_acc
,
gemm_1
(
o_acc
,
get_slice_tile
(
get_slice_tile
(
...
@@ -583,6 +607,25 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -583,6 +607,25 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
}
else
else
{
{
// Override data points which are located outside [0, seqlen_k) to 0.0
if
constexpr
(
kIsPagedKV
&&
kPadSeqLenK
)
{
if
(
v_page_block_navigator
.
is_last_block
(
i_page_block_v_
))
{
const
auto
v_origin
=
v_page_block_navigator
.
to_global_window_origin
(
i_page_block_v_
,
v_dram_window_
.
get_window_origin
());
set_tile_if
(
v
,
type_convert
<
VDataType
>
(
0.0
),
[
&
,
physical_seqlen_k_end_
=
physical_seqlen_k_end
](
auto
tile_idx
)
{
const
auto
col
=
v_origin
.
at
(
number
<
1
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
physical_seqlen_k_end_
<=
col
;
});
}
}
store_tile
(
v_lds_window
,
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v
));
// store next v
tile_elementwise_in
(
v_element_func
,
v
));
// store next v
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment