Unverified Commit b34474bf authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Feature] Support async scheduling + PP (#32359)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 6218034d
...@@ -9,6 +9,7 @@ from vllm.config import ( ...@@ -9,6 +9,7 @@ from vllm.config import (
ECTransferConfig, ECTransferConfig,
KVTransferConfig, KVTransferConfig,
ModelConfig, ModelConfig,
ParallelConfig,
SchedulerConfig, SchedulerConfig,
SpeculativeConfig, SpeculativeConfig,
VllmConfig, VllmConfig,
...@@ -53,6 +54,7 @@ def create_scheduler( ...@@ -53,6 +54,7 @@ def create_scheduler(
num_speculative_tokens: int | None = None, num_speculative_tokens: int | None = None,
skip_tokenizer_init: bool = False, skip_tokenizer_init: bool = False,
async_scheduling: bool = False, async_scheduling: bool = False,
pipeline_parallel_size: int = 1,
use_ec_connector: bool = False, use_ec_connector: bool = False,
ec_role: str | None = None, ec_role: str | None = None,
) -> Scheduler | AsyncScheduler: ) -> Scheduler | AsyncScheduler:
...@@ -133,6 +135,7 @@ def create_scheduler( ...@@ -133,6 +135,7 @@ def create_scheduler(
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
model_config=model_config, model_config=model_config,
cache_config=cache_config, cache_config=cache_config,
parallel_config=ParallelConfig(pipeline_parallel_size=pipeline_parallel_size),
kv_transfer_config=kv_transfer_config, kv_transfer_config=kv_transfer_config,
speculative_config=speculative_config, speculative_config=speculative_config,
ec_transfer_config=ec_transfer_config, ec_transfer_config=ec_transfer_config,
......
...@@ -563,11 +563,6 @@ class VllmConfig: ...@@ -563,11 +563,6 @@ class VllmConfig:
if self.scheduler_config.async_scheduling: if self.scheduler_config.async_scheduling:
# Async scheduling explicitly enabled, hard fail any incompatibilities. # Async scheduling explicitly enabled, hard fail any incompatibilities.
if self.parallel_config.pipeline_parallel_size > 1:
raise ValueError(
"Async scheduling is not yet compatible with "
"pipeline_parallel_size > 1."
)
# Currently, async scheduling only support eagle speculative # Currently, async scheduling only support eagle speculative
# decoding. # decoding.
if self.speculative_config is not None: if self.speculative_config is not None:
...@@ -589,14 +584,7 @@ class VllmConfig: ...@@ -589,14 +584,7 @@ class VllmConfig:
) )
elif self.scheduler_config.async_scheduling is None: elif self.scheduler_config.async_scheduling is None:
# Enable async scheduling unless there is an incompatible option. # Enable async scheduling unless there is an incompatible option.
if self.parallel_config.pipeline_parallel_size > 1: if (
logger.warning_once(
"Async scheduling is not yet supported with "
"pipeline_parallel_size > 1 and will be disabled.",
scope="local",
)
self.scheduler_config.async_scheduling = False
elif (
self.speculative_config is not None self.speculative_config is not None
and self.speculative_config.method not in get_args(EagleModelTypes) and self.speculative_config.method not in get_args(EagleModelTypes)
): ):
......
...@@ -283,6 +283,13 @@ class Scheduler(SchedulerInterface): ...@@ -283,6 +283,13 @@ class Scheduler(SchedulerInterface):
while req_index < len(self.running) and token_budget > 0: while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index] request = self.running[req_index]
# do not schedule another step for the same request while it still has
# output placeholders for PP.
# TODO: support PP + async scheduling without this limit
if self.use_pp and request.num_output_placeholders > 0:
req_index += 1
continue
if ( if (
request.num_output_placeholders > 0 request.num_output_placeholders > 0
# This is (num_computed_tokens + 1) - (num_output_placeholders - 1). # This is (num_computed_tokens + 1) - (num_output_placeholders - 1).
......
...@@ -411,9 +411,9 @@ class MultiprocExecutor(Executor): ...@@ -411,9 +411,9 @@ class MultiprocExecutor(Executor):
@cached_property @cached_property
def max_concurrent_batches(self) -> int: def max_concurrent_batches(self) -> int:
if self.scheduler_config.async_scheduling: # PP requires PP-size concurrent batches to fill the pipeline.
return 2 pp_size = self.parallel_config.pipeline_parallel_size
return self.parallel_config.pipeline_parallel_size return 2 if pp_size <= 1 and self.scheduler_config.async_scheduling else pp_size
def _get_output_rank(self) -> int: def _get_output_rank(self) -> int:
# Only returns ModelRunnerOutput from TP rank=0 and PP rank=-1 # Only returns ModelRunnerOutput from TP rank=0 and PP rank=-1
......
...@@ -111,9 +111,8 @@ class RayDistributedExecutor(Executor): ...@@ -111,9 +111,8 @@ class RayDistributedExecutor(Executor):
"""Ray distributed executor supports pipeline parallelism, """Ray distributed executor supports pipeline parallelism,
meaning that it allows PP size batches to be executed concurrently. meaning that it allows PP size batches to be executed concurrently.
""" """
if self.scheduler_config.async_scheduling: pp_size = self.parallel_config.pipeline_parallel_size
return 2 return 2 if pp_size <= 1 and self.scheduler_config.async_scheduling else pp_size
return self.parallel_config.pipeline_parallel_size
def shutdown(self) -> None: def shutdown(self) -> None:
if logger: if logger:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment