Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2dacd573
Unverified
Commit
2dacd573
authored
Nov 13, 2025
by
wangxiyuan
Committed by
GitHub
Nov 13, 2025
Browse files
[platform] Move get_cu_count to utils (#27005)
Signed-off-by:
wangxiyuan
<
wangxiyuan1007@gmail.com
>
parent
d75ad048
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
28 additions
and
18 deletions
+28
-18
tests/kernels/quantization/test_rocm_skinny_gemms.py
tests/kernels/quantization/test_rocm_skinny_gemms.py
+19
-5
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+2
-1
vllm/model_executor/layers/utils.py
vllm/model_executor/layers/utils.py
+2
-1
vllm/platforms/interface.py
vllm/platforms/interface.py
+0
-7
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+0
-4
vllm/utils/platform_utils.py
vllm/utils/platform_utils.py
+5
-0
No files found.
tests/kernels/quantization/test_rocm_skinny_gemms.py
View file @
2dacd573
...
...
@@ -8,6 +8,7 @@ import torch
import
vllm._custom_ops
as
ops
from
tests.kernels.quant_utils
import
ref_dynamic_per_tensor_fp8_quant
from
vllm.platforms
import
current_platform
from
vllm.utils.platform_utils
import
get_cu_count
DTYPES
=
[
torch
.
bfloat16
,
torch
.
float16
]
# Specific (N, K, M) combinations for targeted testing
...
...
@@ -85,7 +86,7 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_rocm
(),
reason
=
"only test for rocm"
)
def
test_rocm_wvsplitk_kernel
(
n
,
k
,
m
,
dtype
,
seed
):
torch
.
manual_seed
(
seed
)
cu_count
=
current_platform
.
get_cu_count
()
cu_count
=
get_cu_count
()
A
=
torch
.
rand
(
n
,
k
,
dtype
=
dtype
,
device
=
"cuda"
)
-
0.5
B
=
torch
.
rand
(
m
,
k
,
dtype
=
dtype
,
device
=
"cuda"
)
-
0.5
...
...
@@ -102,7 +103,7 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_rocm
(),
reason
=
"only test for rocm"
)
def
test_rocm_wvsplitk_bias1D_kernel
(
n
,
k
,
m
,
dtype
,
seed
):
torch
.
manual_seed
(
seed
)
cu_count
=
current_platform
.
get_cu_count
()
cu_count
=
get_cu_count
()
xavier
=
math
.
sqrt
(
2
/
k
)
# normalize to avoid large output-bias deltas
A
=
(
torch
.
rand
(
n
,
k
,
dtype
=
dtype
,
device
=
"cuda"
)
-
0.5
)
*
xavier
...
...
@@ -121,7 +122,7 @@ def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_rocm
(),
reason
=
"only test for rocm"
)
def
test_rocm_wvsplitk_bias2D_kernel
(
n
,
k
,
m
,
dtype
,
seed
):
torch
.
manual_seed
(
seed
)
cu_count
=
current_platform
.
get_cu_count
()
cu_count
=
get_cu_count
()
xavier
=
math
.
sqrt
(
2
/
k
)
# normalize to avoid large output-bias deltas
A
=
(
torch
.
rand
(
n
,
k
,
dtype
=
dtype
,
device
=
"cuda"
)
-
0.5
)
*
xavier
...
...
@@ -153,7 +154,14 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
ref_out
=
torch
.
_scaled_mm
(
A
,
B
.
t
(),
out_dtype
=
dtype
,
scale_a
=
scale_a
,
scale_b
=
scale_b
)
out
=
ops
.
wvSplitKQ
(
B
,
A
,
dtype
,
scale_a
,
scale_b
,
current_platform
.
get_cu_count
())
out
=
ops
.
wvSplitKQ
(
B
,
A
,
dtype
,
scale_a
,
scale_b
,
get_cu_count
(),
)
assert
torch
.
allclose
(
out
,
ref_out
,
rtol
=
0.01
)
...
...
@@ -180,7 +188,13 @@ def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed):
A
,
B
.
t
(),
out_dtype
=
dtype
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
bias
=
BIAS
)
out
=
ops
.
wvSplitKQ
(
B
,
A
,
dtype
,
scale_a
,
scale_b
,
current_platform
.
get_cu_count
(),
BIAS
B
,
A
,
dtype
,
scale_a
,
scale_b
,
get_cu_count
(),
BIAS
,
)
assert
torch
.
allclose
(
out
,
ref_out
,
rtol
=
0.01
)
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
2dacd573
...
...
@@ -13,6 +13,7 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.platforms
import
current_platform
from
vllm.utils.flashinfer
import
flashinfer_scaled_fp8_mm
,
has_flashinfer
from
vllm.utils.platform_utils
import
get_cu_count
from
vllm.utils.torch_utils
import
direct_register_custom_op
# Input scaling factors are no longer optional in _scaled_mm starting
...
...
@@ -200,7 +201,7 @@ def rocm_per_tensor_w8a8_scaled_mm_impl(
out_dtype
,
scale_a
,
scale_b
,
current_platform
.
get_cu_count
(),
get_cu_count
(),
bias
,
)
else
:
...
...
vllm/model_executor/layers/utils.py
View file @
2dacd573
...
...
@@ -11,6 +11,7 @@ from vllm import envs
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.logger
import
init_logger
from
vllm.platforms
import
CpuArchEnum
,
current_platform
from
vllm.utils.platform_utils
import
get_cu_count
from
vllm.utils.torch_utils
import
direct_register_custom_op
logger
=
init_logger
(
__name__
)
...
...
@@ -151,7 +152,7 @@ def rocm_unquantized_gemm_impl(
x_view
=
x
.
reshape
(
-
1
,
x
.
size
(
-
1
))
if
m
>
8
and
0
<
n
<=
4
:
cu_count
=
current_platform
.
get_cu_count
()
cu_count
=
get_cu_count
()
out
=
ops
.
wvSplitK
(
weight
,
x_view
,
cu_count
,
bias
)
return
out
.
reshape
(
*
x
.
shape
[:
-
1
],
weight
.
shape
[
0
])
elif
m
%
4
==
0
and
n
==
1
and
k
<=
8192
and
bias
is
None
:
...
...
vllm/platforms/interface.py
View file @
2dacd573
...
...
@@ -545,13 +545,6 @@ class Platform:
cls
.
_global_graph_pool
=
self
.
graph_pool_handle
()
return
cls
.
_global_graph_pool
@
classmethod
def
get_cu_count
(
cls
,
device_id
:
int
=
0
)
->
int
:
"""
Returns the total number of compute units (CU) on single GPU.
"""
raise
NotImplementedError
@
classmethod
def
get_static_graph_wrapper_cls
(
cls
)
->
str
:
"""
...
...
vllm/platforms/rocm.py
View file @
2dacd573
...
...
@@ -423,10 +423,6 @@ class RocmPlatform(Platform):
def
opaque_attention_op
(
cls
)
->
bool
:
return
True
@
classmethod
def
get_cu_count
(
cls
,
device_id
:
int
=
0
)
->
int
:
return
torch
.
cuda
.
get_device_properties
(
device_id
).
multi_processor_count
@
classmethod
def
is_navi
(
cls
)
->
bool
:
return
"gfx1"
in
torch
.
cuda
.
get_device_properties
(
0
).
gcnArchName
...
...
vllm/utils/platform_utils.py
View file @
2dacd573
...
...
@@ -24,6 +24,11 @@ def xpu_is_initialized() -> bool:
return
torch
.
xpu
.
is_initialized
()
def
get_cu_count
(
cls
,
device_id
:
int
=
0
)
->
int
:
"""Returns the total number of compute units (CU) on single GPU."""
return
torch
.
cuda
.
get_device_properties
(
device_id
).
multi_processor_count
def
cuda_get_device_properties
(
device
,
names
:
Sequence
[
str
],
init_cuda
=
False
)
->
tuple
[
Any
,
...]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment