"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "3fff6789d417f9b43363e3d65477f883fc794bb7"
Unverified Commit a6ea3add authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

[Auto Sync] Update scheduler.py, spec_info.py, run_suite.py... (20251027) (#12235)


Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: default avatargongwei-130 <56567052+gongwei-130@users.noreply.github.com>
parent 326c84c4
......@@ -327,8 +327,28 @@ class Scheduler(
# Launch a draft worker for speculative decoding
self.launch_draft_worker(
gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
draft_worker_kwargs = dict(
gpu_id=gpu_id,
tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank,
server_args=server_args,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
if server_args.speculative_draft_load_format is not None:
server_args.load_format = server_args.speculative_draft_load_format
logger.info(
f"Using draft model load_format: '{server_args.speculative_draft_load_format}'"
)
# Draft workers are looked up via `SpeculativeAlgorithm` registry; new
# algorithms should register their factory instead of patching this code.
if self.spec_algorithm.name in {"EAGLE", "EAGLE3"}:
draft_worker_kwargs["enable_overlap"] = self.enable_overlap
self.draft_worker = self.spec_algorithm.create_draft_worker(
**draft_worker_kwargs
)
# Dispatch the model worker
......@@ -557,57 +577,6 @@ class Scheduler(
]
)
def launch_draft_worker(
self, gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
):
if server_args.speculative_draft_load_format is not None:
server_args.load_format = server_args.speculative_draft_load_format
logger.info(
f"Using draft model load_format: '{server_args.speculative_draft_load_format}'"
)
if self.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_worker import EAGLEWorker
from sglang.srt.speculative.eagle_worker_v2 import EAGLEWorkerV2
WorkerClass = EAGLEWorkerV2 if self.enable_overlap else EAGLEWorker
self.draft_worker = WorkerClass(
gpu_id=gpu_id,
tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank,
server_args=server_args,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
elif self.spec_algorithm.is_standalone():
from sglang.srt.speculative.standalone_worker import StandaloneWorker
self.draft_worker = StandaloneWorker(
gpu_id=gpu_id,
tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank,
server_args=server_args,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
elif self.spec_algorithm.is_ngram():
from sglang.srt.speculative.ngram_worker import NGRAMWorker
self.draft_worker = NGRAMWorker(
gpu_id=gpu_id,
tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank,
server_args=server_args,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
else:
self.draft_worker = None
def init_sockets(self, server_args: ServerArgs, port_args: PortArgs):
context = zmq.Context(2)
self.idle_sleeper = None
......
from __future__ import annotations
import threading
from abc import ABC, abstractmethod
from collections import defaultdict
from enum import IntEnum, auto
from functools import lru_cache
from typing import List, Tuple
from typing import (
Any,
Callable,
DefaultDict,
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Set,
Tuple,
Union,
)
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
DraftWorkerClass = Callable[..., Any]
DraftWorkerFactory = Callable[..., Any]
class SpeculativeAlgorithm(IntEnum):
NONE = auto()
EAGLE = auto()
EAGLE3 = auto()
STANDALONE = auto()
NGRAM = auto()
def is_none(self):
return self == SpeculativeAlgorithm.NONE
class _SpeculativeAlgorithmMeta(type):
def __iter__(cls) -> Iterator["SpeculativeAlgorithm"]:
return iter(cls._registration_order)
def is_eagle(self):
return self == SpeculativeAlgorithm.EAGLE or self == SpeculativeAlgorithm.EAGLE3
def is_eagle3(self):
return self == SpeculativeAlgorithm.EAGLE3
class SpeculativeAlgorithm(metaclass=_SpeculativeAlgorithmMeta):
"""Registry-backed representation of speculative decoding algorithms."""
def is_standalone(self):
return self == SpeculativeAlgorithm.STANDALONE
__slots__ = ("name", "value", "_draft_worker_factory")
def is_ngram(self):
return self == SpeculativeAlgorithm.NGRAM
_registry_by_name: Dict[str, "SpeculativeAlgorithm"] = {}
_registry_by_value: Dict[int, "SpeculativeAlgorithm"] = {}
_registration_order: List["SpeculativeAlgorithm"] = []
_flags: DefaultDict[str, Set[int]] = defaultdict(set)
_next_value: int = 0
@lru_cache(maxsize=None)
@staticmethod
def from_string(name: str):
name_map = {
"EAGLE": SpeculativeAlgorithm.EAGLE,
"EAGLE3": SpeculativeAlgorithm.EAGLE3,
"STANDALONE": SpeculativeAlgorithm.STANDALONE,
"NGRAM": SpeculativeAlgorithm.NGRAM,
None: SpeculativeAlgorithm.NONE,
}
if name is not None:
name = name.upper()
return name_map[name]
def __init__(
self,
name: str,
value: int,
draft_worker_factory: Optional[DraftWorkerFactory] = None,
):
self.name = name
self.value = value
self._draft_worker_factory = draft_worker_factory
def __repr__(self) -> str: # pragma: no cover - trivial
return f"SpeculativeAlgorithm.{self.name}"
def __str__(self) -> str: # pragma: no cover - trivial
return self.name
def __hash__(self) -> int:
return hash(self.value)
def __eq__(self, other: object) -> bool:
if isinstance(other, SpeculativeAlgorithm):
return self.value == other.value
return NotImplemented
def __int__(self) -> int:
return self.value
@classmethod
def register(
cls,
name: str,
*,
aliases: Optional[Sequence[str]] = None,
value: Optional[int] = None,
draft_worker_factory: Optional[DraftWorkerFactory] = None,
) -> SpeculativeAlgorithm:
normalized_name = name.upper()
if normalized_name in cls._registry_by_name:
raise ValueError(
f"SpeculativeAlgorithm '{normalized_name}' already registered"
)
if value is None:
value = cls._next_value
cls._next_value = max(cls._next_value, value + 1)
algorithm = cls(
normalized_name,
value,
draft_worker_factory=draft_worker_factory,
)
cls._registry_by_name[normalized_name] = algorithm
cls._registry_by_value[value] = algorithm
cls._registration_order.append(algorithm)
setattr(cls, normalized_name, algorithm)
if aliases:
cls.register_aliases(algorithm, *aliases)
return algorithm
@classmethod
def register_aliases(cls, algorithm: SpeculativeAlgorithm, *aliases: str) -> None:
for alias in aliases:
cls._registry_by_name[alias.upper()] = algorithm
@classmethod
def register_draft_worker(
cls,
algorithm: SpeculativeAlgorithm | str,
factory: DraftWorkerFactory,
) -> None:
algo = cls._ensure_algorithm(algorithm)
algo._draft_worker_factory = factory
@classmethod
def _ensure_algorithm(
cls, algorithm: SpeculativeAlgorithm | str
) -> SpeculativeAlgorithm:
if isinstance(algorithm, SpeculativeAlgorithm):
return algorithm
if isinstance(algorithm, str):
return cls.from_string(algorithm)
raise TypeError(f"Unsupported algorithm identifier: {algorithm!r}")
@classmethod
def _add_flag(
cls, flag: str | Sequence[str], algorithm: SpeculativeAlgorithm | str
) -> None:
algo = cls._ensure_algorithm(algorithm)
if isinstance(flag, str):
flag_iter = (flag,)
else:
flag_iter = flag
for flag_name in flag_iter:
cls._flags[flag_name.upper()].add(algo.value)
@classmethod
@classmethod
def from_string(cls, name: Optional[str]) -> SpeculativeAlgorithm:
if name is None:
return cls.NONE
try:
return cls._registry_by_name[name.upper()]
except KeyError as exc:
raise ValueError(f"Unknown speculative algorithm '{name}'") from exc
@classmethod
def from_value(cls, value: int) -> SpeculativeAlgorithm:
try:
return cls._registry_by_value[value]
except KeyError as exc:
raise ValueError(f"Unknown speculative algorithm id {value}") from exc
def _has_flag(self, flag: str) -> bool:
return self.value in type(self)._flags.get(flag.upper(), set())
def is_none(self) -> bool:
return self is SpeculativeAlgorithm.NONE
def is_eagle(self) -> bool:
return self._has_flag("EAGLE")
def is_eagle3(self) -> bool:
return self._has_flag("EAGLE3")
def is_standalone(self) -> bool:
return self._has_flag("STANDALONE")
def is_ngram(self) -> bool:
return self._has_flag("NGRAM")
def create_draft_worker(self, **factory_kwargs: Any) -> Any:
if self._draft_worker_factory is None:
return None
return self._draft_worker_factory(self, **factory_kwargs)
# Registry helpers backed by `SpeculativeAlgorithm`.
_LOCK = threading.RLock()
_REGISTERED_WORKERS: Dict[SpeculativeAlgorithm, DraftWorkerClass] = {}
_FLAG_MARKERS: Dict[str, Callable[[Union[SpeculativeAlgorithm, str]], None]] = {
"EAGLE": lambda algorithm: SpeculativeAlgorithm._add_flag("EAGLE", algorithm),
"EAGLE3": lambda algorithm: SpeculativeAlgorithm._add_flag("EAGLE3", algorithm),
"STANDALONE": lambda algorithm: SpeculativeAlgorithm._add_flag(
"STANDALONE", algorithm
),
"NGRAM": lambda algorithm: SpeculativeAlgorithm._add_flag("NGRAM", algorithm),
}
def _wrap_worker_class(worker_cls: DraftWorkerClass) -> DraftWorkerFactory:
def _factory(_: SpeculativeAlgorithm, **kwargs: Any) -> Any:
return worker_cls(**kwargs)
return _factory
def register_speculative_algorithm(
name: str,
worker_cls: DraftWorkerClass,
*,
aliases: Optional[Sequence[str]] = None,
flags: Optional[Iterable[str]] = None,
value: Optional[int] = None,
override_worker: bool = False,
) -> SpeculativeAlgorithm:
"""Register a speculative algorithm and the associated draft worker class.
Example:
>>> from sglang.srt.speculative.spec_info import register_speculative_algorithm
>>> register_speculative_algorithm("MY_ALGO", MyDraftWorker, flags=("EAGLE",))
"""
name_upper = name.upper()
with _LOCK:
try:
algorithm = SpeculativeAlgorithm.from_string(name_upper)
exists = True
except ValueError:
algorithm = SpeculativeAlgorithm.register(
name_upper,
aliases=aliases,
value=value,
)
SpeculativeAlgorithm.register_draft_worker(
algorithm, _wrap_worker_class(worker_cls)
)
exists = False
if exists:
if aliases:
SpeculativeAlgorithm.register_aliases(algorithm, *aliases)
if not override_worker and algorithm in _REGISTERED_WORKERS:
raise ValueError(
f"Worker already registered for {algorithm!r}. "
"Pass override_worker=True to replace it."
)
SpeculativeAlgorithm.register_draft_worker(
algorithm, _wrap_worker_class(worker_cls)
)
_REGISTERED_WORKERS[algorithm] = worker_cls
if flags:
for flag in flags:
marker = _FLAG_MARKERS.get(flag.upper())
if marker is None:
raise ValueError(f"Unsupported flag '{flag}'")
marker(algorithm)
return algorithm
def list_registered_workers() -> Dict[str, DraftWorkerClass]:
"""Return a snapshot of registered speculative worker classes keyed by algorithm name."""
with _LOCK:
return {algo.name: cls for algo, cls in _REGISTERED_WORKERS.items()}
def _create_eagle_worker(**kwargs: Any) -> Any:
enable_overlap = kwargs.pop("enable_overlap", False)
if enable_overlap:
from sglang.srt.speculative.eagle_worker_v2 import EAGLEWorkerV2
return EAGLEWorkerV2(**kwargs)
from sglang.srt.speculative.eagle_worker import EAGLEWorker
return EAGLEWorker(**kwargs)
def _create_standalone_worker(**kwargs: Any) -> Any:
from sglang.srt.speculative.standalone_worker import StandaloneWorker
return StandaloneWorker(**kwargs)
def _create_ngram_worker(**kwargs: Any) -> Any:
from sglang.srt.speculative.ngram_worker import NGRAMWorker
return NGRAMWorker(**kwargs)
# Register built-in algorithms.
# Third-party integrations should import `SpeculativeAlgorithm` and either
# call `register_speculative_algorithm` or use the helpers below to attach
# additional draft workers.
SpeculativeAlgorithm.register("NONE")
register_speculative_algorithm(
"EAGLE",
aliases=("NEXTN",),
worker_cls=_create_eagle_worker,
flags=("EAGLE",),
)
register_speculative_algorithm(
"EAGLE3",
worker_cls=_create_eagle_worker,
flags=("EAGLE", "EAGLE3"),
)
register_speculative_algorithm(
"STANDALONE",
worker_cls=_create_standalone_worker,
flags=("STANDALONE",),
)
register_speculative_algorithm(
"NGRAM",
worker_cls=_create_ngram_worker,
flags=("NGRAM",),
)
class SpecInputType(IntEnum):
......
......@@ -119,6 +119,7 @@ suites = {
TestFile("test_retract_decode.py", 90),
TestFile("test_score_api.py", 310),
TestFile("test_server_args.py", 1),
TestFile("test_speculative_registry.py", 1),
TestFile("test_skip_tokenizer_init.py", 117),
TestFile("test_srt_endpoint.py", 130),
TestFile("test_srt_engine.py", 261),
......
import unittest
from sglang.srt.speculative import spec_info as spec_info_module
from sglang.srt.speculative.spec_info import (
SpeculativeAlgorithm,
register_speculative_algorithm,
)
class DummyWorker:
def __init__(self, **kwargs):
self.kwargs = kwargs
class SpeculativeRegistryTests(unittest.TestCase):
def test_nextn_alias_maps_to_eagle(self):
eagle = SpeculativeAlgorithm.from_string("EAGLE")
alias = SpeculativeAlgorithm.from_string("NEXTN")
self.assertIs(alias, eagle)
def test_register_speculative_algorithm_registers_worker_and_flags(self):
original_next_value = SpeculativeAlgorithm._next_value
algo = register_speculative_algorithm(
"TEST_SPEC_ALGO",
DummyWorker,
aliases=("TEST_SPEC_ALIAS",),
flags=("EAGLE",),
override_worker=True,
)
self.addCleanup(self._cleanup_registered_algorithm, algo, ("TEST_SPEC_ALIAS",))
self.addCleanup(
setattr, SpeculativeAlgorithm, "_next_value", original_next_value
)
self.assertIs(SpeculativeAlgorithm.from_string("TEST_SPEC_ALGO"), algo)
self.assertIs(SpeculativeAlgorithm.from_string("TEST_SPEC_ALIAS"), algo)
self.assertTrue(algo.is_eagle())
self.assertIs(SpeculativeAlgorithm.from_value(int(algo)), algo)
self.assertIn(algo, list(spec_info_module._REGISTERED_WORKERS))
worker = algo.create_draft_worker(example_arg=42)
self.assertIsInstance(worker, DummyWorker)
self.assertEqual(worker.kwargs["example_arg"], 42)
def test_builtin_algorithms_flags_and_factories(self):
cases = {
"NONE": {
"is_none": True,
"is_eagle": False,
"is_eagle3": False,
"is_standalone": False,
"is_ngram": False,
"has_factory": False,
},
"EAGLE": {
"is_none": False,
"is_eagle": True,
"is_eagle3": False,
"is_standalone": False,
"is_ngram": False,
"has_factory": True,
},
"EAGLE3": {
"is_none": False,
"is_eagle": True,
"is_eagle3": True,
"is_standalone": False,
"is_ngram": False,
"has_factory": True,
},
"STANDALONE": {
"is_none": False,
"is_eagle": False,
"is_eagle3": False,
"is_standalone": True,
"is_ngram": False,
"has_factory": True,
},
"NGRAM": {
"is_none": False,
"is_eagle": False,
"is_eagle3": False,
"is_standalone": False,
"is_ngram": True,
"has_factory": True,
},
}
for name, expectations in cases.items():
with self.subTest(name=name):
algo = SpeculativeAlgorithm.from_string(name)
self.assertEqual(algo.name, name)
self.assertEqual(algo.is_none(), expectations["is_none"])
self.assertEqual(algo.is_eagle(), expectations["is_eagle"])
self.assertEqual(algo.is_eagle3(), expectations["is_eagle3"])
self.assertEqual(algo.is_standalone(), expectations["is_standalone"])
self.assertEqual(algo.is_ngram(), expectations["is_ngram"])
has_factory = algo._draft_worker_factory is not None
self.assertEqual(has_factory, expectations["has_factory"])
self.assertIs(SpeculativeAlgorithm.from_value(int(algo)), algo)
self.assertIs(SpeculativeAlgorithm.from_string(None), SpeculativeAlgorithm.NONE)
def test_iteration_returns_registration_order(self):
names = [algo.name for algo in SpeculativeAlgorithm._registration_order]
for required in ["NONE", "EAGLE", "EAGLE3", "STANDALONE", "NGRAM"]:
self.assertIn(required, names)
def test_create_draft_worker_returns_none_for_none_algorithm(self):
self.assertIsNone(SpeculativeAlgorithm.NONE.create_draft_worker())
def test_register_draft_worker_override(self):
algo = SpeculativeAlgorithm.from_string("EAGLE")
original_factory = algo._draft_worker_factory
def dummy_factory(_: SpeculativeAlgorithm, **kwargs):
return "dummy"
SpeculativeAlgorithm.register_draft_worker(algo, dummy_factory)
self.addCleanup(
SpeculativeAlgorithm.register_draft_worker, algo, original_factory
)
self.assertEqual(algo.create_draft_worker(), "dummy")
def _cleanup_registered_algorithm(self, algorithm: SpeculativeAlgorithm, aliases):
name = algorithm.name
SpeculativeAlgorithm._registry_by_value.pop(algorithm.value, None)
SpeculativeAlgorithm._registry_by_name.pop(name, None)
if hasattr(SpeculativeAlgorithm, name):
delattr(SpeculativeAlgorithm, name)
for alias in aliases:
SpeculativeAlgorithm._registry_by_name.pop(alias, None)
try:
SpeculativeAlgorithm._registration_order.remove(algorithm)
except ValueError:
pass
for flag_values in SpeculativeAlgorithm._flags.values():
flag_values.discard(algorithm.value)
spec_info_module._REGISTERED_WORKERS.pop(algorithm, None)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment