Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4b7869d6
Unverified
Commit
4b7869d6
authored
Apr 23, 2026
by
Matthias Gehre
Committed by
GitHub
Apr 23, 2026
Browse files
[ROCm] Add gfx1102/gfx1103 support (#40037)
Signed-off-by:
Matthias Gehre
<
matthias.gehre@amd.com
>
parent
4a79262e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
35 deletions
+21
-35
CMakeLists.txt
CMakeLists.txt
+1
-1
csrc/rocm/attention.cu
csrc/rocm/attention.cu
+2
-11
csrc/rocm/skinny_gemms.cu
csrc/rocm/skinny_gemms.cu
+18
-23
No files found.
CMakeLists.txt
View file @
4b7869d6
...
...
@@ -37,7 +37,7 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
set
(
PYTHON_SUPPORTED_VERSIONS
"3.10"
"3.11"
"3.12"
"3.13"
)
# Supported AMD GPU architectures.
set
(
HIP_SUPPORTED_ARCHS
"gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1150;gfx1151;gfx1152;gfx1153;gfx1200;gfx1201"
)
set
(
HIP_SUPPORTED_ARCHS
"gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;
gfx1102;gfx1103;
gfx1150;gfx1151;gfx1152;gfx1153;gfx1200;gfx1201"
)
# ROCm installation prefix. Default to /opt/rocm but allow override via
# -DROCM_PATH=/your/rocm/path when invoking cmake.
...
...
csrc/rocm/attention.cu
View file @
4b7869d6
...
...
@@ -40,15 +40,6 @@ using __hip_fp8_e5m2 = __hip_fp8_e5m2_fnuz;
#define __HIP__FP8MFMA__
#endif
#if defined(__HIPCC__) && (defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1150__) || defined(__gfx1151__))
#define __HIP__GFX11__
#endif
#if defined(__HIPCC__) && (defined(__gfx1200__) || defined(__gfx1201__))
#define __HIP__GFX12__
#endif
#if defined(NDEBUG)
#undef NDEBUG
#include <assert.h>
...
...
@@ -1629,7 +1620,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
}
}
#elif defined(__
HIP__
GFX11__)
#elif defined(__GFX11__)
using
floatx8
=
__attribute__
((
__vector_size__
(
8
*
sizeof
(
float
))))
float
;
...
...
@@ -2388,7 +2379,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
out_ptr
[
threadIdx
.
x
]
=
from_float
<
scalar_t
>
(
acc
);
}
#elif defined(__
HIP__
GFX12__)
#elif defined(__GFX12__)
using
floatx8
=
__attribute__
((
__vector_size__
(
8
*
sizeof
(
float
))))
float
;
...
...
csrc/rocm/skinny_gemms.cu
View file @
4b7869d6
...
...
@@ -26,16 +26,11 @@
#define __HIP__GFX9__
#endif
#if defined(__HIPCC__) && \
(defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1150__) || \
defined(__gfx1151__) || defined(__gfx1200__) || defined(__gfx1201__))
// Combined RDNA macro (gfx11 + gfx12) - both use 32-wide wavefronts
#if defined(__GFX11__) || defined(__GFX12__)
#define __HIP__GFX1X__
#endif
#if defined(__HIPCC__) && (defined(__gfx1200__) || defined(__gfx1201__))
#define __HIP__GFX12__
#endif
#if defined(__HIPCC__) && (defined(__gfx942__) || defined(__gfx950__))
#define __HIP__MI3XX__
#endif
...
...
@@ -1845,7 +1840,7 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
return
out_c
;
}
#if defined(__HIP__MI3XX__) || defined(__
HIP__
GFX12__)
#if defined(__HIP__MI3XX__) || defined(__GFX12__)
template
<
typename
scalar_t
,
typename
fp8_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
__global__
void
__launch_bounds__
(
WvPrGrp
*
THRDS
)
...
...
@@ -1893,7 +1888,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
float
sB
=
*
s_B
;
while
(
m
<
M
)
{
#ifdef
__HIP
__GFX12__
#ifdef __GFX12__
// gfx12: per-lane scalar accumulation via v_dot4_f32_fp8_fp8
float
sum
[
N
][
YTILE
]
=
{};
#else
...
...
@@ -1931,7 +1926,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
for
(
uint32_t
n
=
0
;
n
<
N
;
n
++
)
{
#ifdef
__HIP
__GFX12__
#ifdef __GFX12__
// gfx12: 4 x dot4 per A_CHUNK=16 bytes (4 FP8 per dot4)
for
(
int
y
=
0
;
y
<
YTILE
;
++
y
)
{
#pragma unroll
...
...
@@ -1955,7 +1950,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
// Final reduction
#ifdef
__HIP
__GFX12__
#ifdef __GFX12__
// gfx12 wave32: DPP row_shr within 16-lane rows + cross-row shuffle
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
...
...
@@ -1993,7 +1988,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#endif
const
bool
writeback_lane
=
#ifdef
__HIP
__GFX12__
#ifdef __GFX12__
threadIdx
.
x
==
(
THRDS
-
1
);
#else
threadIdx
.
x
==
0
;
...
...
@@ -2009,7 +2004,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
if
(
y
+
m
>=
M
)
break
;
// To avoid mem access fault.
#ifdef
__HIP
__GFX12__
#ifdef __GFX12__
float
result
=
sum
[
n
][
y
]
*
sA
*
sB
;
#else
float
result
=
sum
[
n
][
y
][
0
]
*
sA
*
sB
;
...
...
@@ -2027,7 +2022,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
m
+=
CuCount
*
_WvPrGrp
*
YTILE
;
}
}
#else // !defined(__HIP__MI3XX__) && !defined(__
HIP__
GFX12__)
#else // !defined(__HIP__MI3XX__) && !defined(__GFX12__)
template
<
typename
scalar_t
,
typename
fp8_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
__global__
void
wvSplitKQ_hf_sml_
(
const
int
K
,
const
int
Kap
,
const
int
Kbp
,
...
...
@@ -2039,9 +2034,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp,
const
int
_WvPrGrp
,
const
int
CuCount
)
{
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI3XX__) || defined(__
HIP__
GFX12__)
#endif // defined(__HIP__MI3XX__) || defined(__GFX12__)
#if defined(__HIP__MI3XX__) || defined(__
HIP__
GFX12__)
#if defined(__HIP__MI3XX__) || defined(__GFX12__)
template
<
typename
scalar_t
,
typename
fp8_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
__global__
void
__launch_bounds__
(
WvPrGrp
*
THRDS
)
...
...
@@ -2088,7 +2083,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
float
sB
=
*
s_B
;
while
(
m
<
M
)
{
#ifdef
__HIP
__GFX12__
#ifdef __GFX12__
// gfx12: per-lane scalar accumulation via v_dot4_f32_fp8_fp8
float
sum
[
N
][
YTILE
]
=
{};
#else
...
...
@@ -2128,7 +2123,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
for
(
uint32_t
n
=
0
;
n
<
N
;
n
++
)
{
#ifdef
__HIP
__GFX12__
#ifdef __GFX12__
// gfx12: 4 x dot4 per A_CHUNK=16 bytes (4 FP8 per dot4)
for
(
int
y
=
0
;
y
<
YTILE
;
++
y
)
{
#pragma unroll
...
...
@@ -2152,7 +2147,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
// Final reduction
#ifdef
__HIP
__GFX12__
#ifdef __GFX12__
// gfx12 wave32: DPP row_shr within 16-lane rows + cross-row shuffle
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
...
...
@@ -2190,7 +2185,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#endif
const
bool
writeback_lane
=
#ifdef
__HIP
__GFX12__
#ifdef __GFX12__
threadIdx
.
x
==
(
THRDS
-
1
);
#else
threadIdx
.
x
==
0
;
...
...
@@ -2206,7 +2201,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
y
=
0
;
y
<
YTILE
;
y
++
)
{
if
(
y
+
m
>=
M
)
break
;
// To avoid mem access fault.
#ifdef
__HIP
__GFX12__
#ifdef __GFX12__
float
result
=
sum
[
n
][
y
]
*
sA
*
sB
;
#else
float
result
=
sum
[
n
][
y
][
0
]
*
sA
*
sB
;
...
...
@@ -2224,7 +2219,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
m
+=
CuCount
*
_WvPrGrp
*
YTILE
;
}
}
#else // !defined(__HIP__MI3XX__) && !defined(__
HIP__
GFX12__)
#else // !defined(__HIP__MI3XX__) && !defined(__GFX12__)
template
<
typename
scalar_t
,
typename
fp8_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
__global__
void
wvSplitKQ_hf_
(
const
int
K
,
const
int
Kap
,
const
int
Kbp
,
...
...
@@ -2236,7 +2231,7 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kap, const int Kbp,
const
int
CuCount
)
{
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI3XX__) || defined(__
HIP__
GFX12__)
#endif // defined(__HIP__MI3XX__) || defined(__GFX12__)
void
wvSplitKQ
(
const
at
::
Tensor
&
in_b
,
const
at
::
Tensor
&
in_a
,
const
std
::
optional
<
at
::
Tensor
>&
in_bias
,
at
::
Tensor
&
out_c
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment