Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
287a53bf
"...composable_kernel_rocm.git" did not exist on "421996707e28caec9ce3702e3f9b451bf5d4c969"
Commit
287a53bf
authored
Nov 28, 2024
by
Po Yen Chen
Browse files
Update num_splits heuristic
parent
2cd79708
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
32 deletions
+14
-32
example/ck_tile/01_fmha/fmha_fwd.cpp
example/ck_tile/01_fmha/fmha_fwd.cpp
+14
-32
No files found.
example/ck_tile/01_fmha/fmha_fwd.cpp
View file @
287a53bf
...
...
@@ -176,53 +176,35 @@ auto get_elimit<ck_tile::fp8_t>(std::string init_method)
}
}
int
num_splits_heuristic
(
int
batch_nhead_mblocks
,
int
num_SMs
,
int
num_n_blocks
,
int
max_splits
)
int
num_splits_heuristic
(
int
batch_nhead_mblocks
,
int
num_SMs
,
int
max_splits
)
{
// If we have enough to almost fill the SMs, then just use 1 split
if
(
batch_nhead_mblocks
>=
0.8
f
*
num_SMs
)
{
return
1
;
}
max_splits
=
std
::
min
({
max_splits
,
num_SMs
,
num_n_blocks
});
max_splits
=
std
::
min
({
max_splits
,
num_SMs
});
float
max_efficiency
=
0.
f
;
std
::
vector
<
float
>
efficiency
;
efficiency
.
reserve
(
max_splits
);
auto
ceildiv
=
[](
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
};
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
// (i.e. it's 11 splits anyway).
// So we check if the number of blocks per split is the same as the previous num_splits.
auto
is_split_eligible
=
[
&
ceildiv
,
&
num_n_blocks
](
int
num_splits
)
{
return
num_splits
==
1
||
ceildiv
(
num_n_blocks
,
num_splits
)
!=
ceildiv
(
num_n_blocks
,
num_splits
-
1
);
};
for
(
int
num_splits
=
1
;
num_splits
<=
max_splits
;
num_splits
++
)
{
if
(
!
is_split_eligible
(
num_splits
))
{
efficiency
.
push_back
(
0.
f
);
}
else
float
n_blocks
=
float
(
batch_nhead_mblocks
*
num_splits
)
/
num_SMs
;
float
eff
=
n_blocks
/
std
::
ceil
(
n_blocks
);
if
(
eff
>
max_efficiency
)
{
float
n_waves
=
float
(
batch_nhead_mblocks
*
num_splits
)
/
num_SMs
;
float
eff
=
n_waves
/
ceil
(
n_waves
);
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
if
(
eff
>
max_efficiency
)
{
max_efficiency
=
eff
;
}
efficiency
.
push_back
(
eff
);
max_efficiency
=
eff
;
}
efficiency
.
push_back
(
eff
);
}
for
(
int
num_splits
=
1
;
num_splits
<=
max_splits
;
num_splits
++
)
{
if
(
!
is_split_eligible
(
num_splits
))
{
continue
;
}
if
(
efficiency
[
num_splits
-
1
]
>=
0.85
*
max_efficiency
)
{
// printf("num_splits chosen = %d\n", num_splits);
return
num_splits
;
}
}
...
...
@@ -266,15 +248,15 @@ int override_num_splits_if_necessary(
return
64
;
// meet unsupported hdim_q/hdim_v
}();
const
int
kN1
=
hdim_v
;
//
const int kN1 = hdim_v;
const
int
num_m_blocks
=
ck_tile
::
integer_divide_ceil
(
max_seqlen_q
,
kM0
);
const
int
num_n_blocks
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
//
const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1);
// always 1
if
(
num_splits
<
1
&&
p_drop
==
0.0
f
)
{
return
num_splits_heuristic
(
batch
*
nhead
*
num_m_blocks
,
props
.
multiProcessorCount
*
2
,
num_n_blocks
,
128
);
batch
*
nhead
*
num_m_blocks
,
props
.
multiProcessorCount
*
2
,
32
);
}
return
num_splits
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment