Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2e7054da
Unverified
Commit
2e7054da
authored
Dec 09, 2025
by
Hashem Hashemi
Committed by
GitHub
Dec 09, 2025
Browse files
Improve wvsplitK tile and balance heristics. (#29937)
Signed-off-by:
Hashem Hashemi
<
hashem.hashemi@amd.com
>
parent
3c680f4a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
48 additions
and
49 deletions
+48
-49
csrc/rocm/skinny_gemms.cu
csrc/rocm/skinny_gemms.cu
+48
-49
No files found.
csrc/rocm/skinny_gemms.cu
View file @
2e7054da
...
@@ -1241,33 +1241,16 @@ __global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx,
...
@@ -1241,33 +1241,16 @@ __global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx,
}
}
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
// Find the min val of div2 that doesn't increase N/(div1*div2)
int
mindiv
(
int
N
,
int
div1
,
int
div2
)
{
int
mindiv
(
int
N
,
int
div1
,
int
div2
)
{
int
nPrRnd
=
div1
*
div2
;
int
nPrRnd
=
div1
*
div2
;
int
rnds0
=
N
/
nPrRnd
;
int
rnds
[
13
];
nPrRnd
-=
div1
*
3
;
for
(
int
i
=
0
;
i
<
13
;
i
++
)
{
int
rnds3
=
N
/
nPrRnd
;
rnds
[
i
]
=
(
N
+
nPrRnd
-
1
)
/
nPrRnd
;
nPrRnd
-=
div1
;
nPrRnd
-=
div1
;
int
rnds4
=
N
/
nPrRnd
;
}
nPrRnd
-=
div1
;
for
(
int
i
=
12
;
i
>=
0
;
i
--
)
int
rnds5
=
N
/
nPrRnd
;
if
(
rnds
[
0
]
==
rnds
[
i
])
return
(
div2
-
i
);
nPrRnd
-=
div1
;
int
rnds6
=
N
/
nPrRnd
;
nPrRnd
-=
div1
;
int
rnds7
=
N
/
nPrRnd
;
nPrRnd
-=
div1
;
int
rnds8
=
N
/
nPrRnd
;
nPrRnd
-=
div1
;
int
rnds9
=
N
/
nPrRnd
;
nPrRnd
-=
div1
;
int
rtn
=
div2
;
if
(
rnds0
==
rnds3
)
rtn
=
div2
-
3
;
if
(
rnds0
==
rnds4
)
rtn
=
div2
-
4
;
if
(
rnds0
==
rnds5
)
rtn
=
div2
-
5
;
if
(
rnds0
==
rnds6
)
rtn
=
div2
-
6
;
if
(
rnds0
==
rnds7
)
rtn
=
div2
-
7
;
if
(
rnds0
==
rnds8
)
rtn
=
div2
-
8
;
if
(
rnds0
==
rnds9
)
rtn
=
div2
-
9
;
return
rtn
;
}
}
torch
::
Tensor
wvSplitK
(
const
at
::
Tensor
&
in_a
,
const
at
::
Tensor
&
in_b
,
torch
::
Tensor
wvSplitK
(
const
at
::
Tensor
&
in_a
,
const
at
::
Tensor
&
in_b
,
...
@@ -1300,26 +1283,37 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
...
@@ -1300,26 +1283,37 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
int
max_lds_len
=
get_lds_size
()
/
2
;
const
int
max_lds_len
=
get_lds_size
()
/
2
;
#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
#define WVSPLITK(_YTILE, _UNRL, _N) \
_N) \
{ \
{ \
dim3 block(64, 16); \
dim3 block(64, _WvPrGrp); \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
wvSplitK_hf_sml_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
wvSplitK_hf_sml_<fptype, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
biasf4, c, __wvPrGrp, CuCount); \
else if (K_in * N_in <= max_lds_len * 1.2) \
} else if (K_in * N_in <= max_lds_len * 1.2) { \
wvSplitK_hf_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
wvSplitK_hf_<fptype, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \
biasf4, c, __wvPrGrp, CuCount); \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
else \
biasf4, c, __wvPrGrp, CuCount); \
wvSplitK_hf_big_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
} else { \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \
biasf4, c, __wvPrGrp, CuCount); \
wvSplitK_hf_big_<fptype, 64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \
}
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
#define WVSPLIT_TILE(_sYT, __N) \
} \
{ \
bool fit_lds = (K_in * N_in <= max_lds_len); \
if (_sYT <= 1) \
WVSPLITK(1, 4, __N) \
else if ((__N == 1) || (!fit_lds) || (_sYT <= 4 * 2)) \
WVSPLITK(2, 2, __N) \
else if (_sYT <= 4 * 3) \
WVSPLITK(3, 2, __N) \
else if (__N == 4) \
WVSPLITK(4, 1, __N) \
else \
WVSPLITK(4, 2, __N) \
}
}
AT_DISPATCH_REDUCED_FLOATING_TYPES
(
in_b
.
scalar_type
(),
"wvSplitK"
,
[
&
]
{
AT_DISPATCH_REDUCED_FLOATING_TYPES
(
in_b
.
scalar_type
(),
"wvSplitK"
,
[
&
]
{
...
@@ -1331,18 +1325,23 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
...
@@ -1331,18 +1325,23 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
?
reinterpret_cast
<
const
fptype
*>
(
in_bias
->
data_ptr
())
?
reinterpret_cast
<
const
fptype
*>
(
in_bias
->
data_ptr
())
:
nullptr
;
:
nullptr
;
fptype
*
c
=
reinterpret_cast
<
fptype
*>
(
out_c
.
data_ptr
());
fptype
*
c
=
reinterpret_cast
<
fptype
*>
(
out_c
.
data_ptr
());
// first shoot for biggest tile-size that keeps all simd busy,
// then cut the active waves to balance their distribution...
int
sYT
=
(
M_in
+
CuCount
*
4
-
1
)
/
(
CuCount
*
4
);
switch
(
N_in
)
{
switch
(
N_in
)
{
case
1
:
case
1
:
WVSPLIT
K
(
16
,
2
,
2
,
2
,
2
,
2
,
2
,
1
)
WVSPLIT
_TILE
(
sYT
,
1
)
break
;
break
;
case
2
:
case
2
:
WVSPLIT
K
(
16
,
2
,
2
,
2
,
2
,
2
,
2
,
2
)
WVSPLIT
_TILE
(
sYT
,
2
)
break
;
break
;
case
3
:
case
3
:
WVSPLIT
K
(
16
,
4
,
7
,
7
,
1
,
1
,
1
,
3
)
WVSPLIT
_TILE
(
sYT
,
3
)
break
;
break
;
case
4
:
case
4
:
WVSPLIT
K
(
16
,
4
,
7
,
7
,
1
,
1
,
1
,
4
)
WVSPLIT
_TILE
(
sYT
,
4
)
break
;
break
;
default:
default:
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment