Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
0739bc5a
Commit
0739bc5a
authored
Dec 20, 2024
by
Po Yen Chen
Browse files
Use kv_perm to controol key/value layout
parent
7c0e5822
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
24 deletions
+26
-24
example/ck_tile/01_fmha/fmha_fwd.cpp
example/ck_tile/01_fmha/fmha_fwd.cpp
+26
-24
No files found.
example/ck_tile/01_fmha/fmha_fwd.cpp
View file @
0739bc5a
...
...
@@ -580,12 +580,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
:
(
seqlen_kpads
[
0
]
<
0
?
seqstart_k_host
.
back
()
:
seqstart_k_with_padding_host
.
back
()));
bool
kv_perm
=
(
mode
==
mode_enum
::
group
&&
0
<
page_block_size
?
true
:
i_perm
);
ck_tile
::
HostTensor
<
QDataType
>
q_host
(
get_lengths
(
i_perm
,
shape_batch
,
nhead
,
shape_seqlen_q
,
hdim_q
));
ck_tile
::
HostTensor
<
KDataType
>
k_host
(
0
<
page_block_size
?
get_lengths
(
i
_perm
,
max_num_page_blocks
,
nhead_k
,
page_block_size
,
hdim_q
)
:
get_lengths
(
i
_perm
,
shape_batch
,
nhead_k
,
shape_seqlen_k
,
hdim_q
));
?
get_lengths
(
kv
_perm
,
max_num_page_blocks
,
nhead_k
,
page_block_size
,
hdim_q
)
:
get_lengths
(
kv
_perm
,
shape_batch
,
nhead_k
,
shape_seqlen_k
,
hdim_q
));
/// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode
ck_tile
::
HostTensor
<
KDataType
>
knew_host
(
0
<
seqlen_knew
...
...
@@ -594,10 +596,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
HostTensor
<
VDataType
>
v_host
(
0
<
page_block_size
?
(
is_v_rowmajor
?
get_lengths
(
i
_perm
,
max_num_page_blocks
,
nhead_k
,
page_block_size
,
hdim_v
)
:
get_lengths
(
i
_perm
,
max_num_page_blocks
,
nhead_k
,
hdim_v
,
page_block_size
))
:
(
is_v_rowmajor
?
get_lengths
(
i
_perm
,
shape_batch
,
nhead_k
,
shape_seqlen_k
,
hdim_v
)
:
get_lengths
(
i
_perm
,
shape_batch
,
nhead_k
,
hdim_v
,
shape_seqlen_k
)));
?
get_lengths
(
kv
_perm
,
max_num_page_blocks
,
nhead_k
,
page_block_size
,
hdim_v
)
:
get_lengths
(
kv
_perm
,
max_num_page_blocks
,
nhead_k
,
hdim_v
,
page_block_size
))
:
(
is_v_rowmajor
?
get_lengths
(
kv
_perm
,
shape_batch
,
nhead_k
,
shape_seqlen_k
,
hdim_v
)
:
get_lengths
(
kv
_perm
,
shape_batch
,
nhead_k
,
hdim_v
,
shape_seqlen_k
)));
ck_tile
::
HostTensor
<
VDataType
>
vnew_host
(
0
<
seqlen_knew
?
(
is_v_rowmajor
?
get_lengths
(
i_perm
,
batch
,
nhead_k
,
seqlen_knew
,
hdim_v
)
...
...
@@ -762,9 +764,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
mode
==
mode_enum
::
group
&&
0
<
page_block_size
)
{
if
(
!
(
i_perm
&&
!
is_v_rowmajor
)
)
if
(
!
is_v_rowmajor
)
{
std
::
cerr
<<
"make sure input layout is correct"
<<
std
::
endl
;
std
::
cerr
<<
"make sure input layout is correct
: -vlayout=r
"
<<
std
::
endl
;
return
false
;
}
...
...
@@ -877,14 +879,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
/// 'nhead_stride_bias' are 0.
// setup stride_* arguments
const
ck_tile
::
index_t
stride_q
=
(
i_perm
?
hdim_q
:
nhead
*
hdim_q
);
const
ck_tile
::
index_t
stride_k
=
(
i
_perm
?
hdim_q
:
nhead_k
*
hdim_q
);
const
ck_tile
::
index_t
stride_k
=
(
kv
_perm
?
hdim_q
:
nhead_k
*
hdim_q
);
const
ck_tile
::
index_t
stride_knew
=
(
i_perm
?
hdim_q
:
nhead_k
*
hdim_q
);
const
ck_tile
::
index_t
stride_v
=
[
&
]()
{
if
(
is_v_rowmajor
)
return
i
_perm
?
hdim_v
:
nhead_k
*
hdim_v
;
return
kv
_perm
?
hdim_v
:
nhead_k
*
hdim_v
;
else
return
0
<
page_block_size
?
(
i
_perm
?
page_block_size
:
nhead_k
*
page_block_size
)
:
(
i
_perm
?
shape_seqlen_k
:
nhead_k
*
shape_seqlen_k
);
return
0
<
page_block_size
?
(
kv
_perm
?
page_block_size
:
nhead_k
*
page_block_size
)
:
(
kv
_perm
?
shape_seqlen_k
:
nhead_k
*
shape_seqlen_k
);
}();
const
ck_tile
::
index_t
stride_vnew
=
[
&
]()
{
if
(
is_v_rowmajor
)
...
...
@@ -899,16 +901,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
// setup nhead_stride_* arguments
const
ck_tile
::
index_t
nhead_stride_q
=
(
i_perm
?
shape_seqlen_q
*
hdim_q
:
hdim_q
);
const
ck_tile
::
index_t
nhead_stride_k
=
(
0
<
page_block_size
?
(
i
_perm
?
page_block_size
*
hdim_q
:
hdim_q
)
:
(
i
_perm
?
shape_seqlen_k
*
hdim_q
:
hdim_q
));
(
0
<
page_block_size
?
(
kv
_perm
?
page_block_size
*
hdim_q
:
hdim_q
)
:
(
kv
_perm
?
shape_seqlen_k
*
hdim_q
:
hdim_q
));
const
ck_tile
::
index_t
nhead_stride_knew
=
(
i_perm
?
seqlen_knew
*
hdim_q
:
hdim_q
);
const
ck_tile
::
index_t
nhead_stride_v
=
[
&
]()
{
if
(
is_v_rowmajor
)
return
0
<
page_block_size
?
(
i
_perm
?
page_block_size
*
hdim_v
:
hdim_v
)
:
(
i
_perm
?
shape_seqlen_k
*
hdim_v
:
hdim_v
);
return
0
<
page_block_size
?
(
kv
_perm
?
page_block_size
*
hdim_v
:
hdim_v
)
:
(
kv
_perm
?
shape_seqlen_k
*
hdim_v
:
hdim_v
);
else
return
0
<
page_block_size
?
(
i
_perm
?
hdim_v
*
page_block_size
:
page_block_size
)
:
(
i
_perm
?
hdim_v
*
shape_seqlen_k
:
shape_seqlen_k
);
return
0
<
page_block_size
?
(
kv
_perm
?
hdim_v
*
page_block_size
:
page_block_size
)
:
(
kv
_perm
?
hdim_v
*
shape_seqlen_k
:
shape_seqlen_k
);
}();
const
ck_tile
::
index_t
nhead_stride_vnew
=
[
&
]()
{
if
(
is_v_rowmajor
)
...
...
@@ -1278,7 +1280,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
#endif
#if CK_TILE_FMHA_FWD_SPLITKV_API
if
(
0
<
page_block_size
)
{
if
(
i
_perm
)
{
if
(
kv
_perm
)
{
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
block_table_host
(
wb
,
i
[
1
]
/
page_block_size
),
i
[
0
]
/
nr
,
i
[
1
]
%
page_block_size
,
i
[
2
]);
});
...
...
@@ -1290,7 +1292,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
else
#endif
{
if
(
i
_perm
)
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
cache_b_idx
,
i
[
0
]
/
nr
,
i
[
1
]
+
key_offset
,
i
[
2
]);
});
if
(
kv
_perm
)
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
cache_b_idx
,
i
[
0
]
/
nr
,
i
[
1
]
+
key_offset
,
i
[
2
]);
});
else
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
cache_b_idx
,
i
[
1
]
+
key_offset
,
i
[
0
]
/
nr
,
i
[
2
]);
});
}
...
...
@@ -1330,7 +1332,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
#if CK_TILE_FMHA_FWD_SPLITKV_API
if
(
0
<
page_block_size
)
{
if
(
is_v_rowmajor
)
{
if
(
i
_perm
)
{
if
(
kv
_perm
)
{
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
block_table_host
(
wb
,
i
[
2
]
/
page_block_size
),
i
[
0
]
/
nr
,
i
[
2
]
%
page_block_size
,
i
[
1
]);
});
...
...
@@ -1342,7 +1344,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
else
{
if
(
i
_perm
)
{
if
(
kv
_perm
)
{
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
block_table_host
(
wb
,
i
[
2
]
/
page_block_size
),
i
[
0
]
/
nr
,
i
[
1
],
i
[
2
]
%
page_block_size
);
});
...
...
@@ -1357,13 +1359,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
if
(
is_v_rowmajor
)
{
// v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d]
if
(
i
_perm
)
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
cache_b_idx
,
i
[
0
]
/
nr
,
i
[
2
]
+
key_offset
,
i
[
1
]);
});
if
(
kv
_perm
)
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
cache_b_idx
,
i
[
0
]
/
nr
,
i
[
2
]
+
key_offset
,
i
[
1
]);
});
// v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d]
else
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
cache_b_idx
,
i
[
2
]
+
key_offset
,
i
[
0
]
/
nr
,
i
[
1
]);
});
}
else
{
if
(
i
_perm
)
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
cache_b_idx
,
i
[
0
]
/
nr
,
i
[
1
],
i
[
2
]
+
key_offset
);
});
if
(
kv
_perm
)
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
cache_b_idx
,
i
[
0
]
/
nr
,
i
[
1
],
i
[
2
]
+
key_offset
);
});
else
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
cache_b_idx
,
i
[
1
],
i
[
0
]
/
nr
,
i
[
2
]
+
key_offset
);
});
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment