Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
721ae79f
Unverified
Commit
721ae79f
authored
Mar 10, 2026
by
Hashem Hashemi
Committed by
GitHub
Mar 10, 2026
Browse files
Improvements to wvSplitKrc skinny GEMM solution (#34304)
Signed-off-by:
Hashem Hashemi
<
hashem.hashemi@amd.com
>
parent
aefc59f0
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
168 additions
and
97 deletions
+168
-97
csrc/rocm/skinny_gemms.cu
csrc/rocm/skinny_gemms.cu
+150
-84
tests/kernels/quantization/test_rocm_skinny_gemms.py
tests/kernels/quantization/test_rocm_skinny_gemms.py
+7
-4
vllm/model_executor/layers/utils.py
vllm/model_executor/layers/utils.py
+11
-9
No files found.
csrc/rocm/skinny_gemms.cu
View file @
721ae79f
...
...
@@ -12,6 +12,7 @@
#include "../cuda_compat.h"
#include "dispatch_utils.h"
#include "quantization/w8a8/fp8/common.cuh"
#include "core/batch_invariant.hpp"
// TODO(rasmith): The kernels in this file are susceptible to integer overflow
// issues, do not take strides, and are unable to handle PyTorch tensors that
...
...
@@ -1224,17 +1225,14 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
#if defined(__gfx950__)
#define WVSPLITKRC_1KPASS
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
,
int
GrpsShrB
,
int
CHUNKK
>
int
UNRL
,
int
N
,
int
GrpsShrB
,
int
CHUNKK
,
int
DTRMNSTC
>
__global__
void
__launch_bounds__
(
WvPrGrp
*
THRDS
)
__attribute__
((
amdgpu_waves_per_eu
(
1
,
1
)))
wvSplitKrc_
(
const
int
actlN
,
const
int
K
,
const
int
M
,
const
int
Bx
,
const
int
By
,
const
scalar_t
*
__restrict__
B
,
const
scalar_t
*
__restrict__
A
,
const
scalar_t
*
__restrict__
BIAS
,
float
*
glbl
,
scalar_t
*
C
,
const
int
CuCount
)
{
// Use upper half of glbl buffer for atomic reduce counting
int
*
cntr
=
(
int
*
)(
&
glbl
[
M
*
N
]);
wvSplitKrc_
(
const
int
actlN
,
const
int
K
,
const
int
Kap
,
const
int
M
,
const
int
Bx
,
const
int
By
,
const
scalar_t
*
__restrict__
A
,
const
scalar_t
*
__restrict__
B
,
const
scalar_t
*
__restrict__
BIAS
,
float
*
glbl
,
int
*
cntr
,
scalar_t
*
C
,
const
int
CuCount
)
{
constexpr
int
NTILE
=
16
;
constexpr
int
APAD
=
1
;
constexpr
int
ASTRD
=
64
;
...
...
@@ -1425,10 +1423,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
unsigned
int
kOffcp
=
min__
(
K
-
A_CHUNK
,
k_str
+
kOff
);
for
(
unsigned
int
n
=
0
;
n
<
N
;
n
+=
CHUNKK
*
sprdN
)
{
__builtin_amdgcn_global_load_lds
(
(
int
*
)(
&
A
[
min__
(
K
*
actlN
-
A_
CHUNK
,
kOffcp
+
K
*
(
n
/
CHUNKK
+
(
N
/
CHUNKK
)
*
(
threadIdx
.
x
/
(
64
/
CHUNKK
))
+
(
int
*
)(
&
A
[
min__
(
Kap
*
actlN
-
A_CHUNK
,
kOffcp
+
Kap
*
(
n
/
CHUNK
K
+
(
N
/
CHUNKK
)
*
(
threadIdx
.
x
/
(
64
/
CHUNKK
))
+
(
threadIdx
.
y
%
sprdN
)))]),
(
int
*
)(
&
s
[(
k
+
kFitPdd
*
((
n
/
CHUNKK
)
+
(
threadIdx
.
y
%
sprdN
)))]),
...
...
@@ -1533,30 +1531,66 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
}
union
flt4
{
scalar8
s8
;
float2
f2
[
2
];
float4
f4
;
};
if
(
m
+
(
threadIdx
.
x
%
16
)
<
M
)
{
int
my_cntr
;
int
mindx
=
m
+
(
threadIdx
.
x
%
16
);
int
g_mindx
=
m
*
4
+
(
threadIdx
.
x
%
64
);
// coalesced atomic reduction
scalar_t
biases
[
N
/
NTILE
/
GrpsShrB
][
4
]
=
{};
// Atomic add the output, read biases
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
for
(
uint32_t
j
=
0
;
j
<
4
;
j
++
)
{
// int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
// (N / GrpsShrB) * (threadIdx.y % GrpsShrB);
// int adr = mindx + M * nindx;
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
{
int
g_nindx
=
j
+
(
nt
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
))
/
4
;
int
g_adr
=
g_mindx
+
M
*
g_nindx
*
4
;
atomicAdd
(
&
glbl
[
g_adr
],
sum4
[
nt
][
0
][
j
]);
(
nt
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
))
/
4
;
int
g_adr
=
g_mindx
*
4
+
0
+
M
*
g_nindx
*
4
;
if
(
DTRMNSTC
)
{
flt4
flt4_
=
{.
s8
=
sum4
[
nt
][
0
]};
__hip_atomic_store
((
float2
*
)
&
glbl
[
g_adr
+
M
*
N
*
(
m0
/
Mmod
)],
flt4_
.
f2
[
0
],
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_AGENT
);
__hip_atomic_store
((
float2
*
)
&
glbl
[
g_adr
+
2
+
M
*
N
*
(
m0
/
Mmod
)],
flt4_
.
f2
[
1
],
__ATOMIC_RELAXED
,
__HIP_MEMORY_SCOPE_AGENT
);
}
else
{
for
(
uint32_t
j
=
0
;
j
<
4
;
j
++
)
atomicAdd
((
&
glbl
[
g_adr
+
j
]),
sum4
[
nt
][
0
][
j
]);
}
}
__atomic_signal_fence
(
__ATOMIC_SEQ_CST
);
asm
volatile
(
"s_waitcnt vmcnt(0)"
:::
"memory"
);
__atomic_signal_fence
(
__ATOMIC_SEQ_CST
);
int
nindx_
=
(
0
+
(
threadIdx
.
x
/
16
)
*
4
)
+
0
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
);
int
adr_
=
mindx
+
M
*
nindx_
/
4
;
// Update the complete counter
my_cntr
=
atomicAdd
(
&
cntr
[
adr_
],
1
);
float
vals
[
N
/
NTILE
/
GrpsShrB
][
4
]
=
{};
// make sure LDS is free for write out staging
if
(
DTRMNSTC
)
__syncthreads
();
// Update the complete counter
flt4
vals
[
N
/
NTILE
/
GrpsShrB
]
=
{};
// If we're the last k-shard, read back the value and convert...
if
(
my_cntr
+
1
==
k_rnd
)
{
cntr
[
adr_
]
=
0
;
// clear for next round
if
constexpr
(
DTRMNSTC
)
{
#pragma unroll
for
(
int
ks
=
0
;
ks
<
k_rnd
;
ks
++
)
{
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
{
int
g_nindx
=
(
nt
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
))
/
4
;
int
g_adr
=
g_mindx
*
4
+
0
+
M
*
g_nindx
*
4
;
__builtin_amdgcn_global_load_lds
(
(
float4
*
)(
&
glbl
[
g_adr
+
M
*
N
*
ks
]),
&
(((
float4
*
)
s
)[(
threadIdx
.
y
*
THRDS
)
+
ks
*
THRDS
*
4
+
nt
*
THRDS
*
4
*
k_rnd
]),
16
,
0
,
0
);
}
}
if
(
BIAS
)
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
{
for
(
uint32_t
j
=
0
;
j
<
4
;
j
++
)
{
...
...
@@ -1565,12 +1599,29 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
biases
[
nt
][
j
]
=
BIAS
[(
mindx
%
Bx
)
+
(
nindx
%
By
)
*
Bx
];
}
}
asm
volatile
(
"s_waitcnt 0"
);
for
(
int
ks
=
0
;
ks
<
k_rnd
;
ks
++
)
{
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
{
float4
eval
=
((
float4
*
)
s
)[(
threadIdx
.
x
+
threadIdx
.
y
*
THRDS
)
+
ks
*
THRDS
*
4
+
nt
*
THRDS
*
4
*
k_rnd
];
vals
[
nt
].
f4
+=
eval
;
}
}
}
else
{
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
{
for
(
uint32_t
j
=
0
;
j
<
4
;
j
++
)
{
int
g_nindx
=
j
+
(
nt
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
))
/
4
;
int
g_adr
=
g_mindx
+
M
*
g_nindx
*
4
;
vals
[
nt
][
j
]
=
glbl
[
g_adr
];
(
nt
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
))
/
4
;
int
g_adr
=
g_mindx
*
4
+
0
+
M
*
g_nindx
*
4
;
vals
[
nt
].
f4
=
*
(
float4
*
)(
&
glbl
[
g_adr
]);
*
(
float4
*
)(
&
glbl
[
g_adr
])
=
{};
// clear out for next round
}
if
(
BIAS
)
for
(
uint32_t
nt
=
0
;
nt
<
N
/
NTILE
/
GrpsShrB
;
nt
++
)
{
for
(
uint32_t
j
=
0
;
j
<
4
;
j
++
)
{
int
nindx
=
(
j
+
(
threadIdx
.
x
/
16
)
*
4
)
+
nt
*
NTILE
+
(
N
/
GrpsShrB
)
*
(
threadIdx
.
y
%
GrpsShrB
);
biases
[
nt
][
j
]
=
BIAS
[(
mindx
%
Bx
)
+
(
nindx
%
By
)
*
Bx
];
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
...
...
@@ -1581,11 +1632,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if
(
nindx
<
actlN
)
{
int
adr
=
mindx
+
M
*
nindx
;
if
constexpr
(
std
::
is_same_v
<
scalar_t
,
__hip_bfloat16
>
)
{
vals
[
nt
][
j
]
+=
__bfloat162float
(
biases
[
nt
][
j
]);
C
[
adr
]
=
__float2bfloat16
(
vals
[
nt
][
j
]);
vals
[
nt
]
.
s8
[
j
]
+=
__bfloat162float
(
biases
[
nt
][
j
]);
C
[
adr
]
=
__float2bfloat16
(
vals
[
nt
]
.
s8
[
j
]);
}
else
{
vals
[
nt
][
j
]
+=
__half2float
(
biases
[
nt
][
j
]);
C
[
adr
]
=
__float2half
(
vals
[
nt
][
j
]);
vals
[
nt
]
.
s8
[
j
]
+=
__half2float
(
biases
[
nt
][
j
]);
C
[
adr
]
=
__float2half
(
vals
[
nt
]
.
s8
[
j
]);
}
}
}
...
...
@@ -1604,21 +1655,25 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
,
int
GrpsShrB
,
int
CHUNKK
>
__global__
void
wvSplitKrc_
(
const
int
actlN
,
const
int
K
,
const
int
M
,
const
int
Bx
,
const
int
B
y
,
const
scalar_t
*
B
,
const
scalar_t
*
__restrict__
A
,
int
UNRL
,
int
N
,
int
GrpsShrB
,
int
CHUNKK
,
int
DTRMNSTC
>
__global__
void
wvSplitKrc_
(
const
int
actlN
,
const
int
K
,
const
int
Kap
,
const
int
M
,
const
int
B
x
,
const
int
B
y
,
const
scalar_t
*
B
,
const
scalar_t
*
__restrict__
A
,
const
scalar_t
*
__restrict__
BIAS
,
float
*
glbl
,
//
int* cntr,
scalar_t
*
C
,
const
int
CuCount
){
UNREACHABLE_CODE
}
int
*
cntr
,
scalar_t
*
C
,
const
int
CuCount
){
UNREACHABLE_CODE
}
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
torch
::
Tensor
wvSplitKrc
(
const
at
::
Tensor
&
in_a
,
const
at
::
Tensor
&
in_b
,
const
std
::
optional
<
at
::
Tensor
>&
in_bias
,
const
int64_t
CuCount
)
{
auto
M_in
=
in_a
.
size
(
0
);
auto
N_in
=
in_b
.
size
(
0
);
auto
K_in
=
in_a
.
size
(
1
);
int
_DTRMNSTC
=
1
;
// vllm::vllm_is_batch_invariant();
auto
M_in
=
in_b
.
size
(
0
);
auto
N_in
=
in_a
.
size
(
0
);
auto
K_in
=
in_b
.
size
(
1
);
auto
Kap_in
=
in_a
.
stride
(
0
);
auto
Bx_in
=
(
in_bias
.
has_value
()
&&
in_bias
->
numel
()
>
0
)
?
(
in_bias
->
sizes
().
size
()
==
2
)
?
in_bias
->
size
(
1
)
:
in_bias
->
size
(
0
)
...
...
@@ -1635,13 +1690,9 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
auto
out_c
=
torch
::
empty
(
{
N_in
,
M_in
},
torch
::
TensorOptions
().
dtype
(
in_
b
.
dtype
()).
device
(
in_
b
.
device
()));
torch
::
TensorOptions
().
dtype
(
in_
a
.
dtype
()).
device
(
in_
a
.
device
()));
auto
N_p2
=
1U
<<
(
32
-
__builtin_clz
(
N_in
-
1
));
auto
axl_glbl
=
torch
::
empty
(
{
N_p2
+
N_p2
/
4
,
M_in
+
M_in
/
4
},
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
in_b
.
device
()));
axl_glbl
.
zero_
();
// disable for FAST_UNSAFE_RDC_INIT
dim3
grid
(
CuCount
);
...
...
@@ -1649,25 +1700,6 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
// const int max_lds_len = get_lds_size() / 2;
#define WVSPLITKrc(_N, _GrpsShrB, _CHUNKK) \
{ \
dim3 block(64, 4); \
wvSplitKrc_<fptype, 64, 16, 4, 8, 1, _N, _GrpsShrB, _CHUNKK> \
<<<grid, block, 0, stream>>>(N_in, K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, glbl, c, CuCount); \
}
AT_DISPATCH_REDUCED_FLOATING_TYPES
(
in_b
.
scalar_type
(),
"wvSplitKrc"
,
[
&
]
{
using
fptype
=
typename
scalar
<
scalar_t
>::
type
;
fptype
*
af4
=
reinterpret_cast
<
fptype
*>
(
in_a
.
data_ptr
());
const
fptype
*
bf4
=
reinterpret_cast
<
const
fptype
*>
(
in_b
.
data_ptr
());
const
fptype
*
biasf4
=
(
in_bias
.
has_value
()
&&
in_bias
->
numel
()
>
0
)
?
reinterpret_cast
<
const
fptype
*>
(
in_bias
->
data_ptr
())
:
nullptr
;
fptype
*
c
=
reinterpret_cast
<
fptype
*>
(
out_c
.
data_ptr
());
auto
glbl
=
axl_glbl
.
data_ptr
<
float
>
();
// With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
// and each working on a 512-shard of K, how many CUs would we need?
int
rndup_cus
=
((
M_in
+
64
-
1
)
/
64
)
*
((
K_in
+
512
-
1
)
/
512
);
...
...
@@ -1680,24 +1712,58 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
// Given the above, how many CUs would we need?
int
CuNeeded
=
rndup_cus
*
GrpsShrB
;
if
(
CuNeeded
>
CuCount
)
std
::
runtime_error
(
"Invalid wvSplitKrc size"
);
if
(
CuNeeded
>
CuCount
)
throw
std
::
runtime_error
(
"Invalid wvSplitKrc size"
);
// Can we increase SplitK by shrinking the K-shared to 256?
int
chunkk
=
(
CuNeeded
*
2
<=
CuCount
)
?
2
:
1
;
static
torch
::
Tensor
axl_glbl
=
torch
::
zeros
(
128
*
1024
*
(
_DTRMNSTC
?
12
:
1
),
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
in_a
.
device
()))
.
detach
();
static
torch
::
Tensor
axl_cntr
=
torch
::
zeros
(
128
*
1024
*
(
_DTRMNSTC
?
12
:
1
)
/
4
,
torch
::
TensorOptions
().
dtype
(
torch
::
kInt
).
device
(
in_a
.
device
()))
.
detach
();
auto
glbl
=
axl_glbl
.
data_ptr
<
float
>
();
auto
cntr
=
axl_cntr
.
data_ptr
<
int
>
();
#define WVSPLITKrc(_N, _GrpsShrB, _CHUNKK) \
{ \
dim3 block(64, 4); \
if (_DTRMNSTC) \
wvSplitKrc_<fptype, 64, 16, 4, 8, 1, _N, _GrpsShrB, _CHUNKK, 1> \
<<<grid, block, 0, stream>>>(N_in, K_in, Kap_in, M_in, Bx_in, By_in, \
af4, bf4, biasf4, glbl, cntr, c, \
CuCount); \
else \
wvSplitKrc_<fptype, 64, 16, 4, 8, 1, _N, _GrpsShrB, _CHUNKK, 0> \
<<<grid, block, 0, stream>>>(N_in, K_in, Kap_in, M_in, Bx_in, By_in, \
af4, bf4, biasf4, glbl, cntr, c, \
CuCount); \
}
AT_DISPATCH_REDUCED_FLOATING_TYPES
(
in_a
.
scalar_type
(),
"wvSplitKrc"
,
[
&
]
{
using
fptype
=
typename
scalar
<
scalar_t
>::
type
;
const
fptype
*
af4
=
reinterpret_cast
<
const
fptype
*>
(
in_a
.
data_ptr
());
const
fptype
*
bf4
=
reinterpret_cast
<
const
fptype
*>
(
in_b
.
data_ptr
());
const
fptype
*
biasf4
=
(
in_bias
.
has_value
()
&&
in_bias
->
numel
()
>
0
)
?
reinterpret_cast
<
const
fptype
*>
(
in_bias
->
data_ptr
())
:
nullptr
;
fptype
*
c
=
reinterpret_cast
<
fptype
*>
(
out_c
.
data_ptr
());
switch
(
N_p2
)
{
case
16
:
WVSPLITKrc
(
16
,
1
,
1
)
break
;
case
32
:
if
(
chunkk
==
2
)
WVSPLITKrc
(
32
,
2
,
2
)
else
if
(
chunkk
==
1
)
WVSPLITKrc
(
32
,
2
,
1
)
break
;
if
(
chunkk
==
2
)
WVSPLITKrc
(
32
,
2
,
2
)
else
WVSPLITKrc
(
32
,
2
,
1
)
break
;
case
64
:
if
(
chunkk
==
2
)
WVSPLITKrc
(
64
,
4
,
2
)
else
if
(
chunkk
==
1
)
WVSPLITKrc
(
64
,
4
,
1
)
break
;
if
(
chunkk
==
2
)
WVSPLITKrc
(
64
,
4
,
2
)
else
WVSPLITKrc
(
64
,
4
,
1
)
break
;
case
128
:
if
(
chunkk
==
2
)
WVSPLITKrc
(
128
,
4
,
2
)
else
if
(
chunkk
==
1
)
WVSPLITKrc
(
128
,
4
,
1
)
break
;
if
(
chunkk
==
2
)
WVSPLITKrc
(
128
,
4
,
2
)
else
WVSPLITKrc
(
128
,
4
,
1
)
break
;
default:
throw
std
::
runtime_error
(
"Unsupported N value: "
+
std
::
to_string
(
M_in
)
+
","
+
...
...
tests/kernels/quantization/test_rocm_skinny_gemms.py
View file @
721ae79f
...
...
@@ -70,7 +70,6 @@ N_FACTORS_WVSPLITKRC = [
117
,
128
,
]
K_FACTORS_WVSPLITKRC
=
[
2880
,
2880
+
8
,
3072
,
3072
+
8
]
M_FACTORS_WVSPLITKRC
=
[
128
,
128
+
16
,
256
,
256
+
16
,
640
,
640
+
16
]
...
...
@@ -123,10 +122,11 @@ def pad_fp8(weight):
@
pytest
.
mark
.
parametrize
(
"m"
,
M_FACTORS_WVSPLITKRC
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"padded_a"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"bias_mode"
,
BIAS_MODES
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_rocm
(),
reason
=
"only test for rocm"
)
@
pytest
.
mark
.
skipif
(
not
on_gfx950
(),
reason
=
"only meant for gfx950"
)
def
test_rocm_wvsplitkrc_kernel
(
xnorm
,
n
,
k
,
m
,
dtype
,
seed
,
bias_mode
):
def
test_rocm_wvsplitkrc_kernel
(
xnorm
,
n
,
k
,
m
,
dtype
,
seed
,
padded_a
,
bias_mode
):
torch
.
manual_seed
(
seed
)
cu_count
=
num_compute_units
()
...
...
@@ -141,7 +141,8 @@ def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode):
# Given the above, how many CUs would we need?
CuNeeded
=
rndup_cus
*
GrpsShrB
# candidate for atomic reduce count splitk?
fits_wvsplitkrc
=
CuNeeded
<=
cu_count
fits_wvsplitkrc
=
(
N_p2
*
m
*
((
k
+
512
-
1
)
//
512
))
<=
128
*
1024
*
12
fits_wvsplitkrc
&=
CuNeeded
<=
cu_count
if
not
fits_wvsplitkrc
:
pytest
.
skip
(
"Too large for wvSplitKrc"
)
...
...
@@ -151,6 +152,8 @@ def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode):
)
# normalize to avoid large output-bias deltas
A
=
(
torch
.
rand
(
n
,
k
,
dtype
=
dtype
,
device
=
"cuda"
)
*
2
-
1
)
*
xavier
B
=
(
torch
.
rand
(
m
,
k
,
dtype
=
dtype
,
device
=
"cuda"
)
*
2
-
1
)
*
xavier
if
padded_a
:
A
=
pad_fp8
(
A
)
BIAS
=
None
if
bias_mode
==
1
:
...
...
@@ -159,7 +162,7 @@ def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode):
BIAS
=
torch
.
rand
(
n
,
m
,
dtype
=
dtype
,
device
=
"cuda"
)
*
2
-
1
ref_out
=
torch
.
nn
.
functional
.
linear
(
A
,
B
,
BIAS
)
out
=
ops
.
wvSplitKrc
(
B
,
A
.
view
(
-
1
,
A
.
size
(
-
1
))
,
cu_count
,
BIAS
)
out
=
ops
.
wvSplitKrc
(
A
,
B
,
cu_count
,
BIAS
)
if
xnorm
:
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
1e-3
,
rtol
=
1e-8
)
...
...
vllm/model_executor/layers/utils.py
View file @
721ae79f
...
...
@@ -129,10 +129,6 @@ def rocm_unquantized_gemm_impl(
k
=
weight
.
shape
[
1
]
cu_count
=
num_compute_units
()
if
use_aiter_triton_gemm
(
n
,
m
,
k
,
x
.
dtype
):
from
aiter.ops.triton.gemm_a16w16
import
gemm_a16w16
return
gemm_a16w16
(
x
,
weight
,
bias
)
# Next ^2 of n
N_p2
=
1
<<
(
n
-
1
).
bit_length
()
...
...
@@ -145,7 +141,10 @@ def rocm_unquantized_gemm_impl(
# Given the above, how many CUs would we need?
CuNeeded
=
rndup_cus
*
GrpsShrB
# candidate for atomic reduce count splitk?
fits_wvsplitkrc
=
CuNeeded
<=
cu_count
fits_wvsplitkrc
=
(
N_p2
*
m
*
((
k
+
512
-
1
)
//
512
)
)
<=
128
*
1024
*
12
# deterministic
fits_wvsplitkrc
&=
CuNeeded
<=
cu_count
use_skinny_reduce_counting
=
(
envs
.
VLLM_ROCM_USE_SKINNY_GEMM
...
...
@@ -157,13 +156,16 @@ def rocm_unquantized_gemm_impl(
and
k
>
512
and
m
%
16
==
0
and
fits_wvsplitkrc
and
x
.
is_contiguous
()
and
weight
.
is_contiguous
()
)
)
if
use_skinny_reduce_counting
:
x_view
=
x
.
reshape
(
-
1
,
x
.
size
(
-
1
))
out
=
ops
.
wvSplitKrc
(
weight
,
x_view
,
cu_count
,
bias
)
return
out
.
reshape
(
*
x
.
shape
[:
-
1
],
weight
.
shape
[
0
])
return
ops
.
wvSplitKrc
(
x
,
weight
,
cu_count
,
bias
)
if
use_aiter_triton_gemm
(
n
,
m
,
k
,
x
.
dtype
):
from
aiter.ops.triton.gemm_a16w16
import
gemm_a16w16
return
gemm_a16w16
(
x
,
weight
,
bias
)
use_skinny
=
(
envs
.
VLLM_ROCM_USE_SKINNY_GEMM
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment