Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f842a7af
Unverified
Commit
f842a7af
authored
Sep 11, 2024
by
youkaichao
Committed by
GitHub
Sep 11, 2024
Browse files
[misc] remove engine_use_ray (#8126)
parent
a65cb160
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
32 additions
and
197 deletions
+32
-197
tests/async_engine/test_api_server.py
tests/async_engine/test_api_server.py
+4
-14
tests/async_engine/test_async_llm_engine.py
tests/async_engine/test_async_llm_engine.py
+3
-11
tests/async_engine/test_openapi_server.py
tests/async_engine/test_openapi_server.py
+1
-6
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+0
-11
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+21
-142
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+3
-3
vllm/entrypoints/openai/run_batch.py
vllm/entrypoints/openai/run_batch.py
+0
-1
vllm/envs.py
vllm/envs.py
+0
-9
No files found.
tests/async_engine/test_api_server.py
View file @
f842a7af
import
os
import
subprocess
import
sys
import
time
...
...
@@ -26,8 +25,7 @@ def _query_server_long(prompt: str) -> dict:
@
pytest
.
fixture
def
api_server
(
tokenizer_pool_size
:
int
,
engine_use_ray
:
bool
,
worker_use_ray
:
bool
):
def
api_server
(
tokenizer_pool_size
:
int
,
worker_use_ray
:
bool
):
script_path
=
Path
(
__file__
).
parent
.
joinpath
(
"api_server_async_engine.py"
).
absolute
()
commands
=
[
...
...
@@ -37,25 +35,17 @@ def api_server(tokenizer_pool_size: int, engine_use_ray: bool,
str
(
tokenizer_pool_size
)
]
# Copy the environment variables and append `VLLM_ALLOW_ENGINE_USE_RAY=1`
# to prevent `--engine-use-ray` raises an exception due to it deprecation
env_vars
=
os
.
environ
.
copy
()
env_vars
[
"VLLM_ALLOW_ENGINE_USE_RAY"
]
=
"1"
if
engine_use_ray
:
commands
.
append
(
"--engine-use-ray"
)
if
worker_use_ray
:
commands
.
append
(
"--worker-use-ray"
)
uvicorn_process
=
subprocess
.
Popen
(
commands
,
env
=
env_vars
)
uvicorn_process
=
subprocess
.
Popen
(
commands
)
yield
uvicorn_process
.
terminate
()
@
pytest
.
mark
.
parametrize
(
"tokenizer_pool_size"
,
[
0
,
2
])
@
pytest
.
mark
.
parametrize
(
"worker_use_ray"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"engine_use_ray"
,
[
False
,
True
])
def
test_api_server
(
api_server
,
tokenizer_pool_size
:
int
,
worker_use_ray
:
bool
,
engine_use_ray
:
bool
):
def
test_api_server
(
api_server
,
tokenizer_pool_size
:
int
,
worker_use_ray
:
bool
):
"""
Run the API server and test it.
...
...
tests/async_engine/test_async_llm_engine.py
View file @
f842a7af
import
asyncio
import
os
from
asyncio
import
CancelledError
from
dataclasses
import
dataclass
from
typing
import
Optional
...
...
@@ -72,14 +71,12 @@ class MockEngine:
class
MockAsyncLLMEngine
(
AsyncLLMEngine
):
def
_init_engine
(
self
,
*
args
,
**
kwargs
):
return
MockEngine
()
_engine_class
=
MockEngine
@
pytest
.
mark
.
asyncio
async
def
test_new_requests_event
():
engine
=
MockAsyncLLMEngine
(
worker_use_ray
=
False
,
engine_use_ray
=
False
)
engine
=
MockAsyncLLMEngine
(
worker_use_ray
=
False
)
engine
.
start_background_loop
()
await
asyncio
.
sleep
(
0.01
)
assert
engine
.
engine
.
step_calls
==
0
...
...
@@ -112,16 +109,11 @@ async def test_new_requests_event():
assert
engine
.
engine
.
add_request_calls
==
3
assert
engine
.
engine
.
step_calls
==
old_step_calls
+
1
# Allow deprecated engine_use_ray to not raise exception
os
.
environ
[
"VLLM_ALLOW_ENGINE_USE_RAY"
]
=
"1"
engine
=
MockAsyncLLMEngine
(
worker_use_ray
=
True
,
engine_use_ray
=
True
)
engine
=
MockAsyncLLMEngine
(
worker_use_ray
=
True
)
assert
engine
.
get_model_config
()
is
not
None
assert
engine
.
get_tokenizer
()
is
not
None
assert
engine
.
get_decoding_config
()
is
not
None
os
.
environ
.
pop
(
"VLLM_ALLOW_ENGINE_USE_RAY"
)
def
start_engine
():
wait_for_gpu_memory_to_clear
(
...
...
tests/async_engine/test_openapi_server
_ray
.py
→
tests/async_engine/test_openapi_server.py
View file @
f842a7af
...
...
@@ -19,16 +19,11 @@ def server():
"--max-model-len"
,
"2048"
,
"--enforce-eager"
,
"--engine-use-ray"
,
"--chat-template"
,
str
(
chatml_jinja_path
),
]
# Allow `--engine-use-ray`, otherwise the launch of the server throw
# an error due to try to use a deprecated feature
env_dict
=
{
"VLLM_ALLOW_ENGINE_USE_RAY"
:
"1"
}
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
,
env_dict
=
env_dict
)
as
remote_server
:
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
yield
remote_server
...
...
vllm/engine/arg_utils.py
View file @
f842a7af
...
...
@@ -1035,7 +1035,6 @@ class EngineArgs:
@
dataclass
class
AsyncEngineArgs
(
EngineArgs
):
"""Arguments for asynchronous vLLM engine."""
engine_use_ray
:
bool
=
False
disable_log_requests
:
bool
=
False
@
staticmethod
...
...
@@ -1043,16 +1042,6 @@ class AsyncEngineArgs(EngineArgs):
async_args_only
:
bool
=
False
)
->
FlexibleArgumentParser
:
if
not
async_args_only
:
parser
=
EngineArgs
.
add_cli_args
(
parser
)
parser
.
add_argument
(
'--engine-use-ray'
,
action
=
'store_true'
,
help
=
'Use Ray to start the LLM engine in a '
'separate process as the server process.'
'(DEPRECATED. This argument is deprecated '
'and will be removed in a future update. '
'Set `VLLM_ALLOW_ENGINE_USE_RAY=1` to force '
'use it. See '
'https://github.com/vllm-project/vllm/issues/7045.'
')'
)
parser
.
add_argument
(
'--disable-log-requests'
,
action
=
'store_true'
,
help
=
'Disable logging requests.'
)
...
...
vllm/engine/async_llm_engine.py
View file @
f842a7af
...
...
@@ -16,7 +16,7 @@ from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
PromptComponents
,
SchedulerOutputState
)
from
vllm.engine.metrics_types
import
StatLoggerBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
from
vllm.executor.ray_utils
import
initialize_ray_cluster
,
ray
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.inputs
import
(
EncoderDecoderLLMInputs
,
LLMInputs
,
PromptInputs
,
SingletonPromptInputs
)
from
vllm.inputs.parse
import
is_explicit_encoder_decoder_prompt
...
...
@@ -30,7 +30,6 @@ from vllm.sampling_params import SamplingParams
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
print_warning_once
logger
=
init_logger
(
__name__
)
ENGINE_ITERATION_TIMEOUT_S
=
envs
.
VLLM_ENGINE_ITERATION_TIMEOUT_S
...
...
@@ -590,9 +589,6 @@ class AsyncLLMEngine:
worker_use_ray: Whether to use Ray for model workers. Required for
distributed execution. Should be the same as
`parallel_config.worker_use_ray`.
engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
async frontend will be executed in a separate process as the
model workers.
log_requests: Whether to log the requests.
start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call.
...
...
@@ -604,41 +600,23 @@ class AsyncLLMEngine:
def
__init__
(
self
,
worker_use_ray
:
bool
,
engine_use_ray
:
bool
,
*
args
,
log_requests
:
bool
=
True
,
start_engine_loop
:
bool
=
True
,
**
kwargs
)
->
None
:
self
.
worker_use_ray
=
worker_use_ray
self
.
engine_use_ray
=
engine_use_ray
self
.
log_requests
=
log_requests
self
.
engine
=
self
.
_
init_
engine
(
*
args
,
**
kwargs
)
self
.
engine
=
self
.
_engine
_class
(
*
args
,
**
kwargs
)
# This ensures quick processing of request outputs
# so the append to asyncio queues is not delayed,
# especially for multi-step.
#
# TODO: Currently, disabled for engine_use_ray, ask
# Cody/Will/Woosuk about this case.
self
.
use_process_request_outputs_callback
=
not
self
.
engine_use_ray
self
.
use_process_request_outputs_callback
=
True
if
self
.
use_process_request_outputs_callback
:
self
.
engine
.
process_request_outputs_callback
=
\
self
.
process_request_outputs
if
self
.
engine_use_ray
:
print_warning_once
(
"DEPRECATED. `--engine-use-ray` is deprecated and will "
"be removed in a future update. "
"See https://github.com/vllm-project/vllm/issues/7045."
)
if
envs
.
VLLM_ALLOW_ENGINE_USE_RAY
:
print_warning_once
(
"VLLM_ALLOW_ENGINE_USE_RAY is set, force engine use Ray"
)
else
:
raise
ValueError
(
"`--engine-use-ray` is deprecated. "
"Set `VLLM_ALLOW_ENGINE_USE_RAY=1` to "
"force use it"
)
self
.
background_loop
:
Optional
[
asyncio
.
Future
]
=
None
# We need to keep a reference to unshielded
# task as well to prevent it from being garbage
...
...
@@ -725,16 +703,11 @@ class AsyncLLMEngine:
# Create the engine configs.
engine_config
=
engine_args
.
create_engine_config
()
if
engine_args
.
engine_use_ray
:
from
vllm.executor
import
ray_utils
ray_utils
.
assert_ray_available
()
executor_class
=
cls
.
_get_executor_cls
(
engine_config
)
# Create the async LLM engine.
engine
=
cls
(
executor_class
.
uses_ray
,
engine_args
.
engine_use_ray
,
**
engine_config
.
to_dict
(),
executor_class
=
executor_class
,
log_requests
=
not
engine_args
.
disable_log_requests
,
...
...
@@ -777,10 +750,6 @@ class AsyncLLMEngine:
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
AnyTokenizer
:
if
self
.
engine_use_ray
:
return
await
self
.
engine
.
get_tokenizer
.
remote
(
# type: ignore
lora_request
)
return
await
(
self
.
engine
.
get_tokenizer_group
().
get_lora_tokenizer_async
(
lora_request
))
...
...
@@ -814,26 +783,6 @@ class AsyncLLMEngine:
self
.
_background_loop_unshielded
=
None
self
.
background_loop
=
None
def
_init_engine
(
self
,
*
args
,
**
kwargs
)
->
Union
[
_AsyncLLMEngine
,
"ray.ObjectRef"
]:
if
not
self
.
engine_use_ray
:
engine_class
=
self
.
_engine_class
elif
self
.
worker_use_ray
:
engine_class
=
ray
.
remote
(
num_cpus
=
0
)(
self
.
_engine_class
).
remote
else
:
# FIXME(woosuk): This is a bit hacky. Be careful when changing the
# order of the arguments.
cache_config
=
kwargs
[
"cache_config"
]
parallel_config
=
kwargs
[
"parallel_config"
]
if
(
parallel_config
.
tensor_parallel_size
==
1
and
parallel_config
.
pipeline_parallel_size
==
1
):
num_gpus
=
cache_config
.
gpu_memory_utilization
else
:
num_gpus
=
1
engine_class
=
ray
.
remote
(
num_gpus
=
num_gpus
)(
self
.
_engine_class
).
remote
return
engine_class
(
*
args
,
**
kwargs
)
async
def
engine_step
(
self
,
virtual_engine
:
int
)
->
bool
:
"""Kick the engine to process the waiting requests.
...
...
@@ -844,13 +793,8 @@ class AsyncLLMEngine:
for
new_request
in
new_requests
:
# Add the request into the vLLM engine's waiting queue.
# TODO: Maybe add add_request_batch to reduce Ray overhead
try
:
if
self
.
engine_use_ray
:
await
self
.
engine
.
add_request
.
remote
(
# type: ignore
**
new_request
)
else
:
await
self
.
engine
.
add_request_async
(
**
new_request
)
await
self
.
engine
.
add_request_async
(
**
new_request
)
except
ValueError
as
e
:
# TODO: use a vLLM specific error for failed validation
self
.
_request_tracker
.
process_exception
(
...
...
@@ -862,10 +806,7 @@ class AsyncLLMEngine:
if
aborted_requests
:
await
self
.
_engine_abort
(
aborted_requests
)
if
self
.
engine_use_ray
:
request_outputs
=
await
self
.
engine
.
step
.
remote
()
# type: ignore
else
:
request_outputs
=
await
self
.
engine
.
step_async
(
virtual_engine
)
request_outputs
=
await
self
.
engine
.
step_async
(
virtual_engine
)
# Put the outputs into the corresponding streams.
# If used as a callback, then already invoked inside
...
...
@@ -891,16 +832,10 @@ class AsyncLLMEngine:
return
all_finished
async
def
_engine_abort
(
self
,
request_ids
:
Iterable
[
str
]):
if
self
.
engine_use_ray
:
await
self
.
engine
.
abort_request
.
remote
(
request_ids
)
# type: ignore
else
:
self
.
engine
.
abort_request
(
request_ids
)
self
.
engine
.
abort_request
(
request_ids
)
async
def
run_engine_loop
(
self
):
if
self
.
engine_use_ray
:
pipeline_parallel_size
=
1
# type: ignore
else
:
pipeline_parallel_size
=
\
pipeline_parallel_size
=
\
self
.
engine
.
parallel_config
.
pipeline_parallel_size
has_requests_in_progress
=
[
False
]
*
pipeline_parallel_size
while
True
:
...
...
@@ -912,12 +847,7 @@ class AsyncLLMEngine:
# timeout, and unblocks the RPC thread in the workers so that
# they can process any other queued control plane messages,
# such as add/remove lora adapters.
if
self
.
engine_use_ray
:
await
(
self
.
engine
.
stop_remote_worker_execution_loop
.
remote
()
# type: ignore
)
else
:
await
self
.
engine
.
stop_remote_worker_execution_loop_async
()
await
self
.
engine
.
stop_remote_worker_execution_loop_async
()
await
self
.
_request_tracker
.
wait_for_new_requests
()
logger
.
debug
(
"Got new requests!"
)
requests_in_progress
=
[
...
...
@@ -938,17 +868,9 @@ class AsyncLLMEngine:
for
task
in
done
:
result
=
task
.
result
()
virtual_engine
=
requests_in_progress
.
index
(
task
)
if
self
.
engine_use_ray
:
has_unfinished_requests
=
(
await
(
self
.
engine
.
has_unfinished_requests_for_virtual_engine
.
remote
(
# type: ignore
virtual_engine
)))
else
:
has_unfinished_requests
=
(
self
.
engine
.
has_unfinished_requests_for_virtual_engine
(
virtual_engine
))
has_unfinished_requests
=
(
self
.
engine
.
has_unfinished_requests_for_virtual_engine
(
virtual_engine
))
if
result
or
has_unfinished_requests
:
requests_in_progress
[
virtual_engine
]
=
(
asyncio
.
create_task
(
...
...
@@ -1190,52 +1112,29 @@ class AsyncLLMEngine:
async
def
get_model_config
(
self
)
->
ModelConfig
:
"""Get the model configuration of the vLLM engine."""
if
self
.
engine_use_ray
:
return
await
self
.
engine
.
get_model_config
.
remote
()
# type: ignore
else
:
return
self
.
engine
.
get_model_config
()
return
self
.
engine
.
get_model_config
()
async
def
get_parallel_config
(
self
)
->
ParallelConfig
:
"""Get the parallel configuration of the vLLM engine."""
if
self
.
engine_use_ray
:
return
await
self
.
engine
.
get_parallel_config
.
remote
(
# type: ignore
)
else
:
return
self
.
engine
.
get_parallel_config
()
return
self
.
engine
.
get_parallel_config
()
async
def
get_decoding_config
(
self
)
->
DecodingConfig
:
"""Get the decoding configuration of the vLLM engine."""
if
self
.
engine_use_ray
:
return
await
self
.
engine
.
get_decoding_config
.
remote
(
# type: ignore
)
else
:
return
self
.
engine
.
get_decoding_config
()
return
self
.
engine
.
get_decoding_config
()
async
def
get_scheduler_config
(
self
)
->
SchedulerConfig
:
"""Get the scheduling configuration of the vLLM engine."""
if
self
.
engine_use_ray
:
return
await
self
.
engine
.
get_scheduler_config
.
remote
(
# type: ignore
)
else
:
return
self
.
engine
.
get_scheduler_config
()
return
self
.
engine
.
get_scheduler_config
()
async
def
get_lora_config
(
self
)
->
LoRAConfig
:
"""Get the lora configuration of the vLLM engine."""
if
self
.
engine_use_ray
:
return
await
self
.
engine
.
get_lora_config
.
remote
(
# type: ignore
)
else
:
return
self
.
engine
.
get_lora_config
()
return
self
.
engine
.
get_lora_config
()
async
def
do_log_stats
(
self
,
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
,
model_output
:
Optional
[
List
[
SamplerOutput
]]
=
None
)
->
None
:
if
self
.
engine_use_ray
:
await
self
.
engine
.
do_log_stats
.
remote
(
# type: ignore
scheduler_outputs
,
model_output
)
else
:
self
.
engine
.
do_log_stats
()
self
.
engine
.
do_log_stats
()
async
def
check_health
(
self
)
->
None
:
"""Raises an error if engine is unhealthy."""
...
...
@@ -1244,37 +1143,17 @@ class AsyncLLMEngine:
if
self
.
is_stopped
:
raise
AsyncEngineDeadError
(
"Background loop is stopped."
)
if
self
.
engine_use_ray
:
try
:
await
self
.
engine
.
check_health
.
remote
()
# type: ignore
except
ray
.
exceptions
.
RayActorError
as
e
:
raise
RuntimeError
(
"Engine is dead."
)
from
e
else
:
await
self
.
engine
.
check_health_async
()
await
self
.
engine
.
check_health_async
()
logger
.
debug
(
"Health check took %fs"
,
time
.
perf_counter
()
-
t
)
async
def
is_tracing_enabled
(
self
)
->
bool
:
if
self
.
engine_use_ray
:
return
await
self
.
engine
.
is_tracing_enabled
.
remote
(
# type: ignore
)
else
:
return
self
.
engine
.
is_tracing_enabled
()
return
self
.
engine
.
is_tracing_enabled
()
def
add_logger
(
self
,
logger_name
:
str
,
logger
:
StatLoggerBase
)
->
None
:
if
self
.
engine_use_ray
:
ray
.
get
(
self
.
engine
.
add_logger
.
remote
(
# type: ignore
logger_name
=
logger_name
,
logger
=
logger
))
else
:
self
.
engine
.
add_logger
(
logger_name
=
logger_name
,
logger
=
logger
)
self
.
engine
.
add_logger
(
logger_name
=
logger_name
,
logger
=
logger
)
def
remove_logger
(
self
,
logger_name
:
str
)
->
None
:
if
self
.
engine_use_ray
:
ray
.
get
(
self
.
engine
.
remove_logger
.
remote
(
# type: ignore
logger_name
=
logger_name
))
else
:
self
.
engine
.
remove_logger
(
logger_name
=
logger_name
)
self
.
engine
.
remove_logger
(
logger_name
=
logger_name
)
async
def
start_profile
(
self
)
->
None
:
self
.
engine
.
model_executor
.
_run_workers
(
"start_profile"
)
...
...
vllm/engine/llm_engine.py
View file @
f842a7af
...
...
@@ -3,8 +3,8 @@ import time
from
collections
import
deque
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
ClassVar
,
Deque
,
Dict
,
Iterable
,
List
,
Mapping
,
NamedTuple
,
Optional
)
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
ClassVar
,
Deque
,
Dict
,
Iterable
,
List
,
Mapping
,
NamedTuple
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Tuple
,
Type
,
Union
...
...
@@ -397,7 +397,7 @@ class LLMEngine:
# Currently used by AsyncLLMEngine to ensure quick append
# of request outputs to asyncio queues
self
.
process_request_outputs_callback
=
None
self
.
process_request_outputs_callback
:
Optional
[
Callable
]
=
None
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
...
...
vllm/entrypoints/openai/run_batch.py
View file @
f842a7af
...
...
@@ -195,7 +195,6 @@ async def main(args):
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
,
usage_context
=
UsageContext
.
OPENAI_BATCH_RUNNER
)
# When using single vLLM without engine_use_ray
model_config
=
await
engine
.
get_model_config
()
if
args
.
disable_log_requests
:
...
...
vllm/envs.py
View file @
f842a7af
...
...
@@ -58,7 +58,6 @@ if TYPE_CHECKING:
VLLM_ALLOW_LONG_MAX_MODEL_LEN
:
bool
=
False
VLLM_TEST_FORCE_FP8_MARLIN
:
bool
=
False
VLLM_RPC_GET_DATA_TIMEOUT_MS
:
int
=
5000
VLLM_ALLOW_ENGINE_USE_RAY
:
bool
=
False
VLLM_PLUGINS
:
Optional
[
List
[
str
]]
=
None
VLLM_TORCH_PROFILER_DIR
:
Optional
[
str
]
=
None
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
bool
=
False
...
...
@@ -391,14 +390,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_RPC_GET_DATA_TIMEOUT_MS"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_RPC_GET_DATA_TIMEOUT_MS"
,
"5000"
)),
# If set, allow running the engine as a separate ray actor,
# which is a deprecated feature soon to be removed.
# See https://github.com/vllm-project/vllm/issues/7045
"VLLM_ALLOW_ENGINE_USE_RAY"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_ALLOW_ENGINE_USE_RAY"
,
"0"
).
strip
().
lower
()
in
(
"1"
,
"true"
)),
# a list of plugin names to load, separated by commas.
# if this is not set, it means all plugins will be loaded
# if this is set to an empty string, no plugins will be loaded
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment