Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
314fa8ab
Unverified
Commit
314fa8ab
authored
Oct 16, 2025
by
Matthew Bonanni
Committed by
GitHub
Oct 16, 2025
Browse files
[Attention] Tune CUTLASS MLA num_splits (#26846)
Signed-off-by:
Matthew Bonanni
<
mbonanni@redhat.com
>
parent
334535b6
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
16 deletions
+21
-16
csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp
csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp
+21
-16
No files found.
csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp
View file @
314fa8ab
...
@@ -125,32 +125,37 @@ public:
...
@@ -125,32 +125,37 @@ public:
}
}
static
void
set_split_kv
(
KernelArguments
&
args
)
{
static
void
set_split_kv
(
KernelArguments
&
args
)
{
// printf("set_split_kv start");
if
(
args
.
split_kv
>=
1
)
return
;
if
(
args
.
split_kv
>=
1
)
return
;
auto
[
H
,
K
,
D
,
B
]
=
args
.
problem_shape
;
auto
[
H
,
K
,
D
,
B
]
=
args
.
problem_shape
;
// std::cout << H << " " << K << " " << D << " " << B << "\n";
int
sm_count
=
args
.
hw_info
.
sm_count
;
int
sm_count
=
args
.
hw_info
.
sm_count
;
// printf(" sm_count = %d\n", sm_count);
float
seq_length_k
=
static_cast
<
float
>
(
K
)
/
1024.0
f
;
int
max_splits
=
ceil_div
(
K
,
128
);
int
max_splits
=
1
;
max_splits
=
min
(
16
,
max_splits
);
// TODO: This avoids a hang when the batch size larger than 1 and
if
(
B
<=
4
&&
seq_length_k
>=
16
)
{
// there is more than 1 kv_splits.
max_splits
=
16
;
// Discuss with NVIDIA how this can be fixed.
}
if
(
B
>
1
)
{
else
if
(
B
<=
8
&&
seq_length_k
>=
4
)
{
max_splits
=
min
(
1
,
max_splits
);
max_splits
=
8
;
}
else
if
((
B
<=
16
&&
seq_length_k
>=
8
)
||
(
B
==
48
&&
seq_length_k
>=
32
))
{
max_splits
=
4
;
}
else
if
((
B
<=
32
&&
seq_length_k
>=
16
)
||
(
B
==
96
&&
seq_length_k
>=
16
))
{
max_splits
=
2
;
}
else
{
max_splits
=
1
;
}
}
//
printf(" max_splits = %d\n", max_splits);
//
Wave-aware scheduling: ensure integer number of waves in K dimension
int
sms_per_batch
=
max
(
1
,
sm_count
/
B
);
int
sms_per_batch
=
max
(
1
,
sm_count
/
B
);
// printf(" sms_per_batch = %d\n", sms_per_batch);
int
split_heur
=
min
(
max_splits
,
sms_per_batch
);
int
split_heur
=
min
(
max_splits
,
sms_per_batch
);
int
waves
=
ceil_div
(
B
*
split_heur
,
sm_count
);
int
waves
=
ceil_div
(
B
*
split_heur
,
sm_count
);
int
k_waves
=
ceil_div
(
max_splits
,
split_heur
);
int
k_waves
=
ceil_div
(
max_splits
,
split_heur
);
int
split_wave_aware
=
ceil_div
(
max_splits
,
k_waves
);
int
split_wave_aware
=
ceil_div
(
max_splits
,
k_waves
);
args
.
split_kv
=
split_wave_aware
;
args
.
split_kv
=
split_wave_aware
;
// printf(" args.split_kv = %d\n", args.split_kv);
}
}
/// Determines whether the GEMM can execute the given problem.
/// Determines whether the GEMM can execute the given problem.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment