Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
612a35d6
Commit
612a35d6
authored
Dec 17, 2024
by
Po Yen Chen
Browse files
Use maximum 8 splits
parent
45721793
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
10 deletions
+12
-10
example/ck_tile/01_fmha/fmha_fwd.cpp
example/ck_tile/01_fmha/fmha_fwd.cpp
+12
-10
No files found.
example/ck_tile/01_fmha/fmha_fwd.cpp
View file @
612a35d6
...
...
@@ -234,8 +234,7 @@ int override_num_splits_if_necessary(int batch,
if
(
num_splits
<
1
&&
p_drop
==
0.0
f
)
{
return
num_splits_heuristic
(
batch
*
nhead
*
num_m_blocks
,
props
.
multiProcessorCount
*
2
,
16
);
return
num_splits_heuristic
(
batch
*
nhead
*
num_m_blocks
,
props
.
multiProcessorCount
*
2
,
8
);
}
return
num_splits
;
...
...
@@ -625,11 +624,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
});
ck_tile
::
HostTensor
<
OaccDataType
>
o_acc_host
(
1
<
num_splits
?
std
::
array
<
ck_tile
::
index_t
,
5
>
{
shape_batch
,
nhead
,
num_splits
,
shape_seqlen_q
,
hdim_v
}
:
std
::
array
<
ck_tile
::
index_t
,
5
>
{
1
,
1
,
1
,
1
,
1
});
nhead
,
num_splits
,
shape_seqlen_q
,
hdim_v
}
:
std
::
array
<
ck_tile
::
index_t
,
5
>
{
1
,
1
,
1
,
1
,
1
});
// batch mode of lse data layout is [batch, nhead, seqlen_q]
// group mode of lse data layout is [nhead, total_seqlen_q]
...
...
@@ -1042,7 +1041,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
args
.
num_splits
=
num_splits
;
if
(
1
<
num_splits
)
{
if
(
1
<
num_splits
)
{
args
.
lse_acc_ptr
=
lse_acc_buf
.
GetDeviceBuffer
();
args
.
o_acc_ptr
=
o_acc_buf
.
GetDeviceBuffer
();
...
...
@@ -1053,7 +1053,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
args
.
batch_stride_o_acc
=
batch_stride_o_acc
;
args
.
split_stride_lse_acc
=
split_stride_lse_acc
;
args
.
split_stride_o_acc
=
split_stride_o_acc
;
}
else
{
}
else
{
// following attribues are ignored by fmha_fwd_splitkv()
args
.
lse_acc_ptr
=
nullptr
;
args
.
o_acc_ptr
=
nullptr
;
...
...
@@ -1088,7 +1090,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
float
fwd_ave_time
=
[
&
]
{
#if CK_TILE_FMHA_FWD_SPLITKV_API
if
(
1
<
num_splits
||
use_kvcache
)
if
(
1
<
=
num_splits
||
use_kvcache
)
{
fmha_fwd_splitkv_traits
fmha_splitkv_traits
;
init_traits
(
fmha_splitkv_traits
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment