Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ed17f54c
Unverified
Commit
ed17f54c
authored
Feb 07, 2026
by
Hashem Hashemi
Committed by
GitHub
Feb 07, 2026
Browse files
Perf tuning and expansion of cases covered for wvSplitKrc (#33493)
Signed-off-by:
Hashem Hashemi
<
hashem.hashemi@amd.com
>
parent
860981d8
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
214 additions
and
223 deletions
+214
-223
csrc/rocm/skinny_gemms.cu
csrc/rocm/skinny_gemms.cu
+143
-184
tests/kernels/quantization/test_rocm_skinny_gemms.py
tests/kernels/quantization/test_rocm_skinny_gemms.py
+52
-31
vllm/model_executor/layers/utils.py
vllm/model_executor/layers/utils.py
+19
-8
No files found.
csrc/rocm/skinny_gemms.cu
View file @
ed17f54c
...
@@ -1365,13 +1365,12 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
...
@@ -1365,13 +1365,12 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
return
out_c
;
return
out_c
;
}
}
#if defined(__gfx950__) // TODO: Add NAVI support
// This version targets cases skinny where CUs are not filled
//
This version targets big A[] cases, where it is much larger than LDS
//
Wave-SplitK is used with reduction done via atomics.
// capacity
#if defined(__gfx950__)
#define WVSPLITKRC_1KPASS
#define WVSPLITKRC_1KPASS
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
,
int
GrpsShrB
>
int
UNRL
,
int
N
,
int
GrpsShrB
,
int
CHUNKK
>
__global__
void
__launch_bounds__
(
WvPrGrp
*
THRDS
)
__global__
void
__launch_bounds__
(
WvPrGrp
*
THRDS
)
__attribute__
((
amdgpu_waves_per_eu
(
1
,
1
)))
__attribute__
((
amdgpu_waves_per_eu
(
1
,
1
)))
wvSplitKrc_
(
const
int
actlN
,
const
int
K
,
const
int
M
,
const
int
Bx
,
wvSplitKrc_
(
const
int
actlN
,
const
int
K
,
const
int
M
,
const
int
Bx
,
...
@@ -1383,12 +1382,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -1383,12 +1382,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
int
*
cntr
=
(
int
*
)(
&
glbl
[
M
*
N
]);
int
*
cntr
=
(
int
*
)(
&
glbl
[
M
*
N
]);
constexpr
int
NTILE
=
16
;
constexpr
int
NTILE
=
16
;
constexpr
int
WVLDS_
=
(
NTILE
*
THRDS
*
A_CHUNK
);
constexpr
int
APAD
=
1
;
constexpr
int
APAD
=
1
;
constexpr
int
ASTRD
=
64
;
constexpr
int
ASTRD
=
64
;
constexpr
int
BPAD
=
1
;
constexpr
int
BPAD
=
1
;
constexpr
int
BSTRD
=
64
;
constexpr
int
WVLDS_
=
THRDS
*
A_CHUNK
/
CHUNKK
;
constexpr
int
WVLDS
=
((
WVLDS_
+
(
WVLDS_
/
BSTRD
)
*
4
*
BPAD
));
constexpr
int
WVLDS
=
((
WVLDS_
+
A_CHUNK
*
BPAD
))
*
YTILE
;
constexpr
int
max_lds_len
=
LDS_SIZE
/
2
;
constexpr
int
max_lds_len
=
LDS_SIZE
/
2
;
...
@@ -1442,17 +1440,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -1442,17 +1440,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
break
;
break
;
}
}
#else
#else
int
constexpr
kFit
=
512
;
int
constexpr
kFit
=
512
/
CHUNKK
;
int
constexpr
kfitsPerRdc
=
1
;
int
constexpr
kfitsPerRdc
=
1
;
#endif
#endif
bool
doRdc
=
(
kfitsPerRdc
*
kFit
<
K
)
;
bool
doRdc
=
true
;
// Assuming
(kfitsPerRdc * kFit < K)
is always true
uint32_t
numCuWithFullK
=
uint32_t
numCuWithFullK
=
((
M
+
(
WvPrGrp
*
YTILE
/
GrpsShrB
)
-
1
)
/
(
WvPrGrp
*
YTILE
/
GrpsShrB
));
((
M
+
(
WvPrGrp
*
YTILE
/
GrpsShrB
)
-
1
)
/
(
WvPrGrp
*
YTILE
/
GrpsShrB
));
uint32_t
Mmod
=
numCuWithFullK
*
(
WvPrGrp
*
YTILE
/
GrpsShrB
);
uint32_t
Mmod
=
numCuWithFullK
*
(
WvPrGrp
*
YTILE
/
GrpsShrB
);
// given above k-split, find this wave's position
// given above k-split, find this wave's position
uint32_t
kFitPdd
=
kFit
+
(
kFit
/
ASTRD
)
*
APAD
;
uint32_t
kFitPdd
=
kFit
*
CHUNKK
+
(
(
kFit
*
CHUNKK
)
/
ASTRD
)
*
APAD
;
uint32_t
m0
=
(
blockIdx
.
x
*
WvPrGrp
/
GrpsShrB
)
*
YTILE
;
uint32_t
m0
=
(
blockIdx
.
x
*
WvPrGrp
/
GrpsShrB
)
*
YTILE
;
uint32_t
m1
=
((
threadIdx
.
y
%
WvPrGrp
)
/
GrpsShrB
)
*
YTILE
;
uint32_t
m1
=
((
threadIdx
.
y
%
WvPrGrp
)
/
GrpsShrB
)
*
YTILE
;
uint32_t
m
=
(
m0
+
m1
)
%
Mmod
;
uint32_t
m
=
(
m0
+
m1
)
%
Mmod
;
...
@@ -1460,8 +1458,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -1460,8 +1458,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t
k_end
=
(
m0
/
Mmod
+
1
)
*
kFit
*
kfitsPerRdc
;
uint32_t
k_end
=
(
m0
/
Mmod
+
1
)
*
kFit
*
kfitsPerRdc
;
const
uint32_t
k_rnd
=
(
K
+
kFit
*
kfitsPerRdc
-
1
)
/
(
kFit
*
kfitsPerRdc
);
const
uint32_t
k_rnd
=
(
K
+
kFit
*
kfitsPerRdc
-
1
)
/
(
kFit
*
kfitsPerRdc
);
scalar8
sum4
[
N
/
NTILE
/
GrpsShrB
][
1
];
scalar8
sum4
[
N
/
NTILE
/
GrpsShrB
][
1
]
=
{
0
}
;
bigType
bigB_
[
YTILE
/
GrpsShrB
][
UNRL
];
bigType
bigB_
[
YTILE
/
GrpsShrB
/
CHUNKK
][
UNRL
];
const
uint32_t
bLoader
=
(
threadIdx
.
y
%
GrpsShrB
);
const
uint32_t
bLoader
=
(
threadIdx
.
y
%
GrpsShrB
);
uint32_t
kBase
=
0
;
uint32_t
kBase
=
0
;
if
(
k_str
>=
K
)
return
;
if
(
k_str
>=
K
)
return
;
...
@@ -1498,12 +1496,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -1498,12 +1496,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#pragma unroll
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k_str
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k
=
k_str
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
uint32_t
k_
=
k
+
(
threadIdx
.
x
%
(
THRDS
/
CHUNKK
))
*
A_CHUNK
;
const
scalar_t
*
B_
=
&
B
[
min__
(
k_
,
K
-
A_CHUNK
)];
const
scalar_t
*
B_
=
&
B
[
min__
(
k_
,
K
-
A_CHUNK
)];
#pragma unroll
#pragma unroll
for
(
uint32_t
y
=
0
;
y
<
YTILE
/
GrpsShrB
;
y
++
)
for
(
uint32_t
y
=
0
;
y
<
YTILE
/
GrpsShrB
;
y
+=
CHUNKK
)
bigB_
[
y
][
k2
].
h8
=
(
loadnt
(
bigB_
[
y
/
CHUNKK
][
k2
].
h8
=
(
loadnt
(
(
scalar8
*
)(
&
B_
[
min__
(
y
*
GrpsShrB
+
bLoader
+
m
,
M
-
1
)
*
K
])));
(
scalar8
*
)(
&
B_
[
min__
((
y
+
threadIdx
.
x
/
(
THRDS
/
CHUNKK
))
*
GrpsShrB
+
bLoader
+
m
,
M
-
1
)
*
K
])));
}
}
{
{
#else
#else
...
@@ -1556,48 +1557,51 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -1556,48 +1557,51 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if
(
reloada
)
{
if
(
reloada
)
{
#endif
#endif
constexpr
int
sprdN
=
4
;
constexpr
int
sprdN
=
4
;
const
uint32_t
thrd
=
((
threadIdx
.
y
/
sprdN
)
*
THRDS
+
threadIdx
.
x
);
const
uint32_t
thrd
=
threadIdx
.
x
%
(
THRDS
/
CHUNKK
);
#ifndef WVSPLITKRC_1KPASS
#ifndef WVSPLITKRC_1KPASS
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
kFit
;
k
+=
THRDS
*
(
WvPrGrp
/
sprdN
)
*
A_CHUNK
)
{
for
(
int
k
=
0
;
k
<
kFit
;
k
+=
(
THRDS
*
(
WvPrGrp
/
sprdN
)
*
A_CHUNK
)
/
CHUNKK
)
{
#else
#else
const
unsigned
int
k
=
0
;
const
unsigned
int
k
=
0
;
{
{
#endif
#endif
unsigned
int
kOff
=
k
+
(
thrd
*
A_CHUNK
);
unsigned
int
kOff
=
k
+
(
thrd
*
A_CHUNK
);
unsigned
int
kOffcp
=
min__
(
K
-
A_CHUNK
,
k_str
+
kOff
);
unsigned
int
kOffcp
=
const
unsigned
int
k_in
=
kOffcp
+
((
threadIdx
.
y
%
sprdN
))
*
K
;
k_str
+
kOff
;
// min__(K - A_CHUNK, k_str + kOff);
const
unsigned
int
k_ot
=
kOff
+
((
threadIdx
.
y
%
sprdN
))
*
kFitPdd
;
for
(
unsigned
int
n
=
0
;
n
<
N
;
n
+=
CHUNKK
*
sprdN
)
{
for
(
unsigned
int
n
=
0
;
n
<
N
/
2
;
n
+=
sprdN
)
{
__builtin_amdgcn_global_load_lds
((
int
*
)(
&
A
[
k_in
+
n
*
K
]),
(
int
*
)(
&
s
[(
k_ot
+
n
*
kFitPdd
)]),
16
,
0
,
0
);
if
(((
threadIdx
.
y
%
sprdN
))
+
n
+
N
/
2
>=
actlN
)
continue
;
__builtin_amdgcn_global_load_lds
(
__builtin_amdgcn_global_load_lds
(
(
int
*
)(
&
A
[
k_in
+
(
n
+
N
/
2
)
*
K
]),
(
int
*
)(
&
A
[
min__
(
(
int
*
)(
&
s
[(
k_ot
+
(
n
+
N
/
2
)
*
kFitPdd
)]),
16
,
0
,
0
);
K
*
actlN
-
A_CHUNK
,
kOffcp
+
K
*
(
n
/
CHUNKK
+
(
N
/
CHUNKK
)
*
(
threadIdx
.
x
/
(
64
/
CHUNKK
))
+
(
threadIdx
.
y
%
sprdN
)))]),
(
int
*
)(
&
s
[(
k
+
kFitPdd
*
((
n
/
CHUNKK
)
+
(
threadIdx
.
y
%
sprdN
)))]),
16
,
0
,
0
);
}
}
// Stage loaded B[] to LDS for MFMA swizzling...
// Stage loaded B[] to LDS for MFMA swizzling...
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
uint32_t
k_
=
k
+
(
threadIdx
.
x
%
(
THRDS
/
CHUNKK
))
*
A_CHUNK
;
const
bool
oob_k
=
(
k_
>=
K
);
const
bool
oob_k
=
(
k_
>=
K
);
for
(
uint32_t
y
=
0
;
y
<
YTILE
/
GrpsShrB
;
y
++
)
{
for
(
uint32_t
y
=
0
;
y
<
YTILE
/
GrpsShrB
;
y
+=
CHUNKK
)
{
uint32_t
idx
=
threadIdx
.
x
*
4
+
uint32_t
idx
=
(
y
*
GrpsShrB
+
bLoader
)
*
((
THRDS
+
BPAD
)
*
4
);
(
threadIdx
.
x
%
(
THRDS
/
CHUNKK
))
*
4
+
((
y
+
threadIdx
.
x
/
(
THRDS
/
CHUNKK
))
*
GrpsShrB
+
bLoader
)
*
((
THRDS
/
CHUNKK
+
BPAD
)
*
4
);
// zero out if oob
// zero out if oob
*
((
scalar8
*
)
&
myStg
[
idx
])
=
*
((
scalar8
*
)
&
myStg
[
idx
])
=
(
oob_k
||
(
y
*
GrpsShrB
+
bLoader
+
m
>=
M
)
)
(
oob_k
)
// TODO: ever necessary (y*
GrpsShrB
+
bLoader
+m
>=M)
?
?
0
?
0
:
bigB_
[
y
][
k2
].
h8
;
:
bigB_
[
y
/
CHUNKK
][
k2
].
h8
;
}
}
}
}
}
}
}
}
}
}
#ifndef WVSPLITKRC_1KPASS
#ifndef WVSPLITKRC_1KPASS
// Fire load of next B[] chunk...
// Fire load of next B[] chunk...
if
((
k1
+
THRDS
*
A_CHUNK
*
UNRL
<
k_end
)
&&
if
((
k1
+
THRDS
*
A_CHUNK
*
UNRL
<
k_end
)
&&
...
@@ -1608,40 +1612,50 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -1608,40 +1612,50 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
const
scalar_t
*
B_
=
&
B
[
min__
(
k_
,
K
-
A_CHUNK
)];
const
scalar_t
*
B_
=
&
B
[
min__
(
k_
,
K
-
A_CHUNK
)];
#pragma unroll
#pragma unroll
for
(
uint32_t
y
=
0
;
y
<
YTILE
/
GrpsShrB
;
y
++
)
for
(
uint32_t
y
=
0
;
y
<
YTILE
/
GrpsShrB
;
y
+=
CHUNKK
)
bigB_
[
y
][
k2
].
h8
=
(
loadnt
(
bigB_
[
y
/
CHUNKK
][
k2
].
h8
=
(
loadnt
(
(
scalar8
*
)(
&
B_
[
min__
(
y
*
GrpsShrB
+
bLoader
+
m
,
M
-
1
)
*
K
])));
(
scalar8
*
)(
&
B_
[
min__
((
y
+
threadIdx
.
x
/
(
THRDS
/
CHUNKK
))
*
GrpsShrB
+
bLoader
+
m
,
M
-
1
)
*
K
])));
}
}
#endif
#endif
// B[] staging is cooperative across GrpsShrB, so sync here before reading
// B[] staging is cooperative across GrpsShrB, so sync here before reading
// back
// back. This wait is currently inserted by compiler, but not gauranteed.
asm
volatile
(
"s_waitcnt 0"
);
__syncthreads
();
__syncthreads
();
// read back B[] swizzled for MFMA...
// read back B[] swizzled for MFMA...
bigType
bigB
[
YTILE
][
UNRL
];
bigType
bigB
[
YTILE
/
CHUNKK
][
UNRL
];
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
for
(
uint32_t
y
=
0
;
y
<
YTILE
;
y
++
)
{
for
(
uint32_t
y
=
0
;
y
<
YTILE
/
CHUNKK
;
y
++
)
{
unsigned
int
idx
=
(
threadIdx
.
x
%
YTILE
)
*
((
THRDS
+
BPAD
)
*
4
)
+
unsigned
int
idx
=
(
threadIdx
.
x
%
YTILE
)
*
((
THRDS
/
CHUNKK
+
BPAD
)
*
4
)
+
(
threadIdx
.
x
/
YTILE
)
*
4
+
y
*
16
;
(
threadIdx
.
x
/
YTILE
)
*
4
+
y
*
16
;
bigB
[
y
][
k2
].
h8
=
*
((
scalar8
*
)
&
myStg
[
idx
]);
bigB
[
y
][
k2
].
h8
=
*
((
scalar8
*
)
&
myStg
[
idx
]);
}
}
}
}
// rReadback A[] swizzled for MFMA...
// rReadback A[] swizzled for MFMA...
bigType
bigA
[
N
/
GrpsShrB
][
UNRL
];
bigType
bigA
[
N
/
GrpsShrB
/
CHUNKK
][
UNRL
];
#pragma unroll
#pragma unroll
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
-
kBase
-
k_str
;
uint32_t
k
=
k1
+
k2
*
THRDS
*
A_CHUNK
-
kBase
-
k_str
;
#pragma unroll
#pragma unroll
for
(
uint32_t
nt
=
0
;
nt
<
N
/
GrpsShrB
;
nt
+=
NTILE
)
for
(
uint32_t
nt
=
0
;
nt
<
N
/
GrpsShrB
;
nt
+=
NTILE
)
#pragma unroll
#pragma unroll
for
(
uint32_t
n
=
0
;
n
<
NTILE
;
n
++
)
{
for
(
uint32_t
n
=
0
;
n
<
NTILE
/
CHUNKK
;
n
++
)
{
uint32_t
idxa
=
(
nt
+
(
threadIdx
.
x
%
NTILE
)
+
uint32_t
idxa
=
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
))
*
((
nt
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
))
%
(
N
/
CHUNKK
)
+
(
threadIdx
.
x
%
NTILE
))
*
kFitPdd
+
kFitPdd
+
((
nt
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
))
/
(
N
/
CHUNKK
))
*
A_CHUNK
*
(
64
/
CHUNKK
)
+
A_CHUNK
*
((
threadIdx
.
x
/
NTILE
)
+
n
*
4
)
+
k
;
A_CHUNK
*
((
threadIdx
.
x
/
NTILE
)
+
n
*
4
)
+
k
;
bigA
[
nt
+
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
s
[
idxa
])));
bigA
[
nt
/
CHUNKK
+
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
s
[
idxa
])));
}
}
}
}
...
@@ -1650,122 +1664,37 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -1650,122 +1664,37 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
for
(
uint32_t
k2
=
0
;
k2
<
UNRL
;
k2
++
)
{
#pragma unroll
#pragma unroll
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
{
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
{
if
constexpr
(
std
::
is_same_v
<
scalar_t
,
half
>
)
{
sum4
[
nt
][
0
]
=
__builtin_amdgcn_mfma_f32_16x16x16f16
(
bigA
[
nt
*
NTILE
+
0
][
k2
].
h4
[
0
],
bigB
[
0
][
k2
].
h4
[
0
],
(
k1
==
k_str
)
?
((
scalar8
){
0
})
:
sum4
[
nt
][
0
],
0
,
0
,
0
);
sum4
[
nt
][
0
]
=
__builtin_amdgcn_mfma_f32_16x16x16f16
(
bigA
[
nt
*
NTILE
+
0
][
k2
].
h4
[
1
],
bigB
[
0
][
k2
].
h4
[
1
],
sum4
[
nt
][
0
],
0
,
0
,
0
);
}
else
{
// bf16
sum4
[
nt
][
0
]
=
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k
(
bigA
[
nt
*
NTILE
+
0
][
k2
].
h4
[
0
],
bigB
[
0
][
k2
].
h4
[
0
],
(
k1
==
k_str
)
?
((
scalar8
){
0
})
:
sum4
[
nt
][
0
],
0
,
0
,
0
);
sum4
[
nt
][
0
]
=
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k
(
bigA
[
nt
*
NTILE
+
0
][
k2
].
h4
[
1
],
bigB
[
0
][
k2
].
h4
[
1
],
sum4
[
nt
][
0
],
0
,
0
,
0
);
}
#pragma unroll
#pragma unroll
for
(
uint32_t
j
=
1
;
j
<
YTILE
;
j
++
)
{
for
(
uint32_t
j
=
0
;
j
<
YTILE
/
CHUNKK
;
j
++
)
{
if
constexpr
(
std
::
is_same_v
<
scalar_t
,
half
>
)
{
if
constexpr
(
std
::
is_same_v
<
scalar_t
,
half
>
)
{
sum4
[
nt
][
0
]
=
__builtin_amdgcn_mfma_f32_16x16x16f16
(
sum4
[
nt
][
0
]
=
__builtin_amdgcn_mfma_f32_16x16x32_f16
(
bigA
[
nt
*
NTILE
+
j
][
k2
].
h4
[
0
],
bigB
[
j
][
k2
].
h4
[
0
],
sum4
[
nt
][
0
],
bigA
[
nt
*
(
YTILE
/
CHUNKK
)
+
j
][
k2
].
h8
,
bigB
[
j
][
k2
].
h8
,
0
,
0
,
0
);
sum4
[
nt
][
0
],
0
,
0
,
0
);
sum4
[
nt
][
0
]
=
__builtin_amdgcn_mfma_f32_16x16x16f16
(
bigA
[
nt
*
NTILE
+
j
][
k2
].
h4
[
1
],
bigB
[
j
][
k2
].
h4
[
1
],
sum4
[
nt
][
0
],
0
,
0
,
0
);
}
else
{
// bf16
}
else
{
// bf16
sum4
[
nt
][
0
]
=
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k
(
sum4
[
nt
][
0
]
=
__builtin_amdgcn_mfma_f32_16x16x32_bf16
(
bigA
[
nt
*
NTILE
+
j
][
k2
].
h4
[
0
],
bigB
[
j
][
k2
].
h4
[
0
],
sum4
[
nt
][
0
],
bigA
[
nt
*
(
YTILE
/
CHUNKK
)
+
j
][
k2
].
h8
,
bigB
[
j
][
k2
].
h8
,
0
,
0
,
0
);
sum4
[
nt
][
0
],
0
,
0
,
0
);
sum4
[
nt
][
0
]
=
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k
(
bigA
[
nt
*
NTILE
+
j
][
k2
].
h4
[
1
],
bigB
[
j
][
k2
].
h4
[
1
],
sum4
[
nt
][
0
],
0
,
0
,
0
);
}
}
}
}
}
}
}
}
}
}
if
(
!
doRdc
)
{
if
(
m
+
(
threadIdx
.
x
%
16
)
<
M
)
{
scalar_t
biases
[
N
/
NTILE
/
GrpsShrB
][
4
]
=
{
0
};
if
(
BIAS
)
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
{
for
(
uint32_t
j
=
0
;
j
<
4
;
j
++
)
{
int
mindx
=
m
+
(
threadIdx
.
x
%
16
);
int
nindx
=
(
j
+
(
threadIdx
.
x
/
16
)
*
4
)
+
nt
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
);
biases
[
nt
][
j
]
=
BIAS
[(
mindx
%
Bx
)
+
(
nindx
%
By
)
*
M
];
}
}
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
{
for
(
uint32_t
j
=
0
;
j
<
4
;
j
++
)
{
int
mindx
=
m
+
(
threadIdx
.
x
%
16
);
int
nindx
=
(
j
+
(
threadIdx
.
x
/
16
)
*
4
)
+
nt
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
);
int
adr
=
mindx
+
M
*
nindx
;
if
constexpr
(
std
::
is_same_v
<
scalar_t
,
__hip_bfloat16
>
)
{
if
(
BIAS
)
sum4
[
nt
][
0
][
j
]
+=
__bfloat162float
(
biases
[
nt
][
j
]);
C
[
adr
]
=
__float2bfloat16
(
sum4
[
nt
][
0
][
j
]);
}
else
{
if
(
BIAS
)
sum4
[
nt
][
0
][
j
]
+=
__half2float
(
biases
[
nt
][
j
]);
C
[
adr
]
=
__float2half
(
sum4
[
nt
][
0
][
j
]);
}
}
}
}
}
else
{
if
(
m
+
(
threadIdx
.
x
%
16
)
<
M
)
{
if
(
m
+
(
threadIdx
.
x
%
16
)
<
M
)
{
int
my_cntr
;
int
my_cntr
;
if
(
!
BIAS
)
{
int
mindx
=
m
+
(
threadIdx
.
x
%
16
);
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
for
(
uint32_t
j
=
0
;
j
<
4
;
j
++
)
{
int
nindx
=
(
j
+
(
threadIdx
.
x
/
16
)
*
4
)
+
nt
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
);
int
adr
=
mindx
+
M
*
nindx
;
atomicAdd
(
&
glbl
[
adr
],
sum4
[
nt
][
0
][
j
]);
}
int
nindx_
=
(
0
+
(
threadIdx
.
x
/
16
)
*
4
)
+
0
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
);
int
adr_
=
mindx
+
M
*
nindx_
/
4
;
my_cntr
=
atomicAdd
(
&
cntr
[
adr_
],
1
);
float
vals
[
N
/
NTILE
/
GrpsShrB
][
4
]
=
{};
if
(
my_cntr
+
1
==
k_rnd
)
{
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
{
for
(
uint32_t
j
=
0
;
j
<
4
;
j
++
)
{
int
nindx
=
(
j
+
(
threadIdx
.
x
/
16
)
*
4
)
+
nt
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
);
int
adr
=
mindx
+
M
*
nindx
;
vals
[
nt
][
j
]
=
glbl
[
adr
];
}
}
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
{
for
(
uint32_t
j
=
0
;
j
<
4
;
j
++
)
{
int
nindx
=
(
j
+
(
threadIdx
.
x
/
16
)
*
4
)
+
nt
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
);
if
(
nindx
>=
actlN
)
break
;
int
adr
=
mindx
+
M
*
nindx
;
if
constexpr
(
std
::
is_same_v
<
scalar_t
,
__hip_bfloat16
>
)
{
C
[
adr
]
=
__float2bfloat16
(
vals
[
nt
][
j
]);
}
else
{
C
[
adr
]
=
__float2half
(
vals
[
nt
][
j
]);
}
}
}
}
}
else
{
int
mindx
=
m
+
(
threadIdx
.
x
%
16
);
int
mindx
=
m
+
(
threadIdx
.
x
%
16
);
int
g_mindx
=
m
*
4
+
(
threadIdx
.
x
%
64
);
// coalesced atomic reduction
scalar_t
biases
[
N
/
NTILE
/
GrpsShrB
][
4
]
=
{};
scalar_t
biases
[
N
/
NTILE
/
GrpsShrB
][
4
]
=
{};
// Atomic add the output, read biases
// Atomic add the output, read biases
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
for
(
uint32_t
j
=
0
;
j
<
4
;
j
++
)
{
for
(
uint32_t
j
=
0
;
j
<
4
;
j
++
)
{
int
nindx
=
(
j
+
(
threadIdx
.
x
/
16
)
*
4
)
+
nt
*
NTILE
+
// int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
);
// (N / GrpsShrB) * (threadIdx.y % GrpsShrB);
int
adr
=
mindx
+
M
*
nindx
;
// int adr = mindx + M * nindx;
atomicAdd
(
&
glbl
[
adr
],
sum4
[
nt
][
0
][
j
]);
int
g_nindx
=
biases
[
nt
][
j
]
=
BIAS
[(
mindx
%
Bx
)
+
(
nindx
%
By
)
*
M
];
j
+
(
nt
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
))
/
4
;
int
g_adr
=
g_mindx
+
M
*
g_nindx
*
4
;
atomicAdd
(
&
glbl
[
g_adr
],
sum4
[
nt
][
0
][
j
]);
}
}
int
nindx_
=
(
0
+
(
threadIdx
.
x
/
16
)
*
4
)
+
0
*
NTILE
+
int
nindx_
=
(
0
+
(
threadIdx
.
x
/
16
)
*
4
)
+
0
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
);
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
);
...
@@ -1775,19 +1704,28 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -1775,19 +1704,28 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
float
vals
[
N
/
NTILE
/
GrpsShrB
][
4
]
=
{};
float
vals
[
N
/
NTILE
/
GrpsShrB
][
4
]
=
{};
// If we're the last k-shard, read back the value and convert...
// If we're the last k-shard, read back the value and convert...
if
(
my_cntr
+
1
==
k_rnd
)
{
if
(
my_cntr
+
1
==
k_rnd
)
{
if
(
BIAS
)
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
{
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
{
for
(
uint32_t
j
=
0
;
j
<
4
;
j
++
)
{
for
(
uint32_t
j
=
0
;
j
<
4
;
j
++
)
{
int
nindx
=
(
j
+
(
threadIdx
.
x
/
16
)
*
4
)
+
nt
*
NTILE
+
int
nindx
=
(
j
+
(
threadIdx
.
x
/
16
)
*
4
)
+
nt
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
);
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
);
int
adr
=
mindx
+
M
*
nindx
;
biases
[
nt
][
j
]
=
BIAS
[(
mindx
%
Bx
)
+
(
nindx
%
By
)
*
Bx
];
vals
[
nt
][
j
]
=
glbl
[
adr
];
}
}
}
}
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
{
for
(
uint32_t
j
=
0
;
j
<
4
;
j
++
)
{
int
g_nindx
=
j
+
(
nt
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
))
/
4
;
int
g_adr
=
g_mindx
+
M
*
g_nindx
*
4
;
vals
[
nt
][
j
]
=
glbl
[
g_adr
];
}
}
__builtin_amdgcn_sched_barrier
(
0
);
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
{
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
{
for
(
uint32_t
j
=
0
;
j
<
4
;
j
++
)
{
for
(
uint32_t
j
=
0
;
j
<
4
;
j
++
)
{
int
nindx
=
(
j
+
(
threadIdx
.
x
/
16
)
*
4
)
+
nt
*
NTILE
+
int
nindx
=
(
j
+
(
threadIdx
.
x
/
16
)
*
4
)
+
nt
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
);
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
);
if
(
nindx
>=
actlN
)
break
;
if
(
nindx
<
actlN
)
{
int
adr
=
mindx
+
M
*
nindx
;
int
adr
=
mindx
+
M
*
nindx
;
if
constexpr
(
std
::
is_same_v
<
scalar_t
,
__hip_bfloat16
>
)
{
if
constexpr
(
std
::
is_same_v
<
scalar_t
,
__hip_bfloat16
>
)
{
vals
[
nt
][
j
]
+=
__bfloat162float
(
biases
[
nt
][
j
]);
vals
[
nt
][
j
]
+=
__bfloat162float
(
biases
[
nt
][
j
]);
...
@@ -1800,7 +1738,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -1800,7 +1738,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
}
}
}
}
}
}
#ifndef WVSPLITKRC_1KPASS
#ifndef WVSPLITKRC_1KPASS
m0
+=
CuCount
*
WvPrGrp
*
YTILE
/
GrpsShrB
;
m0
+=
CuCount
*
WvPrGrp
*
YTILE
/
GrpsShrB
;
...
@@ -1814,7 +1751,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -1814,7 +1751,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
}
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
,
int
GrpsShrB
>
int
UNRL
,
int
N
,
int
GrpsShrB
,
int
CHUNKK
>
__global__
void
wvSplitKrc_
(
const
int
actlN
,
const
int
K
,
const
int
M
,
__global__
void
wvSplitKrc_
(
const
int
actlN
,
const
int
K
,
const
int
M
,
const
int
Bx
,
const
int
By
,
const
scalar_t
*
B
,
const
int
Bx
,
const
int
By
,
const
scalar_t
*
B
,
const
scalar_t
*
__restrict__
A
,
const
scalar_t
*
__restrict__
A
,
...
@@ -1859,10 +1796,10 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
...
@@ -1859,10 +1796,10 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
// const int max_lds_len = get_lds_size() / 2;
// const int max_lds_len = get_lds_size() / 2;
#define WVSPLITKrc(_
WvPrGrp, _YTILE, _UNRL, _N, _GrpsShrB)
\
#define WVSPLITKrc(_
N, _GrpsShrB, _CHUNKK)
\
{ \
{ \
dim3 block(64,
_WvPrGrp);
\
dim3 block(64,
4);
\
wvSplitKrc_<fptype, 64,
_YTILE, _WvPrGrp, 8, _UNRL, _N, _GrpsShrB>
\
wvSplitKrc_<fptype, 64,
16, 4, 8, 1, _N, _GrpsShrB, _CHUNKK>
\
<<<grid, block, 0, stream>>>(N_in, K_in, M_in, Bx_in, By_in, af4, bf4, \
<<<grid, block, 0, stream>>>(N_in, K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, glbl, c, CuCount); \
biasf4, glbl, c, CuCount); \
}
}
...
@@ -1877,15 +1814,37 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
...
@@ -1877,15 +1814,37 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
:
nullptr
;
:
nullptr
;
fptype
*
c
=
reinterpret_cast
<
fptype
*>
(
out_c
.
data_ptr
());
fptype
*
c
=
reinterpret_cast
<
fptype
*>
(
out_c
.
data_ptr
());
auto
glbl
=
axl_glbl
.
data_ptr
<
float
>
();
auto
glbl
=
axl_glbl
.
data_ptr
<
float
>
();
// With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
// and each working on a 512-shard of K, how many CUs would we need?
int
rndup_cus
=
((
M_in
+
64
-
1
)
/
64
)
*
((
K_in
+
512
-
1
)
/
512
);
// How many of 4 waves in a group can work on same 16 Ms at same time? First
// try to maximize this. This reduces the Ms each group works on, i.e.
// increasing the number of CUs needed.
int
GrpsShrB
=
min
(
N_p2
/
16
,
4
);
// Given the above, how many CUs would we need?
int
CuNeeded
=
rndup_cus
*
GrpsShrB
;
if
(
CuNeeded
>
CuCount
)
std
::
runtime_error
(
"Invalid wvSplitKrc size"
);
// Can we increase SplitK by shrinking the K-shared to 256?
int
chunkk
=
(
CuNeeded
*
2
<=
CuCount
)
?
2
:
1
;
switch
(
N_p2
)
{
switch
(
N_p2
)
{
case
16
:
case
16
:
WVSPLITKrc
(
4
,
16
,
1
,
16
,
1
)
break
;
WVSPLITKrc
(
16
,
1
,
1
)
break
;
case
32
:
case
32
:
WVSPLITKrc
(
4
,
16
,
1
,
32
,
2
)
break
;
if
(
chunkk
==
2
)
WVSPLITKrc
(
32
,
2
,
2
)
else
if
(
chunkk
==
1
)
WVSPLITKrc
(
32
,
2
,
1
)
break
;
case
64
:
case
64
:
WVSPLITKrc
(
4
,
16
,
1
,
64
,
2
)
break
;
if
(
chunkk
==
2
)
WVSPLITKrc
(
64
,
4
,
2
)
else
if
(
chunkk
==
1
)
WVSPLITKrc
(
64
,
4
,
1
)
break
;
case
128
:
case
128
:
WVSPLITKrc
(
4
,
16
,
1
,
128
,
4
)
break
;
if
(
chunkk
==
2
)
WVSPLITKrc
(
128
,
4
,
2
)
else
if
(
chunkk
==
1
)
WVSPLITKrc
(
128
,
4
,
1
)
break
;
default:
default:
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"Unsupported N value: "
+
std
::
to_string
(
M_in
)
+
","
+
"Unsupported N value: "
+
std
::
to_string
(
M_in
)
+
","
+
...
...
tests/kernels/quantization/test_rocm_skinny_gemms.py
View file @
ed17f54c
...
@@ -45,31 +45,28 @@ NKM_FACTORS_WVSPLITK = [
...
@@ -45,31 +45,28 @@ NKM_FACTORS_WVSPLITK = [
(
4
,
256
,
8
),
(
4
,
256
,
8
),
]
]
NKM_FACTORS_WVSPLITKRC
=
[
N_FACTORS_WVSPLITKRC
=
[
(
16
,
2880
,
128
),
13
,
(
16
,
2880
,
640
),
16
,
(
17
,
2880
,
128
),
17
,
(
17
,
2880
,
640
),
25
,
(
25
,
2880
,
128
),
29
,
(
25
,
2880
,
640
),
31
,
(
31
,
2880
,
128
),
32
,
(
31
,
2880
,
640
),
41
,
(
32
,
2880
,
128
),
51
,
(
32
,
2880
,
640
),
64
,
(
40
,
2880
,
128
),
71
,
(
40
,
2880
,
640
),
81
,
(
60
,
2880
,
128
),
91
,
(
60
,
2880
,
640
),
103
,
(
64
,
2880
,
128
),
117
,
(
64
,
2880
,
640
),
128
,
(
81
,
2880
,
128
),
(
81
,
2880
,
640
),
(
98
,
2880
,
128
),
(
98
,
2880
,
640
),
(
128
,
2880
,
128
),
(
128
,
2880
,
640
),
]
]
K_FACTORS_WVSPLITKRC
=
[
2880
,
2880
+
8
,
3072
,
3072
+
8
]
M_FACTORS_WVSPLITKRC
=
[
128
,
128
+
16
,
256
,
256
+
16
,
640
,
640
+
16
]
NKM_FACTORS_WVSPLITK_FP8
=
[
NKM_FACTORS_WVSPLITK_FP8
=
[
# FP8-specific cases with K % 16 == 0
# FP8-specific cases with K % 16 == 0
(
1
,
16
,
16
),
(
1
,
16
,
16
),
...
@@ -113,30 +110,54 @@ def pad_fp8(weight):
...
@@ -113,30 +110,54 @@ def pad_fp8(weight):
return
F
.
pad
(
weight
,
(
0
,
num_pad
),
"constant"
,
0
)[...,
:
-
num_pad
]
return
F
.
pad
(
weight
,
(
0
,
num_pad
),
"constant"
,
0
)[...,
:
-
num_pad
]
@
pytest
.
mark
.
parametrize
(
"n,k,m"
,
NKM_FACTORS_WVSPLITKRC
)
@
pytest
.
mark
.
parametrize
(
"xnorm"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"n"
,
N_FACTORS_WVSPLITKRC
)
@
pytest
.
mark
.
parametrize
(
"k"
,
K_FACTORS_WVSPLITKRC
)
@
pytest
.
mark
.
parametrize
(
"m"
,
M_FACTORS_WVSPLITKRC
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"bias_mode"
,
BIAS_MODES
)
@
pytest
.
mark
.
parametrize
(
"bias_mode"
,
BIAS_MODES
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_rocm
(),
reason
=
"only test for rocm"
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_rocm
(),
reason
=
"only test for rocm"
)
@
pytest
.
mark
.
skipif
(
not
on_gfx950
(),
reason
=
"only meant for gfx950"
)
@
pytest
.
mark
.
skipif
(
not
on_gfx950
(),
reason
=
"only meant for gfx950"
)
def
test_rocm_wvsplitkrc_kernel
(
n
,
k
,
m
,
dtype
,
seed
,
bias_mode
):
def
test_rocm_wvsplitkrc_kernel
(
xnorm
,
n
,
k
,
m
,
dtype
,
seed
,
bias_mode
):
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
cu_count
=
get_cu_count
()
cu_count
=
get_cu_count
()
xavier
=
math
.
sqrt
(
2
/
k
)
# normalize to avoid large output-bias deltas
# Next ^2 of n
A
=
(
torch
.
rand
(
n
,
k
,
dtype
=
dtype
,
device
=
"cuda"
)
-
0.5
)
*
xavier
N_p2
=
1
<<
(
n
-
1
).
bit_length
()
B
=
(
torch
.
rand
(
m
,
k
,
dtype
=
dtype
,
device
=
"cuda"
)
-
0.5
)
*
xavier
# With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
# and each working on a 512-shard of K, how many CUs would we need?
rndup_cus
=
((
m
+
64
-
1
)
//
64
)
*
((
k
+
512
-
1
)
//
512
)
# How many of 4 waves in a group can work on same 16 Ms at same time?
# This reduces the Ms each group works on, i.e. increasing the number of CUs needed.
GrpsShrB
=
min
(
N_p2
//
16
,
4
)
# Given the above, how many CUs would we need?
CuNeeded
=
rndup_cus
*
GrpsShrB
# candidate for atomic reduce count splitk?
fits_wvsplitkrc
=
CuNeeded
<=
cu_count
if
not
fits_wvsplitkrc
:
pytest
.
skip
(
"Too large for wvSplitKrc"
)
xavier
=
(
math
.
sqrt
(
2
/
k
)
if
xnorm
else
1
)
# normalize to avoid large output-bias deltas
A
=
(
torch
.
rand
(
n
,
k
,
dtype
=
dtype
,
device
=
"cuda"
)
*
2
-
1
)
*
xavier
B
=
(
torch
.
rand
(
m
,
k
,
dtype
=
dtype
,
device
=
"cuda"
)
*
2
-
1
)
*
xavier
BIAS
=
None
BIAS
=
None
if
bias_mode
==
1
:
if
bias_mode
==
1
:
BIAS
=
torch
.
rand
(
m
,
dtype
=
dtype
,
device
=
"cuda"
)
-
0.5
BIAS
=
torch
.
rand
(
m
,
dtype
=
dtype
,
device
=
"cuda"
)
*
2
-
1
elif
bias_mode
==
2
:
elif
bias_mode
==
2
:
BIAS
=
torch
.
rand
(
n
,
m
,
dtype
=
dtype
,
device
=
"cuda"
)
-
0.5
BIAS
=
torch
.
rand
(
n
,
m
,
dtype
=
dtype
,
device
=
"cuda"
)
*
2
-
1
ref_out
=
torch
.
nn
.
functional
.
linear
(
A
,
B
,
BIAS
)
ref_out
=
torch
.
nn
.
functional
.
linear
(
A
,
B
,
BIAS
)
out
=
ops
.
wvSplitKrc
(
B
,
A
.
view
(
-
1
,
A
.
size
(
-
1
)),
cu_count
,
BIAS
)
out
=
ops
.
wvSplitKrc
(
B
,
A
.
view
(
-
1
,
A
.
size
(
-
1
)),
cu_count
,
BIAS
)
assert
torch
.
allclose
(
out
,
ref_out
,
rtol
=
0.01
)
if
xnorm
:
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-3
,
rtol
=
1e-8
)
else
:
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-3
,
rtol
=
1e-2
)
@
pytest
.
mark
.
parametrize
(
"n,k,m"
,
NKM_FACTORS_LLMM1
)
@
pytest
.
mark
.
parametrize
(
"n,k,m"
,
NKM_FACTORS_LLMM1
)
...
...
vllm/model_executor/layers/utils.py
View file @
ed17f54c
...
@@ -145,32 +145,43 @@ def rocm_unquantized_gemm_impl(
...
@@ -145,32 +145,43 @@ def rocm_unquantized_gemm_impl(
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
vllm.platforms.rocm
import
on_gfx9
,
on_gfx950
from
vllm.platforms.rocm
import
on_gfx9
,
on_gfx950
n
=
x
.
numel
()
/
x
.
size
(
-
1
)
n
=
x
.
numel
()
/
/
x
.
size
(
-
1
)
m
=
weight
.
shape
[
0
]
m
=
weight
.
shape
[
0
]
k
=
weight
.
shape
[
1
]
k
=
weight
.
shape
[
1
]
import
math
cu_count
=
get_cu_count
()
if
use_aiter_triton_gemm
(
n
,
m
,
k
,
x
.
dtype
):
if
use_aiter_triton_gemm
(
n
,
m
,
k
,
x
.
dtype
):
from
aiter.ops.triton.gemm_a16w16
import
gemm_a16w16
from
aiter.ops.triton.gemm_a16w16
import
gemm_a16w16
return
gemm_a16w16
(
x
,
weight
,
bias
)
return
gemm_a16w16
(
x
,
weight
,
bias
)
# Next ^2 of n
N_p2
=
1
<<
(
n
-
1
).
bit_length
()
# With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
# and each working on a 512-shard of K, how many CUs would we need?
rndup_cus
=
((
m
+
64
-
1
)
//
64
)
*
((
k
+
512
-
1
)
//
512
)
# How many of 4 waves in a group can work on same 16 Ms at same time?
# This reduces the Ms each group works on, i.e. increasing the number of CUs needed.
GrpsShrB
=
min
(
N_p2
//
16
,
4
)
# Given the above, how many CUs would we need?
CuNeeded
=
rndup_cus
*
GrpsShrB
# candidate for atomic reduce count splitk?
fits_wvsplitkrc
=
CuNeeded
<=
cu_count
use_skinny_reduce_counting
=
(
use_skinny_reduce_counting
=
(
envs
.
VLLM_ROCM_USE_SKINNY_GEMM
envs
.
VLLM_ROCM_USE_SKINNY_GEMM
and
on_gfx950
()
and
on_gfx950
()
and
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
and
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
and
(
and
(
n
>
=
1
6
10
<=
n
<
=
1
28
and
n
<=
128
and
k
%
8
==
0
and
k
>
512
and
k
>
512
and
math
.
ceil
(
k
/
512
)
*
math
.
ceil
(
m
/
16
)
<
get_cu_count
()
and
m
%
16
==
0
and
fits_wvsplitkrc
and
x
.
is_contiguous
()
and
x
.
is_contiguous
()
)
)
# k == 2880 and (m == 640 or m == 128))
)
)
if
use_skinny_reduce_counting
:
if
use_skinny_reduce_counting
:
cu_count
=
get_cu_count
()
x_view
=
x
.
reshape
(
-
1
,
x
.
size
(
-
1
))
x_view
=
x
.
reshape
(
-
1
,
x
.
size
(
-
1
))
out
=
ops
.
wvSplitKrc
(
weight
,
x_view
,
cu_count
,
bias
)
out
=
ops
.
wvSplitKrc
(
weight
,
x_view
,
cu_count
,
bias
)
return
out
.
reshape
(
*
x
.
shape
[:
-
1
],
weight
.
shape
[
0
])
return
out
.
reshape
(
*
x
.
shape
[:
-
1
],
weight
.
shape
[
0
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment