Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9a771b0b
Commit
9a771b0b
authored
Sep 25, 2024
by
Po Yen, Chen
Browse files
Support kVLoadOnce in fwd kernel
parent
b75c9265
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
14 deletions
+53
-14
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+53
-14
No files found.
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
9a771b0b
...
...
@@ -44,6 +44,20 @@ struct is_k_load_once<Pipeline, std::void_t<decltype(Pipeline::kKLoadOnce)>>
template
<
typename
Pipeline
>
static
constexpr
bool
is_k_load_once_v
=
is_k_load_once
<
Pipeline
>::
value
;
template
<
typename
Pipeline
,
typename
=
void
>
struct
is_v_load_once
:
std
::
false_type
{
};
template
<
typename
Pipeline
>
struct
is_v_load_once
<
Pipeline
,
std
::
void_t
<
decltype
(
Pipeline
::
kVLoadOnce
)
>>
:
std
::
bool_constant
<
Pipeline
::
kVLoadOnce
>
{
};
template
<
typename
Pipeline
>
static
constexpr
bool
is_v_load_once_v
=
is_v_load_once
<
Pipeline
>::
value
;
}
// namespace detail
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
...
...
@@ -680,11 +694,21 @@ struct FmhaFwdKernel
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
if
constexpr
(
detail
::
is_v_load_once_v
<
FmhaPipeline
>
)
{
return
pad_tensor_view
(
v_dram_transposed
,
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kN0
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenK
>
{});
}
else
{
return
pad_tensor_view
(
v_dram_transposed
,
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenK
>
{});
}
}
else
{
const
auto
v_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
...
...
@@ -694,16 +718,26 @@ struct FmhaFwdKernel
number
<
FmhaPipeline
::
kAlignmentV
>
{},
number
<
1
>
{});
if
constexpr
(
detail
::
is_v_load_once_v
<
FmhaPipeline
>
)
{
return
pad_tensor_view
(
v_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kN0
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenK
>
{});
}
else
{
return
pad_tensor_view
(
v_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenK
>
{});
}
}
}();
auto
q_dram_window
=
make_tile_window
(
q_dram
,
[
&
]
()
{
[
&
]
{
if
constexpr
(
detail
::
is_q_load_once_v
<
FmhaPipeline
>
)
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0BlockLength
>
{});
...
...
@@ -714,7 +748,7 @@ struct FmhaFwdKernel
auto
k_dram_window
=
make_tile_window
(
k_dram
,
[
&
]
()
{
[
&
]
{
if
constexpr
(
detail
::
is_k_load_once_v
<
FmhaPipeline
>
)
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0BlockLength
>
{});
...
...
@@ -723,9 +757,14 @@ struct FmhaFwdKernel
}(),
{
0
,
0
});
auto
v_dram_window
=
make_tile_window
(
v_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
auto
v_dram_window
=
make_tile_window
(
v_dram
,
[
&
]
{
if
constexpr
(
detail
::
is_v_load_once_v
<
FmhaPipeline
>
)
return
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kN1
>
{},
number
<
FmhaPipeline
::
kK1
>
{});
}(),
{
i_n1
,
0
});
/// FIXME: Before C++20, capturing structured binding variables are not supported. Remove
/// following copy capture of the 'i_nhead' if in C++20
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment