Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4e7ee664
Unverified
Commit
4e7ee664
authored
Apr 16, 2024
by
SangBin Cho
Committed by
GitHub
Apr 16, 2024
Browse files
[Core] Fix engine-use-ray broken (#4105)
parent
37e84a40
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
8 deletions
+16
-8
tests/async_engine/test_api_server.py
tests/async_engine/test_api_server.py
+13
-4
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+3
-4
No files found.
tests/async_engine/test_api_server.py
View file @
4e7ee664
...
...
@@ -25,21 +25,30 @@ def _query_server_long(prompt: str) -> dict:
@
pytest
.
fixture
def
api_server
(
tokenizer_pool_size
:
int
):
def
api_server
(
tokenizer_pool_size
:
int
,
engine_use_ray
:
bool
,
worker_use_ray
:
bool
):
script_path
=
Path
(
__file__
).
parent
.
joinpath
(
"api_server_async_engine.py"
).
absolute
()
uvicorn_process
=
subprocess
.
Popen
(
[
commands
=
[
sys
.
executable
,
"-u"
,
str
(
script_path
),
"--model"
,
"facebook/opt-125m"
,
"--host"
,
"127.0.0.1"
,
"--tokenizer-pool-size"
,
str
(
tokenizer_pool_size
)
])
]
if
engine_use_ray
:
commands
.
append
(
"--engine-use-ray"
)
if
worker_use_ray
:
commands
.
append
(
"--worker-use-ray"
)
uvicorn_process
=
subprocess
.
Popen
(
commands
)
yield
uvicorn_process
.
terminate
()
@
pytest
.
mark
.
parametrize
(
"tokenizer_pool_size"
,
[
0
,
2
])
def
test_api_server
(
api_server
,
tokenizer_pool_size
:
int
):
@
pytest
.
mark
.
parametrize
(
"worker_use_ray"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"engine_use_ray"
,
[
False
,
True
])
def
test_api_server
(
api_server
,
tokenizer_pool_size
:
int
,
worker_use_ray
:
bool
,
engine_use_ray
:
bool
):
"""
Run the API server and test it.
...
...
vllm/engine/async_llm_engine.py
View file @
4e7ee664
...
...
@@ -333,8 +333,7 @@ class AsyncLLMEngine:
if
engine_config
.
device_config
.
device_type
==
"neuron"
:
raise
NotImplementedError
(
"Neuron is not supported for "
"async engine yet."
)
elif
(
engine_config
.
parallel_config
.
worker_use_ray
or
engine_args
.
engine_use_ray
):
elif
engine_config
.
parallel_config
.
worker_use_ray
:
initialize_ray_cluster
(
engine_config
.
parallel_config
)
from
vllm.executor.ray_gpu_executor
import
RayGPUExecutorAsync
executor_class
=
RayGPUExecutorAsync
...
...
@@ -410,8 +409,8 @@ class AsyncLLMEngine:
else
:
# FIXME(woosuk): This is a bit hacky. Be careful when changing the
# order of the arguments.
cache_config
=
args
[
1
]
parallel_config
=
args
[
2
]
cache_config
=
kw
args
[
"cache_config"
]
parallel_config
=
kw
args
[
"parallel_config"
]
if
parallel_config
.
tensor_parallel_size
==
1
:
num_gpus
=
cache_config
.
gpu_memory_utilization
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment