Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a0e50a42
Unverified
Commit
a0e50a42
authored
Feb 24, 2026
by
Hashem Hashemi
Committed by
GitHub
Feb 24, 2026
Browse files
Convert wvSplitKQ to 16x16 MFMA in prep for mi4xx. (#34100)
Signed-off-by:
Hashem Hashemi
<
hashem.hashemi@amd.com
>
parent
9fa5b25a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
46 deletions
+14
-46
csrc/rocm/skinny_gemms.cu
csrc/rocm/skinny_gemms.cu
+14
-46
No files found.
csrc/rocm/skinny_gemms.cu
View file @
a0e50a42
...
@@ -1902,7 +1902,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -1902,7 +1902,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
float
sB
=
*
s_B
;
float
sB
=
*
s_B
;
while
(
m
<
M
)
{
while
(
m
<
M
)
{
floatx16
sum
[
N
][
YTILE
]
=
{};
scalar8
sum
[
N
][
YTILE
]
=
{};
for
(
uint32_t
k1
=
0
;
k1
<
K
;
k1
+=
THRDS
*
A_CHUNK
*
UNRL
)
{
for
(
uint32_t
k1
=
0
;
k1
<
K
;
k1
+=
THRDS
*
A_CHUNK
*
UNRL
)
{
bigType
bigA
[
N
][
UNRL
]
=
{};
bigType
bigA
[
N
][
UNRL
]
=
{};
bigType
bigB
[
YTILE
][
UNRL
];
bigType
bigB
[
YTILE
][
UNRL
];
...
@@ -1936,7 +1936,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -1936,7 +1936,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for
(
uint32_t
n
=
0
;
n
<
N
;
n
++
)
{
for
(
uint32_t
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
i
=
0
;
i
<
A_CHUNK
;
i
+=
8
)
{
for
(
int
i
=
0
;
i
<
A_CHUNK
;
i
+=
8
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
++
y
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
++
y
)
{
sum
[
n
][
y
]
=
__builtin_amdgcn_mfma_f32_
32x32x16
_fp8_fp8
(
sum
[
n
][
y
]
=
__builtin_amdgcn_mfma_f32_
16x16x32
_fp8_fp8
(
bigA
[
n
][
k2
].
l
[
i
/
8
],
bigB
[
y
][
k2
].
l
[
i
/
8
],
sum
[
n
][
y
],
0
,
0
,
bigA
[
n
][
k2
].
l
[
i
/
8
],
bigB
[
y
][
k2
].
l
[
i
/
8
],
sum
[
n
][
y
],
0
,
0
,
0
);
0
);
}
}
...
@@ -1949,31 +1949,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -1949,31 +1949,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
float
accm0
=
sum
[
n
][
y
][
0
];
float
accm0
=
sum
[
n
][
y
][
0
];
float
accm16
=
sum
[
n
][
y
][
8
];
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
1
],
0x101
,
0xf
,
0xf
,
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
1
],
0x101
,
0xf
,
0xf
,
1
);
// row_shl1
1
);
// row_shl1
accm16
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
9
],
0x101
,
0xf
,
0xf
,
1
);
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
2
],
0x102
,
0xf
,
0xf
,
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
2
],
0x102
,
0xf
,
0xf
,
1
);
// row_shl2
1
);
// row_shl2
accm16
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
10
],
0x102
,
0xf
,
0xf
,
1
);
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
3
],
0x103
,
0xf
,
0xf
,
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
3
],
0x103
,
0xf
,
0xf
,
1
);
// row_shl3
1
);
// row_shl3
accm16
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
11
],
0x103
,
0xf
,
0xf
,
1
);
accm0
+=
__shfl_down
(
accm0
,
20
);
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
4
],
0x108
,
0xf
,
0xf
,
accm0
+=
__shfl_down
(
accm0
,
40
);
1
);
// row_shl8
sum
[
n
][
y
][
0
]
=
accm0
;
accm16
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
12
],
0x108
,
0xf
,
0xf
,
1
);
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
5
],
0x109
,
0xf
,
0xf
,
1
);
// row_shl9
accm16
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
13
],
0x109
,
0xf
,
0xf
,
1
);
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
6
],
0x10a
,
0xf
,
0xf
,
1
);
// row_shl10
accm16
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
14
],
0x10a
,
0xf
,
0xf
,
1
);
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
7
],
0x10b
,
0xf
,
0xf
,
1
);
// row_shl11
accm16
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
15
],
0x10b
,
0xf
,
0xf
,
1
);
accm0
+=
__shfl
(
accm0
,
36
);
accm16
+=
__shfl
(
accm16
,
52
);
sum
[
n
][
y
][
0
]
=
accm0
+
__shfl
(
accm16
,
16
);
}
}
}
}
...
@@ -2064,7 +2048,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -2064,7 +2048,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
float
sB
=
*
s_B
;
float
sB
=
*
s_B
;
while
(
m
<
M
)
{
while
(
m
<
M
)
{
floatx16
sum
[
N
][
YTILE
]
=
{};
scalar8
sum
[
N
][
YTILE
]
=
{};
for
(
uint32_t
k1
=
0
;
k1
<
K
;
k1
+=
THRDS
*
A_CHUNK
*
UNRL
)
{
for
(
uint32_t
k1
=
0
;
k1
<
K
;
k1
+=
THRDS
*
A_CHUNK
*
UNRL
)
{
bigType
bigA
[
N
][
UNRL
]
=
{};
bigType
bigA
[
N
][
UNRL
]
=
{};
bigType
bigB
[
YTILE
][
UNRL
];
bigType
bigB
[
YTILE
][
UNRL
];
...
@@ -2100,7 +2084,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -2100,7 +2084,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for
(
uint32_t
n
=
0
;
n
<
N
;
n
++
)
{
for
(
uint32_t
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
i
=
0
;
i
<
A_CHUNK
;
i
+=
8
)
{
for
(
int
i
=
0
;
i
<
A_CHUNK
;
i
+=
8
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
++
y
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
++
y
)
{
sum
[
n
][
y
]
=
__builtin_amdgcn_mfma_f32_
32x32x16
_fp8_fp8
(
sum
[
n
][
y
]
=
__builtin_amdgcn_mfma_f32_
16x16x32
_fp8_fp8
(
bigA
[
n
][
k2
].
l
[
i
/
8
],
bigB
[
y
][
k2
].
l
[
i
/
8
],
sum
[
n
][
y
],
0
,
0
,
bigA
[
n
][
k2
].
l
[
i
/
8
],
bigB
[
y
][
k2
].
l
[
i
/
8
],
sum
[
n
][
y
],
0
,
0
,
0
);
0
);
}
}
...
@@ -2113,31 +2097,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -2113,31 +2097,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
float
accm0
=
sum
[
n
][
y
][
0
];
float
accm0
=
sum
[
n
][
y
][
0
];
float
accm16
=
sum
[
n
][
y
][
8
];
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
1
],
0x101
,
0xf
,
0xf
,
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
1
],
0x101
,
0xf
,
0xf
,
1
);
// row_shl1
1
);
// row_shl1
accm16
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
9
],
0x101
,
0xf
,
0xf
,
1
);
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
2
],
0x102
,
0xf
,
0xf
,
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
2
],
0x102
,
0xf
,
0xf
,
1
);
// row_shl2
1
);
// row_shl2
accm16
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
10
],
0x102
,
0xf
,
0xf
,
1
);
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
3
],
0x103
,
0xf
,
0xf
,
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
3
],
0x103
,
0xf
,
0xf
,
1
);
// row_shl3
1
);
// row_shl3
accm16
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
11
],
0x103
,
0xf
,
0xf
,
1
);
accm0
+=
__shfl_down
(
accm0
,
20
);
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
4
],
0x108
,
0xf
,
0xf
,
accm0
+=
__shfl_down
(
accm0
,
40
);
1
);
// row_shl8
sum
[
n
][
y
][
0
]
=
accm0
;
accm16
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
12
],
0x108
,
0xf
,
0xf
,
1
);
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
5
],
0x109
,
0xf
,
0xf
,
1
);
// row_shl9
accm16
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
13
],
0x109
,
0xf
,
0xf
,
1
);
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
6
],
0x10a
,
0xf
,
0xf
,
1
);
// row_shl10
accm16
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
14
],
0x10a
,
0xf
,
0xf
,
1
);
accm0
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
7
],
0x10b
,
0xf
,
0xf
,
1
);
// row_shl11
accm16
+=
__builtin_amdgcn_mov_dpp
(
sum
[
n
][
y
][
15
],
0x10b
,
0xf
,
0xf
,
1
);
accm0
+=
__shfl
(
accm0
,
36
);
accm16
+=
__shfl
(
accm16
,
52
);
sum
[
n
][
y
][
0
]
=
accm0
+
__shfl
(
accm16
,
16
);
}
}
}
}
...
@@ -2242,16 +2210,16 @@ void wvSplitKQ(const at::Tensor& in_b, const at::Tensor& in_a,
...
@@ -2242,16 +2210,16 @@ void wvSplitKQ(const at::Tensor& in_b, const at::Tensor& in_a,
:
nullptr
;
:
nullptr
;
switch
(
N_in
)
{
switch
(
N_in
)
{
case
1
:
case
1
:
WVSPLITKQ
(
1
2
,
2
,
2
,
2
,
2
,
1
)
WVSPLITKQ
(
1
6
,
2
,
2
,
2
,
2
,
1
)
break
;
break
;
case
2
:
case
2
:
WVSPLITKQ
(
1
2
,
2
,
2
,
2
,
2
,
2
)
WVSPLITKQ
(
1
6
,
2
,
2
,
2
,
2
,
2
)
break
;
break
;
case
3
:
case
3
:
WVSPLITKQ
(
8
,
2
,
2
,
1
,
1
,
3
)
WVSPLITKQ
(
16
,
2
,
2
,
2
,
2
,
3
)
break
;
break
;
case
4
:
case
4
:
WVSPLITKQ
(
4
,
2
,
2
,
1
,
1
,
4
)
WVSPLITKQ
(
16
,
2
,
2
,
2
,
2
,
4
)
break
;
break
;
default:
default:
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment