Unverified Commit 42333026 authored by Yannick Schnider's avatar Yannick Schnider Committed by GitHub
Browse files

[Feature] Pluggable platform-specific scheduler (#13161)


Signed-off-by: default avatarYannick Schnider <yannick.schnider1@ibm.com>
Signed-off-by: default avatarYannick Schnider <Yannick.Schnider1@ibm.com>
parent caf7ff44
...@@ -531,6 +531,7 @@ steps: ...@@ -531,6 +531,7 @@ steps:
- pip uninstall vllm_add_dummy_platform -y - pip uninstall vllm_add_dummy_platform -y
# end platform plugin tests # end platform plugin tests
# other tests continue here: # other tests continue here:
- pytest -v -s plugins_tests/test_scheduler_plugins.py
- pip install -e ./plugins/vllm_add_dummy_model - pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s distributed/test_distributed_oot.py - pytest -v -s distributed/test_distributed_oot.py
- pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process - pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
......
# SPDX-License-Identifier: Apache-2.0
from vllm.core.scheduler import Scheduler
class DummyScheduler(Scheduler):
def schedule(self):
raise Exception("Exception raised by DummyScheduler")
def test_scheduler_plugins():
import pytest
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.sampling_params import SamplingParams
with pytest.raises(Exception) as exception_info:
engine_args = EngineArgs(
model="facebook/opt-125m",
enforce_eager=True, # reduce test time
scheduler_cls=DummyScheduler,
)
engine = LLMEngine.from_engine_args(engine_args=engine_args)
sampling_params = SamplingParams(max_tokens=1)
engine.add_request("0", "foo", sampling_params)
engine.step()
assert str(exception_info.value) == "Exception raised by DummyScheduler"
...@@ -1495,6 +1495,10 @@ class SchedulerConfig: ...@@ -1495,6 +1495,10 @@ class SchedulerConfig:
chunked_prefill_enabled: bool = field(init=False) chunked_prefill_enabled: bool = field(init=False)
# scheduler class or path. "vllm.core.scheduler.Scheduler" (default)
# or "mod.custom_class".
scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler"
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, WARNING: Whenever a new field is added to this config,
......
...@@ -192,6 +192,7 @@ class EngineArgs: ...@@ -192,6 +192,7 @@ class EngineArgs:
collect_detailed_traces: Optional[str] = None collect_detailed_traces: Optional[str] = None
disable_async_output_proc: bool = False disable_async_output_proc: bool = False
scheduling_policy: Literal["fcfs", "priority"] = "fcfs" scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler"
override_neuron_config: Optional[Dict[str, Any]] = None override_neuron_config: Optional[Dict[str, Any]] = None
override_pooler_config: Optional[PoolerConfig] = None override_pooler_config: Optional[PoolerConfig] = None
...@@ -938,6 +939,13 @@ class EngineArgs: ...@@ -938,6 +939,13 @@ class EngineArgs:
'priority (lower value means earlier handling) and time of ' 'priority (lower value means earlier handling) and time of '
'arrival deciding any ties).') 'arrival deciding any ties).')
parser.add_argument(
'--scheduler-cls',
default=EngineArgs.scheduler_cls,
help='The scheduler class to use. "vllm.core.scheduler.Scheduler" '
'is the default scheduler. Can be a class directly or the path to '
'a class of form "mod.custom_class".')
parser.add_argument( parser.add_argument(
'--override-neuron-config', '--override-neuron-config',
type=json.loads, type=json.loads,
...@@ -1273,10 +1281,12 @@ class EngineArgs: ...@@ -1273,10 +1281,12 @@ class EngineArgs:
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
and parallel_config.use_ray), and parallel_config.use_ray),
policy=self.scheduling_policy, policy=self.scheduling_policy,
scheduler_cls=self.scheduler_cls,
max_num_partial_prefills=self.max_num_partial_prefills, max_num_partial_prefills=self.max_num_partial_prefills,
max_long_partial_prefills=self.max_long_partial_prefills, max_long_partial_prefills=self.max_long_partial_prefills,
long_prefill_token_threshold=self.long_prefill_token_threshold, long_prefill_token_threshold=self.long_prefill_token_threshold,
) )
lora_config = LoRAConfig( lora_config = LoRAConfig(
bias_enabled=self.enable_lora_bias, bias_enabled=self.enable_lora_bias,
max_lora_rank=self.max_lora_rank, max_lora_rank=self.max_lora_rank,
......
...@@ -19,8 +19,7 @@ import vllm.envs as envs ...@@ -19,8 +19,7 @@ import vllm.envs as envs
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig, SchedulerConfig, ObservabilityConfig, ParallelConfig, SchedulerConfig,
VllmConfig) VllmConfig)
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs
SchedulerOutputs)
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase, Stats from vllm.engine.metrics_types import StatLoggerBase, Stats
from vllm.engine.output_processor.interfaces import ( from vllm.engine.output_processor.interfaces import (
...@@ -58,7 +57,8 @@ from vllm.transformers_utils.tokenizer_group import ( ...@@ -58,7 +57,8 @@ from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup, init_tokenizer_from_configs) BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind from vllm.utils import (Counter, Device, deprecate_kwargs,
resolve_obj_by_qualname, weak_bind)
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -346,6 +346,11 @@ class LLMEngine: ...@@ -346,6 +346,11 @@ class LLMEngine:
# Create the scheduler. # Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of # NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor. # GPU and CPU blocks, which are profiled in the distributed executor.
if isinstance(self.vllm_config.scheduler_config.scheduler_cls, str):
Scheduler = resolve_obj_by_qualname(
self.vllm_config.scheduler_config.scheduler_cls)
else:
Scheduler = self.vllm_config.scheduler_config.scheduler_cls
self.scheduler = [ self.scheduler = [
Scheduler( Scheduler(
self.scheduler_config, self.cache_config, self.lora_config, self.scheduler_config, self.cache_config, self.lora_config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment