Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7a103043
Unverified
Commit
7a103043
authored
Jan 16, 2026
by
Hashem Hashemi
Committed by
GitHub
Jan 16, 2026
Browse files
Atomics Reduce Counting Optimization for SplitK Skinny GEMMs. (#29843)
Signed-off-by:
Hashem Hashemi
<
hashem.hashemi@amd.com
>
parent
9fd918e5
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
635 additions
and
10 deletions
+635
-10
csrc/rocm/ops.h
csrc/rocm/ops.h
+4
-0
csrc/rocm/skinny_gemms.cu
csrc/rocm/skinny_gemms.cu
+545
-9
csrc/rocm/torch_bindings.cpp
csrc/rocm/torch_bindings.cpp
+6
-0
tests/kernels/quantization/test_rocm_skinny_gemms.py
tests/kernels/quantization/test_rocm_skinny_gemms.py
+53
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+6
-0
vllm/model_executor/layers/utils.py
vllm/model_executor/layers/utils.py
+21
-1
No files found.
csrc/rocm/ops.h
View file @
7a103043
...
@@ -9,6 +9,10 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
...
@@ -9,6 +9,10 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
const
std
::
optional
<
at
::
Tensor
>&
in_bias
,
const
std
::
optional
<
at
::
Tensor
>&
in_bias
,
const
int64_t
CuCount
);
const
int64_t
CuCount
);
torch
::
Tensor
wvSplitKrc
(
const
at
::
Tensor
&
in_a
,
const
at
::
Tensor
&
in_b
,
const
std
::
optional
<
at
::
Tensor
>&
in_bias
,
const
int64_t
CuCount
);
void
wvSplitKQ
(
const
at
::
Tensor
&
in_a
,
const
at
::
Tensor
&
in_b
,
void
wvSplitKQ
(
const
at
::
Tensor
&
in_a
,
const
at
::
Tensor
&
in_b
,
const
std
::
optional
<
at
::
Tensor
>&
in_bias
,
at
::
Tensor
&
out_c
,
const
std
::
optional
<
at
::
Tensor
>&
in_bias
,
at
::
Tensor
&
out_c
,
const
at
::
Tensor
&
scale_a
,
const
at
::
Tensor
&
scale_b
,
const
at
::
Tensor
&
scale_a
,
const
at
::
Tensor
&
scale_b
,
...
...
csrc/rocm/skinny_gemms.cu
View file @
7a103043
...
@@ -287,6 +287,11 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
...
@@ -287,6 +287,11 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
V0 += (s.x + s.y); \
V0 += (s.x + s.y); \
}
}
// To avoid LLVM silently upcasting to double
__device__
inline
unsigned
int
min__
(
uint32_t
a
,
uint32_t
b
)
{
return
min
(
a
,
b
);
}
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
// This version targets cases where A[] fits LDS capacity
// This version targets cases where A[] fits LDS capacity
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
...
@@ -334,11 +339,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -334,11 +339,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// - Then the WG will move to another 8 K elements
// - Then the WG will move to another 8 K elements
// TODO: Logic below will only work when K is multiple of 8
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
//----------------------------------------------------
for
(
uint32_t
k
=
0
;
k
<
min
(
K
*
N
,
max_lds_len
);
for
(
uint32_t
k
=
0
;
k
<
min
__
(
K
*
N
,
max_lds_len
);
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
uint32_t
k_in
=
k
+
((
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
);
uint32_t
k_in
=
k
+
((
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
);
if
(
k_in
>=
min
(
K
*
N
,
max_lds_len
))
break
;
if
(
k_in
>=
min
__
(
K
*
N
,
max_lds_len
))
break
;
*
((
bigType
*
)(
&
s
[
k_in
]))
=
*
((
bigType
*
)(
&
A
[
k_in
]));
*
((
bigType
*
)(
&
s
[
k_in
]))
=
*
((
bigType
*
)(
&
A
[
k_in
]));
}
}
...
@@ -633,11 +638,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -633,11 +638,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// - Then the WG will move to another 8 K elements
// - Then the WG will move to another 8 K elements
// TODO: Logic below will only work when K is multiple of 8
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
//----------------------------------------------------
for
(
uint32_t
k
=
0
;
k
<
min
(
K
*
N
,
max_lds_len
);
for
(
uint32_t
k
=
0
;
k
<
min
__
(
K
*
N
,
max_lds_len
);
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
uint32_t
k_in
=
k
+
((
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
);
uint32_t
k_in
=
k
+
((
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
);
if
(
k_in
>=
min
(
K
*
N
,
max_lds_len
))
break
;
if
(
k_in
>=
min
__
(
K
*
N
,
max_lds_len
))
break
;
*
((
bigType
*
)(
&
s
[
k_in
]))
=
*
((
bigType
*
)(
&
A
[
k_in
]));
*
((
bigType
*
)(
&
s
[
k_in
]))
=
*
((
bigType
*
)(
&
A
[
k_in
]));
}
}
...
@@ -954,11 +959,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -954,11 +959,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
//----------------------------------------------------
//----------------------------------------------------
#define PCML
#define PCML
#ifndef PCML
#ifndef PCML
for
(
uint32_t
k
=
0
;
k
<
min
(
K
*
N
,
max_lds_len
);
for
(
uint32_t
k
=
0
;
k
<
min
__
(
K
*
N
,
max_lds_len
);
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
uint32_t
k_in
=
k
+
((
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
);
uint32_t
k_in
=
k
+
((
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
);
if
(
k_in
>=
min
(
K
*
N
,
max_lds_len
))
break
;
if
(
k_in
>=
min
__
(
K
*
N
,
max_lds_len
))
break
;
*
((
bigType
*
)(
&
s
[
k_in
]))
=
*
((
bigType
*
)(
&
A
[
k_in
]));
*
((
bigType
*
)(
&
s
[
k_in
]))
=
*
((
bigType
*
)(
&
A
[
k_in
]));
}
}
...
@@ -975,7 +980,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -975,7 +980,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
?
kFit
?
kFit
:
(
kFit
-
kFit
%
TUC
);
// round up to multiple of TUC
:
(
kFit
-
kFit
%
TUC
);
// round up to multiple of TUC
// if (kFit == 0) kFit = TUC;
// if (kFit == 0) kFit = TUC;
kFit
=
min
(
kFit
,
K
);
kFit
=
min
__
(
kFit
,
K
);
float
sum
[
N
][
YTILE
];
float
sum
[
N
][
YTILE
];
scalar8
sum4
[
N
][
YTILE
];
scalar8
sum4
[
N
][
YTILE
];
...
@@ -1251,6 +1256,7 @@ int mindiv(int N, int div1, int div2) {
...
@@ -1251,6 +1256,7 @@ int mindiv(int N, int div1, int div2) {
}
}
for
(
int
i
=
12
;
i
>=
0
;
i
--
)
for
(
int
i
=
12
;
i
>=
0
;
i
--
)
if
(
rnds
[
0
]
==
rnds
[
i
])
return
(
div2
-
i
);
if
(
rnds
[
0
]
==
rnds
[
i
])
return
(
div2
-
i
);
return
0
;
}
}
torch
::
Tensor
wvSplitK
(
const
at
::
Tensor
&
in_a
,
const
at
::
Tensor
&
in_b
,
torch
::
Tensor
wvSplitK
(
const
at
::
Tensor
&
in_a
,
const
at
::
Tensor
&
in_b
,
...
@@ -1352,6 +1358,536 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
...
@@ -1352,6 +1358,536 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
return
out_c
;
return
out_c
;
}
}
#if defined(__gfx950__) // TODO: Add NAVI support
// This version targets big A[] cases, where it is much larger than LDS
// capacity
#define WVSPLITKRC_1KPASS
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
,
int
GrpsShrB
>
__global__
void
__launch_bounds__
(
WvPrGrp
*
THRDS
)
__attribute__
((
amdgpu_waves_per_eu
(
1
,
1
)))
wvSplitKrc_
(
const
int
actlN
,
const
int
K
,
const
int
M
,
const
int
Bx
,
const
int
By
,
const
scalar_t
*
__restrict__
B
,
const
scalar_t
*
__restrict__
A
,
const
scalar_t
*
__restrict__
BIAS
,
float
*
glbl
,
scalar_t
*
C
,
const
int
CuCount
)
{
// Use upper half of glbl buffer for atomic reduce counting
int
*
cntr
=
(
int
*
)(
&
glbl
[
M
*
N
]);
constexpr
int
NTILE
=
16
;
constexpr
int
WVLDS_
=
(
NTILE
*
THRDS
*
A_CHUNK
);
constexpr
int
APAD
=
1
;
constexpr
int
ASTRD
=
64
;
constexpr
int
BPAD
=
1
;
constexpr
int
BSTRD
=
64
;
constexpr
int
WVLDS
=
((
WVLDS_
+
(
WVLDS_
/
BSTRD
)
*
4
*
BPAD
));
constexpr
int
max_lds_len
=
LDS_SIZE
/
2
;
using
scalar16
=
__attribute__
((
__vector_size__
((
A_CHUNK
*
2
)
*
sizeof
(
float
))))
float
;
using
scalar8
=
__attribute__
((
__vector_size__
((
A_CHUNK
/
2
)
*
sizeof
(
float
))))
float
;
using
half4
=
__attribute__
((
__vector_size__
((
A_CHUNK
/
2
)
*
sizeof
(
__bf16
))))
__bf16
;
union
bigType
{
scalar_t
h
[
A_CHUNK
];
float
f
[
A_CHUNK
/
2
];
unsigned
int
i
[
A_CHUNK
/
2
];
float2
f2
[
A_CHUNK
/
4
];
unsigned
long
l
[
A_CHUNK
/
4
];
double
d
[
A_CHUNK
/
4
];
half4
h4
[
A_CHUNK
/
4
];
scalar8
h8
;
};
using
big4
=
__attribute__
((
__vector_size__
(
4
*
sizeof
(
bigType
))))
__bf16
;
__shared__
scalar_t
stg
[
WvPrGrp
*
WVLDS
/
GrpsShrB
];
unsigned
int
*
myStg
=
(
unsigned
int
*
)(
&
stg
[
WVLDS
*
(
threadIdx
.
y
/
GrpsShrB
)]);
__shared__
scalar_t
s
[
max_lds_len
-
WvPrGrp
*
WVLDS
/
GrpsShrB
];
#ifndef WVSPLITKRC_1KPASS
constexpr
int
TUC_
=
(
THRDS
*
UNRL
*
A_CHUNK
);
// find biggest k size that fits padded into LDS
constexpr
uint32_t
kFit__
=
(
max_lds_len
-
WvPrGrp
*
WVLDS
/
GrpsShrB
)
/
N
;
constexpr
uint32_t
kFit_
=
(
kFit__
*
ASTRD
)
/
(
APAD
+
ASTRD
);
uint32_t
kFit
=
kFit_
-
(
kFit_
%
TUC_
);
uint32_t
kfitsPerRdc
=
(
K
+
kFit
-
1
)
/
kFit
;
// find best k split to fill the CUs
if
(((
K
+
kfitsPerRdc
*
kFit
-
1
)
/
(
kfitsPerRdc
*
kFit
))
*
numCuWithFullK
<=
CuCount
)
while
(
true
)
{
while
(
kFit
>
TUC_
)
{
uint32_t
kFit_
=
kFit
-
TUC_
;
if
(((
K
+
(
kfitsPerRdc
*
kFit_
-
1
))
/
(
kfitsPerRdc
*
kFit_
))
*
numCuWithFullK
>
CuCount
)
break
;
kFit
=
kFit_
;
}
if
(((
K
+
((
kfitsPerRdc
-
1
)
*
kFit
-
1
))
/
((
kfitsPerRdc
-
1
)
*
kFit
))
*
numCuWithFullK
<=
CuCount
)
kfitsPerRdc
--
;
else
break
;
}
#else
int
constexpr
kFit
=
512
;
int
constexpr
kfitsPerRdc
=
1
;
#endif
bool
doRdc
=
(
kfitsPerRdc
*
kFit
<
K
);
uint32_t
numCuWithFullK
=
((
M
+
(
WvPrGrp
*
YTILE
/
GrpsShrB
)
-
1
)
/
(
WvPrGrp
*
YTILE
/
GrpsShrB
));
uint32_t
Mmod
=
numCuWithFullK
*
(
WvPrGrp
*
YTILE
/
GrpsShrB
);
// given above k-split, find this wave's position
uint32_t
kFitPdd
=
kFit
+
(
kFit
/
ASTRD
)
*
APAD
;
uint32_t
m0
=
(
blockIdx
.
x
*
WvPrGrp
/
GrpsShrB
)
*
YTILE
;
uint32_t
m1
=
((
threadIdx
.
y
%
WvPrGrp
)
/
GrpsShrB
)
*
YTILE
;
uint32_t
m
=
(
m0
+
m1
)
%
Mmod
;
const
uint32_t
k_str
=
(
m0
/
Mmod
)
*
kFit
*
kfitsPerRdc
;
uint32_t
k_end
=
(
m0
/
Mmod
+
1
)
*
kFit
*
kfitsPerRdc
;
const
uint32_t
k_rnd
=
(
K
+
kFit
*
kfitsPerRdc
-
1
)
/
(
kFit
*
kfitsPerRdc
);
scalar8
sum4
[
N
/
NTILE
/
GrpsShrB
][
1
];
bigType
bigB_
[
YTILE
/
GrpsShrB
][
UNRL
];
const
uint32_t
bLoader
=
(
threadIdx
.
y
%
GrpsShrB
);
uint32_t
kBase
=
0
;
if
(
k_str
>=
K
)
return
;
if
(
m
>=
Mmod
)
return
;
bool
noreloada
=
false
;
constexpr
bool
FAST_UNSAFE_RDC_INIT
=
false
;
#ifdef WVSPLITKRC_1KPASS
// Early glbl init, B[] loading, if 1KPASS
if
constexpr
(
FAST_UNSAFE_RDC_INIT
)
{
if
(
m
+
(
threadIdx
.
x
%
16
)
<
M
)
if
(
doRdc
)
if
(
k_str
==
0
)
{
int
mindx
=
m
+
(
threadIdx
.
x
%
16
);
int
nindx_
=
(
0
+
(
threadIdx
.
x
/
16
)
*
4
)
+
0
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
);
int
adr_
=
mindx
+
M
*
nindx_
/
4
;
__hip_atomic_store
(
&
cntr
[
adr_
],
0
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_AGENT
);
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
{
for
(
uint32_t
j
=
0
;
j
<
4
;
j
++
)
{
int
nindx
=
(
j
+
(
threadIdx
.
x
/
16
)
*
4
)
+
nt
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
);
int
adr
=
mindx
+
M
*
nindx
;
__hip_atomic_store
(
&
glbl
[
adr
],
0
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_AGENT
);
}
}
}
}
// Load first B[] chunk
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k_str
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
const
scalar_t
*
B_
=
&
B
[
min__
(
k_
,
K
-
A_CHUNK
)];
#pragma unroll
for
(
uint32_t
y
=
0
;
y
<
YTILE
/
GrpsShrB
;
y
++
)
bigB_
[
y
][
k2
].
h8
=
(
loadnt
(
(
scalar8
*
)(
&
B_
[
min__
(
y
*
GrpsShrB
+
bLoader
+
m
,
M
-
1
)
*
K
])));
}
{
#else
while
(
m
<
Mmod
)
{
#endif
#ifndef WVSPLITKRC_1KPASS
if
constexpr
(
FAST_UNSAFE_RDC_INIT
)
{
if
(
m
+
(
threadIdx
.
x
%
16
)
<
M
)
if
(
doRdc
)
if
(
k_str
==
0
)
{
int
mindx
=
m
+
(
threadIdx
.
x
%
16
);
int
nindx_
=
(
0
+
(
threadIdx
.
x
/
16
)
*
4
)
+
0
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
);
int
adr_
=
mindx
+
M
*
nindx_
/
4
;
__hip_atomic_store
(
&
cntr
[
adr_
],
0
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_AGENT
);
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
{
for
(
uint32_t
j
=
0
;
j
<
4
;
j
++
)
{
int
nindx
=
(
j
+
(
threadIdx
.
x
/
16
)
*
4
)
+
nt
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
);
int
adr
=
mindx
+
M
*
nindx
;
__hip_atomic_store
(
&
glbl
[
adr
],
0
,
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_AGENT
);
}
}
}
}
#endif
#ifndef WVSPLITKRC_1KPASS
for
(
uint32_t
k1
=
k_str
;
k1
<
k_end
;
k1
+=
THRDS
*
A_CHUNK
*
UNRL
)
{
#else
const
uint32_t
k1
=
k_str
;
{
#endif
#ifndef WVSPLITKRC_1KPASS
const
bool
reloada
=
(
!
noreloada
)
&&
((
k1
==
k_str
)
||
(
k1
==
k_str
+
kBase
+
kFit
))
&&
(
k1
<
k_end
);
// load next chunk of A[] to LDS
if
(
reloada
)
{
if
(
k1
!=
k_str
)
kBase
+=
kFit
;
__syncthreads
();
#else
const
bool
reloada
=
(
!
noreloada
)
&&
((
k1
==
k_str
)
||
(
k1
==
k_str
+
kBase
+
kFit
))
&&
(
k1
<
k_end
);
if
(
reloada
)
{
#endif
constexpr
int
sprdN
=
4
;
const
uint32_t
thrd
=
((
threadIdx
.
y
/
sprdN
)
*
THRDS
+
threadIdx
.
x
);
#ifndef WVSPLITKRC_1KPASS
#pragma unroll
for
(
int
k
=
0
;
k
<
kFit
;
k
+=
THRDS
*
(
WvPrGrp
/
sprdN
)
*
A_CHUNK
)
{
#else
const
unsigned
int
k
=
0
;
{
#endif
unsigned
int
kOff
=
k
+
(
thrd
*
A_CHUNK
);
unsigned
int
kOffcp
=
min__
(
K
-
A_CHUNK
,
k_str
+
kOff
);
const
unsigned
int
k_in
=
kOffcp
+
((
threadIdx
.
y
%
sprdN
))
*
K
;
const
unsigned
int
k_ot
=
kOff
+
((
threadIdx
.
y
%
sprdN
))
*
kFitPdd
;
for
(
unsigned
int
n
=
0
;
n
<
N
/
2
;
n
+=
sprdN
)
{
__builtin_amdgcn_global_load_lds
((
int
*
)(
&
A
[
k_in
+
n
*
K
]),
(
int
*
)(
&
s
[(
k_ot
+
n
*
kFitPdd
)]),
16
,
0
,
0
);
if
(((
threadIdx
.
y
%
sprdN
))
+
n
+
N
/
2
>=
actlN
)
continue
;
__builtin_amdgcn_global_load_lds
(
(
int
*
)(
&
A
[
k_in
+
(
n
+
N
/
2
)
*
K
]),
(
int
*
)(
&
s
[(
k_ot
+
(
n
+
N
/
2
)
*
kFitPdd
)]),
16
,
0
,
0
);
}
// Stage loaded B[] to LDS for MFMA swizzling...
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
const
bool
oob_k
=
(
k_
>=
K
);
for
(
uint32_t
y
=
0
;
y
<
YTILE
/
GrpsShrB
;
y
++
)
{
uint32_t
idx
=
threadIdx
.
x
*
4
+
(
y
*
GrpsShrB
+
bLoader
)
*
((
THRDS
+
BPAD
)
*
4
);
// zero out if oob
*
((
scalar8
*
)
&
myStg
[
idx
])
=
(
oob_k
||
(
y
*
GrpsShrB
+
bLoader
+
m
>=
M
))
?
0
:
bigB_
[
y
][
k2
].
h8
;
}
}
}
}
}
#ifndef WVSPLITKRC_1KPASS
// Fire load of next B[] chunk...
if
((
k1
+
THRDS
*
A_CHUNK
*
UNRL
<
k_end
)
&&
(
k1
+
THRDS
*
A_CHUNK
*
UNRL
<
K
))
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
THRDS
*
A_CHUNK
*
UNRL
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
const
scalar_t
*
B_
=
&
B
[
min__
(
k_
,
K
-
A_CHUNK
)];
#pragma unroll
for
(
uint32_t
y
=
0
;
y
<
YTILE
/
GrpsShrB
;
y
++
)
bigB_
[
y
][
k2
].
h8
=
(
loadnt
(
(
scalar8
*
)(
&
B_
[
min__
(
y
*
GrpsShrB
+
bLoader
+
m
,
M
-
1
)
*
K
])));
}
#endif
// B[] staging is cooperative across GrpsShrB, so sync here before reading
// back
__syncthreads
();
// read back B[] swizzled for MFMA...
bigType
bigB
[
YTILE
][
UNRL
];
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
for
(
uint32_t
y
=
0
;
y
<
YTILE
;
y
++
)
{
unsigned
int
idx
=
(
threadIdx
.
x
%
YTILE
)
*
((
THRDS
+
BPAD
)
*
4
)
+
(
threadIdx
.
x
/
YTILE
)
*
4
+
y
*
16
;
bigB
[
y
][
k2
].
h8
=
*
((
scalar8
*
)
&
myStg
[
idx
]);
}
}
// rReadback A[] swizzled for MFMA...
bigType
bigA
[
N
/
GrpsShrB
][
UNRL
];
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
-
kBase
-
k_str
;
#pragma unroll
for
(
uint32_t
nt
=
0
;
nt
<
N
/
GrpsShrB
;
nt
+=
NTILE
)
#pragma unroll
for
(
uint32_t
n
=
0
;
n
<
NTILE
;
n
++
)
{
uint32_t
idxa
=
(
nt
+
(
threadIdx
.
x
%
NTILE
)
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
))
*
kFitPdd
+
A_CHUNK
*
((
threadIdx
.
x
/
NTILE
)
+
n
*
4
)
+
k
;
bigA
[
nt
+
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
s
[
idxa
])));
}
}
// Do the MFMAs
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
#pragma unroll
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
{
if
constexpr
(
std
::
is_same_v
<
scalar_t
,
half
>
)
{
sum4
[
nt
][
0
]
=
__builtin_amdgcn_mfma_f32_16x16x16f16
(
bigA
[
nt
*
NTILE
+
0
][
k2
].
h4
[
0
],
bigB
[
0
][
k2
].
h4
[
0
],
(
k1
==
k_str
)
?
((
scalar8
){
0
})
:
sum4
[
nt
][
0
],
0
,
0
,
0
);
sum4
[
nt
][
0
]
=
__builtin_amdgcn_mfma_f32_16x16x16f16
(
bigA
[
nt
*
NTILE
+
0
][
k2
].
h4
[
1
],
bigB
[
0
][
k2
].
h4
[
1
],
sum4
[
nt
][
0
],
0
,
0
,
0
);
}
else
{
// bf16
sum4
[
nt
][
0
]
=
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k
(
bigA
[
nt
*
NTILE
+
0
][
k2
].
h4
[
0
],
bigB
[
0
][
k2
].
h4
[
0
],
(
k1
==
k_str
)
?
((
scalar8
){
0
})
:
sum4
[
nt
][
0
],
0
,
0
,
0
);
sum4
[
nt
][
0
]
=
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k
(
bigA
[
nt
*
NTILE
+
0
][
k2
].
h4
[
1
],
bigB
[
0
][
k2
].
h4
[
1
],
sum4
[
nt
][
0
],
0
,
0
,
0
);
}
#pragma unroll
for
(
uint32_t
j
=
1
;
j
<
YTILE
;
j
++
)
{
if
constexpr
(
std
::
is_same_v
<
scalar_t
,
half
>
)
{
sum4
[
nt
][
0
]
=
__builtin_amdgcn_mfma_f32_16x16x16f16
(
bigA
[
nt
*
NTILE
+
j
][
k2
].
h4
[
0
],
bigB
[
j
][
k2
].
h4
[
0
],
sum4
[
nt
][
0
],
0
,
0
,
0
);
sum4
[
nt
][
0
]
=
__builtin_amdgcn_mfma_f32_16x16x16f16
(
bigA
[
nt
*
NTILE
+
j
][
k2
].
h4
[
1
],
bigB
[
j
][
k2
].
h4
[
1
],
sum4
[
nt
][
0
],
0
,
0
,
0
);
}
else
{
// bf16
sum4
[
nt
][
0
]
=
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k
(
bigA
[
nt
*
NTILE
+
j
][
k2
].
h4
[
0
],
bigB
[
j
][
k2
].
h4
[
0
],
sum4
[
nt
][
0
],
0
,
0
,
0
);
sum4
[
nt
][
0
]
=
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k
(
bigA
[
nt
*
NTILE
+
j
][
k2
].
h4
[
1
],
bigB
[
j
][
k2
].
h4
[
1
],
sum4
[
nt
][
0
],
0
,
0
,
0
);
}
}
}
}
}
if
(
!
doRdc
)
{
if
(
m
+
(
threadIdx
.
x
%
16
)
<
M
)
{
scalar_t
biases
[
N
/
NTILE
/
GrpsShrB
][
4
]
=
{
0
};
if
(
BIAS
)
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
{
for
(
uint32_t
j
=
0
;
j
<
4
;
j
++
)
{
int
mindx
=
m
+
(
threadIdx
.
x
%
16
);
int
nindx
=
(
j
+
(
threadIdx
.
x
/
16
)
*
4
)
+
nt
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
);
biases
[
nt
][
j
]
=
BIAS
[(
mindx
%
Bx
)
+
(
nindx
%
By
)
*
M
];
}
}
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
{
for
(
uint32_t
j
=
0
;
j
<
4
;
j
++
)
{
int
mindx
=
m
+
(
threadIdx
.
x
%
16
);
int
nindx
=
(
j
+
(
threadIdx
.
x
/
16
)
*
4
)
+
nt
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
);
int
adr
=
mindx
+
M
*
nindx
;
if
constexpr
(
std
::
is_same_v
<
scalar_t
,
__hip_bfloat16
>
)
{
if
(
BIAS
)
sum4
[
nt
][
0
][
j
]
+=
__bfloat162float
(
biases
[
nt
][
j
]);
C
[
adr
]
=
__float2bfloat16
(
sum4
[
nt
][
0
][
j
]);
}
else
{
if
(
BIAS
)
sum4
[
nt
][
0
][
j
]
+=
__half2float
(
biases
[
nt
][
j
]);
C
[
adr
]
=
__float2half
(
sum4
[
nt
][
0
][
j
]);
}
}
}
}
}
else
{
if
(
m
+
(
threadIdx
.
x
%
16
)
<
M
)
{
int
my_cntr
;
if
(
!
BIAS
)
{
int
mindx
=
m
+
(
threadIdx
.
x
%
16
);
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
for
(
uint32_t
j
=
0
;
j
<
4
;
j
++
)
{
int
nindx
=
(
j
+
(
threadIdx
.
x
/
16
)
*
4
)
+
nt
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
);
int
adr
=
mindx
+
M
*
nindx
;
atomicAdd
(
&
glbl
[
adr
],
sum4
[
nt
][
0
][
j
]);
}
int
nindx_
=
(
0
+
(
threadIdx
.
x
/
16
)
*
4
)
+
0
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
);
int
adr_
=
mindx
+
M
*
nindx_
/
4
;
my_cntr
=
atomicAdd
(
&
cntr
[
adr_
],
1
);
float
vals
[
N
/
NTILE
/
GrpsShrB
][
4
]
=
{};
if
(
my_cntr
+
1
==
k_rnd
)
{
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
{
for
(
uint32_t
j
=
0
;
j
<
4
;
j
++
)
{
int
nindx
=
(
j
+
(
threadIdx
.
x
/
16
)
*
4
)
+
nt
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
);
int
adr
=
mindx
+
M
*
nindx
;
vals
[
nt
][
j
]
=
glbl
[
adr
];
}
}
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
{
for
(
uint32_t
j
=
0
;
j
<
4
;
j
++
)
{
int
nindx
=
(
j
+
(
threadIdx
.
x
/
16
)
*
4
)
+
nt
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
);
if
(
nindx
>=
actlN
)
break
;
int
adr
=
mindx
+
M
*
nindx
;
if
constexpr
(
std
::
is_same_v
<
scalar_t
,
__hip_bfloat16
>
)
{
C
[
adr
]
=
__float2bfloat16
(
vals
[
nt
][
j
]);
}
else
{
C
[
adr
]
=
__float2half
(
vals
[
nt
][
j
]);
}
}
}
}
}
else
{
int
mindx
=
m
+
(
threadIdx
.
x
%
16
);
scalar_t
biases
[
N
/
NTILE
/
GrpsShrB
][
4
]
=
{};
// Atomic add the output, read biases
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
for
(
uint32_t
j
=
0
;
j
<
4
;
j
++
)
{
int
nindx
=
(
j
+
(
threadIdx
.
x
/
16
)
*
4
)
+
nt
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
);
int
adr
=
mindx
+
M
*
nindx
;
atomicAdd
(
&
glbl
[
adr
],
sum4
[
nt
][
0
][
j
]);
biases
[
nt
][
j
]
=
BIAS
[(
mindx
%
Bx
)
+
(
nindx
%
By
)
*
M
];
}
int
nindx_
=
(
0
+
(
threadIdx
.
x
/
16
)
*
4
)
+
0
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
);
int
adr_
=
mindx
+
M
*
nindx_
/
4
;
// Update the complete counter
my_cntr
=
atomicAdd
(
&
cntr
[
adr_
],
1
);
float
vals
[
N
/
NTILE
/
GrpsShrB
][
4
]
=
{};
// If we're the last k-shard, read back the value and convert...
if
(
my_cntr
+
1
==
k_rnd
)
{
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
{
for
(
uint32_t
j
=
0
;
j
<
4
;
j
++
)
{
int
nindx
=
(
j
+
(
threadIdx
.
x
/
16
)
*
4
)
+
nt
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
);
int
adr
=
mindx
+
M
*
nindx
;
vals
[
nt
][
j
]
=
glbl
[
adr
];
}
}
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
{
for
(
uint32_t
j
=
0
;
j
<
4
;
j
++
)
{
int
nindx
=
(
j
+
(
threadIdx
.
x
/
16
)
*
4
)
+
nt
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
);
if
(
nindx
>=
actlN
)
break
;
int
adr
=
mindx
+
M
*
nindx
;
if
constexpr
(
std
::
is_same_v
<
scalar_t
,
__hip_bfloat16
>
)
{
vals
[
nt
][
j
]
+=
__bfloat162float
(
biases
[
nt
][
j
]);
C
[
adr
]
=
__float2bfloat16
(
vals
[
nt
][
j
]);
}
else
{
vals
[
nt
][
j
]
+=
__half2float
(
biases
[
nt
][
j
]);
C
[
adr
]
=
__float2half
(
vals
[
nt
][
j
]);
}
}
}
}
}
}
#ifndef WVSPLITKRC_1KPASS
m0
+=
CuCount
*
WvPrGrp
*
YTILE
/
GrpsShrB
;
m
=
(
m0
+
m1
)
%
Mmod
;
k_str
=
(
m0
/
Mmod
)
*
kFit
*
kfitsPerRdc
;
k_end
=
(
m0
/
Mmod
+
1
)
*
kFit
*
kfitsPerRdc
;
if
(
k_str
>=
K
)
break
;
kBase
=
0
;
#endif
}
}
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
,
int
GrpsShrB
>
__global__
void
wvSplitKrc_
(
const
int
actlN
,
const
int
K
,
const
int
M
,
const
int
Bx
,
const
int
By
,
const
scalar_t
*
B
,
const
scalar_t
*
__restrict__
A
,
const
scalar_t
*
__restrict__
BIAS
,
float
*
glbl
,
// int* cntr,
scalar_t
*
C
,
const
int
CuCount
){
UNREACHABLE_CODE
}
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
torch
::
Tensor
wvSplitKrc
(
const
at
::
Tensor
&
in_a
,
const
at
::
Tensor
&
in_b
,
const
std
::
optional
<
at
::
Tensor
>&
in_bias
,
const
int64_t
CuCount
)
{
auto
M_in
=
in_a
.
size
(
0
);
auto
N_in
=
in_b
.
size
(
0
);
auto
K_in
=
in_a
.
size
(
1
);
auto
Bx_in
=
(
in_bias
.
has_value
()
&&
in_bias
->
numel
()
>
0
)
?
(
in_bias
->
sizes
().
size
()
==
2
)
?
in_bias
->
size
(
1
)
:
in_bias
->
size
(
0
)
:
1
;
auto
By_in
=
(
in_bias
.
has_value
()
&&
in_bias
->
numel
()
>
0
&&
in_bias
->
sizes
().
size
()
==
2
)
?
in_bias
->
size
(
0
)
:
1
;
TORCH_CHECK
(
in_a
.
dtype
()
==
in_b
.
dtype
());
TORCH_CHECK
(
K_in
%
8
==
0
,
"k % 8 == 0"
);
TORCH_CHECK
(
in_a
.
dtype
()
==
torch
::
kFloat16
||
in_a
.
dtype
()
==
torch
::
kBFloat16
);
auto
out_c
=
torch
::
empty
(
{
N_in
,
M_in
},
torch
::
TensorOptions
().
dtype
(
in_b
.
dtype
()).
device
(
in_b
.
device
()));
auto
N_p2
=
1U
<<
(
32
-
__builtin_clz
(
N_in
-
1
));
auto
axl_glbl
=
torch
::
empty
(
{
N_p2
+
N_p2
/
4
,
M_in
+
M_in
/
4
},
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
in_b
.
device
()));
axl_glbl
.
zero_
();
// disable for FAST_UNSAFE_RDC_INIT
dim3
grid
(
CuCount
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
in_a
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
// const int max_lds_len = get_lds_size() / 2;
#define WVSPLITKrc(_WvPrGrp, _YTILE, _UNRL, _N, _GrpsShrB) \
{ \
dim3 block(64, _WvPrGrp); \
wvSplitKrc_<fptype, 64, _YTILE, _WvPrGrp, 8, _UNRL, _N, _GrpsShrB> \
<<<grid, block, 0, stream>>>(N_in, K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, glbl, c, CuCount); \
}
AT_DISPATCH_REDUCED_FLOATING_TYPES
(
in_b
.
scalar_type
(),
"wvSplitKrc"
,
[
&
]
{
using
fptype
=
typename
scalar
<
scalar_t
>::
type
;
fptype
*
af4
=
reinterpret_cast
<
fptype
*>
(
in_a
.
data_ptr
());
const
fptype
*
bf4
=
reinterpret_cast
<
const
fptype
*>
(
in_b
.
data_ptr
());
const
fptype
*
biasf4
=
(
in_bias
.
has_value
()
&&
in_bias
->
numel
()
>
0
)
?
reinterpret_cast
<
const
fptype
*>
(
in_bias
->
data_ptr
())
:
nullptr
;
fptype
*
c
=
reinterpret_cast
<
fptype
*>
(
out_c
.
data_ptr
());
auto
glbl
=
axl_glbl
.
data_ptr
<
float
>
();
switch
(
N_p2
)
{
case
16
:
WVSPLITKrc
(
4
,
16
,
1
,
16
,
1
)
break
;
case
32
:
WVSPLITKrc
(
4
,
16
,
1
,
32
,
2
)
break
;
case
64
:
WVSPLITKrc
(
4
,
16
,
1
,
64
,
2
)
break
;
case
128
:
WVSPLITKrc
(
4
,
16
,
1
,
128
,
4
)
break
;
default:
throw
std
::
runtime_error
(
"Unsupported N value: "
+
std
::
to_string
(
M_in
)
+
","
+
std
::
to_string
(
K_in
)
+
","
+
std
::
to_string
(
N_in
));
}
});
return
out_c
;
}
#if defined(__HIP__MI3XX__) // TODO: Add NAVI support
#if defined(__HIP__MI3XX__) // TODO: Add NAVI support
template
<
typename
scalar_t
,
typename
fp8_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
template
<
typename
scalar_t
,
typename
fp8_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
int
A_CHUNK
,
int
UNRL
,
int
N
>
...
@@ -1381,7 +1917,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -1381,7 +1917,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
__shared__
fp8_t
s
[
max_lds_len
];
__shared__
fp8_t
s
[
max_lds_len
];
for
(
uint32_t
k
=
(
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
;
for
(
uint32_t
k
=
(
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
;
k
<
min
(
K
*
N
,
max_lds_len
);
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
k
<
min
__
(
K
*
N
,
max_lds_len
);
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
*
((
bigType
*
)(
&
s
[
k
]))
=
*
((
bigType
*
)(
&
A
[
k
]));
*
((
bigType
*
)(
&
s
[
k
]))
=
*
((
bigType
*
)(
&
A
[
k
]));
}
}
__syncthreads
();
__syncthreads
();
...
@@ -1570,7 +2106,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -1570,7 +2106,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
__shared__
fp8_t
s
[
max_lds_len
];
__shared__
fp8_t
s
[
max_lds_len
];
for
(
uint32_t
k
=
(
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
;
for
(
uint32_t
k
=
(
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
;
k
<
min
(
K
*
N
,
max_lds_len
);
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
k
<
min
__
(
K
*
N
,
max_lds_len
);
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
*
((
bigType
*
)(
&
s
[
k
]))
=
*
((
bigType
*
)(
&
A
[
k
]));
*
((
bigType
*
)(
&
s
[
k
]))
=
*
((
bigType
*
)(
&
A
[
k
]));
}
}
__syncthreads
();
__syncthreads
();
...
...
csrc/rocm/torch_bindings.cpp
View file @
7a103043
...
@@ -26,6 +26,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
...
@@ -26,6 +26,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
"Tensor"
);
"Tensor"
);
rocm_ops
.
impl
(
"wvSplitK"
,
torch
::
kCUDA
,
&
wvSplitK
);
rocm_ops
.
impl
(
"wvSplitK"
,
torch
::
kCUDA
,
&
wvSplitK
);
// Custom gemm op for skinny matrix-matrix multiplication
rocm_ops
.
def
(
"wvSplitKrc(Tensor in_a, Tensor in_b, Tensor? in_bias, int CuCount) -> "
"Tensor"
);
rocm_ops
.
impl
(
"wvSplitKrc"
,
torch
::
kCUDA
,
&
wvSplitKrc
);
// wvSplitK for fp8
// wvSplitK for fp8
rocm_ops
.
def
(
rocm_ops
.
def
(
"wvSplitKQ(Tensor in_a, Tensor in_b, Tensor? in_bias, Tensor! out_c, "
"wvSplitKQ(Tensor in_a, Tensor in_b, Tensor? in_bias, Tensor! out_c, "
...
...
tests/kernels/quantization/test_rocm_skinny_gemms.py
View file @
7a103043
...
@@ -8,9 +8,11 @@ import torch
...
@@ -8,9 +8,11 @@ import torch
import
vllm._custom_ops
as
ops
import
vllm._custom_ops
as
ops
from
tests.kernels.quant_utils
import
ref_dynamic_per_tensor_fp8_quant
from
tests.kernels.quant_utils
import
ref_dynamic_per_tensor_fp8_quant
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.platforms.rocm
import
on_gfx950
from
vllm.utils.platform_utils
import
get_cu_count
from
vllm.utils.platform_utils
import
get_cu_count
DTYPES
=
[
torch
.
bfloat16
,
torch
.
float16
]
DTYPES
=
[
torch
.
bfloat16
,
torch
.
float16
]
BIAS_MODES
=
[
0
,
1
,
2
]
# Specific (N, K, M) combinations for targeted testing
# Specific (N, K, M) combinations for targeted testing
NKM_FACTORS_LLMM1
=
[
NKM_FACTORS_LLMM1
=
[
# Small, medium, large cases
# Small, medium, large cases
...
@@ -43,6 +45,31 @@ NKM_FACTORS_WVSPLITK = [
...
@@ -43,6 +45,31 @@ NKM_FACTORS_WVSPLITK = [
(
4
,
256
,
8
),
(
4
,
256
,
8
),
]
]
NKM_FACTORS_WVSPLITKRC
=
[
(
16
,
2880
,
128
),
(
16
,
2880
,
640
),
(
17
,
2880
,
128
),
(
17
,
2880
,
640
),
(
25
,
2880
,
128
),
(
25
,
2880
,
640
),
(
31
,
2880
,
128
),
(
31
,
2880
,
640
),
(
32
,
2880
,
128
),
(
32
,
2880
,
640
),
(
40
,
2880
,
128
),
(
40
,
2880
,
640
),
(
60
,
2880
,
128
),
(
60
,
2880
,
640
),
(
64
,
2880
,
128
),
(
64
,
2880
,
640
),
(
81
,
2880
,
128
),
(
81
,
2880
,
640
),
(
98
,
2880
,
128
),
(
98
,
2880
,
640
),
(
128
,
2880
,
128
),
(
128
,
2880
,
640
),
]
NKM_FACTORS_WVSPLITK_FP8
=
[
NKM_FACTORS_WVSPLITK_FP8
=
[
# FP8-specific cases with K % 16 == 0
# FP8-specific cases with K % 16 == 0
(
1
,
16
,
16
),
(
1
,
16
,
16
),
...
@@ -60,6 +87,32 @@ NKM_FACTORS_WVSPLITK_FP8 = [
...
@@ -60,6 +87,32 @@ NKM_FACTORS_WVSPLITK_FP8 = [
SEEDS
=
[
0
]
SEEDS
=
[
0
]
@
pytest
.
mark
.
parametrize
(
"n,k,m"
,
NKM_FACTORS_WVSPLITKRC
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"bias_mode"
,
BIAS_MODES
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_rocm
(),
reason
=
"only test for rocm"
)
@
pytest
.
mark
.
skipif
(
not
on_gfx950
(),
reason
=
"only meant for gfx950"
)
def
test_rocm_wvsplitkrc_kernel
(
n
,
k
,
m
,
dtype
,
seed
,
bias_mode
):
torch
.
manual_seed
(
seed
)
cu_count
=
get_cu_count
()
xavier
=
math
.
sqrt
(
2
/
k
)
# normalize to avoid large output-bias deltas
A
=
(
torch
.
rand
(
n
,
k
,
dtype
=
dtype
,
device
=
"cuda"
)
-
0.5
)
*
xavier
B
=
(
torch
.
rand
(
m
,
k
,
dtype
=
dtype
,
device
=
"cuda"
)
-
0.5
)
*
xavier
BIAS
=
None
if
bias_mode
==
1
:
BIAS
=
torch
.
rand
(
m
,
dtype
=
dtype
,
device
=
"cuda"
)
-
0.5
elif
bias_mode
==
2
:
BIAS
=
torch
.
rand
(
n
,
m
,
dtype
=
dtype
,
device
=
"cuda"
)
-
0.5
ref_out
=
torch
.
nn
.
functional
.
linear
(
A
,
B
,
BIAS
)
out
=
ops
.
wvSplitKrc
(
B
,
A
.
view
(
-
1
,
A
.
size
(
-
1
)),
cu_count
,
BIAS
)
assert
torch
.
allclose
(
out
,
ref_out
,
rtol
=
0.01
)
@
pytest
.
mark
.
parametrize
(
"n,k,m"
,
NKM_FACTORS_LLMM1
)
@
pytest
.
mark
.
parametrize
(
"n,k,m"
,
NKM_FACTORS_LLMM1
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"rows_per_block"
,
[
2
,
4
,
8
,
16
])
@
pytest
.
mark
.
parametrize
(
"rows_per_block"
,
[
2
,
4
,
8
,
16
])
...
...
vllm/_custom_ops.py
View file @
7a103043
...
@@ -2072,6 +2072,12 @@ def wvSplitK(
...
@@ -2072,6 +2072,12 @@ def wvSplitK(
return
torch
.
ops
.
_rocm_C
.
wvSplitK
(
a
,
b
,
bias
,
cu_count
)
return
torch
.
ops
.
_rocm_C
.
wvSplitK
(
a
,
b
,
bias
,
cu_count
)
def
wvSplitKrc
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
cu_count
:
int
,
bias
:
torch
.
Tensor
=
None
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_rocm_C
.
wvSplitKrc
(
a
,
b
,
bias
,
cu_count
)
def
wvSplitKQ
(
def
wvSplitKQ
(
a
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/utils.py
View file @
7a103043
...
@@ -129,12 +129,32 @@ def use_aiter_triton_gemm(n, m, k, dtype):
...
@@ -129,12 +129,32 @@ def use_aiter_triton_gemm(n, m, k, dtype):
def
rocm_unquantized_gemm_impl
(
def
rocm_unquantized_gemm_impl
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
vllm.platforms.rocm
import
on_gfx9
from
vllm.platforms.rocm
import
on_gfx9
,
on_gfx950
n
=
x
.
numel
()
/
x
.
size
(
-
1
)
n
=
x
.
numel
()
/
x
.
size
(
-
1
)
m
=
weight
.
shape
[
0
]
m
=
weight
.
shape
[
0
]
k
=
weight
.
shape
[
1
]
k
=
weight
.
shape
[
1
]
import
math
use_skinny_reduce_counting
=
(
envs
.
VLLM_ROCM_USE_SKINNY_GEMM
and
on_gfx950
()
and
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
and
(
n
>=
16
and
n
<=
128
and
k
>
512
and
math
.
ceil
(
k
/
512
)
*
math
.
ceil
(
m
/
16
)
<
get_cu_count
()
)
# k == 2880 and (m == 640 or m == 128))
)
if
use_skinny_reduce_counting
:
cu_count
=
get_cu_count
()
x_view
=
x
.
reshape
(
-
1
,
x
.
size
(
-
1
))
out
=
ops
.
wvSplitKrc
(
weight
,
x_view
,
cu_count
,
bias
)
return
out
.
reshape
(
*
x
.
shape
[:
-
1
],
weight
.
shape
[
0
])
if
use_aiter_triton_gemm
(
n
,
m
,
k
,
x
.
dtype
):
if
use_aiter_triton_gemm
(
n
,
m
,
k
,
x
.
dtype
):
from
aiter.ops.triton.gemm_a16w16
import
gemm_a16w16
from
aiter.ops.triton.gemm_a16w16
import
gemm_a16w16
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment