Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9206b3d7
Unverified
Commit
9206b3d7
authored
Feb 15, 2025
by
Cody Yu
Committed by
GitHub
Feb 15, 2025
Browse files
[V1][PP] Run engine busy loop with batch queue (#13064)
parent
ed0de3e4
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
299 additions
and
15 deletions
+299
-15
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+51
-0
tests/v1/engine/test_engine_core.py
tests/v1/engine/test_engine_core.py
+88
-1
vllm/v1/core/scheduler.py
vllm/v1/core/scheduler.py
+17
-0
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+73
-6
vllm/v1/executor/abstract.py
vllm/v1/executor/abstract.py
+9
-8
vllm/v1/executor/ray_distributed_executor.py
vllm/v1/executor/ray_distributed_executor.py
+61
-0
No files found.
tests/v1/core/test_scheduler.py
View file @
9206b3d7
...
...
@@ -213,3 +213,54 @@ def test_schedule_partial_requests():
assert
output
.
num_scheduled_tokens
[
requests
[
0
].
request_id
]
==
1
assert
output
.
num_scheduled_tokens
[
requests
[
1
].
request_id
]
==
700
assert
requests
[
2
].
request_id
not
in
output
.
num_scheduled_tokens
def
test_schedule_concurrent_batches
():
scheduler
=
create_scheduler
(
max_num_batched_tokens
=
1024
,
max_num_seqs
=
2
,
)
requests
=
create_requests
(
num_requests
=
2
,
num_tokens
=
512
,
)
# Schedule the first request.
scheduler
.
add_request
(
requests
[
0
])
scheduler_output0
=
scheduler
.
schedule
()
assert
len
(
scheduler_output0
.
scheduled_new_reqs
)
==
1
assert
scheduler_output0
.
num_scheduled_tokens
[
requests
[
0
].
request_id
]
==
512
# The first request is still running, so only schedule the second request.
scheduler
.
add_request
(
requests
[
1
])
scheduler_output1
=
scheduler
.
schedule
()
assert
len
(
scheduler_output1
.
scheduled_new_reqs
)
==
1
assert
scheduler_output1
.
num_scheduled_tokens
[
requests
[
1
].
request_id
]
==
512
# Model output of the first request.
model_runner_output
=
ModelRunnerOutput
(
req_ids
=
[
requests
[
0
].
request_id
],
req_id_to_index
=
{
requests
[
0
].
request_id
:
0
},
sampled_token_ids
=
[
0
],
logprobs
=
None
,
prompt_logprobs_dict
=
{},
)
scheduler
.
update_from_output
(
scheduler_output0
,
model_runner_output
)
# Schedule the next step.
# The first request can be scheduled again while the second
# request is still running.
scheduler_output2
=
scheduler
.
schedule
()
assert
scheduler_output2
.
num_scheduled_tokens
[
requests
[
0
].
request_id
]
==
1
# Model output of the second request.
model_runner_output
=
ModelRunnerOutput
(
req_ids
=
[
requests
[
1
].
request_id
],
req_id_to_index
=
{
requests
[
1
].
request_id
:
0
},
sampled_token_ids
=
[
0
],
logprobs
=
None
,
prompt_logprobs_dict
=
{},
)
scheduler
.
update_from_output
(
scheduler_output1
,
model_runner_output
)
tests/v1/engine/test_engine_core.py
View file @
9206b3d7
# SPDX-License-Identifier: Apache-2.0
import
copy
import
threading
import
time
import
uuid
from
concurrent.futures
import
Future
import
pytest
from
transformers
import
AutoTokenizer
...
...
@@ -12,7 +15,9 @@ from vllm.engine.arg_utils import EngineArgs
from
vllm.platforms
import
current_platform
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine.core
import
EngineCore
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.abstract
import
Executor
,
UniProcExecutor
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.outputs
import
ModelRunnerOutput
if
not
current_platform
.
is_cuda
():
pytest
.
skip
(
reason
=
"V1 currently only supported on CUDA."
,
...
...
@@ -191,3 +196,85 @@ def test_engine_core_advanced_sampling(monkeypatch):
)
engine_core
.
add_request
(
request2
)
_check_engine_state
()
@
fork_new_process_for_each_test
def
test_engine_core_concurrent_batches
(
monkeypatch
):
"""
Test that the engine can handle multiple concurrent batches.
"""
def
make_request_with_max_tokens
(
max_tokens
:
int
)
->
EngineCoreRequest
:
request
=
make_request
()
request
.
sampling_params
.
max_tokens
=
max_tokens
return
request
class
DummyExecutor
(
UniProcExecutor
):
def
initialize
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
super
().
initialize
(
kv_cache_config
)
# This executor actually can only run 1 batch at a time
self
.
semaphore
=
threading
.
Semaphore
(
1
)
def
execute_model
(
self
,
scheduler_output
,
)
->
Future
[
ModelRunnerOutput
]:
"""Make execute_model non-blocking."""
future
:
Future
[
ModelRunnerOutput
]
=
Future
()
def
_thread_wrapper
(
scheduler_output
,
future
):
with
self
.
semaphore
:
output
=
self
.
collective_rpc
(
"execute_model"
,
args
=
(
scheduler_output
,
))
# Make a copy because output[0] may be reused
# by the next batch.
output
=
copy
.
deepcopy
(
output
[
0
])
future
.
set_result
(
output
)
threading
.
Thread
(
target
=
_thread_wrapper
,
args
=
(
scheduler_output
,
future
)).
start
()
return
future
@
property
def
max_concurrent_batches
(
self
)
->
int
:
return
2
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
engine_args
=
EngineArgs
(
model
=
MODEL_NAME
,
# To test concurrent batches.
max_num_seqs
=
2
,
# Avoid all requests being scheduled once.
enable_prefix_caching
=
False
,
max_num_batched_tokens
=
10
,
)
vllm_config
=
engine_args
.
create_engine_config
()
engine_core
=
EngineCore
(
vllm_config
=
vllm_config
,
log_stats
=
False
,
executor_class
=
DummyExecutor
)
assert
engine_core
.
batch_queue
is
not
None
# Add two requests in a row.
req
=
make_request_with_max_tokens
(
5
)
engine_core
.
add_request
(
req
)
req
=
make_request_with_max_tokens
(
5
)
engine_core
.
add_request
(
req
)
# First saturate the batch queue.
assert
engine_core
.
step_with_batch_queue
()
is
None
assert
engine_core
.
batch_queue
.
qsize
()
==
1
assert
engine_core
.
step_with_batch_queue
()
is
None
assert
engine_core
.
batch_queue
.
qsize
()
==
2
assert
engine_core
.
scheduler
.
get_num_unfinished_requests
()
==
2
# Loop through both requests.
while
engine_core
.
scheduler
.
get_num_unfinished_requests
()
==
2
:
engine_core
.
step_with_batch_queue
()
# Reaching here when got the result of the first request.
while
engine_core
.
scheduler
.
get_num_unfinished_requests
()
==
1
:
engine_core
.
step_with_batch_queue
()
vllm/v1/core/scheduler.py
View file @
9206b3d7
...
...
@@ -58,6 +58,9 @@ class Scheduler:
# Priority queues for requests.
self
.
waiting
:
Deque
[
Request
]
=
deque
()
self
.
running
:
List
[
Request
]
=
[]
# The requests that have been scheduled and are being executed
# by the executor.
self
.
scheduled_req_ids
:
Set
[
str
]
=
set
()
# The request IDs that are finished in between the previous and the
# current steps. This is used to notify the workers about the finished
...
...
@@ -118,6 +121,11 @@ class Scheduler:
req_index
=
0
while
req_index
<
len
(
self
.
running
)
and
token_budget
>
0
:
request
=
self
.
running
[
req_index
]
if
request
.
request_id
in
self
.
scheduled_req_ids
:
# This request has already been scheduled.
req_index
+=
1
continue
num_new_tokens
=
request
.
num_tokens
-
request
.
num_computed_tokens
num_new_tokens
=
min
(
num_new_tokens
,
token_budget
)
assert
num_new_tokens
>
0
...
...
@@ -164,6 +172,7 @@ class Scheduler:
# Schedule the request.
scheduled_running_reqs
.
append
(
request
)
self
.
scheduled_req_ids
.
add
(
request
.
request_id
)
req_to_new_block_ids
[
request
.
request_id
]
=
[
b
.
block_id
for
b
in
new_blocks
]
...
...
@@ -251,6 +260,7 @@ class Scheduler:
self
.
waiting
.
popleft
()
self
.
running
.
append
(
request
)
self
.
scheduled_req_ids
.
add
(
request
.
request_id
)
if
request
.
status
==
RequestStatus
.
WAITING
:
scheduled_new_reqs
.
append
(
request
)
self
.
request_scheduled
(
request
,
scheduled_timestamp
)
...
...
@@ -519,6 +529,7 @@ class Scheduler:
stop_reason
=
request
.
stop_reason
,
events
=
request
.
take_events
()))
self
.
scheduled_req_ids
.
remove
(
request
.
request_id
)
if
not
stopped
:
new_running
.
append
(
request
)
...
...
@@ -575,6 +586,8 @@ class Scheduler:
if
request
.
status
==
RequestStatus
.
RUNNING
:
self
.
running
.
remove
(
request
)
if
request
.
request_id
in
self
.
scheduled_req_ids
:
self
.
scheduled_req_ids
.
remove
(
request
.
request_id
)
else
:
self
.
waiting
.
remove
(
request
)
request
.
status
=
finished_status
...
...
@@ -595,6 +608,10 @@ class Scheduler:
def
has_unfinished_requests
(
self
)
->
bool
:
return
self
.
get_num_unfinished_requests
()
>
0
def
get_num_unscheduled_requests
(
self
)
->
int
:
"""Number of requests that are not being processed by the executor."""
return
self
.
get_num_unfinished_requests
()
-
len
(
self
.
scheduled_req_ids
)
def
reset_prefix_cache
(
self
)
->
bool
:
return
self
.
kv_cache_manager
.
reset_prefix_cache
()
...
...
vllm/v1/engine/core.py
View file @
9206b3d7
...
...
@@ -4,8 +4,9 @@ import queue
import
signal
import
threading
import
time
from
concurrent.futures
import
Future
from
multiprocessing.connection
import
Connection
from
typing
import
Any
,
List
,
Tuple
,
Type
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Type
import
psutil
import
zmq
...
...
@@ -18,11 +19,12 @@ from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value
)
from
vllm.utils
import
get_exception_traceback
,
zmq_socket_ctx
from
vllm.v1.core.kv_cache_utils
import
get_kv_cache_configs
from
vllm.v1.core.scheduler
import
Scheduler
from
vllm.v1.core.scheduler
import
Scheduler
,
SchedulerOutput
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreRequest
,
EngineCoreRequestType
)
from
vllm.v1.engine.mm_input_cache
import
MMInputCacheServer
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
from
vllm.version
import
__version__
as
VLLM_VERSION
...
...
@@ -66,9 +68,22 @@ class EngineCore:
log_stats
=
self
.
log_stats
,
)
# Setup MM Input Mapper.
self
.
mm_input_cache_server
=
MMInputCacheServer
(
vllm_config
.
model_config
)
# Setup batch queue for pipeline parallelism.
# Batch queue for scheduled batches. This enables us to asynchronously
# schedule and execute batches, and is required by pipeline parallelism
# to eliminate pipeline bubbles.
self
.
batch_queue_size
=
self
.
model_executor
.
max_concurrent_batches
self
.
batch_queue
:
Optional
[
queue
.
Queue
[
Tuple
[
Future
[
ModelRunnerOutput
],
SchedulerOutput
]]]
=
None
if
self
.
batch_queue_size
>
1
:
logger
.
info
(
"Batch queue is enabled with size %d"
,
self
.
batch_queue_size
)
self
.
batch_queue
=
queue
.
Queue
(
self
.
batch_queue_size
)
def
_initialize_kv_caches
(
self
,
vllm_config
:
VllmConfig
)
->
Tuple
[
int
,
int
]:
start
=
time
.
time
()
...
...
@@ -135,7 +150,55 @@ class EngineCore:
scheduler_output
=
self
.
scheduler
.
schedule
()
output
=
self
.
model_executor
.
execute_model
(
scheduler_output
)
engine_core_outputs
=
self
.
scheduler
.
update_from_output
(
scheduler_output
,
output
)
scheduler_output
,
output
)
# type: ignore
return
engine_core_outputs
def
step_with_batch_queue
(
self
)
->
Optional
[
EngineCoreOutputs
]:
"""Schedule and execute batches with the batch queue.
Note that if nothing to output in this step, None is returned.
The execution flow is as follows:
1. Try to schedule a new batch if there are unscheduled requests
and the job queue is not full. If a new batch is scheduled, directly
return an empty engine core output. In other words, we won't check
and return model outputs before the batch queue is full.
2. If there is no new scheduled batch, meaning that the batch queue
is full or no other requests can be scheduled, we block until the first
batch in the job queue is finished.
3. Update the scheduler from the output.
"""
assert
self
.
batch_queue
is
not
None
engine_core_outputs
=
None
scheduler_output
=
None
# If there are unscheduled requests and the job queue
# is not full, schedule a new batch. Note that this is not blocking.
if
(
self
.
scheduler
.
get_num_unscheduled_requests
()
>
0
and
not
self
.
batch_queue
.
full
()):
scheduler_output
=
self
.
scheduler
.
schedule
()
if
scheduler_output
.
total_num_scheduled_tokens
>
0
:
future
=
self
.
model_executor
.
execute_model
(
scheduler_output
)
self
.
batch_queue
.
put_nowait
(
(
future
,
scheduler_output
))
# type: ignore
# If all requests are scheduled or the job queue is full,
# block until the first batch in the job queue is finished.
if
(
scheduler_output
is
None
or
scheduler_output
.
total_num_scheduled_tokens
==
0
):
try
:
future
,
scheduler_output
=
self
.
batch_queue
.
get
(
timeout
=
POLLING_TIMEOUT_S
)
# Blocking until the first result is available.
model_output
=
future
.
result
()
self
.
batch_queue
.
task_done
()
engine_core_outputs
=
self
.
scheduler
.
update_from_output
(
scheduler_output
,
model_output
)
except
queue
.
Empty
:
# If the queue is empty (timeout at .get), return
# an empty EngineCoreOutputs for logging.
engine_core_outputs
=
EngineCoreOutputs
(
outputs
=
[],
scheduler_stats
=
self
.
scheduler
.
make_stats
())
return
engine_core_outputs
def
shutdown
(
self
):
...
...
@@ -226,6 +289,9 @@ class EngineCoreProc(EngineCore):
def
run_busy_loop
(
self
):
"""Core busy loop of the EngineCore."""
step_fn
=
(
self
.
step
if
self
.
batch_queue
is
None
else
self
.
step_with_batch_queue
)
# Loop until process is sent a SIGINT or SIGTERM
while
True
:
# 1) Poll the input queue until there is work to do.
...
...
@@ -249,10 +315,11 @@ class EngineCoreProc(EngineCore):
self
.
_handle_client_request
(
*
req
)
# 3) Step the engine core.
outputs
=
s
elf
.
s
tep
()
outputs
=
step
_fn
()
# 5) Put EngineCoreOutputs into the output queue.
self
.
output_queue
.
put_nowait
(
outputs
)
# 4) Put EngineCoreOutputs into the output queue.
if
outputs
is
not
None
:
self
.
output_queue
.
put_nowait
(
outputs
)
def
_handle_client_request
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Any
)
->
None
:
...
...
vllm/v1/executor/abstract.py
View file @
9206b3d7
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Type
from
concurrent.futures
import
Future
from
typing
import
List
,
Type
,
Union
from
vllm.config
import
VllmConfig
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.executor.ray_distributed_executor
import
(
# noqa
RayDistributedExecutor
as
RayDistributedExecutorV0
)
from
vllm.executor.uniproc_executor
import
(
# noqa
ExecutorWithExternalLauncher
as
ExecutorWithExternalLauncherV0
)
from
vllm.executor.uniproc_executor
import
(
# noqa
...
...
@@ -33,6 +32,8 @@ class Executor(ExecutorBase):
f
"ExecutorBase. Got
{
distributed_executor_backend
}
."
)
executor_class
=
distributed_executor_backend
elif
distributed_executor_backend
==
"ray"
:
from
vllm.v1.executor.ray_distributed_executor
import
(
# noqa
RayDistributedExecutor
)
executor_class
=
RayDistributedExecutor
elif
distributed_executor_backend
==
"mp"
:
from
vllm.v1.executor.multiproc_executor
import
MultiprocExecutor
...
...
@@ -70,11 +71,15 @@ class Executor(ExecutorBase):
def
execute_model
(
self
,
scheduler_output
,
)
->
ModelRunnerOutput
:
)
->
Union
[
ModelRunnerOutput
,
Future
[
ModelRunnerOutput
]]
:
output
=
self
.
collective_rpc
(
"execute_model"
,
args
=
(
scheduler_output
,
))
return
output
[
0
]
@
property
def
max_concurrent_batches
(
self
)
->
int
:
return
1
def
profile
(
self
,
is_start
:
bool
=
True
):
self
.
collective_rpc
(
"profile"
,
args
=
(
is_start
,
))
...
...
@@ -85,7 +90,3 @@ class UniProcExecutor(UniProcExecutorV0, Executor):
class
ExecutorWithExternalLauncher
(
ExecutorWithExternalLauncherV0
,
Executor
):
pass
class
RayDistributedExecutor
(
RayDistributedExecutorV0
,
Executor
):
pass
vllm/v1/executor/ray_distributed_executor.py
0 → 100644
View file @
9206b3d7
# SPDX-License-Identifier: Apache-2.0
from
concurrent.futures
import
Future
from
typing
import
Union
from
vllm.executor.ray_distributed_executor
import
(
# noqa
RayDistributedExecutor
as
RayDistributedExecutorV0
)
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.outputs
import
ModelRunnerOutput
class
FutureWrapper
(
Future
):
"""A wrapper around a Ray output reference to meet the interface
of .execute_model().
"""
def
__init__
(
self
,
ref
):
super
().
__init__
()
self
.
ref
=
ref
def
result
(
self
,
timeout
=
None
):
if
timeout
is
not
None
:
raise
NotImplementedError
(
"timeout is not supported"
)
return
self
.
ref
.
get
()
class
RayDistributedExecutor
(
RayDistributedExecutorV0
,
Executor
):
"""Ray distributed executor using Ray Compiled Graphs."""
@
property
def
max_concurrent_batches
(
self
)
->
int
:
"""Ray distributed executor supports pipeline parallelism,
meaning that it allows PP size batches to be executed concurrently.
"""
return
1
#self.vllm_config.parallel_config.pipeline_parallel_size
def
execute_model
(
self
,
scheduler_output
,
)
->
Union
[
ModelRunnerOutput
,
Future
[
ModelRunnerOutput
]]:
"""Execute the model on the Ray workers.
Args:
scheduler_output: The scheduler output to execute.
Returns:
The model runner output.
"""
# Build the compiled DAG for the first time.
if
self
.
forward_dag
is
None
:
# type: ignore
self
.
forward_dag
=
self
.
_compiled_ray_dag
(
enable_asyncio
=
False
)
refs
=
self
.
forward_dag
.
execute
(
scheduler_output
)
# type: ignore
# When PP is not used, we block here until the result is available.
if
self
.
max_concurrent_batches
==
1
:
return
refs
[
0
].
get
()
# When PP is used, we return a FutureWrapper immediately so that
# the scheduler can yield to the next batch.
return
FutureWrapper
(
refs
[
0
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment