Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7b6fe38e
Commit
7b6fe38e
authored
Sep 20, 2024
by
Po Yen, Chen
Browse files
Add type traits to switch tensor view/tile window size
parent
427206a5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
54 additions
and
7 deletions
+54
-7
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+54
-7
No files found.
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
7b6fe38e
...
@@ -16,6 +16,35 @@
...
@@ -16,6 +16,35 @@
// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
namespace
ck_tile
{
namespace
ck_tile
{
namespace
detail
{
template
<
typename
Pipeline
,
typename
=
void
>
struct
is_q_load_once
:
std
::
false_type
{
};
template
<
typename
Pipeline
>
struct
is_q_load_once
<
Pipeline
,
std
::
void_t
<
decltype
(
Pipeline
::
kQLoadOnce
)
>>
:
std
::
bool_constant
<
Pipeline
::
kQLoadOnce
>
{
};
template
<
typename
Pipeline
>
static
constexpr
bool
is_q_load_once_v
=
is_q_load_once
<
Pipeline
>::
value
;
template
<
typename
Pipeline
,
typename
=
void
>
struct
is_k_load_once
:
std
::
false_type
{
};
template
<
typename
Pipeline
>
struct
is_k_load_once
<
Pipeline
,
std
::
void_t
<
decltype
(
Pipeline
::
kKLoadOnce
)
>>
:
std
::
bool_constant
<
Pipeline
::
kKLoadOnce
>
{
};
template
<
typename
Pipeline
>
static
constexpr
bool
is_k_load_once_v
=
is_k_load_once
<
Pipeline
>::
value
;
}
// namespace detail
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
struct
FmhaFwdKernel
struct
FmhaFwdKernel
...
@@ -596,7 +625,7 @@ struct FmhaFwdKernel
...
@@ -596,7 +625,7 @@ struct FmhaFwdKernel
make_tuple
(
kargs
.
stride_q
,
1
),
make_tuple
(
kargs
.
stride_q
,
1
),
number
<
FmhaPipeline
::
kAlignmentQ
>
{},
number
<
FmhaPipeline
::
kAlignmentQ
>
{},
number
<
1
>
{});
number
<
1
>
{});
if
constexpr
(
FmhaPipeline
::
kQLoadOnce
)
if
constexpr
(
detail
::
is_q_load_once_v
<
FmhaPipeline
>
)
{
{
return
pad_tensor_view
(
return
pad_tensor_view
(
q_dram_naive
,
q_dram_naive
,
...
@@ -619,10 +648,20 @@ struct FmhaFwdKernel
...
@@ -619,10 +648,20 @@ struct FmhaFwdKernel
number
<
FmhaPipeline
::
kAlignmentK
>
{},
number
<
FmhaPipeline
::
kAlignmentK
>
{},
number
<
1
>
{});
number
<
1
>
{});
return
pad_tensor_view
(
if
constexpr
(
detail
::
is_k_load_once_v
<
FmhaPipeline
>
)
k_dram_naive
,
{
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
return
pad_tensor_view
(
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0BlockLength
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
}
else
{
return
pad_tensor_view
(
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
}
}();
}();
const
auto
v_dram
=
[
&
]()
{
const
auto
v_dram
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
...
@@ -665,7 +704,7 @@ struct FmhaFwdKernel
...
@@ -665,7 +704,7 @@ struct FmhaFwdKernel
auto
q_dram_window
=
make_tile_window
(
auto
q_dram_window
=
make_tile_window
(
q_dram
,
q_dram
,
[
&
]()
{
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQLoadOnce
)
if
constexpr
(
detail
::
is_q_load_once_v
<
FmhaPipeline
>
)
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0BlockLength
>
{});
number
<
FmhaPipeline
::
kK0BlockLength
>
{});
else
else
...
@@ -674,7 +713,15 @@ struct FmhaFwdKernel
...
@@ -674,7 +713,15 @@ struct FmhaFwdKernel
{
i_m0
,
0
});
{
i_m0
,
0
});
auto
k_dram_window
=
make_tile_window
(
auto
k_dram_window
=
make_tile_window
(
k_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
{
0
,
0
});
k_dram
,
[
&
]()
{
if
constexpr
(
detail
::
is_k_load_once_v
<
FmhaPipeline
>
)
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0BlockLength
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{});
}(),
{
0
,
0
});
auto
v_dram_window
=
auto
v_dram_window
=
make_tile_window
(
v_dram
,
make_tile_window
(
v_dram
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment