Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d830656a
Unverified
Commit
d830656a
authored
Jul 02, 2024
by
Nick Hill
Committed by
GitHub
Jul 03, 2024
Browse files
[BugFix] Avoid unnecessary Ray import warnings (#6079)
parent
d18bab35
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
9 deletions
+28
-9
vllm/config.py
vllm/config.py
+7
-2
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+5
-0
vllm/executor/ray_utils.py
vllm/executor/ray_utils.py
+16
-7
No files found.
vllm/config.py
View file @
d830656a
...
...
@@ -682,11 +682,13 @@ class ParallelConfig:
from
vllm.executor
import
ray_utils
backend
=
"mp"
ray_found
=
ray_utils
.
ray
is
not
None
ray_found
=
ray_utils
.
ray
_
is
_available
()
if
cuda_device_count_stateless
()
<
self
.
world_size
:
if
not
ray_found
:
raise
ValueError
(
"Unable to load Ray which is "
"required for multi-node inference"
)
"required for multi-node inference, "
"please install Ray with `pip install "
"ray`."
)
from
ray_utils
.
ray_import_err
backend
=
"ray"
elif
ray_found
:
if
self
.
placement_group
:
...
...
@@ -718,6 +720,9 @@ class ParallelConfig:
raise
ValueError
(
"Unrecognized distributed executor backend. Supported values "
"are 'ray' or 'mp'."
)
if
self
.
distributed_executor_backend
==
"ray"
:
from
vllm.executor
import
ray_utils
ray_utils
.
assert_ray_available
()
if
not
self
.
disable_custom_all_reduce
and
self
.
world_size
>
1
:
if
is_hip
():
self
.
disable_custom_all_reduce
=
True
...
...
vllm/engine/async_llm_engine.py
View file @
d830656a
...
...
@@ -380,6 +380,11 @@ class AsyncLLMEngine:
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_config
=
engine_args
.
create_engine_config
()
if
engine_args
.
engine_use_ray
:
from
vllm.executor
import
ray_utils
ray_utils
.
assert_ray_available
()
distributed_executor_backend
=
(
engine_config
.
parallel_config
.
distributed_executor_backend
)
...
...
vllm/executor/ray_utils.py
View file @
d830656a
...
...
@@ -42,14 +42,26 @@ try:
output
=
pickle
.
dumps
(
output
)
return
output
ray_import_err
=
None
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import Ray with %r. For multi-node inference, "
"please install Ray with `pip install ray`."
,
e
)
ray
=
None
# type: ignore
ray_import_err
=
e
RayWorkerWrapper
=
None
# type: ignore
def
ray_is_available
()
->
bool
:
"""Returns True if Ray is available."""
return
ray
is
not
None
def
assert_ray_available
():
"""Raise an exception if Ray is not available."""
if
ray
is
None
:
raise
ValueError
(
"Failed to import Ray, please install Ray with "
"`pip install ray`."
)
from
ray_import_err
def
initialize_ray_cluster
(
parallel_config
:
ParallelConfig
,
ray_address
:
Optional
[
str
]
=
None
,
...
...
@@ -65,10 +77,7 @@ def initialize_ray_cluster(
ray_address: The address of the Ray cluster. If None, uses
the default Ray cluster address.
"""
if
ray
is
None
:
raise
ImportError
(
"Ray is not installed. Please install Ray to use multi-node "
"serving."
)
assert_ray_available
()
# Connect to a ray cluster.
if
is_hip
()
or
is_xpu
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment