Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
337f073d
Commit
337f073d
authored
Nov 28, 2024
by
Po Yen Chen
Browse files
Move num_splits_heuristic() to fmha_fwd.hpp for reusability
parent
2da4b185
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
35 deletions
+36
-35
example/ck_tile/01_fmha/fmha_fwd.cpp
example/ck_tile/01_fmha/fmha_fwd.cpp
+0
-35
example/ck_tile/01_fmha/fmha_fwd.hpp
example/ck_tile/01_fmha/fmha_fwd.hpp
+36
-0
No files found.
example/ck_tile/01_fmha/fmha_fwd.cpp
View file @
337f073d
...
@@ -177,41 +177,6 @@ auto get_elimit<FmhaFwdFp8>(std::string init_method)
...
@@ -177,41 +177,6 @@ auto get_elimit<FmhaFwdFp8>(std::string init_method)
}
}
}
}
int
num_splits_heuristic
(
int
batch_nhead_mblocks
,
int
num_SMs
,
int
max_splits
)
{
// If we have enough to almost fill the SMs, then just use 1 split
if
(
batch_nhead_mblocks
>=
0.8
f
*
num_SMs
)
{
return
1
;
}
max_splits
=
std
::
min
({
max_splits
,
num_SMs
});
float
max_efficiency
=
0.
f
;
std
::
vector
<
float
>
efficiency
;
efficiency
.
reserve
(
max_splits
);
for
(
int
num_splits
=
1
;
num_splits
<=
max_splits
;
num_splits
++
)
{
float
n_blocks
=
float
(
batch_nhead_mblocks
*
num_splits
)
/
num_SMs
;
float
eff
=
n_blocks
/
std
::
ceil
(
n_blocks
);
if
(
eff
>
max_efficiency
)
{
max_efficiency
=
eff
;
}
efficiency
.
push_back
(
eff
);
}
for
(
int
num_splits
=
1
;
num_splits
<=
max_splits
;
num_splits
++
)
{
if
(
efficiency
[
num_splits
-
1
]
>=
0.85
*
max_efficiency
)
{
return
num_splits
;
}
}
return
1
;
}
int
override_num_splits_if_necessary
(
int
batch
,
int
override_num_splits_if_necessary
(
int
batch
,
int
nhead
,
int
nhead
,
int
max_seqlen_q
,
int
max_seqlen_q
,
...
...
example/ck_tile/01_fmha/fmha_fwd.hpp
View file @
337f073d
...
@@ -813,3 +813,39 @@ struct fmha_fwd_appendkv_traits
...
@@ -813,3 +813,39 @@ struct fmha_fwd_appendkv_traits
float
fmha_fwd_appendkv
(
fmha_fwd_appendkv_traits
,
float
fmha_fwd_appendkv
(
fmha_fwd_appendkv_traits
,
fmha_fwd_appendkv_args
,
fmha_fwd_appendkv_args
,
const
ck_tile
::
stream_config
&
);
const
ck_tile
::
stream_config
&
);
template
<
typename
Int
=
int
>
Int
num_splits_heuristic
(
Int
batch_nhead_mblocks
,
Int
num_SMs
,
Int
max_splits
)
{
// If we have enough to almost fill the SMs, then just use 1 split
if
(
batch_nhead_mblocks
>=
0.8
f
*
num_SMs
)
{
return
1
;
}
max_splits
=
std
::
min
({
max_splits
,
num_SMs
});
float
max_efficiency
=
0.
f
;
std
::
vector
<
float
>
efficiency
;
efficiency
.
reserve
(
max_splits
);
for
(
Int
num_splits
=
1
;
num_splits
<=
max_splits
;
num_splits
++
)
{
float
n_blocks
=
float
(
batch_nhead_mblocks
*
num_splits
)
/
num_SMs
;
float
eff
=
n_blocks
/
std
::
ceil
(
n_blocks
);
if
(
eff
>
max_efficiency
)
{
max_efficiency
=
eff
;
}
efficiency
.
push_back
(
eff
);
}
for
(
Int
num_splits
=
1
;
num_splits
<=
max_splits
;
num_splits
++
)
{
if
(
efficiency
[
num_splits
-
1
]
>=
0.85
*
max_efficiency
)
{
return
num_splits
;
}
}
return
1
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment