Unverified Commit a6ea3add authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

[Auto Sync] Update scheduler.py, spec_info.py, run_suite.py... (20251027) (#12235)


Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: default avatargongwei-130 <56567052+gongwei-130@users.noreply.github.com>
parent 326c84c4
...@@ -327,8 +327,28 @@ class Scheduler( ...@@ -327,8 +327,28 @@ class Scheduler(
# Launch a draft worker for speculative decoding # Launch a draft worker for speculative decoding
self.launch_draft_worker( draft_worker_kwargs = dict(
gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank gpu_id=gpu_id,
tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank,
server_args=server_args,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
if server_args.speculative_draft_load_format is not None:
server_args.load_format = server_args.speculative_draft_load_format
logger.info(
f"Using draft model load_format: '{server_args.speculative_draft_load_format}'"
)
# Draft workers are looked up via `SpeculativeAlgorithm` registry; new
# algorithms should register their factory instead of patching this code.
if self.spec_algorithm.name in {"EAGLE", "EAGLE3"}:
draft_worker_kwargs["enable_overlap"] = self.enable_overlap
self.draft_worker = self.spec_algorithm.create_draft_worker(
**draft_worker_kwargs
) )
# Dispatch the model worker # Dispatch the model worker
...@@ -557,57 +577,6 @@ class Scheduler( ...@@ -557,57 +577,6 @@ class Scheduler(
] ]
) )
def launch_draft_worker(
self, gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
):
if server_args.speculative_draft_load_format is not None:
server_args.load_format = server_args.speculative_draft_load_format
logger.info(
f"Using draft model load_format: '{server_args.speculative_draft_load_format}'"
)
if self.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_worker import EAGLEWorker
from sglang.srt.speculative.eagle_worker_v2 import EAGLEWorkerV2
WorkerClass = EAGLEWorkerV2 if self.enable_overlap else EAGLEWorker
self.draft_worker = WorkerClass(
gpu_id=gpu_id,
tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank,
server_args=server_args,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
elif self.spec_algorithm.is_standalone():
from sglang.srt.speculative.standalone_worker import StandaloneWorker
self.draft_worker = StandaloneWorker(
gpu_id=gpu_id,
tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank,
server_args=server_args,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
elif self.spec_algorithm.is_ngram():
from sglang.srt.speculative.ngram_worker import NGRAMWorker
self.draft_worker = NGRAMWorker(
gpu_id=gpu_id,
tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank,
server_args=server_args,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
else:
self.draft_worker = None
def init_sockets(self, server_args: ServerArgs, port_args: PortArgs): def init_sockets(self, server_args: ServerArgs, port_args: PortArgs):
context = zmq.Context(2) context = zmq.Context(2)
self.idle_sleeper = None self.idle_sleeper = None
......
from __future__ import annotations
import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict
from enum import IntEnum, auto from enum import IntEnum, auto
from functools import lru_cache from typing import (
from typing import List, Tuple Any,
Callable,
DefaultDict,
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Set,
Tuple,
Union,
)
from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.schedule_batch import ModelWorkerBatch
DraftWorkerClass = Callable[..., Any]
DraftWorkerFactory = Callable[..., Any]
class SpeculativeAlgorithm(IntEnum):
NONE = auto()
EAGLE = auto()
EAGLE3 = auto()
STANDALONE = auto()
NGRAM = auto()
def is_none(self): class _SpeculativeAlgorithmMeta(type):
return self == SpeculativeAlgorithm.NONE def __iter__(cls) -> Iterator["SpeculativeAlgorithm"]:
return iter(cls._registration_order)
def is_eagle(self):
return self == SpeculativeAlgorithm.EAGLE or self == SpeculativeAlgorithm.EAGLE3
def is_eagle3(self): class SpeculativeAlgorithm(metaclass=_SpeculativeAlgorithmMeta):
return self == SpeculativeAlgorithm.EAGLE3 """Registry-backed representation of speculative decoding algorithms."""
def is_standalone(self): __slots__ = ("name", "value", "_draft_worker_factory")
return self == SpeculativeAlgorithm.STANDALONE
def is_ngram(self): _registry_by_name: Dict[str, "SpeculativeAlgorithm"] = {}
return self == SpeculativeAlgorithm.NGRAM _registry_by_value: Dict[int, "SpeculativeAlgorithm"] = {}
_registration_order: List["SpeculativeAlgorithm"] = []
_flags: DefaultDict[str, Set[int]] = defaultdict(set)
_next_value: int = 0
@lru_cache(maxsize=None) def __init__(
@staticmethod self,
def from_string(name: str): name: str,
name_map = { value: int,
"EAGLE": SpeculativeAlgorithm.EAGLE, draft_worker_factory: Optional[DraftWorkerFactory] = None,
"EAGLE3": SpeculativeAlgorithm.EAGLE3, ):
"STANDALONE": SpeculativeAlgorithm.STANDALONE, self.name = name
"NGRAM": SpeculativeAlgorithm.NGRAM, self.value = value
None: SpeculativeAlgorithm.NONE, self._draft_worker_factory = draft_worker_factory
}
if name is not None: def __repr__(self) -> str: # pragma: no cover - trivial
name = name.upper() return f"SpeculativeAlgorithm.{self.name}"
return name_map[name]
def __str__(self) -> str: # pragma: no cover - trivial
return self.name
def __hash__(self) -> int:
return hash(self.value)
def __eq__(self, other: object) -> bool:
if isinstance(other, SpeculativeAlgorithm):
return self.value == other.value
return NotImplemented
def __int__(self) -> int:
return self.value
@classmethod
def register(
cls,
name: str,
*,
aliases: Optional[Sequence[str]] = None,
value: Optional[int] = None,
draft_worker_factory: Optional[DraftWorkerFactory] = None,
) -> SpeculativeAlgorithm:
normalized_name = name.upper()
if normalized_name in cls._registry_by_name:
raise ValueError(
f"SpeculativeAlgorithm '{normalized_name}' already registered"
)
if value is None:
value = cls._next_value
cls._next_value = max(cls._next_value, value + 1)
algorithm = cls(
normalized_name,
value,
draft_worker_factory=draft_worker_factory,
)
cls._registry_by_name[normalized_name] = algorithm
cls._registry_by_value[value] = algorithm
cls._registration_order.append(algorithm)
setattr(cls, normalized_name, algorithm)
if aliases:
cls.register_aliases(algorithm, *aliases)
return algorithm
@classmethod
def register_aliases(cls, algorithm: SpeculativeAlgorithm, *aliases: str) -> None:
for alias in aliases:
cls._registry_by_name[alias.upper()] = algorithm
@classmethod
def register_draft_worker(
cls,
algorithm: SpeculativeAlgorithm | str,
factory: DraftWorkerFactory,
) -> None:
algo = cls._ensure_algorithm(algorithm)
algo._draft_worker_factory = factory
@classmethod
def _ensure_algorithm(
cls, algorithm: SpeculativeAlgorithm | str
) -> SpeculativeAlgorithm:
if isinstance(algorithm, SpeculativeAlgorithm):
return algorithm
if isinstance(algorithm, str):
return cls.from_string(algorithm)
raise TypeError(f"Unsupported algorithm identifier: {algorithm!r}")
@classmethod
def _add_flag(
cls, flag: str | Sequence[str], algorithm: SpeculativeAlgorithm | str
) -> None:
algo = cls._ensure_algorithm(algorithm)
if isinstance(flag, str):
flag_iter = (flag,)
else:
flag_iter = flag
for flag_name in flag_iter:
cls._flags[flag_name.upper()].add(algo.value)
@classmethod
@classmethod
def from_string(cls, name: Optional[str]) -> SpeculativeAlgorithm:
if name is None:
return cls.NONE
try:
return cls._registry_by_name[name.upper()]
except KeyError as exc:
raise ValueError(f"Unknown speculative algorithm '{name}'") from exc
@classmethod
def from_value(cls, value: int) -> SpeculativeAlgorithm:
try:
return cls._registry_by_value[value]
except KeyError as exc:
raise ValueError(f"Unknown speculative algorithm id {value}") from exc
def _has_flag(self, flag: str) -> bool:
return self.value in type(self)._flags.get(flag.upper(), set())
def is_none(self) -> bool:
return self is SpeculativeAlgorithm.NONE
def is_eagle(self) -> bool:
return self._has_flag("EAGLE")
def is_eagle3(self) -> bool:
return self._has_flag("EAGLE3")
def is_standalone(self) -> bool:
return self._has_flag("STANDALONE")
def is_ngram(self) -> bool:
return self._has_flag("NGRAM")
def create_draft_worker(self, **factory_kwargs: Any) -> Any:
if self._draft_worker_factory is None:
return None
return self._draft_worker_factory(self, **factory_kwargs)
# Registry helpers backed by `SpeculativeAlgorithm`.
_LOCK = threading.RLock()
_REGISTERED_WORKERS: Dict[SpeculativeAlgorithm, DraftWorkerClass] = {}
_FLAG_MARKERS: Dict[str, Callable[[Union[SpeculativeAlgorithm, str]], None]] = {
"EAGLE": lambda algorithm: SpeculativeAlgorithm._add_flag("EAGLE", algorithm),
"EAGLE3": lambda algorithm: SpeculativeAlgorithm._add_flag("EAGLE3", algorithm),
"STANDALONE": lambda algorithm: SpeculativeAlgorithm._add_flag(
"STANDALONE", algorithm
),
"NGRAM": lambda algorithm: SpeculativeAlgorithm._add_flag("NGRAM", algorithm),
}
def _wrap_worker_class(worker_cls: DraftWorkerClass) -> DraftWorkerFactory:
def _factory(_: SpeculativeAlgorithm, **kwargs: Any) -> Any:
return worker_cls(**kwargs)
return _factory
def register_speculative_algorithm(
name: str,
worker_cls: DraftWorkerClass,
*,
aliases: Optional[Sequence[str]] = None,
flags: Optional[Iterable[str]] = None,
value: Optional[int] = None,
override_worker: bool = False,
) -> SpeculativeAlgorithm:
"""Register a speculative algorithm and the associated draft worker class.
Example:
>>> from sglang.srt.speculative.spec_info import register_speculative_algorithm
>>> register_speculative_algorithm("MY_ALGO", MyDraftWorker, flags=("EAGLE",))
"""
name_upper = name.upper()
with _LOCK:
try:
algorithm = SpeculativeAlgorithm.from_string(name_upper)
exists = True
except ValueError:
algorithm = SpeculativeAlgorithm.register(
name_upper,
aliases=aliases,
value=value,
)
SpeculativeAlgorithm.register_draft_worker(
algorithm, _wrap_worker_class(worker_cls)
)
exists = False
if exists:
if aliases:
SpeculativeAlgorithm.register_aliases(algorithm, *aliases)
if not override_worker and algorithm in _REGISTERED_WORKERS:
raise ValueError(
f"Worker already registered for {algorithm!r}. "
"Pass override_worker=True to replace it."
)
SpeculativeAlgorithm.register_draft_worker(
algorithm, _wrap_worker_class(worker_cls)
)
_REGISTERED_WORKERS[algorithm] = worker_cls
if flags:
for flag in flags:
marker = _FLAG_MARKERS.get(flag.upper())
if marker is None:
raise ValueError(f"Unsupported flag '{flag}'")
marker(algorithm)
return algorithm
def list_registered_workers() -> Dict[str, DraftWorkerClass]:
"""Return a snapshot of registered speculative worker classes keyed by algorithm name."""
with _LOCK:
return {algo.name: cls for algo, cls in _REGISTERED_WORKERS.items()}
def _create_eagle_worker(**kwargs: Any) -> Any:
enable_overlap = kwargs.pop("enable_overlap", False)
if enable_overlap:
from sglang.srt.speculative.eagle_worker_v2 import EAGLEWorkerV2
return EAGLEWorkerV2(**kwargs)
from sglang.srt.speculative.eagle_worker import EAGLEWorker
return EAGLEWorker(**kwargs)
def _create_standalone_worker(**kwargs: Any) -> Any:
from sglang.srt.speculative.standalone_worker import StandaloneWorker
return StandaloneWorker(**kwargs)
def _create_ngram_worker(**kwargs: Any) -> Any:
from sglang.srt.speculative.ngram_worker import NGRAMWorker
return NGRAMWorker(**kwargs)
# Register built-in algorithms.
# Third-party integrations should import `SpeculativeAlgorithm` and either
# call `register_speculative_algorithm` or use the helpers below to attach
# additional draft workers.
SpeculativeAlgorithm.register("NONE")
register_speculative_algorithm(
"EAGLE",
aliases=("NEXTN",),
worker_cls=_create_eagle_worker,
flags=("EAGLE",),
)
register_speculative_algorithm(
"EAGLE3",
worker_cls=_create_eagle_worker,
flags=("EAGLE", "EAGLE3"),
)
register_speculative_algorithm(
"STANDALONE",
worker_cls=_create_standalone_worker,
flags=("STANDALONE",),
)
register_speculative_algorithm(
"NGRAM",
worker_cls=_create_ngram_worker,
flags=("NGRAM",),
)
class SpecInputType(IntEnum): class SpecInputType(IntEnum):
......
...@@ -119,6 +119,7 @@ suites = { ...@@ -119,6 +119,7 @@ suites = {
TestFile("test_retract_decode.py", 90), TestFile("test_retract_decode.py", 90),
TestFile("test_score_api.py", 310), TestFile("test_score_api.py", 310),
TestFile("test_server_args.py", 1), TestFile("test_server_args.py", 1),
TestFile("test_speculative_registry.py", 1),
TestFile("test_skip_tokenizer_init.py", 117), TestFile("test_skip_tokenizer_init.py", 117),
TestFile("test_srt_endpoint.py", 130), TestFile("test_srt_endpoint.py", 130),
TestFile("test_srt_engine.py", 261), TestFile("test_srt_engine.py", 261),
......
import unittest
from sglang.srt.speculative import spec_info as spec_info_module
from sglang.srt.speculative.spec_info import (
SpeculativeAlgorithm,
register_speculative_algorithm,
)
class DummyWorker:
def __init__(self, **kwargs):
self.kwargs = kwargs
class SpeculativeRegistryTests(unittest.TestCase):
def test_nextn_alias_maps_to_eagle(self):
eagle = SpeculativeAlgorithm.from_string("EAGLE")
alias = SpeculativeAlgorithm.from_string("NEXTN")
self.assertIs(alias, eagle)
def test_register_speculative_algorithm_registers_worker_and_flags(self):
original_next_value = SpeculativeAlgorithm._next_value
algo = register_speculative_algorithm(
"TEST_SPEC_ALGO",
DummyWorker,
aliases=("TEST_SPEC_ALIAS",),
flags=("EAGLE",),
override_worker=True,
)
self.addCleanup(self._cleanup_registered_algorithm, algo, ("TEST_SPEC_ALIAS",))
self.addCleanup(
setattr, SpeculativeAlgorithm, "_next_value", original_next_value
)
self.assertIs(SpeculativeAlgorithm.from_string("TEST_SPEC_ALGO"), algo)
self.assertIs(SpeculativeAlgorithm.from_string("TEST_SPEC_ALIAS"), algo)
self.assertTrue(algo.is_eagle())
self.assertIs(SpeculativeAlgorithm.from_value(int(algo)), algo)
self.assertIn(algo, list(spec_info_module._REGISTERED_WORKERS))
worker = algo.create_draft_worker(example_arg=42)
self.assertIsInstance(worker, DummyWorker)
self.assertEqual(worker.kwargs["example_arg"], 42)
def test_builtin_algorithms_flags_and_factories(self):
cases = {
"NONE": {
"is_none": True,
"is_eagle": False,
"is_eagle3": False,
"is_standalone": False,
"is_ngram": False,
"has_factory": False,
},
"EAGLE": {
"is_none": False,
"is_eagle": True,
"is_eagle3": False,
"is_standalone": False,
"is_ngram": False,
"has_factory": True,
},
"EAGLE3": {
"is_none": False,
"is_eagle": True,
"is_eagle3": True,
"is_standalone": False,
"is_ngram": False,
"has_factory": True,
},
"STANDALONE": {
"is_none": False,
"is_eagle": False,
"is_eagle3": False,
"is_standalone": True,
"is_ngram": False,
"has_factory": True,
},
"NGRAM": {
"is_none": False,
"is_eagle": False,
"is_eagle3": False,
"is_standalone": False,
"is_ngram": True,
"has_factory": True,
},
}
for name, expectations in cases.items():
with self.subTest(name=name):
algo = SpeculativeAlgorithm.from_string(name)
self.assertEqual(algo.name, name)
self.assertEqual(algo.is_none(), expectations["is_none"])
self.assertEqual(algo.is_eagle(), expectations["is_eagle"])
self.assertEqual(algo.is_eagle3(), expectations["is_eagle3"])
self.assertEqual(algo.is_standalone(), expectations["is_standalone"])
self.assertEqual(algo.is_ngram(), expectations["is_ngram"])
has_factory = algo._draft_worker_factory is not None
self.assertEqual(has_factory, expectations["has_factory"])
self.assertIs(SpeculativeAlgorithm.from_value(int(algo)), algo)
self.assertIs(SpeculativeAlgorithm.from_string(None), SpeculativeAlgorithm.NONE)
def test_iteration_returns_registration_order(self):
names = [algo.name for algo in SpeculativeAlgorithm._registration_order]
for required in ["NONE", "EAGLE", "EAGLE3", "STANDALONE", "NGRAM"]:
self.assertIn(required, names)
def test_create_draft_worker_returns_none_for_none_algorithm(self):
self.assertIsNone(SpeculativeAlgorithm.NONE.create_draft_worker())
def test_register_draft_worker_override(self):
algo = SpeculativeAlgorithm.from_string("EAGLE")
original_factory = algo._draft_worker_factory
def dummy_factory(_: SpeculativeAlgorithm, **kwargs):
return "dummy"
SpeculativeAlgorithm.register_draft_worker(algo, dummy_factory)
self.addCleanup(
SpeculativeAlgorithm.register_draft_worker, algo, original_factory
)
self.assertEqual(algo.create_draft_worker(), "dummy")
def _cleanup_registered_algorithm(self, algorithm: SpeculativeAlgorithm, aliases):
name = algorithm.name
SpeculativeAlgorithm._registry_by_value.pop(algorithm.value, None)
SpeculativeAlgorithm._registry_by_name.pop(name, None)
if hasattr(SpeculativeAlgorithm, name):
delattr(SpeculativeAlgorithm, name)
for alias in aliases:
SpeculativeAlgorithm._registry_by_name.pop(alias, None)
try:
SpeculativeAlgorithm._registration_order.remove(algorithm)
except ValueError:
pass
for flag_values in SpeculativeAlgorithm._flags.values():
flag_values.discard(algorithm.value)
spec_info_module._REGISTERED_WORKERS.pop(algorithm, None)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment