Unverified Commit 78f13981 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

[1/N] DP-Refactor: move communicators into `tokenizer_communicator_mixin` (#10028)

parent bfd7a18d
...@@ -36,7 +36,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -36,7 +36,8 @@ from sglang.srt.managers.io_struct import (
MultiTokenizerRegisterReq, MultiTokenizerRegisterReq,
MultiTokenizerWrapper, MultiTokenizerWrapper,
) )
from sglang.srt.managers.tokenizer_manager import TokenizerManager, _Communicator from sglang.srt.managers.tokenizer_communicator_mixin import _Communicator
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_zmq_socket, kill_process_tree from sglang.srt.utils import get_zmq_socket, kill_process_tree
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
......
from __future__ import annotations
import asyncio
import logging
import os
import time
from collections import deque
from typing import (
TYPE_CHECKING,
Any,
Deque,
Dict,
Generic,
List,
Optional,
Tuple,
TypeVar,
)
import fastapi
from sglang.srt.managers.io_struct import (
ClearHiCacheReqInput,
ClearHiCacheReqOutput,
ExpertDistributionReq,
ExpertDistributionReqOutput,
FlushCacheReqInput,
FlushCacheReqOutput,
GetInternalStateReq,
GetInternalStateReqOutput,
GetWeightsByNameReqInput,
GetWeightsByNameReqOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput,
LoRAUpdateResult,
MultiTokenizerWrapper,
ProfileReq,
ProfileReqOutput,
ProfileReqType,
ReleaseMemoryOccupationReqInput,
ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput,
SetInternalStateReq,
SetInternalStateReqOutput,
SlowDownReqInput,
SlowDownReqOutput,
UnloadLoRAAdapterReqInput,
UnloadLoRAAdapterReqOutput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromDistributedReqOutput,
UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
)
from sglang.srt.server_args import LoRARef, ServerArgs
from sglang.srt.utils import get_bool_env_var
from sglang.utils import TypeBasedDispatcher
if TYPE_CHECKING:
from sglang.srt.managers.tokenizer_manager import TokenizerManager
T = TypeVar("T")
logger = logging.getLogger(__name__)
class _Communicator(Generic[T]):
"""Note: The communicator now only run up to 1 in-flight request at any time."""
enable_multi_tokenizer = False
def __init__(self, sender, fan_out: int):
self._sender = sender
self._fan_out = fan_out
self._result_event: Optional[asyncio.Event] = None
self._result_values: Optional[List[T]] = None
self._ready_queue: Deque[asyncio.Future] = deque()
async def __call__(self, obj):
ready_event = asyncio.Event()
if self._result_event is not None or len(self._ready_queue) > 0:
self._ready_queue.append(ready_event)
await ready_event.wait()
assert self._result_event is None
assert self._result_values is None
if obj:
if _Communicator.enable_multi_tokenizer:
obj = MultiTokenizerWrapper(worker_id=os.getpid(), obj=obj)
self._sender.send_pyobj(obj)
self._result_event = asyncio.Event()
self._result_values = []
await self._result_event.wait()
result_values = self._result_values
self._result_event = self._result_values = None
if len(self._ready_queue) > 0:
self._ready_queue.popleft().set()
return result_values
def handle_recv(self, recv_obj: T):
self._result_values.append(recv_obj)
if len(self._result_values) == self._fan_out:
self._result_event.set()
class TokenizerCommunicatorMixin:
"""Mixin class for TokenizerManager to handle communication with the scheduler."""
def init_communicators(self: TokenizerManager, server_args: ServerArgs):
# Communicators
self.init_weights_update_group_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.update_weights_from_distributed_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.update_weights_from_tensor_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.get_weights_by_name_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.release_memory_occupation_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.resume_memory_occupation_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.slow_down_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.flush_cache_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.clear_hicache_storage_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.profile_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.get_internal_state_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.set_internal_state_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.expert_distribution_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.update_lora_adapter_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self._result_dispatcher += self._get_communicator_dispatcher()
def _get_communicator_dispatcher(self: TokenizerManager):
return TypeBasedDispatcher(
[
(
InitWeightsUpdateGroupReqOutput,
self.init_weights_update_group_communicator.handle_recv,
),
(
UpdateWeightsFromDistributedReqOutput,
self.update_weights_from_distributed_communicator.handle_recv,
),
(
UpdateWeightsFromTensorReqOutput,
self.update_weights_from_tensor_communicator.handle_recv,
),
(
GetWeightsByNameReqOutput,
self.get_weights_by_name_communicator.handle_recv,
),
(
ReleaseMemoryOccupationReqOutput,
self.release_memory_occupation_communicator.handle_recv,
),
(
ResumeMemoryOccupationReqOutput,
self.resume_memory_occupation_communicator.handle_recv,
),
(
SlowDownReqOutput,
self.slow_down_communicator.handle_recv,
),
(
ClearHiCacheReqOutput,
self.clear_hicache_storage_communicator.handle_recv,
),
(
FlushCacheReqOutput,
self.flush_cache_communicator.handle_recv,
),
(
ProfileReqOutput,
self.profile_communicator.handle_recv,
),
(
GetInternalStateReqOutput,
self.get_internal_state_communicator.handle_recv,
),
(
SetInternalStateReqOutput,
self.set_internal_state_communicator.handle_recv,
),
(
ExpertDistributionReqOutput,
self.expert_distribution_communicator.handle_recv,
),
(
LoRAUpdateResult,
self.update_lora_adapter_communicator.handle_recv,
),
]
)
async def flush_cache(self: TokenizerManager) -> FlushCacheReqOutput:
return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
async def clear_hicache_storage(self: TokenizerManager) -> ClearHiCacheReqOutput:
"""Clear the hierarchical cache storage."""
# Delegate to the scheduler to handle HiCacheStorage clearing
return (await self.clear_hicache_storage_communicator(ClearHiCacheReqInput()))[
0
]
async def start_profile(
self: TokenizerManager,
output_dir: Optional[str] = None,
start_step: Optional[int] = None,
num_steps: Optional[int] = None,
activities: Optional[List[str]] = None,
with_stack: Optional[bool] = None,
record_shapes: Optional[bool] = None,
profile_by_stage: bool = False,
):
self.auto_create_handle_loop()
env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
with_stack = False if with_stack is False or env_with_stack is False else True
req = ProfileReq(
type=ProfileReqType.START_PROFILE,
output_dir=output_dir,
start_step=start_step,
num_steps=num_steps,
activities=activities,
with_stack=with_stack,
record_shapes=record_shapes,
profile_by_stage=profile_by_stage,
profile_id=str(time.time()),
)
return await self._execute_profile(req)
async def stop_profile(self: TokenizerManager):
self.auto_create_handle_loop()
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
return await self._execute_profile(req)
async def _execute_profile(self: TokenizerManager, req: ProfileReq):
result = (await self.profile_communicator(req))[0]
if not result.success:
raise RuntimeError(result.message)
return result
async def start_expert_distribution_record(self: TokenizerManager):
self.auto_create_handle_loop()
await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
async def stop_expert_distribution_record(self: TokenizerManager):
self.auto_create_handle_loop()
await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
async def dump_expert_distribution_record(self: TokenizerManager):
self.auto_create_handle_loop()
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
async def init_weights_update_group(
self: TokenizerManager,
obj: InitWeightsUpdateGroupReqInput,
request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for init parameter update group"
result = (await self.init_weights_update_group_communicator(obj))[0]
return result.success, result.message
async def update_weights_from_distributed(
self: TokenizerManager,
obj: UpdateWeightsFromDistributedReqInput,
request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
assert (
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
if obj.abort_all_requests:
self.abort_request(abort_all=True)
# This means that weight sync
# cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
result = (await self.update_weights_from_distributed_communicator(obj))[0]
return result.success, result.message
async def update_weights_from_tensor(
self: TokenizerManager,
obj: UpdateWeightsFromTensorReqInput,
request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
assert (
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
if obj.abort_all_requests:
self.abort_request(abort_all=True)
# This means that weight sync
# cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
result = (await self.update_weights_from_tensor_communicator(obj))[0]
return result.success, result.message
async def load_lora_adapter(
self: TokenizerManager,
obj: LoadLoRAAdapterReqInput,
_: Optional[fastapi.Request] = None,
) -> LoadLoRAAdapterReqOutput:
self.auto_create_handle_loop()
try:
if not self.server_args.enable_lora:
raise ValueError(
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
)
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1.
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for dynamic lora loading"
logger.info(
"Start load Lora adapter. Lora name=%s, path=%s",
obj.lora_name,
obj.lora_path,
)
async with self.lora_update_lock:
if (
self.server_args.max_loaded_loras is not None
and self.lora_registry.num_registered_loras
>= self.server_args.max_loaded_loras
):
raise ValueError(
f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
"Please unload some LoRA adapters before loading new ones."
)
# Generate new uniquely identifiable LoRARef object.
new_adapter = LoRARef(
lora_name=obj.lora_name,
lora_path=obj.lora_path,
pinned=obj.pinned,
)
# Trigger the actual loading operation at the backend processes.
obj.lora_id = new_adapter.lora_id
result = (await self.update_lora_adapter_communicator(obj))[0]
# Register the LoRA adapter only after loading is successful.
if result.success:
await self.lora_registry.register(new_adapter)
return result
except ValueError as e:
return LoadLoRAAdapterReqOutput(
success=False,
error_message=str(e),
)
async def unload_lora_adapter(
self: TokenizerManager,
obj: UnloadLoRAAdapterReqInput,
_: Optional[fastapi.Request] = None,
) -> UnloadLoRAAdapterReqOutput:
self.auto_create_handle_loop()
try:
if not self.server_args.enable_lora:
raise ValueError(
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
)
assert (
obj.lora_name is not None
), "lora_name must be provided to unload LoRA adapter"
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1.
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for dynamic lora loading"
logger.info(
"Start unload Lora adapter. Lora name=%s",
obj.lora_name,
)
async with self.lora_update_lock:
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
# from being started.
lora_id = await self.lora_registry.unregister(obj.lora_name)
obj.lora_id = lora_id
# Initiate the actual unloading operation at the backend processes only after all
# ongoing requests using this LoRA adapter are finished.
await self.lora_registry.wait_for_unload(lora_id)
result = (await self.update_lora_adapter_communicator(obj))[0]
return result
except ValueError as e:
return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))
async def get_weights_by_name(
self: TokenizerManager,
obj: GetWeightsByNameReqInput,
request: Optional[fastapi.Request] = None,
):
self.auto_create_handle_loop()
results = await self.get_weights_by_name_communicator(obj)
all_parameters = [r.parameter for r in results]
if self.server_args.dp_size == 1:
return all_parameters[0]
else:
return all_parameters
async def release_memory_occupation(
self: TokenizerManager,
obj: ReleaseMemoryOccupationReqInput,
request: Optional[fastapi.Request] = None,
):
self.auto_create_handle_loop()
await self.release_memory_occupation_communicator(obj)
async def resume_memory_occupation(
self: TokenizerManager,
obj: ResumeMemoryOccupationReqInput,
request: Optional[fastapi.Request] = None,
):
self.auto_create_handle_loop()
await self.resume_memory_occupation_communicator(obj)
async def slow_down(
self: TokenizerManager,
obj: SlowDownReqInput,
request: Optional[fastapi.Request] = None,
):
self.auto_create_handle_loop()
await self.slow_down_communicator(obj)
async def get_internal_state(self: TokenizerManager) -> List[Dict[Any, Any]]:
req = GetInternalStateReq()
responses: List[GetInternalStateReqOutput] = (
await self.get_internal_state_communicator(req)
)
# Many DP ranks
return [res.internal_state for res in responses]
async def set_internal_state(
self: TokenizerManager, obj: SetInternalStateReq
) -> List[bool]:
responses: List[SetInternalStateReqOutput] = (
await self.set_internal_state_communicator(obj)
)
return [res.updated for res in responses]
async def get_load(self: TokenizerManager) -> dict:
# TODO(lsyin): fake load report server
if not self.current_load_lock.locked():
async with self.current_load_lock:
internal_state = await self.get_internal_state()
self.current_load = internal_state[0]["load"]
return {"load": self.current_load}
...@@ -6,7 +6,7 @@ from torch.distributed.device_mesh import DeviceMesh ...@@ -6,7 +6,7 @@ from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor import DTensor from torch.distributed.tensor import DTensor
from sglang.srt.entrypoints.engine import Engine from sglang.srt.entrypoints.engine import Engine
from sglang.srt.managers.tokenizer_manager import UpdateWeightsFromTensorReqInput from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput
from sglang.srt.model_executor.model_runner import LocalSerializedTensor from sglang.srt.model_executor.model_runner import LocalSerializedTensor
from sglang.srt.utils import MultiprocessingSerializer from sglang.srt.utils import MultiprocessingSerializer
......
...@@ -473,6 +473,10 @@ class TypeBasedDispatcher: ...@@ -473,6 +473,10 @@ class TypeBasedDispatcher:
def __init__(self, mapping: List[Tuple[Type, Callable]]): def __init__(self, mapping: List[Tuple[Type, Callable]]):
self._mapping = mapping self._mapping = mapping
def __iadd__(self, other: "TypeBasedDispatcher"):
self._mapping.extend(other._mapping)
return self
def __call__(self, obj: Any): def __call__(self, obj: Any):
for ty, fn in self._mapping: for ty, fn in self._mapping:
if isinstance(obj, ty): if isinstance(obj, ty):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment