Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5e666f72
Unverified
Commit
5e666f72
authored
Jun 19, 2025
by
kourosh hakhamaneshi
Committed by
GitHub
Jun 19, 2025
Browse files
[Bugfix][Ray] Set the cuda context eagerly in the ray worker (#19583)
parent
e3a3e4db
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
107 additions
and
0 deletions
+107
-0
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+9
-0
tests/cuda/test_cuda_context.py
tests/cuda/test_cuda_context.py
+80
-0
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+11
-0
vllm/platforms/interface.py
vllm/platforms/interface.py
+7
-0
No files found.
.buildkite/test-pipeline.yaml
View file @
5e666f72
...
...
@@ -271,6 +271,15 @@ steps:
commands
:
-
pytest -v -s prefix_caching
-
label
:
Platform Tests (CUDA)
mirror_hardwares
:
[
amdexperimental
]
source_file_dependencies
:
-
vllm/
-
tests/cuda
commands
:
-
pytest -v -s cuda/test_cuda_context.py
-
label
:
Samplers Test
# 36min
mirror_hardwares
:
[
amdexperimental
]
source_file_dependencies
:
...
...
tests/cuda/test_cuda_context.py
0 → 100644
View file @
5e666f72
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
ctypes
from
concurrent.futures
import
ThreadPoolExecutor
import
pytest
import
torch
from
vllm.platforms
import
current_platform
def
check_cuda_context
():
"""Check CUDA driver context status"""
try
:
cuda
=
ctypes
.
CDLL
(
'libcuda.so'
)
device
=
ctypes
.
c_int
()
result
=
cuda
.
cuCtxGetDevice
(
ctypes
.
byref
(
device
))
return
(
True
,
device
.
value
)
if
result
==
0
else
(
False
,
None
)
except
Exception
:
return
False
,
None
def
run_cuda_test_in_thread
(
device_input
,
expected_device_id
):
"""Run CUDA context test in separate thread for isolation"""
try
:
# New thread should have no CUDA context initially
valid_before
,
device_before
=
check_cuda_context
()
if
valid_before
:
return
False
,
\
"CUDA context should not exist in new thread, "
\
f
"got device
{
device_before
}
"
# Test setting CUDA context
current_platform
.
set_device
(
device_input
)
# Verify context is created correctly
valid_after
,
device_id
=
check_cuda_context
()
if
not
valid_after
:
return
False
,
"CUDA context should be valid after set_cuda_context"
if
device_id
!=
expected_device_id
:
return
False
,
\
f
"Expected device
{
expected_device_id
}
, got
{
device_id
}
"
return
True
,
"Success"
except
Exception
as
e
:
return
False
,
f
"Exception in thread:
{
str
(
e
)
}
"
class
TestSetCudaContext
:
"""Test suite for the set_cuda_context function."""
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"CUDA not available"
)
@
pytest
.
mark
.
parametrize
(
argnames
=
"device_input,expected_device_id"
,
argvalues
=
[
(
0
,
0
),
(
torch
.
device
(
'cuda:0'
),
0
),
(
'cuda:0'
,
0
),
],
ids
=
[
"int"
,
"torch_device"
,
"string"
])
def
test_set_cuda_context_parametrized
(
self
,
device_input
,
expected_device_id
):
"""Test setting CUDA context in isolated threads."""
with
ThreadPoolExecutor
(
max_workers
=
1
)
as
executor
:
future
=
executor
.
submit
(
run_cuda_test_in_thread
,
device_input
,
expected_device_id
)
success
,
message
=
future
.
result
(
timeout
=
30
)
assert
success
,
message
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"CUDA not available"
)
def
test_set_cuda_context_invalid_device_type
(
self
):
"""Test error handling for invalid device type."""
with
pytest
.
raises
(
ValueError
,
match
=
"Expected a cuda device"
):
current_platform
.
set_device
(
torch
.
device
(
'cpu'
))
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
,
"-v"
])
vllm/platforms/cuda.py
View file @
5e666f72
...
...
@@ -71,6 +71,17 @@ class CudaPlatformBase(Platform):
# though vLLM doesn't support these GPUs.
return
[
torch
.
float32
]
@
classmethod
def
set_device
(
cls
,
device
:
torch
.
device
)
->
None
:
"""
Set the device for the current platform.
"""
super
().
set_device
(
device
)
# With this trick we can force the device to be set eagerly
# see https://github.com/pytorch/pytorch/issues/155668
# for why and when it is needed
_
=
torch
.
zeros
(
1
,
device
=
device
)
@
classmethod
def
get_device_capability
(
cls
,
device_id
:
int
=
0
...
...
vllm/platforms/interface.py
View file @
5e666f72
...
...
@@ -298,6 +298,13 @@ class Platform:
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
@
classmethod
def
set_device
(
cls
,
device
:
torch
.
device
)
->
None
:
"""
Set the device for the current platform.
"""
torch
.
cuda
.
set_device
(
device
)
@
classmethod
def
pre_register_and_update
(
cls
,
parser
:
Optional
[
FlexibleArgumentParser
]
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment