Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
306d6040
Unverified
Commit
306d6040
authored
May 31, 2025
by
Charlie Fu
Committed by
GitHub
May 31, 2025
Browse files
[ROCm][Kernel] Add gfx950 support for skinny gemms (#18010)
Signed-off-by:
charlifu
<
charlifu@amd.com
>
parent
f2c3f66d
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
91 additions
and
54 deletions
+91
-54
csrc/rocm/skinny_gemms.cu
csrc/rocm/skinny_gemms.cu
+70
-43
tests/kernels/quant_utils.py
tests/kernels/quant_utils.py
+9
-5
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+2
-2
vllm/model_executor/layers/utils.py
vllm/model_executor/layers/utils.py
+2
-2
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+8
-2
No files found.
csrc/rocm/skinny_gemms.cu
View file @
306d6040
...
@@ -13,14 +13,34 @@
...
@@ -13,14 +13,34 @@
#include "dispatch_utils.h"
#include "dispatch_utils.h"
#include "quantization/fp8/common.cuh"
#include "quantization/fp8/common.cuh"
#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx942__))
#if defined(__HIPCC__) && \
#define __HIP__MI300_MI250__
(defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__))
#define __HIP__GFX9__
#endif
#endif
#if defined(__HIPCC__) && defined(__gfx942__)
#if defined(__HIPCC__) &&
(
defined(__gfx942__)
|| defined(__gfx950__))
#define __HIP__MI3
00
__
#define __HIP__MI3
XX
__
#endif
#endif
#if defined(__gfx950__)
#define LDS_SIZE 160 * 1024
#else
#define LDS_SIZE 64 * 1024
#endif
int
get_lds_size
()
{
static
bool
is_cached
=
false
;
static
int
result
;
if
(
is_cached
==
false
)
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
std
::
string
device_arch
=
dprops
->
gcnArchName
;
size_t
substring
=
device_arch
.
find
(
"gfx95"
);
result
=
(
substring
==
std
::
string
::
npos
?
64
*
1024
:
160
*
1024
);
is_cached
=
true
;
}
return
result
;
}
#if defined(NDEBUG)
#if defined(NDEBUG)
#undef NDEBUG
#undef NDEBUG
#include <assert.h>
#include <assert.h>
...
@@ -267,7 +287,7 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
...
@@ -267,7 +287,7 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
V0 += (s.x + s.y); \
V0 += (s.x + s.y); \
}
}
#if defined(__HIP__
MI300_MI250
__) // TODO: Add NAVI support
#if defined(__HIP__
GFX9
__) // TODO: Add NAVI support
// This version targets cases where A[] fits LDS capacity
// This version targets cases where A[] fits LDS capacity
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
int
UNRL
,
int
N
>
...
@@ -275,7 +295,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -275,7 +295,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_sml_
(
const
int
K
,
const
int
M
,
const
scalar_t
*
B
,
wvSplitK_hf_sml_
(
const
int
K
,
const
int
M
,
const
scalar_t
*
B
,
const
scalar_t
*
__restrict__
A
,
scalar_t
*
C
,
const
scalar_t
*
__restrict__
A
,
scalar_t
*
C
,
const
int
_WvPrGrp
,
const
int
CuCount
)
{
const
int
_WvPrGrp
,
const
int
CuCount
)
{
#if defined(__HIP__MI300__)
constexpr
int
max_lds_len
=
LDS_SIZE
/
2
;
#if defined(__HIP__MI3XX__)
constexpr
bool
use_mfma
=
(
std
::
is_same_v
<
scalar_t
,
__hip_bfloat16
>
);
constexpr
bool
use_mfma
=
(
std
::
is_same_v
<
scalar_t
,
__hip_bfloat16
>
);
#else
#else
constexpr
bool
use_mfma
=
false
;
constexpr
bool
use_mfma
=
false
;
...
@@ -295,13 +316,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -295,13 +316,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
};
};
//----------------------------------------------------
//----------------------------------------------------
// Reserving 64 KB of LDS to have 1 WG / CU
// Reserving 64
/160
KB of LDS to have 1 WG / CU
// Goal is to bring the activation matrix A to the LDS
// Goal is to bring the activation matrix A to the LDS
// and use it across the lifetime of the work group
// and use it across the lifetime of the work group
// TODO: When activation matrix is larger than 64 KB
// TODO: When activation matrix is larger than 64 KB
// then this is not goint to work!
// then this is not goint to work!
//----------------------------------------------------
//----------------------------------------------------
__shared__
scalar_t
s
[
1024
*
32
];
__shared__
scalar_t
s
[
max_lds_len
];
//----------------------------------------------------
//----------------------------------------------------
// Fetch the activation matrix to LDS
// Fetch the activation matrix to LDS
...
@@ -312,11 +333,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -312,11 +333,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// - Then the WG will move to another 8 K elements
// - Then the WG will move to another 8 K elements
// TODO: Logic below will only work when K is multiple of 8
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
//----------------------------------------------------
for
(
uint32_t
k
=
0
;
k
<
min
(
K
*
N
,
32
*
1024
);
for
(
uint32_t
k
=
0
;
k
<
min
(
K
*
N
,
max_lds_len
);
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
uint32_t
k_in
=
k
+
((
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
);
uint32_t
k_in
=
k
+
((
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
);
if
(
k_in
>=
min
(
K
*
N
,
32
*
1024
))
break
;
if
(
k_in
>=
min
(
K
*
N
,
max_lds_len
))
break
;
*
((
bigType
*
)(
&
s
[
k_in
]))
=
*
((
bigType
*
)(
&
A
[
k_in
]));
*
((
bigType
*
)(
&
s
[
k_in
]))
=
*
((
bigType
*
)(
&
A
[
k_in
]));
}
}
...
@@ -517,7 +538,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -517,7 +538,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
m
+=
CuCount
*
_WvPrGrp
*
YTILE
;
m
+=
CuCount
*
_WvPrGrp
*
YTILE
;
}
}
}
}
#else // !defined(__HIP__
MI300_MI250
__) TODO: Add NAVI support
#else // !defined(__HIP__
GFX9
__) TODO: Add NAVI support
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
int
UNRL
,
int
N
>
__global__
void
wvSplitK_hf_sml_
(
const
int
K
,
const
int
M
,
const
scalar_t
*
B
,
__global__
void
wvSplitK_hf_sml_
(
const
int
K
,
const
int
M
,
const
scalar_t
*
B
,
...
@@ -525,9 +546,9 @@ __global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
...
@@ -525,9 +546,9 @@ __global__ void wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
const
int
_WvPrGrp
,
const
int
CuCount
)
{
const
int
_WvPrGrp
,
const
int
CuCount
)
{
UNREACHABLE_CODE
UNREACHABLE_CODE
}
}
#endif // defined(__HIP__
MI300_MI250
__) TODO: Add NAVI support
#endif // defined(__HIP__
GFX9
__) TODO: Add NAVI support
#if defined(__HIP__
MI300_MI250
__) // TODO: Add NAVI support
#if defined(__HIP__
GFX9
__) // TODO: Add NAVI support
// This version targets cases where A[] marginally exceeds LDS capacity
// This version targets cases where A[] marginally exceeds LDS capacity
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
int
UNRL
,
int
N
>
...
@@ -535,7 +556,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -535,7 +556,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_
(
const
int
K
,
const
int
M
,
const
scalar_t
*
B
,
wvSplitK_hf_
(
const
int
K
,
const
int
M
,
const
scalar_t
*
B
,
const
scalar_t
*
__restrict__
A
,
scalar_t
*
C
,
const
scalar_t
*
__restrict__
A
,
scalar_t
*
C
,
const
int
_WvPrGrp
,
const
int
CuCount
)
{
const
int
_WvPrGrp
,
const
int
CuCount
)
{
#if defined(__HIP__MI300__)
constexpr
int
max_lds_len
=
LDS_SIZE
/
2
;
#if defined(__HIP__MI3XX__)
constexpr
bool
use_mfma
=
(
std
::
is_same_v
<
scalar_t
,
__hip_bfloat16
>
);
constexpr
bool
use_mfma
=
(
std
::
is_same_v
<
scalar_t
,
__hip_bfloat16
>
);
#else
#else
constexpr
bool
use_mfma
=
false
;
constexpr
bool
use_mfma
=
false
;
...
@@ -561,7 +583,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -561,7 +583,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// TODO: When activation matrix is larger than 64 KB
// TODO: When activation matrix is larger than 64 KB
// then this is not goint to work!
// then this is not goint to work!
//----------------------------------------------------
//----------------------------------------------------
__shared__
scalar_t
s
[
1024
*
32
];
__shared__
scalar_t
s
[
max_lds_len
];
//----------------------------------------------------
//----------------------------------------------------
// Computation of columns that need to be committed to memory!
// Computation of columns that need to be committed to memory!
...
@@ -598,11 +620,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -598,11 +620,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// - Then the WG will move to another 8 K elements
// - Then the WG will move to another 8 K elements
// TODO: Logic below will only work when K is multiple of 8
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
//----------------------------------------------------
for
(
uint32_t
k
=
0
;
k
<
min
(
K
*
N
,
32
*
1024
);
for
(
uint32_t
k
=
0
;
k
<
min
(
K
*
N
,
max_lds_len
);
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
uint32_t
k_in
=
k
+
((
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
);
uint32_t
k_in
=
k
+
((
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
);
if
(
k_in
>=
min
(
K
*
N
,
32
*
1024
))
break
;
if
(
k_in
>=
min
(
K
*
N
,
max_lds_len
))
break
;
*
((
bigType
*
)(
&
s
[
k_in
]))
=
*
((
bigType
*
)(
&
A
[
k_in
]));
*
((
bigType
*
)(
&
s
[
k_in
]))
=
*
((
bigType
*
)(
&
A
[
k_in
]));
}
}
...
@@ -686,7 +708,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -686,7 +708,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// Fetch A activation matrix in interleaved fashion from LDS or memory
// Fetch A activation matrix in interleaved fashion from LDS or memory
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
if
(
k_
+
K
*
n
<
32
*
1024
)
if
(
k_
+
K
*
n
<
max_lds_len
)
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
s
[
k_
+
K
*
n
])));
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
s
[
k_
+
K
*
n
])));
else
else
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
A
[
k_
+
K
*
n
])));
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
A
[
k_
+
K
*
n
])));
...
@@ -817,7 +839,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -817,7 +839,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
}
}
}
#else // !defined(__HIP__
MI300_MI250
__) TODO: Add NAVI support
#else // !defined(__HIP__
GFX9
__) TODO: Add NAVI support
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
int
UNRL
,
int
N
>
__global__
void
wvSplitK_hf_
(
const
int
K
,
const
int
M
,
const
scalar_t
*
B
,
__global__
void
wvSplitK_hf_
(
const
int
K
,
const
int
M
,
const
scalar_t
*
B
,
...
@@ -825,9 +847,9 @@ __global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B,
...
@@ -825,9 +847,9 @@ __global__ void wvSplitK_hf_(const int K, const int M, const scalar_t* B,
const
int
_WvPrGrp
,
const
int
CuCount
)
{
const
int
_WvPrGrp
,
const
int
CuCount
)
{
UNREACHABLE_CODE
UNREACHABLE_CODE
}
}
#endif // defined(__HIP__
MI300_MI250
__) TODO: Add NAVI support
#endif // defined(__HIP__
GFX9
__) TODO: Add NAVI support
#if defined(__HIP__
MI300_MI250
__) // TODO: Add NAVI support
#if defined(__HIP__
GFX9
__) // TODO: Add NAVI support
// This version targets big A[] cases, where it is much larger than LDS capacity
// This version targets big A[] cases, where it is much larger than LDS capacity
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
int
UNRL
,
int
N
>
...
@@ -835,7 +857,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -835,7 +857,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_big_
(
const
int
K
,
const
int
M
,
const
scalar_t
*
B
,
wvSplitK_hf_big_
(
const
int
K
,
const
int
M
,
const
scalar_t
*
B
,
const
scalar_t
*
__restrict__
A
,
scalar_t
*
C
,
const
scalar_t
*
__restrict__
A
,
scalar_t
*
C
,
const
int
_WvPrGrp
,
const
int
CuCount
)
{
const
int
_WvPrGrp
,
const
int
CuCount
)
{
#if defined(__HIP__MI300__)
constexpr
int
max_lds_len
=
LDS_SIZE
/
2
;
#if defined(__HIP__MI3XX__)
constexpr
bool
use_mfma
=
(
std
::
is_same_v
<
scalar_t
,
__hip_bfloat16
>
);
constexpr
bool
use_mfma
=
(
std
::
is_same_v
<
scalar_t
,
__hip_bfloat16
>
);
#else
#else
constexpr
bool
use_mfma
=
false
;
constexpr
bool
use_mfma
=
false
;
...
@@ -855,13 +878,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -855,13 +878,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
};
};
//----------------------------------------------------
//----------------------------------------------------
// Reserving 64 KB of LDS to have 1 WG / CU
// Reserving 64
/160
KB of LDS to have 1 WG / CU
// Goal is to bring the activation matrix A to the LDS
// Goal is to bring the activation matrix A to the LDS
// and use it across the lifetime of the work group
// and use it across the lifetime of the work group
// TODO: When activation matrix is larger than 64 KB
// TODO: When activation matrix is larger than 64 KB
// then this is not goint to work!
// then this is not goint to work!
//----------------------------------------------------
//----------------------------------------------------
__shared__
scalar_t
s
[
1024
*
32
];
__shared__
scalar_t
s
[
max_lds_len
];
//----------------------------------------------------
//----------------------------------------------------
// Computation of columns that need to be committed to memory!
// Computation of columns that need to be committed to memory!
...
@@ -902,11 +925,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -902,11 +925,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
//----------------------------------------------------
//----------------------------------------------------
#define PCML
#define PCML
#ifndef PCML
#ifndef PCML
for
(
uint32_t
k
=
0
;
k
<
min
(
K
*
N
,
32
*
1024
);
for
(
uint32_t
k
=
0
;
k
<
min
(
K
*
N
,
max_lds_len
);
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
uint32_t
k_in
=
k
+
((
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
);
uint32_t
k_in
=
k
+
((
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
);
if
(
k_in
>=
min
(
K
*
N
,
32
*
1024
))
break
;
if
(
k_in
>=
min
(
K
*
N
,
max_lds_len
))
break
;
*
((
bigType
*
)(
&
s
[
k_in
]))
=
*
((
bigType
*
)(
&
A
[
k_in
]));
*
((
bigType
*
)(
&
s
[
k_in
]))
=
*
((
bigType
*
)(
&
A
[
k_in
]));
}
}
...
@@ -916,7 +939,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -916,7 +939,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#define TUC (THRDS * UNRL * A_CHUNK)
#define TUC (THRDS * UNRL * A_CHUNK)
uint32_t
kBase
=
0
;
uint32_t
kBase
=
0
;
// find biggest k size that fits in LDS
// find biggest k size that fits in LDS
uint32_t
kFit
=
(
32
*
1024
)
/
N
;
uint32_t
kFit
=
(
max_lds_len
)
/
N
;
// kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple
// kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple
// of TUC
// of TUC
kFit
=
(
kFit
%
TUC
==
0
)
kFit
=
(
kFit
%
TUC
==
0
)
...
@@ -1164,7 +1187,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -1164,7 +1187,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
}
}
}
}
}
#else // !defined(__HIP__
MI300_MI250
__) TODO: Add NAVI support
#else // !defined(__HIP__
GFX9
__) TODO: Add NAVI support
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
template
<
typename
scalar_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
int
UNRL
,
int
N
>
__global__
void
wvSplitK_hf_big_
(
const
int
K
,
const
int
M
,
const
scalar_t
*
B
,
__global__
void
wvSplitK_hf_big_
(
const
int
K
,
const
int
M
,
const
scalar_t
*
B
,
...
@@ -1172,7 +1195,7 @@ __global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
...
@@ -1172,7 +1195,7 @@ __global__ void wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
const
int
_WvPrGrp
,
const
int
CuCount
)
{
const
int
_WvPrGrp
,
const
int
CuCount
)
{
UNREACHABLE_CODE
UNREACHABLE_CODE
}
}
#endif // defined(__HIP__
MI300_MI250
__) TODO: Add NAVI support
#endif // defined(__HIP__
GFX9
__) TODO: Add NAVI support
int
mindiv
(
int
N
,
int
div1
,
int
div2
)
{
int
mindiv
(
int
N
,
int
div1
,
int
div2
)
{
int
nPrRnd
=
div1
*
div2
;
int
nPrRnd
=
div1
*
div2
;
...
@@ -1222,17 +1245,18 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
...
@@ -1222,17 +1245,18 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
in_a
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
in_a
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
int
max_lds_len
=
get_lds_size
()
/
2
;
#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
_N) \
_N) \
{ \
{ \
dim3 block(64, _WvPrGrp); \
dim3 block(64, _WvPrGrp); \
if ((K_in * N_in <=
32 * 1024
) && (M_in % _YTILEs == 0)) {
\
if ((K_in * N_in <=
max_lds_len
) && (M_in % _YTILEs == 0)) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
wvSplitK_hf_sml_<fptype, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \
wvSplitK_hf_sml_<fptype, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
CuCount); \
CuCount); \
} else if (K_in * N_in <=
32 * 1024
* 1.2) {
\
} else if (K_in * N_in <=
max_lds_len
* 1.2) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
wvSplitK_hf_<fptype, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \
wvSplitK_hf_<fptype, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
<<<grid, block, 0, stream>>>(K_in, M_in, af4, bf4, c, __wvPrGrp, \
...
@@ -1272,7 +1296,7 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
...
@@ -1272,7 +1296,7 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
return
out_c
;
return
out_c
;
}
}
#if defined(__HIP__MI3
00
__) // TODO: Add NAVI support
#if defined(__HIP__MI3
XX
__) // TODO: Add NAVI support
template
<
typename
scalar_t
,
typename
fp8_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
template
<
typename
scalar_t
,
typename
fp8_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
int
A_CHUNK
,
int
UNRL
,
int
N
>
__global__
void
__launch_bounds__
(
WvPrGrp
*
THRDS
)
__global__
void
__launch_bounds__
(
WvPrGrp
*
THRDS
)
...
@@ -1281,6 +1305,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -1281,6 +1305,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
const
float
*
__restrict__
s_A
,
const
float
*
__restrict__
s_A
,
const
float
*
__restrict__
s_B
,
const
int
_WvPrGrp
,
const
float
*
__restrict__
s_B
,
const
int
_WvPrGrp
,
const
int
CuCount
)
{
const
int
CuCount
)
{
constexpr
int
max_lds_len
=
LDS_SIZE
;
using
scalar8
=
using
scalar8
=
__attribute__
((
__vector_size__
((
A_CHUNK
/
4
)
*
sizeof
(
float
))))
float
;
__attribute__
((
__vector_size__
((
A_CHUNK
/
4
)
*
sizeof
(
float
))))
float
;
using
intx2
=
__attribute__
((
__vector_size__
(
2
*
sizeof
(
int
))))
int
;
using
intx2
=
__attribute__
((
__vector_size__
(
2
*
sizeof
(
int
))))
int
;
...
@@ -1296,10 +1321,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -1296,10 +1321,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
scalar8
h8
;
scalar8
h8
;
};
};
__shared__
fp8_t
s
[
1024
*
64
];
__shared__
fp8_t
s
[
max_lds_len
];
for
(
uint32_t
k
=
(
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
;
for
(
uint32_t
k
=
(
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
;
k
<
min
(
K
*
N
,
64
*
1024
);
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
k
<
min
(
K
*
N
,
max_lds_len
);
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
*
((
bigType
*
)(
&
s
[
k
]))
=
*
((
bigType
*
)(
&
A
[
k
]));
*
((
bigType
*
)(
&
s
[
k
]))
=
*
((
bigType
*
)(
&
A
[
k
]));
}
}
__syncthreads
();
__syncthreads
();
...
@@ -1436,7 +1461,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -1436,7 +1461,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
m
+=
CuCount
*
_WvPrGrp
*
YTILE
;
m
+=
CuCount
*
_WvPrGrp
*
YTILE
;
}
}
}
}
#else // !defined(__HIP__MI3
00
__) TODO: Add NAVI support
#else // !defined(__HIP__MI3
XX
__) TODO: Add NAVI support
template
<
typename
scalar_t
,
typename
fp8_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
template
<
typename
scalar_t
,
typename
fp8_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
int
A_CHUNK
,
int
UNRL
,
int
N
>
__global__
void
wvSplitKQ_hf_sml_
(
const
int
K
,
const
int
Kp
,
const
int
M
,
__global__
void
wvSplitKQ_hf_sml_
(
const
int
K
,
const
int
Kp
,
const
int
M
,
...
@@ -1446,9 +1471,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
...
@@ -1446,9 +1471,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
const
int
_WvPrGrp
,
const
int
CuCount
)
{
const
int
_WvPrGrp
,
const
int
CuCount
)
{
UNREACHABLE_CODE
UNREACHABLE_CODE
}
}
#endif // defined(__HIP__MI3
00
__) TODO: Add NAVI support
#endif // defined(__HIP__MI3
XX
__) TODO: Add NAVI support
#if defined(__HIP__MI3
00
__) // TODO: Add NAVI support
#if defined(__HIP__MI3
XX
__) // TODO: Add NAVI support
template
<
typename
scalar_t
,
typename
fp8_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
template
<
typename
scalar_t
,
typename
fp8_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
int
A_CHUNK
,
int
UNRL
,
int
N
>
__global__
void
__launch_bounds__
(
WvPrGrp
*
THRDS
)
__global__
void
__launch_bounds__
(
WvPrGrp
*
THRDS
)
...
@@ -1456,6 +1481,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -1456,6 +1481,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
const
fp8_t
*
__restrict__
A
,
scalar_t
*
C
,
const
fp8_t
*
__restrict__
A
,
scalar_t
*
C
,
const
float
*
__restrict__
s_A
,
const
float
*
__restrict__
s_B
,
const
float
*
__restrict__
s_A
,
const
float
*
__restrict__
s_B
,
const
int
_WvPrGrp
,
const
int
CuCount
)
{
const
int
_WvPrGrp
,
const
int
CuCount
)
{
constexpr
int
max_lds_len
=
LDS_SIZE
;
using
scalar8
=
using
scalar8
=
__attribute__
((
__vector_size__
((
A_CHUNK
/
4
)
*
sizeof
(
float
))))
float
;
__attribute__
((
__vector_size__
((
A_CHUNK
/
4
)
*
sizeof
(
float
))))
float
;
using
intx2
=
__attribute__
((
__vector_size__
(
2
*
sizeof
(
int
))))
int
;
using
intx2
=
__attribute__
((
__vector_size__
(
2
*
sizeof
(
int
))))
int
;
...
@@ -1471,10 +1497,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -1471,10 +1497,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
scalar8
h8
;
scalar8
h8
;
};
};
__shared__
fp8_t
s
[
1024
*
64
];
__shared__
fp8_t
s
[
max_lds_len
];
for
(
uint32_t
k
=
(
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
;
for
(
uint32_t
k
=
(
threadIdx
.
y
*
THRDS
+
threadIdx
.
x
)
*
A_CHUNK
;
k
<
min
(
K
*
N
,
64
*
1024
);
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
k
<
min
(
K
*
N
,
max_lds_len
);
k
+=
THRDS
*
WvPrGrp
*
A_CHUNK
)
{
*
((
bigType
*
)(
&
s
[
k
]))
=
*
((
bigType
*
)(
&
A
[
k
]));
*
((
bigType
*
)(
&
s
[
k
]))
=
*
((
bigType
*
)(
&
A
[
k
]));
}
}
__syncthreads
();
__syncthreads
();
...
@@ -1517,7 +1543,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -1517,7 +1543,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
uint32_t
k_
=
k
+
threadIdx
.
x
*
A_CHUNK
;
if
(
k_
>=
K
)
break
;
if
(
k_
>=
K
)
break
;
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
if
(
k_
+
K
*
n
<
64
*
1024
)
if
(
k_
+
K
*
n
<
max_lds_len
)
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
s
[
k_
+
K
*
n
])));
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
s
[
k_
+
K
*
n
])));
else
else
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
A
[
k_
+
K
*
n
])));
bigA
[
n
][
k2
]
=
*
((
const
bigType
*
)(
&
(
A
[
k_
+
K
*
n
])));
...
@@ -1608,7 +1634,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
...
@@ -1608,7 +1634,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
m
+=
CuCount
*
_WvPrGrp
*
YTILE
;
m
+=
CuCount
*
_WvPrGrp
*
YTILE
;
}
}
}
}
#else // !defined(__HIP__MI3
00
__) TODO: Add NAVI support
#else // !defined(__HIP__MI3
XX
__) TODO: Add NAVI support
template
<
typename
scalar_t
,
typename
fp8_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
template
<
typename
scalar_t
,
typename
fp8_t
,
int
THRDS
,
int
YTILE
,
int
WvPrGrp
,
int
A_CHUNK
,
int
UNRL
,
int
N
>
int
A_CHUNK
,
int
UNRL
,
int
N
>
__global__
void
wvSplitKQ_hf_
(
const
int
K
,
const
int
Kp
,
const
int
M
,
__global__
void
wvSplitKQ_hf_
(
const
int
K
,
const
int
Kp
,
const
int
M
,
...
@@ -1618,7 +1644,7 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
...
@@ -1618,7 +1644,7 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
const
int
CuCount
)
{
const
int
CuCount
)
{
UNREACHABLE_CODE
UNREACHABLE_CODE
}
}
#endif // defined(__HIP__MI3
00
__) TODO: Add NAVI support
#endif // defined(__HIP__MI3
XX
__) TODO: Add NAVI support
void
wvSplitKQ
(
at
::
Tensor
&
in_a
,
at
::
Tensor
&
in_b
,
at
::
Tensor
&
out_c
,
void
wvSplitKQ
(
at
::
Tensor
&
in_a
,
at
::
Tensor
&
in_b
,
at
::
Tensor
&
out_c
,
at
::
Tensor
&
scale_a
,
at
::
Tensor
&
scale_b
,
at
::
Tensor
&
scale_a
,
at
::
Tensor
&
scale_b
,
...
@@ -1638,12 +1664,13 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
...
@@ -1638,12 +1664,13 @@ void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
dim3
grid
(
CuCount
);
dim3
grid
(
CuCount
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
in_a
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
in_a
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
int
max_lds_len
=
get_lds_size
();
#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
_N) \
_N) \
{ \
{ \
dim3 block(64, _WvPrGrp); \
dim3 block(64, _WvPrGrp); \
if ((K_in * N_in <=
64 * 1024
) && (M_in % _YTILEs == 0)) {
\
if ((K_in * N_in <=
max_lds_len
) && (M_in % _YTILEs == 0)) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, a_ptr, b_ptr, c_ptr, \
...
...
tests/kernels/quant_utils.py
View file @
306d6040
...
@@ -8,7 +8,7 @@ from vllm.platforms import current_platform
...
@@ -8,7 +8,7 @@ from vllm.platforms import current_platform
# Using the default value (240.0) from pytorch will cause accuracy
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for rocm.
# issue on dynamic quantization models. Here use 224.0 for rocm.
ROCM_FP8_MAX
=
224.0
ROCM_FP8
FNUZ
_MAX
=
224.0
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
...
@@ -26,9 +26,11 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
...
@@ -26,9 +26,11 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
qtype_traits
=
torch
.
iinfo
(
quant_dtype
)
if
quant_dtype
==
torch
.
int8
\
qtype_traits
=
torch
.
iinfo
(
quant_dtype
)
if
quant_dtype
==
torch
.
int8
\
else
torch
.
finfo
(
quant_dtype
)
else
torch
.
finfo
(
quant_dtype
)
qtype_traits_max
=
ROCM_FP8_MAX
if
current_platform
.
is_rocm
()
\
qtype_traits_max
=
ROCM_FP8FNUZ_MAX
if
current_platform
.
is_rocm
()
\
and
current_platform
.
is_fp8_fnuz
()
\
else
qtype_traits
.
max
else
qtype_traits
.
max
qtype_traits_min
=
-
ROCM_FP8_MAX
if
current_platform
.
is_rocm
()
\
qtype_traits_min
=
-
ROCM_FP8FNUZ_MAX
if
current_platform
.
is_rocm
()
\
and
current_platform
.
is_fp8_fnuz
()
\
else
qtype_traits
.
min
else
qtype_traits
.
min
qtype_max
=
as_float32_tensor
(
qtype_traits_max
)
qtype_max
=
as_float32_tensor
(
qtype_traits_max
)
s_1
=
as_float32_tensor
(
1.0
)
s_1
=
as_float32_tensor
(
1.0
)
...
@@ -70,9 +72,11 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
...
@@ -70,9 +72,11 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
->
tuple
[
torch
.
tensor
,
torch
.
tensor
]:
->
tuple
[
torch
.
tensor
,
torch
.
tensor
]:
fp8_traits
=
torch
.
finfo
(
FP8_DTYPE
)
fp8_traits
=
torch
.
finfo
(
FP8_DTYPE
)
fp8_traits_max
=
ROCM_FP8_MAX
if
current_platform
.
is_rocm
()
\
fp8_traits_max
=
ROCM_FP8FNUZ_MAX
if
current_platform
.
is_rocm
()
\
and
current_platform
.
is_fp8_fnuz
()
\
else
fp8_traits
.
max
else
fp8_traits
.
max
fp8_traits_min
=
-
ROCM_FP8_MAX
if
current_platform
.
is_rocm
()
\
fp8_traits_min
=
-
ROCM_FP8FNUZ_MAX
if
current_platform
.
is_rocm
()
\
and
current_platform
.
is_fp8_fnuz
()
\
else
fp8_traits
.
min
else
fp8_traits
.
min
fp8_max
=
as_float32_tensor
(
fp8_traits_max
)
fp8_max
=
as_float32_tensor
(
fp8_traits_max
)
one
=
as_float32_tensor
(
1.0
)
one
=
as_float32_tensor
(
1.0
)
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
306d6040
...
@@ -155,8 +155,8 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
...
@@ -155,8 +155,8 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
scale_b
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
input_2d
:
torch
.
Tensor
,
input_2d
:
torch
.
Tensor
,
output_shape
:
list
)
->
torch
.
Tensor
:
output_shape
:
list
)
->
torch
.
Tensor
:
from
vllm.platforms.rocm
import
on_mi
250_mi300
from
vllm.platforms.rocm
import
on_mi
3xx
if
envs
.
VLLM_ROCM_USE_SKINNY_GEMM
and
on_mi
250_mi300
(
if
envs
.
VLLM_ROCM_USE_SKINNY_GEMM
and
on_mi
3xx
(
)
and
qinput
.
shape
[
0
]
==
1
and
qinput
.
shape
[
1
]
%
16
==
0
:
)
and
qinput
.
shape
[
0
]
==
1
and
qinput
.
shape
[
1
]
%
16
==
0
:
output
=
ops
.
wvSplitKQ
(
weight
.
t
(),
qinput
,
out_dtype
,
scale_a
,
scale_b
,
output
=
ops
.
wvSplitKQ
(
weight
.
t
(),
qinput
,
out_dtype
,
scale_a
,
scale_b
,
current_platform
.
get_cu_count
())
current_platform
.
get_cu_count
())
...
...
vllm/model_executor/layers/utils.py
View file @
306d6040
...
@@ -70,9 +70,9 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
...
@@ -70,9 +70,9 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
def
rocm_unquantized_gemm
(
x
:
torch
.
Tensor
,
def
rocm_unquantized_gemm
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
):
bias
:
Optional
[
torch
.
Tensor
]
=
None
):
from
vllm.platforms.rocm
import
on_
mi250_mi300
from
vllm.platforms.rocm
import
on_
gfx9
k
=
weight
.
shape
[
1
]
k
=
weight
.
shape
[
1
]
use_skinny
=
(
envs
.
VLLM_ROCM_USE_SKINNY_GEMM
and
on_
mi250_mi300
()
and
\
use_skinny
=
(
envs
.
VLLM_ROCM_USE_SKINNY_GEMM
and
on_
gfx9
()
and
\
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
\
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
\
and
k
%
8
==
0
and
bias
is
None
)
and
k
%
8
==
0
and
bias
is
None
)
...
...
vllm/platforms/rocm.py
View file @
306d6040
...
@@ -105,9 +105,15 @@ def on_gfx1x() -> bool:
...
@@ -105,9 +105,15 @@ def on_gfx1x() -> bool:
@
cache
@
cache
def
on_mi
250_mi300
()
->
bool
:
def
on_mi
3xx
()
->
bool
:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
return
any
(
arch
in
GPU_ARCH
for
arch
in
[
"gfx90a"
,
"gfx942"
])
return
any
(
arch
in
GPU_ARCH
for
arch
in
[
"gfx942"
,
"gfx950"
])
@
cache
def
on_gfx9
()
->
bool
:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
return
any
(
arch
in
GPU_ARCH
for
arch
in
[
"gfx90a"
,
"gfx942"
,
"gfx950"
])
@
cache
@
cache
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment